Skip to content

Commit f2a583c

Browse files
sijunhewj-Mcatw5688414LemonNoelJunnYu
authored
Support multi-label setup in Text Classification Taskflow (#3968)
* [NewFeatures]add ci & pypi workflow (#3578) * add basic ci & pypi workflow * add Makefile file to enable more custom scripts * update makefile & workflow * add manifest.in file * save requirements * improve makefile * add test running command * complete simple CI test script * complete simple CI test script * complete first version of workflow * failure at first * remove CI workflow * update workflows * add dev-dependency * complete format & lint * add paddle dependency * complete format & lint & test command * upgrae workflow files * use pytest to do test * update need keywords * change the description * update lint script * upgrade workflow * update parameterized * update makefile * fix lint * update test command in workflow * fix __init__ lint * ignore all __init__.py file * fix script workflow * udpate flake8 config * update flake config * update flake8 config * remove redefinition of unused Co-authored-by: Sijun He <[email protected]> * Add multi recall of semantic search for pipelines (#3864) * Add multi recall of semantic search for pipelines * Update multi recall semantic search README.md * remove unused imports * remove unused imports * Update __init__.py * remove unused imports * restore __init__.py * skip retriever __init__.py * [trainer] fix bug when batch size=1 (#3960) * [PPdiffusers] Release ppdiffusers 0.6.3 (#3963) * release 0.6.3 * release 0.6.3 * style * code style * fix test * remove commented code Co-authored-by: 骑马小猫 <[email protected]> Co-authored-by: w5688414 <[email protected]> Co-authored-by: Noel <[email protected]> Co-authored-by: yujun <[email protected]>
1 parent c40e9de commit f2a583c

File tree

3 files changed

+152
-69
lines changed

3 files changed

+152
-69
lines changed

paddlenlp/taskflow/taskflow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,11 @@
355355
"models": {
356356
"multi_class": {
357357
"task_class": TextClassificationTask,
358-
"task_flag": "text_classification-text_classification",
358+
"task_flag": "text_classification-multi_class",
359+
},
360+
"multi_label": {
361+
"task_class": TextClassificationTask,
362+
"task_flag": "text_classification-multi_label",
359363
},
360364
},
361365
"default": {"model": "multi_class"},

paddlenlp/taskflow/text_classification.py

Lines changed: 95 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,70 @@
1515

1616
from typing import Any, Dict, List, Union
1717

18-
from paddlenlp.data import DataCollatorWithPadding
19-
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
20-
2118
import numpy as np
2219
import paddle
2320
import paddle.nn.functional as F
24-
from .utils import static_mode_guard, dygraph_mode_guard
21+
from scipy.special import expit as np_sigmoid
22+
from scipy.special import softmax as np_softmax
23+
24+
from paddlenlp.data import DataCollatorWithPadding
25+
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
26+
2527
from .task import Task
28+
from .utils import dygraph_mode_guard, static_mode_guard
2629

2730
usage = r"""
2831
from paddlenlp import Taskflow
29-
id2label = {
30-
0: "negative",
31-
1: "positive"
32-
}
3332
text_cls = Taskflow(
3433
"text_classification",
3534
model="multi_class",
3635
task_path=<local_saved_model>,
37-
id2label=id2label
36+
id2label={0: "negative", 1: "positive"}
3837
)
3938
text_cls('房间依然很整洁,相当不错')
4039
'''
4140
[
42-
{'text': '房间依然很整洁,相当不错',
43-
'label': 'positive',
44-
'score': 0.80}
41+
{
42+
'text': '房间依然很整洁,相当不错',
43+
'predictions: [{
44+
'label': 'positive',
45+
'score': 0.80
46+
}]
47+
}
4548
]
4649
'''
47-
48-
text_cls(['房间依然很整洁,相当不错',
49-
'味道不咋地,很一般'])
50+
text_cls = Taskflow(
51+
"text_classification",
52+
model="multi_label",
53+
task_path=<local_saved_model>,
54+
id2label={ 0: "体育", 1: "经济", 2: "娱乐"}
55+
)
56+
text_cls(['这是一条体育娱乐新闻的例子',
57+
'这是一条经济新闻'])
5058
'''
5159
[
52-
{'text': '房间依然很整洁,相当不错',
53-
'label': 'positive',
54-
'score': 0.90},
55-
{'text': '味道不咋地,很一般',
56-
'label': 'negative',
57-
'score': 0.88},
60+
{
61+
'text': '这是一条体育娱乐新闻的例子',
62+
'predictions: [
63+
{
64+
'label': '体育',
65+
'score': 0.80
66+
},
67+
{
68+
'label': '娱乐',
69+
'score': 0.90
70+
}
71+
]
72+
},
73+
{
74+
'text': '这是一条经济新闻',
75+
'predictions: [
76+
{
77+
'label': '经济',
78+
'score': 0.80
79+
}
80+
]
81+
}
5882
]
5983
"""
6084

@@ -73,18 +97,29 @@ class TextClassificationTask(Task):
7397
7498
Args:
7599
task (string): The name of task.
76-
model (string): Mode of the classification, only support `multi_class` for now
100+
model (string): Mode of the classification, Supports ["multi_class", "multi_class"]
77101
task_path (string): The local file path to the model path or a pre-trained model
78102
id2label (string): The dictionary to map the predictions from class ids to class names
79103
is_static_model (string): Whether the model is a static model
104+
multilabel_threshold (float): The probability threshold used for the multi_label setup. Only effective if model = "multi_label". Defaults to 0.5
80105
kwargs (dict, optional): Additional keyword arguments passed along to the specific task.
81106
"""
82107

83-
def __init__(self, task: str, model: str, id2label: Dict[int, str], is_static_model: bool = False, **kwargs):
108+
def __init__(
109+
self,
110+
task: str,
111+
model: str,
112+
id2label: Dict[int, str],
113+
is_static_model: bool = False,
114+
multilabel_threshold: float = 0.5,
115+
**kwargs
116+
):
84117
super().__init__(task=task, model=model, is_static_model=is_static_model, **kwargs)
85118
self.id2label = id2label
86119
self.is_static_model = is_static_model
87120
self._construct_tokenizer(self._task_path)
121+
self.multilabel_threshold = multilabel_threshold
122+
88123
if self.is_static_model:
89124
self._get_inference_model()
90125
else:
@@ -135,40 +170,58 @@ def _run_model(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
135170
"""
136171
Run the task model from the outputs of the `_tokenize` function.
137172
"""
138-
# TODO: support multi_label, hierachical classification
139-
model_outputs = []
173+
# TODO: support hierachical classification
174+
outputs = {}
175+
outputs["text"] = inputs["text"]
176+
outputs["batch_logits"] = []
140177
if self.is_static_model:
141178
with static_mode_guard():
142179
for batch in inputs["batches"]:
143180
for i, input_name in enumerate(self.predictor.get_input_names()):
144181
self.input_handles[i].copy_from_cpu(batch[input_name])
145182
self.predictor.run()
146183
logits = self.output_handle[0].copy_to_cpu().tolist()
147-
pred_indices = np.argmax(logits, axis=-1)
148-
probs = softmax(logits, axis=-1)
149-
for prob, pred_index in zip(probs, pred_indices):
150-
model_outputs.append({"label": pred_index, "score": prob[pred_index]})
184+
outputs["batch_logits"].append(logits)
151185
else:
152186
with dygraph_mode_guard():
153187
for batch in inputs["batches"]:
154188
logits = self._model(**batch)
155-
probs = F.softmax(logits, axis=-1).tolist()
156-
pred_indices = paddle.argmax(logits, axis=-1).tolist()
157-
for prob, pred_index in zip(probs, pred_indices):
158-
model_outputs.append({"label": pred_index, "score": prob[pred_index]})
159-
outputs = {}
160-
outputs["text"] = inputs["text"]
161-
outputs["model_outputs"] = model_outputs
189+
outputs["batch_logits"].append(logits)
162190
return outputs
163191

164192
def _postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
165193
"""
166-
The model output is tag ids, this function will convert the model output to raw text.
194+
This function converts the model logits output to class score and predictions
167195
"""
168-
# TODO: support multi_label, hierachical classification
196+
# TODO: support hierachical classification
169197
postprocessed_outputs = []
170-
for i, model_output in enumerate(inputs["model_outputs"]):
171-
model_output["label"] = self.id2label[model_output["label"]]
172-
model_output["text"] = inputs["text"][i]
173-
postprocessed_outputs.append(model_output)
198+
for logits in inputs["batch_logits"]:
199+
if self.model == "multi_class":
200+
if isinstance(logits, paddle.Tensor): # dygraph
201+
scores = F.softmax(logits, axis=-1).numpy()
202+
labels = paddle.argmax(logits, axis=-1).numpy()
203+
else: # static graph
204+
scores = np_softmax(logits, axis=-1)
205+
labels = np.argmax(logits, axis=-1)
206+
for score, label in zip(scores, labels):
207+
postprocessed_output = {}
208+
postprocessed_output["predictions"] = [{"label": self.id2label[label], "score": score[label]}]
209+
postprocessed_outputs.append(postprocessed_output)
210+
else: # multi_label
211+
if isinstance(logits, paddle.Tensor): # dygraph
212+
scores = F.sigmoid(logits).numpy()
213+
else: # static graph
214+
scores = np_sigmoid(logits)
215+
for score in scores:
216+
postprocessed_output = {}
217+
postprocessed_output["predictions"] = []
218+
for i, class_score in enumerate(score):
219+
if class_score > self.multilabel_threshold:
220+
postprocessed_output["predictions"].append(
221+
{"label": self.id2label[i], "score": class_score}
222+
)
223+
postprocessed_outputs.append(postprocessed_output)
224+
225+
for i, postprocessed_output in enumerate(postprocessed_outputs):
226+
postprocessed_output["text"] = inputs["text"][i]
174227
return postprocessed_outputs

tests/taskflow/test_text_classification.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,47 +14,57 @@
1414

1515
import os
1616
import unittest
17+
from tempfile import TemporaryDirectory
1718

1819
import paddle
1920
from parameterized import parameterized
20-
from tempfile import TemporaryDirectory
21+
2122
from paddlenlp.taskflow import Taskflow
2223
from paddlenlp.taskflow.text_classification import TextClassificationTask
2324
from paddlenlp.transformers import AutoTokenizer, ErnieForSequenceClassification
2425

2526

2627
class TestTextClassificationTask(unittest.TestCase):
27-
def setUp(self):
28-
self.temp_dir = TemporaryDirectory()
29-
self.dygraph_model_path = os.path.join(self.temp_dir.name, "dygraph")
28+
@classmethod
29+
def setUpClass(cls):
30+
cls.temp_dir = TemporaryDirectory()
31+
cls.dygraph_model_path = os.path.join(cls.temp_dir.name, "dygraph")
3032
model = ErnieForSequenceClassification.from_pretrained("__internal_testing__/ernie", num_classes=2)
3133
tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/ernie")
32-
model.save_pretrained(self.dygraph_model_path)
33-
tokenizer.save_pretrained(self.dygraph_model_path)
34+
model.save_pretrained(cls.dygraph_model_path)
35+
tokenizer.save_pretrained(cls.dygraph_model_path)
3436

3537
# export to static
36-
self.static_model_path = os.path.join(self.temp_dir.name, "static")
38+
cls.static_model_path = os.path.join(cls.temp_dir.name, "static")
3739
input_spec = [
3840
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"),
3941
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_type_ids"),
4042
]
4143
static_model = paddle.jit.to_static(model, input_spec=input_spec)
42-
paddle.jit.save(static_model, self.static_model_path)
43-
tokenizer.save_pretrained(self.static_model_path)
44-
45-
def tearDown(self):
46-
self.temp_dir.cleanup()
47-
48-
@parameterized.expand([(1,), (2,)])
49-
def test_text_classification_task(self, batch_size):
44+
paddle.jit.save(static_model, cls.static_model_path)
45+
tokenizer.save_pretrained(cls.static_model_path)
46+
47+
@classmethod
48+
def tearDownClass(cls):
49+
cls.temp_dir.cleanup()
50+
51+
@parameterized.expand(
52+
[
53+
(1, "multi_class"),
54+
(2, "multi_class"),
55+
(1, "multi_label"),
56+
(2, "multi_label"),
57+
]
58+
)
59+
def test_classification_task(self, batch_size, model):
5060
# input_text is a tuple to simulate the args passed from Taskflow to TextClassificationTask
5161
input_text = (["百度", "深度学习框架", "飞桨", "PaddleNLP"],)
5262
id2label = {
5363
0: "negative",
5464
1: "positive",
5565
}
5666
dygraph_taskflow = TextClassificationTask(
57-
model="multi_class",
67+
model=model,
5868
task="text_classification",
5969
task_path=self.dygraph_model_path,
6070
id2label=id2label,
@@ -63,10 +73,11 @@ def test_text_classification_task(self, batch_size):
6373
)
6474

6575
dygraph_results = dygraph_taskflow(input_text)
76+
6677
self.assertEqual(len(dygraph_results), len(input_text[0]))
6778

6879
static_taskflow = TextClassificationTask(
69-
model="multi_class",
80+
model=model,
7081
task="text_classification",
7182
is_static_model=True,
7283
task_path=self.static_model_path,
@@ -79,18 +90,29 @@ def test_text_classification_task(self, batch_size):
7990
self.assertEqual(len(static_results), len(input_text[0]))
8091

8192
for dygraph_result, static_result in zip(dygraph_results, static_results):
82-
self.assertEqual(dygraph_result["label"], static_result["label"])
83-
self.assertAlmostEqual(dygraph_result["score"], static_result["score"], delta=1e-6)
84-
85-
@parameterized.expand([(1,), (2,)])
86-
def test_taskflow(self, batch_size):
93+
for dygraph_pred, static_pred in zip(dygraph_result["predictions"], static_result["predictions"]):
94+
self.assertEqual(dygraph_pred["label"], static_pred["label"])
95+
self.assertAlmostEqual(dygraph_pred["score"], static_pred["score"], delta=1e-6)
96+
# if multi_label, all predictions should be greater than the threshold
97+
if model == "multi_label":
98+
self.assertGreater(dygraph_pred["score"], dygraph_taskflow.multilabel_threshold)
99+
100+
@parameterized.expand(
101+
[
102+
(1, "multi_class"),
103+
(2, "multi_class"),
104+
(1, "multi_label"),
105+
(2, "multi_label"),
106+
]
107+
)
108+
def test_taskflow(self, batch_size, model):
87109
input_text = ["百度", "深度学习框架", "飞桨", "PaddleNLP"]
88110
id2label = {
89111
0: "negative",
90112
1: "positive",
91113
}
92114
dygraph_taskflow = Taskflow(
93-
model="multi_class",
115+
model=model,
94116
task="text_classification",
95117
task_path=self.dygraph_model_path,
96118
id2label=id2label,
@@ -101,7 +123,7 @@ def test_taskflow(self, batch_size):
101123
self.assertEqual(len(dygraph_results), len(input_text))
102124

103125
static_taskflow = Taskflow(
104-
model="multi_class",
126+
model=model,
105127
task="text_classification",
106128
is_static_model=True,
107129
task_path=self.static_model_path,
@@ -113,5 +135,9 @@ def test_taskflow(self, batch_size):
113135
self.assertEqual(len(static_results), len(input_text))
114136

115137
for dygraph_result, static_result in zip(dygraph_results, static_results):
116-
self.assertEqual(dygraph_result["label"], static_result["label"])
117-
self.assertAlmostEqual(dygraph_result["score"], static_result["score"], delta=1e-6)
138+
for dygraph_pred, static_pred in zip(dygraph_result["predictions"], static_result["predictions"]):
139+
self.assertEqual(dygraph_pred["label"], static_pred["label"])
140+
self.assertAlmostEqual(dygraph_pred["score"], static_pred["score"], delta=1e-6)
141+
# if multi_label, all predictions should be greater than the threshold
142+
if model == "multi_label":
143+
self.assertGreater(dygraph_pred["score"], dygraph_taskflow.task_instance.multilabel_threshold)

0 commit comments

Comments
 (0)