Skip to content

Commit e0e6060

Browse files
committed
Split docname into KnownImport and replace
Reduce responsibility of the former DocName class. Replacing docstring specific type description should be handled separately.
1 parent 87ce931 commit e0e6060

4 files changed

Lines changed: 116 additions & 15 deletions

File tree

src/docstub/_docstrings.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,19 +266,21 @@ def contains(self, tree):
266266
return out
267267

268268
def literals(self, tree):
269-
out = " , ".join(tree.children)
269+
out = ", ".join(tree.children)
270270
out = f"Literal[{out}]"
271271
_, known_import = self.inspector.query("Literal")
272-
self._collected_imports.add(known_import)
272+
if known_import:
273+
self._collected_imports.add(known_import)
273274
return out
274275

275276
def _find_import(self, qualname):
276277
"""Match type names to known imports."""
277278
try:
278-
qualname, known_import = self.inspector.query(qualname)
279-
if known_import:
280-
if known_import.has_import:
281-
self._collected_imports.add(known_import)
279+
annotation_name, known_import = self.inspector.query(qualname)
280+
if known_import and known_import.has_import:
281+
self._collected_imports.add(known_import)
282+
if annotation_name:
283+
qualname = annotation_name
282284
else:
283285
logger.warning(
284286
"unknown import for %r in %s",

src/docstub/doctype.lark

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
?start : annotation
22

3-
annotation : (literals | types_or) ("," optional)? ("," extra_info)?
3+
annotation : types_or ("," optional)? ("," extra_info)?
44

5-
literals : "{" literal ("," literal)* "}"
6-
7-
types_or : type (("or" | "|") type)*
5+
?types_or : type (("or" | "|") type)*
86

97
?type : qualname
108
| sphinx_ref
119
| container
1210
| shape_n_dtype
11+
| literals
1312

1413
optional : "optional"
1514
| "default" ("=" | ":")? literal
@@ -18,10 +17,9 @@ extra_info : /[^\r\n]+/
1817

1918
sphinx_ref : ":" (NAME ":")? NAME ":`" qualname "`"
2019

21-
container: qualname "[" types_or ("," types_or)* "]"
22-
| qualname "[" types_or "," PY_ELLIPSES "]"
20+
container: qualname "[" types_or ("," types_or)* ("," PY_ELLIPSES)? "]"
2321
| qualname "of" type
24-
| qualname "of" "(" types_or ("," types_or)* ")"
22+
| qualname "of" "(" types_or ("," types_or)* ("," PY_ELLIPSES)? ")"
2523
| qualname "of" "{" types_or ":" types_or "}"
2624

2725
// Name with leading dot separated path
@@ -42,14 +40,16 @@ dtype : qualname
4240
shape : "(" dim ",)"
4341
| "(" leading_optional? dim (("," dim | insert_optional))* ")"
4442
| NUMBER "-"? "D"
45-
leading_optional : "[" dim ("," dim)* ",]" -> optional
46-
insert_optional : "[," dim ("," dim)* "]" -> optional
43+
leading_optional : "[" dim ("," dim)* ",]"
44+
insert_optional : "[," dim ("," dim)* "]"
4745
?dim : NUMBER
4846
| PY_ELLIPSES
4947
| NAME
5048

5149

5250
// Python
51+
literals : "{" literal ("," literal)* "}"
52+
5353
literal : PY_ELLIPSES
5454
| STRING
5555
| NUMBER

tests/test_dev.py

Whitespace-only changes.

tests/test_docstrings.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import pytest
2+
3+
from docstub._analysis import KnownImport, StaticInspector, common_known_imports
4+
from docstub._docstrings import DoctypeTransformer, _lark
5+
6+
7+
@pytest.fixture()
8+
def transformer():
9+
inspector = StaticInspector(known_imports=common_known_imports())
10+
transformer = DoctypeTransformer(inspector=inspector, replace_doctypes={})
11+
return transformer
12+
13+
14+
class Test_DoctypeTransformer:
15+
# fmt: off
16+
@pytest.mark.parametrize(
17+
("raw", "expected"),
18+
[
19+
("list[float]", "list[float]"),
20+
("dict[str, Union[int, str]]", "dict[str, Union[int, str]]"),
21+
("tuple[int, ...]", "tuple[int, ...]"),
22+
23+
("list of int", "list[int]"),
24+
("tuple of float", "tuple[float]"),
25+
("tuple of (float, ...)", "tuple[float, ...]"),
26+
27+
("Sequence[int | float]", "Sequence[int | float]"),
28+
29+
("dict of {str: int}", "dict[str, int]"),
30+
],
31+
)
32+
def test_container(self, raw, expected, transformer):
33+
tree = _lark.parse(raw)
34+
annotation = transformer.transform(tree)
35+
36+
assert annotation.value == expected
37+
# fmt: on
38+
39+
@pytest.mark.parametrize(
40+
("raw", "expected"),
41+
[
42+
("{'a', 1, None, False}", "Literal['a', 1, None, False]"),
43+
("dict[{'a', 'b'}, int]", "dict[Literal['a', 'b'], int]"),
44+
],
45+
)
46+
def test_literals(self, raw, expected, transformer):
47+
tree = _lark.parse(raw)
48+
annotation = transformer.transform(tree)
49+
50+
assert annotation.value == expected
51+
assert annotation.imports == frozenset(
52+
{KnownImport(import_path="typing", import_name="Literal")}
53+
)
54+
55+
@pytest.mark.parametrize(
56+
("raw", "expected"),
57+
[
58+
("int, optional", "int | None"),
59+
# None isn't appended, since the type should cover the default
60+
("int, default 1", "int"),
61+
("int, default = 1", "int"),
62+
("int, default: 1", "int"),
63+
],
64+
)
65+
@pytest.mark.parametrize("extra_info", [None, "int", ", extra, info"])
66+
def test_optional_extra_info(self, raw, expected, extra_info, transformer):
67+
doctype = raw
68+
if extra_info:
69+
doctype = f"{doctype}, {extra_info}"
70+
71+
tree = _lark.parse(doctype)
72+
annotation = transformer.transform(tree)
73+
74+
assert annotation.value == expected
75+
76+
# fmt: off
77+
@pytest.mark.parametrize(
78+
("fmt", "expected_fmt"),
79+
[
80+
("{shape} {name}", "{name}"),
81+
("{shape} {name} of {dtype}", "{name}[{dtype}]"),
82+
("{shape} {dtype} {name}", "{name}[{dtype}]"),
83+
("{dtype} {name}", "{name}[{dtype}]"),
84+
("{name} of shape {shape} and dtype {dtype}", "{name}[{dtype}]"),
85+
("{name} of dtype {dtype} and shape {shape}", "{name}[{dtype}]"),
86+
],
87+
)
88+
@pytest.mark.parametrize("name", ["array", "ndarray", "array-like", "array_like"])
89+
@pytest.mark.parametrize("dtype", ["int", "np.int8", "~.foo"])
90+
@pytest.mark.parametrize("shape", ["(2, 3)", "(N, m)", "3D", "2-D", "(N, ...)"])
91+
def test_shape_n_dtype(self, fmt, expected_fmt, name, dtype, shape, transformer):
92+
doctype = fmt.format(name=name, dtype=dtype, shape=shape)
93+
expected = expected_fmt.format(name=name, dtype=dtype, shape=shape)
94+
95+
tree = _lark.parse(doctype)
96+
annotation = transformer.transform(tree)
97+
98+
assert annotation.value == expected
99+
# fmt: on

0 commit comments

Comments
 (0)