Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions src/transformers/pipelines/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
from ..data import SquadExample, SquadFeatures, squad_convert_examples_to_features
from ..modelcard import ModelCard
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import PaddingStrategy, add_end_docstrings, is_tf_available, is_torch_available, logging
from ..utils import (
PaddingStrategy,
add_end_docstrings,
is_tf_available,
is_tokenizers_available,
is_torch_available,
logging,
)
from .base import PIPELINE_INIT_ARGS, ArgumentHandler, ChunkPipeline


Expand All @@ -18,6 +25,9 @@
from ..modeling_tf_utils import TFPreTrainedModel
from ..modeling_utils import PreTrainedModel

if is_tokenizers_available():
import tokenizers

if is_tf_available():
import tensorflow as tf

Expand Down Expand Up @@ -180,6 +190,7 @@ def _sanitize_parameters(
max_seq_len=None,
max_question_len=None,
handle_impossible_answer=None,
align_to_words=None,
**kwargs
):
# Set defaults values
Expand Down Expand Up @@ -208,6 +219,8 @@ def _sanitize_parameters(
postprocess_params["max_answer_len"] = max_answer_len
if handle_impossible_answer is not None:
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
if align_to_words is not None:
postprocess_params["align_to_words"] = align_to_words
return preprocess_params, {}, postprocess_params

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -243,6 +256,9 @@ def __call__(self, *args, **kwargs):
The maximum length of the question after tokenization. It will be truncated if needed.
handle_impossible_answer (`bool`, *optional*, defaults to `False`):
Whether or not we accept impossible as an answer.
align_to_words (`bool`, *optional*, defaults to `True`):
Attempts to align the answer to real words. Improves quality on space separated langages. Might hurt on
non-space-separated languages (like Japanese or Chinese)

Return:
A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
Expand Down Expand Up @@ -386,6 +402,7 @@ def postprocess(
top_k=1,
handle_impossible_answer=False,
max_answer_len=15,
align_to_words=True,
):
min_null_score = 1000000 # large and positive
answers = []
Expand Down Expand Up @@ -464,15 +481,8 @@ def postprocess(
for s, e, score in zip(starts, ends, scores):
s = s - offset
e = e - offset
try:
start_word = enc.token_to_word(s)
end_word = enc.token_to_word(e)
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
except Exception:
# Some tokenizers don't really handle words. Keep to offsets then.
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]

start_index, end_index = self.get_indices(enc, s, e, sequence_index, align_to_words)

answers.append(
{
Expand All @@ -490,6 +500,24 @@ def postprocess(
return answers[0]
return answers

def get_indices(
self, enc: "tokenizers.Encoding", s: int, e: int, sequence_index: int, align_to_words: bool
) -> Tuple[int, int]:
if align_to_words:
try:
start_word = enc.token_to_word(s)
end_word = enc.token_to_word(e)
start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
except Exception:
# Some tokenizers don't really handle words. Keep to offsets then.
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]
else:
start_index = enc.offsets[s][0]
end_index = enc.offsets[e][1]
return start_index, end_index

def decode(
self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray
) -> Tuple:
Expand Down
23 changes: 23 additions & 0 deletions tests/pipelines/test_pipelines_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,29 @@ def ensure_large_logits_postprocess(

self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"})

@slow
@require_torch
def test_small_model_japanese(self):
question_answerer = pipeline(
"question-answering",
model="KoichiYasuoka/deberta-base-japanese-aozora-ud-head",
)
output = question_answerer(question="国語", context="全学年にわたって小学校の国語の教科書に挿し絵が用いられている")

# Wrong answer, the whole text is identified as one "word" since the tokenizer does not include
# a pretokenizer
self.assertEqual(
nested_simplify(output),
{"score": 1.0, "start": 0, "end": 30, "answer": "全学年にわたって小学校の国語の教科書に挿し絵が用いられている"},
)

# Disable word alignment
output = question_answerer(question="国語", context="全学年にわたって小学校の国語の教科書に挿し絵が用いられている", align_to_words=False)
self.assertEqual(
nested_simplify(output),
{"score": 1.0, "start": 15, "end": 18, "answer": "教科書"},
)

@slow
@require_torch
def test_small_model_long_context_cls_slow(self):
Expand Down