55import logging
66from collections .abc import Iterable
77from dataclasses import dataclass
8+ from functools import lru_cache
89from pathlib import Path
910
1011import lark
@@ -67,7 +68,7 @@ class Token(str):
6768
6869 flag = TokenFlag
6970
70- __slots__ = ("value " , "kind " , "pos " )
71+ __slots__ = ("kind " , "pos " , "value " )
7172
7273 def __new__ (cls , value , * , kind , pos = None ):
7374 self = super ().__new__ (cls , value )
@@ -302,10 +303,13 @@ def _format_subscription(self, sequence, kind=None):
302303
303304@dataclass (frozen = True , slots = True )
304305class ParsedDoctype :
306+ """Parsed representation of a doctype, a type description in a docstring."""
307+
305308 tokens : tuple [Token , ...]
306309 raw_doctype : str
307310
308311 @classmethod
312+ @lru_cache (maxsize = 100 )
309313 def parse (cls , doctype ):
310314 """Turn a type description in a docstring into a type annotation.
311315
@@ -316,27 +320,39 @@ def parse(cls, doctype):
316320
317321 Returns
318322 -------
319- annotation_list : list of Token
323+ parsed : Self
320324
321325 Examples
322326 --------
323- >>> doctype = ParsedDoctype.parse(
327+ >>> parsed = ParsedDoctype.parse(
324328 ... "tuple of int or ndarray of dtype (float or int)"
325329 ... )
326- >>> doctype
327- <ParsedDoctype: 'tuple[int] | ndarray[float | int]'>
328- >>> doctype.qualnames
329- (Token('tuple', kind='qualname'),
330- Token('int', kind='qualname'),
331- Token('ndarray', kind='qualname'),
332- Token('float', kind='qualname'),
333- Token('int', kind='qualname'))
330+ >>> parsed
331+ <ParsedDoctype 'tuple[int] | ndarray[float | int]'>
332+ >>> str(parsed)
333+ 'tuple[int] | ndarray[float | int]'
334+ >>> parsed.format({"ndarray": "np.ndarray"})
335+ 'tuple[int] | np.ndarray[float | int]'
336+ >>> parsed.qualnames # doctest: +NORMALIZE_WHITESPACE
337+ (Token('tuple', kind=<TokenFlag.NAME: 1>),
338+ Token('int', kind=<TokenFlag.NAME: 1>),
339+ Token('ndarray', kind=<TokenFlag.NAME|ARRAY: 33>),
340+ Token('float', kind=<TokenFlag.NAME: 1>),
341+ Token('int', kind=<TokenFlag.NAME: 1>))
334342 """
335343 tree = _lark .parse (doctype )
336344 tokens = DoctypeTransformer ().transform (tree = tree )
337345 tokens = tuple (flatten_recursive (tokens ))
338346 return cls (tokens , raw_doctype = doctype )
339347
348+ def format (self , replace_names = None ):
349+ replace_names = replace_names or {}
350+ tokens = [
351+ replace_names .get (token , token ) if token .kind == TokenFlag .NAME else token
352+ for token in self .tokens
353+ ]
354+ return "" .join (tokens )
355+
340356 def __str__ (self ):
341357 return "" .join (self .tokens )
342358
@@ -351,7 +367,7 @@ def print_map_tokens_to_raw(self):
351367 for token in self .tokens :
352368 if token .pos is not None :
353369 start , stop = token .pos
354- print (self .raw_doctype )
355- print (" " * start + "^" * (stop - start ))
356- print (" " * start + token )
357- print ()
370+ print (self .raw_doctype ) # noqa: T201
371+ print (" " * start + "^" * (stop - start )) # noqa: T201
372+ print (" " * start + token ) # noqa: T201
373+ print () # noqa: T201
0 commit comments