Skip to content

Fix KTO compute_kl to average KL across all batches#216

Open
EphraiemSarabamoun wants to merge 1 commit into
PKU-Alignment:mainfrom
EphraiemSarabamoun:fix/kto-kl-average-across-batches
Open

Fix KTO compute_kl to average KL across all batches#216
EphraiemSarabamoun wants to merge 1 commit into
PKU-Alignment:mainfrom
EphraiemSarabamoun:fix/kto-kl-average-across-batches

Conversation

@EphraiemSarabamoun

Copy link
Copy Markdown

This addresses the KTO bug raised in issue #215.

In KTOTrainer.compute_kl, the loop over self.random_dataloader computes a per batch KL and assigns self.kl on every iteration with no accumulation. Because the assignment is inside the loop, only the last batch's KL survives. The intended value is an average across all of the batches drawn from the random dataloader, so the current behavior makes the KL estimate noisy and dependent on whichever batch happens to come last.

The fix keeps the change minimal. It sums the per batch clamped KL into kl_sum, counts the batches in num_batches, and assigns self.kl = kl_sum / num_batches a single time after the loop. The per batch clamp at zero is preserved by accumulating max(kl, 0), and self.kl stays a tensor so the downstream loss arithmetic in loss is unaffected.

Verification. align-anything is a heavy multimodal training package, so running the full trainer to exercise this path was not practical here. I verified by close reading and by a small standalone reproduction of the exact loop logic with a fake dataloader of known per batch KL values. With per batch KLs of 2, 8, 4, and 6, the old logic returns 6, which is the last batch only, while the fixed logic returns 5, which is the correct mean. The reproduction also confirms the zero clamp is preserved and that the result remains a torch tensor. Running the repo formatters and linter on the changed file shows no new violations on the changed lines.

In KTOTrainer.compute_kl the loop over self.random_dataloader reassigned
self.kl on every iteration, so only the last batch's KL survived instead
of an average over all batches. Accumulate the per batch clamped KL and
divide by the batch count once after the loop. Reported in issue PKU-Alignment#215.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant