Skip to content

Commit a921a02

Browse files
committed
[mx] RCEIL rounding for non-CUDA backends
The USE_PTX flag in MXFP8 triton kernels was gated by `not IS_ROCM`, which assumed only CUDA and ROCm backends exist. This would cause errors on other backends such as XPU. Signed-off-by: Ula Golowicz <urszula.golowicz@intel.com>
1 parent 62212e4 commit a921a02

2 files changed

Lines changed: 16 additions & 5 deletions

File tree

torchao/prototype/mx_formats/kernels.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
is_MI350,
2525
is_ROCM,
2626
is_sm_at_least_100,
27+
is_XPU,
2728
torch_version_at_least,
2829
)
2930

@@ -458,9 +459,14 @@ def triton_mxfp8_dequant_dim0(
458459
_triton_kernels_available = (
459460
torch_version_at_least("2.7.0")
460461
and has_triton()
461-
and torch.cuda.is_available()
462-
and (is_sm_at_least_100() and is_cuda_version_at_least(12, 8))
463-
or (is_ROCM() and is_MI350())
462+
and (
463+
(
464+
torch.cuda.is_available()
465+
and (is_sm_at_least_100() and is_cuda_version_at_least(12, 8))
466+
)
467+
or (is_ROCM() and is_MI350())
468+
or is_XPU()
469+
)
464470
)
465471

466472
if _triton_kernels_available:
@@ -469,6 +475,7 @@ def triton_mxfp8_dequant_dim0(
469475
from torch.library import triton_op, wrap_triton
470476

471477
IS_ROCM = tl.constexpr(is_ROCM())
478+
IS_XPU = tl.constexpr(is_XPU())
472479

473480
@triton.jit
474481
def _triton_calculate_scale_rceil(x, axis, USE_PTX: tl.constexpr):
@@ -686,7 +693,7 @@ def to_mxfp8_dim1_kernel(
686693
col_scale_r, col_scale_e8m0_r = _triton_calculate_scale_rceil(
687694
x_block_abs_t_r,
688695
axis=1,
689-
USE_PTX=not IS_ROCM,
696+
USE_PTX=(not IS_ROCM and not IS_XPU),
690697
)
691698
else:
692699
tl.static_assert(SCALING_MODE == "floor")
@@ -796,7 +803,7 @@ def to_mxfp8_dim0_kernel(
796803
scale_fp32_r, scale_e8m0_r = _triton_calculate_scale_rceil(
797804
x_block_abs_r,
798805
axis=1,
799-
USE_PTX=not IS_ROCM,
806+
USE_PTX=(not IS_ROCM and not IS_XPU),
800807
)
801808
else:
802809
tl.static_assert(SCALING_MODE == "floor")

torchao/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,10 @@ def check_cpu_version(device, version="2.6.0"):
12451245
return device == "cpu" and torch_version_at_least(version)
12461246

12471247

1248+
def is_XPU():
1249+
return hasattr(torch, "xpu") and torch.xpu.is_available()
1250+
1251+
12481252
def check_xpu_version(device, version="2.8.0"):
12491253
if isinstance(device, torch.device):
12501254
device = device.type

0 commit comments

Comments
 (0)