Skip to content

Lightning stalls with 2 GPUs on 1 node with SLURM (and apptainer) #19883

@sorenwacker

Description

@sorenwacker

Bug description

When using 2 GPUs on a single node, or multiple nodes on multiple nodes the training does not start while the job keeps running. I use a container to deploy the environment and SLURM. Is there specific cluster/slurm configuration required to make this work?

What version are you seeing the problem on?

v2.2

How to reproduce the bug

script.py

# Lightning implementation

import os
import torch
import torchvision as tv
import lightning as L
from torch.utils.data import DataLoader


class CIFAR10DataModule(L.LightningDataModule):
    def __init__(self, batch_size=64, num_workers=1):
        super().__init__()
        self.data_dir = os.getenv("DATASETS_ROOT", "./data")
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_dataset = None
        self.val_dataset = None

    def prepare_data(self):
        # Ensure CIFAR10 is downloaded once
        tv.datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        tv.datasets.CIFAR10(root=self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Apply transformations
        transform = tv.transforms.Compose([
            tv.transforms.ToTensor(),
            tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # Setup training and validation datasets
        if stage in (None, 'fit', 'train'):
            self.train_dataset = tv.datasets.CIFAR10(root=self.data_dir, train=True, transform=transform)
        if stage in (None, 'fit', 'validate', 'test'):
            self.val_dataset = tv.datasets.CIFAR10(root=self.data_dir, train=False, transform=transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        # Check if val_dataset is defined
        if self.val_dataset is None:
            self.setup('validate')
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)


class CIFAR10Model(L.LightningModule):
    def __init__(self):
        super().__init__()
        # Use the new weights parameter with None to signify no pretrained weights
        self.model = tv.models.resnet18(weights=None)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=False, sync_dist=True)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.001)

def main():
    # Initialize LightningDataModule
    data_module = CIFAR10DataModule()

    # Initialize LightningModel
    model = CIFAR10Model()

    # Get cluster configuration from environment variables
    SLURM_NNODES = int(os.getenv('SLURM_NNODES', '1'))
    SLURM_NTASKS_PER_NODE = int(os.getenv('SLURM_NTASKS_PER_NODE', '1'))

    # Debugging
    print(f"SLURM_NNODES: {SLURM_NNODES}")
    print(f"SLURM_NTASKS_PER_NODE: {SLURM_NTASKS_PER_NODE}")

    # Initialize Trainer
    trainer = L.Trainer(
        max_epochs=5, 
        devices=SLURM_NTASKS_PER_NODE, 
        accelerator='gpu',
        num_nodes=SLURM_NNODES,
        strategy='ddp',
        enable_progress_bar=True
    )
    
    # Train the model
    trainer.fit(model, datamodule=data_module)

    # Evaluate on test set
    test_result = trainer.test(datamodule=data_module)
    print(test_result)

if __name__ == "__main__":
    main()

sbatch submission script

#!/bin/sh
#SBATCH --job-name=cifar-lit-2GPU
...
#SBATCH --time=0:30:00                
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=2
#SBATCH --cpus-per-task=20             
#SBATCH --gres=gpu:2                  
#SBATCH --mem=4GB                   
#SBATCH --output=slurm-%x-%j.out
#SBATCH --error=slurm-%x-%j.err   

# Start measuring execution time
start_time=$(date +%s)

export APPTAINER_HOME=/tudelft.net/staff-umbrella/reit/apptainer
export APPTAINER_NAME=pytorch2.2.1-cuda12.1.sif

# Check that container file exists
if [ ! -f $APPTAINER_HOME/$APPTAINER_NAME ]; then
    ls $APPTAINER_HOME/$APPTAINER_NAME
    exit 1
fi 

# Load CUDA that is compatible to container libraries
module use /opt/insy/modulefiles
module load cuda/12.1

# Run script
srun apptainer exec \
    --nv \
    --env-file ~/.env \
    -B /home/:/home/ \
    -B /tudelft.net/:/tudelft.net/ \
    $APPTAINER_HOME/$APPTAINER_NAME \
    python script.py

# End measuring execution time
end_time=$(date +%s)

elapsed_time=$((end_time - start_time))
echo "Elapsed time: $elapsed_time seconds"

Error messages and logs

==> slurm-cifar-lit-2GPU-10065210.err <==
HPU available: False, using: 0 HPUs
/opt/conda/envs/__apptainer__/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------
==> slurm-cifar-lit-2GPU-10065210.out <==
SLURM_NNODES: 1
SLURM_NTASKS_PER_NODE: 2
SLURM_NNODES: 1
SLURM_NTASKS_PER_NODE: 2
Files already downloaded and verified
Files already downloaded and verified

Environment

name: __apptainer__
channels:
  - pytorch
  - anaconda
  - nvidia
  - conda-forge
dependencies:
  - _libgcc_mutex=0.1
  - _openmp_mutex=4.5
  - accelerate=0.30.1
  - aiohttp=3.9.5
  - aiosignal=1.3.1
  - alsa-lib=1.2.11
  - annotated-types=0.6.0
  - anyio=4.3.0
  - aom=3.9.0
  - appdirs=1.4.4
  - argon2-cffi=23.1.0
  - argon2-cffi-bindings=21.2.0
  - arrow=1.3.0
  - asttokens=2.4.1
  - async-lru=2.0.4
  - attr=2.5.1
  - attrs=23.2.0
  - aws-c-auth=0.7.20
  - aws-c-cal=0.6.12
  - aws-c-common=0.9.17
  - aws-c-compression=0.2.18
  - aws-c-event-stream=0.4.2
  - aws-c-http=0.8.1
  - aws-c-io=0.14.8
  - aws-c-mqtt=0.10.4
  - aws-c-s3=0.5.8
  - aws-c-sdkutils=0.1.16
  - aws-checksums=0.1.18
  - aws-crt-cpp=0.26.8
  - aws-sdk-cpp=1.11.267
  - babel=2.14.0
  - beautifulsoup4=4.12.3
  - blas=2.116
  - blas-devel=3.9.0
  - bleach=6.1.0
  - bokeh=3.4.1
  - brotli=1.1.0
  - brotli-bin=1.1.0
  - brotli-python=1.1.0
  - bzip2=1.0.8
  - c-ares=1.28.1
  - ca-certificates=2024.2.2
  - cached-property=1.5.2
  - cached_property=1.5.2
  - cairo=1.18.0
  - catalogue=2.0.10
  - certifi=2024.2.2
  - cffi=1.16.0
  - charset-normalizer=3.3.2
  - click=8.1.7
  - cloudpathlib=0.16.0
  - cloudpickle=3.0.0
  - colorama=0.4.6
  - comm=0.2.2
  - confection=0.1.4
  - contourpy=1.2.1
  - cuda-cudart=12.1.105
  - cuda-cudart_linux-64=12.1.105
  - cuda-cupti=12.1.105
  - cuda-libraries=12.1.0
  - cuda-nvrtc=12.1.105
  - cuda-nvtx=12.1.105
  - cuda-opencl=12.1.105
  - cuda-runtime=12.1.0
  - cuda-version=12.1
  - cycler=0.12.1
  - cymem=2.0.8
  - cython-blis=0.7.10
  - cytoolz=0.12.3
  - dask=2024.5.0
  - dask-core=2024.5.0
  - dask-expr=1.1.0
  - datasets=2.19.1
  - dav1d=1.2.1
  - dbus=1.13.6
  - debugpy=1.8.1
  - decorator=5.1.1
  - deepspeed=0.14.0
  - defusedxml=0.7.1
  - diffusers=0.27.2
  - dill=0.3.8
  - distributed=2024.5.0
  - docker-pycreds=0.4.0
  - double-conversion=3.3.0
  - entrypoints=0.4
  - exceptiongroup=1.2.0
  - executing=2.0.1
  - expat=2.6.2
  - fastai=2.7.15
  - fastcore=1.5.35
  - fastdownload=0.0.7
  - fastprogress=1.0.3
  - ffmpeg=6.1.1
  - filelock=3.14.0
  - font-ttf-dejavu-sans-mono=2.37
  - font-ttf-inconsolata=3.000
  - font-ttf-source-code-pro=2.038
  - font-ttf-ubuntu=0.83
  - fontconfig=2.14.2
  - fonts-conda-ecosystem=1
  - fonts-conda-forge=1
  - fonttools=4.51.0
  - fqdn=1.5.1
  - freeglut=3.2.2
  - freetype=2.12.1
  - fribidi=1.0.10
  - frozenlist=1.4.1
  - fsspec=2024.3.1
  - gettext=0.22.5
  - gettext-tools=0.22.5
  - gflags=2.2.2
  - gitdb=4.0.11
  - gitpython=3.1.43
  - glib=2.80.2
  - glib-tools=2.80.2
  - glog=0.7.0
  - gmp=6.3.0
  - gmpy2=2.1.5
  - gnutls=3.7.9
  - graphite2=1.3.13
  - gst-plugins-base=1.22.9
  - gstreamer=1.22.9
  - h11=0.14.0
  - h2=4.1.0
  - harfbuzz=8.4.0
  - hdf5=1.14.3
  - hjson-py=3.1.0
  - hpack=4.0.0
  - httpcore=1.0.5
  - httpx=0.27.0
  - huggingface_hub=0.23.0
  - hyperframe=6.0.1
  - icu=73.2
  - idna=3.7
  - imath=3.1.11
  - importlib-metadata=7.1.0
  - importlib_metadata=7.1.0
  - importlib_resources=6.4.0
  - ipykernel=6.29.3
  - ipython=8.24.0
  - ipywidgets=8.1.2
  - isoduration=20.11.0
  - jasper=4.2.4
  - jedi=0.19.1
  - jinja2=3.1.4
  - joblib=1.4.2
  - json5=0.9.25
  - jsonpointer=2.4
  - jsonschema=4.22.0
  - jsonschema-specifications=2023.12.1
  - jsonschema-with-format-nongpl=4.22.0
  - jupyter-lsp=2.2.5
  - jupyter_client=8.6.1
  - jupyter_core=5.7.2
  - jupyter_events=0.10.0
  - jupyter_server=2.14.0
  - jupyter_server_terminals=0.5.3
  - jupyterlab=4.2.0
  - jupyterlab_pygments=0.3.0
  - jupyterlab_server=2.27.1
  - jupyterlab_widgets=3.0.10
  - keyutils=1.6.1
  - kiwisolver=1.4.5
  - krb5=1.21.2
  - lame=3.100
  - langcodes=3.4.0
  - language-data=1.2.0
  - lcms2=2.16
  - ld_impl_linux-64=2.40
  - lerc=4.0.0
  - libabseil=20240116.2
  - libaec=1.1.3
  - libaio=0.3.113
  - libarrow=16.0.0
  - libarrow-acero=16.0.0
  - libarrow-dataset=16.0.0
  - libarrow-substrait=16.0.0
  - libasprintf=0.22.5
  - libasprintf-devel=0.22.5
  - libass=0.17.1
  - libblas=3.9.0
  - libbrotlicommon=1.1.0
  - libbrotlidec=1.1.0
  - libbrotlienc=1.1.0
  - libcap=2.69
  - libcblas=3.9.0
  - libclang-cpp18.1=18.1.5
  - libclang13=18.1.5
  - libcrc32c=1.1.2
  - libcublas=12.1.0.26
  - libcufft=11.0.2.4
  - libcufile=1.6.1.9
  - libcups=2.3.3
  - libcurand=10.3.2.106
  - libcurl=8.7.1
  - libcusolver=11.4.4.55
  - libcusparse=12.0.2.55
  - libdeflate=1.20
  - libdrm=2.4.120
  - libedit=3.1.20191231
  - libev=4.33
  - libevent=2.1.12
  - libexpat=2.6.2
  - libffi=3.4.2
  - libflac=1.4.3
  - libgcc-ng=13.2.0
  - libgcrypt=1.10.3
  - libgettextpo=0.22.5
  - libgettextpo-devel=0.22.5
  - libgfortran-ng=13.2.0
  - libgfortran5=13.2.0
  - libglib=2.80.2
  - libglu=9.0.0
  - libgoogle-cloud=2.23.0
  - libgoogle-cloud-storage=2.23.0
  - libgpg-error=1.49
  - libgrpc=1.62.2
  - libhwloc=2.10.0
  - libiconv=1.17
  - libidn2=2.3.7
  - libjpeg-turbo=3.0.0
  - liblapack=3.9.0
  - liblapacke=3.9.0
  - libllvm18=18.1.5
  - libnghttp2=1.58.0
  - libnpp=12.0.2.50
  - libnsl=2.0.1
  - libnvjitlink=12.1.105
  - libnvjpeg=12.1.1.14
  - libogg=1.3.4
  - libopencv=4.9.0
  - libopenvino=2024.0.0
  - libopenvino-auto-batch-plugin=2024.0.0
  - libopenvino-auto-plugin=2024.0.0
  - libopenvino-hetero-plugin=2024.0.0
  - libopenvino-intel-cpu-plugin=2024.0.0
  - libopenvino-intel-gpu-plugin=2024.0.0
  - libopenvino-ir-frontend=2024.0.0
  - libopenvino-onnx-frontend=2024.0.0
  - libopenvino-paddle-frontend=2024.0.0
  - libopenvino-pytorch-frontend=2024.0.0
  - libopenvino-tensorflow-frontend=2024.0.0
  - libopenvino-tensorflow-lite-frontend=2024.0.0
  - libopus=1.3.1
  - libparquet=16.0.0
  - libpciaccess=0.18
  - libpng=1.6.43
  - libpq=16.3
  - libprotobuf=4.25.3
  - libre2-11=2023.09.01
  - libsndfile=1.2.2
  - libsodium=1.0.18
  - libsqlite=3.45.3
  - libssh2=1.11.0
  - libstdcxx-ng=13.2.0
  - libsystemd0=255
  - libtasn1=4.19.0
  - libthrift=0.19.0
  - libtiff=4.6.0
  - libunistring=0.9.10
  - libutf8proc=2.8.0
  - libuuid=2.38.1
  - libva=2.21.0
  - libvorbis=1.3.7
  - libvpx=1.14.0
  - libwebp-base=1.4.0
  - libxcb=1.15
  - libxcrypt=4.4.36
  - libxkbcommon=1.7.0
  - libxml2=2.12.7
  - libzlib=1.2.13
  - lightning=2.2.4
  - lightning-utilities=0.11.2
  - llvm-openmp=15.0.7
  - locket=1.0.0
  - lz4=4.3.3
  - lz4-c=1.9.4
  - marisa-trie=1.1.0
  - markdown-it-py=3.0.0
  - markupsafe=2.1.5
  - matplotlib-base=3.8.4
  - matplotlib-inline=0.1.7
  - mdurl=0.1.2
  - mistune=3.0.2
  - mkl=2022.1.0
  - mkl-devel=2022.1.0
  - mkl-include=2022.1.0
  - mpc=1.3.1
  - mpfr=4.2.1
  - mpg123=1.32.6
  - mpmath=1.3.0
  - msgpack-python=1.0.8
  - multidict=6.0.5
  - multiprocess=0.70.16
  - munkres=1.1.4
  - murmurhash=1.0.10
  - mysql-common=8.3.0
  - mysql-libs=8.3.0
  - nbclient=0.10.0
  - nbconvert-core=7.16.4
  - nbformat=5.10.4
  - ncurses=6.5
  - nest-asyncio=1.6.0
  - nettle=3.9.1
  - networkx=3.3
  - notebook-shim=0.2.4
  - numpy=1.26.4
  - ocl-icd=2.3.2
  - onnx=1.16.0
  - opencv=4.9.0
  - openexr=3.2.2
  - openh264=2.4.1
  - openjpeg=2.5.2
  - openssl=3.3.0
  - orc=2.0.0
  - overrides=7.7.0
  - p11-kit=0.24.1
  - packaging=24.0
  - pandas=2.2.2
  - pandocfilters=1.5.0
  - parso=0.8.4
  - partd=1.4.2
  - pathtools=0.1.2
  - pathy=0.10.1
  - patsy=0.5.6
  - pcre2=10.43
  - pexpect=4.9.0
  - pickleshare=0.7.5
  - pillow=10.3.0
  - pip=24.0
  - pixman=0.43.2
  - pkgutil-resolve-name=1.3.10
  - platformdirs=4.2.1
  - plotly=5.22.0
  - preshed=3.0.9
  - pretty_errors=1.2.25
  - prometheus_client=0.20.0
  - prompt-toolkit=3.0.42
  - protobuf=4.25.3
  - psutil=5.9.8
  - pthread-stubs=0.4
  - ptyprocess=0.7.0
  - pugixml=1.14
  - pulseaudio-client=17.0
  - pure_eval=0.2.2
  - py-cpuinfo=9.0.0
  - py-opencv=4.9.0
  - pyarrow=16.0.0
  - pyarrow-core=16.0.0
  - pyarrow-hotfix=0.6
  - pycparser=2.22
  - pydantic=2.7.1
  - pydantic-core=2.18.2
  - pygments=2.18.0
  - pynvml=11.5.0
  - pyparsing=3.1.2
  - pysocks=1.7.1
  - python=3.11.9
  - python-dateutil=2.9.0
  - python-fastjsonschema=2.19.1
  - python-json-logger=2.0.7
  - python-tzdata=2024.1
  - python-xxhash=3.4.1
  - python_abi=3.11
  - pytorch=2.2.1
  - pytorch-cuda=12.1
  - pytorch-lightning=2.2.2
  - pytorch-mutex=1.0
  - pytz=2024.1
  - pyyaml=6.0.1
  - pyzmq=26.0.3
  - qt6-main=6.6.3
  - re2=2023.09.01
  - readline=8.2
  - referencing=0.35.1
  - regex=2024.5.10
  - requests=2.31.0
  - rfc3339-validator=0.1.4
  - rfc3986-validator=0.1.1
  - rich=13.7.1
  - rpds-py=0.18.1
  - s2n=1.4.13
  - safetensors=0.4.3
  - scikit-learn=1.4.2
  - scipy=1.13.0
  - seaborn=0.13.2
  - seaborn-base=0.13.2
  - send2trash=1.8.3
  - sentry-sdk=2.1.1
  - setproctitle=1.3.3
  - setuptools=69.5.1
  - shellingham=1.5.4
  - six=1.16.0
  - smart_open=6.4.0
  - smmap=5.0.0
  - snappy=1.2.0
  - sniffio=1.3.1
  - sortedcontainers=2.4.0
  - soupsieve=2.5
  - spacy=3.7.3
  - spacy-legacy=3.0.12
  - spacy-loggers=1.0.5
  - srsly=2.4.8
  - stack_data=0.6.2
  - statsmodels=0.14.1
  - svt-av1=2.0.0
  - sympy=1.12
  - tbb=2021.12.0
  - tblib=3.0.0
  - tenacity=8.3.0
  - terminado=0.18.1
  - thinc=8.2.3
  - threadpoolctl=3.5.0
  - tinycss2=1.3.0
  - tk=8.6.13
  - tokenizers=0.19.1
  - tomli=2.0.1
  - toolz=0.12.1
  - torchaudio=2.2.1
  - torchmetrics=1.4.0
  - torchtriton=2.2.0
  - torchvision=0.17.1
  - tornado=6.4
  - tqdm=4.66.4
  - traitlets=5.14.3
  - transformers=4.40.2
  - typer=0.9.4
  - types-python-dateutil=2.9.0.20240316
  - typing-extensions=4.11.0
  - typing_extensions=4.11.0
  - typing_utils=0.1.0
  - tzdata=2024a
  - uri-template=1.3.0
  - urllib3=2.2.1
  - wandb=0.16.5
  - wasabi=1.1.2
  - wcwidth=0.2.13
  - weasel=0.3.4
  - webcolors=1.13
  - webencodings=0.5.1
  - websocket-client=1.8.0
  - wheel=0.43.0
  - widgetsnbextension=4.0.10
  - x264=1!164.3095
  - x265=3.5
  - xcb-util=0.4.0
  - xcb-util-cursor=0.1.4
  - xcb-util-image=0.4.0
  - xcb-util-keysyms=0.4.0
  - xcb-util-renderutil=0.3.9
  - xcb-util-wm=0.4.1
  - xkeyboard-config=2.41
  - xorg-fixesproto=5.0
  - xorg-inputproto=2.3.2
  - xorg-kbproto=1.0.7
  - xorg-libice=1.1.1
  - xorg-libsm=1.2.4
  - xorg-libx11=1.8.9
  - xorg-libxau=1.0.11
  - xorg-libxdmcp=1.1.3
  - xorg-libxext=1.3.4
  - xorg-libxfixes=5.0.3
  - xorg-libxi=1.7.10
  - xorg-libxrender=0.9.11
  - xorg-renderproto=0.11.1
  - xorg-xextproto=7.3.0
  - xorg-xproto=7.0.31
  - xxhash=0.8.2
  - xyzservices=2024.4.0
  - xz=5.2.6
  - yaml=0.2.5
  - yarl=1.9.4
  - zeromq=4.3.5
  - zict=3.0.0
  - zipp=3.17.0
  - zlib=1.2.13
  - zstd=1.5.6
  - pip:
      - mpi4py==3.1.6
      - watermark==2.4.3
prefix: /opt/conda/envs/__apptainer__

cc @lantiga @justusschock

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions