Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 112 additions & 26 deletions sacrebleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,25 +673,22 @@ def download_test_set(test_set, langpair=None):
BLEU = namedtuple('BLEU', 'score, ngram1, ngram2, ngram3, ngram4, bp, sys_len, ref_len')


def compute_bleu(instream, refstreams, smooth=0., force=False, lc=False, tokenize=False) -> BLEU:
def compute_bleu(instream, refstreams, smooth=0., force=False, lc=False, tokenize=False, bootstrap_trials=1) -> BLEU:
"""Produces the BLEU scores along with its sufficient statistics from a source against one or more references.

:param instream: the input stream, one segment per line
:param refstreams: a list of reference streams
:param bootstrap_trials=1: number of trials for bootstrap resampling
:return: a BLEU object containing everything you'd want
"""

fhs = [sys.stdin] + refstreams

sys_len = 0
ref_len = 0

correct = defaultdict(int)
total = defaultdict(int)

# look for already-tokenized sentences
tokenized_count = 0

# Pre-compute segment-level data for BLEU computation.
segmentdata = defaultdict(list)
for sentno, lines in enumerate(zip(*fhs)):
if lc:
lines = [x.lower() for x in lines]
Expand All @@ -706,35 +703,122 @@ def compute_bleu(instream, refstreams, smooth=0., force=False, lc=False, tokeniz
sys.exit(1)

output, *refs = [tokenizers[tokenize](x.rstrip()) for x in lines]

sys_ngrams = extract_ngrams(output)
ref_ngrams, closest_diff, closest_len = ref_stats(output, refs)

sys_len += len(output.split())
ref_len += closest_len
local_correct = defaultdict(int)
local_total = defaultdict(int)

sys_ngrams = extract_ngrams(output)
for ngram in sys_ngrams.keys():
n = len(ngram.split())

total[n] += sys_ngrams[ngram]
correct[n] += min(sys_ngrams[ngram], ref_ngrams.get(ngram, 0))

if sum(total) == 0:
logging.error('No input?')
sys.exit(1)
local_total[n] += sys_ngrams[ngram]
local_correct[n] += min(sys_ngrams[ngram], ref_ngrams.get(ngram, 0))

segmentdata[sentno].append(len(output.split())) # 0: output_len
segmentdata[sentno].append(closest_diff) # 1: closest_diff (unused)
segmentdata[sentno].append(closest_len) # 2: closest_len
segmentdata[sentno].append(local_total) # 3: local_total
segmentdata[sentno].append(local_correct) # 4: local_correct

# Based on pre-computed segment-level data, compute BLEU score for input.
#
# This requires seeding the RNG to get reproducible results. For now,
# we simply freeze the seed value as 12345. This can later be changed
# so that is is configurable. If so, the random seed needs to become
# part of the sacreBLEU signature for future reference.
from random import seed, randrange
seed(12345)

# Size of keys set equals set size
set_size = len(segmentdata.keys())

trial_runs = []
for trial_run in range(bootstrap_trials):
sys_len = 0
ref_len = 0

correct = defaultdict(int)
total = defaultdict(int)

# First trial run will always use normal test set. This results in
# desired behaviour for bootstrap_trials=1, i.e., a single run.
if trial_run == 0:
input_data = segmentdata.keys()

# Subsequent trial runs will draw with replacement from keys set.
else:
input_data = (randrange(0, set_size-1) for _ in range(set_size))

precisions = [0, 0, 0, 0, 0]
# Compute BLEU score for current trial, based on pre-computed data.
for sentno in input_data:
output_len = segmentdata[sentno][0]
closest_diff = segmentdata[sentno][1]
closest_len = segmentdata[sentno][2]
local_total = segmentdata[sentno][3]
local_correct = segmentdata[sentno][4]

for n in range(1, 5):
precisions[n] = max(smooth, 100. * correct[n] / total[n] if total.get(n) > 0 else 0)
sys_len += output_len
ref_len += closest_len

brevity_penalty = 1.0
if sys_len < ref_len:
brevity_penalty = math.exp(1 - ref_len / sys_len)
for n in local_total.keys():
total[n] += local_total[n]
correct[n] += local_correct[n]

bleu = 1. * brevity_penalty * math.exp(sum(map(my_log, precisions[1:])) / 4)
if sum(total) == 0:
logging.error('No input?')
sys.exit(1)

return BLEU._make([bleu, precisions[1], precisions[2], precisions[3], precisions[4], brevity_penalty, sys_len, ref_len])
precisions = [0, 0, 0, 0, 0]

for n in range(1, 5):
precisions[n] = max(smooth, 100. * correct[n] / total[n] if total.get(n) > 0 else 0)

brevity_penalty = 1.0
if sys_len < ref_len:
brevity_penalty = math.exp(1 - ref_len / sys_len)

bleu = 1. * brevity_penalty * math.exp(sum(map(my_log, precisions[1:])) / 4)
trial_runs.append([bleu, precisions[1], precisions[2], precisions[3], precisions[4], brevity_penalty, sys_len, ref_len])

# Compute average BLEU score and component values.
avgBleu = [
sum(x[0] for x in trial_runs) / len(trial_runs), # bleu
sum(x[1] for x in trial_runs) / len(trial_runs), # precisions[1]
sum(x[2] for x in trial_runs) / len(trial_runs), # precisions[2]
sum(x[3] for x in trial_runs) / len(trial_runs), # precisions[3]
sum(x[4] for x in trial_runs) / len(trial_runs), # precisions[4]
sum(x[5] for x in trial_runs) / len(trial_runs), # brevity_penalty
int(sum(x[6] for x in trial_runs) / len(trial_runs)), # sys_len
int(sum(x[7] for x in trial_runs) / len(trial_runs)), # ref_len
]

if bootstrap_trials > 1:
print('Bootstrap trials: n={0}'.format(bootstrap_trials))
allBleuScores = [x[0] for x in trial_runs]
try:
from numpy import mean, std
from math import sqrt

# Compute 0.95 confidence interval around BLEU score mean.
xbar = mean(allBleuScores)
s = std(allBleuScores)
sqrtn = sqrt(bootstrap_trials)
z = 1.96
confidenceInterval = z * s / sqrtn

except ImportError:
logger.error('Could not import numpy for confidence interval computation')
xbar = sum(allBleuScores) / len(allBleuScores)
confidenceInterval = None

finally:
if confidenceInterval:
print('Mean BLEU score: {0:.2f} +/- {1:.2f}'.format(xbar, confidenceInterval))
else:
print('Mean BLEU score: {0:.2f}'.format(xbar))

return BLEU._make(avgBleu)


def main():
Expand Down Expand Up @@ -766,6 +850,8 @@ def main():
help='Suppress informative output.')
arg_parser.add_argument('--encoding', '-e', type=str, default='utf-8',
help='Open text files with specified encoding (default: %(default)s)')
arg_parser.add_argument('--bootstrap-trials', '-b', type=int, default=1,
help='Compute BLEU based on bootstrap resampling with n trials (default: %(default)d)')
arg_parser.add_argument('-V', '--version', action='version',
version='%(prog)s {}'.format(VERSION))
args = arg_parser.parse_args()
Expand Down Expand Up @@ -820,7 +906,7 @@ def main():
logging.warn('You should also pass "--tok zh" when scoring Chinese...')

bleu = compute_bleu(sys.stdin, refs, smooth=args.smooth, force=args.force,
lc=args.lc, tokenize=args.tokenize)
lc=args.lc, tokenize=args.tokenize, bootstrap_trials=args.bootstrap_trials)

version_str = build_signature(args, len(refs))

Expand Down