Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b3507f4
chore: ignore notebooks
bdura May 23, 2023
b50b6e4
feat: add rel task
bdura May 23, 2023
c40880a
fix: rel template
bdura May 23, 2023
41f3bd2
docs: add rel example
bdura May 23, 2023
e84f920
feat: clean and simplify rel.py
bdura May 23, 2023
91f24ec
docs: mention Doc._.rel in docstrings
bdura May 23, 2023
77ee80c
feat: update rel
bdura May 23, 2023
d817db9
docs: update rel example
bdura May 23, 2023
0a97ba5
docs: document REL task in readme
bdura May 23, 2023
7cf79ca
Update usage_examples/rel_openai/README.md
bdura May 23, 2023
9b76090
docs: use en_core_web_md instead of sm
bdura May 23, 2023
aa09a91
test: add usage test
bdura May 23, 2023
8061dd1
fix: mypy issue
bdura May 23, 2023
dba9855
Merge remote-tracking branch 'origin/main' into feat/rel
bdura May 23, 2023
06fb413
test: rel configuration
bdura May 23, 2023
116d216
ci: install en_core_web_md
bdura May 23, 2023
1493cc5
fix: test_rel_config assert
bdura May 23, 2023
3a257f7
test: rel prediction
bdura May 23, 2023
fd39ded
Merge branch 'main' into feat/rel
rmitsch May 24, 2023
4f1113a
Adjust tests w.r.t. new format.
rmitsch May 24, 2023
bf7dd06
ci: install en_core_web_md in external tests
bdura May 24, 2023
fbf1f70
docs: update description of NER-capable
bdura May 24, 2023
6c6fcf1
fix: check custom rel extension at every call
bdura May 24, 2023
a6b225d
Merge remote-tracking branch 'origin/main' into feat/rel
bdura May 24, 2023
757a638
feat: update rel example to use util.assemble
bdura May 24, 2023
3f96586
chore: move the [initialize] section to the end of the rel example co…
bdura May 24, 2023
f72fede
feat: assemble accepts a config or a path
bdura May 24, 2023
b7e4488
use assemble in test_rel
bdura May 24, 2023
2709af5
feat: remove useless type hint
bdura May 24, 2023
ed7d4a9
fix: add initialization
bdura May 24, 2023
cfb2289
fix: cleaner util
bdura May 24, 2023
23dfc9c
Update util.py
adrianeboyd May 24, 2023
2a8a22a
Update README.md
rmitsch May 25, 2023
e4724ef
Merge remote-tracking branch 'origin/main' into feat/rel
bdura May 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
*.ipynb

# IPython
profile_default/
Expand Down
9 changes: 8 additions & 1 deletion spacy_llm/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from .ner import NERTask
from .noop import NoopTask
from .rel import RELTask
from .spancat import SpanCatTask
from .textcat import TextCatTask

__all__ = ["NoopTask", "NERTask", "TextCatTask", "SpanCatTask"]
__all__ = [
"NoopTask",
"NERTask",
"TextCatTask",
"SpanCatTask",
"RELTask",
]
102 changes: 102 additions & 0 deletions spacy_llm/tasks/rel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Callable, Dict, Iterable, List, Optional

import jinja2
from pydantic import BaseModel, ConstrainedInt, ValidationError
from spacy.tokens import Doc
from wasabi import msg

from ..registry import lowercase_normalizer, registry
from .templates import read_template


class EntityId(ConstrainedInt):
@classmethod
def __get_validators__(cls):
yield cls.clean_ent
for val in super().__get_validators__():
yield val

@classmethod
def clean_ent(cls, value):
if isinstance(value, str):
value = value.strip("ENT")
return value


class RelationItem(BaseModel):
dep: EntityId
dest: EntityId
relation: str


class RELExample(BaseModel):
text: str
relations: List[RelationItem]


def _preannotate(doc: Doc) -> str:
offset = 0

text = doc.text

for i, ent in enumerate(doc.ents):
end = ent.end_char
before, after = text[: end + offset], text[end + offset :]

annotation = f"[ENT{i}:{ent.label_}]"
offset += len(annotation)

text = f"{before}{annotation}{after}"

return text


@registry.llm_tasks("spacy.REL.v1")
class RELTask:
_TEMPLATE_STR = read_template("rel")

def __init__(
self,
labels: str,
label_definitions: Optional[Dict[str, str]] = None,
examples: Optional[Callable[[], Iterable[Dict]]] = None,
normalizer: Optional[Callable[[str], str]] = None,
):

if not Doc.has_extension("rel"):
Doc.set_extension("rel", default=[])

self._normalizer = normalizer if normalizer else lowercase_normalizer()
self._label_dict = {
self._normalizer(label): label for label in labels.split(",")
}
self._label_definitions = label_definitions
self._examples = [RELExample(**eg) for eg in examples()] if examples else None

def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]:
environment = jinja2.Environment()
_template = environment.from_string(self._TEMPLATE_STR)
for doc in docs:
prompt = _template.render(
text=_preannotate(doc),
labels=list(self._label_dict.values()),
label_definitions=self._label_definitions,
examples=self._examples,
)
yield prompt

def _format_response(self, response: str) -> Iterable[RelationItem]:
"""Parse raw string response into a structured format"""
for line in response.strip().split("\n"):
try:
yield RelationItem.parse_raw(line)
except ValidationError:
msg.warn("Validation issue", line)

def parse_responses(
self, docs: Iterable[Doc], responses: Iterable[str]
) -> Iterable[Doc]:
for doc, prompt_response in zip(docs, responses):
rels = list(self._format_response(prompt_response))
doc._.rel = rels
yield doc
57 changes: 57 additions & 0 deletions spacy_llm/tasks/templates/rel.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
The text below contains pre-extracted entities, denoted in the following format within the text:
{# whitespace #}
<entity text>[ENT<entity id>:<entity label>]
{# whitespace #}
From the text below, extract the following relations between entities:
{# whitespace #}
{# whitespace #}
{%- for label in labels -%}
{{ label }}
{# whitespace #}
{%- endfor -%}
{# whitespace #}
The extraction has to use the following format, with one line for each detected relation:
{# whitespace #}
{"dep": <entity id>, "dest": <entity id>, "relation": <relation label>}
{# whitespace #}
Make sure that only relevant relations are listed, and that each line is a valid JSON object.
{# whitespace #}
{%- if label_definitions -%}
Below are definitions of each label to help aid you in what kinds of relationship to extract for each label.
Assume these definitions are written by an expert and follow them closely.
{# whitespace #}
{# whitespace #}
{%- for label, definition in label_definitions.items() -%}
{{ label }}: {{ definition }}
{# whitespace #}
{%- endfor -%}
{# whitespace #}
{# whitespace #}
{%- endif -%}
{%- if examples -%}
Below are some examples (only use these as a guide):
{# whitespace #}
{# whitespace #}
{%- for example in examples -%}
Text:
'''
{{ example.text }}
'''
{# whitespace #}
{%- for item in example.relations -%}
{# whitespace #}
{{ item.json() }}
{%- endfor -%}
{# whitespace #}
{# whitespace #}
{# whitespace #}
{%- endfor -%}
{# whitespace #}
{# whitespace #}
{%- endif -%}
Here is the text that needs labeling:
{# whitespace #}
Text:
'''
{{ text }}
'''
13 changes: 13 additions & 0 deletions spacy_llm/tests/tasks/test_rel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import spacy
from spacy.tokens import Span

from spacy_llm.tasks.rel import _preannotate


def test_text_preannotation():
nlp = spacy.load("blank:en")

doc = nlp("This is a test")
doc.ents = [Span(doc, start=3, end=4, label="test")]

assert _preannotate(doc) == "This is a test[ENT0:test]"
38 changes: 38 additions & 0 deletions usage_examples/rel_openai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Relation extraction using LLMs

This example shows how you can use a model from OpenAI for relation extraction in
zero- or few-shot settings.

Here, we use the pretrained [`en_core_web_sm` model](https://spacy.io/models/en#en_core_web_sm)
to perform Named Entity Recognition (NER) using a fast and properly evaluated pipeline.
Then, we leverage the OpenAI API to detect the relations between the extracted entities.
In this example, we focus on two simple relations: `LivesIn` and `Visits`.

First, create a new API key from
[openai.com](https://platform.openai.com/account/api-keys) or fetch an existing
one. Record the secret key and make sure this is available as an environmental
variable:

```sh
export OPENAI_API_KEY="sk-..."
export OPENAI_API_ORG="org-..."
```

Then, you can run the pipeline on a sample text via:

```sh
python run_rel_openai_pipeline.py [TEXT] [PATH TO CONFIG]
```

For example:

```sh
python run_rel_openai_pipeline.py \
"Laura just bought an apartment in Boston." \
./openai_rel_zeroshot.cfg
```

You can also include examples to perform few-shot annotation. To do so, use the
`openai_rel_fewshot.cfg` file instead. You can find the few-shot examples in
the `examples.jsonl` file. Feel free to change and update it to your liking.
We also support other file formats, including `.json`, `.yml` and `.yaml`.
2 changes: 2 additions & 0 deletions usage_examples/rel_openai/examples.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"text": "Laura bought a house in Boston with her husband Mark.", "relations": [{"dep": 0, "dest": 1, "relation": "LivesIn"}, {"dep": 2, "dest": 1, "relation": "Visits"}]}
{"text": "Laura travelled through South America by bike.", "relations": [{"dep": 0, "dest": 1, "relation": "Visits"}]}
26 changes: 26 additions & 0 deletions usage_examples/rel_openai/openai_rel_fewshot.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[paths]
examples = null

[nlp]
lang = "en"
pipeline = ["ner", "llm_rel"]

[components]

[components.ner]
source = "en_core_web_sm"

[components.llm_rel]
factory = "llm"

[components.llm_rel.task]
@llm_tasks = "spacy.REL.v1"
labels = LivesIn,Visits

[components.llm_rel.task.examples]
@misc = "spacy.FewShotReader.v1"
path = ${paths.examples}

[components.llm_rel.backend]
@llm_backends = "spacy.REST.v1"
api = "OpenAI"
22 changes: 22 additions & 0 deletions usage_examples/rel_openai/openai_rel_zeroshot.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[paths]
examples = null

[nlp]
lang = "en"
pipeline = ["ner", "llm_rel"]

[components]

[components.ner]
source = "en_core_web_sm"

[components.llm_rel]
factory = "llm"

[components.llm_rel.task]
@llm_tasks = "spacy.REL.v1"
labels = LivesIn,Visits

[components.llm_rel.backend]
@llm_backends = "spacy.REST.v1"
api = "OpenAI"
47 changes: 47 additions & 0 deletions usage_examples/rel_openai/run_rel_openai_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
from pathlib import Path

import typer
from spacy import util
from wasabi import msg

Arg = typer.Argument
Opt = typer.Option


def run_pipeline(
# fmt: off
text: str = Arg("", help="Text to perform text categorization on."),
config_path: Path = Arg(..., help="Path to the configuration file to use."),
verbose: bool = Opt(False, "--verbose", "-v", help="Show extra information."),
# fmt: on
):
if not os.getenv("OPENAI_API_KEY", None):
msg.fail(
"OPENAI_API_KEY env variable was not found. "
"Set it by running 'export OPENAI_API_KEY=...' and try again.",
exits=1,
)

msg.text(f"Loading config from {config_path}", show=verbose)
config = util.load_config(config_path)
# Reload config with dynamic path for examples, if available in config.
if "examples" in config.get("paths", {}):
config = util.load_config(
config_path,
overrides={"paths.examples": str(Path(__file__).parent / "examples.jsonl")},
)

nlp = util.load_model_from_config(config, auto_fill=True)
doc = nlp(text)

msg.text(f"Text: {doc.text}")
msg.text(f"Entities: {[(ent.text, ent.label_) for ent in doc.ents]}")

msg.text("Relations:")
for r in doc._.rel:
msg.text(f" - {doc.ents[r.dep]} [{r.relation}] {doc.ents[r.dest]}")


if __name__ == "__main__":
typer.run(run_pipeline)