Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions spacy_llm/pipeline/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,15 @@ def labels(self) -> Tuple[str, ...]:
labels = self._task.labels
return labels

def add_label(self, label: str) -> int:
def add_label(self, label: str, label_definition: Optional[str] = None) -> int:
if not isinstance(self._task, LabeledTask):
raise ValueError("The task of this LLM component does not have labels.")
return self._task.add_label(label)
return self._task.add_label(label, label_definition)

def clear(self) -> None:
if not isinstance(self._task, LabeledTask):
raise ValueError("The task of this LLM component does not have labels.")
return self._task.clear()

@property
def task(self) -> LLMTask:
Expand Down
13 changes: 12 additions & 1 deletion spacy_llm/tasks/builtin_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,25 @@ def _extract_labels_from_example(self, example: Example) -> List[str]:
def labels(self) -> Tuple[str, ...]:
return tuple(self._label_dict.values())

def add_label(self, label: str) -> int:
def add_label(self, label: str, label_definition: Optional[str] = None) -> int:
"""Add a label to the task"""
if not isinstance(label, str):
raise ValueError(Errors.E187)
if label in self.labels:
return 0
self._label_dict[self._normalizer(label)] = label
if label_definition is None:
return 1
if self._label_definitions is None:
self._label_definitions = {}
self._label_definitions[label] = label_definition
return 1

def clear(self) -> None:
"""Reset all labels."""
self._label_dict = {}
self._label_definitions = None

@property
def normalizer(self) -> Callable[[str], str]:
return self._normalizer
Expand Down
34 changes: 34 additions & 0 deletions spacy_llm/tests/tasks/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,41 @@ def test_add_label():
doc = nlp(text)
assert len(doc.ents) == 0

for label, definition in [
("PERSON", "Every person with the name Jack"),
("LOCATION", None),
]:
llm.add_label(label, definition)
doc = nlp(text)
assert len(doc.ents) == 2


@pytest.mark.external
@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available")
def test_clear_label():
nlp = spacy.blank("en")
llm = nlp.add_pipe(
"llm",
config={
"task": {
"@llm_tasks": "spacy.NER.v3",
},
"model": {
"@llm_models": "spacy.GPT-3-5.v1",
},
},
)

nlp.initialize()
text = "Jack and Jill visited France."
doc = nlp(text)

for label in ["PERSON", "LOCATION"]:
llm.add_label(label)
doc = nlp(text)
assert len(doc.ents) == 3

llm.clear()

doc = nlp(text)
assert len(doc.ents) == 0
5 changes: 4 additions & 1 deletion spacy_llm/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ class LabeledTask(Protocol):
def labels(self) -> Tuple[str, ...]:
...

def add_label(self, label: str) -> int:
def add_label(self, label: str, label_definition: Optional[str] = None) -> int:
...

def clear(self) -> None:
...


Expand Down