Skip to content

Commit 83e5c05

Browse files
Fix CI for metrics
1 parent af818af commit 83e5c05

File tree

4 files changed

+6
-3
lines changed

4 files changed

+6
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@
197197
"jiwer",
198198
"langdetect",
199199
"mauve-text",
200-
"nltk",
200+
"nltk<3.8.2",
201201
"rouge_score",
202202
"sacrebleu",
203203
"sacremoses",

tests/test_inspect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_inspect_dataset(path, tmp_path):
2929
@pytest.mark.filterwarnings("ignore:metric_module_factory is deprecated:FutureWarning")
3030
@pytest.mark.parametrize("path", ["accuracy"])
3131
def test_inspect_metric(path, tmp_path):
32-
inspect_metric(path, tmp_path, trust_remote_code=True)
32+
inspect_metric(path, tmp_path, trust_remote_code=True, revision="2.21")
3333
script_name = path + ".py"
3434
assert script_name in os.listdir(tmp_path)
3535
assert "__pycache__" not in os.listdir(tmp_path)

tests/test_load.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def test_GithubMetricModuleFactory_with_internal_import(self):
452452
# "squad_v2" requires additional imports (internal)
453453
factory = GithubMetricModuleFactory(
454454
"squad_v2",
455+
revision="2.21",
455456
download_config=self.download_config,
456457
dynamic_modules_path=self.dynamic_modules_path,
457458
trust_remote_code=True,
@@ -464,6 +465,7 @@ def test_GithubMetricModuleFactory_with_external_import(self):
464465
# "bleu" requires additional imports (external from github)
465466
factory = GithubMetricModuleFactory(
466467
"bleu",
468+
revision="2.21",
467469
download_config=self.download_config,
468470
dynamic_modules_path=self.dynamic_modules_path,
469471
trust_remote_code=True,

tests/test_metric.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from datasets.features import Features, Sequence, Value
1111
from datasets.metric import Metric, MetricInfo
1212

13-
from .utils import require_tf, require_torch
13+
from .utils import require_numpy1_on_windows, require_tf, require_torch
1414

1515

1616
class DummyMetric(Metric):
@@ -433,6 +433,7 @@ def test_input_numpy(self):
433433
self.assertDictEqual(expected_results, metric.compute())
434434
del metric
435435

436+
@require_numpy1_on_windows
436437
@require_torch
437438
def test_input_torch(self):
438439
import torch

0 commit comments

Comments
 (0)