Skip to content

Commit dfd43fe

Browse files
rmitschsvlandeg
andauthored
Detect label inconsistency in SpanTask (#183)
* Check for label consistency in span tasks. * Relax label comparison to sub-/superset one. * Fix set comparison. * Update spacy_llm/tests/tasks/examples/ner_inconsistent.yml Co-authored-by: Sofie Van Landeghem <[email protected]> * Expand error message. * Fix tests. * Change exception to warning. Only discard invalid labels and examples containing only invalid labels. * Fix test config error. * Fix test config error. * Incorporate feedback. * Rename self._examples to self._prompt_examples. * Fix non-REL test failures. * Fix REL issue. --------- Co-authored-by: Sofie Van Landeghem <[email protected]>
1 parent 514db77 commit dfd43fe

File tree

7 files changed

+121
-6
lines changed

7 files changed

+121
-6
lines changed

spacy_llm/tasks/rel.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import jinja2
44
from pydantic import BaseModel, Field, ValidationError, validator
55
from spacy.language import Language
6-
from spacy.tokens import Doc, Span
6+
from spacy.tokens import Doc
77
from spacy.training import Example
88
from wasabi import msg
99

@@ -45,7 +45,6 @@ def _preannotate(doc: Union[Doc, RELExample]) -> str:
4545
text = doc.text
4646

4747
for i, ent in enumerate(doc.ents):
48-
assert isinstance(ent, Span)
4948
end = ent.end_char
5049
before, after = text[: end + offset], text[end + offset :]
5150

spacy_llm/tasks/span.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Type
23

34
import jinja2
@@ -43,6 +44,48 @@ def __init__(
4344
self._case_sensitive_matching = case_sensitive_matching
4445
self._single_match = single_match
4546

47+
if self._prompt_examples:
48+
self._prompt_examples = self._check_label_consistency()
49+
50+
def _check_label_consistency(self) -> List[SpanExample]:
51+
"""Checks consistency of labels between examples and defined labels. Emits warning on inconsistency.
52+
RETURNS (List[SpanExample]): List of SpanExamples with valid labels.
53+
"""
54+
assert self._prompt_examples
55+
example_labels = {
56+
self._normalizer(key): key
57+
for example in self._prompt_examples
58+
for key in example.entities
59+
}
60+
unspecified_labels = {
61+
example_labels[key]
62+
for key in (set(example_labels.keys()) - set(self._label_dict.keys()))
63+
}
64+
if not set(example_labels.keys()) <= set(self._label_dict.keys()):
65+
warnings.warn(
66+
f"Examples contain labels that are not specified in the task configuration. The latter contains the "
67+
f"following labels: {sorted(list(set(self._label_dict.values())))}. Labels in examples missing from "
68+
f"the task configuration: {sorted(list(unspecified_labels))}. Please ensure your label specification "
69+
f"and example labels are consistent."
70+
)
71+
72+
# Return examples without non-declared labels. If an example only has undeclared labels, it is discarded.
73+
return [
74+
example
75+
for example in [
76+
SpanExample(
77+
text=example.text,
78+
entities={
79+
label: entities
80+
for label, entities in example.entities.items()
81+
if self._normalizer(label) in self._label_dict
82+
},
83+
)
84+
for example in self._prompt_examples
85+
]
86+
if len(example.entities)
87+
]
88+
4689
@property
4790
def labels(self) -> Tuple[str, ...]:
4891
return tuple(self._label_dict.values())
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
- text: Jack and Jill went up the hill.
2+
entities:
3+
PERSON:
4+
- Jack
5+
- Jill
6+
LOCATION:
7+
- hill
8+
- text: spaCy is a great tool
9+
entities:
10+
TECH:
11+
- spaCy
12+
- text: Jack and Jill went up the hill and spaCy is a great tool.
13+
entities:
14+
PERSON:
15+
- Jack
16+
- Jill
17+
LOCATION:
18+
- hill
19+
TECH:
20+
- spaCy

spacy_llm/tests/tasks/test_ner.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23
from pathlib import Path
34

45
import pytest
@@ -200,7 +201,7 @@ def test_ner_config(cfg_string, request):
200201
labels = split_labels(labels)
201202
task = pipe.task
202203
assert isinstance(task, Labeled)
203-
assert task.labels == tuple(labels)
204+
assert sorted(task.labels) == sorted(tuple(labels))
204205
assert pipe.labels == task.labels
205206
assert nlp.pipe_labels["llm"] == list(task.labels)
206207

@@ -827,3 +828,55 @@ def test_ner_to_disk(noop_config, tmp_path: Path):
827828
nlp2.from_disk(path)
828829

829830
assert task1._label_dict == task2._label_dict == labels
831+
832+
833+
def test_label_inconsistency():
834+
"""Test whether inconsistency between specified labels and labels in examples is detected."""
835+
cfg = f"""
836+
[nlp]
837+
lang = "en"
838+
pipeline = ["llm"]
839+
840+
[components]
841+
842+
[components.llm]
843+
factory = "llm"
844+
845+
[components.llm.task]
846+
@llm_tasks = "spacy.NER.v2"
847+
labels = ["PERSON", "LOCATION"]
848+
849+
[components.llm.task.examples]
850+
@misc = "spacy.FewShotReader.v1"
851+
path = {str((Path(__file__).parent / "examples" / "ner_inconsistent.yml"))}
852+
853+
[components.llm.model]
854+
@llm_models = "test.NoOpModel.v1"
855+
"""
856+
857+
config = Config().from_str(cfg)
858+
with pytest.warns(
859+
UserWarning,
860+
match=re.escape(
861+
"Examples contain labels that are not specified in the task configuration. The latter contains the "
862+
"following labels: ['LOCATION', 'PERSON']. Labels in examples missing from the task configuration: "
863+
"['TECH']. Please ensure your label specification and example labels are consistent."
864+
),
865+
):
866+
nlp = assemble_from_config(config)
867+
868+
prompt_examples = nlp.get_pipe("llm")._task._prompt_examples
869+
assert len(prompt_examples) == 2
870+
assert prompt_examples[0].text == "Jack and Jill went up the hill."
871+
assert prompt_examples[0].entities == {
872+
"LOCATION": ["hill"],
873+
"PERSON": ["Jack", "Jill"],
874+
}
875+
assert (
876+
prompt_examples[1].text
877+
== "Jack and Jill went up the hill and spaCy is a great tool."
878+
)
879+
assert prompt_examples[1].entities == {
880+
"LOCATION": ["hill"],
881+
"PERSON": ["Jack", "Jill"],
882+
}

spacy_llm/tests/tasks/test_rel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_rel_config(cfg_string, request: FixtureRequest):
133133

134134
@pytest.mark.external
135135
@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available")
136-
@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string", "fewshot_cfg_string"])
136+
@pytest.mark.parametrize("cfg_string", ["fewshot_cfg_string"]) # "zeroshot_cfg_string",
137137
def test_rel_predict(task, cfg_string, request):
138138
"""Use OpenAI to get REL results.
139139
Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable

spacy_llm/tests/tasks/test_spancat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_spancat_config(cfg_string, request):
9292
labels = split_labels(labels)
9393
task = pipe.task
9494
assert isinstance(task, Labeled)
95-
assert task.labels == tuple(labels)
95+
assert sorted(task.labels) == sorted(tuple(labels))
9696
assert pipe.labels == task.labels
9797
assert nlp.pipe_labels["llm"] == list(task.labels)
9898

spacy_llm/tests/tasks/test_textcat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def test_textcat_config(task, cfg_string, request):
209209
labels = split_labels(labels)
210210
task = pipe.task
211211
assert isinstance(task, Labeled)
212-
assert task.labels == tuple(labels)
212+
assert sorted(task.labels) == sorted(tuple(labels))
213213
assert pipe.labels == task.labels
214214
assert nlp.pipe_labels["llm"] == list(task.labels)
215215

0 commit comments

Comments
 (0)