Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.13.1]
### Added
- Added chrF metric
([Popovic 2015: chrF: character n-gram F-score for automatic MT evaluation](http://www.statmt.org/wmt15/pdf/WMT49.pdf)) to Sockeye.
sockeye.evaluate now accepts `bleu` and `chrf` as values for `--metrics`

## [1.13.0]
### Fixed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.13.0'
__version__ = '1.13.1'
6 changes: 5 additions & 1 deletion sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,9 +881,13 @@ def add_evaluate_args(params):
eval_params.add_argument('--quiet', '-q',
action="store_true",
help="Do not print logging information.")
eval_params.add_argument('--metrics',
nargs='+',
default=[C.BLEU, C.CHRF],
help='List of metrics to compute. Default: %(default)s.')
eval_params.add_argument('--sentence', '-s',
action="store_true",
help="Show sentence-BLEU. Default: %(default)s.")
help="Show sentence-level metrics. Default: %(default)s.")
eval_params.add_argument('--offset',
type=float,
default=0.01,
Expand Down
4 changes: 4 additions & 0 deletions sockeye/checkpoint_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import sockeye.output_handler
from . import evaluate
from . import chrf
from . import constants as C
from . import data_io
from . import inference
Expand Down Expand Up @@ -148,4 +149,7 @@ def decode_and_evaluate(self,
return {C.BLEU_VAL: evaluate.raw_corpus_bleu(hypotheses=translations,
references=self.target_sentences,
offset=0.01),
C.CHRF_VAL: chrf.corpus_chrf(hypotheses=translations,
references=self.target_sentences,
trim_whitespaces=True),
C.AVG_TIME: avg_time}
137 changes: 137 additions & 0 deletions sockeye/chrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
# is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

"""
Computes chrF scores as described in
'CHRF: character n-gram F-score for automatic MT evaluation' by Maja Popovic.
[http://www.statmt.org/wmt15/pdf/WMT49.pdf]
"""

import re
from collections import Counter
from typing import Iterable, Tuple

import numpy as np

ORDER = 6
BETA = 3.0
TRIM_WS = True


def extract_ngrams(s: str, n: int) -> Counter:
"""
Yields counts of character n-grams from string s of order n.
"""
return Counter([s[i:i + n] for i in range(len(s) - n + 1)])


def delete_whitespace(text: str) -> str:
"""
Removes whitespaces from text.
"""
return re.sub("\s+", "", text)


def get_sentence_statistics(hypothesis: str,
reference: str,
order: int = ORDER,
trim_whitespaces: bool = TRIM_WS) -> np.array:
hypothesis = delete_whitespace(hypothesis) if trim_whitespaces else hypothesis
reference = delete_whitespace(reference) if trim_whitespaces else reference
statistics = np.zeros((order * 3))
for i in range(order):
n = i + 1
hypothesis_ngrams = extract_ngrams(hypothesis, n)
reference_ngrams = extract_ngrams(reference, n)
common_ngrams = hypothesis_ngrams & reference_ngrams
statistics[3 * i + 0] = sum(hypothesis_ngrams.values())
statistics[3 * i + 1] = sum(reference_ngrams.values())
statistics[3 * i + 2] = sum(common_ngrams.values())
return statistics


def get_corpus_statistics(hypotheses: Iterable[str],
references: Iterable[str],
order: int = ORDER,
trim_whitespaces: bool = TRIM_WS) -> np.array:
corpus_statistics = np.zeros((order * 3))
for hypothesis, reference in zip(hypotheses, references):
statistics = get_sentence_statistics(hypothesis, reference, order=order, trim_whitespaces=trim_whitespaces)
corpus_statistics += statistics
return corpus_statistics


def _avg_precision_and_recall(statistics: np.array, order: int) -> Tuple[float, float]:
avg_precision = 0.0
avg_recall = 0.0
effective_order = 0
for i in range(order):
hypotheses_ngrams = statistics[3 * i + 0]
references_ngrams = statistics[3 * i + 1]
common_ngrams = statistics[3 * i + 2]
if hypotheses_ngrams > 0 and references_ngrams > 0:
avg_precision += common_ngrams / hypotheses_ngrams
avg_recall += common_ngrams / references_ngrams
effective_order += 1
if effective_order == 0:
return 0.0, 0.0
avg_precision /= effective_order
avg_recall /= effective_order
return avg_precision, avg_recall


def _chrf(avg_precision, avg_recall, beta: float = BETA) -> float:
if avg_precision + avg_recall == 0:
return 0.0
beta_square = beta ** 2
return (1 + beta_square) * (avg_precision * avg_recall) / ((beta_square * avg_precision) + avg_recall)


def corpus_chrf(hypotheses: Iterable[str],
references: Iterable[str],
order: int = ORDER,
trim_whitespaces: bool = TRIM_WS,
beta: float = BETA) -> float:
"""
Computes Chrf on a corpus.

:param hypotheses: Stream of hypotheses.
:param references: Stream of references
:param order: Maximum n-gram order.
:param trim_whitespaces: Whether to trim whitespaces from hypothesis and reference strings.
:param beta: Defines importance of recall w.r.t precision. If beta=1, same importance.
:return: Chrf score.
"""
corpus_statistics = get_corpus_statistics(hypotheses, references, order=order, trim_whitespaces=trim_whitespaces)
avg_precision, avg_recall = _avg_precision_and_recall(corpus_statistics, order)
return _chrf(avg_precision, avg_recall, beta=beta)


def sentence_chrf(hypothesis: str,
reference: str,
order: int = ORDER,
trim_whitespaces: bool = TRIM_WS,
beta: float = BETA) -> float:
"""
Computes Chrf on a single sentence pair.

:param hypothesis: Hypothesis string.
:param reference: Reference string.
:param order: Maximum n-gram order.
:param trim_whitespaces: Whether to trim whitespaces from hypothesis and reference strings.
:param beta: Defines importance of recall w.r.t precision. If beta=1, same importance.
:return: Chrf score.
"""
statistics = get_sentence_statistics(hypothesis, reference, order=order, trim_whitespaces=trim_whitespaces)
avg_precision, avg_recall = _avg_precision_and_recall(statistics, order)
return _chrf(avg_precision, avg_recall, beta=beta)
2 changes: 2 additions & 0 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@
ACCURACY = 'accuracy'
PERPLEXITY = 'perplexity'
BLEU = 'bleu'
CHRF = 'chrf'
BLEU_VAL = BLEU + "-val"
CHRF_VAL = CHRF + "-val"
AVG_TIME = "avg-sec-per-sent-val"
METRICS = [PERPLEXITY, ACCURACY, BLEU]
METRIC_MAXIMIZE = {ACCURACY: True, BLEU: True, PERPLEXITY: False}
Expand Down
28 changes: 22 additions & 6 deletions sockeye/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from contrib import sacrebleu
from sockeye.log import setup_main_logger, log_sockeye_version
from . import arguments
from . import chrf
from . import constants as C
from . import data_io
from . import utils

Expand All @@ -41,8 +43,8 @@ def raw_corpus_bleu(hypotheses: Iterable[str], references: Iterable[str], offset


def main():
params = argparse.ArgumentParser(description='Evaluate translations by calculating 4-BLEU '
'score with respect to a reference set.')
params = argparse.ArgumentParser(description='Evaluate translations by calculating metrics with '
'respect to a reference set.')
arguments.add_evaluate_args(params)
args = params.parse_args()

Expand All @@ -65,12 +67,26 @@ def main():
len(references)))

if not args.sentence:
bleu = raw_corpus_bleu(hypotheses, references, args.offset)
print(bleu, file=sys.stdout)
scores = []
for metric in args.metrics:
if metric == C.BLEU:
bleu_score = raw_corpus_bleu(hypotheses, references, args.offset)
scores.append("%.6f" % bleu_score)
elif metric == C.CHRF:
chrf_score = chrf.corpus_chrf(hypotheses, references, trim_whitespaces=True)
scores.append("%.6f" % chrf_score)
print("\t".join(scores), file=sys.stdout)
else:
for h, r in zip(hypotheses, references):
bleu = raw_corpus_bleu(h, r, args.offset)
print(bleu, file=sys.stdout)
scores = []
for metric in args.metrics:
if metric == C.BLEU:
bleu = raw_corpus_bleu(h, r, args.offset)
scores.append("%.6f" % bleu)
elif metric == C.CHRF:
chrf_score = chrf.corpus_chrf(h, r, trim_whitespaces=True)
scores.append("%.6f" % chrf_score)
print("\t".join(scores), file=sys.stdout)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion sockeye/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def get_gpu_memory_usage(ctx: List[mx.context.Context]) -> Optional[Dict[int, Tu


def log_gpu_memory_usage(memory_data: Dict[int, Tuple[int, int]]):
log_str = " ".join("GPU %d: %d/%d MB (%.2f%%)" %(k, v[0], v[1], v[0] * 100.0/v[1]) for k, v in memory_data.items())
log_str = " ".join("GPU %d: %d/%d MB (%.2f%%)" % (k, v[0], v[1], v[0] * 100.0/v[1]) for k, v in memory_data.items())
logger.info(log_str)


Expand Down
29 changes: 16 additions & 13 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import sockeye.utils

from sockeye.evaluate import raw_corpus_bleu
from sockeye.chrf import corpus_chrf


def gaussian_vector(shape, return_symbol=False):
Expand Down Expand Up @@ -149,7 +150,7 @@ def tmp_digits_dataset(prefix: str,

_TRANSLATE_PARAMS_RESTRICT = "--restrict-lexicon {json}"

_EVAL_PARAMS_COMMON = "--hypotheses {hypotheses} --references {references}"
_EVAL_PARAMS_COMMON = "--hypotheses {hypotheses} --references {references} --metrics {metrics}"


def run_train_translate(train_params: str,
Expand All @@ -161,7 +162,7 @@ def run_train_translate(train_params: str,
dev_target_path: str,
max_seq_len: int = 10,
restrict_lexicon: bool = False,
work_dir: Optional[str] = None) -> Tuple[float, float, float]:
work_dir: Optional[str] = None) -> Tuple[float, float, float, float]:
"""
Train a model and translate a dev set. Report validation perplexity and BLEU.

Expand All @@ -175,7 +176,7 @@ def run_train_translate(train_params: str,
:param max_seq_len: The maximum sequence length.
:param restrict_lexicon: Additional translation run with top-k lexicon-based vocabulary restriction.
:param work_dir: The directory to store the model and other outputs in.
:return: A tuple containing perplexity and bleu scores for standard and reduced vocab decoding.
:return: A tuple containing perplexity, bleu scores for standard and reduced vocab decoding, chrf score.
"""
with TemporaryDirectory(dir=work_dir, prefix="test_train_translate.") as work_dir:
# Train model
Expand Down Expand Up @@ -243,7 +244,6 @@ def run_train_translate(train_params: str,
with patch.object(sys, "argv", params.split()):
sockeye.translate.main()


# test averaging
points = sockeye.average.find_checkpoints(model_path=model_path,
size=1,
Expand All @@ -257,20 +257,23 @@ def run_train_translate(train_params: str,
metrics = sockeye.utils.read_metrics_file(path=os.path.join(model_path, C.METRICS_NAME))
perplexity = metrics[-1][C.PERPLEXITY + '-val']

# Measure BLEU
bleu = raw_corpus_bleu(hypotheses=open(out_path, "r").readlines(),
references=open(dev_target_path, "r").readlines(),
offset=0.01)
hypotheses = open(out_path, "r").readlines()
references = open(dev_target_path, "r").readlines()

# compute metrics
bleu = raw_corpus_bleu(hypotheses=hypotheses, references=references, offset=0.01)
chrf = corpus_chrf(hypotheses=hypotheses, references=references)

bleu_restrict = None
if restrict_lexicon:
bleu_restrict = raw_corpus_bleu(hypotheses=open(out_restrict_path, "r").readlines(),
references=open(dev_target_path, "r").readlines(),
offset=0.01)
bleu_restrict = raw_corpus_bleu(hypotheses=hypotheses, references=references, offset=0.01)

# Run BLEU cli
eval_params = "{} {} ".format(sockeye.evaluate.__file__,
_EVAL_PARAMS_COMMON.format(hypotheses=out_path, references=dev_target_path), )
_EVAL_PARAMS_COMMON.format(hypotheses=out_path,
references=dev_target_path,
metrics="bleu chrf"), )
with patch.object(sys, "argv", eval_params.split()):
sockeye.evaluate.main()

return perplexity, bleu, bleu_restrict
return perplexity, bleu, bleu_restrict, chrf
1 change: 1 addition & 0 deletions test/integration/test_seq_copy_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
"--beam-size 2",
True)]


@pytest.mark.parametrize("train_params, translate_params, restrict_lexicon", ENCODER_DECODER_SETTINGS)
def test_seq_copy(train_params: str, translate_params: str, restrict_lexicon: bool):
"""Task: copy short sequences of digits"""
Expand Down
Loading