|
12 | 12 | from ._helpers import import_, all_libraries, wrapped_libraries |
13 | 13 |
|
14 | 14 | @pytest.mark.parametrize("use_compat", [True, False, None]) |
15 | | -@pytest.mark.parametrize("api_version", [None, "2021.12"]) |
| 15 | +@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"]) |
16 | 16 | @pytest.mark.parametrize("library", all_libraries + ['array_api_strict']) |
17 | 17 | def test_array_namespace(library, api_version, use_compat): |
18 | 18 | xp = import_(library) |
@@ -69,14 +69,14 @@ def test_array_namespace_errors_torch(): |
69 | 69 | pytest.raises(TypeError, lambda: array_namespace(x, y)) |
70 | 70 |
|
71 | 71 | def test_api_version(): |
72 | | - x = np.asarray([1, 2]) |
73 | | - np_ = import_("numpy", wrapper=True) |
74 | | - assert array_namespace(x, api_version="2022.12") == np_ |
75 | | - assert array_namespace(x, api_version=None) == np_ |
76 | | - assert array_namespace(x) == np_ |
| 72 | + x = torch.asarray([1, 2]) |
| 73 | + torch_ = import_("torch", wrapper=True) |
| 74 | + assert array_namespace(x, api_version="2022.12") == torch_ |
| 75 | + assert array_namespace(x, api_version=None) == torch_ |
| 76 | + assert array_namespace(x) == torch_ |
77 | 77 | # Should issue a warning |
78 | 78 | with warnings.catch_warnings(record=True) as w: |
79 | | - assert array_namespace(x, api_version="2021.12") == np_ |
| 79 | + assert array_namespace(x, api_version="2021.12") == torch_ |
80 | 80 | assert len(w) == 1 |
81 | 81 | assert "2021.12" in str(w[0].message) |
82 | 82 |
|
|
0 commit comments