Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions examples/flax/question-answering/utils_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def postprocess_qa_predictions(
"end_logit": end_logits[end_index],
}
)
if version_2_with_negative:
if version_2_with_negative and min_null_prediction is not None:
# Add the minimum null prediction
prelim_predictions.append(min_null_prediction)
null_score = min_null_prediction["score"]
Expand All @@ -167,7 +167,11 @@ def postprocess_qa_predictions(
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

# Add back the minimum null prediction if it was removed because of its low score.
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
if (
version_2_with_negative
and min_null_prediction is not None
and not any(p["offsets"] == (0, 0) for p in predictions)
):
predictions.append(min_null_prediction)

# Use the offsets to gather the answer text in the original context.
Expand Down Expand Up @@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
):
continue

# Don't consider answers with a length negative or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
Expand Down Expand Up @@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0:
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
# Without predictions min_null_score is going to be None and None will cause an exception later
min_null_score = -2e-6
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})

# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
# the LogSumExp trick).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.

import argparse
import json
import logging
import math
import os
Expand Down Expand Up @@ -60,6 +61,29 @@
logger = logging.getLogger(__name__)


def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
"""
Save results while prefixing metric names.

Args:
results: (:obj:`dict`):
A dictionary of results.
output_dir: (:obj:`str`):
An output directory.
file_name: (:obj:`str`, `optional`, defaults to :obj:`all_results.json`):
An output file name.
metric_key_prefix: (:obj:`str`, `optional`, defaults to :obj:`eval`):
A metric name prefix.
"""
# Prefix all keys with metric_key_prefix + '_'
for key in list(results.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
results[f"{metric_key_prefix}_{key}"] = results.pop(key)

with open(os.path.join(output_dir, file_name), "w") as f:
json.dump(results, f, indent=4)


def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task")
parser.add_argument(
Expand Down Expand Up @@ -171,8 +195,7 @@ def parse_args():
)
parser.add_argument(
"--version_2_with_negative",
type=bool,
default=False,
action="store_true",
help="If true, some of the examples do not have an answer.",
)
parser.add_argument(
Expand Down Expand Up @@ -807,6 +830,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
all_end_top_log_probs = []
all_end_top_index = []
all_cls_logits = []

model.eval()

for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
Expand Down Expand Up @@ -864,6 +890,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
all_end_top_log_probs = []
all_end_top_index = []
all_cls_logits = []

model.eval()

for step, batch in enumerate(predict_dataloader):
with torch.no_grad():
outputs = model(**batch)
Expand Down Expand Up @@ -938,6 +967,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)

logger.info(json.dumps(eval_metric, indent=4))
save_prefixed_metrics(eval_metric, args.output_dir)


if __name__ == "__main__":
main()
37 changes: 33 additions & 4 deletions examples/pytorch/question-answering/run_qa_no_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,29 @@
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


def save_prefixed_metrics(results, output_dir, file_name: str = "all_results.json", metric_key_prefix: str = "eval"):
"""
Save results while prefixing metric names.

Args:
results: (:obj:`dict`):
A dictionary of results.
output_dir: (:obj:`str`):
An output directory.
file_name: (:obj:`str`, `optional`, defaults to :obj:`all_results.json`):
An output file name.
metric_key_prefix: (:obj:`str`, `optional`, defaults to :obj:`eval`):
A metric name prefix.
"""
# Prefix all keys with metric_key_prefix + '_'
for key in list(results.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
results[f"{metric_key_prefix}_{key}"] = results.pop(key)

with open(os.path.join(output_dir, file_name), "w") as f:
json.dump(results, f, indent=4)


def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a Question Answering task")
parser.add_argument(
Expand Down Expand Up @@ -194,8 +217,7 @@ def parse_args():
)
parser.add_argument(
"--version_2_with_negative",
type=bool,
default=False,
action="store_true",
help="If true, some of the examples do not have an answer.",
)
parser.add_argument(
Expand Down Expand Up @@ -824,6 +846,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):

all_start_logits = []
all_end_logits = []

model.eval()

for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
Expand Down Expand Up @@ -860,6 +885,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):

all_start_logits = []
all_end_logits = []

model.eval()

for step, batch in enumerate(predict_dataloader):
with torch.no_grad():
outputs = model(**batch)
Expand Down Expand Up @@ -907,8 +935,9 @@ def create_and_fill_np_array(start_or_end_logits, dataset, max_len):
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_f1": eval_metric["f1"], "eval_exact": eval_metric["exact"]}, f)

logger.info(json.dumps(eval_metric, indent=4))
save_prefixed_metrics(eval_metric, args.output_dir)


if __name__ == "__main__":
Expand Down
15 changes: 12 additions & 3 deletions examples/pytorch/question-answering/utils_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def postprocess_qa_predictions(
"end_logit": end_logits[end_index],
}
)
if version_2_with_negative:
if version_2_with_negative and min_null_prediction is not None:
# Add the minimum null prediction
prelim_predictions.append(min_null_prediction)
null_score = min_null_prediction["score"]
Expand All @@ -167,7 +167,11 @@ def postprocess_qa_predictions(
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

# Add back the minimum null prediction if it was removed because of its low score.
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
if (
version_2_with_negative
and min_null_prediction is not None
and not any(p["offsets"] == (0, 0) for p in predictions)
):
predictions.append(min_null_prediction)

# Use the offsets to gather the answer text in the original context.
Expand Down Expand Up @@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
):
continue

# Don't consider answers with a length negative or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
Expand Down Expand Up @@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0:
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
# Without predictions min_null_score is going to be None and None will cause an exception later
min_null_score = -2e-6
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})

# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
# the LogSumExp trick).
Expand Down
3 changes: 2 additions & 1 deletion examples/pytorch/test_accelerate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_run_squad_no_trainer(self):
testargs = f"""
run_qa_no_trainer.py
--model_name_or_path bert-base-uncased
--version_2_with_negative=False
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
--output_dir {tmp_dir}
Expand All @@ -216,6 +216,7 @@ def test_run_squad_no_trainer(self):
with patch.object(sys, "argv", testargs):
run_squad_no_trainer.main()
result = get_results(tmp_dir)
# Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, "epoch_0")))
Expand Down
15 changes: 12 additions & 3 deletions examples/tensorflow/question-answering/utils_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def postprocess_qa_predictions(
"end_logit": end_logits[end_index],
}
)
if version_2_with_negative:
if version_2_with_negative and min_null_prediction is not None:
# Add the minimum null prediction
prelim_predictions.append(min_null_prediction)
null_score = min_null_prediction["score"]
Expand All @@ -167,7 +167,11 @@ def postprocess_qa_predictions(
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

# Add back the minimum null prediction if it was removed because of its low score.
if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
if (
version_2_with_negative
and min_null_prediction is not None
and not any(p["offsets"] == (0, 0) for p in predictions)
):
predictions.append(min_null_prediction)

# Use the offsets to gather the answer text in the original context.
Expand Down Expand Up @@ -350,9 +354,12 @@ def postprocess_qa_predictions_with_beam_search(
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
):
continue

# Don't consider answers with a length negative or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
Expand Down Expand Up @@ -381,7 +388,9 @@ def postprocess_qa_predictions_with_beam_search(
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0:
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
# Without predictions min_null_score is going to be None and None will cause an exception later
min_null_score = -2e-6
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})

# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
# the LogSumExp trick).
Expand Down