Skip to content

Commit 563c604

Browse files
authored
BUG: Use dict to store ufunc loops (fixes thread-safety) (numpy#31184)
1 parent 6847bf5 commit 563c604

5 files changed

Lines changed: 72 additions & 99 deletions

File tree

numpy/_core/code_generators/generate_umath.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1589,7 +1589,7 @@ def make_ufuncs(funcdict):
15891589
mlist.append(rf"((PyUFuncObject *)f)->type_resolver = &{uf.typereso};")
15901590
for c in uf.indexed:
15911591
# Handle indexed loops by getting the underlying ArrayMethodObject
1592-
# from the list in f._loops and setting its field appropriately
1592+
# from the dict in f._loops and setting its field appropriately
15931593
fmt = textwrap.dedent("""
15941594
{{
15951595
PyArray_DTypeMeta *dtype = PyArray_DTypeFromTypeNum({typenum});

numpy/_core/include/numpy/ufuncobject.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ typedef struct _tagPyUFuncObject {
223223
#if NPY_FEATURE_VERSION >= NPY_1_22_API_VERSION
224224
/* New private fields related to dispatching */
225225
void *_dispatch_cache;
226-
/* A PyListObject of `(tuple of DTypes, ArrayMethod/Promoter)` */
226+
/* Ordered dict `tuple of DTypes -> (tuple of DTypes, ArrayMethod/Promoter)` */
227227
PyObject *_loops;
228228
#endif
229229
#if NPY_FEATURE_VERSION >= NPY_2_1_API_VERSION

numpy/_core/src/umath/dispatching.cpp

Lines changed: 47 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
*
77
* - operand_DTypes: The datatypes as passed in by the user.
88
* - signature: The DTypes fixed by the user with `dtype=` or `signature=`.
9-
* - ufunc._loops: A list of all ArrayMethods and promoters, it contains
10-
* tuples `(dtypes, ArrayMethod)` or `(dtypes, promoter)`.
9+
* - ufunc._loops: Ordered dict of all ArrayMethods and promoters, mapping
10+
* `dtypes` to tuples `(dtypes, ArrayMethod)` or `(dtypes, promoter)`.
1111
* - ufunc._dispatch_cache: A cache to store previous promotion and/or
1212
* dispatching results.
1313
* - The actual arrays are used to support the old code paths where necessary.
@@ -70,8 +70,8 @@ promote_and_get_info_and_ufuncimpl(PyUFuncObject *ufunc,
7070

7171

7272
/**
73-
* Function to add a new loop to the ufunc. This mainly appends it to the
74-
* list (as it currently is just a list).
73+
* Function to add a new loop to the ufunc. This adds it to the
74+
* _loops dict keyed by the DType tuple.
7575
*
7676
* @param ufunc The universal function to add the loop to.
7777
* @param info The tuple (dtype_tuple, ArrayMethod/promoter).
@@ -114,38 +114,16 @@ PyUFunc_AddLoop(PyUFuncObject *ufunc, PyObject *info, int ignore_duplicate)
114114
return -1;
115115
}
116116

117-
if (ufunc->_loops == NULL) {
118-
ufunc->_loops = PyList_New(0);
119-
if (ufunc->_loops == NULL) {
120-
return -1;
121-
}
117+
int found = PyDict_SetDefaultRef(ufunc->_loops, DType_tuple, info, NULL);
118+
if (found < 0) {
119+
return -1;
122120
}
123-
124-
PyObject *loops = ufunc->_loops;
125-
Py_ssize_t length = PyList_Size(loops);
126-
for (Py_ssize_t i = 0; i < length; i++) {
127-
PyObject *item = PyList_GetItemRef(loops, i);
128-
PyObject *cur_DType_tuple = PyTuple_GetItem(item, 0);
129-
Py_DECREF(item);
130-
int cmp = PyObject_RichCompareBool(cur_DType_tuple, DType_tuple, Py_EQ);
131-
if (cmp < 0) {
132-
return -1;
133-
}
134-
if (cmp == 0) {
135-
continue;
136-
}
137-
if (ignore_duplicate) {
138-
return 0;
139-
}
121+
if (found && !ignore_duplicate) {
140122
PyErr_Format(PyExc_TypeError,
141123
"A loop/promoter has already been registered with '%s' for %R",
142124
ufunc_get_name_cstr(ufunc), DType_tuple);
143125
return -1;
144126
}
145-
146-
if (PyList_Append(loops, info) < 0) {
147-
return -1;
148-
}
149127
return 0;
150128
}
151129

@@ -332,7 +310,13 @@ resolve_implementation_info(PyUFuncObject *ufunc,
332310
PyObject **out_info)
333311
{
334312
int nin = ufunc->nin, nargs = ufunc->nargs;
335-
Py_ssize_t size = PySequence_Length(ufunc->_loops);
313+
int ret = -1;
314+
/* PyDict_Values returns a snapshot, safe against concurrent additions. */
315+
PyObject *loops = PyDict_Values(ufunc->_loops);
316+
if (loops == NULL) {
317+
return -1;
318+
}
319+
Py_ssize_t size = PySequence_Length(loops);
336320
PyObject *best_dtypes = NULL;
337321
PyObject *best_resolver_info = NULL;
338322

@@ -349,8 +333,7 @@ resolve_implementation_info(PyUFuncObject *ufunc,
349333

350334
for (Py_ssize_t res_idx = 0; res_idx < size; res_idx++) {
351335
/* Test all resolvers */
352-
PyObject *resolver_info = PySequence_Fast_GET_ITEM(
353-
ufunc->_loops, res_idx);
336+
PyObject *resolver_info = PySequence_Fast_GET_ITEM(loops, res_idx);
354337

355338
if (only_promoters && PyObject_TypeCheck(
356339
PyTuple_GET_ITEM(resolver_info, 1), &PyArrayMethod_Type)) {
@@ -411,7 +394,7 @@ resolve_implementation_info(PyUFuncObject *ufunc,
411394
int subclass = PyObject_IsSubclass(
412395
(PyObject *)given_dtype, (PyObject *)resolver_dtype);
413396
if (subclass < 0) {
414-
return -1;
397+
goto finish;
415398
}
416399
if (!subclass) {
417400
matches = NPY_FALSE;
@@ -509,7 +492,7 @@ resolve_implementation_info(PyUFuncObject *ufunc,
509492
"a better match is not yet implemented. This "
510493
"will pick the better (or bail) in the future.");
511494
*out_info = NULL;
512-
return -1;
495+
goto finish;
513496
}
514497

515498
if (best == -1) {
@@ -541,8 +524,9 @@ resolve_implementation_info(PyUFuncObject *ufunc,
541524
* We just redo it anyway for simplicity.)
542525
*/
543526
if (!only_promoters) {
544-
return resolve_implementation_info(ufunc,
545-
op_dtypes, NPY_TRUE, out_info);
527+
ret = resolve_implementation_info(
528+
ufunc, op_dtypes, NPY_TRUE, out_info);
529+
goto finish;
546530
}
547531
/*
548532
* If this is already the retry, we are out of luck. Promoters
@@ -564,7 +548,8 @@ resolve_implementation_info(PyUFuncObject *ufunc,
564548
Py_DECREF(given);
565549
}
566550
*out_info = NULL;
567-
return 0;
551+
ret = 0;
552+
goto finish;
568553
}
569554
else if (current_best == 0) {
570555
/* The new match is not better, continue looking. */
@@ -578,11 +563,15 @@ resolve_implementation_info(PyUFuncObject *ufunc,
578563
if (best_dtypes == NULL) {
579564
/* The non-legacy lookup failed */
580565
*out_info = NULL;
581-
return 0;
582566
}
567+
else {
568+
*out_info = best_resolver_info;
569+
}
570+
ret = 0;
583571

584-
*out_info = best_resolver_info;
585-
return 0;
572+
finish:
573+
Py_DECREF(loops);
574+
return ret;
586575
}
587576

588577

@@ -816,7 +805,7 @@ legacy_promote_using_legacy_type_resolver(PyUFuncObject *ufunc,
816805

817806

818807
/*
819-
* Note, this function returns a BORROWED references to info since it adds
808+
* Note, this function returns a BORROWED reference to info since it adds
820809
* it to the loops.
821810
*/
822811
NPY_NO_EXPORT PyObject *
@@ -845,8 +834,11 @@ add_and_return_legacy_wrapping_ufunc_loop(PyUFuncObject *ufunc,
845834
Py_DECREF(info);
846835
return NULL;
847836
}
848-
Py_DECREF(info); /* now borrowed from the ufunc's list of loops */
849-
return info;
837+
/* Loop currently borrowed from the _loops (use original if not replaced) */
838+
PyObject *result = PyDict_GetItemWithError( // noqa: borrowed-ref OK
839+
ufunc->_loops, PyTuple_GET_ITEM(info, 0));
840+
Py_DECREF(info);
841+
return result;
850842
}
851843

852844

@@ -1369,28 +1361,20 @@ get_info_no_cast(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtype,
13691361
return NULL;
13701362
}
13711363
for (int i=0; i < ndtypes; i++) {
1372-
PyTuple_SetItem(t_dtypes, i, (PyObject *)op_dtype);
1373-
}
1374-
PyObject *loops = ufunc->_loops;
1375-
Py_ssize_t length = PyList_Size(loops);
1376-
for (Py_ssize_t i = 0; i < length; i++) {
1377-
PyObject *item = PyList_GetItemRef(loops, i);
1378-
PyObject *cur_DType_tuple = PyTuple_GetItem(item, 0);
1379-
Py_DECREF(item);
1380-
int cmp = PyObject_RichCompareBool(cur_DType_tuple,
1381-
t_dtypes, Py_EQ);
1382-
if (cmp < 0) {
1383-
Py_DECREF(t_dtypes);
1384-
return NULL;
1385-
}
1386-
if (cmp == 0) {
1387-
continue;
1388-
}
1389-
/* Got the match */
1364+
Py_INCREF(op_dtype);
1365+
PyTuple_SET_ITEM(t_dtypes, i, (PyObject *)op_dtype);
1366+
}
1367+
PyObject *info;
1368+
if (PyDict_GetItemRef(ufunc->_loops, t_dtypes, &info) < 0) {
13901369
Py_DECREF(t_dtypes);
1391-
return PyTuple_GetItem(item, 1);
1370+
return NULL;
13921371
}
13931372
Py_DECREF(t_dtypes);
1373+
if (info != NULL) {
1374+
PyObject *result = PyTuple_GET_ITEM(info, 1);
1375+
Py_DECREF(info);
1376+
return result;
1377+
}
13941378
Py_RETURN_NONE;
13951379
}
13961380

numpy/_core/src/umath/ufunc_object.c

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4921,7 +4921,7 @@ PyUFunc_FromFuncAndDataAndSignatureAndIdentity(PyUFuncGenericFunction *func, voi
49214921
*/
49224922
ufunc->_dispatch_cache = NULL;
49234923
}
4924-
ufunc->_loops = PyList_New(0);
4924+
ufunc->_loops = PyDict_New();
49254925
if (ufunc->_loops == NULL) {
49264926
Py_DECREF(ufunc);
49274927
return NULL;
@@ -5249,21 +5249,18 @@ PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc,
52495249
* A new-style loop should not be replaced by an old-style one.
52505250
*/
52515251
int add_new_loop = 1;
5252-
for (Py_ssize_t j = 0; j < PyList_GET_SIZE(ufunc->_loops); j++) {
5253-
PyObject *item = PyList_GET_ITEM(ufunc->_loops, j); // noqa: borrowed-ref OK
5254-
PyObject *existing_tuple = PyTuple_GET_ITEM(item, 0);
5255-
5256-
int cmp = PyObject_RichCompareBool(existing_tuple, signature_tuple, Py_EQ);
5257-
if (cmp < 0) {
5258-
goto fail;
5259-
}
5260-
if (!cmp) {
5261-
continue;
5262-
}
5263-
PyObject *registered = PyTuple_GET_ITEM(item, 1);
5264-
if (!PyObject_TypeCheck(registered, &PyArrayMethod_Type) || (
5265-
(PyArrayMethodObject *)registered)->get_strided_loop !=
5266-
&get_wrapped_legacy_ufunc_loop) {
5252+
PyObject *existing_item;
5253+
if (PyDict_GetItemRef(ufunc->_loops, signature_tuple, &existing_item) < 0) {
5254+
goto fail;
5255+
}
5256+
if (existing_item != NULL) {
5257+
PyObject *registered = PyTuple_GET_ITEM(existing_item, 1);
5258+
int not_compatible = (
5259+
!PyObject_TypeCheck(registered, &PyArrayMethod_Type) ||
5260+
((PyArrayMethodObject *)registered)->get_strided_loop !=
5261+
&get_wrapped_legacy_ufunc_loop);
5262+
Py_DECREF(existing_item);
5263+
if (not_compatible) {
52675264
PyErr_Format(PyExc_TypeError,
52685265
"A non-compatible loop was already registered for "
52695266
"ufunc %s and DTypes %S.",
@@ -5272,7 +5269,6 @@ PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc,
52725269
}
52735270
/* The loop was already added */
52745271
add_new_loop = 0;
5275-
break;
52765272
}
52775273
if (add_new_loop) {
52785274
PyObject *info = add_and_return_legacy_wrapping_ufunc_loop(

numpy/_core/src/umath/wrapping_array_method.c

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ PyUFunc_AddWrappingLoop(PyObject *ufunc_obj,
235235
PyObject *wrapped_dt_tuple = NULL;
236236
PyObject *new_dt_tuple = NULL;
237237
PyArrayMethodObject *meth = NULL;
238+
PyObject *existing_info = NULL;
238239

239240
if (!PyObject_TypeCheck(ufunc_obj, &PyUFunc_Type)) {
240241
PyErr_SetString(PyExc_TypeError,
@@ -249,28 +250,19 @@ PyUFunc_AddWrappingLoop(PyObject *ufunc_obj,
249250
}
250251

251252
PyArrayMethodObject *wrapped_meth = NULL;
252-
PyObject *loops = ufunc->_loops;
253-
Py_ssize_t length = PyList_Size(loops);
254-
for (Py_ssize_t i = 0; i < length; i++) {
255-
PyObject *item = PyList_GetItemRef(loops, i);
256-
PyObject *cur_DType_tuple = PyTuple_GetItem(item, 0);
257-
Py_DECREF(item);
258-
int cmp = PyObject_RichCompareBool(cur_DType_tuple, wrapped_dt_tuple, Py_EQ);
259-
if (cmp < 0) {
260-
goto finish;
261-
}
262-
if (cmp == 0) {
263-
continue;
264-
}
265-
wrapped_meth = (PyArrayMethodObject *)PyTuple_GET_ITEM(item, 1);
266-
if (!PyObject_TypeCheck(wrapped_meth, &PyArrayMethod_Type)) {
253+
if (PyDict_GetItemRef(ufunc->_loops, wrapped_dt_tuple, &existing_info) < 0) {
254+
goto finish;
255+
}
256+
if (existing_info != NULL) {
257+
PyObject *existing_meth = PyTuple_GET_ITEM(existing_info, 1);
258+
if (!PyObject_TypeCheck(existing_meth, &PyArrayMethod_Type)) {
267259
PyErr_SetString(PyExc_TypeError,
268260
"Matching loop was not an ArrayMethod.");
269261
goto finish;
270262
}
271-
break;
263+
wrapped_meth = (PyArrayMethodObject *)existing_meth;
272264
}
273-
if (wrapped_meth == NULL) {
265+
else {
274266
PyErr_Format(PyExc_TypeError,
275267
"Did not find the to-be-wrapped loop in the ufunc with given "
276268
"DTypes. Received wrapping types: %S", wrapped_dt_tuple);
@@ -336,5 +328,6 @@ PyUFunc_AddWrappingLoop(PyObject *ufunc_obj,
336328
Py_XDECREF(wrapped_dt_tuple);
337329
Py_XDECREF(new_dt_tuple);
338330
Py_XDECREF(meth);
331+
Py_XDECREF(existing_info);
339332
return res;
340333
}

0 commit comments

Comments
 (0)