|
35 | 35 | groupwise_affine_quantize_tensor_from_qparams, |
36 | 36 | ) |
37 | 37 | from torchao.utils import ( |
38 | | - check_cpu_version, |
39 | | - check_xpu_version, |
| 38 | + _is_device, |
40 | 39 | get_current_accelerator_device, |
41 | 40 | is_fbcode, |
42 | 41 | ) |
@@ -152,9 +151,9 @@ def _groupwise_affine_quantize_tensor_from_qparams( |
152 | 151 | .reshape_as(w) |
153 | 152 | ) |
154 | 153 |
|
155 | | - if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))): |
| 154 | + if (not (_is_device("cpu", w.device))) and (not (_is_device("xpu", w.device))): |
156 | 155 | w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) |
157 | | - if check_xpu_version(w.device): |
| 156 | + if _is_device("xpu", w.device): |
158 | 157 | w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8) |
159 | 158 |
|
160 | 159 | return w_int4x8 |
@@ -708,11 +707,11 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): |
708 | 707 | if zero_point_domain == ZeroPointDomain.INT: |
709 | 708 | zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32) |
710 | 709 | input_tmp = input |
711 | | - if (not (check_cpu_version(input.device))) and ( |
712 | | - not (check_xpu_version(input.device)) |
| 710 | + if (not (_is_device("cpu", input.device))) and ( |
| 711 | + not (_is_device("xpu", input.device)) |
713 | 712 | ): |
714 | 713 | input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) |
715 | | - if check_xpu_version(input.device): |
| 714 | + if _is_device("xpu", input.device): |
716 | 715 | input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8) |
717 | 716 | w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( |
718 | 717 | input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain |
|
0 commit comments