Skip to content

Commit 719daa1

Browse files
lucascolleyev-br
andauthored
ENH: array_namespace: support torch.compile
Co-authored-by: Evgeni Burovski <evgeny.burovskiy@gmail.com>
1 parent 9f3a525 commit 719daa1

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

array_api_compat/common/_helpers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def your_function(x, y):
641641
is_pydata_sparse_array
642642
643643
"""
644-
namespaces: set[Namespace] = set()
644+
namespaces: list[Namespace] = []
645645
for x in xs:
646646
xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat)
647647
if info is _ClsToXPInfo.SCALAR:
@@ -663,7 +663,14 @@ def your_function(x, y):
663663
)
664664
xp = get_ns(api_version=api_version)
665665

666-
namespaces.add(xp)
666+
namespaces.append(xp)
667+
668+
# Use a list of modules to avoid a graph break under torch.compile:
669+
# torch._dynamo.exc.Unsupported: Dynamo cannot determine whether the underlying object is hashable
670+
# Explanation: Dynamo does not know whether the underlying python object for
671+
# PythonModuleVariable(<module 'array_api_compat.torch' from ...) is hashable
672+
names = set(x.__name__ for x in namespaces)
673+
namespaces = [ns for ns in namespaces if ns.__name__ in names]
667674

668675
try:
669676
(xp,) = namespaces

0 commit comments

Comments
 (0)