Skip to content

Conversation

oKatanaaa
Copy link
Contributor

This pull request solves two issues:

  1. MistralForCausalLM_fast_forward method does not account for 'UNSLOTH_RETURN_HIDDEN_STATES' env variable causing GRPO trainer to break.
  2. unsloth must be imported before trl/transformers/peft since it patches their source code (specifically trl). If either of those libraries imported first, the patching won't have any effect and that (a) may break training in some cases and (b) lead to OOM due to absence of some optimizations.

Mistral GRPO

Current implementation of MistralForCausalLM_fast_forward method always returns logits which is incompatible with GRPOTrainer. Running training leads to the following exception:

TorchRuntimeError: Failed running call_function <built-in method matmul of type object at 0x74393685f240>(*(GradTrackingTensor(lvl=1, value=
    FakeTensor(..., device='cuda:0', size=(1, s0, 32001), dtype=torch.bfloat16,
               requires_grad=True)
), GradTrackingTensor(lvl=1, value=
    FakeTensor(..., device='cuda:0', size=(4096, 32001), dtype=torch.bfloat16)
)), **{}):
a and b must have same reduction dim, but got [s0, 32001] X [4096, 32001].

from user code:
   File "/workspace/research/grpo/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 100, in accumulate_chunk
    (chunk_grad_input,), (chunk_loss, (unscaled_loss, chunk_completion_length, chunk_mean_kl,)) = torch.func.grad_and_value(
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/apis.py", line 442, in wrapper
    return eager_transforms.grad_and_value_impl(
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 48, in fn
    return f(*args, **kwargs)
  File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/eager_transforms.py", line 1407, in grad_and_value_impl
    output = func(*args, **kwargs)
  File "/workspace/research/grpo/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 80, in compute_loss
    new_logits = torch.matmul(new_hidden_states, lm_head.t())

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

I've implemented the necessary functionality following the example of CausalLM_fast_forward in llama.py. With that fix training runs as expected.


Importing order

Importing trl (or any other dependencies with trl) before unsloth makes trainer patching lose its effect.

I've spend a whole day trying to understand why my code either breaks or OOMs (while colab notebooks with essentially the same functionality work just fine) mid training till I had an AHA! moment to realize importing order is the issue.

To avoid such cases, I've implemented a simple check for imports in the top-level __init__.py that warns users if unsloth was imported after trl.

@oKatanaaa
Copy link
Contributor Author

Found the issue that describes exactly this problem #1790
This PR solves it

@danielhanchen
Copy link
Contributor

Oh fantastic!! Weird how I missed this!

@danielhanchen danielhanchen changed the base branch from main to nightly February 25, 2025 23:04
@danielhanchen danielhanchen changed the base branch from nightly to main February 25, 2025 23:05
@danielhanchen danielhanchen merged commit 42cbe1f into unslothai:main Feb 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants