Skip to content

dataloader_persistent_workers=True causes fork-bomb due to repeated creation of eval_dataloader #28469

@naba89

Description

@naba89

System Info

  • transformers version: 4.36.2
  • Platform: macOS-10.16-x86_64-i386-64bit
  • Python version: 3.10.13
  • Huggingface_hub version: 0.20.2
  • Safetensors version: 0.4.1
  • Accelerate version: 0.26.1
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: NO
    - mixed_precision: fp16
    - use_cpu: False
    - debug: False
    - num_processes: 1
    - machine_rank: 0
    - num_machines: 1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • PyTorch version (GPU?): 2.1.2 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: does not matter
  • Using distributed or parallel set-up in script?: does not matter

Who can help?

@muellerzr @pacman100

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import os
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
from transformers.modeling_outputs import BaseModelOutput


# Dummy Dataset
class DummyDataset(Dataset):
    def __init__(self, size=100):
        self.size = size
        self.data = torch.rand(size, 10)  # Random data
        self.labels = torch.randint(0, 2, (size,))  # Binary labels

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return {'input_ids': self.data[idx], 'labels': self.labels[idx]}


@dataclass
class DummyModelOutput(BaseModelOutput):
    loss: torch.Tensor = None
    logits: torch.Tensor = None


# Dummy Model
class DummyModel(torch.nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.linear = torch.nn.Linear(10, 2)

    def forward(self, input_ids, labels=None) -> DummyModelOutput:
        outputs = self.linear(input_ids)
        loss = F.cross_entropy(outputs, labels)
        return DummyModelOutput(loss=loss, logits=outputs)


if __name__ == '__main__':

    # using wandb, because it logs system metrics periodically
    os.environ["WANDB_PROJECT"] = "dummy_project"

    # Create dataset and model instances
    dataset = DummyDataset(size=1000)
    model = DummyModel()
    
    persistent_workers = False    # set to True to enable persistent workers

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./test_trainer",
        run_name=f'dataloader_peristent_workers={persistent_workers}',
        num_train_epochs=20,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        dataloader_num_workers=8,
        dataloader_persistent_workers=persistent_workers,
        logging_strategy="no",
        evaluation_strategy="epoch",
    )

    # Initialize the custom trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        eval_dataset=dataset,
    )

    # Train the model
    trainer.train()

Expected behavior

Since the get_eval_loader is called on every evaluate call, with dataloader_persistent_workers=True the previous worker processes are not killed and leads to a fork-bomb and exhausts system resources and causes instability/crash.

As you can see in the below plots generated with the reproduction script (in the wandb system metrics section),

  • persistent data loader workers cause speedup (mainly because the training loader does not recreate all processes at every epoch), but evaluation loaders cause the fork-bomb.
  • without persistent data loader workers, speed is slow, but the number of processes is constant.

image

Having the persistent dataloader option is good. Still, it is necessary to fix the eval loader logic, create it once, and reuse it since the eval datasets won't change in the middle of training.

This option was added in #27058 and #27189

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions