Skip to content

Commit d2aadb3

Browse files
committed
Change exception to warning. Only discard invalid labels and examples containing only invalid labels.
1 parent 41e969f commit d2aadb3

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

spacy_llm/tasks/util/span.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type
23

34
import jinja2
@@ -36,12 +37,14 @@ def __init__(
3637
self._case_sensitive_matching = case_sensitive_matching
3738
self._single_match = single_match
3839

39-
self._check_label_consistency()
40+
if self._examples:
41+
self._examples = self._check_label_consistency()
4042

41-
def _check_label_consistency(self) -> None:
42-
"""Checks consistency of labels between examples and defined labels."""
43-
if not self._examples:
44-
return
43+
def _check_label_consistency(self) -> List[SpanExample]:
44+
"""Checks consistency of labels between examples and defined labels. Emits warning on inconsistency.
45+
RETURNS ():
46+
"""
47+
assert self._examples
4548
example_labels = {
4649
self._normalizer(key): key
4750
for example in self._examples
@@ -52,13 +55,30 @@ def _check_label_consistency(self) -> None:
5255
for key in (set(example_labels.keys()) - set(self._label_dict.keys()))
5356
}
5457
if not set(example_labels.keys()) <= set(self._label_dict.keys()):
55-
raise ValueError(
58+
warnings.warn(
5659
f"Examples contain labels that are not specified in the task configuration. The latter contains the "
5760
f"following labels: {sorted(list(set(self._label_dict.values())))}. Labels in examples missing from "
5861
f"the task configuration: {sorted(list(unspecified_labels))}. Please ensure your label specification "
5962
f"and example labels are consistent."
6063
)
6164

65+
# Return examples without non-declared labels. If an example only has undeclared labels, it is discarded.
66+
return [
67+
example
68+
for example in [
69+
SpanExample(
70+
text=example.text,
71+
entities={
72+
label: entities
73+
for label, entities in example.entities.items()
74+
if self._normalizer(label) in self._label_dict
75+
},
76+
)
77+
for example in self._examples
78+
]
79+
if len(example.entities)
80+
]
81+
6282
@property
6383
def labels(self) -> Tuple[str, ...]:
6484
return tuple(self._label_dict.values())

spacy_llm/tests/tasks/examples/ner_inconsistent.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,14 @@
77
- hill
88
- text: spaCy is a great tool
99
entities:
10+
TECH:
11+
- spaCy
12+
- text: Jack and Jill went up the hill and spaCy is a great tool.
13+
entities:
14+
PERSON:
15+
- Jack
16+
- Jill
17+
LOCATION:
18+
- hill
1019
TECH:
1120
- spaCy

spacy_llm/tests/tasks/test_ner.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -851,12 +851,21 @@ def test_label_inconsistency():
851851
"""
852852

853853
config = Config().from_str(cfg)
854-
with pytest.raises(
855-
ValueError,
854+
with pytest.warns(
855+
UserWarning,
856856
match=re.escape(
857857
"Examples contain labels that are not specified in the task configuration. The latter contains the "
858858
"following labels: ['LOCATION', 'PERSON']. Labels in examples missing from the task configuration: "
859859
"['TECH']. Please ensure your label specification and example labels are consistent."
860860
),
861861
):
862-
assemble_from_config(config)
862+
nlp = assemble_from_config(config)
863+
864+
examples = nlp.get_pipe("llm")._task._examples
865+
assert len(examples) == 2
866+
assert examples[0].text == "Jack and Jill went up the hill."
867+
assert examples[0].entities == {"LOCATION": ["hill"], "PERSON": ["Jack", "Jill"]}
868+
assert (
869+
examples[1].text == "Jack and Jill went up the hill and spaCy is a great tool."
870+
)
871+
assert examples[1].entities == {"LOCATION": ["hill"], "PERSON": ["Jack", "Jill"]}

0 commit comments

Comments
 (0)