Skip to content

Commit 9954469

Browse files
authored
External events and other improvements (#3)
- Removed simplejson dependency - Added OrchestrationFailedError class to client module - Fixed race condition with WhenAnyTask - Support for timedelta in create_timer - Improved task error messages - Support for custom object roundtripping
1 parent eecc56a commit 9954469

14 files changed

Lines changed: 497 additions & 57 deletions

.vscode/launch.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
"type": "python",
77
"request": "launch",
88
"program": "${file}",
9+
"cwd": "${fileDirname}",
910
"purpose": [
1011
"debug-test"
1112
],
1213
"env": {
1314
// pytest-cov breaks debugging, so we have to disable it during debug sessions
14-
"PYTEST_ADDOPTS": "--no-cov"
15+
"PYTEST_ADDOPTS": "--no-cov",
16+
"PYTHONPATH": "${workspaceFolder}"
1517
},
1618
"console": "integratedTerminal",
1719
"justMyCode": false

README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,37 @@ def orchestrator(ctx: task.OrchestrationContext, _):
7171

7272
You can find the full sample [here](./examples/fanout_fanin.py).
7373

74+
### Human interaction and durable timers
75+
76+
An orchestration can wait for a user-defined event, such as a human approval event, before proceding to the next step. In addition, the orchestration can create a timer with an arbitrary duration that triggers some alternate action if the external event hasn't been received:
77+
78+
```python
79+
def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order):
80+
"""Orchestrator function that represents a purchase order workflow"""
81+
# Orders under $1000 are auto-approved
82+
if order.Cost < 1000:
83+
return "Auto-approved"
84+
85+
# Orders of $1000 or more require manager approval
86+
yield ctx.call_activity(send_approval_request, input=order)
87+
88+
# Approvals must be received within 24 hours or they will be canceled.
89+
approval_event = ctx.wait_for_external_event("approval_received")
90+
timeout_event = ctx.create_timer(timedelta(hours=24))
91+
winner = yield task.when_any([approval_event, timeout_event])
92+
if winner == timeout_event:
93+
return "Canceled"
94+
95+
# The order was approved
96+
ctx.call_activity(place_order, input=order)
97+
approval_details = approval_event.get_result()
98+
return f"Approved by '{approval_details.approver}'"
99+
```
100+
101+
As an aside, you'll also notice that the example orchestration above works with custom business objects. Support for custom business objects includes support for custom classes, custom data classes, and named tuples. Serialization and deserialization of these objects is handled automatically by the SDK.
102+
103+
You can find the full sample [here](./examples/human_interaction.py).
104+
74105
## Getting Started
75106

76107
### Prerequisites

durabletask/client.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
from dataclasses import dataclass
77
from datetime import datetime
88
from enum import Enum
9-
from typing import TypeVar
9+
from typing import Any, TypeVar
1010

1111
import grpc
12-
import simplejson as json
1312
from google.protobuf import wrappers_pb2
1413

1514
import durabletask.internal.helpers as helpers
@@ -46,14 +45,38 @@ class OrchestrationState:
4645
serialized_input: str | None
4746
serialized_output: str | None
4847
serialized_custom_status: str | None
49-
failure_details: pb.TaskFailureDetails | None
48+
failure_details: task.FailureDetails | None
49+
50+
def raise_if_failed(self):
51+
if self.failure_details is not None:
52+
raise OrchestrationFailedError(
53+
f"Orchestration '{self.instance_id}' failed: {self.failure_details.message}",
54+
self.failure_details)
55+
56+
57+
class OrchestrationFailedError(Exception):
58+
def __init__(self, message: str, failure_details: task.FailureDetails):
59+
super().__init__(message)
60+
self._failure_details = failure_details
61+
62+
@property
63+
def failure_details(self):
64+
return self._failure_details
5065

5166

5267
def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> OrchestrationState | None:
5368
if not res.exists:
5469
return None
5570

5671
state = res.orchestrationState
72+
73+
failure_details = None
74+
if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '':
75+
failure_details = task.FailureDetails(
76+
state.failureDetails.errorMessage,
77+
state.failureDetails.errorType,
78+
state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None)
79+
5780
return OrchestrationState(
5881
instance_id,
5982
state.name,
@@ -63,7 +86,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Or
6386
state.input.value if not helpers.is_empty(state.input) else None,
6487
state.output.value if not helpers.is_empty(state.output) else None,
6588
state.customStatus.value if not helpers.is_empty(state.customStatus) else None,
66-
state.failureDetails if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '' else None)
89+
failure_details)
6790

6891

6992
class TaskHubGrpcClient:
@@ -86,7 +109,7 @@ def schedule_new_orchestration(self, orchestrator: task.Orchestrator[TInput, TOu
86109
req = pb.CreateInstanceRequest(
87110
name=name,
88111
instanceId=instance_id if instance_id else uuid.uuid4().hex,
89-
input=wrappers_pb2.StringValue(value=json.dumps(input)) if input else None,
112+
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None,
90113
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None)
91114

92115
self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
@@ -128,6 +151,16 @@ def wait_for_orchestration_completion(self, instance_id: str, *,
128151
else:
129152
raise
130153

154+
def raise_orchestration_event(self, instance_id: str, event_name: str, *,
155+
data: Any | None = None):
156+
req = pb.RaiseEventRequest(
157+
instanceId=instance_id,
158+
name=event_name,
159+
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
160+
161+
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
162+
self._stub.RaiseEvent(req)
163+
131164
def terminate_orchestration(self):
132165
pass
133166

@@ -136,6 +169,3 @@ def suspend_orchestration(self):
136169

137170
def resume_orchestration(self):
138171
pass
139-
140-
def raise_orchestration_event(self):
141-
pass

durabletask/internal/helpers.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33

44
import traceback
55
from datetime import datetime
6-
from typing import Any
76

8-
import simplejson as json
97
from google.protobuf import timestamp_pb2, wrappers_pb2
108

119
import durabletask.internal.orchestrator_service_pb2 as pb
@@ -117,6 +115,14 @@ def new_failure_details(ex: Exception) -> pb.TaskFailureDetails:
117115
)
118116

119117

118+
def new_event_raised_event(name: str, encoded_input: str | None = None) -> pb.HistoryEvent:
119+
return pb.HistoryEvent(
120+
eventId=-1,
121+
timestamp=timestamp_pb2.Timestamp(),
122+
eventRaised=pb.EventRaisedEvent(name=name, input=get_string_value(encoded_input))
123+
)
124+
125+
120126
def get_string_value(val: str | None) -> wrappers_pb2.StringValue | None:
121127
if val is None:
122128
return None
@@ -146,8 +152,7 @@ def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction
146152
return pb.OrchestratorAction(id=id, createTimer=pb.CreateTimerAction(fireAt=timestamp))
147153

148154

149-
def new_schedule_task_action(id: int, name: str, input: Any) -> pb.OrchestratorAction:
150-
encoded_input = json.dumps(input) if input is not None else None
155+
def new_schedule_task_action(id: int, name: str, encoded_input: str | None) -> pb.OrchestratorAction:
151156
return pb.OrchestratorAction(id=id, scheduleTask=pb.ScheduleTaskAction(
152157
name=name,
153158
input=get_string_value(encoded_input)
@@ -164,8 +169,7 @@ def new_create_sub_orchestration_action(
164169
id: int,
165170
name: str,
166171
instance_id: str | None,
167-
input: Any) -> pb.OrchestratorAction:
168-
encoded_input = json.dumps(input) if input is not None else None
172+
encoded_input: str | None) -> pb.OrchestratorAction:
169173
return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction(
170174
name=name,
171175
instanceId=instance_id,

durabletask/internal/shared.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4+
import dataclasses
5+
import json
46
import logging
7+
from types import SimpleNamespace
8+
from typing import Any, Dict
59

610
import grpc
711

12+
# Field name used to indicate that an object was automatically serialized
13+
# and should be deserialized as a SimpleNamespace
14+
AUTO_SERIALIZED = "__durabletask_autoobject__"
15+
816

917
def get_default_host_address() -> str:
1018
return "localhost:4001"
@@ -35,3 +43,49 @@ def get_logger(
3543
datefmt='%Y-%m-%d %H:%M:%S')
3644
log_handler.setFormatter(log_formatter)
3745
return logger
46+
47+
48+
def to_json(obj):
49+
return json.dumps(obj, cls=InternalJSONEncoder)
50+
51+
52+
def from_json(json_str):
53+
return json.loads(json_str, cls=InternalJSONDecoder)
54+
55+
56+
class InternalJSONEncoder(json.JSONEncoder):
57+
"""JSON encoder that supports serializing specific Python types."""
58+
59+
def encode(self, obj: Any) -> str:
60+
# if the object is a namedtuple, convert it to a dict with the AUTO_SERIALIZED key added
61+
if isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_asdict"):
62+
d = obj._asdict() # type: ignore
63+
d[AUTO_SERIALIZED] = True
64+
obj = d
65+
return super().encode(obj)
66+
67+
def default(self, obj):
68+
if dataclasses.is_dataclass(obj):
69+
# Dataclasses are not serializable by default, so we convert them to a dict and mark them for
70+
# automatic deserialization by the receiver
71+
d = dataclasses.asdict(obj)
72+
d[AUTO_SERIALIZED] = True
73+
return d
74+
elif isinstance(obj, SimpleNamespace):
75+
# Most commonly used for serializing custom objects that were previously serialized using our encoder
76+
d = vars(obj)
77+
d[AUTO_SERIALIZED] = True
78+
return d
79+
# This will typically raise a TypeError
80+
return json.JSONEncoder.default(self, obj)
81+
82+
83+
class InternalJSONDecoder(json.JSONDecoder):
84+
def __init__(self, *args, **kwargs):
85+
super().__init__(object_hook=self.dict_to_object, *args, **kwargs)
86+
87+
def dict_to_object(self, d: Dict[str, Any]):
88+
# If the object was serialized by the InternalJSONEncoder, deserialize it as a SimpleNamespace
89+
if d.pop(AUTO_SERIALIZED, False):
90+
return SimpleNamespace(**d)
91+
return d

durabletask/task.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
from abc import ABC, abstractmethod
8-
from datetime import datetime
8+
from datetime import datetime, timedelta
99
from typing import Any, Callable, Generator, Generic, List, TypeVar
1010

1111
import durabletask.internal.helpers as pbh
@@ -70,13 +70,13 @@ def is_replaying(self) -> bool:
7070
pass
7171

7272
@abstractmethod
73-
def create_timer(self, fire_at: datetime) -> Task:
73+
def create_timer(self, fire_at: datetime | timedelta) -> Task:
7474
"""Create a Timer Task to fire after at the specified deadline.
7575
7676
Parameters
7777
----------
78-
fire_at: datetime.datetime
79-
The time for the timer to trigger
78+
fire_at: datetime.datetime | datetime.timedelta
79+
The time for the timer to trigger or a time delta from now.
8080
8181
Returns
8282
-------
@@ -129,12 +129,27 @@ def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
129129
"""
130130
pass
131131

132+
# TOOD: Add a timeout parameter, which allows the task to be canceled if the event is
133+
# not received within the specified timeout. This requires support for task cancellation.
134+
@abstractmethod
135+
def wait_for_external_event(self, name: str) -> Task:
136+
"""Wait asynchronously for an event to be raised with the name `name`.
132137
133-
class TaskFailedError(Exception):
134-
"""Exception type for all orchestration task failures."""
138+
Parameters
139+
----------
140+
name : str
141+
The event name of the event that the task is waiting for.
135142
143+
Returns
144+
-------
145+
Task[TOutput]
146+
A Durable Task that completes when the event is received.
147+
"""
148+
pass
149+
150+
151+
class FailureDetails:
136152
def __init__(self, message: str, error_type: str, stack_trace: str | None):
137-
super().__init__(message)
138153
self._message = message
139154
self._error_type = error_type
140155
self._stack_trace = stack_trace
@@ -152,6 +167,21 @@ def stack_trace(self) -> str | None:
152167
return self._stack_trace
153168

154169

170+
class TaskFailedError(Exception):
171+
"""Exception type for all orchestration task failures."""
172+
173+
def __init__(self, message: str, details: pb.TaskFailureDetails):
174+
super().__init__(message)
175+
self._details = FailureDetails(
176+
details.errorMessage,
177+
details.errorType,
178+
details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None)
179+
180+
@property
181+
def details(self) -> FailureDetails:
182+
return self._details
183+
184+
155185
class NonDeterminismError(Exception):
156186
pass
157187

@@ -208,6 +238,8 @@ def __init__(self, tasks: List[Task]):
208238
self._failed_tasks = 0
209239
for task in tasks:
210240
task._parent = self
241+
if task.is_complete:
242+
self.on_child_completed(task)
211243

212244
def get_tasks(self) -> List[Task]:
213245
return self._tasks
@@ -230,13 +262,10 @@ def complete(self, result: T):
230262
if self._parent is not None:
231263
self._parent.on_child_completed(self)
232264

233-
def fail(self, details: pb.TaskFailureDetails):
265+
def fail(self, message: str, details: pb.TaskFailureDetails):
234266
if self._is_complete:
235267
raise ValueError('The task has already completed.')
236-
self._exception = TaskFailedError(
237-
details.errorMessage,
238-
details.errorType,
239-
details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None)
268+
self._exception = TaskFailedError(message, details)
240269
self._is_complete = True
241270
if self._parent is not None:
242271
self._parent.on_child_completed(self)
@@ -278,6 +307,7 @@ def __init__(self, tasks: List[Task]):
278307
super().__init__(tasks)
279308

280309
def on_child_completed(self, task: Task):
310+
# The first task to complete is the result of the WhenAnyTask.
281311
if not self.is_complete:
282312
self._is_complete = True
283313
self._result = task

0 commit comments

Comments
 (0)