|
| 1 | +from dataclasses import dataclass |
| 2 | +from datetime import datetime |
| 3 | +import logging |
| 4 | +from typing import Any |
| 5 | +import uuid |
| 6 | +import grpc |
| 7 | +from durabletask.api.state import OrchestrationState, new_orchestration_state |
| 8 | +import durabletask.protos.orchestrator_service_pb2 as pb |
| 9 | +import durabletask.internal.shared as shared |
| 10 | +import simplejson as json |
| 11 | + |
| 12 | +from google.protobuf import timestamp_pb2, wrappers_pb2 |
| 13 | + |
| 14 | +from durabletask.protos.orchestrator_service_pb2_grpc import TaskHubSidecarServiceStub |
| 15 | + |
| 16 | + |
| 17 | +class TaskHubGrpcClient: |
| 18 | + |
| 19 | + def __init__(self, *, |
| 20 | + host_address: str | None = None, |
| 21 | + log_handler=None, |
| 22 | + log_formatter: logging.Formatter | None = None): |
| 23 | + channel = shared.get_grpc_channel(host_address) |
| 24 | + self._stub = TaskHubSidecarServiceStub(channel) |
| 25 | + self._logger = shared.get_logger(log_handler, log_formatter) |
| 26 | + |
| 27 | + def schedule_new_orchestration(self, name: str, *, |
| 28 | + input: Any = None, |
| 29 | + instance_id: str | None = None, |
| 30 | + start_at: datetime | None = None) -> str: |
| 31 | + req = pb.CreateInstanceRequest(name=name) |
| 32 | + if instance_id is None: |
| 33 | + instance_id = uuid.uuid4().hex |
| 34 | + req.instanceId = instance_id |
| 35 | + |
| 36 | + if input is not None: |
| 37 | + json_input = json.dumps(input) |
| 38 | + req.input = wrappers_pb2.StringValue(value=json_input) |
| 39 | + |
| 40 | + if start_at is not None: |
| 41 | + req.scheduledStartTimestamp = timestamp_pb2.Timestamp() |
| 42 | + req.scheduledStartTimestamp.FromDatetime(start_at) |
| 43 | + self._logger.info(f"Starting new '{name}' instance with ID = '{instance_id}'.") |
| 44 | + res: pb.CreateInstanceResponse = self._stub.StartInstance(req) |
| 45 | + return res.instanceId |
| 46 | + |
| 47 | + def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> OrchestrationState | None: |
| 48 | + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) |
| 49 | + res: pb.GetInstanceResponse = self._stub.GetInstance(req) |
| 50 | + return new_orchestration_state(req.instanceId, res) |
| 51 | + |
| 52 | + def wait_for_orchestration_start(self, instance_id: str, *, |
| 53 | + fetch_payloads: bool = False, |
| 54 | + timeout: int = 60) -> OrchestrationState | None: |
| 55 | + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) |
| 56 | + try: |
| 57 | + self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to start.") |
| 58 | + res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=timeout) |
| 59 | + return new_orchestration_state(req.instanceId, res) |
| 60 | + except grpc.RpcError as rpc_error: |
| 61 | + if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore |
| 62 | + # Replace gRPC error with the built-in TimeoutError |
| 63 | + raise TimeoutError("Timed-out waiting for the orchestration to start") |
| 64 | + else: |
| 65 | + raise |
| 66 | + |
| 67 | + def wait_for_orchestration_completion(self, instance_id: str, *, |
| 68 | + fetch_payloads: bool = True, |
| 69 | + timeout: int = 60) -> OrchestrationState | None: |
| 70 | + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) |
| 71 | + try: |
| 72 | + self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") |
| 73 | + res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout) |
| 74 | + return new_orchestration_state(req.instanceId, res) |
| 75 | + except grpc.RpcError as rpc_error: |
| 76 | + if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore |
| 77 | + # Replace gRPC error with the built-in TimeoutError |
| 78 | + raise TimeoutError("Timed-out waiting for the orchestration to complete") |
| 79 | + else: |
| 80 | + raise |
| 81 | + |
| 82 | + def terminate_orchestration(self): |
| 83 | + pass |
| 84 | + |
| 85 | + def suspend_orchestration(self): |
| 86 | + pass |
| 87 | + |
| 88 | + def resume_orchestration(self): |
| 89 | + pass |
| 90 | + |
| 91 | + def raise_orchestration_event(self): |
| 92 | + pass |
0 commit comments