Skip to content

Commit 79d9a37

Browse files
authored
fix tokenization logic (#4565)
confirmed that the PR fix the performance issue. Force merging to bypass the known CI bug
1 parent d6d597d commit 79d9a37

File tree

1 file changed

+64
-37
lines changed

1 file changed

+64
-37
lines changed

model_zoo/ernie-3.0/utils.py

Lines changed: 64 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,14 @@ def prepare_train_features(examples, tokenizer, args, dynamic_max_length: Option
5959
max_length = get_dynamic_max_length(
6060
examples=tokenized_examples, default_max_length=args.max_seq_length, dynamic_max_length=dynamic_max_length
6161
)
62+
# always pad to max_length
63+
tokenized_examples = tokenizer(
64+
questions, contexts, stride=args.doc_stride, max_length=max_length, padding="max_length", truncation=True
65+
)
6266
else:
63-
max_length = args.max_seq_length
64-
tokenized_examples = tokenizer(
65-
questions, contexts, stride=args.doc_stride, max_length=max_length, padding="max_length", truncation=True
66-
)
67+
tokenized_examples = tokenizer(
68+
questions, contexts, stride=args.doc_stride, max_length=args.max_seq_length, truncation=True
69+
)
6770

6871
# Since one example might give us several features if it has a long context, we need a map from a feature to
6972
# its corresponding example. This key gives us just that.
@@ -140,11 +143,14 @@ def prepare_validation_features(examples, tokenizer, args, dynamic_max_length: O
140143
max_length = get_dynamic_max_length(
141144
examples=tokenized_examples, default_max_length=args.max_seq_length, dynamic_max_length=dynamic_max_length
142145
)
146+
# always pad to max_length
147+
tokenized_examples = tokenizer(
148+
questions, contexts, stride=args.doc_stride, max_length=max_length, padding="max_length", truncation=True
149+
)
143150
else:
144-
max_length = args.max_seq_length
145-
tokenized_examples = tokenizer(
146-
questions, contexts, stride=args.doc_stride, max_length=max_length, padding="max_length", truncation=True
147-
)
151+
tokenized_examples = tokenizer(
152+
questions, contexts, stride=args.doc_stride, max_length=args.max_seq_length, truncation=True
153+
)
148154
# Since one example might give us several features if it has a long context, we need a map from a feature to
149155
# its corresponding example. This key gives us just that.
150156
sample_mapping = tokenized_examples.pop("overflow_to_sample")
@@ -315,9 +321,10 @@ def seq_convert_example(
315321
max_length = get_dynamic_max_length(
316322
examples=temp_example, default_max_length=max_seq_length, dynamic_max_length=dynamic_max_length
317323
)
324+
# always pad to max_length
325+
example = tokenizer(example["sentence"], max_length=max_length, padding="max_length", truncation=True)
318326
else:
319-
max_length = max_seq_length
320-
example = tokenizer(example["sentence"], max_length=max_length, padding="max_length", truncation=True)
327+
example = tokenizer(example["sentence"], max_length=max_seq_length, truncation=True)
321328
elif "sentence1" in example:
322329
if dynamic_max_length is not None:
323330
temp_example = tokenizer(
@@ -329,15 +336,21 @@ def seq_convert_example(
329336
max_length = get_dynamic_max_length(
330337
examples=temp_example, default_max_length=max_seq_length, dynamic_max_length=dynamic_max_length
331338
)
339+
example = tokenizer(
340+
example["sentence1"],
341+
text_pair=example["sentence2"],
342+
max_length=max_length,
343+
padding="max_length",
344+
truncation=True,
345+
)
332346
else:
333-
max_length = max_seq_length
334-
example = tokenizer(
335-
example["sentence1"],
336-
text_pair=example["sentence2"],
337-
max_length=max_length,
338-
padding="max_length",
339-
truncation=True,
340-
)
347+
example = tokenizer(
348+
example["sentence1"],
349+
text_pair=example["sentence2"],
350+
max_length=max_seq_length,
351+
truncation=True,
352+
)
353+
341354
if not is_test:
342355
if "token_type_ids" in example:
343356
return {"input_ids": example["input_ids"], "token_type_ids": example["token_type_ids"], "labels": label}
@@ -369,16 +382,23 @@ def token_convert_example(
369382
max_length = get_dynamic_max_length(
370383
examples=tokenized_input, default_max_length=max_seq_length, dynamic_max_length=dynamic_max_length
371384
)
385+
# always pad to max_length
386+
tokenized_input = tokenizer(
387+
example,
388+
is_split_into_words=True,
389+
max_length=max_length,
390+
padding="max_length",
391+
truncation=True,
392+
return_length=return_length,
393+
)
372394
else:
373-
max_length = max_seq_length
374-
tokenized_input = tokenizer(
375-
example,
376-
is_split_into_words=True,
377-
max_length=max_length,
378-
padding="max_length",
379-
truncation=True,
380-
return_length=return_length,
381-
)
395+
tokenized_input = tokenizer(
396+
example,
397+
is_split_into_words=True,
398+
max_length=max_seq_length,
399+
truncation=True,
400+
return_length=return_length,
401+
)
382402

383403
# -2 for [CLS] and [SEP]
384404
if len(tokenized_input["input_ids"]) - 2 < len(labels):
@@ -406,17 +426,24 @@ def token_convert_example(
406426
max_length = get_dynamic_max_length(
407427
examples=tokenized_input, default_max_length=max_seq_length, dynamic_max_length=dynamic_max_length
408428
)
429+
# always pad to max_length
430+
tokenized_input = tokenizer(
431+
example["tokens"],
432+
max_length=max_length,
433+
padding="max_length",
434+
truncation=True,
435+
is_split_into_words=True,
436+
return_length=return_length,
437+
)
409438
else:
410-
max_length = max_seq_length
411-
412-
tokenized_input = tokenizer(
413-
example["tokens"],
414-
max_length=max_length,
415-
padding="max_length",
416-
truncation=True,
417-
is_split_into_words=True,
418-
return_length=return_length,
419-
)
439+
tokenized_input = tokenizer(
440+
example["tokens"],
441+
max_length=max_seq_length,
442+
truncation=True,
443+
is_split_into_words=True,
444+
return_length=return_length,
445+
)
446+
420447
label_ids = example["ner_tags"]
421448
if len(tokenized_input["input_ids"]) - 2 < len(label_ids):
422449
label_ids = label_ids[: len(tokenized_input["input_ids"]) - 2]

0 commit comments

Comments
 (0)