-
Notifications
You must be signed in to change notification settings - Fork 102
Expand file tree
/
Copy pathinvoke_model.py
More file actions
126 lines (104 loc) · 4.28 KB
/
invoke_model.py
File metadata and controls
126 lines (104 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from temporalio import activity
from openai import AsyncOpenAI
import braintrust
from braintrust import wrap_openai
from typing import Optional, List, cast, Any, TypeVar, Generic
from typing_extensions import Annotated
from pydantic import BaseModel
from pydantic.functional_validators import BeforeValidator
from pydantic.functional_serializers import PlainSerializer
import importlib
import os
T = TypeVar("T", bound=BaseModel)
def _coerce_class(v: Any) -> type[Any]:
"""Pydantic validator: convert string path to class during deserialization."""
if isinstance(v, str):
mod_path, sep, qual = v.partition(":")
if not sep: # support "package.module.Class"
mod_path, _, qual = v.rpartition(".")
module = importlib.import_module(mod_path)
obj = module
for attr in qual.split("."):
obj = getattr(obj, attr)
return cast(type[Any], obj)
elif isinstance(v, type):
return v
else:
raise ValueError(f"Cannot coerce {v} to class")
def _dump_class(t: type[Any]) -> str:
"""Pydantic serializer: convert class to string path during serialization."""
return f"{t.__module__}:{t.__qualname__}"
# Custom type that automatically handles class <-> string conversion in Pydantic serialization
ClassReference = Annotated[
type[T],
BeforeValidator(_coerce_class),
PlainSerializer(_dump_class, return_type=str),
]
class InvokeModelRequest(BaseModel, Generic[T]):
model: str
instructions: str # Fallback if Braintrust prompt unavailable
input: str
prompt_slug: Optional[str] = None # Braintrust prompt slug (e.g., "report-synthesis")
response_format: Optional[ClassReference[T]] = None
tools: Optional[List[dict]] = None
class InvokeModelResponse(BaseModel, Generic[T]):
# response_format records the type of the response model
response_format: Optional[ClassReference[T]] = None
response_model: Any
@property
def response(self) -> T:
"""Reconstruct the original response type if response_format was provided."""
if self.response_format:
model_cls = self.response_format
return model_cls.model_validate(self.response_model)
return self.response_model
@activity.defn
async def invoke_model(request: InvokeModelRequest[T]) -> InvokeModelResponse[T]:
instructions = request.instructions
# Load prompt from Braintrust if slug provided
if request.prompt_slug:
try:
prompt = braintrust.load_prompt(
project=os.environ.get("BRAINTRUST_PROJECT", "deep-research"),
slug=request.prompt_slug,
)
# Extract system message content only
# NOTE: Other params (temperature, max_tokens, model) are NOT used
built = prompt.build()
for msg in built.get("messages", []):
if msg.get("role") == "system":
instructions = msg["content"]
activity.logger.info(
f"Loaded prompt '{request.prompt_slug}' from Braintrust"
)
break
except Exception as e:
# Log warning but continue with fallback
activity.logger.warning(
f"Failed to load prompt '{request.prompt_slug}': {e}. "
"Using hardcoded fallback."
)
client = wrap_openai(AsyncOpenAI(max_retries=0))
kwargs: dict[str, Any] = {
"model": request.model,
"instructions": instructions,
"input": request.input,
}
if request.response_format:
kwargs["text_format"] = request.response_format
if request.tools:
kwargs["tools"] = request.tools
# Use responses API consistently
resp = await client.responses.parse(**kwargs)
if request.response_format:
# Convert structured response to dict for managed serialization.
# This allows us to reconstruct the original response type while maintaining type safety.
parsed_model = cast(BaseModel, resp.output_parsed)
return InvokeModelResponse(
response_model=parsed_model.model_dump(),
response_format=request.response_format,
)
else:
return InvokeModelResponse(
response_model=resp.output_text, response_format=None
)