Fix KTO compute_kl to average KL across all batches#216
Open
EphraiemSarabamoun wants to merge 1 commit into
Open
Fix KTO compute_kl to average KL across all batches#216EphraiemSarabamoun wants to merge 1 commit into
EphraiemSarabamoun wants to merge 1 commit into
Conversation
In KTOTrainer.compute_kl the loop over self.random_dataloader reassigned self.kl on every iteration, so only the last batch's KL survived instead of an average over all batches. Accumulate the per batch clamped KL and divide by the batch count once after the loop. Reported in issue PKU-Alignment#215.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This addresses the KTO bug raised in issue #215.
In KTOTrainer.compute_kl, the loop over self.random_dataloader computes a per batch KL and assigns self.kl on every iteration with no accumulation. Because the assignment is inside the loop, only the last batch's KL survives. The intended value is an average across all of the batches drawn from the random dataloader, so the current behavior makes the KL estimate noisy and dependent on whichever batch happens to come last.
The fix keeps the change minimal. It sums the per batch clamped KL into kl_sum, counts the batches in num_batches, and assigns self.kl = kl_sum / num_batches a single time after the loop. The per batch clamp at zero is preserved by accumulating max(kl, 0), and self.kl stays a tensor so the downstream loss arithmetic in loss is unaffected.
Verification. align-anything is a heavy multimodal training package, so running the full trainer to exercise this path was not practical here. I verified by close reading and by a small standalone reproduction of the exact loop logic with a fake dataloader of known per batch KL values. With per batch KLs of 2, 8, 4, and 6, the old logic returns 6, which is the last batch only, while the fixed logic returns 5, which is the correct mean. The reproduction also confirms the zero clamp is preserved and that the result remains a torch tensor. Running the repo formatters and linter on the changed file shows no new violations on the changed lines.