-
Notifications
You must be signed in to change notification settings - Fork 30.3k
Description
System Info
transformers
version: 4.36.2- Platform: macOS-10.16-x86_64-i386-64bit
- Python version: 3.10.13
- Huggingface_hub version: 0.20.2
- Safetensors version: 0.4.1
- Accelerate version: 0.26.1
- Accelerate config: - compute_environment: LOCAL_MACHINE
- distributed_type: NO
- mixed_precision: fp16
- use_cpu: False
- debug: False
- num_processes: 1
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: [] - PyTorch version (GPU?): 2.1.2 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: does not matter
- Using distributed or parallel set-up in script?: does not matter
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
import os
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
from transformers.modeling_outputs import BaseModelOutput
# Dummy Dataset
class DummyDataset(Dataset):
def __init__(self, size=100):
self.size = size
self.data = torch.rand(size, 10) # Random data
self.labels = torch.randint(0, 2, (size,)) # Binary labels
def __len__(self):
return self.size
def __getitem__(self, idx):
return {'input_ids': self.data[idx], 'labels': self.labels[idx]}
@dataclass
class DummyModelOutput(BaseModelOutput):
loss: torch.Tensor = None
logits: torch.Tensor = None
# Dummy Model
class DummyModel(torch.nn.Module):
def __init__(self):
super(DummyModel, self).__init__()
self.linear = torch.nn.Linear(10, 2)
def forward(self, input_ids, labels=None) -> DummyModelOutput:
outputs = self.linear(input_ids)
loss = F.cross_entropy(outputs, labels)
return DummyModelOutput(loss=loss, logits=outputs)
if __name__ == '__main__':
# using wandb, because it logs system metrics periodically
os.environ["WANDB_PROJECT"] = "dummy_project"
# Create dataset and model instances
dataset = DummyDataset(size=1000)
model = DummyModel()
persistent_workers = False # set to True to enable persistent workers
# Training arguments
training_args = TrainingArguments(
output_dir="./test_trainer",
run_name=f'dataloader_peristent_workers={persistent_workers}',
num_train_epochs=20,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
dataloader_num_workers=8,
dataloader_persistent_workers=persistent_workers,
logging_strategy="no",
evaluation_strategy="epoch",
)
# Initialize the custom trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=dataset,
)
# Train the model
trainer.train()
Expected behavior
Since the get_eval_loader is called on every evaluate call, with dataloader_persistent_workers=True
the previous worker processes are not killed and leads to a fork-bomb and exhausts system resources and causes instability/crash.
As you can see in the below plots generated with the reproduction script (in the wandb system metrics section),
- persistent data loader workers cause speedup (mainly because the training loader does not recreate all processes at every epoch), but evaluation loaders cause the fork-bomb.
- without persistent data loader workers, speed is slow, but the number of processes is constant.
Having the persistent dataloader option is good. Still, it is necessary to fix the eval loader logic, create it once, and reuse it since the eval datasets won't change in the middle of training.
umarbutler and FrancoisAlexandreTremblay
Metadata
Metadata
Assignees
Labels
No labels