Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Memory requirements to replicate on Pems-Bay #88

@steve3nto

Description

@steve3nto

Thanks to the authors for the great work and for sharing the code!

I am interested in replicating the results on Pems-Bay before trying the model on a custom dataset of similar size.
I am using the suggested command:

I set accelerator="gpu" in the Trainer and run the command from the README:

python train.py spacetimeformer pems-bay --batch_size 32 --warmup_steps 1000 --d_model 200 --d_ff 700 --enc_layers 5 --dec_layers 6 --dropout_emb .1 --dropout_ff .3 --run_name pems-bay-spatiotemporal --base_lr 1e-3 --l2_coeff 1e-3 --loss mae --data_path /path/to/pems_bay/ --d_qk 30 --d_v 30 --n_heads 10 --patience 10 --decay_factor .8

It gives an OOM error on my GPU (which has more than 20GB of space)
How many GBs of GPU memory are required? Did you use multiple GPUs to train spacetimeformer on Pems-Bay?

I also tried keeping accelerator="dp" , but it gives this error:

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

Traceback (most recent call last):
  File "train.py", line 873, in <module>
    main(args)
  File "train.py", line 851, in main
    trainer.fit(forecaster, datamodule=data_module)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 771, in fit
    self._call_and_handle_interrupt(
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 722, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 812, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1218, in _run
    self.strategy.setup(self)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 172, in setup
    self.configure_ddp()
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 294, in configure_ddp
    self.model = self._setup_model(LightningDistributedModule(self.model))
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 178, in _setup_model
    return DistributedDataParallel(module=model, device_ids=device_ids, **self._ddp_kwargs)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 650, in __init__
    self._ddp_init_helper(parameters, expect_sparse_gradient, param_to_name_mapping)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 738, in _ddp_init_helper
    self._passing_sync_batchnorm_handle(self.module)
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1637, in _passing_sync_batchnorm_handle
    self._log_and_throw(
  File "/home/simoscopi/miniconda3/envs/spacetimeformer/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 674, in _log_and_throw
    raise err_type(err_msg)
ValueError: SyncBatchNorm layers only work with GPU modules

I hope someone can help me run spacetimeformer in multi-gpu mode.

In case it is not possible, do you have some suggestion to reduce memory requirements while mantaining good performance?
For now I could fit the model on a single GPU only by setting batch_size=4, but I fear this would lead to a bad model fit.

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions