Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 81 additions & 61 deletions nemo_automodel/components/distributed/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,50 @@ def validate_tp_mesh(model, tp_mesh):
)


def _find_largest_module_list(model: nn.Module) -> Optional[nn.ModuleList]:
"""
Heuristic function to find the largest nn.ModuleList in a model.

This function recursively traverses the model to find all nn.ModuleList instances
and returns the one with the most modules. This is useful as a fallback when
the model architecture is unknown, since transformer layers are typically
organized in ModuleLists.

Args:
model (nn.Module): The model to search through.

Returns:
Optional[nn.ModuleList]: The largest ModuleList found, or None if no ModuleList exists.
"""
largest_module_list = None
largest_size = 0

def _recursive_search(module: nn.Module, path: str = ""):
nonlocal largest_module_list, largest_size

for name, child in module.named_children():
current_path = f"{path}.{name}" if path else name

if isinstance(child, nn.ModuleList):
current_size = len(child)
if current_size > largest_size:
largest_size = current_size
largest_module_list = child
logger.debug(f"Found ModuleList at {current_path} with {current_size} modules")

# Continue recursive search
_recursive_search(child, current_path)

_recursive_search(model)

if largest_module_list is not None:
logger.info(f"Largest ModuleList found with {largest_size} modules")
else:
logger.warning("No ModuleList found in the model")

return largest_module_list


def _extract_model_layers(model: nn.Module) -> List[nn.Module]:
"""
Extract layers from different model architectures for parallelization.
Expand All @@ -511,70 +555,46 @@ def _extract_model_layers(model: nn.Module) -> List[nn.Module]:
model_cls = type(model)
layers: List[nn.Module] = []

# Handle different model structures
if model_cls == Gemma3ForConditionalGeneration:
# Collect language model layers
for layer in model.language_model.layers:
layers.append(layer)
# Collect vision model layers (siglip encoder has same structure as clip encoder)
for layer in model.vision_tower.vision_model.encoder.layers:
layers.append(layer)

elif model_cls in [
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
]:
# VL models have the language model at model.language_model
# Append language model layers
for layer in model.language_model.layers:
layers.append(layer)
# Append visual model layers
for layer in model.visual.blocks:
layers.append(layer)

elif model_cls == SmolVLMForConditionalGeneration:
# Collect text model layers
for layer in model.model.text_model.layers:
layers.append(layer)
# Collect vision model layers
for layer in model.model.vision_model.encoder.layers:
layers.append(layer)

elif model_cls in [
LlavaForConditionalGeneration,
LlavaNextForConditionalGeneration,
LlavaNextVideoForConditionalGeneration,
LlavaOnevisionForConditionalGeneration,
]:
# Collect language model layers
for layer in model.model.language_model.layers:
layers.append(layer)
# Collect vision model layers
for layer in model.vision_tower.vision_model.encoder.layers:
layers.append(layer)

elif model_cls == Mistral3ForConditionalGeneration:
# Collect language model layers
for layer in model.model.language_model.layers:
layers.append(layer)
# Collect vision model layers
for layer in model.model.vision_tower.transformer.layers:
layers.append(layer)

elif model_cls == Llama4ForConditionalGeneration:
# Collect language model layers
for layer in model.language_model.model.layers:
layers.append(layer)
# Collect vision model layers
for layer in model.vision_model.model.layers:
layers.append(layer)
elif model_cls.__name__ == "NemotronHForCausalLM":
LAYER_MAP = {
Gemma3ForConditionalGeneration: ["language_model.layers", "vision_tower.vision_model.encoder.layers"],
Qwen2_5_VLForConditionalGeneration: ["language_model.layers", "visual.blocks"],
Qwen2VLForConditionalGeneration: ["language_model.layers", "visual.blocks"],
SmolVLMForConditionalGeneration: ["model.text_model.layers", "model.vision_model.encoder.layers"],
LlavaForConditionalGeneration: ["model.language_model.layers", "vision_tower.vision_model.encoder.layers"],
LlavaNextForConditionalGeneration: ["model.language_model.layers", "vision_tower.vision_model.encoder.layers"],
LlavaNextVideoForConditionalGeneration: ["model.language_model.layers", "vision_tower.vision_model.encoder.layers"],
LlavaOnevisionForConditionalGeneration: ["model.language_model.layers", "vision_tower.vision_model.encoder.layers"],
Mistral3ForConditionalGeneration: ["model.language_model.layers", "model.vision_tower.transformer.layers"],
Llama4ForConditionalGeneration: ["language_model.model.layers", "vision_model.model.layers"],
# NemotronH models use backbone.layers instead of model.layers
layers.extend(model.backbone.layers)
else:
"NemotronHForCausalLM": ["backbone.layers"],
}
def _reduce_attr(model, attr):
parts = attr.split(".")
return reduce(getattr, parts, model)

if model_cls in LAYER_MAP:
for attr in LAYER_MAP[model_cls]:
layers.extend(_reduce_attr(model, attr))
elif model_cls.__name__ == "NemotronHForCausalLM":
for attr in LAYER_MAP[model_cls.__name__]:
layers.extend(_reduce_attr(model, attr))
elif hasattr(model, "model"):
# Default case for all other models (assumed to be a causal LM)
layers.extend(model.model.layers)

elif hasattr(model, "layers"):
layers.extend(model.layers)
else:
# Use heuristic to find the largest ModuleList in the model
logger.warning(f"Unknown model type: {model_cls}. Using heuristic to find transformer layers.")
largest_module_list = _find_largest_module_list(model)
if largest_module_list is not None:
layers.extend(largest_module_list)
logger.info(f"Successfully extracted {len(largest_module_list)} layers using heuristic")
else:
# If no ModuleList found, still raise an exception
print(model)
raise ValueError(f"Unknown model type: {model_cls} and no ModuleList found in model structure")
return layers


Expand Down
48 changes: 30 additions & 18 deletions nemo_automodel/recipes/llm/train_ft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import inspect
import logging
import pathlib
import time
Expand Down Expand Up @@ -267,6 +268,32 @@ def build_loss_fn(cfg_loss):
"""
return cfg_loss.instantiate()

def _build_tokenizer(cfg_model, cfg_ds):
# if tokenizer is not provided, use the model config to instantiate it
if "tokenizer" not in cfg_ds and cfg_model.get("pretrained_model_name_or_path", None) is not None:
logging.info("Using model config to instantiate tokenizer")
trust_remote_code = getattr(cfg_model, "trust_remote_code", False)
tokenizer = AutoTokenizer.from_pretrained(
cfg_model.pretrained_model_name_or_path, trust_remote_code=trust_remote_code
)
elif cfg_ds.get("tokenizer", None) is None:
tokenizer = None
elif "_target_" not in cfg_ds.tokenizer:
tokenizer = AutoTokenizer.from_pretrained(**cfg_ds.tokenizer.to_dict())
else:
tokenizer = cfg_ds.tokenizer.instantiate()

# Finally, check if the dataset target accepts a tokenizer parameter
kwargs = {}
if tokenizer is not None and callable(cfg_ds._target_):
try:
sig = inspect.signature(cfg_ds._target_)
if 'tokenizer' in sig.parameters:
kwargs["tokenizer"] = tokenizer
except (ValueError, TypeError):
# If we can't get the signature, skip adding tokenizer
pass
return kwargs

def build_dataloader(
cfg_ds, cfg_dl, cfg_model, cfg_ps, device_mesh, seed, local_batch_size
Expand Down Expand Up @@ -295,24 +322,9 @@ def build_dataloader(
"num_replicas": device_mesh["dp"].size(),
"rank": device_mesh["dp"].get_local_rank(),
}
# if tokenizer is not provided, use the model config to instantiate it
if "tokenizer" not in cfg_ds and cfg_model.get("pretrained_model_name_or_path", None) is not None:
logging.info("Using model config to instantiate tokenizer")
trust_remote_code = getattr(cfg_model, "trust_remote_code", False)
tokenizer = AutoTokenizer.from_pretrained(
cfg_model.pretrained_model_name_or_path, trust_remote_code=trust_remote_code
)
elif cfg_ds.get("tokenizer", None) is None:
tokenizer = None
elif "_target_" not in cfg_ds.tokenizer:
tokenizer = AutoTokenizer.from_pretrained(**cfg_ds.tokenizer.to_dict())
else:
tokenizer = cfg_ds.tokenizer.instantiate()

with StatefulRNG(seed=seed, ranked=True):
kwargs = {}
if tokenizer is not None:
kwargs["tokenizer"] = tokenizer
kwargs = _build_tokenizer(cfg_model, cfg_ds)
ds = cfg_ds.instantiate(**kwargs)
# Apply packing if configured
if getattr(cfg_ps, "packed_sequence_size", 0) > 0:
Expand Down Expand Up @@ -808,10 +820,10 @@ def _forward_backward_step(

local_loss = calculate_loss(
self.loss_fn,
logits=out.logits,
logits=getattr(out, "logits", out),
labels=labels,
model=model,
hidden_states=out.hidden_states[-1] if "hidden_states" in out else None,
hidden_states=out.hidden_states[-1] if getattr(out, "hidden_states", None) is not None else None,
num_label_tokens=num_label_tokens,
)
loss_buffer.append(local_loss.clone().detach())
Expand Down
Loading