|
| 1 | +import pytest |
| 2 | + |
| 3 | +from docstub._analysis import KnownImport, StaticInspector, common_known_imports |
| 4 | +from docstub._docstrings import DoctypeTransformer, _lark |
| 5 | + |
| 6 | + |
| 7 | +@pytest.fixture() |
| 8 | +def transformer(): |
| 9 | + inspector = StaticInspector(known_imports=common_known_imports()) |
| 10 | + transformer = DoctypeTransformer(inspector=inspector, replace_doctypes={}) |
| 11 | + return transformer |
| 12 | + |
| 13 | + |
| 14 | +class Test_DoctypeTransformer: |
| 15 | + # fmt: off |
| 16 | + @pytest.mark.parametrize( |
| 17 | + ("raw", "expected"), |
| 18 | + [ |
| 19 | + ("list[float]", "list[float]"), |
| 20 | + ("dict[str, Union[int, str]]", "dict[str, Union[int, str]]"), |
| 21 | + ("tuple[int, ...]", "tuple[int, ...]"), |
| 22 | +
|
| 23 | + ("list of int", "list[int]"), |
| 24 | + ("tuple of float", "tuple[float]"), |
| 25 | + ("tuple of (float, ...)", "tuple[float, ...]"), |
| 26 | +
|
| 27 | + ("Sequence[int | float]", "Sequence[int | float]"), |
| 28 | +
|
| 29 | + ("dict of {str: int}", "dict[str, int]"), |
| 30 | + ], |
| 31 | + ) |
| 32 | + def test_container(self, raw, expected, transformer): |
| 33 | + tree = _lark.parse(raw) |
| 34 | + annotation = transformer.transform(tree) |
| 35 | + |
| 36 | + assert annotation.value == expected |
| 37 | + # fmt: on |
| 38 | + |
| 39 | + @pytest.mark.parametrize( |
| 40 | + ("raw", "expected"), |
| 41 | + [ |
| 42 | + ("{'a', 1, None, False}", "Literal['a', 1, None, False]"), |
| 43 | + ("dict[{'a', 'b'}, int]", "dict[Literal['a', 'b'], int]"), |
| 44 | + ], |
| 45 | + ) |
| 46 | + def test_literals(self, raw, expected, transformer): |
| 47 | + tree = _lark.parse(raw) |
| 48 | + annotation = transformer.transform(tree) |
| 49 | + |
| 50 | + assert annotation.value == expected |
| 51 | + assert annotation.imports == frozenset( |
| 52 | + {KnownImport(import_path="typing", import_name="Literal")} |
| 53 | + ) |
| 54 | + |
| 55 | + @pytest.mark.parametrize( |
| 56 | + ("raw", "expected"), |
| 57 | + [ |
| 58 | + ("int, optional", "int | None"), |
| 59 | + # None isn't appended, since the type should cover the default |
| 60 | + ("int, default 1", "int"), |
| 61 | + ("int, default = 1", "int"), |
| 62 | + ("int, default: 1", "int"), |
| 63 | + ], |
| 64 | + ) |
| 65 | + @pytest.mark.parametrize("extra_info", [None, "int", ", extra, info"]) |
| 66 | + def test_optional_extra_info(self, raw, expected, extra_info, transformer): |
| 67 | + doctype = raw |
| 68 | + if extra_info: |
| 69 | + doctype = f"{doctype}, {extra_info}" |
| 70 | + |
| 71 | + tree = _lark.parse(doctype) |
| 72 | + annotation = transformer.transform(tree) |
| 73 | + |
| 74 | + assert annotation.value == expected |
| 75 | + |
| 76 | + # fmt: off |
| 77 | + @pytest.mark.parametrize( |
| 78 | + ("fmt", "expected_fmt"), |
| 79 | + [ |
| 80 | + ("{shape} {name}", "{name}"), |
| 81 | + ("{shape} {name} of {dtype}", "{name}[{dtype}]"), |
| 82 | + ("{shape} {dtype} {name}", "{name}[{dtype}]"), |
| 83 | + ("{dtype} {name}", "{name}[{dtype}]"), |
| 84 | + ("{name} of shape {shape} and dtype {dtype}", "{name}[{dtype}]"), |
| 85 | + ("{name} of dtype {dtype} and shape {shape}", "{name}[{dtype}]"), |
| 86 | + ], |
| 87 | + ) |
| 88 | + @pytest.mark.parametrize("name", ["array", "ndarray", "array-like", "array_like"]) |
| 89 | + @pytest.mark.parametrize("dtype", ["int", "np.int8", "~.foo"]) |
| 90 | + @pytest.mark.parametrize("shape", ["(2, 3)", "(N, m)", "3D", "2-D", "(N, ...)"]) |
| 91 | + def test_shape_n_dtype(self, fmt, expected_fmt, name, dtype, shape, transformer): |
| 92 | + doctype = fmt.format(name=name, dtype=dtype, shape=shape) |
| 93 | + expected = expected_fmt.format(name=name, dtype=dtype, shape=shape) |
| 94 | + |
| 95 | + tree = _lark.parse(doctype) |
| 96 | + annotation = transformer.transform(tree) |
| 97 | + |
| 98 | + assert annotation.value == expected |
| 99 | + # fmt: on |
0 commit comments