Skip to content

Commit f0e3f04

Browse files
committed
Simplify if nesting
1 parent 904a840 commit f0e3f04

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

src/transformers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,8 @@ def pipeline(
510510
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
511511
when running `huggingface-cli login` (stored in `~/.huggingface`).
512512
device (`int` or `str` or `torch.device`):
513-
Sent directly as `model_kwargs` (just a simpler shortcut). Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`,
514-
`"mps"`, or a GPU ordinal rank like `1`) on which this pipeline will be allocated.
513+
Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank like `1`) on which this
514+
pipeline will be allocated.
515515
device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
516516
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
517517
`device_map="auto"` to compute the most optimized `device_map` automatically. [More

src/transformers/pipelines/base.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -760,14 +760,17 @@ def __init__(
760760
self.feature_extractor = feature_extractor
761761
self.modelcard = modelcard
762762
self.framework = framework
763-
if is_torch_available() and isinstance(device, torch.device):
764-
self.device = device
763+
if is_torch_available() and self.framework == "pt":
764+
if isinstance(device, torch.device):
765+
self.device = device
766+
elif type(device) == str:
767+
self.device = torch.device(device)
768+
elif device < 0:
769+
self.device = torch.device("cpu")
770+
else:
771+
self.device = torch.device("cuda:{device}")
765772
else:
766-
self.device = (
767-
device
768-
if framework == "tf"
769-
else torch.device(device if type(device) == str else "cpu" if device < 0 else f"cuda:{device}")
770-
)
773+
self.device = device
771774
self.binary_output = binary_output
772775

773776
# Special handling

0 commit comments

Comments
 (0)