Skip to content

Commit f1e3865

Browse files
committed
Fix memory leak in AsyncCompletions.parse() with dynamically created models
This commit fixes a memory leak issue in AsyncCompletions.parse() when repeatedly called with Pydantic models created via create_model(). The issue was occurring because schema representations of models were being retained indefinitely. The fix implements a WeakKeyDictionary cache that allows the schema objects to be garbage-collected when the model types are no longer referenced elsewhere in code. Added test cases to verify the fix prevents memory leaks with dynamically created models.
1 parent 3e69750 commit f1e3865

2 files changed

Lines changed: 115 additions & 1 deletion

File tree

src/openai/lib/_parsing/_completions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import json
4+
import weakref
45
from typing import TYPE_CHECKING, Any, Iterable, cast
56
from typing_extensions import TypeVar, TypeGuard, assert_never
67

@@ -28,6 +29,9 @@
2829
from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
2930
from ...types.chat.chat_completion_message_tool_call import Function
3031

32+
# Cache to store weak references to schema objects
33+
_schema_cache = weakref.WeakKeyDictionary()
34+
3135
ResponseFormatT = TypeVar(
3236
"ResponseFormatT",
3337
# if it isn't given then we don't do any parsing
@@ -243,6 +247,10 @@ def type_to_response_format_param(
243247
# can only be a `type`
244248
response_format = cast(type, response_format)
245249

250+
# Check if we already have a schema for this type in the cache
251+
if response_format in _schema_cache:
252+
return _schema_cache[response_format]
253+
246254
json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
247255

248256
if is_basemodel_type(response_format):
@@ -254,11 +262,16 @@ def type_to_response_format_param(
254262
else:
255263
raise TypeError(f"Unsupported response_format type - {response_format}")
256264

257-
return {
265+
schema_param = {
258266
"type": "json_schema",
259267
"json_schema": {
260268
"schema": to_strict_json_schema(json_schema_type),
261269
"name": name,
262270
"strict": True,
263271
},
264272
}
273+
274+
# Store a weak reference to the schema parameter
275+
_schema_cache[response_format] = schema_param
276+
277+
return schema_param
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import unittest
2+
import gc
3+
import sys
4+
from unittest.mock import AsyncMock, patch, MagicMock
5+
from typing import List
6+
7+
import pytest
8+
from pydantic import Field, create_model
9+
10+
from openai.resources.beta.chat.completions import AsyncCompletions
11+
from openai.lib._parsing import type_to_response_format_param
12+
from openai.lib._parsing._completions import _schema_cache
13+
14+
class TestMemoryLeak(unittest.TestCase):
15+
def setUp(self):
16+
# Clear the schema cache before each test
17+
_schema_cache.clear()
18+
19+
def test_schema_cache_with_models(self):
20+
"""Test if schema cache properly handles dynamic models and prevents memory leak"""
21+
22+
StepModel = create_model(
23+
"Step",
24+
explanation=(str, Field()),
25+
output=(str, Field()),
26+
)
27+
28+
# Create several models and ensure they're cached properly
29+
models = []
30+
for i in range(5):
31+
model = create_model(
32+
f"MathResponse{i}",
33+
steps=(List[StepModel], Field()),
34+
final_answer=(str, Field()),
35+
)
36+
models.append(model)
37+
38+
# Convert model to response format param
39+
param = type_to_response_format_param(model)
40+
41+
# Check if the model is in the cache
42+
self.assertIn(model, _schema_cache)
43+
44+
# Test that all models are in the cache
45+
self.assertEqual(len(_schema_cache), 5)
46+
47+
# Let the models go out of scope and trigger garbage collection
48+
models = None
49+
gc.collect()
50+
51+
# After garbage collection, the cache should be empty or reduced
52+
# since we're using weakref.WeakKeyDictionary
53+
self.assertLess(len(_schema_cache), 5)
54+
55+
@pytest.mark.asyncio
56+
async def test_async_completions_parse_memory():
57+
"""Test if AsyncCompletions.parse() doesn't leak memory with dynamic models"""
58+
StepModel = create_model(
59+
"Step",
60+
explanation=(str, Field()),
61+
output=(str, Field()),
62+
)
63+
64+
# Clear the cache and record initial state
65+
_schema_cache.clear()
66+
initial_cache_size = len(_schema_cache)
67+
68+
# Create a mock client
69+
mock_client = MagicMock()
70+
mock_client.chat.completions.create = AsyncMock()
71+
72+
# Create the AsyncCompletions instance with our mock client
73+
completions = AsyncCompletions(mock_client)
74+
75+
# Simulate the issue by creating multiple models and making calls
76+
models = []
77+
for i in range(10):
78+
# Create a new dynamic model each time
79+
new_model = create_model(
80+
f"MathResponse{i}",
81+
steps=(List[StepModel], Field()),
82+
final_answer=(str, Field()),
83+
)
84+
models.append(new_model)
85+
86+
# Convert to response format and check if it's in the cache
87+
type_to_response_format_param(new_model)
88+
assert new_model in _schema_cache
89+
90+
# Record cache size with all models referenced
91+
cache_size_with_references = len(_schema_cache)
92+
93+
# Let the models go out of scope and trigger garbage collection
94+
models = None
95+
gc.collect()
96+
97+
# After garbage collection, the cache should be significantly reduced
98+
cache_size_after_gc = len(_schema_cache)
99+
assert cache_size_after_gc < cache_size_with_references
100+
# The cache size should be close to the initial size (with some tolerance)
101+
assert cache_size_after_gc < cache_size_with_references / 2

0 commit comments

Comments
 (0)