Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 60 additions & 20 deletions torchao/prototype/gptq/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,43 +263,85 @@ def gptq_quantize(H: torch.Tensor, W_t: torch.Tensor, config: GPTQConfig):
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

# GPTQ update loop:
#
# W_t (below)
#
# |------------------ K --------------------|
# |---B1----|---B2----| ...
# |-G1-|-G2-|-G1-|-G2-| ...
# |-----------------------------------------|
# N | 0 1 2 3 | 4 5 6 7 | ...
# | ...
# |-----------------------------------------|
#
# 1. start with W_t, with shape [N, K]
# * B1, ..., BN are chunks of size [N, B], where B is a hyperparameter of GPTQ
# * G1, ..., GN are chunks of size [N, G], where G is group_size of the quantization recipe
#
# 2. triple for loop, with every loop chunking along the K dimension:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the same thing the code is doing? if so, would it better to just do inline comment for code (including refactor code itself to be easier to read by adding helper functions etc.)?

#
# for B_cur in (B1, ..., BN):
# # B_cur is of shape (N, B)
# # Hinv_cur corresponding to B_cur is of shape (B, B)
#
# for G_cur in (G1, ..., GN):
# # G_cur is of shape (N, group_size)
# # Initialize qparams for all of G_cur, this freezes the quantization
# # grid for G_cur. The rest of this for loop will iteratively optimize
# # the quantized weight values.
#
# for k in range(G_k_start - B_cur_k_start, G_k_end - B_cur_k_start):
# # k is relative to the start of B_cur
# w_t = B_cur[:, k]
# w_t_qdq = quant_dequant(w_t, base_config, qparams)
# err1 = (w_t - w_t_qdq) / Hinv_cur[k, k]
# # propagate errors to remaining columns in B_cur
# B_cur[:, k:] -= err1.matmul(Hinv_cur[k, k:])
# B_cur_Err1[:, k] = err1.flatten()
#
# # batch propagate errors for all remaining blocks in W_t
# W_t[:, B_cur_k_end:] -= B_cur_Err1.matmul(Hinv[B_cur_k_start:B_cur_k_end, B_cur_k_end:])
#

group_qparams = []
for W_t_quantize_block, k_block_start in zip(
for B_cur, B_cur_k_start in zip(
torch.split(W_t, gptq_quantize_block_size, dim=1),
range(0, columns, gptq_quantize_block_size),
):
k_block_end = min(k_block_start + gptq_quantize_block_size, columns)
Err1 = torch.zeros_like(W_t_quantize_block, dtype=H.dtype)
Hinv_quantize_block = Hinv[k_block_start:k_block_end, k_block_start:k_block_end]
B_cur_k_end = min(B_cur_k_start + gptq_quantize_block_size, columns)
B_cur_Err1 = torch.zeros_like(B_cur, dtype=H.dtype)
Hinv_cur = Hinv[B_cur_k_start:B_cur_k_end, B_cur_k_start:B_cur_k_end]

# If we are doing per-row quantization, the group_size is equal to the number of columns and this will only run once.
# Otherwise, if we do per-group quantization, we need to iterate through the block one group at a time.
for k_group_start in range(k_block_start, k_block_end, group_size):
k_group_end = min(k_group_start + group_size, k_block_end)
for G_k_start in range(B_cur_k_start, B_cur_k_end, group_size):
G_k_end = min(G_k_start + group_size, B_cur_k_end)

# We only need to calculate initial qparams for the group once
if k_group_start % group_size == 0:
if G_k_start % group_size == 0:
if isinstance(base_config, Int4WeightOnlyConfig):
_, scale, zero_point = int4_row_quantize_zp(
W_t_quantize_block[
B_cur[
:,
k_group_start - k_block_start : k_group_end - k_block_start,
G_k_start - B_cur_k_start : G_k_end - B_cur_k_start,
],
group_size,
)
group_qparams.append((scale, zero_point))
elif isinstance(base_config, Int8WeightOnlyConfig):
quantized_tensor = Int8Tensor.from_hp(
W_t_quantize_block[
B_cur[
:,
k_group_start - k_block_start : k_group_end - k_block_start,
G_k_start - B_cur_k_start : G_k_end - B_cur_k_start,
],
base_config.granularity,
)

# Quantize each column and propagate errors to subsequent columns
for k in range(k_group_start - k_block_start, k_group_end - k_block_start):
w_t = W_t_quantize_block[:, k].unsqueeze(1)
for k in range(G_k_start - B_cur_k_start, G_k_end - B_cur_k_start):
# k is relative to the start of B_cur
w_t = B_cur[:, k].unsqueeze(1)
if isinstance(base_config, Int4WeightOnlyConfig):
q = _int4_row_quantize_zp_precomputed_qparams(
w_t, scale, zero_point, group_size
Expand All @@ -313,16 +355,14 @@ def gptq_quantize(H: torch.Tensor, W_t: torch.Tensor, config: GPTQConfig):
)
dq = q.dequantize(output_dtype=torch.float)

err1 = (w_t - dq) / Hinv_quantize_block[k, k]
W_t_quantize_block[:, k:] -= err1.matmul(
Hinv_quantize_block[k, k:].unsqueeze(0)
)
Err1[:, k] = err1.flatten()
err1 = (w_t - dq) / Hinv_cur[k, k]
B_cur[:, k:] -= err1.matmul(Hinv_cur[k, k:].unsqueeze(0))
B_cur_Err1[:, k] = err1.flatten()

# Lazy Batch-Updates: We process B columns at a time with local updates above.
# Once a block is fully processed, perform global updates to H^-1 and W using batched versions of the error propagation equations.
W_t[:, k_block_end:] -= Err1.matmul(
Hinv[k_block_start:k_block_end, k_block_end:]
W_t[:, B_cur_k_end:] -= B_cur_Err1.matmul(
Hinv[B_cur_k_start:B_cur_k_end, B_cur_k_end:]
)

torch.cuda.synchronize()
Expand Down
Loading