Skip to content

Commit dfa6fe0

Browse files
authored
Add continue_as_new orchestration API (#9)
Compatible with durabletask-go v0.2.1 or newer.
1 parent ff8df5e commit dfa6fe0

8 files changed

Lines changed: 173 additions & 33 deletions

File tree

durabletask/client.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, *,
9797
log_formatter: Union[logging.Formatter, None] = None):
9898
channel = shared.get_grpc_channel(host_address)
9999
self._stub = stubs.TaskHubSidecarServiceStub(channel)
100-
self._logger = shared.get_logger(log_handler, log_formatter)
100+
self._logger = shared.get_logger("client", log_handler, log_formatter)
101101

102102
def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
103103
input: Union[TInput, None] = None,
@@ -109,7 +109,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
109109
req = pb.CreateInstanceRequest(
110110
name=name,
111111
instanceId=instance_id if instance_id else uuid.uuid4().hex,
112-
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
112+
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
113113
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
114114
version=wrappers_pb2.StringValue(value=""))
115115

@@ -144,7 +144,19 @@ def wait_for_orchestration_completion(self, instance_id: str, *,
144144
try:
145145
self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.")
146146
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout)
147-
return new_orchestration_state(req.instanceId, res)
147+
state = new_orchestration_state(req.instanceId, res)
148+
if not state:
149+
return None
150+
151+
if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None:
152+
details = state.failure_details
153+
self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
154+
elif state.runtime_status == OrchestrationStatus.TERMINATED:
155+
self._logger.info(f"Instance '{instance_id}' was terminated.")
156+
elif state.runtime_status == OrchestrationStatus.COMPLETED:
157+
self._logger.info(f"Instance '{instance_id}' completed.")
158+
159+
return state
148160
except grpc.RpcError as rpc_error:
149161
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
150162
# Replace gRPC error with the built-in TimeoutError

durabletask/internal/helpers.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import traceback
55
from datetime import datetime
6-
from typing import Union
6+
from typing import List, Union
77

88
from google.protobuf import timestamp_pb2, wrappers_pb2
99

@@ -161,14 +161,13 @@ def new_complete_orchestration_action(
161161
id: int,
162162
status: pb.OrchestrationStatus,
163163
result: Union[str, None] = None,
164-
failure_details: Union[pb.TaskFailureDetails, None] = None) -> pb.OrchestratorAction:
165-
164+
failure_details: Union[pb.TaskFailureDetails, None] = None,
165+
carryover_events: Union[List[pb.HistoryEvent], None] = None) -> pb.OrchestratorAction:
166166
completeOrchestrationAction = pb.CompleteOrchestrationAction(
167167
orchestrationStatus=status,
168168
result=get_string_value(result),
169-
failureDetails=failure_details)
170-
171-
# TODO: CarryoverEvents
169+
failureDetails=failure_details,
170+
carryoverEvents=carryover_events)
172171

173172
return pb.OrchestratorAction(id=id, completeOrchestration=completeOrchestrationAction)
174173

durabletask/internal/shared.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ def get_grpc_channel(host_address: Union[str, None]) -> grpc.Channel:
2626

2727

2828
def get_logger(
29+
name_suffix: str,
2930
log_handler: Union[logging.Handler, None] = None,
3031
log_formatter: Union[logging.Formatter, None] = None) -> logging.Logger:
31-
logger = logging.Logger("durabletask")
32+
logger = logging.Logger(f"durabletask-{name_suffix}")
3233

3334
# Add a default log handler if none is provided
3435
if log_handler is None:

durabletask/task.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,19 @@ def wait_for_external_event(self, name: str) -> Task:
147147
"""
148148
pass
149149

150+
@abstractmethod
151+
def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None:
152+
"""Continue the orchestration execution as a new instance.
153+
154+
Parameters
155+
----------
156+
new_input : Any
157+
The new input to use for the new orchestration instance.
158+
save_events : bool
159+
A flag indicating whether to add any unprocessed external events in the new orchestration history.
160+
"""
161+
pass
162+
150163

151164
class FailureDetails:
152165
def __init__(self, message: str, error_type: str, stack_trace: Union[str, None]):

durabletask/worker.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import concurrent.futures
55
import logging
6-
from dataclasses import dataclass
76
from datetime import datetime, timedelta
87
from threading import Event, Thread
98
from types import GeneratorType
@@ -90,7 +89,7 @@ def __init__(self, *,
9089
log_formatter: Union[logging.Formatter, None] = None):
9190
self._registry = _Registry()
9291
self._host_address = host_address if host_address else shared.get_default_host_address()
93-
self._logger = shared.get_logger(log_handler, log_formatter)
92+
self._logger = shared.get_logger("worker", log_handler, log_formatter)
9493
self._shutdown = Event()
9594
self._response_stream = None
9695
self._is_running = False
@@ -149,7 +148,7 @@ def run_loop():
149148

150149
except grpc.RpcError as rpc_error:
151150
if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore
152-
self._logger.warning(f'Disconnected from {self._host_address}')
151+
self._logger.info(f'Disconnected from {self._host_address}')
153152
elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore
154153
self._logger.warning(
155154
f'The sidecar at address {self._host_address} is unavailable - will continue retrying')
@@ -163,7 +162,7 @@ def run_loop():
163162
self._logger.info("No longer listening for work items")
164163
return
165164

166-
self._logger.info(f"starting gRPC worker that connects to {self._host_address}")
165+
self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
167166
self._runLoop = Thread(target=run_loop)
168167
self._runLoop.start()
169168
self._is_running = True
@@ -220,12 +219,6 @@ def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarS
220219
f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}")
221220

222221

223-
@dataclass
224-
class _ExternalEvent:
225-
name: str
226-
data: Any
227-
228-
229222
class _RuntimeOrchestrationContext(task.OrchestrationContext):
230223
_generator: Union[Generator[task.Task, Any, Any], None]
231224
_previous_task: Union[task.Task, None]
@@ -241,8 +234,10 @@ def __init__(self, instance_id: str):
241234
self._current_utc_datetime = datetime(1000, 1, 1)
242235
self._instance_id = instance_id
243236
self._completion_status: Union[pb.OrchestrationStatus, None] = None
244-
self._received_events: Dict[str, List[_ExternalEvent]] = {}
237+
self._received_events: Dict[str, List[Any]] = {}
245238
self._pending_events: Dict[str, List[task.CompletableTask]] = {}
239+
self._new_input: Union[Any, None] = None
240+
self._save_events = False
246241

247242
def run(self, generator: Generator[task.Task, Any, Any]):
248243
self._generator = generator
@@ -282,6 +277,9 @@ def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_en
282277
return
283278

284279
self._is_complete = True
280+
self._completion_status = status
281+
self._pending_actions.clear() # Cancel any pending actions
282+
285283
self._result = result
286284
result_json: Union[str, None] = None
287285
if result is not None:
@@ -296,13 +294,44 @@ def set_failed(self, ex: Exception):
296294

297295
self._is_complete = True
298296
self._pending_actions.clear() # Cancel any pending actions
297+
self._completion_status = pb.ORCHESTRATION_STATUS_FAILED
298+
299299
action = ph.new_complete_orchestration_action(
300300
self.next_sequence_number(), pb.ORCHESTRATION_STATUS_FAILED, None, ph.new_failure_details(ex)
301301
)
302302
self._pending_actions[action.id] = action
303303

304+
def set_continued_as_new(self, new_input: Any, save_events: bool):
305+
if self._is_complete:
306+
return
307+
308+
self._is_complete = True
309+
self._pending_actions.clear() # Cancel any pending actions
310+
self._completion_status = pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
311+
self._new_input = new_input
312+
self._save_events = save_events
313+
304314
def get_actions(self) -> List[pb.OrchestratorAction]:
305-
return list(self._pending_actions.values())
315+
if self._completion_status == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW:
316+
# When continuing-as-new, we only return a single completion action.
317+
carryover_events: Union[List[pb.HistoryEvent], None] = None
318+
if self._save_events:
319+
carryover_events = []
320+
# We need to save the current set of pending events so that they can be
321+
# replayed when the new instance starts.
322+
for event_name, values in self._received_events.items():
323+
for event_value in values:
324+
encoded_value = shared.to_json(event_value) if event_value else None
325+
carryover_events.append(ph.new_event_raised_event(event_name, encoded_value))
326+
action = ph.new_complete_orchestration_action(
327+
self.next_sequence_number(),
328+
pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW,
329+
result=shared.to_json(self._new_input) if self._new_input is not None else None,
330+
failure_details=None,
331+
carryover_events=carryover_events)
332+
return [action]
333+
else:
334+
return list(self._pending_actions.values())
306335

307336
def next_sequence_number(self) -> int:
308337
self._sequence_number += 1
@@ -370,13 +399,13 @@ def wait_for_external_event(self, name: str) -> task.Task:
370399
# arrives. If there are multiple events with the same name, we return
371400
# them in the order they were received.
372401
external_event_task = task.CompletableTask()
373-
event_name = name.upper()
402+
event_name = name.casefold()
374403
event_list = self._received_events.get(event_name, None)
375404
if event_list:
376-
event = event_list.pop(0)
405+
event_data = event_list.pop(0)
377406
if not event_list:
378407
del self._received_events[event_name]
379-
external_event_task.complete(event.data)
408+
external_event_task.complete(event_data)
380409
else:
381410
task_list = self._pending_events.get(event_name, None)
382411
if not task_list:
@@ -385,6 +414,12 @@ def wait_for_external_event(self, name: str) -> task.Task:
385414
task_list.append(external_event_task)
386415
return external_event_task
387416

417+
def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
418+
if self._is_complete:
419+
return
420+
421+
self.set_continued_as_new(new_input, save_events)
422+
388423

389424
class _OrchestrationExecutor:
390425
_generator: Union[task.Orchestrator, None]
@@ -415,13 +450,16 @@ def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_e
415450
ctx._is_replaying = False
416451
for new_event in new_events:
417452
self.process_event(ctx, new_event)
418-
if ctx._is_complete:
419-
break
453+
420454
except Exception as ex:
421455
# Unhandled exceptions fail the orchestration
422456
ctx.set_failed(ex)
423457

424-
if ctx._completion_status:
458+
if not ctx._is_complete:
459+
task_count = len(ctx._pending_tasks)
460+
event_count = len(ctx._pending_events)
461+
self._logger.info(f"{instance_id}: Waiting for {task_count} task(s) and {event_count} event(s).")
462+
elif ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW:
425463
completion_status_str = pbh.get_orchestration_status_str(ctx._completion_status)
426464
self._logger.info(f"{instance_id}: Orchestration completed with status: {completion_status_str}")
427465

@@ -570,9 +608,9 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
570608
ctx.resume()
571609
elif event.HasField("eventRaised"):
572610
# event names are case-insensitive
573-
event_name = event.eventRaised.name.upper()
611+
event_name = event.eventRaised.name.casefold()
574612
if not ctx.is_replaying:
575-
self._logger.info(f"Event raised: {event_name}")
613+
self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
576614
task_list = ctx._pending_events.get(event_name, None)
577615
decoded_result: Union[Any, None] = None
578616
if task_list:
@@ -591,7 +629,7 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
591629
ctx._received_events[event_name] = event_list
592630
if not ph.is_empty(event.eventRaised.input):
593631
decoded_result = shared.from_json(event.eventRaised.input.value)
594-
event_list.append(_ExternalEvent(event.eventRaised.name, decoded_result))
632+
event_list.append(decoded_result)
595633
if not ctx.is_replaying:
596634
self._logger.info(f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it.")
597635
elif event.HasField("executionSuspended"):

tests/test_activity_executor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
# Licensed under the MIT License.
33

44
import json
5+
import logging
56
from typing import Any, Tuple, Union
67

7-
import durabletask.internal.shared as shared
88
from durabletask import task, worker
99

10-
TEST_LOGGER = shared.get_logger()
10+
logging.basicConfig(
11+
format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s',
12+
datefmt='%Y-%m-%d %H:%M:%S',
13+
level=logging.DEBUG)
14+
TEST_LOGGER = logging.getLogger("tests")
1115
TEST_INSTANCE_ID = 'abc123'
1216
TEST_TASK_ID = 42
1317

tests/test_orchestration_e2e.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,39 @@ def orchestrator(ctx: task.OrchestrationContext, _):
232232
assert state is not None
233233
assert state.runtime_status == client.OrchestrationStatus.TERMINATED
234234
assert state.serialized_output == json.dumps("some reason for termination")
235+
236+
237+
def test_continue_as_new():
238+
all_results = []
239+
240+
def orchestrator(ctx: task.OrchestrationContext, input: int):
241+
result = yield ctx.wait_for_external_event("my_event")
242+
if not ctx.is_replaying:
243+
# NOTE: Real orchestrations should never interact with nonlocal variables like this.
244+
nonlocal all_results
245+
all_results.append(result)
246+
247+
if len(all_results) <= 4:
248+
ctx.continue_as_new(max(all_results), save_events=True)
249+
else:
250+
return all_results
251+
252+
# Start a worker, which will connect to the sidecar in a background thread
253+
with worker.TaskHubGrpcWorker() as w:
254+
w.add_orchestrator(orchestrator)
255+
w.start()
256+
257+
task_hub_client = client.TaskHubGrpcClient()
258+
id = task_hub_client.schedule_new_orchestration(orchestrator, input=0)
259+
task_hub_client.raise_orchestration_event(id, "my_event", data=1)
260+
task_hub_client.raise_orchestration_event(id, "my_event", data=2)
261+
task_hub_client.raise_orchestration_event(id, "my_event", data=3)
262+
task_hub_client.raise_orchestration_event(id, "my_event", data=4)
263+
task_hub_client.raise_orchestration_event(id, "my_event", data=5)
264+
265+
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
266+
assert state is not None
267+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
268+
assert state.serialized_output == json.dumps(all_results)
269+
assert state.serialized_input == json.dumps(4)
270+
assert all_results == [1, 2, 3, 4, 5]

tests/test_orchestration_executor.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from datetime import datetime, timedelta
77
from typing import List
88

9+
import pytest
10+
911
import durabletask.internal.helpers as helpers
1012
import durabletask.internal.orchestrator_service_pb2 as pb
1113
from durabletask import task, worker
@@ -585,6 +587,41 @@ def orchestrator(ctx: task.OrchestrationContext, _):
585587
assert complete_action.result.value == json.dumps("terminated!")
586588

587589

590+
@pytest.mark.parametrize("save_events", [True, False])
591+
def test_continue_as_new(save_events: bool):
592+
"""Tests the behavior of the continue-as-new API"""
593+
def orchestrator(ctx: task.OrchestrationContext, input: int):
594+
yield ctx.create_timer(ctx.current_utc_datetime + timedelta(days=1))
595+
ctx.continue_as_new(input + 1, save_events=save_events)
596+
597+
registry = worker._Registry()
598+
orchestrator_name = registry.add_orchestrator(orchestrator)
599+
600+
old_events = [
601+
helpers.new_orchestrator_started_event(),
602+
helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input="1"),
603+
helpers.new_event_raised_event("my_event", encoded_input="42"),
604+
helpers.new_event_raised_event("my_event", encoded_input="43"),
605+
helpers.new_event_raised_event("my_event", encoded_input="44"),
606+
helpers.new_timer_created_event(1, datetime.utcnow() + timedelta(days=1))]
607+
new_events = [
608+
helpers.new_timer_fired_event(1, datetime.utcnow() + timedelta(days=1))]
609+
610+
# Execute the orchestration. It should be in a running state waiting for the timer to fire
611+
executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
612+
actions = executor.execute(TEST_INSTANCE_ID, old_events, new_events)
613+
complete_action = get_and_validate_single_complete_orchestration_action(actions)
614+
assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW
615+
assert complete_action.result.value == json.dumps(2)
616+
assert len(complete_action.carryoverEvents) == (3 if save_events else 0)
617+
for i in range(len(complete_action.carryoverEvents)):
618+
event = complete_action.carryoverEvents[i]
619+
assert type(event) is pb.HistoryEvent
620+
assert event.HasField("eventRaised")
621+
assert event.eventRaised.name.casefold() == "my_event".casefold() # event names are case-insensitive
622+
assert event.eventRaised.input.value == json.dumps(42 + i)
623+
624+
588625
def test_fan_out():
589626
"""Tests that a fan-out pattern correctly schedules N tasks"""
590627
def hello(_, name: str):

0 commit comments

Comments
 (0)