|
14 | 14 | from . import xps |
15 | 15 | from .typing import Array, Shape |
16 | 16 |
|
17 | | -MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 |
18 | | -MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims |
19 | | - |
20 | 17 |
|
21 | 18 | def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: |
22 | 19 | key = "shape" |
@@ -66,7 +63,7 @@ def test_concat(dtypes, base_shape, data): |
66 | 63 | shape_strat = hh.shapes() |
67 | 64 | else: |
68 | 65 | _axis = axis if axis >= 0 else len(base_shape) + axis |
69 | | - shape_strat = st.integers(0, MAX_SIDE).map( |
| 66 | + shape_strat = st.integers(0, hh.MAX_SIDE).map( |
70 | 67 | lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :] |
71 | 68 | ) |
72 | 69 | arrays = [] |
@@ -348,26 +345,14 @@ def test_repeat(x, kw, data): |
348 | 345 | kw=kw) |
349 | 346 | start = end |
350 | 347 |
|
351 | | -@st.composite |
352 | | -def reshape_shapes(draw, shape): |
353 | | - size = 1 if len(shape) == 0 else math.prod(shape) |
354 | | - rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) |
355 | | - assume(all(side <= MAX_SIDE for side in rshape)) |
356 | | - if len(rshape) != 0 and size > 0 and draw(st.booleans()): |
357 | | - index = draw(st.integers(0, len(rshape) - 1)) |
358 | | - rshape[index] = -1 |
359 | | - return tuple(rshape) |
360 | | - |
| 348 | +reshape_shape = st.shared(hh.shapes(), key="reshape_shape") |
361 | 349 |
|
362 | 350 | @pytest.mark.unvectorized |
363 | | -@pytest.mark.skip("flaky") # TODO: fix! |
364 | 351 | @given( |
365 | | - x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(max_side=MAX_SIDE)), |
366 | | - data=st.data(), |
| 352 | + x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape), |
| 353 | + shape=hh.reshape_shapes(reshape_shape), |
367 | 354 | ) |
368 | | -def test_reshape(x, data): |
369 | | - shape = data.draw(reshape_shapes(x.shape)) |
370 | | - |
| 355 | +def test_reshape(x, shape): |
371 | 356 | out = xp.reshape(x, shape) |
372 | 357 |
|
373 | 358 | ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype) |
|
0 commit comments