Skip to content
Open
60 changes: 36 additions & 24 deletions cuda_core/cuda/core/_module.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from __future__ import annotations
from libc.stddef cimport size_t

from collections import namedtuple
from os import fsencode, fspath, PathLike

from cuda.core._device import Device
from cuda.core._launch_config cimport LaunchConfig
Expand Down Expand Up @@ -606,7 +607,13 @@ cdef class ObjectCode:
self._h_library = LibraryHandle() # Empty handle

self._code_type = str(code_type)
self._module = module

if isinstance(module, (str, bytes, bytearray)):
self._module = module
elif isinstance(module, PathLike):
self._module = fspath(module)
else:
self._module = module
self._sym_map = {} if symbol_mapping is None else symbol_mapping
self._name = name if name else ""

Expand All @@ -620,14 +627,15 @@ cdef class ObjectCode:
return ObjectCode._reduce_helper, (self._module, self._code_type, self._name, self._sym_map)

@staticmethod
def from_cubin(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
def from_cubin(module: bytes | str | PathLike, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
"""Create an :class:`ObjectCode` instance from an existing cubin.

Parameters
----------
module : Union[bytes, str]
module : Union[bytes, str, os.PathLike]
Either a bytes object containing the in-memory cubin to load, or
a file path string pointing to the on-disk cubin to load.
a file path object (or its string representation) pointing to the
on-disk cubin to load.
name : Optional[str]
A human-readable identifier representing this code object.
symbol_mapping : Optional[dict]
Expand All @@ -638,14 +646,15 @@ cdef class ObjectCode:
return ObjectCode._init(module, ObjectCodeFormatType.CUBIN, name=name, symbol_mapping=symbol_mapping)

@staticmethod
def from_ptx(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
def from_ptx(module: bytes | str | PathLike, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
"""Create an :class:`ObjectCode` instance from an existing PTX.

Parameters
----------
module : Union[bytes, str]
module : Union[bytes, str, os.PathLike]
Either a bytes object containing the in-memory ptx code to load, or
a file path string pointing to the on-disk ptx file to load.
a file path object (or its string representation) pointing to the
on-disk ptx file to load.
name : Optional[str]
A human-readable identifier representing this code object.
symbol_mapping : Optional[dict]
Expand All @@ -656,14 +665,15 @@ cdef class ObjectCode:
return ObjectCode._init(module, ObjectCodeFormatType.PTX, name=name, symbol_mapping=symbol_mapping)

@staticmethod
def from_ltoir(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
def from_ltoir(module: bytes | str | PathLike, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
"""Create an :class:`ObjectCode` instance from an existing LTOIR.

Parameters
----------
module : Union[bytes, str]
Either a bytes object containing the in-memory ltoir code to load, or
a file path string pointing to the on-disk ltoir file to load.
module : Union[bytes, str, os.PathLike]
Either a bytes object containing the in-memory ltoir code to load,
or a file path object (or its string representation) pointing to the
on-disk ltoir file to load.
name : Optional[str]
A human-readable identifier representing this code object.
symbol_mapping : Optional[dict]
Expand All @@ -674,14 +684,15 @@ cdef class ObjectCode:
return ObjectCode._init(module, ObjectCodeFormatType.LTOIR, name=name, symbol_mapping=symbol_mapping)

@staticmethod
def from_fatbin(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
def from_fatbin(module: bytes | str | PathLike, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode:
"""Create an :class:`ObjectCode` instance from an existing fatbin.

Parameters
----------
module : Union[bytes, str]
module : Union[bytes, str, os.PathLike]
Either a bytes object containing the in-memory fatbin to load, or
a file path string pointing to the on-disk fatbin to load.
or a file path object (or its string representation) pointing to the
on-disk fatbin to load.
name : Optional[str]
A human-readable identifier representing this code object.
symbol_mapping : Optional[dict]
Expand Down Expand Up @@ -733,21 +744,22 @@ cdef class ObjectCode:
if self._h_library:
return 0
module = self._module
assert_type_str_or_bytes_like(module)
cdef bytes path_bytes
if isinstance(module, str):
path_bytes = module.encode()
self._h_library = create_library_handle_from_file(<const char*>path_bytes)
if not self._h_library:
HANDLE_RETURN(get_last_error())
return 0
if isinstance(module, (bytes, bytearray)):
elif isinstance(module, (bytes, bytearray)):
self._h_library = create_library_handle_from_data(<const void*><char*>module)
if not self._h_library:
HANDLE_RETURN(get_last_error())
return 0
raise_code_path_meant_to_be_unreachable()
return -1
elif isinstance(module, PathLike):
path_bytes = fsencode(module)
self._h_library = create_library_handle_from_file(<const char*>path_bytes)
else:
assert_type_str_or_bytes_like(module)
raise_code_path_meant_to_be_unreachable()
return -1
if not self._h_library:
HANDLE_RETURN(get_last_error())
return 0

def get_kernel(self, name) -> Kernel:
"""Return the :obj:`~_module.Kernel` of a specified name from this object code.
Expand Down
33 changes: 21 additions & 12 deletions cuda_core/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def cuda12_4_prerequisite_check():
return binding_version() >= (12, 0, 0) and driver_version() >= (12, 4, 0)


@pytest.fixture(name="convert_path", params=[str, lambda p: p], ids=["str", "path"])
def convert_path_to_arg(request):
return request.param


def test_kernel_attributes_init_disabled():
with pytest.raises(RuntimeError, match=r"^KernelAttributes cannot be instantiated directly\."):
cuda.core._module.KernelAttributes() # Ensure back door is locked.
Expand Down Expand Up @@ -231,14 +236,15 @@ def test_object_code_load_ptx(get_saxpy_kernel_ptx):
mod_obj.get_kernel("saxpy<double>") # force loading


def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path):
def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path, convert_path):
ptx, mod = get_saxpy_kernel_ptx
sym_map = mod.symbol_mapping
assert isinstance(ptx, bytes)
ptx_file = tmp_path / "test.ptx"
ptx_file.write_bytes(ptx)
mod_obj = ObjectCode.from_ptx(str(ptx_file), symbol_mapping=sym_map)
assert mod_obj.code == str(ptx_file)
arg = convert_path(ptx_file)
mod_obj = ObjectCode.from_ptx(arg, symbol_mapping=sym_map)
assert mod_obj.code == str(arg)
assert mod_obj.code_type == "ptx"
if not _can_load_generated_ptx():
pytest.skip("PTX version too new for current driver")
Expand All @@ -255,15 +261,16 @@ def test_object_code_load_cubin(get_saxpy_kernel_cubin):
mod.get_kernel("saxpy<double>") # force loading


def test_object_code_load_cubin_from_file(get_saxpy_kernel_cubin, tmp_path):
def test_object_code_load_cubin_from_file(get_saxpy_kernel_cubin, tmp_path, convert_path):
_, mod = get_saxpy_kernel_cubin
cubin = mod.code
sym_map = mod.symbol_mapping
assert isinstance(cubin, bytes)
cubin_file = tmp_path / "test.cubin"
cubin_file.write_bytes(cubin)
mod = ObjectCode.from_cubin(str(cubin_file), symbol_mapping=sym_map)
assert mod.code == str(cubin_file)
arg = convert_path(cubin_file)
mod = ObjectCode.from_cubin(arg, symbol_mapping=sym_map)
assert mod.code == str(arg)
mod.get_kernel("saxpy<double>") # force loading


Expand All @@ -286,15 +293,16 @@ def test_object_code_load_ltoir(get_saxpy_kernel_ltoir):
mod_obj.get_kernel("saxpy<float>")


def test_object_code_load_ltoir_from_file(get_saxpy_kernel_ltoir, tmp_path):
def test_object_code_load_ltoir_from_file(get_saxpy_kernel_ltoir, tmp_path, convert_path):
mod = get_saxpy_kernel_ltoir
ltoir = mod.code
sym_map = mod.symbol_mapping
assert isinstance(ltoir, bytes)
ltoir_file = tmp_path / "test.ltoir"
ltoir_file.write_bytes(ltoir)
mod_obj = ObjectCode.from_ltoir(str(ltoir_file), symbol_mapping=sym_map)
assert mod_obj.code == str(ltoir_file)
arg = convert_path(ltoir_file)
mod_obj = ObjectCode.from_ltoir(arg, symbol_mapping=sym_map)
assert mod_obj.code == str(arg)
assert mod_obj.code_type == "ltoir"
# ltoir doesn't support kernel retrieval directly as it's used for linking

Expand All @@ -310,13 +318,14 @@ def test_object_code_load_fatbin(get_saxpy_fatbin):


@nvfatbin_available
def test_object_code_load_fatbin_from_file(get_saxpy_fatbin, tmp_path):
def test_object_code_load_fatbin_from_file(get_saxpy_fatbin, tmp_path, convert_path):
fatbin, sym_map = get_saxpy_fatbin
assert isinstance(fatbin, bytes)
fatbin_file = tmp_path / "test.fatbin"
fatbin_file.write_bytes(fatbin)
mod_obj = ObjectCode.from_fatbin(str(fatbin_file), symbol_mapping=sym_map)
assert mod_obj.code == str(fatbin_file)
arg = convert_path(fatbin_file)
mod_obj = ObjectCode.from_fatbin(arg, symbol_mapping=sym_map)
assert mod_obj.code == str(arg)
assert mod_obj.code_type == "fatbin"
mod_obj.get_kernel("saxpy<double>") # force loading

Expand Down
Loading