File tree Expand file tree Collapse file tree 2 files changed +19
-2
lines changed
src/transformers/pipelines Expand file tree Collapse file tree 2 files changed +19
-2
lines changed Original file line number Diff line number Diff line change @@ -693,7 +693,7 @@ def predict(self, X):
693
693
Reference to the object in charge of parsing supplied pipeline parameters.
694
694
device (`int`, *optional*, defaults to -1):
695
695
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
696
- the associated CUDA device id.
696
+ the associated CUDA device id. You can pass native `torch.device` too.
697
697
binary_output (`bool`, *optional*, defaults to `False`):
698
698
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
699
699
"""
@@ -750,7 +750,10 @@ def __init__(
750
750
self .feature_extractor = feature_extractor
751
751
self .modelcard = modelcard
752
752
self .framework = framework
753
- self .device = device if framework == "tf" else torch .device ("cpu" if device < 0 else f"cuda:{ device } " )
753
+ if is_torch_available () and isinstance (device , torch .device ):
754
+ self .device = device
755
+ else :
756
+ self .device = device if framework == "tf" else torch .device ("cpu" if device < 0 else f"cuda:{ device } " )
754
757
self .binary_output = binary_output
755
758
756
759
# Special handling
Original file line number Diff line number Diff line change @@ -39,6 +39,20 @@ def test_small_model_pt(self):
39
39
outputs = text_classifier ("This is great !" )
40
40
self .assertEqual (nested_simplify (outputs ), [{"label" : "LABEL_0" , "score" : 0.504 }])
41
41
42
+ @require_torch
43
+ def test_accepts_torch_device (self ):
44
+ import torch
45
+
46
+ text_classifier = pipeline (
47
+ task = "text-classification" ,
48
+ model = "hf-internal-testing/tiny-random-distilbert" ,
49
+ framework = "pt" ,
50
+ device = torch .device ("cpu" ),
51
+ )
52
+
53
+ outputs = text_classifier ("This is great !" )
54
+ self .assertEqual (nested_simplify (outputs ), [{"label" : "LABEL_0" , "score" : 0.504 }])
55
+
42
56
@require_tf
43
57
def test_small_model_tf (self ):
44
58
text_classifier = pipeline (
You can’t perform that action at this time.
0 commit comments