Skip to content

Commit 62ec412

Browse files
committed
Sub-orchestrations, when_all, and when_any support (#5)
1 parent 45622e9 commit 62ec412

7 files changed

Lines changed: 587 additions & 46 deletions

File tree

.vscode/settings.json

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
{
2-
"editor.formatOnSave": true,
3-
"python.analysis.typeCheckingMode": "basic",
4-
"python.formatting.autopep8Args": [
5-
"--max-line-length=120"
2+
"[python]": {
3+
"editor.defaultFormatter": "ms-python.autopep8",
4+
"editor.formatOnSave": true,
5+
"editor.codeActionsOnSave": {
6+
"source.organizeImports": true,
7+
},
8+
"editor.rulers": [
9+
119
10+
],
11+
},
12+
"autopep8.args": [
13+
"--max-line-length=119"
614
],
15+
"python.analysis.typeCheckingMode": "basic",
716
"python.testing.pytestArgs": [
817
"-v",
918
"--cov=durabletask/",

durabletask/protos/helpers.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,37 @@ def new_task_failed_event(event_id: int, ex: Exception) -> HistoryEvent:
6969
)
7070

7171

72+
def new_sub_orchestration_created_event(event_id: int, name: str, instance_id: str, encoded_input: str | None = None) -> HistoryEvent:
73+
return HistoryEvent(
74+
eventId=event_id,
75+
timestamp=timestamp_pb2.Timestamp(),
76+
subOrchestrationInstanceCreated=SubOrchestrationInstanceCreatedEvent(
77+
name=name,
78+
input=get_string_value(encoded_input),
79+
instanceId=instance_id)
80+
)
81+
82+
83+
def new_sub_orchestration_completed_event(event_id: int, encoded_output: str | None = None) -> HistoryEvent:
84+
return HistoryEvent(
85+
eventId=-1,
86+
timestamp=timestamp_pb2.Timestamp(),
87+
subOrchestrationInstanceCompleted=SubOrchestrationInstanceCompletedEvent(
88+
result=get_string_value(encoded_output),
89+
taskScheduledId=event_id)
90+
)
91+
92+
93+
def new_sub_orchestration_failed_event(event_id: int, ex: Exception) -> HistoryEvent:
94+
return HistoryEvent(
95+
eventId=-1,
96+
timestamp=timestamp_pb2.Timestamp(),
97+
subOrchestrationInstanceFailed=SubOrchestrationInstanceFailedEvent(
98+
failureDetails=new_failure_details(ex),
99+
taskScheduledId=event_id)
100+
)
101+
102+
72103
def new_failure_details(ex: Exception) -> TaskFailureDetails:
73104
return TaskFailureDetails(
74105
errorType=type(ex).__name__,
@@ -120,5 +151,14 @@ def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp:
120151
return ts
121152

122153

154+
def new_create_sub_orchestration_action(id: int, name: str, instance_id: str | None, input: Any) -> OrchestratorAction:
155+
encoded_input = json.dumps(input) if input is not None else None
156+
return OrchestratorAction(id=id, createSubOrchestration=CreateSubOrchestrationAction(
157+
name=name,
158+
instanceId=instance_id,
159+
input=get_string_value(encoded_input)
160+
))
161+
162+
123163
def is_empty(v: wrappers_pb2.StringValue):
124164
return v is None or v.value == ''

durabletask/task/execution.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
1-
import inspect
2-
from logging import Logger
3-
import simplejson as json
4-
51
from datetime import datetime
2+
from logging import Logger
63
from types import GeneratorType
74
from typing import Any, Dict, Generator, Iterable, List, TypeVar
85

6+
import simplejson as json
7+
98
import durabletask.protos.helpers as ph
109
import durabletask.protos.orchestrator_service_pb2 as pb
11-
from durabletask.task.activities import Activity, ActivityContext
1210
import durabletask.task.task as task
13-
11+
from durabletask.task.activities import Activity, ActivityContext
1412
from durabletask.task.orchestration import OrchestrationContext, Orchestrator
1513
from durabletask.task.registry import Registry, get_name
1614
from durabletask.task.task import Task
@@ -67,11 +65,11 @@ def resume(self):
6765
# case is if the user yielded on a WhenAll task and there are still
6866
# outstanding child tasks that need to be completed.
6967
if self._previous_task is not None:
70-
if self._previous_task.is_failed():
68+
if self._previous_task.is_failed:
7169
# Raise the failure as an exception to the generator. The orchestrator can then either
7270
# handle the exception or allow it to fail the orchestration.
7371
self._generator.throw(self._previous_task.get_exception())
74-
elif self._previous_task.is_complete():
72+
elif self._previous_task.is_complete:
7573
# Resume the generator. This will either return a Task or raise StopIteration if it's done.
7674
next_task = self._generator.send(self._previous_task.get_result())
7775
# TODO: Validate the return value
@@ -138,6 +136,21 @@ def call_activity(self, activity: Activity[TInput, TOutput], *,
138136
self._pending_tasks[id] = activity_task
139137
return activity_task
140138

139+
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
140+
input: TInput | None = None,
141+
instance_id: str | None = None) -> task.Task[TOutput]:
142+
id = self.next_sequence_number()
143+
name = get_name(orchestrator)
144+
if instance_id is None:
145+
# Create a deteministic instance ID based on the parent instance ID
146+
instance_id = f"{self.instance_id}:{id:04x}"
147+
action = ph.new_create_sub_orchestration_action(id, name, instance_id, input)
148+
self._pending_actions[id] = action
149+
150+
sub_orch_task = task.CompletableTask[TOutput]()
151+
self._pending_tasks[id] = sub_orch_task
152+
return sub_orch_task
153+
141154

142155
class OrchestrationExecutor:
143156
_generator: Orchestrator | None
@@ -252,6 +265,45 @@ def process_event(self, ctx: RuntimeOrchestrationContext, event: pb.HistoryEvent
252265
return
253266
activity_task.fail(event.taskFailed.failureDetails)
254267
ctx.resume()
268+
elif event.HasField("subOrchestrationInstanceCreated"):
269+
# This history event confirms that the sub-orchestration execution was successfully scheduled.
270+
# Remove the subOrchestrationInstanceCreated event from the pending action list so we don't schedule it again.
271+
task_id = event.eventId
272+
action = ctx._pending_actions.pop(task_id, None)
273+
if not action:
274+
raise _get_non_determinism_error(task_id, get_name(ctx.call_sub_orchestrator))
275+
elif not action.HasField("createSubOrchestration"):
276+
expected_method_name = get_name(ctx.call_sub_orchestrator)
277+
raise _get_wrong_action_type_error(task_id, expected_method_name, action)
278+
elif action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name:
279+
raise _get_wrong_action_name_error(
280+
task_id,
281+
method_name=get_name(ctx.call_sub_orchestrator),
282+
expected_task_name=event.subOrchestrationInstanceCreated.name,
283+
actual_task_name=action.createSubOrchestration.name)
284+
elif event.HasField("subOrchestrationInstanceCompleted"):
285+
task_id = event.subOrchestrationInstanceCompleted.taskScheduledId
286+
sub_orch_task = ctx._pending_tasks.pop(task_id, None)
287+
if not sub_orch_task:
288+
# TODO: Should this be an error? When would it ever happen?
289+
self._logger.warning(
290+
f"Ignoring unexpected subOrchestrationInstanceCompleted event for '{ctx.instance_id}' with ID = {task_id}.")
291+
return
292+
result = None
293+
if not ph.is_empty(event.subOrchestrationInstanceCompleted.result):
294+
result = json.loads(event.subOrchestrationInstanceCompleted.result.value)
295+
sub_orch_task.complete(result)
296+
ctx.resume()
297+
elif event.HasField("subOrchestrationInstanceFailed"):
298+
task_id = event.subOrchestrationInstanceFailed.taskScheduledId
299+
sub_orch_task = ctx._pending_tasks.pop(task_id, None)
300+
if not sub_orch_task:
301+
# TODO: Should this be an error? When would it ever happen?
302+
self._logger.warning(
303+
f"Ignoring unexpected subOrchestrationInstanceFailed event for '{ctx.instance_id}' with ID = {task_id}.")
304+
return
305+
sub_orch_task.fail(event.subOrchestrationInstanceFailed.failureDetails)
306+
ctx.resume()
255307
else:
256308
eventType = event.WhichOneof("eventType")
257309
raise OrchestrationStateError(f"Don't know how to handle event of type '{eventType}'")

durabletask/task/orchestration.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# See https://peps.python.org/pep-0563/
2+
from __future__ import annotations
3+
14
from abc import ABC, abstractmethod
25
from datetime import datetime
36
from typing import Any, Callable, Generator, TypeVar
@@ -99,6 +102,29 @@ def call_activity(self, activity: Activity[TInput, TOutput], *,
99102
"""
100103
pass
101104

105+
@abstractmethod
106+
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
107+
input: TInput | None = None,
108+
instance_id: str | None = None) -> task.Task[TOutput]:
109+
"""Schedule sub-orchestrator function for execution.
110+
111+
Parameters
112+
----------
113+
orchestrator: Orchestrator[TInput, TOutput]
114+
A reference to the orchestrator function to call.
115+
input: TInput | None
116+
The optional JSON-serializable input to pass to the orchestrator function.
117+
instance_id: str | None
118+
A unique ID to use for the sub-orchestration instance. If not specified, a
119+
random UUID will be used.
120+
121+
Returns
122+
-------
123+
Task
124+
A Durable Task that completes when the called sub-orchestrator completes or fails.
125+
"""
126+
pass
127+
102128

103129
# Orchestrators are generators that yield tasks and receive/return any type
104130
Orchestrator = Callable[[OrchestrationContext, TInput], Generator[task.Task, Any, Any] | TOutput]

durabletask/task/task.py

Lines changed: 101 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
# See https://peps.python.org/pep-0563/
2+
from __future__ import annotations
3+
14
from abc import ABC, abstractmethod
2-
from typing import Type, TypeVar, Generic
5+
from typing import Generic, List, TypeVar
36

4-
import durabletask.protos.orchestrator_service_pb2 as pb
57
import durabletask.protos.helpers as pbh
8+
import durabletask.protos.orchestrator_service_pb2 as pb
69

710
T = TypeVar('T')
811

@@ -31,57 +34,132 @@ def stack_trace(self) -> str | None:
3134

3235
class Task(ABC, Generic[T]):
3336
"""Abstract base class for asynchronous tasks in a durable orchestration."""
37+
_result: T
38+
_exception: TaskFailedError | None
39+
_parent: CompositeTask[T] | None
3440

3541
def __init__(self) -> None:
3642
super().__init__()
37-
pass
43+
self._is_complete = False
44+
self._exception = None
45+
self._parent = None
3846

39-
@abstractmethod
47+
@property
4048
def is_complete(self) -> bool:
41-
pass
49+
"""Returns True if the task has completed, False otherwise."""
50+
return self._is_complete
4251

43-
@abstractmethod
52+
@property
4453
def is_failed(self) -> bool:
45-
pass
54+
"""Returns True if the task has failed, False otherwise."""
55+
return self._exception is not None
4656

47-
@abstractmethod
4857
def get_result(self) -> T:
49-
pass
58+
"""Returns the result of the task."""
59+
if not self._is_complete:
60+
raise ValueError('The task has not completed.')
61+
elif self._exception is not None:
62+
raise self._exception
63+
return self._result
5064

51-
@abstractmethod
5265
def get_exception(self) -> TaskFailedError:
66+
"""Returns the exception that caused the task to fail."""
67+
if self._exception is None:
68+
raise ValueError('The task has not failed.')
69+
return self._exception
70+
71+
72+
class CompositeTask(Task[T]):
73+
"""A task that is composed of other tasks."""
74+
_tasks: List[Task]
75+
76+
def __init__(self, tasks: List[Task]):
77+
super().__init__()
78+
self._tasks = tasks
79+
self._completed_tasks = 0
80+
self._failed_tasks = 0
81+
for task in tasks:
82+
task._parent = self
83+
84+
def get_tasks(self) -> List[Task]:
85+
return self._tasks
86+
87+
@abstractmethod
88+
def on_child_completed(self, task: Task[T]):
5389
pass
5490

5591

5692
class CompletableTask(Task[T]):
57-
_result: T | None
58-
_exception: TaskFailedError | None
5993

6094
def __init__(self):
6195
super().__init__()
62-
self._is_complete = False
63-
self._result = None
64-
self._exception = None
6596

6697
def complete(self, result: T):
98+
if self._is_complete:
99+
raise ValueError('The task has already completed.')
67100
self._result = result
68101
self._is_complete = True
102+
if self._parent is not None:
103+
self._parent.on_child_completed(self)
69104

70105
def fail(self, details: pb.TaskFailureDetails):
106+
if self._is_complete:
107+
raise ValueError('The task has already completed.')
71108
self._exception = TaskFailedError(
72109
details.errorMessage,
73110
details.errorType,
74111
details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None)
75112
self._is_complete = True
113+
if self._parent is not None:
114+
self._parent.on_child_completed(self)
76115

77-
def is_complete(self) -> bool:
78-
return self._is_complete
79116

80-
def is_failed(self) -> bool:
81-
return self._exception is not None
117+
class WhenAllTask(CompositeTask[List[T]]):
118+
"""A task that completes when all of its child tasks complete."""
82119

83-
def get_result(self) -> T | None:
84-
return self._result
120+
def __init__(self, tasks: List[Task[T]]):
121+
super().__init__(tasks)
122+
self._completed_tasks = 0
123+
self._failed_tasks = 0
85124

86-
def get_exception(self) -> TaskFailedError | None:
87-
return self._exception
125+
@property
126+
def pending_tasks(self) -> int:
127+
"""Returns the number of tasks that have not yet completed."""
128+
return len(self._tasks) - self._completed_tasks
129+
130+
def on_child_completed(self, task: Task[T]):
131+
if self.is_complete:
132+
raise ValueError('The task has already completed.')
133+
self._completed_tasks += 1
134+
if task.is_failed and self._exception is None:
135+
self._exception = task.get_exception()
136+
self._is_complete = True
137+
if self._completed_tasks == len(self._tasks):
138+
# The order of the result MUST match the order of the tasks provided to the constructor.
139+
self._result = [task.get_result() for task in self._tasks]
140+
self._is_complete = True
141+
142+
def get_completed_tasks(self) -> int:
143+
return self._completed_tasks
144+
145+
146+
class WhenAnyTask(CompositeTask[Task]):
147+
"""A task that completes when any of its child tasks complete."""
148+
149+
def __init__(self, tasks: List[Task]):
150+
super().__init__(tasks)
151+
152+
def on_child_completed(self, task: Task):
153+
if not self.is_complete:
154+
self._is_complete = True
155+
self._result = task
156+
157+
158+
def when_all(tasks: List[Task[T]]) -> WhenAllTask[T]:
159+
"""Returns a task that completes when all of the provided tasks complete or when one of the tasks fail."""
160+
return WhenAllTask(tasks)
161+
162+
163+
def when_any(tasks: List[Task]) -> WhenAnyTask:
164+
"""Returns a task that completes when any of the provided tasks complete or fail."""
165+
return WhenAnyTask(tasks)

0 commit comments

Comments
 (0)