Skip to content

Commit 444ef8a

Browse files
committed
WIP add _doctype.py
1 parent feda8bf commit 444ef8a

1 file changed

Lines changed: 280 additions & 0 deletions

File tree

src/docstub/_doctype.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
"""Parsing of doctypes"""
2+
3+
import logging
4+
from collections.abc import Iterable
5+
from dataclasses import dataclass
6+
from pathlib import Path
7+
8+
import lark
9+
import lark.visitors
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
grammar_path = Path(__file__).parent / "doctype.lark"
15+
16+
with grammar_path.open() as file:
17+
_grammar = file.read()
18+
19+
_lark = lark.Lark(_grammar, propagate_positions=True, strict=True)
20+
21+
22+
def flatten_recursive(iterable):
23+
for item in iterable:
24+
if not isinstance(item, str) and isinstance(item, Iterable):
25+
yield from flatten_recursive(item)
26+
else:
27+
yield item
28+
29+
30+
def insert_between(iterable, *, sep):
31+
out = []
32+
for item in iterable:
33+
out.append(item)
34+
out.append(sep)
35+
return out[:-1]
36+
37+
38+
class Token(str):
39+
"""A token representing an atomic part of a doctype."""
40+
41+
__slots__ = ("value", "kind")
42+
43+
def __new__(cls, value, *, kind):
44+
self = super().__new__(cls, value)
45+
self.kind = kind
46+
return self
47+
48+
def __repr__(self):
49+
return f"{type(self).__name__}('{self}', kind={self.kind!r})"
50+
51+
@classmethod
52+
def find_iter(cls, iterable, *, kind):
53+
for item in flatten_recursive(iterable):
54+
if isinstance(item, cls) and item.kind == kind:
55+
yield item
56+
57+
@classmethod
58+
def find_one(cls, iterable, *, kind):
59+
matching = list(cls.find_iter(iterable, kind=kind))
60+
if len(matching) != 1:
61+
msg = (
62+
f"expected exactly one {cls.__name__} with {kind=}, got {len(matching)}"
63+
)
64+
raise ValueError(msg)
65+
return matching[0]
66+
67+
68+
@lark.visitors.v_args(tree=True)
69+
class DoctypeTransformer(lark.visitors.Transformer):
70+
def qualname(self, tree):
71+
"""
72+
Parameters
73+
----------
74+
tree : lark.Tree
75+
76+
Returns
77+
-------
78+
out : lark.Token
79+
"""
80+
children = tree.children
81+
_qualname = ".".join(children)
82+
_qualname = Token(_qualname, kind="qualname")
83+
return _qualname
84+
85+
def rst_role(self, tree):
86+
"""
87+
Parameters
88+
----------
89+
tree : lark.Tree
90+
91+
Returns
92+
-------
93+
out : lark.Token
94+
"""
95+
qualname = Token.find_one(tree.children, kind="qualname")
96+
return qualname
97+
98+
def union(self, tree):
99+
"""
100+
Parameters
101+
----------
102+
tree : lark.Tree
103+
104+
Returns
105+
-------
106+
out : list[str]
107+
"""
108+
sep = Token(" | ", kind="union_sep")
109+
out = insert_between(tree.children, sep=sep)
110+
return out
111+
112+
def subscription(self, tree):
113+
"""
114+
Parameters
115+
----------
116+
tree : lark.Tree
117+
118+
Returns
119+
-------
120+
out : str
121+
"""
122+
return self._format_subscription(tree.children, name="subscription")
123+
124+
def natlang_literal(self, tree):
125+
"""
126+
Parameters
127+
----------
128+
tree : lark.Tree
129+
130+
Returns
131+
-------
132+
out : str
133+
"""
134+
items = [Token("Literal", kind="qualname"), *tree.children]
135+
out = self._format_subscription(items, "nl_literal")
136+
137+
if len(tree.children) == 1:
138+
logger.warning(
139+
"natural language literal with one item `%s`, "
140+
"consider using `%s` to improve readability",
141+
tree.children[0],
142+
"".join(out),
143+
)
144+
return out
145+
146+
def natlang_container(self, tree):
147+
"""
148+
Parameters
149+
----------
150+
tree : lark.Tree
151+
152+
Returns
153+
-------
154+
out : str
155+
"""
156+
return self._format_subscription(tree.children, name="nl_container")
157+
158+
def natlang_array(self, tree):
159+
"""
160+
Parameters
161+
----------
162+
tree : lark.Tree
163+
164+
Returns
165+
-------
166+
out : str
167+
"""
168+
array_name = Token.find_one(tree.children, kind="array_name")
169+
items = tree.children.copy()
170+
items.remove(array_name)
171+
items.insert(0, Token(array_name, kind="qualname"))
172+
return self._format_subscription(items, name="nl_array")
173+
174+
def array_name(self, tree):
175+
"""
176+
Parameters
177+
----------
178+
tree : lark.Tree
179+
180+
Returns
181+
-------
182+
out : lark.Token
183+
"""
184+
# Treat `array_name` as `qualname`, but mark it as an array name,
185+
# so we know which one to treat as the container in `array_expression`
186+
# This currently relies on a hack that only allows specific names
187+
# in `array_expression` (see `ARRAY_NAME` terminal in gramar)
188+
qualname = self.qualname(tree)
189+
qualname = Token(qualname, kind="array_name")
190+
return qualname
191+
192+
def shape(self, tree):
193+
"""
194+
Parameters
195+
----------
196+
tree : lark.Tree
197+
198+
Returns
199+
-------
200+
out : lark.visitors._DiscardType
201+
"""
202+
logger.debug("dropping shape information")
203+
return lark.Discard
204+
205+
def optional(self, tree):
206+
"""
207+
Parameters
208+
----------
209+
tree : lark.Tree
210+
211+
Returns
212+
-------
213+
out : lark.visitors._DiscardType
214+
"""
215+
logger.debug("dropping optional / default info")
216+
return lark.Discard
217+
218+
def extra_info(self, tree):
219+
"""
220+
Parameters
221+
----------
222+
tree : lark.Tree
223+
224+
Returns
225+
-------
226+
out : lark.visitors._DiscardType
227+
"""
228+
logger.debug("dropping extra info")
229+
return lark.Discard
230+
231+
def _format_subscription(self, sequence, name):
232+
sep = Token(", ", kind=f"{name}_sep")
233+
container, *content = sequence
234+
content = insert_between(content, sep=sep)
235+
assert content
236+
out = [
237+
container,
238+
Token("[", kind=f"{name}_start"),
239+
*content,
240+
Token("]", kind=f"{name}_stop"),
241+
]
242+
return out
243+
244+
def __default_token__(self, token):
245+
return Token(token.value, kind=token.type.lower())
246+
247+
248+
@dataclass(frozen=True, slots=True)
249+
class ParsedDoctype:
250+
tokens: tuple[Token, ...]
251+
raw_doctype: str
252+
253+
@classmethod
254+
def parse(cls, doctype):
255+
"""Turn a type description in a docstring into a type annotation.
256+
257+
Parameters
258+
----------
259+
doctype : str
260+
The doctype to parse.
261+
262+
Returns
263+
-------
264+
annotation_list : list of Token
265+
266+
Examples
267+
--------
268+
>>> ParsedDoctype.parse("tuple of int or ndarray of dtype (float or int)")
269+
<ParsedDoctype: 'tuple[int] | ndarray[float | int]'>
270+
"""
271+
tree = _lark.parse(doctype)
272+
result = DoctypeTransformer().transform(tree=tree)
273+
result = tuple(flatten_recursive(result))
274+
return cls(result, raw_doctype=doctype)
275+
276+
def __str__(self):
277+
return "".join(self.tokens)
278+
279+
def __repr__(self):
280+
return f"<{type(self).__name__}: '{self}'>"

0 commit comments

Comments
 (0)