Skip to content

fix gc_kwargs #4591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ Running Environment:
| python | >=3.9 | 3.10 | |
| cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
| torch | >=2.0 | | |
| transformers | >=4.33 | 4.51 | |
| transformers | >=4.33 | 4.51.3 | |
| modelscope | >=1.23 | | |
| peft | >=0.11,<0.16 | ||
| trl | >=0.13,<0.19 | 0.18 |RLHF|
| deepspeed | >=0.14 | 0.14.5 | Training |
| vllm | >=0.5.1 | 0.8 | Inference/Deployment/Evaluation |
| deepspeed | >=0.14 | 0.14.5 / 0.16.9 | Training |
| vllm | >=0.5.1 | 0.8.5.post1 | Inference/Deployment/Evaluation |
| lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |
| evalscope | >=0.11 | | Evaluation |

Expand Down
6 changes: 3 additions & 3 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ pip install -e .
| python | >=3.9 | 3.10 ||
| cuda | | cuda12 |使用cpu、npu、mps则无需安装|
| torch | >=2.0 | ||
| transformers | >=4.33 | 4.51 ||
| transformers | >=4.33 | 4.51.3 ||
| modelscope | >=1.23 | ||
| peft | >=0.11,<0.16 | ||
| trl | >=0.13,<0.19 | 0.18 |RLHF|
| deepspeed | >=0.14 | 0.14.5 |训练|
| vllm | >=0.5.1 | 0.8 |推理/部署/评测|
| deepspeed | >=0.14 | 0.14.5 / 0.16.9 |训练|
| vllm | >=0.5.1 | 0.8.5.post1 |推理/部署/评测|
| lmdeploy | >=0.5 | 0.8 |推理/部署/评测|
| evalscope | >=0.11 | |评测|

Expand Down
6 changes: 3 additions & 3 deletions docs/source/GetStarted/SWIFT安装.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,12 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2
| python | >=3.9 | 3.10 ||
| cuda | | cuda12 |使用cpu、npu、mps则无需安装|
| torch | >=2.0 | ||
| transformers | >=4.33 | 4.51 ||
| transformers | >=4.33 | 4.51.3 ||
| modelscope | >=1.23 | ||
| peft | >=0.11,<0.16 | ||
| trl | >=0.13,<0.19 | 0.18 |RLHF|
| deepspeed | >=0.14 | 0.14.5 |训练|
| vllm | >=0.5.1 | 0.8 |推理/部署/评测|
| deepspeed | >=0.14 | 0.14.5 / 0.16.9 |训练|
| vllm | >=0.5.1 | 0.8.5.post1 |推理/部署/评测|
| lmdeploy | >=0.5 | 0.8 |推理/部署/评测|
| evalscope | >=0.11 | |评测|

Expand Down
1 change: 1 addition & 0 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
- lr_scheduler_type: lr_scheduler类型,默认为'cosine'。
- lr_scheduler_kwargs: lr_scheduler其他参数。默认为None。
- 🔥gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。
- 注意:当使用DDP而不使用deepspeed/fsdp,且gradient_checkpointing_kwargs为None,会默认设置其为`'{"use_reentrant": false}'`。
- full_determinism: 确保训练中获得可重现的结果,注意:这会对性能产生负面影响。默认为False。
- 🔥report_to: 默认值为`tensorboard`。你也可以指定`--report_to tensorboard wandb swanlab`、`--report_to all`。
- logging_first_step: 是否记录第一个step的日志,默认为True。
Expand Down
6 changes: 3 additions & 3 deletions docs/source_en/GetStarted/SWIFT-installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ More images can be found [here](https://modelscope.cn/docs/intro/environment-set
| python | >=3.9 | 3.10 | |
| cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
| torch | >=2.0 | | |
| transformers | >=4.33 | 4.51 | |
| transformers | >=4.33 | 4.51.3 | |
| modelscope | >=1.23 | | |
| peft | >=0.11,<0.16 | | |
| trl | >=0.13,<0.19 | 0.18 | RLHF |
| deepspeed | >=0.14 | 0.14.5 | Training |
| vllm | >=0.5.1 | 0.8 | Inference/Deployment/Evaluation |
| deepspeed | >=0.14 | 0.14.5 / 0.16.9 | Training |
| vllm | >=0.5.1 | 0.8.5.post1 | Inference/Deployment/Evaluation |
| lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |
| evalscope | >=0.11 | | Evaluation |

Expand Down
1 change: 1 addition & 0 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ This parameter list inherits from transformers `Seq2SeqTrainingArguments`, with
- lr_scheduler_type: Type of lr_scheduler, defaults to 'cosine'.
- lr_scheduler_kwargs: Other parameters for the lr_scheduler, defaults to None.
- 🔥gradient_checkpointing_kwargs: Parameters for `torch.utils.checkpoint`. For example, set as `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to None.
- Note: When using DDP without DeepSpeed/FSDP, and `gradient_checkpointing_kwargs` is `None`, it will default to `'{"use_reentrant": false}'`.
- full_determinism: Ensures reproducible results during training. Note: This will negatively impact performance. Defaults to False.
- 🔥report_to: Default value is `tensorboard`. You can also specify `--report_to tensorboard wandb swanlab` or `--report_to all`.
- logging_first_step: Whether to log the first step, defaults to True.
Expand Down
11 changes: 6 additions & 5 deletions examples/train/embedding/train_emb.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
nproc_per_node=8
# 4*12G
nproc_per_node=2
# 2*12G
# losses: plugin/loss.py
# data format: docs/source_en/Customization/Custom-dataset.md
# --use_chat_template must be false to use generation template
# --dataloader_drop_last must be true or eval gather will throw error
# --model iic/gte-modernbert-base iic/gte_Qwen2-7B-instruct also supported
CUDA_VISIBLE_DEVICES=0,1 \
NPROC_PER_NODE=$nproc_per_node \
swift sft \
--model Qwen/Qwen3-Embedding-0.6B \
Expand All @@ -15,14 +16,14 @@ swift sft \
--split_dataset_ratio 0.05 \
--eval_strategy steps \
--output_dir output \
--eval_steps 20 \
--save_steps 50 \
--eval_steps 50 \
--num_train_epochs 5 \
--save_steps 70 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--learning_rate 6e-6 \
--loss_type infonce \
--label_names labels \
--dataloader_drop_last true \
--deepspeed zero3
--deepspeed zero2
2 changes: 1 addition & 1 deletion examples/train/multi-gpu/fsdp_qlora/train.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# 14GiB * 2
# 80GiB * 2
nproc_per_node=2

CUDA_VISIBLE_DEVICES=0,1 \
Expand Down
1 change: 1 addition & 0 deletions examples/train/multimodal/audio.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pip install "transformers==4.48.*"
CUDA_VISIBLE_DEVICES=0 \
swift sft \
--model Qwen/Qwen2-Audio-7B-Instruct \
Expand Down
2 changes: 1 addition & 1 deletion swift/megatron/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _patch_training_log():
from megatron.core.num_microbatches_calculator import get_num_microbatches
from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory

# Code borrowed from megatron-lm
# Code borrowed from NVIDIA/Megatron-LM
def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, loss_scale,
report_memory_flag, skipped_iter, grad_norm, params_norm, num_zeros_in_grad):
"""Log training information such as losses, timing, ...."""
Expand Down
1 change: 1 addition & 0 deletions swift/megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def apply_rotary_pos_emb(*args, **kwargs):
finally:
attention.apply_rotary_pos_emb = origin_apply_rotary_pos_emb

# Code borrowed from NVIDIA/Megatron-LM
def forward(
self,
input_ids: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions swift/megatron/train/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_par
except StopIteration:
return {}, True, True, True, 0, None, None

# Code borrowed from megatron-lm
# Code borrowed from NVIDIA/Megatron-LM
def evaluate(self,
forward_step_func,
data_iterator,
Expand Down Expand Up @@ -229,7 +229,7 @@ def _patch_megatron(self):
self._origin_evaluate = training.evaluate
training.evaluate = self.evaluate

# Code borrowed from megatron-lm
# Code borrowed from NVIDIA/Megatron-LM
def loss_func(self, output_tensor: torch.Tensor, *, loss_mask: torch.Tensor):
"""Loss function.

Expand Down
27 changes: 0 additions & 27 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@
import os
import platform
from dataclasses import dataclass, field
from functools import wraps
from typing import List, Literal, Optional, Union

import torch
import torch.utils.checkpoint
from transformers.training_args import TrainingArguments as HfTrainingArguments
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments

Expand Down Expand Up @@ -65,29 +62,6 @@ class TrainArgumentsMixin:
eval_datasets_args: Optional[Union[str, dict]] = None
eval_generation_config: Optional[Union[str, dict]] = None

def _fix_gradient_checkpointing(self):
# fix use_reentrant
if hasattr(torch.utils.checkpoint, '_old_checkpoint'): # avoid double patching
return
# Consistent with the default behavior of transformers.
use_reentrant_ = (
self.gradient_checkpointing_kwargs.get('use_reentrant', True)
if self.gradient_checkpointing_kwargs else True)
_old_checkpoint = torch.utils.checkpoint.checkpoint

@wraps(_old_checkpoint)
def _new_checkpoint(*args, use_reentrant=None, **kwargs):
return _old_checkpoint(*args, use_reentrant=use_reentrant_, **kwargs)

torch.utils.checkpoint._old_checkpoint = _old_checkpoint
torch.utils.checkpoint.checkpoint = _new_checkpoint
try:
# Fix the old version of transformers.
import transformers.modeling_utils
transformers.modeling_utils.checkpoint = _new_checkpoint
except (ImportError, AttributeError):
pass

@staticmethod
def _patch_liger_kernel():
# fix logits_to_keep
Expand Down Expand Up @@ -129,7 +103,6 @@ def __post_init__(self):
self.vit_gradient_checkpointing = self.gradient_checkpointing
if self.gradient_checkpointing_kwargs:
self.gradient_checkpointing_kwargs = ModelArguments.parse_to_dict(self.gradient_checkpointing_kwargs)
self._fix_gradient_checkpointing()
self._init_liger()
if self.dataloader_num_workers is None:
if platform.system() == 'Windows':
Expand Down
37 changes: 35 additions & 2 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import time
from contextlib import contextmanager
from copy import copy
from functools import partial
from functools import partial, wraps
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union

import safetensors
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils.checkpoint
import transformers
from datasets import Dataset as HfDataset
from modelscope import check_local_model_is_latest
Expand All @@ -34,7 +35,7 @@
from swift.llm import BatchSamplerShard, DataLoaderDispatcher, DataLoaderShard, Template
from swift.plugin import MeanMetric, compute_acc, extra_tuners
from swift.tuners import SwiftModel
from swift.utils import get_logger, is_mp, is_mp_ddp, ms_logger_context, seed_worker, use_torchacc
from swift.utils import get_logger, is_dist, is_mp, is_mp_ddp, ms_logger_context, seed_worker, use_torchacc
from swift.utils.torchacc_utils import ta_trim_graph
from ..utils.torch_utils import get_device_count
from .arguments import TrainingArguments
Expand Down Expand Up @@ -109,6 +110,7 @@ def __init__(self,
if self.template.sequence_parallel_size > 1:
from swift.trainers.sequence_parallel import sequence_parallel
sequence_parallel.prepare_trainer(self)
self._fix_gradient_checkpointing()

def get_use_logits_to_keep(self, default_value: bool):
use_logits_to_keep = self.args.use_logits_to_keep
Expand Down Expand Up @@ -323,6 +325,37 @@ def clip_grad_norm_(self, parameters, *args, **kwargs):
finally:
Accelerator.clip_grad_norm_ = origin_clip_grad_norm_

def _fix_gradient_checkpointing(self):
# fix use_reentrant
if hasattr(torch.utils.checkpoint, '_old_checkpoint'): # avoid double patching
return
args = self.args
# Consistent with the default behavior of transformers.
if args.gradient_checkpointing_kwargs:
use_reentrant_ = args.gradient_checkpointing_kwargs.get('use_reentrant')
else:
use_reentrant_ = None
if use_reentrant_ is None:
if is_dist() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
use_reentrant_ = False
else:
use_reentrant_ = True
logger.info(f'use_reentrant: {use_reentrant_}')
_old_checkpoint = torch.utils.checkpoint.checkpoint

@wraps(_old_checkpoint)
def _new_checkpoint(*args, use_reentrant=None, **kwargs):
return _old_checkpoint(*args, use_reentrant=use_reentrant_, **kwargs)

torch.utils.checkpoint._old_checkpoint = _old_checkpoint
torch.utils.checkpoint.checkpoint = _new_checkpoint
try:
# Fix the old version of transformers.
import transformers.modeling_utils
transformers.modeling_utils.checkpoint = _new_checkpoint
except (ImportError, AttributeError):
pass

def _prepare_gradient_checkpointing(self, model) -> None:
from swift.llm import HfConfigFactory, get_model_arch, deep_getattr, dynamic_gradient_checkpointing
args = self.args
Expand Down
Loading