Skip to content

Commit c82e017

Browse files
Misc. fixes for Pytorch QA examples: (#16958)
1. Fixes evaluation errors popping up when you train/eval on squad v2 (one was newly encountered and one that was previously reported Running SQuAD 1.0 sample command raises IndexError #15401 but not completely fixed). 2. Removes boolean arguments that don't use store_true. Please, don't use these: *ANY non-empty string is being converted to True in this case and this clearly is not the desired behavior (and it creates a LOT of confusion). 3. All no-trainer test scripts are now saving metric values in the same way (with the right prefix eval_), which is consistent with the trainer-based versions. 4. Adds forgotten model.eval() in the no-trainer versions. This improved some results, but not everything (see the discussion in the end). Please, see the F1 scores and the discussion below.
1 parent 49d5bcb commit c82e017

File tree

6 files changed

+105
-16
lines changed

6 files changed

+105
-16
lines changed

examples/flax/question-answering/utils_qa.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def postprocess_qa_predictions(
158158
"end_logit": end_logits[end_index],
159159
}
160160
)
161-
if version_2_with_negative:
161+
if version_2_with_negative and min_null_prediction is not None:
162162
# Add the minimum null prediction
163163
prelim_predictions.append(min_null_prediction)
164164
null_score = min_null_prediction["score"]
@@ -167,7 +167,11 @@ def postprocess_qa_predictions(
167167
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
168168

169169
# Add back the minimum null prediction if it was removed because of its low score.
170-
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
170+
if (
171+
version_2_with_negative
172+
and min_null_prediction is not None
173+
and not any(p["offsets"] == (0, 0) for p in predictions)
174+
):
171175
predictions.append(min_null_prediction)
172176

173177
# Use the offsets to gather the answer text in the original context.
@@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
350354
start_index >= len(offset_mapping)
351355
or end_index >= len(offset_mapping)
352356
or offset_mapping[start_index] is None
357+
or len(offset_mapping[start_index]) < 2
353358
or offset_mapping[end_index] is None
359+
or len(offset_mapping[end_index]) < 2
354360
):
355361
continue
362+
356363
# Don't consider answers with a length negative or > max_answer_length.
357364
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
358365
continue
@@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
381388
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
382389
# failure.
383390
if len(predictions) == 0:
384-
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
391+
# Without predictions min_null_score is going to be None and None will cause an exception later
392+
min_null_score = -2e-6
393+
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})
385394

386395
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
387396
# the LogSumExp trick).

examples/pytorch/question-answering/run_qa_beam_search_no_trainer.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
2020

2121
import argparse
22+
import json
2223
import logging
2324
import math
2425
import os
@@ -60,6 +61,29 @@
6061
logger = logging.getLogger(__name__)
6162

6263

64+
def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
65+
"""
66+
Save results while prefixing metric names.
67+
68+
Args:
69+
results: (:obj:`dict`):
70+
A dictionary of results.
71+
output_dir: (:obj:`str`):
72+
An output directory.
73+
file_name: (:obj:`str`, `optional`, defaults to :obj:`all_results.json`):
74+
An output file name.
75+
metric_key_prefix: (:obj:`str`, `optional`, defaults to :obj:`eval`):
76+
A metric name prefix.
77+
"""
78+
# Prefix all keys with metric_key_prefix + '_'
79+
for key in list(results.keys()):
80+
if not key.startswith(f"{metric_key_prefix}_"):
81+
results[f"{metric_key_prefix}_{key}"] = results.pop(key)
82+
83+
with open(os.path.join(output_dir, file_name), "w") as f:
84+
json.dump(results, f, indent=4)
85+
86+
6387
def parse_args():
6488
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task")
6589
parser.add_argument(
@@ -171,8 +195,7 @@ def parse_args():
171195
)
172196
parser.add_argument(
173197
"--version_2_with_negative",
174-
type=bool,
175-
default=False,
198+
action="store_true",
176199
help="If true, some of the examples do not have an answer.",
177200
)
178201
parser.add_argument(
@@ -807,6 +830,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
807830
all_end_top_log_probs = []
808831
all_end_top_index = []
809832
all_cls_logits = []
833+
834+
model.eval()
835+
810836
for step, batch in enumerate(eval_dataloader):
811837
with torch.no_grad():
812838
outputs = model(**batch)
@@ -864,6 +890,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
864890
all_end_top_log_probs = []
865891
all_end_top_index = []
866892
all_cls_logits = []
893+
894+
model.eval()
895+
867896
for step, batch in enumerate(predict_dataloader):
868897
with torch.no_grad():
869898
outputs = model(**batch)
@@ -938,6 +967,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
938967
if args.push_to_hub:
939968
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
940969

970+
logger.info(json.dumps(eval_metric, indent=4))
971+
save_prefixed_metrics(eval_metric, args.output_dir)
972+
941973

942974
if __name__ == "__main__":
943975
main()

examples/pytorch/question-answering/run_qa_no_trainer.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,29 @@
6666
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
6767

6868

69+
def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
70+
"""
71+
Save results while prefixing metric names.
72+
73+
Args:
74+
results: (:obj:`dict`):
75+
A dictionary of results.
76+
output_dir: (:obj:`str`):
77+
An output directory.
78+
file_name: (:obj:`str`, `optional`, defaults to :obj:`all_results.json`):
79+
An output file name.
80+
metric_key_prefix: (:obj:`str`, `optional`, defaults to :obj:`eval`):
81+
A metric name prefix.
82+
"""
83+
# Prefix all keys with metric_key_prefix + '_'
84+
for key in list(results.keys()):
85+
if not key.startswith(f"{metric_key_prefix}_"):
86+
results[f"{metric_key_prefix}_{key}"] = results.pop(key)
87+
88+
with open(os.path.join(output_dir, file_name), "w") as f:
89+
json.dump(results, f, indent=4)
90+
91+
6992
def parse_args():
7093
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task")
7194
parser.add_argument(
@@ -194,8 +217,7 @@ def parse_args():
194217
)
195218
parser.add_argument(
196219
"--version_2_with_negative",
197-
type=bool,
198-
default=False,
220+
action="store_true",
199221
help="If true, some of the examples do not have an answer.",
200222
)
201223
parser.add_argument(
@@ -824,6 +846,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
824846

825847
all_start_logits = []
826848
all_end_logits = []
849+
850+
model.eval()
851+
827852
for step, batch in enumerate(eval_dataloader):
828853
with torch.no_grad():
829854
outputs = model(**batch)
@@ -860,6 +885,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
860885

861886
all_start_logits = []
862887
all_end_logits = []
888+
889+
model.eval()
890+
863891
for step, batch in enumerate(predict_dataloader):
864892
with torch.no_grad():
865893
outputs = model(**batch)
@@ -907,8 +935,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
907935
tokenizer.save_pretrained(args.output_dir)
908936
if args.push_to_hub:
909937
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
910-
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
911-
json.dump({"eval_f1": eval_metric["f1"], "eval_exact": eval_metric["exact"]}, f)
938+
939+
logger.info(json.dumps(eval_metric, indent=4))
940+
save_prefixed_metrics(eval_metric, args.output_dir)
912941

913942

914943
if __name__ == "__main__":

examples/pytorch/question-answering/utils_qa.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def postprocess_qa_predictions(
158158
"end_logit": end_logits[end_index],
159159
}
160160
)
161-
if version_2_with_negative:
161+
if version_2_with_negative and min_null_prediction is not None:
162162
# Add the minimum null prediction
163163
prelim_predictions.append(min_null_prediction)
164164
null_score = min_null_prediction["score"]
@@ -167,7 +167,11 @@ def postprocess_qa_predictions(
167167
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
168168

169169
# Add back the minimum null prediction if it was removed because of its low score.
170-
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
170+
if (
171+
version_2_with_negative
172+
and min_null_prediction is not None
173+
and not any(p["offsets"] == (0, 0) for p in predictions)
174+
):
171175
predictions.append(min_null_prediction)
172176

173177
# Use the offsets to gather the answer text in the original context.
@@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
350354
start_index >= len(offset_mapping)
351355
or end_index >= len(offset_mapping)
352356
or offset_mapping[start_index] is None
357+
or len(offset_mapping[start_index]) < 2
353358
or offset_mapping[end_index] is None
359+
or len(offset_mapping[end_index]) < 2
354360
):
355361
continue
362+
356363
# Don't consider answers with a length negative or > max_answer_length.
357364
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
358365
continue
@@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
381388
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
382389
# failure.
383390
if len(predictions) == 0:
384-
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
391+
# Without predictions min_null_score is going to be None and None will cause an exception later
392+
min_null_score = -2e-6
393+
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})
385394

386395
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
387396
# the LogSumExp trick).

examples/pytorch/test_accelerate_examples.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def test_run_squad_no_trainer(self):
200200
testargs = f"""
201201
run_qa_no_trainer.py
202202
--model_name_or_path bert-base-uncased
203-
--version_2_with_negative=False
203+
--version_2_with_negative
204204
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
205205
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
206206
--output_dir {tmp_dir}
@@ -216,6 +216,7 @@ def test_run_squad_no_trainer(self):
216216
with patch.object(sys, "argv", testargs):
217217
run_squad_no_trainer.main()
218218
result = get_results(tmp_dir)
219+
# Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
219220
self.assertGreaterEqual(result["eval_f1"], 30)
220221
self.assertGreaterEqual(result["eval_exact"], 30)
221222
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))

examples/tensorflow/question-answering/utils_qa.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def postprocess_qa_predictions(
158158
"end_logit": end_logits[end_index],
159159
}
160160
)
161-
if version_2_with_negative:
161+
if version_2_with_negative and min_null_prediction is not None:
162162
# Add the minimum null prediction
163163
prelim_predictions.append(min_null_prediction)
164164
null_score = min_null_prediction["score"]
@@ -167,7 +167,11 @@ def postprocess_qa_predictions(
167167
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
168168

169169
# Add back the minimum null prediction if it was removed because of its low score.
170-
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
170+
if (
171+
version_2_with_negative
172+
and min_null_prediction is not None
173+
and not any(p["offsets"] == (0, 0) for p in predictions)
174+
):
171175
predictions.append(min_null_prediction)
172176

173177
# Use the offsets to gather the answer text in the original context.
@@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
350354
start_index >= len(offset_mapping)
351355
or end_index >= len(offset_mapping)
352356
or offset_mapping[start_index] is None
357+
or len(offset_mapping[start_index]) < 2
353358
or offset_mapping[end_index] is None
359+
or len(offset_mapping[end_index]) < 2
354360
):
355361
continue
362+
356363
# Don't consider answers with a length negative or > max_answer_length.
357364
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
358365
continue
@@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
381388
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
382389
# failure.
383390
if len(predictions) == 0:
384-
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
391+
# Without predictions min_null_score is going to be None and None will cause an exception later
392+
min_null_score = -2e-6
393+
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})
385394

386395
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
387396
# the LogSumExp trick).

0 commit comments

Comments
 (0)