Skip to content
Merged
77 changes: 71 additions & 6 deletions durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def set_custom_status(self, custom_status: Any) -> None:
pass

@abstractmethod
def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
def create_timer(self, fire_at: Union[datetime, timedelta]) -> CancellableTask:
"""Create a Timer Task to fire after at the specified deadline.

Parameters
Expand Down Expand Up @@ -231,7 +231,7 @@ def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput
# TOOD: Add a timeout parameter, which allows the task to be canceled if the event is
# not received within the specified timeout. This requires support for task cancellation.
@abstractmethod
def wait_for_external_event(self, name: str) -> CompletableTask:
def wait_for_external_event(self, name: str) -> CancellableTask:
"""Wait asynchronously for an event to be raised with the name `name`.

Parameters
Expand Down Expand Up @@ -324,6 +324,10 @@ class OrchestrationStateError(Exception):
pass


class TaskCanceledError(Exception):
"""Exception type for canceled orchestration tasks."""


class Task(ABC, Generic[T]):
"""Abstract base class for asynchronous tasks in a durable orchestration."""
_result: T
Expand Down Expand Up @@ -435,6 +439,48 @@ def fail(self, message: str, details: Union[Exception, pb.TaskFailureDetails]):
self._parent.on_child_completed(self)


class CancellableTask(CompletableTask[T]):
Comment thread
andystaples marked this conversation as resolved.
"""A completable task that can be canceled before it finishes."""

def __init__(self) -> None:
super().__init__()
self._is_cancelled = False
self._cancel_handler: Optional[Callable[[], None]] = None

@property
def is_cancelled(self) -> bool:
"""Returns True if the task was canceled, False otherwise."""
return self._is_cancelled

def get_result(self) -> T:
if self._is_cancelled:
raise TaskCanceledError('The task was canceled.')
return super().get_result()
Comment thread
andystaples marked this conversation as resolved.

def set_cancel_handler(self, cancel_handler: Callable[[], None]) -> None:
self._cancel_handler = cancel_handler

def cancel(self) -> bool:
"""Attempts to cancel this task.

Returns
-------
bool
True if cancellation was applied, False if the task had already completed.
"""
if self._is_complete:
return False

if self._cancel_handler is not None:
self._cancel_handler()

self._is_cancelled = True
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)
return True


class RetryableTask(CompletableTask[T]):
"""A task that can be retried according to a retry policy."""

Expand Down Expand Up @@ -474,13 +520,32 @@ def compute_next_delay(self) -> Optional[timedelta]:
return None


class TimerTask(CompletableTask[T]):
class TimerTask(CancellableTask[None]):
def set_retryable_parent(self, retryable_task: RetryableTask):
self._retryable_parent = retryable_task

def complete(self, *args, **kwargs):
Comment thread
andystaples marked this conversation as resolved.
Outdated
super().complete(None)

def __init__(self) -> None:

class LongTimerTask(TimerTask):
Comment thread
andystaples marked this conversation as resolved.
Outdated
def __init__(self, final_fire_at: datetime, maximum_timer_duration: timedelta):
super().__init__()
self._final_fire_at = final_fire_at
self._maximum_timer_duration = maximum_timer_duration

def set_retryable_parent(self, retryable_task: RetryableTask):
self._retryable_parent = retryable_task
def start(self, current_utc_datetime: datetime) -> datetime:
return self._get_next_fire_at(current_utc_datetime)

def complete(self, current_utc_datetime: datetime):
if current_utc_datetime < self._final_fire_at:
return self._get_next_fire_at(current_utc_datetime)
super().complete(None)
Comment thread
andystaples marked this conversation as resolved.
Outdated

def _get_next_fire_at(self, current_utc_datetime: datetime) -> datetime:
if current_utc_datetime + self._maximum_timer_duration < self._final_fire_at:
return current_utc_datetime + self._maximum_timer_duration
return self._final_fire_at


class WhenAnyTask(CompositeTask[Task]):
Comment thread
andystaples marked this conversation as resolved.
Expand Down
132 changes: 102 additions & 30 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TInput = TypeVar("TInput")
TOutput = TypeVar("TOutput")
DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'
DEFAULT_MAXIMUM_TIMER_INTERVAL = timedelta(days=3)


class ConcurrencyOptions:
Expand Down Expand Up @@ -307,7 +308,7 @@ class TaskHubGrpcWorker:
activity function.
"""

_response_stream: Optional[grpc.Future] = None
_response_stream: Optional[Any] = None
_interceptors: Optional[list[shared.ClientInterceptor]] = None

def __init__(
Expand All @@ -320,6 +321,7 @@ def __init__(
secure_channel: bool = False,
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
concurrency_options: Optional[ConcurrencyOptions] = None,
maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL
):
self._registry = _Registry()
self._host_address = (
Expand Down Expand Up @@ -348,12 +350,18 @@ def __init__(
self._interceptors = None

self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger)
self._maximum_timer_interval = maximum_timer_interval
Comment thread
andystaples marked this conversation as resolved.

@property
def concurrency_options(self) -> ConcurrencyOptions:
"""Get the current concurrency options for this worker."""
return self._concurrency_options

@property
def maximum_timer_interval(self) -> Optional[timedelta]:
"""Get the configured maximum timer interval for long timer chunking."""
return self._maximum_timer_interval

def __enter__(self):
return self

Expand Down Expand Up @@ -512,7 +520,11 @@ def should_invalidate_connection(rpc_error):

def stream_reader():
try:
for work_item in self._response_stream:
response_stream = self._response_stream
if response_stream is None:
return

for work_item in response_stream:
work_item_queue.put(work_item)
except Exception as e:
work_item_queue.put(e)
Expand Down Expand Up @@ -822,7 +834,11 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext):
_generator: Optional[Generator[task.Task, Any, Any]]
_previous_task: Optional[task.Task]

def __init__(self, instance_id: str, registry: _Registry):
def __init__(self,
instance_id: str,
registry: _Registry,
maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL,
):
self._generator = None
self._is_replaying = True
self._is_complete = False
Expand All @@ -843,10 +859,11 @@ def __init__(self, instance_id: str, registry: _Registry):
self._version: Optional[str] = None
self._completion_status: Optional[pb.OrchestrationStatus] = None
self._received_events: dict[str, list[Any]] = {}
self._pending_events: dict[str, list[task.CompletableTask]] = {}
self._pending_events: dict[str, list[task.CancellableTask]] = {}
self._new_input: Optional[Any] = None
self._save_events = False
self._encoded_custom_status: Optional[str] = None
self._maximum_timer_interval = maximum_timer_interval

def run(self, generator: Generator[task.Task, Any, Any]):
self._generator = generator
Expand Down Expand Up @@ -1022,11 +1039,26 @@ def create_timer_internal(
) -> task.TimerTask:
Comment thread
andystaples marked this conversation as resolved.
id = self.next_sequence_number()
if isinstance(fire_at, timedelta):
fire_at = self.current_utc_datetime + fire_at
action = ph.new_create_timer_action(id, fire_at)
final_fire_at = self.current_utc_datetime + fire_at
else:
final_fire_at = fire_at

next_fire_at: datetime = final_fire_at

if self._maximum_timer_interval is not None and self.current_utc_datetime + self._maximum_timer_interval < final_fire_at:
Comment thread
andystaples marked this conversation as resolved.
Outdated
timer_task = task.LongTimerTask(final_fire_at, self._maximum_timer_interval)
next_fire_at = timer_task.start(self.current_utc_datetime)
else:
timer_task = task.TimerTask()

action = ph.new_create_timer_action(id, next_fire_at)
self._pending_actions[id] = action

timer_task: task.TimerTask = task.TimerTask()
def _cancel_timer() -> None:
self._pending_actions.pop(id, None)
self._pending_tasks.pop(id, None)

timer_task.set_cancel_handler(_cancel_timer)
if retryable_task is not None:
timer_task.set_retryable_parent(retryable_task)
self._pending_tasks[id] = timer_task
Expand Down Expand Up @@ -1234,13 +1266,13 @@ def _exit_critical_section(self) -> None:
action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message)
self._pending_actions[task_id] = action

def wait_for_external_event(self, name: str) -> task.CompletableTask:
def wait_for_external_event(self, name: str) -> task.CancellableTask:
# Check to see if this event has already been received, in which case we
# can return it immediately. Otherwise, record out intent to receive an
# event with the given name so that we can resume the generator when it
# arrives. If there are multiple events with the same name, we return
# them in the order they were received.
external_event_task: task.CompletableTask = task.CompletableTask()
external_event_task: task.CancellableTask = task.CancellableTask()
event_name = name.casefold()
event_list = self._received_events.get(event_name, None)
if event_list:
Expand All @@ -1254,6 +1286,19 @@ def wait_for_external_event(self, name: str) -> task.CompletableTask:
task_list = []
self._pending_events[event_name] = task_list
task_list.append(external_event_task)

def _cancel_wait() -> None:
waiting_tasks = self._pending_events.get(event_name)
if waiting_tasks is None:
return
try:
waiting_tasks.remove(external_event_task)
except ValueError:
return
if not waiting_tasks:
del self._pending_events[event_name]

external_event_task.set_cancel_handler(_cancel_wait)
return external_event_task

def continue_as_new(self, new_input, *, save_events: bool = False) -> None:
Expand Down Expand Up @@ -1288,9 +1333,13 @@ def __init__(
class _OrchestrationExecutor:
_generator: Optional[task.Orchestrator] = None

def __init__(self, registry: _Registry, logger: logging.Logger):
def __init__(self,
registry: _Registry,
logger: logging.Logger,
maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL):
self._registry = registry
self._logger = logger
self._maximum_timer_interval = maximum_timer_interval
self._is_suspended = False
self._suspended_events: list[pb.HistoryEvent] = []

Expand All @@ -1314,7 +1363,11 @@ def execute(
"The new history event list must have at least one event in it."
)

ctx = _RuntimeOrchestrationContext(instance_id, self._registry)
ctx = _RuntimeOrchestrationContext(
instance_id,
self._registry,
maximum_timer_interval=self._maximum_timer_interval,
)
try:
# Rebuild local state by replaying old history into the orchestrator function
self._logger.debug(
Expand Down Expand Up @@ -1450,27 +1503,46 @@ def process_event(
f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}."
)
return
timer_task.complete(None)
if timer_task._retryable_parent is not None:
activity_action = timer_task._retryable_parent._action
if not (isinstance(timer_task, task.TimerTask) or isinstance(timer_task, task.LongTimerTask)):
if not ctx._is_replaying:
self._logger.warning(
Comment thread
andystaples marked this conversation as resolved.
f"{ctx.instance_id}: Ignoring timerFired event with non-timer task ID = {timer_id}."
)
return

if not timer_task._retryable_parent._is_sub_orch:
cur_task = activity_action.scheduleTask
instance_id = None
else:
cur_task = activity_action.createSubOrchestration
instance_id = cur_task.instanceId
ctx.call_activity_function_helper(
id=activity_action.id,
activity_function=cur_task.name,
input=cur_task.input.value,
retry_policy=timer_task._retryable_parent._retry_policy,
is_sub_orch=timer_task._retryable_parent._is_sub_orch,
instance_id=instance_id,
fn_task=timer_task._retryable_parent,
)
next_fire_at = timer_task.complete(event.timerFired.fireAt.ToDatetime())
if next_fire_at is not None:
id = ctx.next_sequence_number()
new_action = ph.new_create_timer_action(id, next_fire_at)
ctx._pending_tasks[id] = timer_task
ctx._pending_actions[id] = new_action

def _cancel_timer() -> None:
ctx._pending_actions.pop(id, None)
ctx._pending_tasks.pop(id, None)

timer_task.set_cancel_handler(_cancel_timer)
else:
ctx.resume()
if timer_task._retryable_parent is not None:
activity_action = timer_task._retryable_parent._action

if not timer_task._retryable_parent._is_sub_orch:
cur_task = activity_action.scheduleTask
instance_id = None
else:
cur_task = activity_action.createSubOrchestration
instance_id = cur_task.instanceId
ctx.call_activity_function_helper(
id=activity_action.id,
activity_function=cur_task.name,
input=cur_task.input.value,
retry_policy=timer_task._retryable_parent._retry_policy,
is_sub_orch=timer_task._retryable_parent._is_sub_orch,
instance_id=instance_id,
fn_task=timer_task._retryable_parent,
)
else:
ctx.resume()
elif event.HasField("taskScheduled"):
# This history event confirms that the activity execution was successfully scheduled.
# Remove the taskScheduled event from the pending action list so we don't schedule it again.
Expand Down
Loading
Loading