Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/openapi_parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def parse(
uri: str | None = None,
spec_string: str | None = None,
strict_enum: bool = True,
recursion_limit: int = 1,
) -> Specification:
"""Parse specification document by URL/filepath or as a string.

Expand All @@ -191,8 +192,9 @@ def parse(
strict_enum (bool): Validate content types and string formats against the
enums defined in openapi-parser. Note that the OpenAPI specification allows
for custom values in these properties.
recursion_limit (int): Maximum recursion depth for resolving references
"""
resolver = OpenAPIResolver(uri, spec_string)
resolver = OpenAPIResolver(uri, spec_string, recursion_limit=recursion_limit)
specification = resolver.resolve()

parser = _create_parser(strict_enum=strict_enum)
Expand Down
25 changes: 24 additions & 1 deletion src/openapi_parser/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,47 @@
logger = logging.getLogger(__name__)


def _default_recursion_limit_handler(
limit: int,
parsed_url: Any,
_recursions: tuple[Any, ...] = (),
) -> dict[str, str]:
"""Log warning and return minimal schema for circular reference."""
logger.warning(
"Recursion limit of %d reached at %s. "
"Replacing circular reference with placeholder schema.",
limit,
str(parsed_url),
)
return {"type": "object"}


class OpenAPIResolver:
"""Resolves and validates OpenAPI specs using prance."""

_resolver: prance.ResolvingParser

def __init__(self, uri: str | None, spec_string: str | None = None) -> None:
def __init__(
self,
uri: str | None,
spec_string: str | None = None,
recursion_limit: int = 1,
) -> None:
"""Initialize resolver.

Args:
uri: Path or URL to the spec file
spec_string: Raw spec string as alternative to uri
recursion_limit: Maximum recursion depth for resolving references
"""
self._resolver = prance.ResolvingParser(
uri,
spec_string=spec_string,
backend=OPENAPI_SPEC_VALIDATOR,
strict=False,
lazy=True,
recursion_limit=recursion_limit,
recursion_limit_handler=_default_recursion_limit_handler,
)

def resolve(self) -> dict[str, Any]:
Expand Down
40 changes: 40 additions & 0 deletions tests/data/recursive.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
openapi: 3.0.0

info:
title: Recursive schema test
version: 1.0.0

paths:
/test:
get:
summary: Test endpoint
operationId: Test
responses:
200:
description: OK

components:
schemas:
Equipment:
title: Equipment
type: object
properties:
Features:
type: array
items:
$ref: '#/components/schemas/Feature'
Id:
type: integer
format: int64

Feature:
title: Feature
type: object
properties:
Equipments:
type: array
items:
$ref: '#/components/schemas/Equipment'
Id:
type: integer
format: int64
27 changes: 26 additions & 1 deletion tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import pytest

from openapi_parser.errors import ParserError
from openapi_parser.resolver import OpenAPIResolver
from openapi_parser.resolver import (
OpenAPIResolver,
_default_recursion_limit_handler,
)


@mock.patch("openapi_parser.resolver.prance.ResolvingParser")
Expand All @@ -27,3 +30,25 @@ def test_resolve_generic_error(mock_resolving_parser: mock.MagicMock) -> None:

with pytest.raises(ParserError, match="OpenAPI file parsing error"):
resolver.resolve()


@mock.patch("openapi_parser.resolver.prance.ResolvingParser")
def test_custom_recursion_limit(
mock_resolving_parser: mock.MagicMock,
) -> None:
OpenAPIResolver("fake.yaml", recursion_limit=10)

mock_resolving_parser.assert_called_once_with(
"fake.yaml",
spec_string=None,
backend=mock.ANY,
strict=False,
lazy=True,
recursion_limit=10,
recursion_limit_handler=_default_recursion_limit_handler,
)


def test_default_recursion_limit_handler_returns_placeholder() -> None:
result = _default_recursion_limit_handler(1, "http://example.com#/test")
assert result == {"type": "object"}
54 changes: 53 additions & 1 deletion tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

from openapi_parser import parse
from openapi_parser.specification import Specification
from openapi_parser.enumeration import DataType
from openapi_parser.specification import Array, Object, Specification
from tests.openapi_fixture import create_specification


Expand All @@ -14,3 +15,54 @@ def test_run_parser(swagger_specification: Specification) -> None:
actual_specification = parse("tests/data/swagger.yml")

assert actual_specification == swagger_specification


def test_parse_recursive_schema() -> None:
actual_specification = parse("tests/data/recursive.yml")

assert actual_specification.version == "3.0.0"
assert actual_specification.info.title == "Recursive schema test"
assert "Equipment" in actual_specification.schemas
assert "Feature" in actual_specification.schemas


def test_parse_recursive_schema_with_recursion_limit_2() -> None:
spec = parse("tests/data/recursive.yml", recursion_limit=2)

equipment = spec.schemas["Equipment"]
assert isinstance(equipment, Object)

features = equipment.properties[0]
assert features.name == "Features"
assert isinstance(features.schema, Array)

feature_level_1 = features.schema.items
assert isinstance(feature_level_1, Object)

equipment_level_2_schema = feature_level_1.properties[0].schema
assert isinstance(equipment_level_2_schema, Array)
equipment_level_2 = equipment_level_2_schema.items
assert isinstance(equipment_level_2, Object)
assert equipment_level_2.type == DataType.OBJECT
assert len(equipment_level_2.properties) == 2
assert equipment_level_2.properties[0].name == "Features"

feature_level_3_schema = equipment_level_2.properties[0].schema
assert isinstance(feature_level_3_schema, Array)
feature_level_3 = feature_level_3_schema.items
assert isinstance(feature_level_3, Object)
assert len(feature_level_3.properties) == 2
assert feature_level_3.properties[0].name == "Equipments"

equipment_level_4_schema = feature_level_3.properties[0].schema
assert isinstance(equipment_level_4_schema, Array)
equipment_level_4 = equipment_level_4_schema.items
assert isinstance(equipment_level_4, Object)
assert len(equipment_level_4.properties) == 2
assert equipment_level_4.properties[0].name == "Features"

placeholder_schema = equipment_level_4.properties[0].schema
assert isinstance(placeholder_schema, Array)
placeholder = placeholder_schema.items
assert isinstance(placeholder, Object)
assert len(placeholder.properties) == 0
Loading