1515
1616from typing import Any , Dict , List , Union
1717
18- from paddlenlp .data import DataCollatorWithPadding
19- from paddlenlp .transformers import AutoModelForSequenceClassification , AutoTokenizer
20-
2118import numpy as np
2219import paddle
2320import paddle .nn .functional as F
24- from .utils import static_mode_guard , dygraph_mode_guard
21+ from scipy .special import expit as np_sigmoid
22+ from scipy .special import softmax as np_softmax
23+
24+ from paddlenlp .data import DataCollatorWithPadding
25+ from paddlenlp .transformers import AutoModelForSequenceClassification , AutoTokenizer
26+
2527from .task import Task
28+ from .utils import dygraph_mode_guard , static_mode_guard
2629
2730usage = r"""
2831 from paddlenlp import Taskflow
29- id2label = {
30- 0: "negative",
31- 1: "positive"
32- }
3332 text_cls = Taskflow(
3433 "text_classification",
3534 model="multi_class",
3635 task_path=<local_saved_model>,
37- id2label=id2label
36+ id2label={0: "negative", 1: "positive"}
3837 )
3938 text_cls('房间依然很整洁,相当不错')
4039 '''
4140 [
42- {'text': '房间依然很整洁,相当不错',
43- 'label': 'positive',
44- 'score': 0.80}
41+ {
42+ 'text': '房间依然很整洁,相当不错',
43+ 'predictions: [{
44+ 'label': 'positive',
45+ 'score': 0.80
46+ }]
47+ }
4548 ]
4649 '''
47-
48- text_cls(['房间依然很整洁,相当不错',
49- '味道不咋地,很一般'])
50+ text_cls = Taskflow(
51+ "text_classification",
52+ model="multi_label",
53+ task_path=<local_saved_model>,
54+ id2label={ 0: "体育", 1: "经济", 2: "娱乐"}
55+ )
56+ text_cls(['这是一条体育娱乐新闻的例子',
57+ '这是一条经济新闻'])
5058 '''
5159 [
52- {'text': '房间依然很整洁,相当不错',
53- 'label': 'positive',
54- 'score': 0.90},
55- {'text': '味道不咋地,很一般',
56- 'label': 'negative',
57- 'score': 0.88},
60+ {
61+ 'text': '这是一条体育娱乐新闻的例子',
62+ 'predictions: [
63+ {
64+ 'label': '体育',
65+ 'score': 0.80
66+ },
67+ {
68+ 'label': '娱乐',
69+ 'score': 0.90
70+ }
71+ ]
72+ },
73+ {
74+ 'text': '这是一条经济新闻',
75+ 'predictions: [
76+ {
77+ 'label': '经济',
78+ 'score': 0.80
79+ }
80+ ]
81+ }
5882 ]
5983 """
6084
@@ -73,18 +97,29 @@ class TextClassificationTask(Task):
7397
7498 Args:
7599 task (string): The name of task.
76- model (string): Mode of the classification, only support ` multi_class` for now
100+ model (string): Mode of the classification, Supports [" multi_class", "multi_class"]
77101 task_path (string): The local file path to the model path or a pre-trained model
78102 id2label (string): The dictionary to map the predictions from class ids to class names
79103 is_static_model (string): Whether the model is a static model
104+ multilabel_threshold (float): The probability threshold used for the multi_label setup. Only effective if model = "multi_label". Defaults to 0.5
80105 kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
81106 """
82107
83- def __init__ (self , task : str , model : str , id2label : Dict [int , str ], is_static_model : bool = False , ** kwargs ):
108+ def __init__ (
109+ self ,
110+ task : str ,
111+ model : str ,
112+ id2label : Dict [int , str ],
113+ is_static_model : bool = False ,
114+ multilabel_threshold : float = 0.5 ,
115+ ** kwargs
116+ ):
84117 super ().__init__ (task = task , model = model , is_static_model = is_static_model , ** kwargs )
85118 self .id2label = id2label
86119 self .is_static_model = is_static_model
87120 self ._construct_tokenizer (self ._task_path )
121+ self .multilabel_threshold = multilabel_threshold
122+
88123 if self .is_static_model :
89124 self ._get_inference_model ()
90125 else :
@@ -135,40 +170,58 @@ def _run_model(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
135170 """
136171 Run the task model from the outputs of the `_tokenize` function.
137172 """
138- # TODO: support multi_label, hierachical classification
139- model_outputs = []
173+ # TODO: support hierachical classification
174+ outputs = {}
175+ outputs ["text" ] = inputs ["text" ]
176+ outputs ["batch_logits" ] = []
140177 if self .is_static_model :
141178 with static_mode_guard ():
142179 for batch in inputs ["batches" ]:
143180 for i , input_name in enumerate (self .predictor .get_input_names ()):
144181 self .input_handles [i ].copy_from_cpu (batch [input_name ])
145182 self .predictor .run ()
146183 logits = self .output_handle [0 ].copy_to_cpu ().tolist ()
147- pred_indices = np .argmax (logits , axis = - 1 )
148- probs = softmax (logits , axis = - 1 )
149- for prob , pred_index in zip (probs , pred_indices ):
150- model_outputs .append ({"label" : pred_index , "score" : prob [pred_index ]})
184+ outputs ["batch_logits" ].append (logits )
151185 else :
152186 with dygraph_mode_guard ():
153187 for batch in inputs ["batches" ]:
154188 logits = self ._model (** batch )
155- probs = F .softmax (logits , axis = - 1 ).tolist ()
156- pred_indices = paddle .argmax (logits , axis = - 1 ).tolist ()
157- for prob , pred_index in zip (probs , pred_indices ):
158- model_outputs .append ({"label" : pred_index , "score" : prob [pred_index ]})
159- outputs = {}
160- outputs ["text" ] = inputs ["text" ]
161- outputs ["model_outputs" ] = model_outputs
189+ outputs ["batch_logits" ].append (logits )
162190 return outputs
163191
164192 def _postprocess (self , inputs : Dict [str , Any ]) -> Dict [str , Any ]:
165193 """
166- The model output is tag ids, this function will convert the model output to raw text.
194+ This function converts the model logits output to class score and predictions
167195 """
168- # TODO: support multi_label, hierachical classification
196+ # TODO: support hierachical classification
169197 postprocessed_outputs = []
170- for i , model_output in enumerate (inputs ["model_outputs" ]):
171- model_output ["label" ] = self .id2label [model_output ["label" ]]
172- model_output ["text" ] = inputs ["text" ][i ]
173- postprocessed_outputs .append (model_output )
198+ for logits in inputs ["batch_logits" ]:
199+ if self .model == "multi_class" :
200+ if isinstance (logits , paddle .Tensor ): # dygraph
201+ scores = F .softmax (logits , axis = - 1 ).numpy ()
202+ labels = paddle .argmax (logits , axis = - 1 ).numpy ()
203+ else : # static graph
204+ scores = np_softmax (logits , axis = - 1 )
205+ labels = np .argmax (logits , axis = - 1 )
206+ for score , label in zip (scores , labels ):
207+ postprocessed_output = {}
208+ postprocessed_output ["predictions" ] = [{"label" : self .id2label [label ], "score" : score [label ]}]
209+ postprocessed_outputs .append (postprocessed_output )
210+ else : # multi_label
211+ if isinstance (logits , paddle .Tensor ): # dygraph
212+ scores = F .sigmoid (logits ).numpy ()
213+ else : # static graph
214+ scores = np_sigmoid (logits )
215+ for score in scores :
216+ postprocessed_output = {}
217+ postprocessed_output ["predictions" ] = []
218+ for i , class_score in enumerate (score ):
219+ if class_score > self .multilabel_threshold :
220+ postprocessed_output ["predictions" ].append (
221+ {"label" : self .id2label [i ], "score" : class_score }
222+ )
223+ postprocessed_outputs .append (postprocessed_output )
224+
225+ for i , postprocessed_output in enumerate (postprocessed_outputs ):
226+ postprocessed_output ["text" ] = inputs ["text" ][i ]
174227 return postprocessed_outputs
0 commit comments