Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,6 @@ def from_pretrained(
assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16
or dtype == torch.float32)

if use_gradient_checkpointing == "unsloth":
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)

if fast_inference:
if importlib.util.find_spec("vllm") is None:
raise ImportError(
Expand Down Expand Up @@ -369,6 +366,9 @@ def from_pretrained(
)
pass

if use_gradient_checkpointing == "unsloth":
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)

# Check if this is local model since the tokenizer gets overwritten
if os.path.exists(os.path.join(old_model_name, "tokenizer_config.json")) and \
os.path.exists(os.path.join(old_model_name, "tokenizer.json")) and \
Expand Down
38 changes: 33 additions & 5 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
"from unsloth_zoo.vision_utils import UnslothVisionDataCollator\n"\
"if not isinstance(data_collator, UnslothVisionDataCollator):\n"\
" if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:\n"\
" data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)\n"\
" data_collator = TransformersDataCollatorForLanguageModeling(\n"\
" __tokenizer,\n"\
" mlm = False,\n"\
" mlm_probability = 0.0,\n"\
" pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"\
" )\n"\
" elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:\n"\
" data_collator = DataCollatorForSeq2Seq(__tokenizer)\n"\
" data_collator = DataCollatorForSeq2Seq(\n"\
" __tokenizer,\n"\
" pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"\
" )\n"\
"else:\n"\
" if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False\n"\
" if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''\n"\
Expand All @@ -404,9 +412,17 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
"if not isinstance(data_collator, UnslothVisionDataCollator):\n"\
" if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):\n"\
" if isinstance(data_collator, DataCollatorForSeq2Seq):\n"\
" data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)\n"\
" data_collator = DataCollatorForSeq2Seq(\n"\
" __tokenizer.tokenizer,\n"\
" pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"\
" )\n"\
" else:\n"\
" data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)\n"
" data_collator = TransformersDataCollatorForLanguageModeling(\n"\
" __tokenizer.tokenizer,\n"\
" mlm = False,\n"\
" mlm_probability = 0.0,\n"\
" pad_to_multiple_of = getattr(args, 'pad_to_multiple_of', None),\n"\
" )\n"
extra_args += pad_check
pass

Expand Down Expand Up @@ -566,10 +582,22 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
num_proc_check = \
"if dataset_num_proc is None:\n"\
" from multiprocessing import cpu_count\n"\
" dataset_num_proc = min(cpu_count()*2, 2)\n"
" dataset_num_proc = max(cpu_count()+4, 2)\n"
extra_args += num_proc_check
pass

# Add padding if flex attention is added
if "pad_to_multiple_of" in call_args:
pad_to_multiple_of = \
"if os.environ.get('UNSLOTH_ENABLE_FLEX_ATTENTION', '0') == '1':\n"\
" from unsloth_zoo.flex_attention import HAS_FLEX_ATTENTION\n"\
" if HAS_FLEX_ATTENTION and pad_to_multiple_of is None:\n"\
" from unsloth_zoo.flex_attention import FLEX_ATTENTION_BLOCK_SIZE\n"\
" pad_to_multiple_of = FLEX_ATTENTION_BLOCK_SIZE\n"\
"\n"
extra_args += pad_to_multiple_of
pass

# Check for loss_type = dr_grpo and scale_rewards for GRPO
if "loss_type" in call_args and "scale_rewards" in call_args:
# See https://github.com/huggingface/trl/issues/3130#issuecomment-2746947835
Expand Down