-
Notifications
You must be signed in to change notification settings - Fork 1.2k
support dpo orpo and simpo #1631
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @poryfly, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the training capabilities by adding support for advanced preference optimization techniques such as DPO, ORPO, and SIMPO. It achieves this by integrating a specialized DPO trainer and by refactoring the core accelerator logic to allow for more flexible model management and device placement. The internal configuration for training modes has also been generalized to better support these diverse training approaches. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for DPO by adding a new KTDporainer and refactors some existing code for better reusability and generality, such as moving KAccelerator to a shared utility module and changing the training mode from sft to a more general train. While the refactoring is a good step, the new KTDporainer has several issues that need to be addressed. These include a typo in the class name, using print instead of a logger, a risky try...except Exception: pass block that can hide errors, and a critical bug in post_training_step that fails to return the correctly processed tensor. Addressing these points will significantly improve the code's quality, maintainability, and correctness.
| def post_training_step(self, loss): | ||
| if loss.device != self.args.device: | ||
| ret = loss.to(self.args.device, non_blocking=True) | ||
| return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a bug in this method. The loss tensor is moved to self.args.device and the result is stored in ret. However, the original loss tensor is returned, so the tensor is not actually on the correct device when returned. The value of ret is never used.
| def post_training_step(self, loss): | |
| if loss.device != self.args.device: | |
| ret = loss.to(self.args.device, non_blocking=True) | |
| return loss | |
| def post_training_step(self, loss): | |
| if loss.device != self.args.device: | |
| loss = loss.to(self.args.device, non_blocking=True) | |
| return loss |
| try: | ||
| self.accelerator.state.device_ids = [0] | ||
| self.accelerator.state.num_processes = 1 | ||
| self.accelerator.state.num_gpus = 1 | ||
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a broad except Exception: pass is risky because it silences all errors, including system-level ones like KeyboardInterrupt or SystemExit, which can make debugging extremely difficult. If an exception is expected, it should be caught specifically. At the very least, the exception should be logged to provide visibility into potential issues. The hardcoded values for device_ids, num_processes, and num_gpus also seem to assume a single-GPU setup and could cause problems in a distributed environment.
| try: | |
| self.accelerator.state.device_ids = [0] | |
| self.accelerator.state.num_processes = 1 | |
| self.accelerator.state.num_gpus = 1 | |
| except Exception: | |
| pass | |
| try: | |
| self.accelerator.state.device_ids = [0] | |
| self.accelerator.state.num_processes = 1 | |
| self.accelerator.state.num_gpus = 1 | |
| except Exception as e: | |
| logger.warning("Could not override accelerator state. This may be expected in a multi-GPU environment. Error: %s", e) |
kt-sft/ktransformers/dpo/trainer.py
Outdated
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| class KTDporainer(DPOTrainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.model.save_pretrained(output_dir) | ||
|
|
||
| def _move_model_to_device(self, model, device): | ||
| print("[KTrainer] Due to the placement feature in KTransformers, skip moving model to", device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using print for logging is generally discouraged in a library. It's better to use the logging module, which is already imported as logger. This provides more control over log levels and output streams, and allows users of the library to configure logging as they see fit.
| print("[KTrainer] Due to the placement feature in KTransformers, skip moving model to", device) | |
| logger.info("[KTrainer] Due to the placement feature in KTransformers, skip moving model to %s", device) |
…-gpu optimizer rule
What does this PR do?
#1610
Fixes # (issue)
Before submitting