Skip to content

Commit 45622e9

Browse files
committed
Activity support, fixes, and stronger typing (#4)
1 parent fbb8e36 commit 45622e9

15 files changed

Lines changed: 1050 additions & 322 deletions

durabletask/api/state.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from datetime import datetime
33

44
import durabletask.protos.orchestrator_service_pb2 as pb
5+
import durabletask.protos.helpers as helpers
56
from durabletask.protos.orchestrator_service_pb2 import TaskFailureDetails
67

78

@@ -29,7 +30,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Or
2930
state.orchestrationStatus,
3031
state.createdTimestamp.ToDatetime(),
3132
state.lastUpdatedTimestamp.ToDatetime(),
32-
state.input.value if state.input is not None else None,
33-
state.output.value if state.output is not None else None,
34-
state.customStatus.value if state.customStatus is not None else None,
33+
state.input.value if not helpers.is_empty(state.input) else None,
34+
state.output.value if not helpers.is_empty(state.output) else None,
35+
state.customStatus.value if not helpers.is_empty(state.customStatus) else None,
3536
state.failureDetails if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '' else None)

durabletask/api/task_hub_client.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from dataclasses import dataclass
22
from datetime import datetime
33
import logging
4-
from typing import Any
4+
from typing import Any, TypeVar
55
import uuid
66
import grpc
77
from durabletask.api.state import OrchestrationState, new_orchestration_state
8+
import durabletask.protos.helpers as helpers
89
import durabletask.protos.orchestrator_service_pb2 as pb
910
import durabletask.internal.shared as shared
1011
import simplejson as json
1112

12-
from google.protobuf import timestamp_pb2, wrappers_pb2
13+
from google.protobuf import wrappers_pb2
1314

15+
import durabletask.task.registry as registry
1416
from durabletask.protos.orchestrator_service_pb2_grpc import TaskHubSidecarServiceStub
17+
from durabletask.task.orchestration import Orchestrator
18+
19+
TInput = TypeVar('TInput')
20+
TOutput = TypeVar('TOutput')
1521

1622

1723
class TaskHubGrpcClient:
@@ -24,22 +30,19 @@ def __init__(self, *,
2430
self._stub = TaskHubSidecarServiceStub(channel)
2531
self._logger = shared.get_logger(log_handler, log_formatter)
2632

27-
def schedule_new_orchestration(self, name: str, *,
28-
input: Any = None,
33+
def schedule_new_orchestration(self, orchestrator: Orchestrator[TInput, TOutput], *,
34+
input: TInput | None = None,
2935
instance_id: str | None = None,
3036
start_at: datetime | None = None) -> str:
31-
req = pb.CreateInstanceRequest(name=name)
32-
if instance_id is None:
33-
instance_id = uuid.uuid4().hex
34-
req.instanceId = instance_id
35-
36-
if input is not None:
37-
json_input = json.dumps(input)
38-
req.input = wrappers_pb2.StringValue(value=json_input)
39-
40-
if start_at is not None:
41-
req.scheduledStartTimestamp = timestamp_pb2.Timestamp()
42-
req.scheduledStartTimestamp.FromDatetime(start_at)
37+
38+
name = registry.get_name(orchestrator)
39+
40+
req = pb.CreateInstanceRequest(
41+
name=name,
42+
instanceId=instance_id if instance_id else uuid.uuid4().hex,
43+
input=wrappers_pb2.StringValue(value=json.dumps(input)) if input else None,
44+
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None)
45+
4346
self._logger.info(f"Starting new '{name}' instance with ID = '{instance_id}'.")
4447
res: pb.CreateInstanceResponse = self._stub.StartInstance(req)
4548
return res.instanceId

durabletask/internal/shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_grpc_channel(host_address: str | None) -> grpc.Channel:
1414
return channel
1515

1616

17-
def get_logger(log_handler: logging.Handler | None, log_formatter: logging.Formatter | None) -> logging.Logger:
17+
def get_logger(log_handler: logging.Handler | None = None, log_formatter: logging.Formatter | None = None) -> logging.Logger:
1818
logger = logging.Logger("durabletask")
1919

2020
# Add a default log handler if none is provided

durabletask/protos/helpers.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
import simplejson as json
12
import traceback
23

34
from datetime import datetime
5+
from typing import Any
46
from google.protobuf import timestamp_pb2, wrappers_pb2
57

68
from durabletask.protos.orchestrator_service_pb2 import *
79

10+
# TODO: The new_xxx_event methods are only used by test code and should be moved elsewhere
11+
812

913
def new_orchestrator_started_event(timestamp: datetime | None = None) -> HistoryEvent:
1014
ts = timestamp_pb2.Timestamp()
@@ -13,19 +17,15 @@ def new_orchestrator_started_event(timestamp: datetime | None = None) -> History
1317
return HistoryEvent(eventId=-1, timestamp=ts, orchestratorStarted=OrchestratorStartedEvent())
1418

1519

16-
def new_execution_started_event(name: str, instance_id: str, input: str | None = None) -> HistoryEvent:
17-
input_: wrappers_pb2.StringValue | None = None
18-
if input is not None:
19-
input_ = wrappers_pb2.StringValue(value=input)
20-
20+
def new_execution_started_event(name: str, instance_id: str, encoded_input: str | None = None) -> HistoryEvent:
2121
return HistoryEvent(
2222
eventId=-1,
2323
timestamp=timestamp_pb2.Timestamp(),
2424
executionStarted=ExecutionStartedEvent(
25-
name=name, input=input_, orchestrationInstance=OrchestrationInstance(instanceId=instance_id)))
25+
name=name, input=get_string_value(encoded_input), orchestrationInstance=OrchestrationInstance(instanceId=instance_id)))
2626

2727

28-
def new_timer_created_event(timer_id: int, fire_at: datetime):
28+
def new_timer_created_event(timer_id: int, fire_at: datetime) -> HistoryEvent:
2929
ts = timestamp_pb2.Timestamp()
3030
ts.FromDatetime(fire_at)
3131
return HistoryEvent(
@@ -35,7 +35,7 @@ def new_timer_created_event(timer_id: int, fire_at: datetime):
3535
)
3636

3737

38-
def new_timer_fired_event(timer_id: int, fire_at: datetime):
38+
def new_timer_fired_event(timer_id: int, fire_at: datetime) -> HistoryEvent:
3939
ts = timestamp_pb2.Timestamp()
4040
ts.FromDatetime(fire_at)
4141
return HistoryEvent(
@@ -45,6 +45,30 @@ def new_timer_fired_event(timer_id: int, fire_at: datetime):
4545
)
4646

4747

48+
def new_task_scheduled_event(event_id: int, name: str, encoded_input: str | None = None) -> HistoryEvent:
49+
return HistoryEvent(
50+
eventId=event_id,
51+
timestamp=timestamp_pb2.Timestamp(),
52+
taskScheduled=TaskScheduledEvent(name=name, input=get_string_value(encoded_input))
53+
)
54+
55+
56+
def new_task_completed_event(event_id: int, encoded_output: str | None = None) -> HistoryEvent:
57+
return HistoryEvent(
58+
eventId=-1,
59+
timestamp=timestamp_pb2.Timestamp(),
60+
taskCompleted=TaskCompletedEvent(taskScheduledId=event_id, result=get_string_value(encoded_output))
61+
)
62+
63+
64+
def new_task_failed_event(event_id: int, ex: Exception) -> HistoryEvent:
65+
return HistoryEvent(
66+
eventId=-1,
67+
timestamp=timestamp_pb2.Timestamp(),
68+
taskFailed=TaskFailedEvent(taskScheduledId=event_id, failureDetails=new_failure_details(ex))
69+
)
70+
71+
4872
def new_failure_details(ex: Exception) -> TaskFailureDetails:
4973
return TaskFailureDetails(
5074
errorType=type(ex).__name__,
@@ -53,19 +77,22 @@ def new_failure_details(ex: Exception) -> TaskFailureDetails:
5377
)
5478

5579

80+
def get_string_value(val: str | None) -> wrappers_pb2.StringValue | None:
81+
if val is None:
82+
return None
83+
else:
84+
return wrappers_pb2.StringValue(value=val)
85+
86+
5687
def new_complete_orchestration_action(
5788
id: int,
5889
status: OrchestrationStatus,
5990
result: str | None = None,
6091
failure_details: TaskFailureDetails | None = None) -> OrchestratorAction:
6192

62-
result_pb: wrappers_pb2.StringValue | None = None
63-
if result is not None:
64-
result_pb = wrappers_pb2.StringValue(value=result)
65-
6693
completeOrchestrationAction = CompleteOrchestrationAction(
6794
orchestrationStatus=status,
68-
result=result_pb,
95+
result=get_string_value(result),
6996
failureDetails=failure_details)
7097

7198
# TODO: CarryoverEvents
@@ -77,3 +104,21 @@ def new_create_timer_action(id: int, fire_at: datetime) -> OrchestratorAction:
77104
timestamp = timestamp_pb2.Timestamp()
78105
timestamp.FromDatetime(fire_at)
79106
return OrchestratorAction(id=id, createTimer=CreateTimerAction(fireAt=timestamp))
107+
108+
109+
def new_schedule_task_action(id: int, name: str, input: Any) -> OrchestratorAction:
110+
encoded_input = json.dumps(input) if input is not None else None
111+
return OrchestratorAction(id=id, scheduleTask=ScheduleTaskAction(
112+
name=name,
113+
input=get_string_value(encoded_input)
114+
))
115+
116+
117+
def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp:
118+
ts = timestamp_pb2.Timestamp()
119+
ts.FromDatetime(dt)
120+
return ts
121+
122+
123+
def is_empty(v: wrappers_pb2.StringValue):
124+
return v is None or v.value == ''

durabletask/task/activities.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
2+
from typing import Callable, TypeVar
3+
4+
5+
TInput = TypeVar('TInput')
6+
TOutput = TypeVar('TOutput')
7+
8+
9+
class ActivityContext:
10+
def __init__(self, orchestration_id: str, task_id: int):
11+
self._orchestration_id = orchestration_id
12+
self._task_id = task_id
13+
14+
@property
15+
def orchestration_id(self) -> str:
16+
"""Get the ID of the orchestration instance that scheduled this activity.
17+
18+
Returns
19+
-------
20+
str
21+
The ID of the current orchestration instance.
22+
"""
23+
return self._orchestration_id
24+
25+
@property
26+
def task_id(self) -> int:
27+
"""Get the task ID associated with this activity invocation.
28+
29+
The task ID is an auto-incrementing integer that is unique within
30+
the scope of the orchestration instance. It can be used to distinguish
31+
between multiple activity invocations that are part of the same
32+
orchestration instance.
33+
34+
Returns
35+
-------
36+
str
37+
The ID of the current orchestration instance.
38+
"""
39+
return self._task_id
40+
41+
42+
Activity = Callable[[ActivityContext, TInput], TOutput]

0 commit comments

Comments
 (0)