-
-
Couldn't load subscription status.
- Fork 104
feat: REL task #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
feat: REL task #114
Changes from 18 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
b3507f4
chore: ignore notebooks
bdura b50b6e4
feat: add rel task
bdura c40880a
fix: rel template
bdura 41f3bd2
docs: add rel example
bdura e84f920
feat: clean and simplify rel.py
bdura 91f24ec
docs: mention Doc._.rel in docstrings
bdura 77ee80c
feat: update rel
bdura d817db9
docs: update rel example
bdura 0a97ba5
docs: document REL task in readme
bdura 7cf79ca
Update usage_examples/rel_openai/README.md
bdura 9b76090
docs: use en_core_web_md instead of sm
bdura aa09a91
test: add usage test
bdura 8061dd1
fix: mypy issue
bdura dba9855
Merge remote-tracking branch 'origin/main' into feat/rel
bdura 06fb413
test: rel configuration
bdura 116d216
ci: install en_core_web_md
bdura 1493cc5
fix: test_rel_config assert
bdura 3a257f7
test: rel prediction
bdura fd39ded
Merge branch 'main' into feat/rel
rmitsch 4f1113a
Adjust tests w.r.t. new format.
rmitsch bf7dd06
ci: install en_core_web_md in external tests
bdura fbf1f70
docs: update description of NER-capable
bdura 6c6fcf1
fix: check custom rel extension at every call
bdura a6b225d
Merge remote-tracking branch 'origin/main' into feat/rel
bdura 757a638
feat: update rel example to use util.assemble
bdura 3f96586
chore: move the [initialize] section to the end of the rel example co…
bdura f72fede
feat: assemble accepts a config or a path
bdura b7e4488
use assemble in test_rel
bdura 2709af5
feat: remove useless type hint
bdura ed7d4a9
fix: add initialization
bdura cfb2289
fix: cleaner util
bdura 23dfc9c
Update util.py
adrianeboyd 2a8a22a
Update README.md
rmitsch e4724ef
Merge remote-tracking branch 'origin/main' into feat/rel
bdura File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,6 +88,7 @@ target/ | |
|
|
||
| # Jupyter Notebook | ||
| .ipynb_checkpoints | ||
| *.ipynb | ||
|
|
||
| # IPython | ||
| profile_default/ | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,13 @@ | ||
| from .ner import NERTask | ||
| from .noop import NoopTask | ||
| from .rel import RELTask | ||
| from .spancat import SpanCatTask | ||
| from .textcat import TextCatTask | ||
|
|
||
| __all__ = ["NoopTask", "NERTask", "TextCatTask", "SpanCatTask"] | ||
| __all__ = [ | ||
| "NoopTask", | ||
| "NERTask", | ||
| "TextCatTask", | ||
| "SpanCatTask", | ||
| "RELTask", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| from typing import Callable, Dict, Iterable, List, Optional, Union | ||
|
|
||
| import jinja2 | ||
| from pydantic import BaseModel, Field, ValidationError, validator | ||
| from spacy.tokens import Doc | ||
| from wasabi import msg | ||
|
|
||
| from ..registry import lowercase_normalizer, registry | ||
| from .templates import read_template | ||
|
|
||
|
|
||
| class RelationItem(BaseModel): | ||
| dep: int | ||
| dest: int | ||
| relation: str | ||
|
|
||
| @validator("dep", "dest", pre=True) | ||
| def clean_ent(cls, value): | ||
| if isinstance(value, str): | ||
| value = value.strip("ENT") | ||
| return value | ||
|
|
||
|
|
||
| class EntityItem(BaseModel): | ||
| start_char: int | ||
| end_char: int | ||
| label_: str = Field(alias="label") | ||
|
|
||
|
|
||
| class RELExample(BaseModel): | ||
| text: str | ||
| ents: List[EntityItem] | ||
| relations: List[RelationItem] | ||
|
|
||
|
|
||
| def _preannotate(doc: Union[Doc, RELExample]) -> str: | ||
| """Creates a text version of the document with annotated entities.""" | ||
| offset = 0 | ||
|
|
||
| text = doc.text | ||
|
|
||
| for i, ent in enumerate(doc.ents): | ||
| end = ent.end_char | ||
| before, after = text[: end + offset], text[end + offset :] | ||
|
|
||
| annotation = f"[ENT{i}:{ent.label_}]" | ||
| offset += len(annotation) | ||
|
|
||
| text = f"{before}{annotation}{after}" | ||
|
|
||
| return text | ||
|
|
||
|
|
||
| @registry.llm_tasks("spacy.REL.v1") | ||
| class RELTask: | ||
| """Simple REL task. Populates a `Doc._.rel` custom attribute.""" | ||
|
|
||
| _TEMPLATE_STR = read_template("rel") | ||
|
|
||
| def __init__( | ||
| self, | ||
| labels: str, | ||
| label_definitions: Optional[Dict[str, str]] = None, | ||
| examples: Optional[Callable[[], Iterable[Dict]]] = None, | ||
| normalizer: Optional[Callable[[str], str]] = None, | ||
| verbose: bool = False, | ||
| ): | ||
|
|
||
| if not Doc.has_extension("rel"): | ||
| Doc.set_extension("rel", default=[]) | ||
bdura marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self._normalizer = normalizer if normalizer else lowercase_normalizer() | ||
| self._label_dict = { | ||
| self._normalizer(label): label for label in labels.split(",") | ||
| } | ||
| self._label_definitions = label_definitions | ||
| self._examples = examples and [RELExample.parse_obj(eg) for eg in examples()] | ||
|
|
||
| self._verbose = verbose | ||
|
|
||
| def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: | ||
| environment = jinja2.Environment() | ||
| _template = environment.from_string(self._TEMPLATE_STR) | ||
| for doc in docs: | ||
| prompt = _template.render( | ||
| text=_preannotate(doc), | ||
| labels=list(self._label_dict.values()), | ||
| label_definitions=self._label_definitions, | ||
| examples=self._examples, | ||
| preannotate=_preannotate, | ||
| ) | ||
| yield prompt | ||
|
|
||
| def _format_response(self, response: str) -> Iterable[RelationItem]: | ||
| """Parse raw string response into a structured format""" | ||
| relations = [] | ||
| for line in response.strip().split("\n"): | ||
| try: | ||
| relations.append(RelationItem.parse_raw(line)) | ||
| except ValidationError: | ||
| msg.warn( | ||
| "Validation issue", | ||
| line, | ||
| show=self._verbose, | ||
| ) | ||
| return relations | ||
|
|
||
| def parse_responses( | ||
| self, docs: Iterable[Doc], responses: Iterable[str] | ||
| ) -> Iterable[Doc]: | ||
| for doc, prompt_response in zip(docs, responses): | ||
| rels = self._format_response(prompt_response) | ||
| doc._.rel = rels | ||
| yield doc | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| The text below contains pre-extracted entities, denoted in the following format within the text: | ||
| {# whitespace #} | ||
| <entity text>[ENT<entity id>:<entity label>] | ||
| {# whitespace #} | ||
| From the text below, extract the following relations between entities: | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- for label in labels -%} | ||
| {{ label }} | ||
| {# whitespace #} | ||
| {%- endfor -%} | ||
| {# whitespace #} | ||
| The extraction has to use the following format, with one line for each detected relation: | ||
| {# whitespace #} | ||
| {"dep": <entity id>, "dest": <entity id>, "relation": <relation label>} | ||
| {# whitespace #} | ||
| Make sure that only relevant relations are listed, and that each line is a valid JSON object. | ||
| {# whitespace #} | ||
| {%- if label_definitions -%} | ||
| Below are definitions of each label to help aid you in what kinds of relationship to extract for each label. | ||
| Assume these definitions are written by an expert and follow them closely. | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- for label, definition in label_definitions.items() -%} | ||
| {{ label }}: {{ definition }} | ||
| {# whitespace #} | ||
| {%- endfor -%} | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- endif -%} | ||
| {%- if examples -%} | ||
| Below are some examples (only use these as a guide): | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- for example in examples -%} | ||
| Text: | ||
| ''' | ||
| {{ preannotate(example) }} | ||
| ''' | ||
| {# whitespace #} | ||
| {%- for item in example.relations -%} | ||
| {# whitespace #} | ||
| {{ item.json() }} | ||
| {%- endfor -%} | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- endfor -%} | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- endif -%} | ||
| Here is the text that needs labeling: | ||
| {# whitespace #} | ||
| Text: | ||
| ''' | ||
| {{ text }} | ||
| ''' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| {"text": "Laura bought a house in Boston with her husband Mark.", "ents": [{"start_char": 0, "end_char": 5, "label": "PERSON"}, {"start_char": 24, "end_char": 30, "label": "GPE"}, {"start_char": 48, "end_char": 52, "label": "PERSON"}], "relations": [{"dep": 0, "dest": 1, "relation": "LivesIn"}, {"dep": 2, "dest": 1, "relation": "LivesIn"}]} | ||
| {"text": "Michael travelled through South America by bike.", "ents": [{"start_char": 0, "end_char": 7, "label": "PERSON"}, {"start_char": 26, "end_char": 39, "label": "LOC"}], "relations": [{"dep": 0, "dest": 1, "relation": "Visits"}]} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
| import spacy | ||
| from confection import Config | ||
| from pytest import FixtureRequest | ||
|
|
||
| from spacy_llm.tasks.rel import RelationItem | ||
|
|
||
| from ..compat import has_openai_key | ||
|
|
||
| EXAMPLES_DIR = Path(__file__).parent / "examples" | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def zeroshot_cfg_string(): | ||
| return """ | ||
| [nlp] | ||
| lang = "en" | ||
| pipeline = ["ner", "llm"] | ||
| batch_size = 128 | ||
| [components] | ||
| [components.ner] | ||
| source = "en_core_web_md" | ||
| [components.llm] | ||
| factory = "llm" | ||
| [components.llm.task] | ||
| @llm_tasks = "spacy.REL.v1" | ||
| labels = "LivesIn,Visits" | ||
| [components.llm.backend] | ||
| @llm_backends = "spacy.REST.v1" | ||
| api = "OpenAI" | ||
| """ | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def fewshot_cfg_string(): | ||
| return f""" | ||
| [nlp] | ||
| lang = "en" | ||
| pipeline = ["ner", "llm"] | ||
| batch_size = 128 | ||
| [components] | ||
| [components.ner] | ||
| source = "en_core_web_md" | ||
| [components.llm] | ||
| factory = "llm" | ||
| [components.llm.task] | ||
| @llm_tasks = "spacy.REL.v1" | ||
| labels = "LivesIn,Visits" | ||
| [components.llm.task.examples] | ||
| @misc = "spacy.FewShotReader.v1" | ||
| path = {str(EXAMPLES_DIR / "rel_examples.jsonl")} | ||
| [components.llm.backend] | ||
| @llm_backends = "spacy.REST.v1" | ||
| api = "OpenAI" | ||
| """ | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def task(): | ||
| text = "Joey rents a place in New York City." | ||
| gold_relations = [RelationItem(dep=0, dest=1, relation="LivesIn")] | ||
| return text, gold_relations | ||
|
|
||
|
|
||
| @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") | ||
| @pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string", "fewshot_cfg_string"]) | ||
| def test_rel_config(cfg_string, request: FixtureRequest): | ||
| """Simple test to check if the config loads properly given different settings""" | ||
|
|
||
| cfg_string = request.getfixturevalue(cfg_string) | ||
| orig_config = Config().from_str(cfg_string) | ||
| nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) | ||
| assert nlp.pipe_names == ["ner", "llm"] | ||
|
|
||
|
|
||
| @pytest.mark.external | ||
| @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") | ||
| @pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string", "fewshot_cfg_string"]) | ||
| def test_rel_predict(task, cfg_string, request): | ||
| """Use OpenAI to get REL results. | ||
| Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable | ||
| """ | ||
| cfg_string = request.getfixturevalue(cfg_string) | ||
| orig_config = Config().from_str(cfg_string) | ||
| nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) | ||
bdura marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| text, _ = task | ||
| doc = nlp(text) | ||
|
|
||
| assert doc.ents | ||
| assert doc._.rel | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| # Relation extraction using LLMs | ||
|
|
||
| This example shows how you can use a model from OpenAI for relation extraction in | ||
| zero- and few-shot settings. | ||
|
|
||
| Here, we use the pretrained [`en_core_web_md` model](https://spacy.io/models/en#en_core_web_sm) | ||
| to perform Named Entity Recognition (NER) using a fast and properly evaluated pipeline. | ||
| Then, we leverage the OpenAI API to detect the relations between the extracted entities. | ||
| In this example, we focus on two simple relations: `LivesIn` and `Visits`. | ||
|
|
||
| First, create a new API key from | ||
| [openai.com](https://platform.openai.com/account/api-keys) or fetch an existing | ||
| one. Record the secret key and make sure this is available as an environmental | ||
| variable: | ||
|
|
||
| ```sh | ||
| export OPENAI_API_KEY="sk-..." | ||
| export OPENAI_API_ORG="org-..." | ||
| ``` | ||
|
|
||
| Then, you can run the pipeline on a sample text via: | ||
|
|
||
| ```sh | ||
| python run_rel_openai_pipeline.py [TEXT] [PATH TO CONFIG] | ||
| ``` | ||
|
|
||
| For example: | ||
|
|
||
| ```sh | ||
| python run_rel_openai_pipeline.py \ | ||
| "Laura just bought an apartment in Boston." \ | ||
| ./openai_rel_zeroshot.cfg | ||
| ``` | ||
|
|
||
| You can also include examples to perform few-shot annotation. To do so, use the | ||
| `openai_rel_fewshot.cfg` file instead. You can find the few-shot examples in | ||
| the `examples.jsonl` file. Feel free to change and update it to your liking. | ||
| We also support other file formats, including `.json`, `.yml` and `.yaml`. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.