Skip to content
Open
2 changes: 1 addition & 1 deletion nemo/collections/diffusion/recipes/flux_12b.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def trainer(
gradient_accumulation_fusion=True,
ddp=run.Config(
DistributedDataParallelConfig,
# use_custom_fsdp=True,
# use_megatron_fsdp=True,
# data_parallel_sharding_strategy='optim_grads_params',
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
Expand Down
8 changes: 4 additions & 4 deletions nemo/lightning/_strategy_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,11 +568,11 @@ def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], stri
strict = strict in strict_options

try:
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel
from megatron.core.distributed import FullyShardedDataParallel

have_custom_fsdp = True
have_megatron_fsdp = True
except ImportError or ModuleNotFoundError:
have_custom_fsdp = False
have_megatron_fsdp = False

for index, module in enumerate(megatron_parallel):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
Expand Down Expand Up @@ -610,7 +610,7 @@ def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], stri
else:
_state_dict[key] = value

if have_custom_fsdp and hasattr(module, "module") and isinstance(module.module, FullyShardedDataParallel):
if have_megatron_fsdp and hasattr(module, "module") and isinstance(module.module, FullyShardedDataParallel):
module.module.load_state_dict(_state_dict, strict=strict)
continue

Expand Down
23 changes: 13 additions & 10 deletions nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@
from nemo.utils.model_utils import check_lib_version

try:
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel
from megatron.core.distributed import FullyShardedDataParallel

HAVE_CUSTOM_FSDP = True
HAVE_MEGATRON_FSDP = True
except ImportError:
HAVE_CUSTOM_FSDP = False
HAVE_MEGATRON_FSDP = False

try:
from megatron.core.full_cuda_graph import FullCudaGraphWrapper
Expand Down Expand Up @@ -553,7 +553,7 @@ def init_model_parallel(self):
from megatron.core.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes

for model_module in self:
if not self._cpu and (not HAVE_CUSTOM_FSDP or self.fsdp != "megatron"):
if not self._cpu and (not HAVE_MEGATRON_FSDP or self.fsdp != "megatron"):
# If Megatron custom FSDP is enabled, we don't need to move the model to GPU here to avoid GPU OOM.
model_module.cuda(torch.cuda.current_device())

Expand Down Expand Up @@ -634,29 +634,32 @@ def init_ddp(self):
# Avoid rewrapping the module if it's already wrapped with FSDP
unwrapped_module = unwrap_model(module, Float16Module)
if (
HAVE_CUSTOM_FSDP
HAVE_MEGATRON_FSDP
and self.fsdp == "megatron"
and not isinstance(unwrapped_module, FullyShardedDataParallel)
):
from nemo.utils import logging

if not getattr(module.config, "use_custom_fsdp", False):
setattr(module.config, "use_custom_fsdp", True)
logging.warning("Setting module.config.use_custom_fsdp to True for MCore FSDP.")
if not getattr(module.config, "use_megatron_fsdp", False):
setattr(module.config, "use_megatron_fsdp", True)
logging.warning("Setting module.config.use_megatron_fsdp to True for MCore FSDP.")

if getattr(module.config, "gradient_accumulation_fusion", True):
setattr(module.config, "gradient_accumulation_fusion", False)
logging.warning("Setting module.config.gradient_accumulation_fusion to False for MCore FSDP.")

assert module.config.use_custom_fsdp, "Custom FSDP is not enabled in module.config."
assert self.ddp_config.use_custom_fsdp, "Custom FSDP is not enabled in ddp_config."
assert module.config.use_megatron_fsdp, "MCore FSDP is not enabled in module.config."
assert self.ddp_config.use_megatron_fsdp, "MCore FSDP is not enabled in ddp_config."

dist_module = FullyShardedDataParallel(
module.config,
self.ddp_config,
module,
disable_bucketing=disable_bucketing,
)
dist_module.buffers = [dist_module.param_and_grad_buffer]
dist_module.config = module.config
dist_module.sharded_state_dict = lambda *args, **kwargs: dist_module.state_dict()
elif not isinstance(unwrapped_module, DDP):
dist_module = DDP(
module.config,
Expand Down
137 changes: 125 additions & 12 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,9 @@ def __init__(

self._fsdp = None

if fsdp is None and self.ddp_config and self.ddp_config.use_custom_fsdp:
if fsdp is None and self.ddp_config and self.ddp_config.use_megatron_fsdp:
logging.warning(
"FSDP option is not set but ddp_config.use_custom_fsdp is set to true. "
"FSDP option is not set but ddp_config.use_megatron_fsdp is set to true. "
"Setting FSDP option to megatron"
)
fsdp = 'megatron'
Expand All @@ -376,9 +376,9 @@ def __init__(
raise NotImplementedError("PyTorch FSDP2 is not supported with MegatronParallel.")
elif fsdp == "megatron":
self._fsdp = fsdp
if not self.ddp_config.use_custom_fsdp:
self.ddp_config.use_custom_fsdp = True
logging.warning("Setting ddp_config.use_custom_fsdp to True for MCore FSDP.")
if not self.ddp_config.use_megatron_fsdp:
self.ddp_config.use_megatron_fsdp = True
logging.warning("Setting ddp_config.use_megatron_fsdp to True for MCore FSDP.")
logging.info("FSDP option is set to MCore. Using MCore's Custom FSDP for DP.")
elif fsdp is not None:
raise ValueError(f'Invalid DDP type: {fsdp}, please choose from ["megatron", "pytorch"].')
Expand Down Expand Up @@ -930,6 +930,58 @@ def optimizer_sharded_state_dict(self, is_loading: bool = False, metadata: Optio
metadata=metadata,
)

def _get_fsdp_dtensor_state_dict(
self,
raw_state_dict,
model_key="model",
optimizer_key="optimizer_states",
):
from megatron.core.transformer.fsdp_dtensor_checkpoint import (
handle_fp8_extra_state_case,
handle_swiglu_in_state_dict,
)
from megatron.core.distributed.fsdp.src.megatron_fsdp.uneven_dtensor import (
preprocess_state_dict_for_uneven_dtensor,
)

state_dict = raw_state_dict.copy()
handle_fp8_extra_state_case(state_dict[model_key])
module = self.model[0].module
if torch.distributed.get_rank() == 0:
print(self.model, module)
if getattr(module.config, "gated_linear_unit", False):
model_state_dict = state_dict[model_key].copy()
if optimizer_key in state_dict:
optimizer_state_dict = state_dict[optimizer_key].copy()
else:
optimizer_state_dict = {}

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo.collections.llm.modelopt.model_utils
begins an import cycle.
handle_swiglu_in_state_dict(
module.module, model_state_dict, optimizer_state_dict
)
state_dict[model_key] = model_state_dict
if optimizer_key in state_dict:
state_dict[optimizer_key] = optimizer_state_dict
preprocess_state_dict_for_uneven_dtensor(state_dict)

return state_dict

def _save_fsdp_dtensor_checkpoint(
self,
checkpoint: Dict[str, Any],
path,
storage_options,
):
state_dict = self._get_fsdp_dtensor_state_dict(checkpoint)

torch.distributed.checkpoint.save(
state_dict,
storage_writer=torch.distributed.checkpoint.FileSystemWriter(path),
)
self._save_fsdp_dtensor_common_state(state_dict=state_dict, ckpt_dir=path)

if "finalize_fn" in storage_options:
storage_options["finalize_fn"]()

@override
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None
Expand Down Expand Up @@ -963,10 +1015,23 @@ def save_checkpoint(
if not storage_options:
storage_options = {}
storage_options['content_metadata'] = self.sharded_state_dict_metadata
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
if self.save_ckpt_format == "fsdp_dtensor":
checkpoint = checkpoint.copy()
if "optimizer" in checkpoint:
checkpoint["optimizer_states"] = checkpoint.pop("optimizer")[0]
checkpoint["model"] = checkpoint.pop("sharded_state_dict")
self._save_fsdp_dtensor_checkpoint(
checkpoint=checkpoint,
path=ckpt_to_dir(filepath),
storage_options=storage_options,
)
checkpoint_io = None
else:
self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
checkpoint_io = self.checkpoint_io

# Save ModelOpt state too, if it exists.
save_modelopt_state(self.megatron_parallel, filepath, self.checkpoint_io)
save_modelopt_state(self.megatron_parallel, filepath, checkpoint_io)

def should_restore_optimizer_states(self, selective_restore: bool = False) -> bool:
"""Determines whether to restore optimizer states or not"""
Expand All @@ -975,6 +1040,34 @@ def should_restore_optimizer_states(self, selective_restore: bool = False) -> bo

return self.ckpt_load_optimizer

def _save_fsdp_dtensor_common_state(self, state_dict, ckpt_dir):
state_dict = state_dict.copy()
del state_dict["model"]
del state_dict["optimizer_states"]
torch.save(state_dict, os.path.join(ckpt_dir, "common.pt"))

def _load_fsdp_dtensor_common_state(self, ckpt_dir):
return torch.load(os.path.join(ckpt_dir, "common.pt"), weights_only=False)

def _load_fsdp_dtensor_checkpoint(self, path, sharded_state_dict, strict):
from torch.distributed.checkpoint import default_planner

state_dict = self._get_fsdp_dtensor_state_dict(sharded_state_dict)

planner = default_planner.DefaultLoadPlanner(allow_partial_load=not strict)
torch.distributed.checkpoint.load(
state_dict,
checkpoint_id=path,
planner=planner,
)
sharded_state_dict.update(
self._load_fsdp_dtensor_common_state(ckpt_dir=path)
)
if "loops" in sharded_state_dict:
sharded_state_dict["fit_loop"] = sharded_state_dict["loops"]["fit_loop"]

return sharded_state_dict

@override
def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore: bool = False) -> Dict[str, Any]:
"""PTL method which we override to integrate distributed checkpoints for model parallel models.
Expand All @@ -994,7 +1087,10 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore:
sharded_state_context = nullcontext

# After dist_checkpointing.load, sharded tensors will be replaced with tensors
sharded_sd_metadata = self.unwrapped_checkpoint_io.load_content_metadata(checkpoint_path)
if self.save_ckpt_format == "fsdp_dtensor":
sharded_sd_metadata = self.sharded_state_dict_metadata
else:
sharded_sd_metadata = self.unwrapped_checkpoint_io.load_content_metadata(checkpoint_path)
sharded_state_dict = {}
with sharded_state_context():
sharded_state_dict["state_dict"] = self.megatron_parallel.sharded_state_dict(metadata=sharded_sd_metadata)
Expand All @@ -1010,9 +1106,19 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore:
)

try:
checkpoint = self.checkpoint_io.load_checkpoint(
checkpoint_path, sharded_state_dict=sharded_state_dict, strict=strict
)
if self.save_ckpt_format == "fsdp_dtensor":
sharded_state_dict["model"] = sharded_state_dict.pop("state_dict")
if "optimizer" in sharded_state_dict:
sharded_state_dict["optimizer_states"] = sharded_state_dict.pop("optimizer")[0]
checkpoint = self._load_fsdp_dtensor_checkpoint(
path=ckpt_to_dir(checkpoint_path),
sharded_state_dict=sharded_state_dict,
strict=strict,
)
else:
checkpoint = self.checkpoint_io.load_checkpoint(
checkpoint_path, sharded_state_dict=sharded_state_dict, strict=strict
)
except CheckpointException as e:
error_message = f"{e}\n{LOAD_ERROR}"
raise RuntimeError(error_message)
Expand All @@ -1031,7 +1137,9 @@ def sharded_state_dict_metadata(self):
"""Metadata used for sharded_state_dict generation during checkpoint save."""
metadata = {}
if isinstance(self.ddp_config, DistributedDataParallelConfig) and self.ddp_config.use_distributed_optimizer:
if self.parallel_save_optim:
if self.ddp_config.use_megatron_fsdp:
metadata["distrib_optim_sharding_type"] = "fsdp_dtensor"
elif self.parallel_save_optim:
metadata["distrib_optim_sharding_type"] = "fully_sharded_model_space"
else:
metadata["distrib_optim_sharding_type"] = "dp_zero_gather_scatter"
Expand Down Expand Up @@ -1072,6 +1180,11 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any], selective_res

mesh = DeviceMesh.from_group(parallel_state.get_data_parallel_group(), "cuda")

if self.save_ckpt_format == "fsdp_dtensor":
assert len(self.optimizers) == 1, "FSDP DTensor format requires a single optimizer."
self.optimizers[0].load_state_dict(checkpoint["optimizer_states"])
return

optimizer_states = checkpoint["optimizer"]
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
if self._fsdp is not None:
Expand Down
12 changes: 6 additions & 6 deletions scripts/dit/dit_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def train_mock() -> run.Partial:
recipe.data.model_config = recipe.model.config
recipe.log.log_dir = 'nemo_experiments/train_mock'

recipe.trainer.strategy.ddp.use_custom_fsdp = True
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand All @@ -236,7 +236,7 @@ def mock_ditllama5b_8k() -> run.Partial:
recipe.data.model_config = recipe.model.config
recipe.log.log_dir = 'nemo_experiments/mock_ditllama5b_8k'
recipe.model.config.attn_mask_type = AttnMaskType.no_mask
recipe.trainer.strategy.ddp.use_custom_fsdp = True
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand Down Expand Up @@ -360,7 +360,7 @@ def pretrain_ditllama30b() -> run.Partial:
recipe.data.task_encoder.seq_length = 256
recipe.data.virtual_epoch_length = 0
recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage1_mock'
recipe.trainer.strategy.ddp.use_custom_fsdp = True
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand All @@ -386,7 +386,7 @@ def pretrain_ditllama30b_stage2_mock() -> run.Partial:
recipe.trainer.val_check_interval = 1.0
recipe.data.model_config = recipe.model.config
recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage2_mock'
recipe.trainer.strategy.ddp.use_custom_fsdp = True
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand All @@ -412,7 +412,7 @@ def pretrain_ditllama30b_stage3_mock() -> run.Partial:
recipe.trainer.val_check_interval = 1.0
recipe.data.model_config = recipe.model.config
recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage3_mock'
recipe.trainer.strategy.ddp.use_custom_fsdp = True
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand Down Expand Up @@ -512,7 +512,7 @@ def pretrain_ecditllama1b() -> run.Partial:
recipe.log.log_dir = 'nemo_experiments/ecditllama1b'
recipe.trainer.val_check_interval = 3000

recipe.trainer.strategy.ddp.use_custom_fsdp = True
recipe.trainer.strategy.ddp.use_megatron_fsdp = True
recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'optim_grads_params'
recipe.trainer.strategy.ddp.overlap_param_gather = True
recipe.trainer.strategy.ddp.overlap_grad_reduce = True
Expand Down
4 changes: 2 additions & 2 deletions scripts/flux/flux_controlnet_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def flux_controlnet_training() -> run.Partial:
pipeline_dtype=torch.bfloat16,
ddp=run.Config(
DistributedDataParallelConfig,
use_custom_fsdp=True,
use_megatron_fsdp=True,
data_parallel_sharding_strategy='optim_grads_params',
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
Expand Down Expand Up @@ -292,7 +292,7 @@ def unit_test(custom_fsdp=True) -> run.Partial:
def configure_custom_fsdp(recipe) -> run.Partial:
recipe.trainer.strategy.ddp = run.Config(
DistributedDataParallelConfig,
use_custom_fsdp=True,
use_megatron_fsdp=True,
data_parallel_sharding_strategy='optim_grads_params', # Custom FSDP
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
Expand Down
4 changes: 2 additions & 2 deletions scripts/flux/flux_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def flux_training() -> run.Partial:
gradient_accumulation_fusion=True,
ddp=run.Config(
DistributedDataParallelConfig,
use_custom_fsdp=True,
use_megatron_fsdp=True,
data_parallel_sharding_strategy='optim_grads_params',
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
Expand Down Expand Up @@ -229,7 +229,7 @@ def fp8_test(custom_fsdp=True) -> run.Partial:
def configure_custom_fsdp(recipe) -> run.Partial:
recipe.trainer.strategy.ddp = run.Config(
DistributedDataParallelConfig,
use_custom_fsdp=True,
use_megatron_fsdp=True,
data_parallel_sharding_strategy='optim_grads_params', # Custom FSDP
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
Expand Down
Loading
Loading