33
44import concurrent .futures
55import logging
6- from dataclasses import dataclass
76from datetime import datetime , timedelta
87from threading import Event , Thread
98from types import GeneratorType
@@ -90,7 +89,7 @@ def __init__(self, *,
9089 log_formatter : Union [logging .Formatter , None ] = None ):
9190 self ._registry = _Registry ()
9291 self ._host_address = host_address if host_address else shared .get_default_host_address ()
93- self ._logger = shared .get_logger (log_handler , log_formatter )
92+ self ._logger = shared .get_logger ("worker" , log_handler , log_formatter )
9493 self ._shutdown = Event ()
9594 self ._response_stream = None
9695 self ._is_running = False
@@ -149,7 +148,7 @@ def run_loop():
149148
150149 except grpc .RpcError as rpc_error :
151150 if rpc_error .code () == grpc .StatusCode .CANCELLED : # type: ignore
152- self ._logger .warning (f'Disconnected from { self ._host_address } ' )
151+ self ._logger .info (f'Disconnected from { self ._host_address } ' )
153152 elif rpc_error .code () == grpc .StatusCode .UNAVAILABLE : # type: ignore
154153 self ._logger .warning (
155154 f'The sidecar at address { self ._host_address } is unavailable - will continue retrying' )
@@ -163,7 +162,7 @@ def run_loop():
163162 self ._logger .info ("No longer listening for work items" )
164163 return
165164
166- self ._logger .info (f"starting gRPC worker that connects to { self ._host_address } " )
165+ self ._logger .info (f"Starting gRPC worker that connects to { self ._host_address } " )
167166 self ._runLoop = Thread (target = run_loop )
168167 self ._runLoop .start ()
169168 self ._is_running = True
@@ -220,12 +219,6 @@ def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarS
220219 f"Failed to deliver activity response for '{ req .name } #{ req .taskId } ' of orchestration ID '{ instance_id } ' to sidecar: { ex } " )
221220
222221
223- @dataclass
224- class _ExternalEvent :
225- name : str
226- data : Any
227-
228-
229222class _RuntimeOrchestrationContext (task .OrchestrationContext ):
230223 _generator : Union [Generator [task .Task , Any , Any ], None ]
231224 _previous_task : Union [task .Task , None ]
@@ -241,8 +234,10 @@ def __init__(self, instance_id: str):
241234 self ._current_utc_datetime = datetime (1000 , 1 , 1 )
242235 self ._instance_id = instance_id
243236 self ._completion_status : Union [pb .OrchestrationStatus , None ] = None
244- self ._received_events : Dict [str , List [_ExternalEvent ]] = {}
237+ self ._received_events : Dict [str , List [Any ]] = {}
245238 self ._pending_events : Dict [str , List [task .CompletableTask ]] = {}
239+ self ._new_input : Union [Any , None ] = None
240+ self ._save_events = False
246241
247242 def run (self , generator : Generator [task .Task , Any , Any ]):
248243 self ._generator = generator
@@ -282,6 +277,9 @@ def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_en
282277 return
283278
284279 self ._is_complete = True
280+ self ._completion_status = status
281+ self ._pending_actions .clear () # Cancel any pending actions
282+
285283 self ._result = result
286284 result_json : Union [str , None ] = None
287285 if result is not None :
@@ -296,13 +294,44 @@ def set_failed(self, ex: Exception):
296294
297295 self ._is_complete = True
298296 self ._pending_actions .clear () # Cancel any pending actions
297+ self ._completion_status = pb .ORCHESTRATION_STATUS_FAILED
298+
299299 action = ph .new_complete_orchestration_action (
300300 self .next_sequence_number (), pb .ORCHESTRATION_STATUS_FAILED , None , ph .new_failure_details (ex )
301301 )
302302 self ._pending_actions [action .id ] = action
303303
304+ def set_continued_as_new (self , new_input : Any , save_events : bool ):
305+ if self ._is_complete :
306+ return
307+
308+ self ._is_complete = True
309+ self ._pending_actions .clear () # Cancel any pending actions
310+ self ._completion_status = pb .ORCHESTRATION_STATUS_CONTINUED_AS_NEW
311+ self ._new_input = new_input
312+ self ._save_events = save_events
313+
304314 def get_actions (self ) -> List [pb .OrchestratorAction ]:
305- return list (self ._pending_actions .values ())
315+ if self ._completion_status == pb .ORCHESTRATION_STATUS_CONTINUED_AS_NEW :
316+ # When continuing-as-new, we only return a single completion action.
317+ carryover_events : Union [List [pb .HistoryEvent ], None ] = None
318+ if self ._save_events :
319+ carryover_events = []
320+ # We need to save the current set of pending events so that they can be
321+ # replayed when the new instance starts.
322+ for event_name , values in self ._received_events .items ():
323+ for event_value in values :
324+ encoded_value = shared .to_json (event_value ) if event_value else None
325+ carryover_events .append (ph .new_event_raised_event (event_name , encoded_value ))
326+ action = ph .new_complete_orchestration_action (
327+ self .next_sequence_number (),
328+ pb .ORCHESTRATION_STATUS_CONTINUED_AS_NEW ,
329+ result = shared .to_json (self ._new_input ) if self ._new_input is not None else None ,
330+ failure_details = None ,
331+ carryover_events = carryover_events )
332+ return [action ]
333+ else :
334+ return list (self ._pending_actions .values ())
306335
307336 def next_sequence_number (self ) -> int :
308337 self ._sequence_number += 1
@@ -370,13 +399,13 @@ def wait_for_external_event(self, name: str) -> task.Task:
370399 # arrives. If there are multiple events with the same name, we return
371400 # them in the order they were received.
372401 external_event_task = task .CompletableTask ()
373- event_name = name .upper ()
402+ event_name = name .casefold ()
374403 event_list = self ._received_events .get (event_name , None )
375404 if event_list :
376- event = event_list .pop (0 )
405+ event_data = event_list .pop (0 )
377406 if not event_list :
378407 del self ._received_events [event_name ]
379- external_event_task .complete (event . data )
408+ external_event_task .complete (event_data )
380409 else :
381410 task_list = self ._pending_events .get (event_name , None )
382411 if not task_list :
@@ -385,6 +414,12 @@ def wait_for_external_event(self, name: str) -> task.Task:
385414 task_list .append (external_event_task )
386415 return external_event_task
387416
417+ def continue_as_new (self , new_input , * , save_events : bool = False ) -> None :
418+ if self ._is_complete :
419+ return
420+
421+ self .set_continued_as_new (new_input , save_events )
422+
388423
389424class _OrchestrationExecutor :
390425 _generator : Union [task .Orchestrator , None ]
@@ -415,13 +450,16 @@ def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_e
415450 ctx ._is_replaying = False
416451 for new_event in new_events :
417452 self .process_event (ctx , new_event )
418- if ctx ._is_complete :
419- break
453+
420454 except Exception as ex :
421455 # Unhandled exceptions fail the orchestration
422456 ctx .set_failed (ex )
423457
424- if ctx ._completion_status :
458+ if not ctx ._is_complete :
459+ task_count = len (ctx ._pending_tasks )
460+ event_count = len (ctx ._pending_events )
461+ self ._logger .info (f"{ instance_id } : Waiting for { task_count } task(s) and { event_count } event(s)." )
462+ elif ctx ._completion_status and ctx ._completion_status is not pb .ORCHESTRATION_STATUS_CONTINUED_AS_NEW :
425463 completion_status_str = pbh .get_orchestration_status_str (ctx ._completion_status )
426464 self ._logger .info (f"{ instance_id } : Orchestration completed with status: { completion_status_str } " )
427465
@@ -570,9 +608,9 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
570608 ctx .resume ()
571609 elif event .HasField ("eventRaised" ):
572610 # event names are case-insensitive
573- event_name = event .eventRaised .name .upper ()
611+ event_name = event .eventRaised .name .casefold ()
574612 if not ctx .is_replaying :
575- self ._logger .info (f"Event raised: { event_name } " )
613+ self ._logger .info (f"{ ctx . instance_id } Event raised: { event_name } " )
576614 task_list = ctx ._pending_events .get (event_name , None )
577615 decoded_result : Union [Any , None ] = None
578616 if task_list :
@@ -591,7 +629,7 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
591629 ctx ._received_events [event_name ] = event_list
592630 if not ph .is_empty (event .eventRaised .input ):
593631 decoded_result = shared .from_json (event .eventRaised .input .value )
594- event_list .append (_ExternalEvent ( event . eventRaised . name , decoded_result ) )
632+ event_list .append (decoded_result )
595633 if not ctx .is_replaying :
596634 self ._logger .info (f"{ ctx .instance_id } : Event '{ event_name } ' has been buffered as there are no tasks waiting for it." )
597635 elif event .HasField ("executionSuspended" ):
0 commit comments