File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -2411,6 +2411,18 @@ def _is_jax_array(x):
24112411 and isinstance (x , tp ))
24122412
24132413
2414+ def _is_mlx_array (x ):
2415+ """Return whether *x* is a MLX Array."""
2416+ try :
2417+ # We're intentionally not attempting to import mlx. If somebody
2418+ # has created a mlx array, mlx should already be in sys.modules.
2419+ tp = sys .modules .get ("mlx.core" ).array
2420+ except AttributeError :
2421+ return False # Module not imported or a nonstandard module with no Array attr.
2422+ return (isinstance (tp , type ) # Just in case it's a very nonstandard module.
2423+ and isinstance (x , tp ))
2424+
2425+
24142426def _is_pandas_dataframe (x ):
24152427 """Check if *x* is a Pandas DataFrame."""
24162428 try :
@@ -2454,7 +2466,10 @@ def _unpack_to_numpy(x):
24542466 # so in this case we do not want to return a function
24552467 if isinstance (xtmp , np .ndarray ):
24562468 return xtmp
2457- if _is_torch_array (x ) or _is_jax_array (x ) or _is_tensorflow_array (x ):
2469+ if _is_torch_array (x ) \
2470+ or _is_jax_array (x ) \
2471+ or _is_tensorflow_array (x ) \
2472+ or _is_mlx_array (x ):
24582473 # using np.asarray() instead of explicitly __array__(), as the latter is
24592474 # only _one_ of many methods, and it's the last resort, see also
24602475 # https://numpy.org/devdocs/user/basics.interoperability.html#using-arbitrary-objects-in-numpy
Original file line number Diff line number Diff line change @@ -1077,3 +1077,34 @@ def __array__(self):
10771077 # if not mocked, and the implementation does not guarantee it
10781078 # is the same Python object, just the same values.
10791079 assert_array_equal (result , data )
1080+
1081+
1082+ def test_unpack_to_numpy_from_mlx ():
1083+ """
1084+ Test that mlx arrays are converted to NumPy arrays.
1085+
1086+ We don't want to create a dependency on mlx in the test suite, so we mock it.
1087+ """
1088+ class Array :
1089+ def __init__ (self , data ):
1090+ self .data = data
1091+
1092+ def __array__ (self ):
1093+ return self .data
1094+
1095+ # mlx is something peculiar
1096+ # class `array` is in `mlx.core`
1097+ mlx_core = ModuleType ('mlx.core' )
1098+ mlx_core .array = Array
1099+
1100+ sys .modules ['mlx.core' ] = mlx_core
1101+
1102+ data = np .arange (10 )
1103+ mlx_array = mlx_core .array (data )
1104+
1105+ result = cbook ._unpack_to_numpy (mlx_array )
1106+ assert isinstance (result , np .ndarray )
1107+ # compare results, do not check for identity: the latter would fail
1108+ # if not mocked, and the implementation does not guarantee it
1109+ # is the same Python object, just the same values.
1110+ assert_array_equal (result , data )
You can’t perform that action at this time.
0 commit comments