Skip to content

Commit 4c996a8

Browse files
committed
Align transformers with Kolasu (wip: tests missing)
1 parent a573689 commit 4c996a8

2 files changed

Lines changed: 48 additions & 25 deletions

File tree

pylasu/transformation/transformation.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
Output = TypeVar('Output', bound=Node)
1414
Source = TypeVar('Source')
1515
Target = TypeVar('Target')
16-
node_factory_constructor_type = Callable[[Source, "ASTTransformer", "NodeFactory[Source, Output]"], Optional[Output]]
16+
node_factory_constructor_type = Callable[[Source, "ASTTransformer", "NodeFactory[Source, Output]"], List[Output]]
17+
node_factory_single_constructor_type = Callable[[Source, "ASTTransformer", "NodeFactory[Source, Output]"], Output]
1718

1819

1920
@dataclass
@@ -86,31 +87,40 @@ def __init__(self, issues: List[Issue] = None, allow_generic_node: bool = True):
8687
self.known_classes = dict()
8788

8889
def transform(self, source: Optional[Any], parent: Optional[Node] = None) -> Optional[Node]:
89-
if source is None:
90+
result = self.transform_into_nodes(source, parent)
91+
if len(result) == 0:
9092
return None
93+
elif len(result) == 1:
94+
return result[0]
95+
else:
96+
raise Exception(f"Cannot transform {source} into a single Node as multiple nodes where produced")
97+
98+
def transform_into_nodes(self, source: Optional[Any], parent: Optional[Node] = None) -> List[Node]:
99+
if source is None:
100+
return []
91101
elif isinstance(source, Iterable):
92102
raise Exception(f"Mapping error: received collection when value was expected: {source}")
93103
factory = self.get_node_factory(type(source))
94104
if factory:
95-
node = self.make_node(factory, source)
96-
if not node:
97-
return None
98-
for pd in type(node).node_properties:
99-
self.process_child(source, node, pd, factory)
100-
factory.finalizer(node)
101-
node.parent = parent
105+
nodes = self.make_nodes(factory, source)
106+
for node in nodes:
107+
for pd in type(node).node_properties:
108+
self.process_child(source, node, pd, factory)
109+
factory.finalizer(node)
110+
node.parent = parent
111+
102112
else:
103113
if self.allow_generic_node:
104114
origin = self.as_origin(source)
105-
node = GenericNode(parent).with_origin(origin)
115+
nodes = [GenericNode(parent).with_origin(origin)]
106116
self.issues.append(
107117
Issue.semantic(
108118
f"Source node not mapped: {type(source).__qualname__}",
109119
IssueSeverity.INFO,
110120
origin.position if origin else None))
111121
else:
112122
raise Exception(f"Unable to transform node {source} (${type(source)})")
113-
return node
123+
return nodes
114124

115125
def process_child(self, source, node, pd, factory):
116126
child_key = type(node).__qualname__ + "#" + pd.name
@@ -144,15 +154,16 @@ def set_child(self, child_node_factory: ChildNodeFactory, source: Any, node: Nod
144154
def get_source(self, node: Node, source: Any) -> Any:
145155
return source
146156

147-
def make_node(self, factory: NodeFactory[Source, Target], source: Source) -> Optional[Node]:
157+
def make_nodes(self, factory: NodeFactory[Source, Target], source: Source) -> List[Node]:
148158
try:
149-
node = factory.constructor(source, self, factory)
150-
if node:
151-
node = node.with_origin(self.as_origin(source))
152-
return node
159+
nodes = factory.constructor(source, self, factory)
160+
for node in nodes:
161+
if node.origin is None:
162+
node.with_origin(self.as_origin(source))
163+
return nodes
153164
except Exception as e:
154165
if self.allow_generic_node:
155-
return GenericErrorNode(error=e).with_origin(self.as_origin(source))
166+
return [GenericErrorNode(error=e).with_origin(self.as_origin(source))]
156167
else:
157168
raise e
158169

@@ -166,10 +177,11 @@ def get_node_factory(self, node_type: Type[Source]) -> Optional[NodeFactory[Sour
166177
return factory
167178

168179
def register_node_factory(
169-
self, source: Type[Source], factory: Union[node_factory_constructor_type, Type[Target]]
180+
self, source: Type[Source],
181+
factory: Union[node_factory_constructor_type, node_factory_single_constructor_type, Type[Target]]
170182
) -> NodeFactory[Source, Target]:
171183
if isinstance(factory, type):
172-
node_factory = NodeFactory(lambda _, __, ___: factory())
184+
node_factory = NodeFactory(lambda _, __, ___: [factory()])
173185
else:
174186
node_factory = NodeFactory(get_node_constructor_wrapper(factory))
175187
self.factories[source] = node_factory
@@ -180,24 +192,33 @@ def register_identity_transformation(self, node_class: Type[Target]):
180192

181193

182194
def get_node_constructor_wrapper(decorated_function):
195+
def ensure_list(obj):
196+
if isinstance(obj, list):
197+
return obj
198+
else:
199+
return [obj]
200+
183201
try:
184202
sig = signature(decorated_function)
185203
try:
186204
sig.bind(1, 2, 3)
187-
wrapper = decorated_function
205+
206+
def wrapper(node: Node, transformer: ASTTransformer, factory):
207+
return ensure_list(decorated_function(node, transformer, factory))
188208
except TypeError:
189209
try:
190210
sig.bind(1, 2)
191211

192212
def wrapper(node: Node, transformer: ASTTransformer, _):
193-
return decorated_function(node, transformer)
213+
return ensure_list(decorated_function(node, transformer))
194214
except TypeError:
195215
sig.bind(1)
196216

197217
def wrapper(node: Node, _, __):
198-
return decorated_function(node)
218+
return ensure_list(decorated_function(node))
199219
except ValueError:
200-
wrapper = decorated_function
220+
def wrapper(node: Node, transformer: ASTTransformer, factory):
221+
return ensure_list(decorated_function(node, transformer, factory))
201222

202223
functools.update_wrapper(wrapper, decorated_function)
203224
return wrapper

tests/test_metamodel_builder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def open_out_stream(self, other=None):
4343
resource.append(eClass)
4444
with BytesIO() as out:
4545
resource.save(out)
46-
self.assertEqual(json.loads(out.getvalue().decode("utf-8")), json.loads('''{
46+
self.assertEqual(
47+
json.loads('''{
4748
"eClass": "http://www.eclipse.org/emf/2002/Ecore#//EPackage",
4849
"nsPrefix": "test",
4950
"nsURI": "http://test/1.0",
@@ -82,7 +83,8 @@ def open_out_stream(self, other=None):
8283
"name": "A"
8384
}
8485
]
85-
}'''))
86+
}'''),
87+
json.loads(out.getvalue().decode("utf-8")))
8688

8789
def test_can_serialize_starlasu_model(self):
8890
starlasu_package = ASTNode.eClass.ePackage

0 commit comments

Comments
 (0)