Skip to content

Commit e6f5227

Browse files
committed
raise warning if arrays need to be copied
1 parent 6ee641e commit e6f5227

3 files changed

Lines changed: 55 additions & 6 deletions

File tree

cebra/integrations/sklearn/cebra.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,8 @@ def transform(self,
12531253

12541254
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
12551255

1256+
X = cebra_sklearn_dataset._ensure_writable(X)
1257+
12561258
if isinstance(X, np.ndarray):
12571259
X = torch.from_numpy(X)
12581260

cebra/integrations/sklearn/dataset.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#
2222
"""Datasets to be used as part of the sklearn framework."""
2323

24+
import traceback
25+
import warnings
2426
from typing import Iterable, Optional
2527

2628
import numpy as np
@@ -34,6 +36,28 @@
3436
import cebra.solver
3537

3638

39+
def _ensure_writable(array: npt.NDArray) -> npt.NDArray:
40+
if not array.flags.writeable:
41+
stack = traceback.extract_stack()[-5:-1]
42+
stack_str = ''.join(traceback.format_list(stack[-4:]))
43+
44+
warnings.warn(
45+
("You passed a non-writable Numpy array to CEBRA. Pytorch does currently "
46+
"not support non-writable tensors. As a result, CEBRA needs to copy the "
47+
"contents of the array, which might yield unnecessary memory overhead. "
48+
"Ideally, adapt the code such that the array you pass to CEBRA is writable "
49+
"to make your code memory efficient. "
50+
"You can find more context and the rationale for this fix here: "
51+
"https://github.com/AdaptiveMotorControlLab/CEBRA/pull/289."
52+
"\n\n"
53+
"Trace:\n" + stack_str),
54+
UserWarning,
55+
stacklevel=2,
56+
)
57+
array = array.copy()
58+
return array
59+
60+
3761
class SklearnDataset(cebra.data.SingleSessionDataset):
3862
"""Dataset for wrapping array-like input/index pairs.
3963
@@ -110,9 +134,7 @@ def _parse_data(self, X: npt.NDArray):
110134
# one sample is a conservative default here to ensure that sklearn tests
111135
# passes with the correct error messages.
112136
X = cebra_sklearn_utils.check_input_array(X, min_samples=2)
113-
# Ensure array is writable (pandas 3.0+ may return read-only arrays)
114-
if not X.flags.writeable:
115-
X = X.copy()
137+
X = _ensure_writable(X)
116138
self.neural = torch.from_numpy(X).float().to(self.device)
117139

118140
def _parse_labels(self, labels: Optional[tuple]):
@@ -146,11 +168,10 @@ def _parse_labels(self, labels: Optional[tuple]):
146168
f"or lists that can be converted to arrays, but got {type(y)}"
147169
)
148170

171+
y = _ensure_writable(y)
172+
149173
# Define the index as either continuous or discrete indices, depending
150174
# on the dtype in the index array.
151-
# Ensure array is writable (pandas 3.0+ may return read-only arrays)
152-
if not y.flags.writeable:
153-
y = y.copy()
154175
if cebra.helper._is_floating(y):
155176
y = torch.from_numpy(y).float()
156177
if y.dim() == 1:

tests/test_sklearn.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,3 +1544,29 @@ def test_last_incomplete_batch_smaller_than_offset():
15441544
model.fit(train.neural, train.continuous)
15451545

15461546
_ = model.transform(train.neural, batch_size=300)
1547+
1548+
1549+
def test_non_writable_array():
1550+
import numpy as np
1551+
import pytest
1552+
1553+
import cebra
1554+
1555+
# Create a numpy array and make it non-writable
1556+
X = np.random.randn(100, 10)
1557+
y = np.random.randn(100, 2)
1558+
1559+
X.setflags(write=False)
1560+
y.setflags(write=False)
1561+
1562+
with pytest.raises(ValueError, match="assignment destination is read-only"):
1563+
X[:] = 0
1564+
y[:] = 0
1565+
1566+
cebra_model = cebra.CEBRA(max_iterations=2, batch_size=32, device="cpu")
1567+
1568+
# This should not raise an exception even though arrays are not writable
1569+
cebra_model.fit(X, y)
1570+
embedding = cebra_model.transform(X)
1571+
assert isinstance(embedding, np.ndarray)
1572+
assert embedding.shape[0] == X.shape[0]

0 commit comments

Comments
 (0)