Skip to content

Commit 8f8bbcb

Browse files
authored
BUG: f2py: restore .r/.i field access on complex types via union typedef (numpy#30983)
1 parent 2205b50 commit 8f8bbcb

4 files changed

Lines changed: 64 additions & 27 deletions

File tree

numpy/f2py/cfuncs.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def errmess(s: str) -> None:
8989
typedef long double long_double;
9090
#endif
9191
"""
92-
typedefs['complex_long_double'] = 'typedef npy_clongdouble complex_long_double;'
93-
typedefs['complex_float'] = 'typedef npy_cfloat complex_float;'
94-
typedefs['complex_double'] = 'typedef npy_cdouble complex_double;'
92+
typedefs['complex_long_double'] = 'typedef union { struct {long double r,i;}; npy_clongdouble _npy; } complex_long_double;'
93+
typedefs['complex_float'] = 'typedef union { struct {float r,i;}; npy_cfloat _npy; } complex_float;'
94+
typedefs['complex_double'] = 'typedef union { struct {double r,i;}; npy_cdouble _npy; } complex_double;'
9595
typedefs['string'] = """typedef char * string;"""
9696
typedefs['character'] = """typedef char character;"""
9797

@@ -289,13 +289,13 @@ def errmess(s: str) -> None:
289289
#define pyobj_from_float1(v) (PyFloat_FromDouble(v))"""
290290
needs['pyobj_from_complex_long_double1'] = ['complex_long_double', 'npy_math.h']
291291
cppmacros['pyobj_from_complex_long_double1'] = """
292-
#define pyobj_from_complex_long_double1(v) (PyComplex_FromDoubles((double)npy_creall(v),(double)npy_cimagl(v)))"""
292+
#define pyobj_from_complex_long_double1(v) (PyComplex_FromDoubles((double)npy_creall(v._npy),(double)npy_cimagl(v._npy)))"""
293293
needs['pyobj_from_complex_double1'] = ['complex_double', 'npy_math.h']
294294
cppmacros['pyobj_from_complex_double1'] = """
295-
#define pyobj_from_complex_double1(v) (PyComplex_FromDoubles(npy_creal(v),npy_cimag(v)))"""
295+
#define pyobj_from_complex_double1(v) (PyComplex_FromDoubles(npy_creal(v._npy),npy_cimag(v._npy)))"""
296296
needs['pyobj_from_complex_float1'] = ['complex_float', 'npy_math.h']
297297
cppmacros['pyobj_from_complex_float1'] = """
298-
#define pyobj_from_complex_float1(v) (PyComplex_FromDoubles((double)npy_crealf(v),(double)npy_cimagf(v)))"""
298+
#define pyobj_from_complex_float1(v) (PyComplex_FromDoubles((double)npy_crealf(v._npy),(double)npy_cimagf(v._npy)))"""
299299
needs['pyobj_from_string1'] = ['string']
300300
cppmacros['pyobj_from_string1'] = """
301301
#define pyobj_from_string1(v) (PyUnicode_FromString((char *)v))"""
@@ -1146,7 +1146,7 @@ def errmess(s: str) -> None:
11461146
static int
11471147
complex_long_double_from_pyobj(complex_long_double* v, PyObject *obj, const char *errmess)
11481148
{
1149-
complex_double cd = npy_cpack(0.0, 0.0);
1149+
complex_double cd = {.r=0, .i=0};
11501150
if (PyArray_CheckScalar(obj)){
11511151
if PyArray_IsScalar(obj, CLongDouble) {
11521152
PyArray_ScalarAsCtype(obj, v);
@@ -1156,15 +1156,15 @@ def errmess(s: str) -> None:
11561156
PyArrayObject *arr = (PyArrayObject *)obj;
11571157
if (PyArray_TYPE(arr)==NPY_CLONGDOUBLE) {
11581158
npy_clongdouble tmp = *(npy_clongdouble *)PyArray_DATA(arr);
1159-
npy_csetreall(v, npy_creall(tmp));
1160-
npy_csetimagl(v, npy_cimagl(tmp));
1159+
npy_csetreall(&v->_npy, npy_creall(tmp));
1160+
npy_csetimagl(&v->_npy, npy_cimagl(tmp));
11611161
return 1;
11621162
}
11631163
}
11641164
}
11651165
if (complex_double_from_pyobj(&cd,obj,errmess)) {
1166-
npy_csetreall(v, (long_double)npy_creal(cd));
1167-
npy_csetimagl(v, (long_double)npy_cimag(cd));
1166+
npy_csetreall(&v->_npy, (long_double)npy_creal(cd._npy));
1167+
npy_csetimagl(&v->_npy, (long_double)npy_cimag(cd._npy));
11681168
return 1;
11691169
}
11701170
return 0;
@@ -1179,22 +1179,22 @@ def errmess(s: str) -> None:
11791179
Py_complex c;
11801180
if (PyComplex_Check(obj)) {
11811181
c = PyComplex_AsCComplex(obj);
1182-
npy_csetreal(v, c.real);
1183-
npy_csetimag(v, c.imag);
1182+
npy_csetreal(&v->_npy, c.real);
1183+
npy_csetimag(&v->_npy, c.imag);
11841184
return 1;
11851185
}
11861186
if (PyArray_IsScalar(obj, ComplexFloating)) {
11871187
if (PyArray_IsScalar(obj, CFloat)) {
11881188
npy_cfloat tmp;
11891189
PyArray_ScalarAsCtype(obj, &tmp);
1190-
npy_csetreal(v, (double)npy_crealf(tmp));
1191-
npy_csetimag(v, (double)npy_cimagf(tmp));
1190+
npy_csetreal(&v->_npy, (double)npy_crealf(tmp));
1191+
npy_csetimag(&v->_npy, (double)npy_cimagf(tmp));
11921192
}
11931193
else if (PyArray_IsScalar(obj, CLongDouble)) {
11941194
npy_clongdouble tmp;
11951195
PyArray_ScalarAsCtype(obj, &tmp);
1196-
npy_csetreal(v, (double)npy_creall(tmp));
1197-
npy_csetimag(v, (double)npy_cimagl(tmp));
1196+
npy_csetreal(&v->_npy, (double)npy_creall(tmp));
1197+
npy_csetimag(&v->_npy, (double)npy_cimagl(tmp));
11981198
}
11991199
else { /* if (PyArray_IsScalar(obj, CDouble)) */
12001200
PyArray_ScalarAsCtype(obj, v);
@@ -1213,20 +1213,20 @@ def errmess(s: str) -> None:
12131213
return 0;
12141214
}
12151215
npy_cdouble tmp = *(npy_cdouble *)PyArray_DATA(arr);
1216-
npy_csetreal(v, npy_creal(tmp));
1217-
npy_csetimag(v, npy_cimag(tmp));
1216+
npy_csetreal(&v->_npy, npy_creal(tmp));
1217+
npy_csetimag(&v->_npy, npy_cimag(tmp));
12181218
Py_DECREF(arr);
12191219
return 1;
12201220
}
12211221
/* Python does not provide PyNumber_Complex function :-( */
1222-
npy_csetimag(v, 0.0);
1222+
npy_csetimag(&v->_npy, 0.0);
12231223
if (PyFloat_Check(obj)) {
1224-
npy_csetreal(v, PyFloat_AsDouble(obj));
1225-
return !(npy_creal(*v) == -1.0 && PyErr_Occurred());
1224+
npy_csetreal(&v->_npy, PyFloat_AsDouble(obj));
1225+
return !(npy_creal(v->_npy) == -1.0 && PyErr_Occurred());
12261226
}
12271227
if (PyLong_Check(obj)) {
1228-
npy_csetreal(v, PyLong_AsDouble(obj));
1229-
return !(npy_creal(*v) == -1.0 && PyErr_Occurred());
1228+
npy_csetreal(&v->_npy, PyLong_AsDouble(obj));
1229+
return !(npy_creal(v->_npy) == -1.0 && PyErr_Occurred());
12301230
}
12311231
if (PySequence_Check(obj) && !(PyBytes_Check(obj) || PyUnicode_Check(obj))) {
12321232
PyObject *tmp = PySequence_GetItem(obj,0);
@@ -1255,10 +1255,10 @@ def errmess(s: str) -> None:
12551255
static int
12561256
complex_float_from_pyobj(complex_float* v,PyObject *obj,const char *errmess)
12571257
{
1258-
complex_double cd = npy_cpack(0.0, 0.0);
1258+
complex_double cd = {.r=0, .i=0};
12591259
if (complex_double_from_pyobj(&cd,obj,errmess)) {
1260-
npy_csetrealf(v, (float)npy_creal(cd));
1261-
npy_csetimagf(v, (float)npy_cimag(cd));
1260+
npy_csetrealf(&v->_npy, (float)npy_creal(cd._npy));
1261+
npy_csetimagf(&v->_npy, (float)npy_cimag(cd._npy));
12621262
return 1;
12631263
}
12641264
return 0;
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
subroutine zero_imag(c, n)
2+
complex*16, intent(inout) :: c(n)
3+
integer, intent(in) :: n
4+
integer :: k
5+
do k = 1, n
6+
c(k) = cmplx(dble(c(k)), 0.0d0, kind=8)
7+
end do
8+
end subroutine
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
python module _complex_struct_compat_test
2+
interface
3+
subroutine zero_imag(c, n)
4+
callstatement { int k; for(k=0;k<n;k++) (c+k)->i = 0.0; }
5+
callprotoargument complex_double*, int*
6+
7+
complex*16 intent(inout), dimension(n) :: c
8+
integer intent(hide), depend(c) :: n = shape(c,0)
9+
10+
end subroutine zero_imag
11+
end interface
12+
end python module _complex_struct_compat_test

numpy/f2py/tests/test_regression.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,23 @@ def test_gh25784():
175175
assert "unknown_subroutine_" in str(rerr)
176176

177177

178+
@pytest.mark.slow
179+
class TestComplexStructCompat(util.F2PyTest):
180+
# Check that .r/.i field access works on complex_double pointers in
181+
# callstatements (scipy compatibility, gh-30966 follow-up)
182+
sources = [
183+
util.getpath("tests", "src", "regression", "complex_struct_compat.pyf"),
184+
util.getpath("tests", "src", "regression", "complex_struct_compat.f90"),
185+
]
186+
module_name = "_complex_struct_compat_test"
187+
188+
def test_complex_struct_field_access(self):
189+
c = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex128)
190+
self.module.zero_imag(c)
191+
npt.assert_array_equal(c.imag, [0.0, 0.0, 0.0])
192+
npt.assert_array_equal(c.real, [1.0, 3.0, 5.0])
193+
194+
178195
@pytest.mark.slow
179196
class TestAssignmentOnlyModules(util.F2PyTest):
180197
# Ensure that variables are exposed without functions or subroutines in a module

0 commit comments

Comments
 (0)