Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@
import os, re, subprocess, inspect
import numpy as np

# Check if modules that need patching are already imported
critical_modules = ['trl', 'transformers', 'peft']
already_imported = [mod for mod in critical_modules if mod in sys.modules]

# This check is critical because Unsloth optimizes these libraries by modifying
# their code at import time. If they're imported first, the original (slower,
# more memory-intensive) implementations will be used instead of Unsloth's
# optimized versions, potentially causing OOM errors or slower training.

if already_imported:
# stacklevel=2 makes warning point to user's import line rather than this library code,
# showing them exactly where to fix the import order in their script
warnings.warn(
f"WARNING: Unsloth should be imported before {', '.join(already_imported)} "
f"to ensure all optimizations are applied. Your code may run slower or encounter "
f"memory issues without these optimizations.\n\n"
f"Please restructure your imports with 'import unsloth' at the top of your file.",
stacklevel = 2,
)
pass

# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
# enabling it will require much more work, so we have to prioritize. Please understand!
# We do have a beta version, which you can contact us about!
Expand Down
80 changes: 62 additions & 18 deletions unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
MistralSdpaAttention = MistralAttention
MistralFlashAttention2 = MistralAttention
pass
from unsloth_zoo.utils import Version, _get_dtype


def MistralAttention_fast_forward(
Expand Down Expand Up @@ -183,6 +184,7 @@ def MistralForCausalLM_fast_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: Optional[int] = 0,
logits_to_keep: Optional[int] = 0,
*args, **kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:

Expand All @@ -194,7 +196,6 @@ def MistralForCausalLM_fast_forward(
elif q_len <= sliding_window:
causal_mask = xformers.attn_bias.LowerTriangularMask()
else:
# Fix from https://github.com/Rypo
causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\
.from_seqlens([q_len]*bsz)\
.make_local_attention(window_size = sliding_window)
Expand All @@ -219,20 +220,35 @@ def MistralForCausalLM_fast_forward(
)
else:
outputs = self.model(
input_ids=input_ids,
causal_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
input_ids = input_ids,
causal_mask = causal_mask,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
use_cache = use_cache,
output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
return_dict = return_dict,
)
pass

hidden_states = outputs[0]

# If we are in GRPO mode, return raw hidden states
if os.environ.get("UNSLOTH_RETURN_HIDDEN_STATES", "0") == "1":
num_logits_to_keep = max(num_logits_to_keep, logits_to_keep)
if num_logits_to_keep != 0:
hidden_states = hidden_states[:, -num_logits_to_keep:, :]
return CausalLMOutputWithPast(
loss = None,
logits = hidden_states,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
)
pass

bsz, q_len, hd = hidden_states.shape
lm_head = self.lm_head.weight
if bsz == 1 and q_len == 1:
Expand All @@ -241,9 +257,37 @@ def MistralForCausalLM_fast_forward(
elif num_logits_to_keep != 0:
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :].to(lm_head.dtype))
else:
RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1"
# < 1024 Normal Unsloth uses less VRAM!
if bsz * q_len <= 1024: RETURN_LOGITS = True

if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None:
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None)
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
loss = fused_linear_cross_entropy(
hidden_states = hidden_states,
lm_weight = lm_head,
labels = labels,
num_items_in_batch = n_items,
logit_softcapping = logit_softcapping,
)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

output = CausalLMOutputWithPast(
loss = loss,
logits = EMPTY_LOGITS,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
)
return output
pass
logits = self.lm_head(hidden_states.to(lm_head.dtype))
pass
logits = logits.to(self.config.torch_dtype)
logits = logits.to(_get_dtype(self.config.torch_dtype))

loss = None
if labels is not None:
Expand All @@ -252,7 +296,7 @@ def MistralForCausalLM_fast_forward(
# Fixes https://github.com/unslothai/unsloth/issues/10
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
pass

shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
Expand All @@ -266,11 +310,11 @@ def MistralForCausalLM_fast_forward(
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
loss = loss,
logits = logits,
past_key_values = outputs.past_key_values,
hidden_states = outputs.hidden_states,
attentions = outputs.attentions,
)
pass

Expand Down