Skip to content
11 changes: 11 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,17 @@ StreamHandle create_stream_handle_with_owner(CUstream stream, PyObject* owner) {
return StreamHandle(box, &box->resource);
}

void py_object_user_object_destroy(void* py_object) noexcept {
if (!py_object) {
return;
}
GILAcquireGuard gil;
if (!gil.acquired()) {
return;
}
Py_DECREF(reinterpret_cast<PyObject*>(py_object));
}

ContextHandle get_stream_context(const StreamHandle& h) noexcept {
return h ? get_box(h)->h_context : ContextHandle{};
}
Expand Down
20 changes: 20 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ StreamHandle create_stream_handle_ref(CUstream stream);
// The owner is responsible for keeping the stream's context alive.
StreamHandle create_stream_handle_with_owner(CUstream stream, PyObject* owner);

// Destroy a Python-backed CUDA user object by decref'ing it when safe.
// If Python is finalized or finalizing, the object is intentionally leaked.
void py_object_user_object_destroy(void* py_object) noexcept;

// Return the context dependency associated with a stream handle, if any.
ContextHandle get_stream_context(const StreamHandle& h) noexcept;

Expand Down Expand Up @@ -662,6 +666,22 @@ inline std::intptr_t as_intptr(const FileDescriptorHandle& h) noexcept {
extern "C" int _Py_IsFinalizing(void);
#endif

// Best-effort probe for interpreter shutdown.
//
// In CPython this is not a hard guarantee: finalization can begin after this
// returns false but before a later PyGILState_Ensure() or other Python C-API
// call.
//
// If that race is lost on a non-finalizer thread, CPython's behavior is
// version-dependent: on older supported versions (3.10-3.13) it may abruptly
// terminate the current thread (historically via PyThread_exit_thread(),
// without normal C++ unwinding), while on newer versions (3.14+) it may hang
// the thread until process exit.
//
// We still use this check because the policy in this layer is to avoid Python
// work once shutdown is underway and accept an intentional leak or skipped
// Python conversion in that edge case rather than add more complex deferral
// machinery.
inline bool py_is_finalizing() noexcept {
#if PY_VERSION_HEX >= 0x030D0000
return Py_IsFinalizing();
Expand Down
1 change: 1 addition & 0 deletions cuda_core/cuda/core/_resource_handles.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ cdef StreamHandle create_stream_handle(
const ContextHandle& h_ctx, unsigned int flags, int priority) except+ nogil
cdef StreamHandle create_stream_handle_ref(cydriver.CUstream stream) except+ nogil
cdef StreamHandle create_stream_handle_with_owner(cydriver.CUstream stream, object owner) except+ nogil
cdef void py_object_user_object_destroy(void* py_object) noexcept nogil
cdef ContextHandle get_stream_context(const StreamHandle& h) noexcept nogil
cdef StreamHandle get_legacy_stream() except+ nogil
cdef StreamHandle get_per_thread_stream() except+ nogil
Expand Down
2 changes: 2 additions & 0 deletions cuda_core/cuda/core/_resource_handles.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
cydriver.CUstream stream) except+ nogil
StreamHandle create_stream_handle_with_owner "cuda_core::create_stream_handle_with_owner" (
cydriver.CUstream stream, object owner) except+ nogil
void py_object_user_object_destroy "cuda_core::py_object_user_object_destroy" (
void* py_object) noexcept nogil
ContextHandle get_stream_context "cuda_core::get_stream_context" (
const StreamHandle& h) noexcept nogil
StreamHandle get_legacy_stream "cuda_core::get_legacy_stream" () except+ nogil
Expand Down
6 changes: 3 additions & 3 deletions cuda_core/cuda/core/graph/_graph_node.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,22 @@ from cuda.core.graph._subclasses cimport (
from cuda.core._resource_handles cimport (
EventHandle,
GraphHandle,
KernelHandle,
GraphNodeHandle,
KernelHandle,
as_cu,
as_intptr,
as_py,
create_graph_handle_ref,
create_graph_node_handle,
graph_node_get_graph,
invalidate_graph_node,
py_object_user_object_destroy,
)
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value

from cuda.core.graph._utils cimport (
_attach_host_callback_to_graph,
_attach_user_object,
_py_host_destructor,
)

import weakref
Expand Down Expand Up @@ -650,7 +650,7 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker,
if kernel_args is not None:
Py_INCREF(kernel_args)
_attach_user_object(as_cu(h_graph), <void*>kernel_args,
<cydriver.CUhostFn>_py_host_destructor)
<cydriver.CUhostFn>py_object_user_object_destroy)

return _registered(KernelNode._create_with_params(
create_graph_node_handle(new_node, h_graph),
Expand Down
2 changes: 0 additions & 2 deletions cuda_core/cuda/core/graph/_utils.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ from cuda.bindings cimport cydriver

cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil

cdef void _py_host_destructor(void* data) noexcept with gil

cdef void _attach_user_object(
cydriver.CUgraph graph, void* ptr,
cydriver.CUhostFn destroy) except *
Expand Down
13 changes: 3 additions & 10 deletions cuda_core/cuda/core/graph/_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,14 @@ from libc.string cimport memcpy as c_memcpy

from cuda.bindings cimport cydriver

from cuda.core._resource_handles cimport py_object_user_object_destroy
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN


cdef extern from "Python.h":
void _py_decref "Py_DECREF" (void*)


cdef void _py_host_trampoline(void* data) noexcept with gil:
(<object>data)()


cdef void _py_host_destructor(void* data) noexcept with gil:
_py_decref(data)


cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil:
return fn == <cydriver.CUhostFn>_py_host_trampoline

Expand Down Expand Up @@ -73,7 +66,7 @@ cdef void _attach_host_callback_to_graph(
fn_pyobj = <void*>fn
_attach_user_object(
graph, fn_pyobj,
<cydriver.CUhostFn>_py_host_destructor)
<cydriver.CUhostFn>py_object_user_object_destroy)
out_fn[0] = <cydriver.CUhostFn><uintptr_t>ct.cast(
fn, ct.c_void_p).value

Expand Down Expand Up @@ -103,4 +96,4 @@ cdef void _attach_host_callback_to_graph(
out_user_data[0] = fn_pyobj
_attach_user_object(
graph, fn_pyobj,
<cydriver.CUhostFn>_py_host_destructor)
<cydriver.CUhostFn>py_object_user_object_destroy)