Skip to content

Commit 07e4c4e

Browse files
committed
update use_custom_fsdp to use_megatron_fsdp
1 parent ab6e271 commit 07e4c4e

File tree

5 files changed

+54
-11
lines changed

5 files changed

+54
-11
lines changed

nemo/collections/diffusion/recipes/flux_12b.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def trainer(
113113
gradient_accumulation_fusion=True,
114114
ddp=run.Config(
115115
DistributedDataParallelConfig,
116-
# use_custom_fsdp=True,
116+
# use_megatron_fsdp=True,
117117
# data_parallel_sharding_strategy='optim_grads_params',
118118
check_for_nan_in_grad=True,
119119
grad_reduce_in_fp32=True,

scripts/dit/dit_train.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def train_mock() -> run.Partial:
210210
recipe.data.model_config = recipe.model.config
211211
recipe.log.log_dir = 'nemo_experiments/train_mock'
212212

213-
recipe.trainer.strategy.ddp.use_custom_fsdp = True
213+
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
214214
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
215215
recipe.trainer.strategy.ddp.overlap_param_gather = True
216216
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
@@ -236,7 +236,7 @@ def mock_ditllama5b_8k() -> run.Partial:
236236
recipe.data.model_config = recipe.model.config
237237
recipe.log.log_dir = 'nemo_experiments/mock_ditllama5b_8k'
238238
recipe.model.config.attn_mask_type = AttnMaskType.no_mask
239-
recipe.trainer.strategy.ddp.use_custom_fsdp = True
239+
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
240240
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
241241
recipe.trainer.strategy.ddp.overlap_param_gather = True
242242
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
@@ -360,7 +360,7 @@ def pretrain_ditllama30b() -> run.Partial:
360360
recipe.data.task_encoder.seq_length = 256
361361
recipe.data.virtual_epoch_length = 0
362362
recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage1_mock'
363-
recipe.trainer.strategy.ddp.use_custom_fsdp = True
363+
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
364364
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
365365
recipe.trainer.strategy.ddp.overlap_param_gather = True
366366
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
@@ -386,7 +386,7 @@ def pretrain_ditllama30b_stage2_mock() -> run.Partial:
386386
recipe.trainer.val_check_interval = 1.0
387387
recipe.data.model_config = recipe.model.config
388388
recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage2_mock'
389-
recipe.trainer.strategy.ddp.use_custom_fsdp = True
389+
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
390390
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
391391
recipe.trainer.strategy.ddp.overlap_param_gather = True
392392
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
@@ -412,7 +412,7 @@ def pretrain_ditllama30b_stage3_mock() -> run.Partial:
412412
recipe.trainer.val_check_interval = 1.0
413413
recipe.data.model_config = recipe.model.config
414414
recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage3_mock'
415-
recipe.trainer.strategy.ddp.use_custom_fsdp = True
415+
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
416416
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
417417
recipe.trainer.strategy.ddp.overlap_param_gather = True
418418
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
@@ -512,7 +512,7 @@ def pretrain_ecditllama1b() -> run.Partial:
512512
recipe.log.log_dir = 'nemo_experiments/ecditllama1b'
513513
recipe.trainer.val_check_interval = 3000
514514

515-
recipe.trainer.strategy.ddp.use_custom_fsdp = True
515+
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
516516
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
517517
recipe.trainer.strategy.ddp.overlap_param_gather = True
518518
recipe.trainer.strategy.ddp.overlap_grad_reduce = True

scripts/flux/flux_controlnet_training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def flux_controlnet_training() -> run.Partial:
9292
pipeline_dtype=torch.bfloat16,
9393
ddp=run.Config(
9494
DistributedDataParallelConfig,
95-
use_custom_fsdp=True,
95+
use_megatron_fsdp=True,
9696
data_parallel_sharding_strategy='optim_grads_params',
9797
check_for_nan_in_grad=True,
9898
grad_reduce_in_fp32=True,
@@ -292,7 +292,7 @@ def unit_test(custom_fsdp=True) -> run.Partial:
292292
def configure_custom_fsdp(recipe) -> run.Partial:
293293
recipe.trainer.strategy.ddp = run.Config(
294294
DistributedDataParallelConfig,
295-
use_custom_fsdp=True,
295+
use_megatron_fsdp=True,
296296
data_parallel_sharding_strategy='optim_grads_params', # Custom FSDP
297297
check_for_nan_in_grad=True,
298298
grad_reduce_in_fp32=True,

scripts/flux/flux_training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def flux_training() -> run.Partial:
9595
gradient_accumulation_fusion=True,
9696
ddp=run.Config(
9797
DistributedDataParallelConfig,
98-
use_custom_fsdp=True,
98+
use_megatron_fsdp=True,
9999
data_parallel_sharding_strategy='optim_grads_params',
100100
check_for_nan_in_grad=True,
101101
grad_reduce_in_fp32=True,
@@ -229,7 +229,7 @@ def fp8_test(custom_fsdp=True) -> run.Partial:
229229
def configure_custom_fsdp(recipe) -> run.Partial:
230230
recipe.trainer.strategy.ddp = run.Config(
231231
DistributedDataParallelConfig,
232-
use_custom_fsdp=True,
232+
use_megatron_fsdp=True,
233233
data_parallel_sharding_strategy='optim_grads_params', # Custom FSDP
234234
check_for_nan_in_grad=True,
235235
grad_reduce_in_fp32=True,

scripts/performance/llm/pretrain_llama3_8b.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,49 @@ def override_recipe_configs(
8484
recipe = set_exp_logging_configs(
8585
recipe, "pre_train", "llm", "llama3", args.tensorboard, args.wandb, args.wandb_prj_name, args.wandb_job_name
8686
)
87+
# for saving checkpoints
88+
ckpt_path = "/lustre/fsw/coreai_devtech_all/jianbinc/playground/nemo_nvfsdp_update/NeMo/checkpoints"
89+
recipe.log.log_dir = ckpt_path
90+
import nemo.lightning as nl
91+
import nemo_run as run
92+
93+
recipe.log.ckpt = run.Config(
94+
nl.ModelCheckpoint,
95+
train_time_interval=None,
96+
save_last=True,
97+
every_n_train_steps=100,
98+
save_top_k=1,
99+
save_on_train_epoch_end=True,
100+
save_optim_on_train_end=True,
101+
always_save_context=False,
102+
filename="{model_name}--{val_loss:.2f}-{step}-{consumed_samples}",
103+
)
104+
105+
# nl.ModelCheckpoint(
106+
# train_time_interval=None,
107+
# )
108+
# # recipe.log.ckpt.train_time_interval = None
109+
# recipe.log.ckpt.save_last = True
110+
# recipe.log.ckpt.every_n_train_steps = 100
111+
# recipe.log.ckpt.save_top_k = 1
112+
# recipe.log.ckpt.save_on_train_epoch_end = True
113+
# recipe.log.ckpt.save_optim_on_train_end = True
114+
# recipe.log.ckpt.always_save_context = False
115+
116+
# for loading checkpoints
117+
recipe.resume.resume_if_exists = True
118+
recipe.resume.resume_ignore_no_checkpoint = True
119+
# recipe.resume.restore_config = RestoreConfig(
120+
# path=ckpt_path,
121+
# load_model_state=True,
122+
# load_optim_state=True,
123+
# )
124+
125+
recipe.trainer.strategy.save_ckpt_format = "fsdp_dtensor"
126+
recipe.trainer.strategy.ddp.average_in_collective = False
127+
# recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = "optim"
128+
129+
recipe.optim.config.use_precision_aware_optimizer = False
87130

88131
# data module configs
89132
if args.use_hf_tokenizer:

0 commit comments

Comments
 (0)