Skip to content

Releases: unslothai/unsloth

gpt-oss Reinforcement Learning + Auto Kernel Notebook

26 Sep 15:24
Compare
Choose a tag to compare

We’re introducing gpt-oss RL support and the fastest RL inference and lowest VRAM use vs. any implementation. Blog: https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning

  • Unsloth now offers the fastest inference (~3x faster), lowest VRAM (50% less) and most context (8x longer) for gpt-oss RL vs. any implementation - with no accuracy loss.
  • Since RL on gpt-oss isn't yet vLLM compatible, we rewrote Transformers inference code to enable faster inference
  • gpt-oss-20b GSPO free Colab notebook
  • This notebook automatically creates faster matrix multiplication kernels and uses a new Unsloth reward function. We also show how to counteract reward-hacking which is one of RL's biggest challenges.
gptoss rl
  • We previously released Vision RL with GSPO support
  • ⚠️ Reminder to NOT use Flash Attention 3 for gpt-oss as it'll make your training loss wrong.
  • DeepSeek-V3.1-Terminus is here and you can run locally via our GGUF
    Read how our 3-bit GGUF beats Claude-4-Opus (thinking) on Aider Polyglot here
  • Magistral 1.2 is here and you can run it locally here or fine-tune it for free by using our Kaggle notebook
  • Fine-tuning the new Qwen3 models including Qwen3-VL, Qwen3-Omni and Qwen3-Next should work in Unsloth if you install the latest transformers. The models are big however so ensure you have enough VRAM.
  • BERT is now fixed! Feel free to use our BERT fine-tuning notebook
  • ⭐ We’re hosting a Developer event with Mistral AI & NVIDIA at Y Combinator’s Office in San Francisco on Oct 21. Come say hello!
  • We’re also joining Pytorch and AMD for a 2 day Virtual AI Agents Challenge with prizes. Join Hackathon

Don't forget to also join our Reddit: r/unsloth 🥰

What's Changed

New Contributors

Full Changelog: September-2025-v2...September-2025-v3

Vision Reinforcement Learning + Memory Efficient RL

16 Sep 16:14
Compare
Choose a tag to compare

We're excited to support Vision models for RL and even more memory efficient + faster RL! sloth magnify

Unsloth now supports vision/multimodal RL with Gemma 3, Qwen2.5-VL and other vision models. Due to Unsloth's unique weight sharing and custom kernels, Unsloth makes VLM RL 1.5–2× faster, uses 90% less VRAM, and enables 10× longer context lengths than FA2 setups, with no accuracy loss. Qwen2.5-VL GSPO notebook
Gemma 3 (4B) Vision GSPO notebook

Full details in our blogpost: https://docs.unsloth.ai/new/vision-reinforcement-learning-vlm-rl

  • This update also introduces Qwen's GSPO algorithm.
  • Our new vision RL support also comes now even faster & more memory efficient! Our new kernels & algos allows faster RL for text and vision LLMs with 50% less VRAM & 10× more context.
  • Introducing a new RL feature called 'Standby'. Before, RL requires GPU splitting between training & inference. With Unsloth Standby, you no longer have to & 'Unsloth Standby' uniquely limits speed degradation compared to other implementations and sometimes makes training even faster! Read our Blog
memory efficient rl
  • We released Aider Polyglot benchmarks for our DeepSeek-V3.1 Dynamic GGUFs and Unsloth quants perform consistently better than others. Blog
aider min

Don't forget to also join our Reddit: r/unsloth 🥰

What's Changed

New Contributors

Full Changelog: August-2025-v2...September-2025-v2

Unsloth Flex Attention + Long context gpt-oss Training

28 Aug 16:38
Compare
Choose a tag to compare

We’re excited to introduce Unsloth Flex Attention support for OpenAI gpt-oss training that enables >8× longer context lengths, >50% less VRAM usage and >1.5× faster training compared to all implementations including those using Flash Attention 3 (FA3). Unsloth Flex Attention makes it possible to train with a 60K context length on just 80GB of VRAM for BF16 LoRA. Also:

  • You can now export/save your QLoRA fine-tuned gpt-oss model to llama.cpp, vLLM, or HF.
  • We fixed gpt-oss training losses going to infinity on float16 GPUs (like T4 Colab)
  • We fixed gpt-oss implementation issues, most notably ensuring that swiglu_limit = 7.0 is properly applied during MXFP4 inference in transformers
  • Unsloth Flex Attention scales with context, longer sequences yield bigger savings in both VRAM and training time

Full details in our blogpost: https://docs.unsloth.ai/basics/long-context-gpt-oss-training

What's Changed

New Contributors

Full Changelog: August-2025...August-2025-v2

gpt-oss Fine-tuning

08 Aug 15:33
Compare
Choose a tag to compare
gpt-oss unsloth

gpt-oss is here! ✨

Finetune gpt-oss for free with our Unsloth Colab notebook!

  • We’ve managed to make gpt-oss train on just 14GB of VRAM, making it possible to work on free Colab due to our linear conversions. For more details, Read our Guide/Blogpost
  • Fine-tuning gpt-oss is 1.5x faster and uses 50% less VRAM with Unsloth. gpt-oss-120b model fits on 65GB of VRAM.
  • Model uploads: 20b GGUF120b GGUFAll uploads

🦥 Unsloth updates

  • We’ve made algorithmic updates to Unsloth so every model now trains faster and with less VRAM, no matter which.
  • Unsloth now works on RTX 50 and Blackwell GPUs. Read our guide.
  • Official Unsloth Docker image coming very soon!
  • You can now run Unsloth models directly via Docker: docker model pull hf.co/unsloth/gpt-oss-20b-GGUF

🌠 Qwen3-Coder + Qwen3-2507

Qwen made July, 2025 updates called 'Qwen3-2507' and launched their SOTA coding models!

🔮 New models + Support:

Run these new models:

Unsloth also now supports running + training for:

Don't forget to also join our Reddit: r/unsloth 🥰

What's Changed

New Contributors

Full Changelog: July-2025...August-2025

Less VRAM + bug fixes

10 Jul 14:35
Compare
Choose a tag to compare

More VRAM reduction, faster & bug fixes

Please update Unsloth! pip install --upgrade --force-reinstall --no-deps --no-cache-dir unsloth unsloth_zoo

  1. Gemma 3N Vision now works and is fixed! Please re-download all model checkpoints (Unsloth will auto do it) Try Kaggle Notebook! There is also a challenge with a prize pool of $100,000!
  2. Gemma 3 text and vision are all fixed for T4, and is much faster. Losses of 6 to 7 are now fixed - it should be 1 to 2.
  3. 10 to 25% less VRAM consumption for all models. Also faster compiling and less errors. Unsloth is now more stable!
  4. Downloads stuck at 90% to 95% fixed!
  5. Qwen 2.5, Qwen 2, GLM all fixed as well.
  6. GRPO now works with latest main TRL
  7. Main TRL, PEFT, Transformers all work
  8. Forced upgrading transformers is now fixed.
  9. Falcon H1 finetuning should work great! Notebooks incoming
  10. Devstral 1.1 and MedGemma 27B, 4B support with vision
  11. Many many many more bug fixes - this release of Unsloth should be much more stable and error tolerant!

Please update Unsloth! pip install --upgrade --force-reinstall --no-deps --no-cache-dir unsloth unsloth_zoo

What's Changed

New Contributors

Full Changelog: June-2025...July-2025

Gemma 3n + Text-to-speech (TTS)

26 Jun 16:25
Compare
Choose a tag to compare

✨ Gemma 3n now available

  • Google's new Gemma 3n multimodal models that support text, image, video & audio. Guide
  • Gemma 3n finetuning notebook + audio, vision, text inference Colab notebook
  • Gemma 3n collection in dynamic GGUF, safetensor 4-bit etc formats: Gemma-3n

🎵 Text-to-Speech (TTS) Fine-tuning

  • Train TTS/STT models like Sesame-CSM, Orpheus-TTS and OpenAI's Whisper locally! Guide
  • Clone voices, learn new emotions, tones & styles with 1.5x faster training and -50% VRAM. Notebooks

Tip

Update Unsloth via pip install --upgrade --force-reinstall unsloth unsloth_zoo

🧠 DeepSeek-R1-0528 Support with Dynamic 1-bit GGUFs

  • Fine-tune DeepSeek-R1-0528-Qwen3 with GRPO! Our new reward function increases multilingual response rates by 40%+ Notebook
  • Dynamic 1-bit GGUFs shrink the full 715GB model to just 175GB (-80% size)

📈 Dynamic 2.0 GGUFs

  • New quantization method that achieves SOTA performance. More info
  • Sets new benchmarks for 5-shot MMLU and KL Divergence and selectively quantizes layers for optimal accuracy

⚡ Advanced Qwen3 GRPO notebook

  • Proximity scoring for more better reward functions. Advanced GRPO notebook
  • New Prefinetuning/priming to skip GRPO format learning

🎯 Magistral Conversational Reasoning

  • Fine-tune Magistral-24B for advanced conversational reasoning. Notebook

👁️ Gemma3 Vision Support

  • Fine-tune Gemma3 vision models for multimodal tasks Notebook

Documentation & Guides

  • Reinforcement Learning Guide: Complete guide on RL for LLMs covering GRPO, RLHF, DPO. Guide
  • LoRA Hyperparameters Guide: Master optimal learning rates, epochs, LoRA rank & alpha settings. Guide

What's Changed

New Contributors

Read more

Qwen3

02 May 16:13
Compare
Choose a tag to compare

Qwen 3 support + bug fixes

Please update Unsloth via pip install --upgrade --force-reinstall unsloth unsloth_zoo

Qwen3 notebook: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(14B)-Reasoning-Conversational.ipynb
GRPO with Qwen3 notebook: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb

There are also many bug fixes in this release!

The 30B MoE is also fine-tunable in Unsloth!

from unsloth import FastModel
import torch
model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/Qwen3-30B-A3B",
    max_seq_length = 2048, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)

What's Changed

New Contributors

Full Changelog: 2025-03...May-2025

Gemma 3 + FFT Support

14 Mar 15:58
Compare
Choose a tag to compare

March Release 🦥

Get the latest stable Unsloth via:

pip install --upgrade --force-reinstall --no-cache-dir unsloth unsloth_zoo

The March release should be stable - you can force the version via:

pip install "unsloth==2025.3.18" "unsloth_zoo==2025.3.16"

New Features

  • Read all details here: https://unsloth.ai/blog/gemma3

  • Gemma 3 1B, 4B, 12B and 27B finetuning all work now! Colab Notebook We fixed some issues which caused Gemma 3 training loss to be very high. This includes some tokenization issues so fine-tuning Gemma 3 will now work correctly if you use Unsloth.
    image

  • We also encountered many infinite gradients during Gemma 3 (1B to 27B) finetuning. We found float16 mixed precision (Tesla T4, RTX 2080 series) to not function well, and we defaulted to float32 precision. Float16 also failed on A100, so this is a hardware agnostic issue. Bfloat16 is fine though! Unsloth auto selects the best data-type! You do not have to do anything! Colab Notebook to finetune Gemma 3

  • Preliminary support for full-finetuning and 8bit finetuning - set full_finetuning = True or load_in_8bit = True Both will be optimized further in the future! A reminder you will need more powerful GPUs!

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-4B-it",
    max_seq_length = 2048, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    # token = "hf_...", # use one if using gated models
)
  • New Unsloth Auto Model support - nearly all models are now supported! We now supports vision and text models out of the box, without the need for custom implementations (and all are optimized!)
  • Mixtral (yes finally!), Gemma 3, Granite 3.2, Cohere, OLMo, Reka, and generally any vision or language model! There might be some occasional models which don't work!
model, tokenizer = FastModel.from_pretrained(
    model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1",
)
pip install unsloth
  • Train on completions / responses only for vision models supported! Use it like below:
data_collator = UnslothVisionDataCollator(
    model,
    tokenizer,
    train_on_responses_only = False,
    instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n",
    response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",
)
SFTTrainer(..., data_collator = data_collator)
  • Conversions to llama.cpp GGUFs for 16bit and 8bit now DO NOT need compiling! This solves many many issues, and this means no need to install GCC, Microsoft Visual Studio etc!
model.save_pretrained_merged("gemma-3-finetune", tokenizer)
model.save_pretrained_gguf(
    "gemma-3-finetune",
    quantization_type = "Q8_0", # For now only Q8_0, BF16, F16 supported
)
  • Vision models now auto resize images which stops OOMs and also allows truncating sequence lengths!

  • Many multiple optimizations in Unsloth allowing a further +10% less VRAM usage, and >10% speedup boost for 4bit (on top of our original 2x faster, 70% less memory usage). 8bit and full finetuning also benefit!

  • GRPO in Unsloth now allows non Unsloth uploaded models to be in 4bit as well - reduces VRAM usage a lot! (ie pretend your own finetune of Llama)

  • New training logs and infos - training parameter counts, total batch size
    image

  • Vision models now also work for normal text training! This means non vision notebooks can work with vision models!

  • Complete gradient accumulation bug fix coverage for all models!

  • GRPO notebook for Gemma 3 coming soon with Hugging Face's reasoning course!

  • DoRA, Dropout, and other PEFT methods should just work!

Bug fixes

  • Faster and less error prone streamlined finetuning experience! Apologies for the recent issues with constant releases and breaking breaks - the March release should be stable! Ie pip install "unsloth==2025.3.14" "unsloth_zoo==2025.3.12"
  • Pixtral and Llava finetuning are now fixed! In fact nearly all vision models are supported out of the box! Please update transformers for Pixtral: pip install --no-deps git+https://github.com/huggingface/transformers.git
  • Fixed all Colabs not working - cloud instances like Runpod should just work now!
  • Fixed many many bugs - will reply to each issue with updates!

Other items

New Contributors

Full Changelog: 2025-02...2025-03

Long Context GRPO

20 Feb 18:51
Compare
Choose a tag to compare

90% less memory usage GRPO

Update Unsloth via pip install --upgrade --no-cache-dir unsloth unsloth_zoo

More details in blog post: https://unsloth.ai/blog/grpo

Llama 3.1 8B GRPO Colab: https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb

Metric Unsloth TRL + FA2
Training Memory Cost (GB) 42GB 414GB
GRPO Memory Cost (GB) 9.8GB 78.3GB
Inference Cost (GB) 0GB 16GB
Inference KV Cache for 20K context (GB) 2.5GB 2.5GB
Total Memory Usage 54.3GB (90% less) 510.8GB

You automatically get 90% less memory usage! Also all reward logs for individual reward functions will show up.
Screenshot_2025-02-20_at_04-52-52_Copy_of_Yet_another_copy_of_Llama3 1_(8B)-GRPO ipynb_-_Colab_5lpAL05rCEjw67tij45ua

Script to run GRPO:

!pip install unsloth vllm
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)
import re
from datasets import load_dataset, Dataset
global COUNTER
COUNTER = 0
global PRINT_EVERY
PRINT_EVERY = 20

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    global COUNTER
    if COUNTER % PRINT_EVERY == 0:
        print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    COUNTER += 1
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

max_prompt_length = 256
from trl import GRPOConfig, GRPOTrainer

# Optional extra params for vLLM
from unsloth import vLLMSamplingParams
vllm_sampling_params = vLLMSamplingParams(
    min_p = 0.01,
    seed = 3407,
)
training_args = GRPOConfig(
    learning_rate = 5e-6,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 250,
    report_to = "none", # Can use Weights & Biases
    vllm_sampling_params = vllm_sampling_params, # Optional
    temperature = 1.0,
)
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

What's Changed

New Contributors

Full Changelog: 2025-02...2025-02-v2

GRPO, vLLM

06 Feb 14:47
Compare
Choose a tag to compare

GRPO is in Unsloth!

  • Experience the "aha moment" from DeepSeek R1's paper now with Unsloth!
  • LoRA (16bit) / QLoRA (4bit) actually work for GRPO now!
  • Unsloth can do GRPO for Phi-4 14B Llama-3.1 8B in a free 15GB Colab GPU!
  • Unsloth now has native fast inference (20x more throughput) via vLLM! Use it via model.fast_generate after setting FastLanguageModel.from_pretrained(..., fast_inference = True) and installing vLLM via pip install vllm
  • Llama 3.3 70B QLoRA GRPO should fit in 1x 48GB (best 1x 80GB)
  • Update unsloth via pip install --upgrade --no-cache-dir --force-reinstall unsloth_zoo unsloth vllm

image

GRPO Notebooks

Model Type Colab Link
Phi 4 (14B) GRPO Open in Colab
Llama 3.1 (8B) GRPO Open in Colab
Qwen 2.5 (3B) GRPO Open in Colab

Minimal GRPO example (courtesy of Will Brown]

!pip install unsloth vllm
!pip install git+https://github.com/huggingface/trl.git

from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

from unsloth import is_bfloat16_supported
import torch
max_seq_length = 512
lora_rank = 32

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True,
    fast_inference = True,
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6,
)
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank,
    lora_alpha = lora_rank,
)

import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1,
    num_generations = 6,
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1,
    max_steps = 250,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none",
    output_dir = "outputs",
)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

Bug Fixes

  • Gemma 2 should be fixed now
  • Mistral base mapping should be fixed
  • Some syntax warning issue fixes
  • And many many more bug fixes!

What's Changed

New Contributors

Full Changelog: 2025-01...2025-02