1313Output = TypeVar ('Output' , bound = Node )
1414Source = TypeVar ('Source' )
1515Target = TypeVar ('Target' )
16- node_factory_constructor_type = Callable [[Source , "ASTTransformer" , "NodeFactory[Source, Output]" ], Optional [Output ]]
16+ node_factory_constructor_type = Callable [[Source , "ASTTransformer" , "NodeFactory[Source, Output]" ], List [Output ]]
17+ node_factory_single_constructor_type = Callable [[Source , "ASTTransformer" , "NodeFactory[Source, Output]" ], Output ]
1718
1819
1920@dataclass
@@ -86,31 +87,40 @@ def __init__(self, issues: List[Issue] = None, allow_generic_node: bool = True):
8687 self .known_classes = dict ()
8788
8889 def transform (self , source : Optional [Any ], parent : Optional [Node ] = None ) -> Optional [Node ]:
89- if source is None :
90+ result = self .transform_into_nodes (source , parent )
91+ if len (result ) == 0 :
9092 return None
93+ elif len (result ) == 1 :
94+ return result [0 ]
95+ else :
96+ raise Exception (f"Cannot transform { source } into a single Node as multiple nodes where produced" )
97+
98+ def transform_into_nodes (self , source : Optional [Any ], parent : Optional [Node ] = None ) -> List [Node ]:
99+ if source is None :
100+ return []
91101 elif isinstance (source , Iterable ):
92102 raise Exception (f"Mapping error: received collection when value was expected: { source } " )
93103 factory = self .get_node_factory (type (source ))
94104 if factory :
95- node = self .make_node (factory , source )
96- if not node :
97- return None
98- for pd in type ( node ). node_properties :
99- self . process_child ( source , node , pd , factory )
100- factory . finalizer ( node )
101- node . parent = parent
105+ nodes = self .make_nodes (factory , source )
106+ for node in nodes :
107+ for pd in type ( node ). node_properties :
108+ self . process_child ( source , node , pd , factory )
109+ factory . finalizer ( node )
110+ node . parent = parent
111+
102112 else :
103113 if self .allow_generic_node :
104114 origin = self .as_origin (source )
105- node = GenericNode (parent ).with_origin (origin )
115+ nodes = [ GenericNode (parent ).with_origin (origin )]
106116 self .issues .append (
107117 Issue .semantic (
108118 f"Source node not mapped: { type (source ).__qualname__ } " ,
109119 IssueSeverity .INFO ,
110120 origin .position if origin else None ))
111121 else :
112122 raise Exception (f"Unable to transform node { source } (${ type (source )} )" )
113- return node
123+ return nodes
114124
115125 def process_child (self , source , node , pd , factory ):
116126 child_key = type (node ).__qualname__ + "#" + pd .name
@@ -144,15 +154,16 @@ def set_child(self, child_node_factory: ChildNodeFactory, source: Any, node: Nod
144154 def get_source (self , node : Node , source : Any ) -> Any :
145155 return source
146156
147- def make_node (self , factory : NodeFactory [Source , Target ], source : Source ) -> Optional [Node ]:
157+ def make_nodes (self , factory : NodeFactory [Source , Target ], source : Source ) -> List [Node ]:
148158 try :
149- node = factory .constructor (source , self , factory )
150- if node :
151- node = node .with_origin (self .as_origin (source ))
152- return node
159+ nodes = factory .constructor (source , self , factory )
160+ for node in nodes :
161+ if node .origin is None :
162+ node .with_origin (self .as_origin (source ))
163+ return nodes
153164 except Exception as e :
154165 if self .allow_generic_node :
155- return GenericErrorNode (error = e ).with_origin (self .as_origin (source ))
166+ return [ GenericErrorNode (error = e ).with_origin (self .as_origin (source ))]
156167 else :
157168 raise e
158169
@@ -166,10 +177,11 @@ def get_node_factory(self, node_type: Type[Source]) -> Optional[NodeFactory[Sour
166177 return factory
167178
168179 def register_node_factory (
169- self , source : Type [Source ], factory : Union [node_factory_constructor_type , Type [Target ]]
180+ self , source : Type [Source ],
181+ factory : Union [node_factory_constructor_type , node_factory_single_constructor_type , Type [Target ]]
170182 ) -> NodeFactory [Source , Target ]:
171183 if isinstance (factory , type ):
172- node_factory = NodeFactory (lambda _ , __ , ___ : factory ())
184+ node_factory = NodeFactory (lambda _ , __ , ___ : [ factory ()] )
173185 else :
174186 node_factory = NodeFactory (get_node_constructor_wrapper (factory ))
175187 self .factories [source ] = node_factory
@@ -180,24 +192,33 @@ def register_identity_transformation(self, node_class: Type[Target]):
180192
181193
182194def get_node_constructor_wrapper (decorated_function ):
195+ def ensure_list (obj ):
196+ if isinstance (obj , list ):
197+ return obj
198+ else :
199+ return [obj ]
200+
183201 try :
184202 sig = signature (decorated_function )
185203 try :
186204 sig .bind (1 , 2 , 3 )
187- wrapper = decorated_function
205+
206+ def wrapper (node : Node , transformer : ASTTransformer , factory ):
207+ return ensure_list (decorated_function (node , transformer , factory ))
188208 except TypeError :
189209 try :
190210 sig .bind (1 , 2 )
191211
192212 def wrapper (node : Node , transformer : ASTTransformer , _ ):
193- return decorated_function (node , transformer )
213+ return ensure_list ( decorated_function (node , transformer ) )
194214 except TypeError :
195215 sig .bind (1 )
196216
197217 def wrapper (node : Node , _ , __ ):
198- return decorated_function (node )
218+ return ensure_list ( decorated_function (node ) )
199219 except ValueError :
200- wrapper = decorated_function
220+ def wrapper (node : Node , transformer : ASTTransformer , factory ):
221+ return ensure_list (decorated_function (node , transformer , factory ))
201222
202223 functools .update_wrapper (wrapper , decorated_function )
203224 return wrapper
0 commit comments