Skip to content

Commit 398d229

Browse files
authored
Add a text translation example (#1283)
1 parent d9fee4f commit 398d229

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

examples/text_translation.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import numpy as np
2+
import torch
3+
from datasets import load_dataset
4+
from torchtext.data.metrics import bleu_score
5+
from transformers import AutoTokenizer, T5ForConditionalGeneration
6+
7+
from mmengine.evaluator import BaseMetric
8+
from mmengine.model import BaseModel
9+
from mmengine.runner import Runner
10+
11+
tokenizer = AutoTokenizer.from_pretrained('t5-small')
12+
13+
14+
class MMT5ForTranslation(BaseModel):
15+
16+
def __init__(self, model):
17+
super().__init__()
18+
self.model = model
19+
20+
def forward(self, label, input_ids, attention_mask, mode):
21+
if mode == 'loss':
22+
output = self.model(
23+
input_ids=input_ids,
24+
attention_mask=attention_mask,
25+
labels=label)
26+
return {'loss': output.loss}
27+
elif mode == 'predict':
28+
output = self.model.generate(input_ids)
29+
return output, label
30+
31+
32+
def post_process(preds, labels):
33+
preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
34+
labels = torch.where(labels != -100, labels, tokenizer.pad_token_id)
35+
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
36+
decoded_preds = [pred.split() for pred in preds]
37+
decoded_labels = [[label.split()] for label in labels]
38+
return decoded_preds, decoded_labels
39+
40+
41+
class Accuracy(BaseMetric):
42+
43+
def process(self, data_batch, data_samples):
44+
outputs, labels = data_samples
45+
decoded_preds, decoded_labels = post_process(outputs, labels)
46+
score = bleu_score(decoded_preds, decoded_labels)
47+
prediction_lens = torch.tensor([
48+
torch.count_nonzero(pred != tokenizer.pad_token_id)
49+
for pred in outputs
50+
],
51+
dtype=torch.float64)
52+
53+
gen_len = torch.mean(prediction_lens).item()
54+
self.results.append({
55+
'gen_len': gen_len,
56+
'bleu': score,
57+
})
58+
59+
def compute_metrics(self, results):
60+
return dict(
61+
gen_len=np.mean([item['gen_len'] for item in results]),
62+
bleu_score=np.mean([item['bleu'] for item in results]),
63+
)
64+
65+
66+
def collate_fn(data):
67+
prefix = 'translate English to French: '
68+
input_sequences = [prefix + item['translation']['en'] for item in data]
69+
target_sequences = [item['translation']['fr'] for item in data]
70+
input_dict = tokenizer(
71+
input_sequences,
72+
padding='longest',
73+
return_tensors='pt',
74+
)
75+
76+
label = tokenizer(
77+
target_sequences,
78+
padding='longest',
79+
return_tensors='pt',
80+
).input_ids
81+
label[label ==
82+
tokenizer.pad_token_id] = -100 # ignore contribution to loss
83+
return dict(
84+
label=label,
85+
input_ids=input_dict.input_ids,
86+
attention_mask=input_dict.attention_mask)
87+
88+
89+
def main():
90+
model = T5ForConditionalGeneration.from_pretrained('t5-small')
91+
92+
books = load_dataset('opus_books', 'en-fr')
93+
books = books['train'].train_test_split(test_size=0.2)
94+
train_set, test_set = books['train'], books['test']
95+
96+
train_loader = dict(
97+
batch_size=16,
98+
dataset=train_set,
99+
sampler=dict(type='DefaultSampler', shuffle=True),
100+
collate_fn=collate_fn)
101+
test_loader = dict(
102+
batch_size=32,
103+
dataset=test_set,
104+
sampler=dict(type='DefaultSampler', shuffle=False),
105+
collate_fn=collate_fn)
106+
runner = Runner(
107+
model=MMT5ForTranslation(model),
108+
train_dataloader=train_loader,
109+
val_dataloader=test_loader,
110+
optim_wrapper=dict(optimizer=dict(type=torch.optim.Adam, lr=2e-5)),
111+
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
112+
val_cfg=dict(),
113+
work_dir='t5_work_dir',
114+
val_evaluator=dict(type=Accuracy))
115+
116+
runner.train()
117+
118+
119+
if __name__ == '__main__':
120+
main()

0 commit comments

Comments
 (0)