|
| 1 | +Writing a metric loading script |
| 2 | +============================================= |
| 3 | + |
| 4 | +If you want to use your own metric, or if you would like to share a new metric with the community, for instance in the `HuggingFace Hub <https://huggingface.co/metrics>`__, then you can define a new metric loading script. |
| 5 | + |
| 6 | +This chapter will explain how metrics are loaded and how you can write from scratch or adapt a metric loading script. |
| 7 | + |
| 8 | +.. note:: |
| 9 | + |
| 10 | + You can start from the `template for a metric loading script <https://github.com/huggingface/nlp/blob/master/templates/new_metric_script.py>`__ when writing a new metric loading script. You can find this template in the ``templates`` folder on the github repository. |
| 11 | + |
| 12 | + |
| 13 | +To create a new metric loading script one mostly needs to specify three methods in a :class:`nlp.Metric` class: |
| 14 | + |
| 15 | +- :func:`nlp.Metric._info` which is in charge of specifying the metric metadata as a :obj:`nlp.MetricInfo` dataclass and in particular the :class:`nlp.Features` which defined the types of the predictions and the references, |
| 16 | +- :func:`nlp.Metric._compute` which is in charge of computing the actual score(s), given some predictions and references. |
| 17 | + |
| 18 | +.. note:: |
| 19 | + |
| 20 | + Note on naming: the metric class should be camel case, while the metric name is its snake case equivalent (ex: :obj:`class Rouge(nlp.Metric)` for the metric ``rouge``). |
| 21 | + |
| 22 | + |
| 23 | +Adding metric metadata |
| 24 | +---------------------------------- |
| 25 | + |
| 26 | +The :func:`nlp.Metric._info` method is in charge of specifying the metric metadata as a :obj:`nlp.MetricInfo` dataclass and in particular the :class:`nlp.Features` which defined the types of the predictions and the references. :class:`nlp.MetricInfo` has a predefined set of attributes and cannot be extended. The full list of attributes can be found in the package reference. |
| 27 | + |
| 28 | +The most important attributes to specify are: |
| 29 | + |
| 30 | +- :attr:`nlp.MetricInfo.features`: a :class:`nlp.Features` instance defining the name and the type the predictions and references, |
| 31 | +- :attr:`nlp.MetricInfo.description`: a :obj:`str` describing the metric, |
| 32 | +- :attr:`nlp.MetricInfo.citation`: a :obj:`str` containing the citation for the metric in a BibTex format for inclusion in communications citing the metric, |
| 33 | +- :attr:`nlp.MetricInfo.homepage`: a :obj:`str` containing an URL to an original homepage of the metric. |
| 34 | +- :attr:`nlp.MetricInfo.format`: an optional :obj:`str` to tell what is the format of the predictions and the references passed to _compute. It can be set to "numpy", "torch", "tensorflow" or "pandas". |
| 35 | + |
| 36 | +Here is for instance the :func:`nlp.Metric._info` for the Sacrebleu metric for instance, which is taken from the `sacrebleu metric loading script <https://github.com/huggingface/nlp/tree/master/metrics/sacrebleu/sacrebleu.py>`__ |
| 37 | + |
| 38 | +.. code-block:: |
| 39 | +
|
| 40 | + def _info(self): |
| 41 | + return nlp.MetricInfo( |
| 42 | + description=_DESCRIPTION, |
| 43 | + citation=_CITATION, |
| 44 | + homepage="https://github.com/mjpost/sacreBLEU", |
| 45 | + inputs_description=_KWARGS_DESCRIPTION, |
| 46 | + features=nlp.Features({ |
| 47 | + 'predictions': nlp.Value('string'), |
| 48 | + 'references': nlp.Sequence(nlp.Value('string')), |
| 49 | + }), |
| 50 | + codebase_urls=["https://github.com/mjpost/sacreBLEU"], |
| 51 | + reference_urls=["https://github.com/mjpost/sacreBLEU", |
| 52 | + "https://en.wikipedia.org/wiki/BLEU", |
| 53 | + "https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213"] |
| 54 | + ) |
| 55 | +
|
| 56 | +
|
| 57 | +The :class:`nlp.Features` define the type of the predictions and the references and can define arbitrary nested objects with fields of various types. More details on the available ``features`` can be found in the guide on features :doc:`features` and in the package reference on :class:`nlp.Features`. Many examples of features can also be found in the various `metric scripts provided on the GitHub repository <https://github.com/huggingface/nlp/tree/master/metrics>`__ and even in `dataset scripts provided on the GitHub repository <https://github.com/huggingface/nlp/tree/master/datasets>`__ or directly inspected on the `🤗nlp viewer <https://huggingface.co/nlp/viewer>`__. |
| 58 | + |
| 59 | +Here are the features of the SQuAD metric for instance, which is taken from the `squad metric loading script <https://github.com/huggingface/nlp/tree/master/metrics/squad/squad.py>`__: |
| 60 | + |
| 61 | +.. code-block:: |
| 62 | +
|
| 63 | + nlp.Features({ |
| 64 | + 'predictions': nlp.Value('string'), |
| 65 | + 'references': nlp.Sequence(nlp.Value('string')), |
| 66 | + }), |
| 67 | +
|
| 68 | +We can see that each prediction is a string, and each reference is a sequence of strings. |
| 69 | +Indeed we can use the metric the following way: |
| 70 | + |
| 71 | +.. code-block:: |
| 72 | +
|
| 73 | + >>> import nlp |
| 74 | +
|
| 75 | + >>> metric = nlp.load_metric('./metrics/sacrebleu') |
| 76 | + >>> reference_batch = [['The dog bit the man.', 'The dog had bit the man.'], |
| 77 | + ... ['It was not unexpected.', 'No one was surprised.'], |
| 78 | + ... ['The man bit him first.', 'The man had bitten the dog.']] |
| 79 | + >>> sys_batch = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.'] |
| 80 | + >>> score = metric.add_batch(predictions=sys_batch, references=reference_batch) |
| 81 | + >>> print(metric) |
| 82 | +
|
| 83 | +
|
| 84 | +Downloading data files |
| 85 | +------------------------------------------------- |
| 86 | + |
| 87 | +The :func:`nlp.Metric._download_and_prepare` method is in charge of downloading (or retrieving locally the data files) if needed. |
| 88 | + |
| 89 | +This method **takes as input** a :class:`nlp.DownloadManager` which is a utility which can be used to download files (or to retrieve them from the local filesystem if they are local files or are already in the cache). |
| 90 | + |
| 91 | +Let's have a look at a simple example of a :func:`nlp.Metric._download_and_prepare` method. We'll take the example of the `bleurt metric loading script <https://github.com/huggingface/nlp/tree/master/metrics/bleurt/bleurt.py>`__: |
| 92 | + |
| 93 | +.. code-block:: |
| 94 | +
|
| 95 | + def _download_and_prepare(self, dl_manager): |
| 96 | +
|
| 97 | + # check that config name specifies a valid BLEURT model |
| 98 | + if self.config_name not in CHECKPOINT_URLS.keys(): |
| 99 | + raise KeyError(f"{self.config_name} model not found. You should supply the name of a model checkpoint for bleurt in {CHECKPOINT_URLS.keys()}") |
| 100 | +
|
| 101 | + # download the model checkpoint specified by self.config_name and set up the scorer |
| 102 | + model_path = dl_manager.download_and_extract(CHECKPOINT_URLS[self.config_name]) |
| 103 | + self.scorer = score.BleurtScorer(os.path.join(model_path, self.config_name)) |
| 104 | +
|
| 105 | +As you can see this method downloads a model checkpoint depending of the configuration name of the metric. The checkpoint url is then provided to the :func:`nlp.DownloadManager.download_and_extract` method which will take care of downloading or retrieving the file from the local file system and returning a object of the same type and organization (here a just one path, but it could be a list or a dict of paths) with the path to the local version of the requested files. :func:`nlp.DownloadManager.download_and_extract` can take as input a single URL/path or a list or dictionary of URLs/paths and will return an object of the same structure (single URL/path, list or dictionary of URLs/paths) with the path to the local files. This method also takes care of extracting compressed tar, gzip and zip archives. |
| 106 | + |
| 107 | +:func:`nlp.DownloadManager.download_and_extract` can download files from a large set of origins but if your data files are hosted on a special access server, it's also possible to provide a callable which will take care of the downloading process to the ``DownloadManager`` using :func:`nlp.DownloadManager.download_custom`. |
| 108 | + |
| 109 | +.. note:: |
| 110 | + |
| 111 | + In addition to :func:`nlp.DownloadManager.download_and_extract` and :func:`nlp.DownloadManager.download_custom`, the :class:`nlp.DownloadManager` class also provide more fine-grained control on the download and extraction process through several methods including: :func:`nlp.DownloadManager.download`, :func:`nlp.DownloadManager.extract` and :func:`nlp.DownloadManager.iter_archive`. Please refer to the package reference on :class:`nlp.DownloadManager` for details on these methods. |
| 112 | + |
| 113 | + |
| 114 | +Computing the scores |
| 115 | +------------------------------------------------- |
| 116 | + |
| 117 | +The :func:`nlp.DatasetBuilder._compute` is in charge of computing the metric scores given predictions and references that are in the format specified in the ``features`` set in :func:`nlp.DatasetBuilder._info`. |
| 118 | + |
| 119 | +Here again, let's take the simple example of the `xnli metric loading script <https://github.com/huggingface/nlp/tree/master/metrics/squad/squad.py>`__: |
| 120 | + |
| 121 | +.. code-block:: |
| 122 | +
|
| 123 | + def simple_accuracy(preds, labels): |
| 124 | + return (preds == labels).mean() |
| 125 | +
|
| 126 | + class Xnli(nlp.Metric): |
| 127 | + def _info(self): |
| 128 | + return nlp.MetricInfo( |
| 129 | + description=_DESCRIPTION, |
| 130 | + citation=_CITATION, |
| 131 | + inputs_description=_KWARGS_DESCRIPTION, |
| 132 | + features=nlp.Features({ |
| 133 | + 'predictions': nlp.Value('int64' if self.config_name != 'sts-b' else 'float32'), |
| 134 | + 'references': nlp.Value('int64' if self.config_name != 'sts-b' else 'float32'), |
| 135 | + }), |
| 136 | + codebase_urls=[], |
| 137 | + reference_urls=[], |
| 138 | + format='numpy' |
| 139 | + ) |
| 140 | +
|
| 141 | + def _compute(self, predictions, references): |
| 142 | + return {"accuracy": simple_accuracy(predictions, references)} |
| 143 | +
|
| 144 | +Here to compute the accuracy it uses the simple_accuracy function, that uses numpy to compute the accuracy using .mean() |
| 145 | + |
| 146 | +The predictions and references objects passes to _compute are sequences of integers or floats, and the sequences are formated as numpy arrays since the format specified in the :obj:`nlp.MetricInfo` object is set to "numpy". |
| 147 | + |
| 148 | +Specifying several metric configurations |
| 149 | +------------------------------------------------- |
| 150 | + |
| 151 | +Sometimes you want to provide several ways of computing the scores. |
| 152 | + |
| 153 | +It is possible to gave different configurations for a metric. The configuration name is stored in :obj:`nlp.Metric.config_name` attribute. The configuration name can be specified by the user when instantiating a metric: |
| 154 | + |
| 155 | +.. code-block:: |
| 156 | +
|
| 157 | + >>> from nlp import load_metric |
| 158 | + >>> metric = load_metric('bleurt', name='bleurt-base-128') |
| 159 | + >>> metric = load_metric('bleurt', name='bleurt-base-512') |
| 160 | +
|
| 161 | +Here depending on the configuration name, a different checkpoint will be downloaded and used to compute the BLEURT score. |
| 162 | + |
| 163 | +You can access :obj:`nlp.Metric.config_name` from inside :func:`nlp.Metric._info`, :func:`nlp.Metric._download_and_prepare` and :func:`nlp.Metric._compute` |
| 164 | + |
| 165 | +Testing the metric loading script |
| 166 | +------------------------------------------------- |
| 167 | + |
| 168 | +Once you're finished with creating or adapting a metric loading script, you can try it locally by giving the path to the metric loading script: |
| 169 | + |
| 170 | +.. code-block:: |
| 171 | +
|
| 172 | + >>> from nlp import load_metric |
| 173 | + >>> metric = load_metric('PATH/TO/MY/SCRIPT.py') |
| 174 | +
|
| 175 | +If your metric has several configurations you can use the arguments of :func:`nlp.load_metric` accordingly: |
| 176 | + |
| 177 | +.. code-block:: |
| 178 | +
|
| 179 | + >>> from nlp import load_metric |
| 180 | + >>> metric = load_metric('PATH/TO/MY/SCRIPT.py', 'my_configuration') |
| 181 | +
|
| 182 | +
|
0 commit comments