Skip to content

Commit 90e993f

Browse files
julien-csgugger
authored andcommitted
pipeline support for device="mps" (or any other string) (huggingface#18494)
* `pipeline` support for `device="mps"` (or any other string) * Simplify `if` nesting * Update src/transformers/pipelines/base.py Co-authored-by: Sylvain Gugger <[email protected]> * Fix? @sgugger * passing `attr=None` is not the same as not passing `attr` 🤯 Co-authored-by: Sylvain Gugger <[email protected]>
1 parent 065e780 commit 90e993f

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

src/transformers/pipelines/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def pipeline(
422422
revision: Optional[str] = None,
423423
use_fast: bool = True,
424424
use_auth_token: Optional[Union[str, bool]] = None,
425+
device: Optional[Union[int, str, "torch.device"]] = None,
425426
device_map=None,
426427
torch_dtype=None,
427428
trust_remote_code: Optional[bool] = None,
@@ -508,6 +509,9 @@ def pipeline(
508509
use_auth_token (`str` or *bool*, *optional*):
509510
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
510511
when running `huggingface-cli login` (stored in `~/.huggingface`).
512+
device (`int` or `str` or `torch.device`):
513+
Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank like `1`) on which this
514+
pipeline will be allocated.
511515
device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
512516
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
513517
`device_map="auto"` to compute the most optimized `device_map` automatically. [More
@@ -811,4 +815,7 @@ def pipeline(
811815
if feature_extractor is not None:
812816
kwargs["feature_extractor"] = feature_extractor
813817

818+
if device is not None:
819+
kwargs["device"] = device
820+
814821
return pipeline_class(model=model, framework=framework, task=task, **kwargs)

src/transformers/pipelines/base.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def predict(self, X):
704704
Reference to the object in charge of parsing supplied pipeline parameters.
705705
device (`int`, *optional*, defaults to -1):
706706
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
707-
the associated CUDA device id. You can pass native `torch.device` too.
707+
the associated CUDA device id. You can pass native `torch.device` or a `str` too.
708708
binary_output (`bool`, *optional*, defaults to `False`):
709709
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
710710
"""
@@ -747,7 +747,7 @@ def __init__(
747747
framework: Optional[str] = None,
748748
task: str = "",
749749
args_parser: ArgumentHandler = None,
750-
device: int = -1,
750+
device: Union[int, str, "torch.device"] = -1,
751751
binary_output: bool = False,
752752
**kwargs,
753753
):
@@ -760,14 +760,21 @@ def __init__(
760760
self.feature_extractor = feature_extractor
761761
self.modelcard = modelcard
762762
self.framework = framework
763-
if is_torch_available() and isinstance(device, torch.device):
764-
self.device = device
763+
if is_torch_available() and self.framework == "pt":
764+
if isinstance(device, torch.device):
765+
self.device = device
766+
elif isinstance(device, str):
767+
self.device = torch.device(device)
768+
elif device < 0:
769+
self.device = torch.device("cpu")
770+
else:
771+
self.device = torch.device("cuda:{device}")
765772
else:
766-
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
773+
self.device = device
767774
self.binary_output = binary_output
768775

769776
# Special handling
770-
if self.framework == "pt" and self.device.type == "cuda":
777+
if self.framework == "pt" and self.device.type != "cpu":
771778
self.model = self.model.to(self.device)
772779

773780
# Update config with task specific parameters

0 commit comments

Comments
 (0)