Skip to content

Commit 97a36bc

Browse files
yaozheweilekurile
andauthored
fix only optimize lora and ack-ckpting compatible (#658)
Co-authored-by: Lev Kurilenko <[email protected]>
1 parent 1a0b896 commit 97a36bc

File tree

5 files changed

+33
-22
lines changed

5 files changed

+33
-22
lines changed

applications/DeepSpeed-Chat/training/step1_supervised_finetuning/main.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from utils.data.data_utils import create_prompt_dataset
2828
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer
2929
from utils.ds_utils import get_train_ds_config
30-
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters
30+
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
3131
from utils.model.model_utils import create_hf_model
3232

3333

@@ -175,12 +175,6 @@ def parse_args():
175175
parser = deepspeed.add_config_arguments(parser)
176176
args = parser.parse_args()
177177

178-
# Validate settings
179-
if args.gradient_checkpointing and args.lora_dim > 0:
180-
assert (
181-
not args.only_optimize_lora
182-
), "--gradient_checkpointing and --only_optimize_lora cannot be enabled at the same time."
183-
184178
return args
185179

186180

@@ -240,6 +234,7 @@ def main():
240234
args.lora_dim)
241235
if args.only_optimize_lora:
242236
model = only_optimize_lora_parameters(model)
237+
model = make_model_gradient_checkpointing_compatible(model)
243238

244239
# Prepare the data
245240
train_phase = 1

applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/main.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from utils.data.data_utils import create_prompt_dataset, DataCollatorReward
2727
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer
2828
from utils.ds_utils import get_train_ds_config
29-
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters
29+
from utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
3030

3131

3232
def parse_args():
@@ -171,12 +171,6 @@ def parse_args():
171171
parser = deepspeed.add_config_arguments(parser)
172172
args = parser.parse_args()
173173

174-
# Validate settings
175-
if args.gradient_checkpointing and args.lora_dim > 0:
176-
assert (
177-
not args.only_optimize_lora
178-
), "--gradient_checkpointing and --only_optimize_lora cannot be enabled at the same time."
179-
180174
return args
181175

182176

@@ -225,6 +219,7 @@ def main():
225219
args.lora_dim)
226220
if args.only_optimize_lora:
227221
rm_model = only_optimize_lora_parameters(rm_model)
222+
rm_model = make_model_gradient_checkpointing_compatible(rm_model)
228223

229224
train_phase = 2
230225
train_dataset, eval_dataset = create_prompt_dataset(

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -318,18 +318,21 @@ def parse_args():
318318
f"{args.tensorboard_path}/step3_tensorboard_logs")
319319

320320
# Validate settings
321-
if (args.actor_gradient_checkpointing
322-
and args.actor_lora_dim > 0) or (args.critic_gradient_checkpointing
323-
and args.critic_lora_dim > 0):
324-
assert (
325-
not args.only_optimize_lora
326-
), "--{actor,critic}_gradient_checkpointing and --only_optimize_lora cannot be enabled at the same time."
327-
328321
if args.inference_tp_size > 1:
329322
assert (
330323
args.actor_zero_stage == 3
331324
), "Zero stage 3 must be used to do Tensor sharding in the hybrid engine"
332325

326+
if args.actor_zero_stage == 2 and args.critic_zero_stage == 2 and args.enable_hybrid_engine and args.offload and args.actor_lora_dim == 0:
327+
raise ValueError(
328+
"The combination of [actor_zero_stage==2, critic_zero_stage==2, enable_hybrid_engine=True, offload=True, lora=False] is currently unsupported due to training instability!"
329+
)
330+
331+
if args.actor_zero_stage == 3 and args.critic_zero_stage == 3 and args.enable_hybrid_engine and args.offload and args.actor_lora_dim > 0:
332+
raise ValueError(
333+
"The combination of [actor_zero_stage==3, critic_zero_stage==3, enable_hybrid_engine=True, offload=True, lora=True] is currently unsupported due to training instability!"
334+
)
335+
333336
return args
334337

335338

applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/rlhf_engine.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from transformers import AutoModelForCausalLM, get_scheduler
1111

1212
from utils.ds_utils import get_train_ds_config, get_eval_ds_config
13-
from utils.module.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters
13+
from utils.module.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible
1414
from utils.model.model_utils import create_hf_model, create_critic_model
1515
from utils.utils import get_optimizer_grouped_parameters
1616
"""
@@ -100,6 +100,8 @@ def _init_actor(self, actor_model_name_or_path):
100100
self.args.actor_lora_dim)
101101
if self.args.only_optimize_lora:
102102
actor_model = only_optimize_lora_parameters(actor_model)
103+
actor_model = make_model_gradient_checkpointing_compatible(
104+
actor_model)
103105

104106
# Optimizer
105107
AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam
@@ -226,6 +228,8 @@ def _init_critic(self, critic_model_name_or_path):
226228
self.args.critic_lora_dim)
227229
if self.args.only_optimize_lora:
228230
critic_model = only_optimize_lora_parameters(critic_model)
231+
critic_model = make_model_gradient_checkpointing_compatible(
232+
critic_model)
229233

230234
# Optimizer
231235
AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam

applications/DeepSpeed-Chat/training/utils/module/lora.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,17 @@ def only_optimize_lora_parameters(model):
139139
else:
140140
param.requires_grad = False
141141
return model
142+
143+
144+
def make_model_gradient_checkpointing_compatible(model):
145+
# Higgingface added this enable input require grads function to make gradient checkpointing work for lora-only optimization
146+
if hasattr(model, "enable_input_require_grads"):
147+
model.enable_input_require_grads()
148+
elif hasattr(model, "get_input_embeddings"):
149+
150+
def make_inputs_require_grad(module, input, output):
151+
output.requires_grad_(True)
152+
153+
model.get_input_embeddings().register_forward_hook(
154+
make_inputs_require_grad)
155+
return model

0 commit comments

Comments
 (0)