diff --git a/src/openapi_parser/parser.py b/src/openapi_parser/parser.py index d28f33d..a0a7476 100644 --- a/src/openapi_parser/parser.py +++ b/src/openapi_parser/parser.py @@ -182,6 +182,7 @@ def parse( uri: str | None = None, spec_string: str | None = None, strict_enum: bool = True, + recursion_limit: int = 1, ) -> Specification: """Parse specification document by URL/filepath or as a string. @@ -191,8 +192,9 @@ def parse( strict_enum (bool): Validate content types and string formats against the enums defined in openapi-parser. Note that the OpenAPI specification allows for custom values in these properties. + recursion_limit (int): Maximum recursion depth for resolving references """ - resolver = OpenAPIResolver(uri, spec_string) + resolver = OpenAPIResolver(uri, spec_string, recursion_limit=recursion_limit) specification = resolver.resolve() parser = _create_parser(strict_enum=strict_enum) diff --git a/src/openapi_parser/resolver.py b/src/openapi_parser/resolver.py index df56902..eca980c 100644 --- a/src/openapi_parser/resolver.py +++ b/src/openapi_parser/resolver.py @@ -12,17 +12,38 @@ logger = logging.getLogger(__name__) +def _default_recursion_limit_handler( + limit: int, + parsed_url: Any, + _recursions: tuple[Any, ...] = (), +) -> dict[str, str]: + """Log warning and return minimal schema for circular reference.""" + logger.warning( + "Recursion limit of %d reached at %s. " + "Replacing circular reference with placeholder schema.", + limit, + str(parsed_url), + ) + return {"type": "object"} + + class OpenAPIResolver: """Resolves and validates OpenAPI specs using prance.""" _resolver: prance.ResolvingParser - def __init__(self, uri: str | None, spec_string: str | None = None) -> None: + def __init__( + self, + uri: str | None, + spec_string: str | None = None, + recursion_limit: int = 1, + ) -> None: """Initialize resolver. Args: uri: Path or URL to the spec file spec_string: Raw spec string as alternative to uri + recursion_limit: Maximum recursion depth for resolving references """ self._resolver = prance.ResolvingParser( uri, @@ -30,6 +51,8 @@ def __init__(self, uri: str | None, spec_string: str | None = None) -> None: backend=OPENAPI_SPEC_VALIDATOR, strict=False, lazy=True, + recursion_limit=recursion_limit, + recursion_limit_handler=_default_recursion_limit_handler, ) def resolve(self) -> dict[str, Any]: diff --git a/tests/data/recursive.yml b/tests/data/recursive.yml new file mode 100644 index 0000000..cf08afa --- /dev/null +++ b/tests/data/recursive.yml @@ -0,0 +1,40 @@ +openapi: 3.0.0 + +info: + title: Recursive schema test + version: 1.0.0 + +paths: + /test: + get: + summary: Test endpoint + operationId: Test + responses: + 200: + description: OK + +components: + schemas: + Equipment: + title: Equipment + type: object + properties: + Features: + type: array + items: + $ref: '#/components/schemas/Feature' + Id: + type: integer + format: int64 + + Feature: + title: Feature + type: object + properties: + Equipments: + type: array + items: + $ref: '#/components/schemas/Equipment' + Id: + type: integer + format: int64 diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 8f9f3af..ad98de6 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -4,7 +4,10 @@ import pytest from openapi_parser.errors import ParserError -from openapi_parser.resolver import OpenAPIResolver +from openapi_parser.resolver import ( + OpenAPIResolver, + _default_recursion_limit_handler, +) @mock.patch("openapi_parser.resolver.prance.ResolvingParser") @@ -27,3 +30,25 @@ def test_resolve_generic_error(mock_resolving_parser: mock.MagicMock) -> None: with pytest.raises(ParserError, match="OpenAPI file parsing error"): resolver.resolve() + + +@mock.patch("openapi_parser.resolver.prance.ResolvingParser") +def test_custom_recursion_limit( + mock_resolving_parser: mock.MagicMock, +) -> None: + OpenAPIResolver("fake.yaml", recursion_limit=10) + + mock_resolving_parser.assert_called_once_with( + "fake.yaml", + spec_string=None, + backend=mock.ANY, + strict=False, + lazy=True, + recursion_limit=10, + recursion_limit_handler=_default_recursion_limit_handler, + ) + + +def test_default_recursion_limit_handler_returns_placeholder() -> None: + result = _default_recursion_limit_handler(1, "http://example.com#/test") + assert result == {"type": "object"} diff --git a/tests/test_runner.py b/tests/test_runner.py index e7eb9b6..69c144c 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -1,7 +1,8 @@ import pytest from openapi_parser import parse -from openapi_parser.specification import Specification +from openapi_parser.enumeration import DataType +from openapi_parser.specification import Array, Object, Specification from tests.openapi_fixture import create_specification @@ -14,3 +15,54 @@ def test_run_parser(swagger_specification: Specification) -> None: actual_specification = parse("tests/data/swagger.yml") assert actual_specification == swagger_specification + + +def test_parse_recursive_schema() -> None: + actual_specification = parse("tests/data/recursive.yml") + + assert actual_specification.version == "3.0.0" + assert actual_specification.info.title == "Recursive schema test" + assert "Equipment" in actual_specification.schemas + assert "Feature" in actual_specification.schemas + + +def test_parse_recursive_schema_with_recursion_limit_2() -> None: + spec = parse("tests/data/recursive.yml", recursion_limit=2) + + equipment = spec.schemas["Equipment"] + assert isinstance(equipment, Object) + + features = equipment.properties[0] + assert features.name == "Features" + assert isinstance(features.schema, Array) + + feature_level_1 = features.schema.items + assert isinstance(feature_level_1, Object) + + equipment_level_2_schema = feature_level_1.properties[0].schema + assert isinstance(equipment_level_2_schema, Array) + equipment_level_2 = equipment_level_2_schema.items + assert isinstance(equipment_level_2, Object) + assert equipment_level_2.type == DataType.OBJECT + assert len(equipment_level_2.properties) == 2 + assert equipment_level_2.properties[0].name == "Features" + + feature_level_3_schema = equipment_level_2.properties[0].schema + assert isinstance(feature_level_3_schema, Array) + feature_level_3 = feature_level_3_schema.items + assert isinstance(feature_level_3, Object) + assert len(feature_level_3.properties) == 2 + assert feature_level_3.properties[0].name == "Equipments" + + equipment_level_4_schema = feature_level_3.properties[0].schema + assert isinstance(equipment_level_4_schema, Array) + equipment_level_4 = equipment_level_4_schema.items + assert isinstance(equipment_level_4, Object) + assert len(equipment_level_4.properties) == 2 + assert equipment_level_4.properties[0].name == "Features" + + placeholder_schema = equipment_level_4.properties[0].schema + assert isinstance(placeholder_schema, Array) + placeholder = placeholder_schema.items + assert isinstance(placeholder, Object) + assert len(placeholder.properties) == 0