Skip to content

Commit 64ff9da

Browse files
Fix galore linears (#541)
1 parent 20b19f9 commit 64ff9da

File tree

3 files changed

+6
-25
lines changed

3 files changed

+6
-25
lines changed

swift/llm/tuner.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,20 +179,18 @@ def prepare_model(model, args: SftArguments):
179179

180180
if args.use_galore:
181181
from swift.trainers.optimizers.galore import GaLoreConfig
182-
model_type = args.model_type
183-
for key in MODEL_KEYS_MAPPING.keys():
184-
if key in model_type.lower():
185-
model_type = key
186-
break
182+
if args.galore_target_modules is None:
183+
args.galore_target_modules = find_all_linears(
184+
model, 0, args.model_type)
185+
if args.galore_with_embedding:
186+
args.galore_target_modules += find_embedding(model)
187187
args.training_args.galore_config = GaLoreConfig(
188-
model_type=model_type,
189188
target_modules=args.galore_target_modules,
190189
rank=args.galore_rank,
191190
update_proj_gap=args.galore_update_proj_gap,
192191
galore_scale=args.galore_scale,
193192
proj_type=args.galore_proj_type,
194193
optim_per_parameter=args.galore_optim_per_parameter,
195-
with_embedding=args.galore_with_embedding,
196194
)
197195

198196
class TrainerAdapterCallback(TrainerCallback):

swift/llm/utils/argument.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class SftArguments:
125125
# galore
126126
use_galore: bool = False
127127
galore_rank: int = 128
128-
galore_target_modules: Union[str, List[str]] = None
128+
galore_target_modules: Optional[List[str]] = None
129129
galore_update_proj_gap: int = 50
130130
galore_scale: float = 1.0
131131
galore_proj_type: str = 'std'

swift/trainers/optimizers/galore/utils.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch.optim.lr_scheduler import LRScheduler
99
from transformers import Trainer, TrainingArguments, get_scheduler
1010

11-
from swift.tuners.module_mapping import MODEL_KEYS_MAPPING
1211
from swift.utils import get_logger
1312

1413
logger = get_logger()
@@ -23,7 +22,6 @@ class GaLoreConfig:
2322
See https://arxiv.org/abs/2403.03507
2423
2524
Args:
26-
model_type (`str`): The model_type of Galore
2725
rank (`int`): The galore rank
2826
target_modules (`Union[str, List[str]]`): The target modules to use, if `None`,
2927
will use all attn and mlp linears
@@ -33,13 +31,11 @@ class GaLoreConfig:
3331
galore_scale(float): the scale of gradient
3432
optim_per_parameter(bool): Gives one optimizer per parameter
3533
"""
36-
model_type: str = None
3734
rank: int = 128
3835
target_modules: Union[str, List[str]] = None
3936
update_proj_gap: int = 50
4037
galore_scale: float = 1.0
4138
proj_type: str = 'std'
42-
with_embedding: bool = False
4339
optim_per_parameter: bool = False
4440

4541

@@ -72,19 +68,6 @@ def step(self, *args, **kwargs) -> None:
7268
def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments,
7369
config: GaLoreConfig, max_steps,
7470
**defaults):
75-
if not config.target_modules:
76-
if config.model_type in MODEL_KEYS_MAPPING:
77-
target_modules_list = [
78-
MODEL_KEYS_MAPPING[config.model_type].attention.split('.{}.')
79-
[1], MODEL_KEYS_MAPPING[config.model_type].mlp.split('.{}.')[1]
80-
]
81-
config.target_modules = target_modules_list
82-
if config.with_embedding:
83-
embedding = MODEL_KEYS_MAPPING[config.model_type].embedding
84-
idx = embedding.rfind('.')
85-
embedding = embedding[idx + 1:]
86-
target_modules_list.append(embedding)
87-
8871
galore_params = []
8972
for module_name, module in model.named_modules():
9073
if not isinstance(module, (nn.Linear, nn.Embedding)) or \

0 commit comments

Comments
 (0)