Skip to content

Commit bfef59e

Browse files
authored
Anthropic beta (#497)
* add anthropic beta support * added beta support * bump version * extra char * fix failing tests
1 parent 84a94dd commit bfef59e

5 files changed

Lines changed: 1598 additions & 81 deletions

File tree

agentops/llms/anthropic.py

Lines changed: 101 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def handle_response(
2929
"""Handle responses for Anthropic"""
3030
from anthropic import Stream, AsyncStream
3131
from anthropic.resources import AsyncMessages
32+
import anthropic.resources.beta.messages.messages as beta_messages
3233
from anthropic.types import Message
3334

3435
llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
@@ -155,6 +156,7 @@ def override(self):
155156

156157
def _override_completion(self):
157158
from anthropic.resources import messages
159+
import anthropic.resources.beta.messages.messages as beta_messages
158160
from anthropic.types import (
159161
Message,
160162
RawContentBlockDeltaEvent,
@@ -167,54 +169,64 @@ def _override_completion(self):
167169

168170
# Store the original method
169171
self.original_create = messages.Messages.create
172+
self.original_create_beta = beta_messages.Messages.create
170173

171-
def patched_function(*args, **kwargs):
172-
init_timestamp = get_ISO_time()
173-
session = kwargs.get("session", None)
174-
if "session" in kwargs.keys():
175-
del kwargs["session"]
174+
def create_patched_function(is_beta=False):
175+
def patched_function(*args, **kwargs):
176+
init_timestamp = get_ISO_time()
177+
session = kwargs.get("session", None)
178+
if "session" in kwargs.keys():
179+
del kwargs["session"]
176180

177-
completion_override = fetch_completion_override_from_time_travel_cache(
178-
kwargs
179-
)
180-
if completion_override:
181-
result_model = None
182-
pydantic_models = (
183-
Message,
184-
RawContentBlockDeltaEvent,
185-
RawContentBlockStartEvent,
186-
RawContentBlockStopEvent,
187-
RawMessageDeltaEvent,
188-
RawMessageStartEvent,
189-
RawMessageStopEvent,
181+
completion_override = fetch_completion_override_from_time_travel_cache(
182+
kwargs
190183
)
184+
if completion_override:
185+
result_model = None
186+
pydantic_models = (
187+
Message,
188+
RawContentBlockDeltaEvent,
189+
RawContentBlockStartEvent,
190+
RawContentBlockStopEvent,
191+
RawMessageDeltaEvent,
192+
RawMessageStartEvent,
193+
RawMessageStopEvent,
194+
)
191195

192-
for pydantic_model in pydantic_models:
193-
try:
194-
result_model = pydantic_model.model_validate_json(
195-
completion_override
196+
for pydantic_model in pydantic_models:
197+
try:
198+
result_model = pydantic_model.model_validate_json(
199+
completion_override
200+
)
201+
break
202+
except Exception as e:
203+
pass
204+
205+
if result_model is None:
206+
logger.error(
207+
f"Time Travel: Pydantic validation failed for {pydantic_models} \n"
208+
f"Time Travel: Completion override was:\n"
209+
f"{pprint.pformat(completion_override)}"
196210
)
197-
break
198-
except Exception as e:
199-
pass
200-
201-
if result_model is None:
202-
logger.error(
203-
f"Time Travel: Pydantic validation failed for {pydantic_models} \n"
204-
f"Time Travel: Completion override was:\n"
205-
f"{pprint.pformat(completion_override)}"
211+
return None
212+
return self.handle_response(
213+
result_model, kwargs, init_timestamp, session=session
206214
)
207-
return None
215+
216+
# Call the original function with its original arguments
217+
original_func = (
218+
self.original_create_beta if is_beta else self.original_create
219+
)
220+
result = original_func(*args, **kwargs)
208221
return self.handle_response(
209-
result_model, kwargs, init_timestamp, session=session
222+
result, kwargs, init_timestamp, session=session
210223
)
211224

212-
# Call the original function with its original arguments
213-
result = self.original_create(*args, **kwargs)
214-
return self.handle_response(result, kwargs, init_timestamp, session=session)
225+
return patched_function
215226

216-
# Override the original method with the patched one
217-
messages.Messages.create = patched_function
227+
# Override the original methods with the patched ones
228+
messages.Messages.create = create_patched_function(is_beta=False)
229+
beta_messages.Messages.create = create_patched_function(is_beta=True)
218230

219231
def _override_async_completion(self):
220232
from anthropic.resources import messages
@@ -227,58 +239,71 @@ def _override_async_completion(self):
227239
RawMessageStartEvent,
228240
RawMessageStopEvent,
229241
)
242+
import anthropic.resources.beta.messages.messages as beta_messages
230243

231244
# Store the original method
232245
self.original_create_async = messages.AsyncMessages.create
246+
self.original_create_async_beta = beta_messages.AsyncMessages.create
233247

234-
async def patched_function(*args, **kwargs):
235-
# Call the original function with its original arguments
236-
init_timestamp = get_ISO_time()
237-
session = kwargs.get("session", None)
238-
if "session" in kwargs.keys():
239-
del kwargs["session"]
248+
def create_patched_async_function(is_beta=False):
249+
async def patched_function(*args, **kwargs):
250+
init_timestamp = get_ISO_time()
251+
session = kwargs.get("session", None)
252+
if "session" in kwargs.keys():
253+
del kwargs["session"]
240254

241-
completion_override = fetch_completion_override_from_time_travel_cache(
242-
kwargs
243-
)
244-
if completion_override:
245-
result_model = None
246-
pydantic_models = (
247-
Message,
248-
RawContentBlockDeltaEvent,
249-
RawContentBlockStartEvent,
250-
RawContentBlockStopEvent,
251-
RawMessageDeltaEvent,
252-
RawMessageStartEvent,
253-
RawMessageStopEvent,
255+
completion_override = fetch_completion_override_from_time_travel_cache(
256+
kwargs
254257
)
258+
if completion_override:
259+
result_model = None
260+
pydantic_models = (
261+
Message,
262+
RawContentBlockDeltaEvent,
263+
RawContentBlockStartEvent,
264+
RawContentBlockStopEvent,
265+
RawMessageDeltaEvent,
266+
RawMessageStartEvent,
267+
RawMessageStopEvent,
268+
)
255269

256-
for pydantic_model in pydantic_models:
257-
try:
258-
result_model = pydantic_model.model_validate_json(
259-
completion_override
270+
for pydantic_model in pydantic_models:
271+
try:
272+
result_model = pydantic_model.model_validate_json(
273+
completion_override
274+
)
275+
break
276+
except Exception as e:
277+
pass
278+
279+
if result_model is None:
280+
logger.error(
281+
f"Time Travel: Pydantic validation failed for {pydantic_models} \n"
282+
f"Time Travel: Completion override was:\n"
283+
f"{pprint.pformat(completion_override)}"
260284
)
261-
break
262-
except Exception as e:
263-
pass
264-
265-
if result_model is None:
266-
logger.error(
267-
f"Time Travel: Pydantic validation failed for {pydantic_models} \n"
268-
f"Time Travel: Completion override was:\n"
269-
f"{pprint.pformat(completion_override)}"
285+
return None
286+
287+
return self.handle_response(
288+
result_model, kwargs, init_timestamp, session=session
270289
)
271-
return None
272290

291+
# Call the original function with its original arguments
292+
original_func = (
293+
self.original_create_async_beta
294+
if is_beta
295+
else self.original_create_async
296+
)
297+
result = await original_func(*args, **kwargs)
273298
return self.handle_response(
274-
result_model, kwargs, init_timestamp, session=session
299+
result, kwargs, init_timestamp, session=session
275300
)
276301

277-
result = await self.original_create_async(*args, **kwargs)
278-
return self.handle_response(result, kwargs, init_timestamp, session=session)
302+
return patched_function
279303

280-
# Override the original method with the patched one
281-
messages.AsyncMessages.create = patched_function
304+
# Override the original methods with the patched ones
305+
messages.AsyncMessages.create = create_patched_async_function(is_beta=False)
306+
beta_messages.AsyncMessages.create = create_patched_async_function(is_beta=True)
282307

283308
def undo_override(self):
284309
if self.original_create is not None and self.original_create_async is not None:

0 commit comments

Comments
 (0)