Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import re
from enum import Enum
from pathlib import Path
from random import randint
from typing import Dict, Iterator, List, Optional, Union

from datasets import VerificationMode, load_dataset
from torch.utils.data import Dataset

from nemo_automodel.components.datasets.llm.formatting_utils import (
NoContextLeftError,
_add_pad_token,
_has_chat_template,
format_chat_template,
Expand Down Expand Up @@ -152,20 +154,21 @@ def __init__(
*,
split: Optional[str] = None,
answer_only_loss_mask: bool = True,
seq_length: Optional[int] = None,
max_seq_length: Optional[int] = None,
start_of_turn_token: Optional[str] = None,
) -> None:
"""
Initialize the dataset.

Args:
path_or_dataset_id: The path or dataset id of the dataset.
column_mapping: The mapping of the columns.
tokenizer: The tokenizer to use.
split: The split of the dataset to load.
answer_only_loss_mask: Whether to compute the loss mask only on the answer tokens.
seq_length: The sequence length to use for padding.
start_of_turn_token: The token to use to indicate the start of a turn.
path_or_dataset_id (str, list[str]): The path or dataset id of the dataset.
column_mapping (dict): The mapping of the columns.
tokenizer (Tokenizer): The tokenizer to use.
split (str, optional): The split of the dataset to load.
answer_only_loss_mask (bool, optional): Whether to compute the loss mask only on the answer tokens.
max_seq_length (int, optional): If set, will truncate each example to this
length. If smaller than max_seq_length, the sequence is left as is.
start_of_turn_token (str, optional): The token to use to indicate the start of a turn.
"""

if _has_chat_template(tokenizer):
Expand Down Expand Up @@ -207,7 +210,7 @@ def __init__(

self.answer_only_loss_mask = answer_only_loss_mask
self.start_of_turn_token = start_of_turn_token
self.seq_length = seq_length
self.max_seq_length = max_seq_length

def __len__(self) -> int: # noqa: D401
"""
Expand All @@ -234,11 +237,26 @@ def __getitem__(self, idx): # noqa: D401
Raises:
RuntimeError: If streaming is enabled.
"""
row = self.dataset[idx]
mapped = {dest: row[src] for dest, src in self.column_mapping.items() if src in row}
mapped = self._apply_tokenizer(mapped)
assert _check_all_values_equal_length(mapped), "All values must be of the same length"
return mapped
# Try current idx, then successive ones, to find a sample that fits max_seq_length
total = len(self.dataset)
cur_idx = idx
last_error: Optional[Exception] = None
max_attempts = min(64, total)
for _ in range(max_attempts):
row = self.dataset[cur_idx]
mapped = {dest: row[src] for dest, src in self.column_mapping.items() if src in row}
try:
mapped = self._apply_tokenizer(mapped)
assert _check_all_values_equal_length(mapped), "All values must be of the same length"
return mapped
except NoContextLeftError as e:
last_error = e
cur_idx = randint(0, total - 1) # randint [start, end] (inclusive)
continue
# If we exhausted attempts, re-raise the last error for visibility
if last_error is not None:
raise last_error
raise RuntimeError("Failed to retrieve a valid sample")

def _apply_tokenizer(self, sample: Dict[str, str]) -> Dict[str, List[int]]:
"""
Expand Down Expand Up @@ -272,7 +290,7 @@ def _apply_tokenizer(self, sample: Dict[str, str]) -> Dict[str, List[int]]:
answer,
eos_token_id,
pad_token_id,
seq_length=self.seq_length,
max_seq_length=self.max_seq_length,
start_of_turn_token=self.start_of_turn_token,
)
else:
Expand All @@ -283,6 +301,6 @@ def _apply_tokenizer(self, sample: Dict[str, str]) -> Dict[str, List[int]]:
answer,
eos_token_id,
pad_token_id,
seq_length=self.seq_length,
max_seq_length=self.max_seq_length,
answer_only_loss_mask=self.answer_only_loss_mask,
)
122 changes: 105 additions & 17 deletions nemo_automodel/components/datasets/llm/formatting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from transformers import PreTrainedTokenizer


class NoContextLeftError(RuntimeError):
"""Raised when context must be fully removed to satisfy max_seq_length."""


def _pad_to_seq_length(sample, pad_token_id, seq_length):
"""Pad a sample to a specific sequence length."""
n = seq_length - len(sample)
Expand Down Expand Up @@ -53,7 +57,7 @@ def _has_chat_template(tokenizer: "PreTrainedTokenizer") -> bool:
)


def _package_tokenized_example(has_chat_template, input_ids, eos_token_id, pad_token_id, seq_length, context_len):
def _package_tokenized_example(has_chat_template, input_ids, eos_token_id, pad_token_id, max_seq_length, context_len):
"""
Package a tokenized example with proper masking and padding.

Expand All @@ -62,7 +66,7 @@ def _package_tokenized_example(has_chat_template, input_ids, eos_token_id, pad_t
input_ids: The tokenized input ids.
eos_token_id: The end-of-sequence token id.
pad_token_id: The padding token id.
seq_length: Optional sequence length for padding.
max_seq_length: Optional maximum sequence length (no padding applied).
context_len: Length of the context/prompt (to mask in labels).

Returns:
Expand All @@ -88,11 +92,7 @@ def _package_tokenized_example(has_chat_template, input_ids, eos_token_id, pad_t
assert input_ids[-1] != eos_token_id, f"input_ids[-1]={input_ids[-1]} == eos_token_id={eos_token_id}"
assert len(input_ids) == len(labels), f"len(input_ids)={len(input_ids)} != len(labels)={len(labels)}"

if isinstance(seq_length, int):
input_ids = _pad_to_seq_length(input_ids, pad_token_id, seq_length)
labels = _pad_to_seq_length(labels, -100, seq_length)

# the attention mask can also be extended in the collator with zeros.
# No padding is applied here. We only ensure mask matches length.
attention_mask += [0] * (len(labels) - len(attention_mask))
return {
"input_ids": input_ids,
Expand All @@ -106,13 +106,91 @@ def _package_tokenized_example(has_chat_template, input_ids, eos_token_id, pad_t
}


def _truncate_prompt_to_fit_plain(
tokenizer: "PreTrainedTokenizer",
prompt: str,
answer: str,
eos_token_id: int,
pad_token_id: int,
max_seq_length: Optional[int],
answer_only_loss_mask: bool,
) -> str:
"""Iteratively remove leading context words until packaged length fits.

Splits on spaces to avoid mid-word truncation. Raises NoContextLeftError
if the prompt must be fully removed to satisfy the constraint.
"""
if not isinstance(max_seq_length, int):
return prompt

current_prompt = prompt
while True:
context_len = len(tokenizer(current_prompt)["input_ids"]) if answer_only_loss_mask else 0
full_ids = tokenizer(current_prompt + answer)["input_ids"]
packaged = _package_tokenized_example(False, full_ids, eos_token_id, pad_token_id, max_seq_length, context_len)
if len(packaged["labels"]) <= max_seq_length:
return current_prompt
# remove up to the first space from the left
cut = current_prompt.find(" ")
if cut == -1:
raise NoContextLeftError("Context fully removed but sequence still exceeds max_seq_length")
current_prompt = current_prompt[cut + 1 :]
# Skip any additional spaces
while current_prompt.startswith(" "):
current_prompt = current_prompt[1:]
if not current_prompt:
raise NoContextLeftError("No context left after truncation")


def _truncate_prompt_to_fit_chat(
tokenizer: "PreTrainedTokenizer",
prompt: str,
answer: str,
eos_token_id: int,
pad_token_id: int,
max_seq_length: Optional[int],
start_of_turn_token: Optional[str],
) -> str:
"""Iteratively remove leading context words for chat-template path."""
if not isinstance(max_seq_length, int):
return prompt

current_prompt = prompt
while True:
messages = [
{"role": "user", "content": current_prompt},
{"role": "assistant", "content": answer},
]
input_ids = tokenizer.apply_chat_template(messages)
if isinstance(start_of_turn_token, str):
start_of_turn_token_id = tokenizer(start_of_turn_token, add_special_tokens=False)["input_ids"][0]
first_start_of_turn_token_id = input_ids.index(start_of_turn_token_id)
response_start = input_ids.index(start_of_turn_token_id, first_start_of_turn_token_id + 1)
else:
response_start = 0
packaged = _package_tokenized_example(
True, input_ids, eos_token_id, pad_token_id, max_seq_length, response_start
)
if len(packaged["labels"]) <= max_seq_length:
return current_prompt
# remove up to the first space from the left
cut = current_prompt.find(" ")
if cut == -1:
raise NoContextLeftError("Context fully removed but sequence still exceeds max_seq_length")
current_prompt = current_prompt[cut + 1 :]
while current_prompt.startswith(" "):
current_prompt = current_prompt[1:]
if not current_prompt:
raise NoContextLeftError("No context left after truncation")


def format_prompt_completion(
tokenizer: "PreTrainedTokenizer",
prompt: str,
answer: str,
eos_token_id: int,
pad_token_id: int,
seq_length: Optional[int] = None,
max_seq_length: Optional[int] = None,
answer_only_loss_mask: bool = True,
) -> Dict[str, List[int]]:
"""
Expand All @@ -124,23 +202,28 @@ def format_prompt_completion(
answer: The answer string.
eos_token_id: The end-of-sequence token id.
pad_token_id: The padding token id.
seq_length: Optional sequence length for padding.
max_seq_length: Optional maximum sequence length. If the packaged
sequence exceeds this length, context is removed from the left at
space boundaries. No padding is applied.

Returns:
A dictionary with the formatted example.
"""
full_text = prompt + answer
# Optionally truncate prompt to fit the requested maximum length
truncated_prompt = _truncate_prompt_to_fit_plain(
tokenizer, prompt, answer, eos_token_id, pad_token_id, max_seq_length, answer_only_loss_mask
)

# Tokenize separately to locate answer start
if answer_only_loss_mask:
prompt_ids = tokenizer(prompt)["input_ids"]
prompt_ids = tokenizer(truncated_prompt)["input_ids"]
len_prompt_ids = len(prompt_ids)
else:
len_prompt_ids = 0
# Tokenize full text
input_ids = tokenizer(full_text)["input_ids"]
input_ids = tokenizer(truncated_prompt + answer)["input_ids"]

return _package_tokenized_example(False, input_ids, eos_token_id, pad_token_id, seq_length, len_prompt_ids)
return _package_tokenized_example(False, input_ids, eos_token_id, pad_token_id, max_seq_length, len_prompt_ids)


def format_chat_template(
Expand All @@ -149,7 +232,7 @@ def format_chat_template(
answer: str,
eos_token_id: int,
pad_token_id: int,
seq_length: Optional[int] = None,
max_seq_length: Optional[int] = None,
start_of_turn_token: Optional[str] = None,
) -> Dict[str, List[int]]:
"""
Expand All @@ -161,14 +244,19 @@ def format_chat_template(
answer: The answer string.
eos_token_id: The end-of-sequence token id.
pad_token_id: The padding token id.
seq_length: Optional sequence length for padding.
max_seq_length: Optional maximum sequence length. If the packaged
sequence exceeds this length, context is removed from the left at
space boundaries. No padding is applied.
start_of_turn_token: The start of turn token string.

Returns:
A dictionary with the formatted example.
"""
truncated_prompt = _truncate_prompt_to_fit_chat(
tokenizer, prompt, answer, eos_token_id, pad_token_id, max_seq_length, start_of_turn_token
)
formatted_text = [
{"role": "user", "content": prompt},
{"role": "user", "content": truncated_prompt},
{"role": "assistant", "content": answer},
]
input_ids = tokenizer.apply_chat_template(formatted_text)
Expand All @@ -180,4 +268,4 @@ def format_chat_template(
response_start = input_ids.index(start_of_turn_token_id, first_start_of_turn_token_id + 1)
else:
response_start = 0
return _package_tokenized_example(True, input_ids, eos_token_id, pad_token_id, seq_length, response_start)
return _package_tokenized_example(True, input_ids, eos_token_id, pad_token_id, max_seq_length, response_start)
17 changes: 7 additions & 10 deletions nemo_automodel/components/datasets/llm/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _formatting_prompts_func(example, tokenizer, eos_token_id, pad_token_id, seq
answer=answer,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
seq_length=seq_length,
max_seq_length=seq_length,
)


Expand All @@ -52,17 +52,16 @@ def _formatting_prompts_func_with_chat_template(
answer=answer,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
seq_length=seq_length,
max_seq_length=seq_length,
start_of_turn_token=start_of_turn_token,
)


def make_squad_dataset(
tokenizer,
seq_length=None,
max_seq_length=None,
limit_dataset_samples=None,
start_of_turn_token=None,
fp8=False,
split="train",
dataset_name="squad",
):
Expand All @@ -79,15 +78,13 @@ def make_squad_dataset(
tokenizer: A Hugging Face tokenizer with attributes
`eos_token_id`, optional `bos_id`, optional `eos_id`, and
optionally `chat_template`/`apply_chat_template`.
seq_length (int, optional): If set, pad/truncate each example to this
length.
max_seq_length (int, optional): If set, will truncate each example to this
length. If smaller than max_seq_length, the sequence is left as is.
limit_dataset_samples (int, optional): If set, limit the number of
examples loaded from the split.
start_of_turn_token (str or None): If using a chat template, the
token that marks the start of each turn. Used to compute the
response offset for `labels`.
fp8 (bool): Flag for future use (e.g., mixed precision). Currently
unused.
split (str): Which split of the dataset to load (e.g. 'train',
'validation').
dataset_name (str): Identifier for the Hugging Face dataset
Expand Down Expand Up @@ -116,10 +113,10 @@ def make_squad_dataset(
pad_token_id = _add_pad_token(tokenizer) or eos_token_id

if chat_template is None:
fmt_fn = lambda x: _formatting_prompts_func(x, tokenizer, eos_token_id, pad_token_id, seq_length)
fmt_fn = lambda x: _formatting_prompts_func(x, tokenizer, eos_token_id, pad_token_id, max_seq_length)
else:
fmt_fn = lambda x: _formatting_prompts_func_with_chat_template(
x, tokenizer, eos_token_id, pad_token_id, seq_length, start_of_turn_token
x, tokenizer, eos_token_id, pad_token_id, max_seq_length, start_of_turn_token
) # noqa: E731

# map the dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ TRANSFORMERS_OFFLINE=1 python -m torch.distributed.run --nproc_per_node=2 --nnod
--dataset.dataset_name /home/TestData/lite/hf_cache/squad/ \
--validation_dataset.dataset_name /home/TestData/lite/hf_cache/squad/ \
--dataset.limit_dataset_samples 1000 \
--dataset.seq_length 512 \
--dataset.max_seq_length 512 \
--validation_dataset.seq_length 512 \
--step_scheduler.ckpt_every_steps 10 \
--checkpoint.enabled true \
Expand Down
2 changes: 1 addition & 1 deletion tests/functional_tests/hf_dcp/L2_DCP_PP2_Checkpoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ TRANSFORMERS_OFFLINE=1 python -m torch.distributed.run --nproc_per_node=2 --nnod
--dataset.dataset_name /home/TestData/lite/hf_cache/squad/ \
--validation_dataset.dataset_name /home/TestData/lite/hf_cache/squad/ \
--dataset.limit_dataset_samples 1000 \
--dataset.seq_length 512 \
--dataset.max_seq_length 512 \
--validation_dataset.seq_length 512 \
--step_scheduler.ckpt_every_steps 10 \
--checkpoint.enabled true \
Expand Down
Loading
Loading