Skip to content

TrainingArguments does not support mps device (Mac M1 GPU) #17971

@saattrupdan

Description

@saattrupdan

System Info

  • transformers version: 4.21.0.dev0
  • Platform: macOS-12.4-arm64-arm-64bit
  • Python version: 3.8.9
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.12.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@sgugger

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

export TASK_NAME=wnli
python run_glue.py \
  --model_name_or_path bert-base-cased \
  --task_name $TASK_NAME \
  --do_train \
  --do_eval \
  --max_seq_length 128 \
  --per_device_train_batch_size 32 \
  --learning_rate 2e-5 \
  --num_train_epochs 3 \
  --output_dir /tmp/$TASK_NAME/

Expected behavior

When running the Trainer.train on a machine with an MPS GPU, it still just uses the CPU. I expected it to use the MPS GPU. This is supported by torch in the newest version 1.12.0, and we can check if the MPS GPU is available using torch.backends.mps.is_available().

It seems like the issue lies in the TrainingArguments._setup_devices method, which doesn't appear to allow for the case where device = "mps".

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions