Skip to content

Commit 514db77

Browse files
bdurasvlandegrmitsch
authored
feat: create examples on init (#163)
* feat: add NER example gathering from initialization * feat: add example gathering from init * feat: add examples from init for REL task * feat: add examples from init for textcat * style: consistent naming, updated docstrings * move example types within files * rename gather_examples to add_prompt_examples * black formatting * rename to prompt_examples to avoid confusion * refactor initialize loop to avoid having to put all examples in a list * type fixes * import fixes * allow -1 to infer prompt examples from all examples * black * fix template rendering * use set of labels to avoid duplicates internally * update readme * sort labels to obtain the same prompts across runs * fixes * add functionality to LemmaTask * fixes * Update spacy_llm/tasks/lemma.py Co-authored-by: Sofie Van Landeghem <[email protected]> * Fix test config error. * Fix test config error. * Update docstring (to restart tests). * Update readme. * Rename infer_prompt_examples to n_prompt_examples. --------- Co-authored-by: svlandeg <[email protected]> Co-authored-by: Raphael Mitsch <[email protected]> Co-authored-by: Sofie Van Landeghem <[email protected]>
1 parent 6e1a7f7 commit 514db77

File tree

16 files changed

+382
-147
lines changed

16 files changed

+382
-147
lines changed

README.md

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,39 @@ Moreover, the task may define an optional [`scorer` method](https://spacy.io/api
305305
It should accept an iterable of `Example`s as input and return a score dictionary.
306306
If the `scorer` method is defined, `spacy-llm` will call it to evaluate the component.
307307

308+
#### Providing examples for few-shot prompts
309+
310+
All built-in tasks support few-shot prompts, i. e. including examples in a prompt. Examples can be supplied in two ways:
311+
(1) as a separate file containing only examples or (2) by initializing `llm` with a `get_examples()` callback (like any
312+
other spaCy pipeline component).
313+
314+
##### (1) Few-shot example file
315+
316+
A file containing examples for few-shot prompting can be configured like this:
317+
318+
```ini
319+
[components.llm.task]
320+
@llm_tasks = "spacy.NER.v2"
321+
labels = PERSON,ORGANISATION,LOCATION
322+
[components.llm.task.examples]
323+
@misc = "spacy.FewShotReader.v1"
324+
path = "ner_examples.yml"
325+
```
326+
327+
The supplied file has to conform to the format expected by the required task (see the task documentation further down).
328+
329+
##### (2) Initializing the `llm` component with a `get_examples()` callback
330+
331+
Alternatively, you can initialize your `nlp` pipeline by providing a `get_examples` callback for
332+
[`nlp.initialize`](https://spacy.io/api/language#initialize) and setting `n_prompt_examples` to a positive number to
333+
automatically fetch a few examples for few-shot learning. Set `n_prompt_examples` to `-1` to use all examples as
334+
part of the few-shot learning prompt.
335+
336+
```ini
337+
[initialize.components.llm]
338+
n_prompt_examples = 3
339+
```
340+
308341
#### <kbd>function</kbd> `task.generate_prompts`
309342

310343
Takes a collection of documents, and returns a collection of "prompts", which can be of type `Any`.
@@ -389,7 +422,10 @@ labels = PERSON,ORGANISATION,LOCATION
389422
path = "ner_examples.yml"
390423
```
391424

392-
If you don't have specific examples to provide to the LLM, you can write definitions for each label and provide them via the `label_definitions` argument. This lets you tell the LLM exactly what you're looking for rather than relying on the LLM to interpret its task given just the label name. Label descriptions are freeform so you can write whatever you want here, but through some experiments a brief description along with some examples and counter examples seems to work quite well.
425+
You can also write definitions for each label and provide them via the `label_definitions` argument. This lets you tell
426+
the LLM exactly what you're looking for rather than relying on the LLM to interpret its task given just the label name.
427+
Label descriptions are freeform so you can write whatever you want here, but through some experiments a brief
428+
description along with some examples and counter examples seems to work quite well.
393429

394430
```ini
395431
[components.llm.task]
@@ -627,15 +663,11 @@ labels = ["LivesIn", "Visits"]
627663
To perform few-shot learning, you can write down a few examples in a separate file, and provide these to be injected into the prompt to the LLM.
628664
The default reader `spacy.FewShotReader.v1` supports `.yml`, `.yaml`, `.json` and `.jsonl`.
629665

630-
```json
666+
```jsonl
631667
{"text": "Laura bought a house in Boston with her husband Mark.", "ents": [{"start_char": 0, "end_char": 5, "label": "PERSON"}, {"start_char": 24, "end_char": 30, "label": "GPE"}, {"start_char": 48, "end_char": 52, "label": "PERSON"}], "relations": [{"dep": 0, "dest": 1, "relation": "LivesIn"}, {"dep": 2, "dest": 1, "relation": "LivesIn"}]}
632668
{"text": "Michael travelled through South America by bike.", "ents": [{"start_char": 0, "end_char": 7, "label": "PERSON"}, {"start_char": 26, "end_char": 39, "label": "LOC"}], "relations": [{"dep": 0, "dest": 1, "relation": "Visits"}]}
633669
```
634670

635-
Note: the REL task relies on pre-extracted entities to make its prediction.
636-
Hence, you'll need to add a component that populates `doc.ents` with recognized
637-
spans to your spaCy pipeline and put it _before_ the REL component.
638-
639671
```ini
640672
[components.llm.task]
641673
@llm_tasks = "spacy.REL.v1"
@@ -645,6 +677,10 @@ labels = ["LivesIn", "Visits"]
645677
path = "rel_examples.jsonl"
646678
```
647679

680+
Note: the REL task relies on pre-extracted entities to make its prediction.
681+
Hence, you'll need to add a component that populates `doc.ents` with recognized
682+
spans to your spaCy pipeline and put it _before_ the REL component.
683+
648684
#### spacy.Lemma.v1
649685

650686
The `Lemma.v1` task lemmatizes the provided text and updates the `lemma_` attribute in the doc's tokens accordingly.

spacy_llm/pipeline/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def to_disk(
287287
if isinstance(self._model, Serializable):
288288
serialize["model"] = lambda p: self._model.to_disk(p, exclude=exclude) # type: ignore[attr-defined]
289289

290-
return util.to_disk(path, serialize, exclude)
290+
util.to_disk(path, serialize, exclude)
291291

292292
def from_disk(
293293
self, path: Path, *, exclude: Tuple[str] = cast(Tuple[str], tuple())

spacy_llm/tasks/lemma.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
22

33
import jinja2
4+
from pydantic import BaseModel
45
from spacy.language import Language
56
from spacy.scorer import Scorer
67
from spacy.tokens import Doc
@@ -10,12 +11,15 @@
1011
from ..ty import ExamplesConfigType
1112
from .templates import read_template
1213
from .util import SerializableTask
13-
from .util.examples import LemmaExample
14-
from .util.serialization import ExampleType
1514

1615
_DEFAULT_LEMMA_TEMPLATE_V1 = read_template("lemma")
1716

1817

18+
class LemmaExample(BaseModel):
19+
text: str
20+
lemmas: List[Dict[str, str]]
21+
22+
1923
@registry.llm_tasks("spacy.Lemma.v1")
2024
def make_lemma_task(
2125
template: str = _DEFAULT_LEMMA_TEMPLATE_V1,
@@ -29,11 +33,11 @@ def make_lemma_task(
2933
passed, then zero-shot learning will be used.
3034
"""
3135
raw_examples = examples() if callable(examples) else examples
32-
span_examples = (
36+
lemma_examples = (
3337
[LemmaExample(**eg) for eg in raw_examples] if raw_examples else None
3438
)
3539

36-
return LemmaTask(template=template, examples=span_examples)
40+
return LemmaTask(template=template, examples=lemma_examples)
3741

3842

3943
class LemmaTask(SerializableTask[LemmaExample]):
@@ -50,28 +54,33 @@ def __init__(
5054
passed, then zero-shot learning will be used.
5155
"""
5256
self._template = template
53-
self._examples = examples
57+
self._prompt_examples = examples or []
5458

5559
def initialize(
5660
self,
5761
get_examples: Callable[[], Iterable["Example"]],
5862
nlp: Language,
63+
n_prompt_examples: int = 0,
5964
**kwargs: Any,
6065
) -> None:
6166
"""Nothing to initialize for the LEMMA task.
6267
get_examples (Callable[[], Iterable["Example"]]): Callable that provides examples
6368
for initialization.
6469
nlp (Language): Language instance.
65-
labels (List[str]): Optional list of labels.
70+
n_prompt_examples (int): How many prompt examples to infer from the provided Example objects.
71+
0 by default. Takes all examples if set to -1.
6672
"""
73+
for eg in get_examples():
74+
if n_prompt_examples < 0 or len(self._prompt_examples) < n_prompt_examples:
75+
self._prompt_examples.append(self._create_prompt_example(eg))
6776

6877
def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]:
6978
environment = jinja2.Environment()
7079
_template = environment.from_string(self._template)
7180
for doc in docs:
7281
prompt = _template.render(
7382
text=doc.text,
74-
examples=self._examples,
83+
examples=self._prompt_examples,
7584
)
7685
yield prompt
7786

@@ -114,5 +123,10 @@ def _cfg_keys(self) -> List[str]:
114123
return ["_template"]
115124

116125
@property
117-
def _Example(self) -> Type[ExampleType]:
126+
def _Example(self) -> Type[LemmaExample]:
118127
return LemmaExample
128+
129+
def _create_prompt_example(self, example: Example) -> LemmaExample:
130+
"""Create a lemma prompt example from a spaCy example."""
131+
lemma_dict = [{t.text: t.lemma_} for t in example.reference]
132+
return LemmaExample(text=example.reference.text, lemmas=lemma_dict)

spacy_llm/tasks/ner.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
23

34
from spacy.language import Language
@@ -10,8 +11,8 @@
1011
from ..registry import registry
1112
from ..ty import ExamplesConfigType
1213
from ..util import split_labels
14+
from .span import SpanExample, SpanTask
1315
from .templates import read_template
14-
from .util import SpanExample, SpanTask
1516

1617
_DEFAULT_NER_TEMPLATE_V1 = read_template("ner")
1718
_DEFAULT_NER_TEMPLATE_V2 = read_template("ner.v2")
@@ -52,7 +53,7 @@ def make_ner_task(
5253
return NERTask(
5354
labels=labels_list,
5455
template=_DEFAULT_NER_TEMPLATE_V1,
55-
examples=span_examples,
56+
prompt_examples=span_examples,
5657
normalizer=normalizer,
5758
alignment_mode=alignment_mode,
5859
case_sensitive_matching=case_sensitive_matching,
@@ -98,7 +99,7 @@ def make_ner_task_v2(
9899
labels=labels_list,
99100
template=template,
100101
label_definitions=label_definitions,
101-
examples=span_examples,
102+
prompt_examples=span_examples,
102103
normalizer=normalizer,
103104
alignment_mode=alignment_mode,
104105
case_sensitive_matching=case_sensitive_matching,
@@ -112,7 +113,7 @@ def __init__(
112113
labels: List[str] = [],
113114
template: str = _DEFAULT_NER_TEMPLATE_V2,
114115
label_definitions: Optional[Dict[str, str]] = None,
115-
examples: Optional[List[SpanExample]] = None,
116+
prompt_examples: Optional[List[SpanExample]] = None,
116117
normalizer: Optional[Callable[[str], str]] = None,
117118
alignment_mode: Literal["strict", "contract", "expand"] = "contract",
118119
case_sensitive_matching: bool = False,
@@ -140,7 +141,7 @@ def __init__(
140141
labels=labels,
141142
template=template,
142143
label_definitions=label_definitions,
143-
examples=examples,
144+
prompt_examples=prompt_examples,
144145
normalizer=normalizer,
145146
alignment_mode=alignment_mode,
146147
case_sensitive_matching=case_sensitive_matching,
@@ -152,6 +153,7 @@ def initialize(
152153
get_examples: Callable[[], Iterable["Example"]],
153154
nlp: Language,
154155
labels: List[str] = [],
156+
n_prompt_examples: int = 0,
155157
**kwargs: Any,
156158
) -> None:
157159
"""Initialize the NER task, by auto-discovering labels.
@@ -166,22 +168,26 @@ def initialize(
166168
for initialization.
167169
nlp (Language): Language instance.
168170
labels (List[str]): Optional list of labels.
171+
n_prompt_examples (int): How many prompt examples to infer from the Example objects.
172+
0 by default. Takes all examples if set to -1.
169173
"""
170-
171-
examples = get_examples()
172-
173174
if not labels:
174175
labels = list(self._label_dict.values())
176+
infer_labels = not labels
175177

176-
if not labels:
177-
label_set = set()
178+
if infer_labels:
179+
labels = []
178180

179-
for eg in examples:
181+
for eg in get_examples():
182+
if infer_labels:
180183
for ent in eg.reference.ents:
181-
label_set.add(ent.label_)
182-
labels = list(label_set)
184+
labels.append(ent.label_)
185+
if n_prompt_examples < 0 or len(self._prompt_examples) < n_prompt_examples:
186+
self._prompt_examples.append(self._create_prompt_example(eg))
183187

184-
self._label_dict = {self._normalizer(label): label for label in labels}
188+
self._label_dict = {
189+
self._normalizer(label): label for label in sorted(set(labels))
190+
}
185191

186192
def assign_spans(
187193
self,
@@ -196,3 +202,11 @@ def scorer(
196202
examples: Iterable[Example],
197203
) -> Dict[str, Any]:
198204
return get_ner_prf(examples)
205+
206+
def _create_prompt_example(self, example: Example) -> SpanExample:
207+
"""Create an NER prompt example from a spaCy example."""
208+
entities = defaultdict(list)
209+
for ent in example.reference.ents:
210+
entities[ent.label_].append(ent.text)
211+
212+
return SpanExample(text=example.reference.text, entities=entities)

0 commit comments

Comments
 (0)