33
44import numpy as np
55import pytest
6- import tensorflow as tf
76import torch
87
98import blosc2
@@ -104,7 +103,7 @@ def _test_binary_func_proxy(np_func, blosc_func, dtype, shape, chunkshape, xp):
104103 dtype = np .dtype (dtype )
105104 not_blosc1 = xp .ones (shape , dtype = dtype_ )
106105 if np_func .__name__ in ("right_shift" , "left_shift" ):
107- a_blosc2 = blosc2 .asarray (2 )
106+ a_blosc2 = blosc2 .asarray (2 , copy = True )
108107 else :
109108 a_blosc2 = blosc2 .linspace (
110109 start = np .prod (shape ) * 2 ,
@@ -114,16 +113,21 @@ def _test_binary_func_proxy(np_func, blosc_func, dtype, shape, chunkshape, xp):
114113 shape = shape ,
115114 dtype = dtype ,
116115 )
117- if not blosc2 .isdtype (dtype , "integral" ):
118- a_blosc2 [tuple (i // 2 for i in shape )] = blosc2 .nan
119- if dtype == blosc2 .complex128 :
120- a_blosc2 = (
121- a_blosc2
122- + blosc2 .linspace (
123- 1j , stop = np .prod (shape ) * 1j , num = np .prod (shape ), chunks = chunkshape , shape = shape , dtype = dtype
124- )
125- ).compute ()
126- a_blosc2 [tuple (i // 2 for i in shape )] = blosc2 .nan + blosc2 .nan * 1j
116+ if not blosc2 .isdtype (dtype , "integral" ):
117+ a_blosc2 [tuple (i // 2 for i in shape )] = blosc2 .nan
118+ if dtype == blosc2 .complex128 :
119+ a_blosc2 = (
120+ a_blosc2
121+ + blosc2 .linspace (
122+ 1j ,
123+ stop = np .prod (shape ) * 1j ,
124+ num = np .prod (shape ),
125+ chunks = chunkshape ,
126+ shape = shape ,
127+ dtype = dtype ,
128+ )
129+ ).compute ()
130+ a_blosc2 [tuple (i // 2 for i in shape )] = blosc2 .nan + blosc2 .nan * 1j
127131 arr1 = np .asarray (not_blosc1 )
128132 arr2 = a_blosc2 [()]
129133 success = False
@@ -299,7 +303,7 @@ def test_unary_funcs(np_func, blosc_func, dtype, shape, chunkshape):
299303@pytest .mark .parametrize (("np_func" , "blosc_func" ), UNARY_FUNC_PAIRS )
300304@pytest .mark .parametrize ("dtype" , STR_DTYPES )
301305@pytest .mark .parametrize ("shape" , [(10 ,), (20 , 20 )])
302- @pytest .mark .parametrize ("xp" , [torch , tf ])
306+ @pytest .mark .parametrize ("xp" , [torch ])
303307def test_unfuncs_proxy (np_func , blosc_func , dtype , shape , chunkshape , xp ):
304308 _test_unary_func_proxy (np_func , blosc_func , dtype , shape , chunkshape , xp )
305309
@@ -322,7 +326,7 @@ def test_binary_funcs(np_func, blosc_func, dtype, shape, chunkshape):
322326@pytest .mark .parametrize (("np_func" , "blosc_func" ), BINARY_FUNC_PAIRS )
323327@pytest .mark .parametrize ("dtype" , STR_DTYPES )
324328@pytest .mark .parametrize (("shape" , "chunkshape" ), SHAPES_CHUNKS )
325- @pytest .mark .parametrize ("xp" , [torch , tf ])
329+ @pytest .mark .parametrize ("xp" , [torch ])
326330def test_binfuncs_proxy (np_func , blosc_func , dtype , shape , chunkshape , xp ):
327331 _test_binary_func_proxy (np_func , blosc_func , dtype , shape , chunkshape , xp )
328332
0 commit comments