Skip to content

Commit b3e0db2

Browse files
Xia-WeiwenCopilotCopilot
authored
Remove check_cpu_version and check_xpu_version helpers (#4211)
* Rename check_cpu_version and check_xpu_version * Remove the helpers * Add is_on_device and not_on_device * Update torchao/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 42ff8ed commit b3e0db2

5 files changed

Lines changed: 17 additions & 33 deletions

File tree

test/quantization/test_quant_primitives.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@
3535
groupwise_affine_quantize_tensor_from_qparams,
3636
)
3737
from torchao.utils import (
38-
check_cpu_version,
39-
check_xpu_version,
38+
_is_device,
4039
get_current_accelerator_device,
4140
is_fbcode,
4241
)
@@ -152,9 +151,9 @@ def _groupwise_affine_quantize_tensor_from_qparams(
152151
.reshape_as(w)
153152
)
154153

155-
if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))):
154+
if (not (_is_device("cpu", w.device))) and (not (_is_device("xpu", w.device))):
156155
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
157-
if check_xpu_version(w.device):
156+
if _is_device("xpu", w.device):
158157
w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8)
159158

160159
return w_int4x8
@@ -708,11 +707,11 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
708707
if zero_point_domain == ZeroPointDomain.INT:
709708
zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32)
710709
input_tmp = input
711-
if (not (check_cpu_version(input.device))) and (
712-
not (check_xpu_version(input.device))
710+
if (not (_is_device("cpu", input.device))) and (
711+
not (_is_device("xpu", input.device))
713712
):
714713
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
715-
if check_xpu_version(input.device):
714+
if _is_device("xpu", input.device):
716715
input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8)
717716
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
718717
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain

torchao/kernel/intmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch._dynamo import is_compiling as dynamo_is_compiling
1111
from torch._higher_order_ops.out_dtype import out_dtype
1212

13-
from torchao.utils import check_cpu_version, torch_version_at_least
13+
from torchao.utils import _is_device, torch_version_at_least
1414

1515
logger = logging.getLogger(__name__)
1616
logger.addHandler(logging.NullHandler())
@@ -192,7 +192,7 @@ def int_scaled_matmul(
192192
assert 1 == scales1.size(1)
193193
assert scales1.is_contiguous()
194194

195-
if check_cpu_version(scales1.device):
195+
if _is_device("cpu", scales1.device):
196196
return _int_scaled_matmul_cpu(a, b, scales1)
197197

198198
scales1 = scales1.expand((M, N))

torchao/prototype/hqq/hqq_tinygemm_linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from hqq.core.utils import * # noqa: F401, F403
1717
from torch import Tensor, nn
1818

19-
from torchao.utils import _is_device, check_cpu_version
19+
from torchao.utils import _is_device
2020

2121

2222
class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
@@ -166,7 +166,7 @@ def process_hqq_quants(self, W_q, meta):
166166
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(
167167
W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits
168168
)
169-
if check_cpu_version(W_q.device):
169+
if _is_device("cpu", W_q.device):
170170
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
171171
W_q_torch, self.inner_k_tiles
172172
)
@@ -241,7 +241,7 @@ def pack_scales_and_zeros(self, scales, zeros):
241241
def matmul(self, x):
242242
origin_x_size = x.size()
243243
x = x.reshape(-1, origin_x_size[-1])
244-
if check_cpu_version(x.device):
244+
if _is_device("cpu", x.device):
245245
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
246246
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
247247
)

torchao/quantization/utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@
2525
dequantize_affine,
2626
quantize_affine,
2727
)
28-
from torchao.utils import (
29-
check_cpu_version,
30-
check_xpu_version,
31-
)
28+
from torchao.utils import _is_device
3229

3330
from .granularity import (
3431
Granularity,
@@ -462,11 +459,11 @@ def groupwise_affine_quantize_tensor_from_qparams(
462459
quant_max,
463460
)
464461
if w.shape[-1] > 1:
465-
if (not (check_cpu_version(int_data.device))) and (
466-
not (check_xpu_version(int_data.device))
462+
if (not (_is_device("cpu", int_data.device))) and (
463+
not (_is_device("xpu", int_data.device))
467464
):
468465
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
469-
if check_xpu_version(int_data.device):
466+
if _is_device("xpu", int_data.device):
470467
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
471468
return int_data
472469

@@ -483,7 +480,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
483480
assert w_int4x8.dim() == 2
484481
# need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path
485482
if (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not (
486-
check_cpu_version(w_int4x8.device)
483+
_is_device("cpu", w_int4x8.device)
487484
):
488485
data = w_int4x8.to(torch.int32)
489486
high_bits = data >> 4
@@ -493,7 +490,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
493490
dtype=torch.int32,
494491
device=w_int4x8.device,
495492
)
496-
if not (check_xpu_version(w_int4x8.device)):
493+
if not (_is_device("xpu", w_int4x8.device)):
497494
w_int32[::, ::2] = high_bits
498495
w_int32[::, 1::2] = low_bits
499496
else:

torchao/utils.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,18 +1169,6 @@ def is_cuda_version_at_least(major: int, minor: int) -> bool:
11691169
return (cuda_major, cuda_minor) >= (major, minor)
11701170

11711171

1172-
def check_cpu_version(device, version="2.6.0"):
1173-
if isinstance(device, torch.device):
1174-
device = device.type
1175-
return device == "cpu" and torch_version_at_least(version)
1176-
1177-
1178-
def check_xpu_version(device, version="2.8.0"):
1179-
if isinstance(device, torch.device):
1180-
device = device.type
1181-
return device == "xpu" and torch_version_at_least(version)
1182-
1183-
11841172
def ceil_div(a, b):
11851173
return (a + b - 1) // b
11861174

0 commit comments

Comments
 (0)