diff --git a/cuda_bindings/cuda/bindings/_internal/strdecode.py b/cuda_bindings/cuda/bindings/_internal/strdecode.py new file mode 100644 index 00000000000..9c723fe9d97 --- /dev/null +++ b/cuda_bindings/cuda/bindings/_internal/strdecode.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +"""Decode C strings returned by CUDA libraries with actionable failure context.""" + +# Cap sized for the #2118 mojibake without flooding exception text. +_PREVIEW_MAX_BYTES = 64 + + +def _bounded_hex_preview(data: bytes, max_bytes: int = _PREVIEW_MAX_BYTES) -> str: + # Bytes after the first NUL are not part of the returned C string. The + # marker is explicit so truncation cannot be misread as an empty value. + nul = data.find(b"\x00") + nul_stopped = nul != -1 + visible_end = len(data) if not nul_stopped else nul + snippet_end = min(visible_end, max_bytes) + snippet = data[:snippet_end] + body = snippet.hex(" ") if snippet else "" + parts = [] + if snippet_end < visible_end: + parts.append(f"+{visible_end - snippet_end} more") + if nul_stopped: + parts.append(f"stopped at NUL@{nul}") + suffix = f" ...({'; '.join(parts)})" if parts else "" + return f"<{visible_end} bytes; hex='{body}'{suffix}>" + + +def decode_c_str(data: bytes, api_name: str) -> str: + """Decode ``data`` as UTF-8, or raise ``UnicodeDecodeError`` with ``api_name`` and a bounded hex preview in ``reason``. + + Internal API. ``api_name`` is trusted caller input and embedded verbatim. + """ + try: + return data.decode("utf-8") + except UnicodeDecodeError as e: + # Same exception type, not a subclass, so existing handlers still catch. + preview = _bounded_hex_preview(data) + reason = f"{e.reason} (returned by {api_name}; bytes={preview})" + raise UnicodeDecodeError(e.encoding, e.object, e.start, e.end, reason) from e diff --git a/cuda_bindings/tests/test_strdecode.py b/cuda_bindings/tests/test_strdecode.py new file mode 100644 index 00000000000..2d68f6cf4c1 --- /dev/null +++ b/cuda_bindings/tests/test_strdecode.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE + +import pytest + +from cuda.bindings._internal.strdecode import _bounded_hex_preview, decode_c_str + +WSL_MOJIBAKE_PREFIX = b"\xf8\x9a\x80\x80\xaf" + + +def test_valid_utf8_passthrough(): + assert decode_c_str(b"hello world", "fakeApi") == "hello world" + + +def test_invalid_bytes_raise_unicode_decode_error(): + with pytest.raises(UnicodeDecodeError): + decode_c_str(WSL_MOJIBAKE_PREFIX, "nvmlSystemGetProcessName") + + +def test_failure_reason_includes_api_name(): + with pytest.raises(UnicodeDecodeError) as excinfo: + decode_c_str(WSL_MOJIBAKE_PREFIX, "nvmlSystemGetProcessName") + assert "nvmlSystemGetProcessName" in excinfo.value.reason + + +def test_failure_reason_includes_hex_preview(): + with pytest.raises(UnicodeDecodeError) as excinfo: + decode_c_str(WSL_MOJIBAKE_PREFIX, "nvmlSystemGetProcessName") + assert "f8 9a 80 80 af" in excinfo.value.reason + + +def test_failure_chains_original_error(): + with pytest.raises(UnicodeDecodeError) as excinfo: + decode_c_str(b"\xf8", "fakeApi") + assert isinstance(excinfo.value.__cause__, UnicodeDecodeError) + + +def test_failure_preserves_codec_and_position(): + with pytest.raises(UnicodeDecodeError) as excinfo: + decode_c_str(b"\xf8\x9a", "fakeApi") + assert excinfo.value.encoding == "utf-8" + assert excinfo.value.start == 0 + assert excinfo.value.end == 1 + + +def test_preview_stops_at_first_nul(): + preview = _bounded_hex_preview(b"\xf8\xf8\x00trailing junk") + assert "f8 f8" in preview + assert "trailing" not in preview + assert "<2 bytes;" in preview + assert "stopped at NUL@2" in preview + + +def test_preview_caps_long_buffers(): + preview = _bounded_hex_preview(b"\xf8" * 200, max_bytes=8) + assert "f8 f8 f8 f8 f8 f8 f8 f8" in preview + assert "+192 more" in preview + assert "stopped at NUL" not in preview + + +def test_preview_combines_truncation_and_nul_markers(): + preview = _bounded_hex_preview(b"\xf8" * 20 + b"\x00rest", max_bytes=8) + assert "+12 more" in preview + assert "stopped at NUL@20" in preview + + +def test_failure_preview_stops_at_embedded_nul_even_with_bad_bytes_before(): + with pytest.raises(UnicodeDecodeError) as excinfo: + decode_c_str(b"\xf8\x9a\x00ignored_after_nul", "fakeApi") + reason = excinfo.value.reason + assert "f8 9a" in reason + assert "ignored_after_nul" not in reason + + +def test_failure_message_stays_bounded_for_long_garbage(): + with pytest.raises(UnicodeDecodeError) as excinfo: + decode_c_str(b"\xf8" * 1024, "fakeApi") + reason = excinfo.value.reason + assert "+960 more" in reason + assert len(reason) < 500