Skip to content

Commit 97dce4f

Browse files
committed
updates and clean up on metrics class
1 parent 42ab0df commit 97dce4f

File tree

4 files changed

+149
-6
lines changed

4 files changed

+149
-6
lines changed

docs/source/using_metrics.rst

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,123 @@ Adding model predictions and references can be done using either one of the :fun
4949

5050
The model predictions and references can be provided in a wide number of formats (python lists, numpy arrays, pytorch tensors, tensorflow tensors), the metric object will take care of converting them to a suitable format for temporary storage and computation (as well as bringing them back to cpu and detaching them from gradients for PyTorch tensors).
5151

52-
The exact format of the inputs is specific to each metric script and can be read in the
52+
The exact format of the inputs is specific to each metric script and can be found in :obj:`nlp.Metric.features`, :obj:`nlp.Metric.inputs_descriptions` and the string representation of the :class:`nlp.Metric` object:
53+
54+
.. code-block::
55+
56+
>>> import nlp
57+
58+
>>> metric = nlp.load_metric('./metrics/sacrebleu')
59+
60+
>>> print(metric)
61+
Metric(name: "sacrebleu", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id='references')}, usage: """
62+
Produces BLEU scores along with its sufficient statistics
63+
from a source against one or more references.
64+
65+
Args:
66+
predictions: The system stream (a sequence of segments)
67+
references: A list of one or more reference streams (each a sequence of segments)
68+
smooth: The smoothing method to use
69+
smooth_value: For 'floor' smoothing, the floor to use
70+
force: Ignore data that looks already tokenized
71+
lowercase: Lowercase the data
72+
tokenize: The tokenizer to use
73+
Returns:
74+
'score': BLEU score,
75+
'counts': Counts,
76+
'totals': Totals,
77+
'precisions': Precisions,
78+
'bp': Brevity penalty,
79+
'sys_len': predictions length,
80+
'ref_len': reference length,
81+
""")
82+
83+
>>> print(metric.features)
84+
{'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id='references')}
85+
86+
>>> print(metric.inputs_description)
87+
88+
Produces BLEU scores along with its sufficient statistics
89+
from a source against one or more references.
90+
91+
Args:
92+
predictions: The system stream (a sequence of segments)
93+
references: A list of one or more reference streams (each a sequence of segments)
94+
smooth: The smoothing method to use
95+
smooth_value: For 'floor' smoothing, the floor to use
96+
force: Ignore data that looks already tokenized
97+
lowercase: Lowercase the data
98+
tokenize: The tokenizer to use
99+
Returns:
100+
'score': BLEU score,
101+
'counts': Counts,
102+
'totals': Totals,
103+
'precisions': Precisions,
104+
'bp': Brevity penalty,
105+
'sys_len': predictions length,
106+
'ref_len': reference length,
107+
108+
Here we can see that the ``sacrebleu`` metric expect a sequence of segments as predictions and a list of one or several sequences of segments as references.
109+
110+
You can find more information on the segments in the description, homepage and publication of ``sacrebleu`` which can be access with the respective attributes on the metric:
111+
112+
.. code-block::
113+
>>> print(metric.description)
114+
SacreBLEU provides hassle-free computation of shareable, comparable, and reproducible BLEU scores.
115+
Inspired by Rico Sennrich's `multi-bleu-detok.perl`, it produces the official WMT scores but works with plain text.
116+
It also knows all the standard test sets and handles downloading, processing, and tokenization for you.
117+
118+
See the [README.md] file at https://github.com/mjpost/sacreBLEU for more information.
119+
120+
>>> print(metric.homepage)
121+
https://github.com/mjpost/sacreBLEU
122+
>>> print(metric.citation)
123+
@inproceedings{post-2018-call,
124+
title = "A Call for Clarity in Reporting {BLEU} Scores",
125+
author = "Post, Matt",
126+
booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers",
127+
month = oct,
128+
year = "2018",
129+
address = "Belgium, Brussels",
130+
publisher = "Association for Computational Linguistics",
131+
url = "https://www.aclweb.org/anthology/W18-6319",
132+
pages = "186--191",
133+
}
134+
135+
Let's use ``sacrebleu`` with the official quick-start example on its homepage at https://github.com/mjpost/sacreBLEU:
136+
137+
.. code-block::
138+
139+
>>> reference_batch = [['The dog bit the man.', 'The dog had bit the man.'],
140+
... ['It was not unexpected.', 'No one was surprised.'],
141+
... ['The man bit him first.', 'The man had bitten the dog.']]
142+
>>> sys_batch = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.']
143+
>>> score = metric.add_batch(predictions=sys_batch, references=reference_batch)
144+
>>> print(metric)
145+
Metric(name: "sacrebleu", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id='references')}, usage: """
146+
Produces BLEU scores along with its sufficient statistics
147+
from a source against one or more references.
148+
149+
Args:
150+
predictions: The system stream (a sequence of segments)
151+
references: A list of one or more reference streams (each a sequence of segments)
152+
smooth: The smoothing method to use
153+
smooth_value: For 'floor' smoothing, the floor to use
154+
force: Ignore data that looks already tokenized
155+
lowercase: Lowercase the data
156+
tokenize: The tokenizer to use
157+
Returns:
158+
'score': BLEU score,
159+
'counts': Counts,
160+
'totals': Totals,
161+
'precisions': Precisions,
162+
'bp': Brevity penalty,
163+
'sys_len': predictions length,
164+
'ref_len': reference length,
165+
""", stored examples: 3)
166+
167+
We have stored three evaluation examples in our metric, now let's compute the score.
168+
169+
Conmputing the metric scores
170+
-----------------------------------------
171+

src/nlp/arrow_writer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def __init__(
169169
self.current_rows = []
170170
self.pa_writer: Optional[pa.RecordBatchStreamWriter] = None
171171

172+
def __len__(self):
173+
""" Return the number of writed and staged examples """
174+
return self._num_examples + len(self.current_rows)
175+
172176
def _build_writer(self, inferred_schema: pa.Schema):
173177
inferred_features = Features.from_arrow_schema(inferred_schema)
174178
if self._features is not None:

src/nlp/load.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def get_imports(file_path: str):
200200
def prepare_module(
201201
path: str,
202202
download_config: Optional[DownloadConfig] = None,
203+
download_mode: Optional[GenerateMode] = None,
203204
dataset: bool = True,
204205
force_local_path: Optional[str] = None,
205206
**download_kwargs,
@@ -335,6 +336,9 @@ def prepare_module(
335336
lock_path = local_path + ".lock"
336337
with FileLock(lock_path):
337338
# Create main dataset/metrics folder if needed
339+
if download_mode == GenerateMode.FORCE_REDOWNLOAD and os.path.exists(main_folder_path):
340+
shutil.rmtree(main_folder_path)
341+
338342
if not os.path.exists(main_folder_path):
339343
logger.info(f"Creating main folder for {module_type} {file_path} at {main_folder_path}")
340344
os.makedirs(main_folder_path, exist_ok=True)
@@ -428,6 +432,7 @@ def load_metric(
428432
experiment_id: Optional[str] = None,
429433
keep_in_memory: bool = False,
430434
download_config: Optional[DownloadConfig] = None,
435+
download_mode: Optional[GenerateMode] = None,
431436
**metric_init_kwargs,
432437
) -> Metric:
433438
r"""Load a `nlp.Metric`.
@@ -446,12 +451,13 @@ def load_metric(
446451
cache_dir (Optional str): path to store the temporary predictions and references (default to `~/.nlp/`)
447452
keep_in_memory (bool): Weither to store the temporary results in memory (defaults to False)
448453
download_config (Optional ``nlp.DownloadConfig``: specific download configuration parameters.
454+
download_mode (Optional `nlp.GenerateMode`): select the download/generate mode - Default to REUSE_DATASET_IF_EXISTS
449455
experiment_id (``str``): A specific experiment id. This is used if several distributed evaluations share the same file system.
450456
This is useful to compute metrics in distributed setups (in particular non-additive metrics like F1).
451457
452458
Returns: `nlp.Metric`.
453459
"""
454-
module_path, hash = prepare_module(path, download_config=download_config, dataset=False)
460+
module_path, hash = prepare_module(path, download_config=download_config, download_mode=download_mode, dataset=False)
455461
metric_cls = import_main_class(module_path, dataset=False)
456462
metric = metric_cls(
457463
config_name=config_name,
@@ -538,7 +544,7 @@ def load_dataset(
538544
"""
539545
ignore_verifications = ignore_verifications or save_infos
540546
# Download/copy dataset processing script
541-
module_path, hash = prepare_module(path, download_config=download_config, dataset=True)
547+
module_path, hash = prepare_module(path, download_config=download_config, download_mode=download_mode, dataset=True)
542548

543549
# Get dataset builder class from the processing script
544550
builder_cls = import_main_class(module_path, dataset=True)

src/nlp/metric.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,9 @@ def __init__(
183183
self.filelocks = None
184184

185185
def __repr__(self):
186-
return f'Metric(name: "{self.name}", features: {self.features}, usage: """{self.inputs_description}""")'
186+
return (f'Metric(name: "{self.name}", features: {self.features}, '
187+
f'usage: """{self.inputs_description}""", '
188+
f'stored examples: {0 if self.writer is None else len(self.writer)})')
187189

188190
def _build_data_dir(self):
189191
"""Path of this metric in cache_dir:
@@ -344,15 +346,27 @@ def add_batch(self, *, predictions=None, references=None):
344346
batch = self.info.features.encode_batch(batch)
345347
if self.writer is None:
346348
self._init_writer()
347-
self.writer.write_batch(batch)
349+
try:
350+
self.writer.write_batch(batch)
351+
except pa.ArrowInvalid:
352+
raise ValueError(f"Predictions and/or references don't match the expected format.\n"
353+
f"Expected format: {self.features},\n"
354+
f"Input predictions: {predictions},\n"
355+
f"Input references: {references}")
348356

349357
def add(self, *, prediction=None, reference=None):
350358
"""Add one prediction and reference for the metric's stack."""
351359
example = {"predictions": prediction, "references": reference}
352360
example = self.info.features.encode_example(example)
353361
if self.writer is None:
354362
self._init_writer()
355-
self.writer.write(example)
363+
try:
364+
self.writer.write(example)
365+
except pa.ArrowInvalid:
366+
raise ValueError(f"Prediction and/or reference don't match the expected format.\n"
367+
f"Expected format: {self.features},\n"
368+
f"Input predictions: {prediction},\n"
369+
f"Input references: {reference}")
356370

357371
def _init_writer(self):
358372
if self.keep_in_memory:

0 commit comments

Comments
 (0)