Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions src/openai/lib/_parsing/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
import weakref
from typing import TYPE_CHECKING, Any, Iterable, cast
from typing_extensions import TypeVar, TypeGuard, assert_never

Expand Down Expand Up @@ -30,6 +31,9 @@
from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
from ...types.chat.chat_completion_message_function_tool_call import Function

# Cache to store weak references to schema objects
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's consider adding this in a separate PR, and in this one, only resolve the memory leak

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, removed

_schema_cache: weakref.WeakKeyDictionary[type, ResponseFormatParam] = weakref.WeakKeyDictionary()

ResponseFormatT = TypeVar(
"ResponseFormatT",
# if it isn't given then we don't do any parsing
Expand Down Expand Up @@ -138,7 +142,7 @@ def parse_chat_completion(

choices.append(
construct_type_unchecked(
type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
type_=ParsedChoice[ResponseFormatT],
value={
**choice.to_dict(),
"message": {
Expand All @@ -153,15 +157,12 @@ def parse_chat_completion(
)
)

return cast(
ParsedChatCompletion[ResponseFormatT],
construct_type_unchecked(
type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
value={
**chat_completion.to_dict(),
"choices": choices,
},
),
return construct_type_unchecked(
type_=ParsedChatCompletion[ResponseFormatT],
value={
**chat_completion.to_dict(),
"choices": choices,
},
)


Expand Down Expand Up @@ -284,6 +285,10 @@ def type_to_response_format_param(
# can only be a `type`
response_format = cast(type, response_format)

# Check if we already have a schema for this type in the cache
if response_format in _schema_cache:
return _schema_cache[response_format]

json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None

if is_basemodel_type(response_format):
Expand All @@ -295,11 +300,16 @@ def type_to_response_format_param(
else:
raise TypeError(f"Unsupported response_format type - {response_format}")

return {
schema_param: ResponseFormatParam = {
"type": "json_schema",
"json_schema": {
"schema": to_strict_json_schema(json_schema_type),
"name": name,
"strict": True,
},
}

# Store a weak reference to the schema parameter
_schema_cache[response_format] = schema_param

return schema_param
24 changes: 10 additions & 14 deletions src/openai/lib/_parsing/_responses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, List, Iterable, cast
from typing import TYPE_CHECKING, List, Iterable, cast
from typing_extensions import TypeVar, assert_never

import pydantic
Expand All @@ -12,7 +12,7 @@
from ..._compat import PYDANTIC_V1, model_parse_json
from ..._models import construct_type_unchecked
from .._pydantic import is_basemodel_type, is_dataclass_like_type
from ._completions import solve_response_format_t, type_to_response_format_param
from ._completions import type_to_response_format_param
from ...types.responses import (
Response,
ToolParam,
Expand Down Expand Up @@ -56,7 +56,6 @@ def parse_response(
input_tools: Iterable[ToolParam] | Omit | None,
response: Response | ParsedResponse[object],
) -> ParsedResponse[TextFormatT]:
solved_t = solve_response_format_t(text_format)
output_list: List[ParsedResponseOutputItem[TextFormatT]] = []

for output in response.output:
Expand All @@ -69,7 +68,7 @@ def parse_response(

content_list.append(
construct_type_unchecked(
type_=cast(Any, ParsedResponseOutputText)[solved_t],
type_=ParsedResponseOutputText[TextFormatT],
value={
**item.to_dict(),
"parsed": parse_text(item.text, text_format=text_format),
Expand All @@ -79,7 +78,7 @@ def parse_response(

output_list.append(
construct_type_unchecked(
type_=cast(Any, ParsedResponseOutputMessage)[solved_t],
type_=ParsedResponseOutputMessage[TextFormatT],
value={
**output.to_dict(),
"content": content_list,
Expand Down Expand Up @@ -118,15 +117,12 @@ def parse_response(
else:
output_list.append(output)

return cast(
ParsedResponse[TextFormatT],
construct_type_unchecked(
type_=cast(Any, ParsedResponse)[solved_t],
value={
**response.to_dict(),
"output": output_list,
},
),
return construct_type_unchecked(
type_=ParsedResponse[TextFormatT],
value={
**response.to_dict(),
"output": output_list,
},
)


Expand Down
11 changes: 2 additions & 9 deletions src/openai/lib/streaming/chat/_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import inspect
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast
from typing import TYPE_CHECKING, Generic, Callable, Iterable, Awaitable, AsyncIterator, cast
from typing_extensions import Self, Iterator, assert_never

from jiter import from_json
Expand Down Expand Up @@ -33,7 +33,6 @@
maybe_parse_content,
parse_chat_completion,
get_input_tool_by_name,
solve_response_format_t,
parse_function_tool_arguments,
)
from ...._streaming import Stream, AsyncStream
Expand Down Expand Up @@ -658,13 +657,7 @@ def _content_done_events(

events_to_fire.append(
build(
# we do this dance so that when the `ContentDoneEvent` instance
# is printed at runtime the class name will include the solved
# type variable, e.g. `ContentDoneEvent[MyModelType]`
cast( # pyright: ignore[reportUnnecessaryCast]
"type[ContentDoneEvent[ResponseFormatT]]",
cast(Any, ContentDoneEvent)[solve_response_format_t(response_format)],
),
ContentDoneEvent[ResponseFormatT],
type="content.done",
content=choice_snapshot.message.content,
parsed=parsed,
Expand Down
50 changes: 50 additions & 0 deletions tests/lib/_parsing/test_memory_leak.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let’s remove this file too—the test doesn’t seem to work. Plus, memory leak tests in pytest are super flaky: memory can grow because of pytest itself. So we can also consider not adding a test at all

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, removed

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import gc
from typing import List

import pytest
from pydantic import Field, create_model

from openai.lib._parsing import type_to_response_format_param
from openai.lib._parsing._completions import _schema_cache


@pytest.mark.asyncio
async def test_async_completions_parse_memory() -> None:
"""Test if AsyncCompletions.parse() doesn't leak memory with dynamic models"""
# Create a base step model
StepModel = create_model(
"Step",
explanation=(str, Field()),
output=(str, Field()),
)

# Clear the cache before testing
_schema_cache.clear()

# Simulate the issue by creating multiple models and making calls
models: list[type] = []
for i in range(10):
# Create a new dynamic model each time
new_model = create_model(
f"MathResponse{i}",
steps=(List[StepModel], Field()), # type: ignore[valid-type]
final_answer=(str, Field()),
)
models.append(new_model)

# Convert to response format and check if it's in the cache
type_to_response_format_param(new_model)
assert new_model in _schema_cache

# Record cache size with all models referenced
cache_size_with_references = len(_schema_cache)

# Let the models go out of scope and trigger garbage collection
del models
gc.collect()

# After garbage collection, the cache should be significantly reduced
cache_size_after_gc = len(_schema_cache)
assert cache_size_after_gc < cache_size_with_references
# The cache size should be close to the initial size (with some tolerance)
assert cache_size_after_gc < cache_size_with_references / 2
Loading