66from dataclasses import dataclass
77from datetime import datetime
88from enum import Enum
9- from typing import Any , TypeVar
9+ from typing import Any , TypeVar , Union
1010
1111import grpc
1212from google .protobuf import wrappers_pb2
@@ -42,10 +42,10 @@ class OrchestrationState:
4242 runtime_status : OrchestrationStatus
4343 created_at : datetime
4444 last_updated_at : datetime
45- serialized_input : str | None
46- serialized_output : str | None
47- serialized_custom_status : str | None
48- failure_details : task .FailureDetails | None
45+ serialized_input : Union [ str , None ]
46+ serialized_output : Union [ str , None ]
47+ serialized_custom_status : Union [ str , None ]
48+ failure_details : Union [ task .FailureDetails , None ]
4949
5050 def raise_if_failed (self ):
5151 if self .failure_details is not None :
@@ -64,7 +64,7 @@ def failure_details(self):
6464 return self ._failure_details
6565
6666
67- def new_orchestration_state (instance_id : str , res : pb .GetInstanceResponse ) -> OrchestrationState | None :
67+ def new_orchestration_state (instance_id : str , res : pb .GetInstanceResponse ) -> Union [ OrchestrationState , None ] :
6868 if not res .exists :
6969 return None
7070
@@ -92,38 +92,39 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Or
9292class TaskHubGrpcClient :
9393
9494 def __init__ (self , * ,
95- host_address : str | None = None ,
95+ host_address : Union [ str , None ] = None ,
9696 log_handler = None ,
97- log_formatter : logging .Formatter | None = None ):
97+ log_formatter : Union [ logging .Formatter , None ] = None ):
9898 channel = shared .get_grpc_channel (host_address )
9999 self ._stub = stubs .TaskHubSidecarServiceStub (channel )
100100 self ._logger = shared .get_logger (log_handler , log_formatter )
101101
102- def schedule_new_orchestration (self , orchestrator : task .Orchestrator [TInput , TOutput ], * ,
103- input : TInput | None = None ,
104- instance_id : str | None = None ,
105- start_at : datetime | None = None ) -> str :
102+ def schedule_new_orchestration (self , orchestrator : Union [ task .Orchestrator [TInput , TOutput ], str ], * ,
103+ input : Union [ TInput , None ] = None ,
104+ instance_id : Union [ str , None ] = None ,
105+ start_at : Union [ datetime , None ] = None ) -> str :
106106
107- name = task .get_name (orchestrator )
107+ name = orchestrator if isinstance ( orchestrator , str ) else task .get_name (orchestrator )
108108
109109 req = pb .CreateInstanceRequest (
110110 name = name ,
111111 instanceId = instance_id if instance_id else uuid .uuid4 ().hex ,
112112 input = wrappers_pb2 .StringValue (value = shared .to_json (input )) if input else None ,
113- scheduledStartTimestamp = helpers .new_timestamp (start_at ) if start_at else None )
113+ scheduledStartTimestamp = helpers .new_timestamp (start_at ) if start_at else None ,
114+ version = wrappers_pb2 .StringValue (value = "" ))
114115
115116 self ._logger .info (f"Starting new '{ name } ' instance with ID = '{ req .instanceId } '." )
116117 res : pb .CreateInstanceResponse = self ._stub .StartInstance (req )
117118 return res .instanceId
118119
119- def get_orchestration_state (self , instance_id : str , * , fetch_payloads : bool = True ) -> OrchestrationState | None :
120+ def get_orchestration_state (self , instance_id : str , * , fetch_payloads : bool = True ) -> Union [ OrchestrationState , None ] :
120121 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
121122 res : pb .GetInstanceResponse = self ._stub .GetInstance (req )
122123 return new_orchestration_state (req .instanceId , res )
123124
124125 def wait_for_orchestration_start (self , instance_id : str , * ,
125126 fetch_payloads : bool = False ,
126- timeout : int = 60 ) -> OrchestrationState | None :
127+ timeout : int = 60 ) -> Union [ OrchestrationState , None ] :
127128 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
128129 try :
129130 self ._logger .info (f"Waiting up to { timeout } s for instance '{ instance_id } ' to start." )
@@ -138,7 +139,7 @@ def wait_for_orchestration_start(self, instance_id: str, *,
138139
139140 def wait_for_orchestration_completion (self , instance_id : str , * ,
140141 fetch_payloads : bool = True ,
141- timeout : int = 60 ) -> OrchestrationState | None :
142+ timeout : int = 60 ) -> Union [ OrchestrationState , None ] :
142143 req = pb .GetInstanceRequest (instanceId = instance_id , getInputsAndOutputs = fetch_payloads )
143144 try :
144145 self ._logger .info (f"Waiting { timeout } s for instance '{ instance_id } ' to complete." )
@@ -152,7 +153,7 @@ def wait_for_orchestration_completion(self, instance_id: str, *,
152153 raise
153154
154155 def raise_orchestration_event (self , instance_id : str , event_name : str , * ,
155- data : Any | None = None ):
156+ data : Union [ Any , None ] = None ):
156157 req = pb .RaiseEventRequest (
157158 instanceId = instance_id ,
158159 name = event_name ,
@@ -162,7 +163,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *,
162163 self ._stub .RaiseEvent (req )
163164
164165 def terminate_orchestration (self , instance_id : str , * ,
165- output : Any | None = None ):
166+ output : Union [ Any , None ] = None ):
166167 req = pb .TerminateRequest (
167168 instanceId = instance_id ,
168169 output = wrappers_pb2 .StringValue (value = shared .to_json (output )) if output else None )
0 commit comments