Fix: GRPO with Mistral and importing #1831
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request solves two issues:
MistralForCausalLM_fast_forward
method does not account for 'UNSLOTH_RETURN_HIDDEN_STATES' env variable causing GRPO trainer to break.Mistral GRPO
Current implementation of
MistralForCausalLM_fast_forward
method always returns logits which is incompatible with GRPOTrainer. Running training leads to the following exception:I've implemented the necessary functionality following the example of
CausalLM_fast_forward
in llama.py. With that fix training runs as expected.Importing order
Importing trl (or any other dependencies with trl) before unsloth makes trainer patching lose its effect.
I've spend a whole day trying to understand why my code either breaks or OOMs (while colab notebooks with essentially the same functionality work just fine) mid training till I had an AHA! moment to realize importing order is the issue.
To avoid such cases, I've implemented a simple check for imports in the top-level
__init__.py
that warns users if unsloth was imported after trl.