diff --git a/cuda_core/cuda/core/_module.pyx b/cuda_core/cuda/core/_module.pyx index 96ac65effc..b2dd5b4d8d 100644 --- a/cuda_core/cuda/core/_module.pyx +++ b/cuda_core/cuda/core/_module.pyx @@ -7,6 +7,7 @@ from __future__ import annotations from libc.stddef cimport size_t from collections import namedtuple +from os import fsencode, fspath, PathLike from cuda.core._device import Device from cuda.core._launch_config cimport LaunchConfig @@ -606,7 +607,13 @@ cdef class ObjectCode: self._h_library = LibraryHandle() # Empty handle self._code_type = str(code_type) - self._module = module + + if isinstance(module, (str, bytes, bytearray)): + self._module = module + elif isinstance(module, PathLike): + self._module = fspath(module) + else: + self._module = module self._sym_map = {} if symbol_mapping is None else symbol_mapping self._name = name if name else "" @@ -620,14 +627,15 @@ cdef class ObjectCode: return ObjectCode._reduce_helper, (self._module, self._code_type, self._name, self._sym_map) @staticmethod - def from_cubin(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: + def from_cubin(module: bytes | str | PathLike, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: """Create an :class:`ObjectCode` instance from an existing cubin. Parameters ---------- - module : Union[bytes, str] + module : Union[bytes, str, os.PathLike] Either a bytes object containing the in-memory cubin to load, or - a file path string pointing to the on-disk cubin to load. + a file path object (or its string representation) pointing to the + on-disk cubin to load. name : Optional[str] A human-readable identifier representing this code object. symbol_mapping : Optional[dict] @@ -638,14 +646,15 @@ cdef class ObjectCode: return ObjectCode._init(module, ObjectCodeFormatType.CUBIN, name=name, symbol_mapping=symbol_mapping) @staticmethod - def from_ptx(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: + def from_ptx(module: bytes | str | PathLike, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: """Create an :class:`ObjectCode` instance from an existing PTX. Parameters ---------- - module : Union[bytes, str] + module : Union[bytes, str, os.PathLike] Either a bytes object containing the in-memory ptx code to load, or - a file path string pointing to the on-disk ptx file to load. + a file path object (or its string representation) pointing to the + on-disk ptx file to load. name : Optional[str] A human-readable identifier representing this code object. symbol_mapping : Optional[dict] @@ -656,14 +665,15 @@ cdef class ObjectCode: return ObjectCode._init(module, ObjectCodeFormatType.PTX, name=name, symbol_mapping=symbol_mapping) @staticmethod - def from_ltoir(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: + def from_ltoir(module: bytes | str | PathLike, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: """Create an :class:`ObjectCode` instance from an existing LTOIR. Parameters ---------- - module : Union[bytes, str] - Either a bytes object containing the in-memory ltoir code to load, or - a file path string pointing to the on-disk ltoir file to load. + module : Union[bytes, str, os.PathLike] + Either a bytes object containing the in-memory ltoir code to load, + or a file path object (or its string representation) pointing to the + on-disk ltoir file to load. name : Optional[str] A human-readable identifier representing this code object. symbol_mapping : Optional[dict] @@ -674,14 +684,15 @@ cdef class ObjectCode: return ObjectCode._init(module, ObjectCodeFormatType.LTOIR, name=name, symbol_mapping=symbol_mapping) @staticmethod - def from_fatbin(module: bytes | str, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: + def from_fatbin(module: bytes | str | PathLike, *, name: str = "", symbol_mapping: dict | None = None) -> ObjectCode: """Create an :class:`ObjectCode` instance from an existing fatbin. Parameters ---------- - module : Union[bytes, str] + module : Union[bytes, str, os.PathLike] Either a bytes object containing the in-memory fatbin to load, or - a file path string pointing to the on-disk fatbin to load. + or a file path object (or its string representation) pointing to the + on-disk fatbin to load. name : Optional[str] A human-readable identifier representing this code object. symbol_mapping : Optional[dict] @@ -733,21 +744,22 @@ cdef class ObjectCode: if self._h_library: return 0 module = self._module - assert_type_str_or_bytes_like(module) cdef bytes path_bytes if isinstance(module, str): path_bytes = module.encode() self._h_library = create_library_handle_from_file(path_bytes) - if not self._h_library: - HANDLE_RETURN(get_last_error()) - return 0 - if isinstance(module, (bytes, bytearray)): + elif isinstance(module, (bytes, bytearray)): self._h_library = create_library_handle_from_data(module) - if not self._h_library: - HANDLE_RETURN(get_last_error()) - return 0 - raise_code_path_meant_to_be_unreachable() - return -1 + elif isinstance(module, PathLike): + path_bytes = fsencode(module) + self._h_library = create_library_handle_from_file(path_bytes) + else: + assert_type_str_or_bytes_like(module) + raise_code_path_meant_to_be_unreachable() + return -1 + if not self._h_library: + HANDLE_RETURN(get_last_error()) + return 0 def get_kernel(self, name) -> Kernel: """Return the :obj:`~_module.Kernel` of a specified name from this object code. diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index 5c994b6f5e..3a438f825a 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -52,6 +52,11 @@ def cuda12_4_prerequisite_check(): return binding_version() >= (12, 0, 0) and driver_version() >= (12, 4, 0) +@pytest.fixture(name="convert_path", params=[str, lambda p: p], ids=["str", "path"]) +def convert_path_to_arg(request): + return request.param + + def test_kernel_attributes_init_disabled(): with pytest.raises(RuntimeError, match=r"^KernelAttributes cannot be instantiated directly\."): cuda.core._module.KernelAttributes() # Ensure back door is locked. @@ -231,14 +236,15 @@ def test_object_code_load_ptx(get_saxpy_kernel_ptx): mod_obj.get_kernel("saxpy") # force loading -def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path): +def test_object_code_load_ptx_from_file(get_saxpy_kernel_ptx, tmp_path, convert_path): ptx, mod = get_saxpy_kernel_ptx sym_map = mod.symbol_mapping assert isinstance(ptx, bytes) ptx_file = tmp_path / "test.ptx" ptx_file.write_bytes(ptx) - mod_obj = ObjectCode.from_ptx(str(ptx_file), symbol_mapping=sym_map) - assert mod_obj.code == str(ptx_file) + arg = convert_path(ptx_file) + mod_obj = ObjectCode.from_ptx(arg, symbol_mapping=sym_map) + assert mod_obj.code == str(arg) assert mod_obj.code_type == "ptx" if not _can_load_generated_ptx(): pytest.skip("PTX version too new for current driver") @@ -255,15 +261,16 @@ def test_object_code_load_cubin(get_saxpy_kernel_cubin): mod.get_kernel("saxpy") # force loading -def test_object_code_load_cubin_from_file(get_saxpy_kernel_cubin, tmp_path): +def test_object_code_load_cubin_from_file(get_saxpy_kernel_cubin, tmp_path, convert_path): _, mod = get_saxpy_kernel_cubin cubin = mod.code sym_map = mod.symbol_mapping assert isinstance(cubin, bytes) cubin_file = tmp_path / "test.cubin" cubin_file.write_bytes(cubin) - mod = ObjectCode.from_cubin(str(cubin_file), symbol_mapping=sym_map) - assert mod.code == str(cubin_file) + arg = convert_path(cubin_file) + mod = ObjectCode.from_cubin(arg, symbol_mapping=sym_map) + assert mod.code == str(arg) mod.get_kernel("saxpy") # force loading @@ -286,15 +293,16 @@ def test_object_code_load_ltoir(get_saxpy_kernel_ltoir): mod_obj.get_kernel("saxpy") -def test_object_code_load_ltoir_from_file(get_saxpy_kernel_ltoir, tmp_path): +def test_object_code_load_ltoir_from_file(get_saxpy_kernel_ltoir, tmp_path, convert_path): mod = get_saxpy_kernel_ltoir ltoir = mod.code sym_map = mod.symbol_mapping assert isinstance(ltoir, bytes) ltoir_file = tmp_path / "test.ltoir" ltoir_file.write_bytes(ltoir) - mod_obj = ObjectCode.from_ltoir(str(ltoir_file), symbol_mapping=sym_map) - assert mod_obj.code == str(ltoir_file) + arg = convert_path(ltoir_file) + mod_obj = ObjectCode.from_ltoir(arg, symbol_mapping=sym_map) + assert mod_obj.code == str(arg) assert mod_obj.code_type == "ltoir" # ltoir doesn't support kernel retrieval directly as it's used for linking @@ -310,13 +318,14 @@ def test_object_code_load_fatbin(get_saxpy_fatbin): @nvfatbin_available -def test_object_code_load_fatbin_from_file(get_saxpy_fatbin, tmp_path): +def test_object_code_load_fatbin_from_file(get_saxpy_fatbin, tmp_path, convert_path): fatbin, sym_map = get_saxpy_fatbin assert isinstance(fatbin, bytes) fatbin_file = tmp_path / "test.fatbin" fatbin_file.write_bytes(fatbin) - mod_obj = ObjectCode.from_fatbin(str(fatbin_file), symbol_mapping=sym_map) - assert mod_obj.code == str(fatbin_file) + arg = convert_path(fatbin_file) + mod_obj = ObjectCode.from_fatbin(arg, symbol_mapping=sym_map) + assert mod_obj.code == str(arg) assert mod_obj.code_type == "fatbin" mod_obj.get_kernel("saxpy") # force loading