-
-
Couldn't load subscription status.
- Fork 104
feat: REL task #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
feat: REL task #114
Changes from 4 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
b3507f4
chore: ignore notebooks
bdura b50b6e4
feat: add rel task
bdura c40880a
fix: rel template
bdura 41f3bd2
docs: add rel example
bdura e84f920
feat: clean and simplify rel.py
bdura 91f24ec
docs: mention Doc._.rel in docstrings
bdura 77ee80c
feat: update rel
bdura d817db9
docs: update rel example
bdura 0a97ba5
docs: document REL task in readme
bdura 7cf79ca
Update usage_examples/rel_openai/README.md
bdura 9b76090
docs: use en_core_web_md instead of sm
bdura aa09a91
test: add usage test
bdura 8061dd1
fix: mypy issue
bdura dba9855
Merge remote-tracking branch 'origin/main' into feat/rel
bdura 06fb413
test: rel configuration
bdura 116d216
ci: install en_core_web_md
bdura 1493cc5
fix: test_rel_config assert
bdura 3a257f7
test: rel prediction
bdura fd39ded
Merge branch 'main' into feat/rel
rmitsch 4f1113a
Adjust tests w.r.t. new format.
rmitsch bf7dd06
ci: install en_core_web_md in external tests
bdura fbf1f70
docs: update description of NER-capable
bdura 6c6fcf1
fix: check custom rel extension at every call
bdura a6b225d
Merge remote-tracking branch 'origin/main' into feat/rel
bdura 757a638
feat: update rel example to use util.assemble
bdura 3f96586
chore: move the [initialize] section to the end of the rel example co…
bdura f72fede
feat: assemble accepts a config or a path
bdura b7e4488
use assemble in test_rel
bdura 2709af5
feat: remove useless type hint
bdura ed7d4a9
fix: add initialization
bdura cfb2289
fix: cleaner util
bdura 23dfc9c
Update util.py
adrianeboyd 2a8a22a
Update README.md
rmitsch e4724ef
Merge remote-tracking branch 'origin/main' into feat/rel
bdura File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,6 +88,7 @@ target/ | |
|
|
||
| # Jupyter Notebook | ||
| .ipynb_checkpoints | ||
| *.ipynb | ||
|
|
||
| # IPython | ||
| profile_default/ | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,13 @@ | ||
| from .ner import NERTask | ||
| from .noop import NoopTask | ||
| from .rel import RELTask | ||
| from .spancat import SpanCatTask | ||
| from .textcat import TextCatTask | ||
|
|
||
| __all__ = ["NoopTask", "NERTask", "TextCatTask", "SpanCatTask"] | ||
| __all__ = [ | ||
| "NoopTask", | ||
| "NERTask", | ||
| "TextCatTask", | ||
| "SpanCatTask", | ||
| "RELTask", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| from typing import Callable, Dict, Iterable, List, Optional | ||
|
|
||
| import jinja2 | ||
| from pydantic import BaseModel, ConstrainedInt, ValidationError | ||
| from spacy.tokens import Doc | ||
| from wasabi import msg | ||
|
|
||
| from ..registry import lowercase_normalizer, registry | ||
| from .templates import read_template | ||
|
|
||
|
|
||
| class EntityId(ConstrainedInt): | ||
| @classmethod | ||
| def __get_validators__(cls): | ||
| yield cls.clean_ent | ||
| for val in super().__get_validators__(): | ||
| yield val | ||
|
|
||
| @classmethod | ||
| def clean_ent(cls, value): | ||
| if isinstance(value, str): | ||
| value = value.strip("ENT") | ||
| return value | ||
|
|
||
|
|
||
| class RelationItem(BaseModel): | ||
| dep: EntityId | ||
| dest: EntityId | ||
| relation: str | ||
|
|
||
|
|
||
| class RELExample(BaseModel): | ||
| text: str | ||
| relations: List[RelationItem] | ||
|
|
||
|
|
||
| def _preannotate(doc: Doc) -> str: | ||
| offset = 0 | ||
|
|
||
| text = doc.text | ||
|
|
||
| for i, ent in enumerate(doc.ents): | ||
| end = ent.end_char | ||
| before, after = text[: end + offset], text[end + offset :] | ||
|
|
||
| annotation = f"[ENT{i}:{ent.label_}]" | ||
| offset += len(annotation) | ||
|
|
||
| text = f"{before}{annotation}{after}" | ||
|
|
||
| return text | ||
|
|
||
|
|
||
| @registry.llm_tasks("spacy.REL.v1") | ||
| class RELTask: | ||
| _TEMPLATE_STR = read_template("rel") | ||
|
|
||
| def __init__( | ||
| self, | ||
| labels: str, | ||
| label_definitions: Optional[Dict[str, str]] = None, | ||
| examples: Optional[Callable[[], Iterable[Dict]]] = None, | ||
| normalizer: Optional[Callable[[str], str]] = None, | ||
| ): | ||
|
|
||
| if not Doc.has_extension("rel"): | ||
| Doc.set_extension("rel", default=[]) | ||
|
|
||
| self._normalizer = normalizer if normalizer else lowercase_normalizer() | ||
| self._label_dict = { | ||
| self._normalizer(label): label for label in labels.split(",") | ||
| } | ||
| self._label_definitions = label_definitions | ||
| self._examples = [RELExample(**eg) for eg in examples()] if examples else None | ||
|
|
||
| def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: | ||
| environment = jinja2.Environment() | ||
| _template = environment.from_string(self._TEMPLATE_STR) | ||
| for doc in docs: | ||
| prompt = _template.render( | ||
| text=_preannotate(doc), | ||
| labels=list(self._label_dict.values()), | ||
| label_definitions=self._label_definitions, | ||
| examples=self._examples, | ||
| ) | ||
| yield prompt | ||
|
|
||
| def _format_response(self, response: str) -> Iterable[RelationItem]: | ||
| """Parse raw string response into a structured format""" | ||
| for line in response.strip().split("\n"): | ||
| try: | ||
| yield RelationItem.parse_raw(line) | ||
| except ValidationError: | ||
| msg.warn("Validation issue", line) | ||
|
|
||
| def parse_responses( | ||
| self, docs: Iterable[Doc], responses: Iterable[str] | ||
| ) -> Iterable[Doc]: | ||
| for doc, prompt_response in zip(docs, responses): | ||
| rels = list(self._format_response(prompt_response)) | ||
| doc._.rel = rels | ||
| yield doc | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| The text below contains pre-extracted entities, denoted in the following format within the text: | ||
| {# whitespace #} | ||
| <entity text>[ENT<entity id>:<entity label>] | ||
| {# whitespace #} | ||
| From the text below, extract the following relations between entities: | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- for label in labels -%} | ||
| {{ label }} | ||
| {# whitespace #} | ||
| {%- endfor -%} | ||
| {# whitespace #} | ||
| The extraction has to use the following format, with one line for each detected relation: | ||
| {# whitespace #} | ||
| {"dep": <entity id>, "dest": <entity id>, "relation": <relation label>} | ||
| {# whitespace #} | ||
| Make sure that only relevant relations are listed, and that each line is a valid JSON object. | ||
| {# whitespace #} | ||
| {%- if label_definitions -%} | ||
| Below are definitions of each label to help aid you in what kinds of relationship to extract for each label. | ||
| Assume these definitions are written by an expert and follow them closely. | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- for label, definition in label_definitions.items() -%} | ||
| {{ label }}: {{ definition }} | ||
| {# whitespace #} | ||
| {%- endfor -%} | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- endif -%} | ||
| {%- if examples -%} | ||
| Below are some examples (only use these as a guide): | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- for example in examples -%} | ||
| Text: | ||
| ''' | ||
| {{ example.text }} | ||
| ''' | ||
| {# whitespace #} | ||
| {%- for item in example.relations -%} | ||
| {# whitespace #} | ||
| {{ item.json() }} | ||
| {%- endfor -%} | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- endfor -%} | ||
| {# whitespace #} | ||
| {# whitespace #} | ||
| {%- endif -%} | ||
| Here is the text that needs labeling: | ||
| {# whitespace #} | ||
| Text: | ||
| ''' | ||
| {{ text }} | ||
| ''' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| import spacy | ||
| from spacy.tokens import Span | ||
|
|
||
| from spacy_llm.tasks.rel import _preannotate | ||
|
|
||
|
|
||
| def test_text_preannotation(): | ||
| nlp = spacy.load("blank:en") | ||
|
|
||
| doc = nlp("This is a test") | ||
| doc.ents = [Span(doc, start=3, end=4, label="test")] | ||
|
|
||
| assert _preannotate(doc) == "This is a test[ENT0:test]" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| # Relation extraction using LLMs | ||
|
|
||
| This example shows how you can use a model from OpenAI for relation extraction in | ||
| zero- or few-shot settings. | ||
bdura marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Here, we use the pretrained [`en_core_web_sm` model](https://spacy.io/models/en#en_core_web_sm) | ||
bdura marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| to perform Named Entity Recognition (NER) using a fast and properly evaluated pipeline. | ||
| Then, we leverage the OpenAI API to detect the relations between the extracted entities. | ||
| In this example, we focus on two simple relations: `LivesIn` and `Visits`. | ||
|
|
||
| First, create a new API key from | ||
| [openai.com](https://platform.openai.com/account/api-keys) or fetch an existing | ||
| one. Record the secret key and make sure this is available as an environmental | ||
| variable: | ||
|
|
||
| ```sh | ||
| export OPENAI_API_KEY="sk-..." | ||
| export OPENAI_API_ORG="org-..." | ||
| ``` | ||
|
|
||
| Then, you can run the pipeline on a sample text via: | ||
|
|
||
| ```sh | ||
| python run_rel_openai_pipeline.py [TEXT] [PATH TO CONFIG] | ||
| ``` | ||
|
|
||
| For example: | ||
|
|
||
| ```sh | ||
| python run_rel_openai_pipeline.py \ | ||
| "Laura just bought an apartment in Boston." \ | ||
| ./openai_rel_zeroshot.cfg | ||
| ``` | ||
|
|
||
| You can also include examples to perform few-shot annotation. To do so, use the | ||
| `openai_rel_fewshot.cfg` file instead. You can find the few-shot examples in | ||
| the `examples.jsonl` file. Feel free to change and update it to your liking. | ||
| We also support other file formats, including `.json`, `.yml` and `.yaml`. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| {"text": "Laura bought a house in Boston with her husband Mark.", "relations": [{"dep": 0, "dest": 1, "relation": "LivesIn"}, {"dep": 2, "dest": 1, "relation": "Visits"}]} | ||
| {"text": "Laura travelled through South America by bike.", "relations": [{"dep": 0, "dest": 1, "relation": "Visits"}]} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| [paths] | ||
| examples = null | ||
|
|
||
| [nlp] | ||
| lang = "en" | ||
| pipeline = ["ner", "llm_rel"] | ||
|
|
||
| [components] | ||
|
|
||
| [components.ner] | ||
| source = "en_core_web_sm" | ||
|
|
||
| [components.llm_rel] | ||
| factory = "llm" | ||
|
|
||
| [components.llm_rel.task] | ||
| @llm_tasks = "spacy.REL.v1" | ||
| labels = LivesIn,Visits | ||
|
|
||
| [components.llm_rel.task.examples] | ||
| @misc = "spacy.FewShotReader.v1" | ||
| path = ${paths.examples} | ||
|
|
||
| [components.llm_rel.backend] | ||
| @llm_backends = "spacy.REST.v1" | ||
| api = "OpenAI" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| [paths] | ||
| examples = null | ||
|
|
||
| [nlp] | ||
| lang = "en" | ||
| pipeline = ["ner", "llm_rel"] | ||
|
|
||
| [components] | ||
|
|
||
| [components.ner] | ||
| source = "en_core_web_sm" | ||
|
|
||
| [components.llm_rel] | ||
| factory = "llm" | ||
|
|
||
| [components.llm_rel.task] | ||
| @llm_tasks = "spacy.REL.v1" | ||
| labels = LivesIn,Visits | ||
|
|
||
| [components.llm_rel.backend] | ||
| @llm_backends = "spacy.REST.v1" | ||
| api = "OpenAI" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| import os | ||
| from pathlib import Path | ||
|
|
||
| import typer | ||
| from spacy import util | ||
| from wasabi import msg | ||
|
|
||
| Arg = typer.Argument | ||
| Opt = typer.Option | ||
|
|
||
|
|
||
| def run_pipeline( | ||
| # fmt: off | ||
| text: str = Arg("", help="Text to perform text categorization on."), | ||
| config_path: Path = Arg(..., help="Path to the configuration file to use."), | ||
| verbose: bool = Opt(False, "--verbose", "-v", help="Show extra information."), | ||
| # fmt: on | ||
| ): | ||
| if not os.getenv("OPENAI_API_KEY", None): | ||
| msg.fail( | ||
| "OPENAI_API_KEY env variable was not found. " | ||
| "Set it by running 'export OPENAI_API_KEY=...' and try again.", | ||
| exits=1, | ||
| ) | ||
|
|
||
| msg.text(f"Loading config from {config_path}", show=verbose) | ||
| config = util.load_config(config_path) | ||
| # Reload config with dynamic path for examples, if available in config. | ||
| if "examples" in config.get("paths", {}): | ||
| config = util.load_config( | ||
| config_path, | ||
| overrides={"paths.examples": str(Path(__file__).parent / "examples.jsonl")}, | ||
| ) | ||
|
|
||
| nlp = util.load_model_from_config(config, auto_fill=True) | ||
| doc = nlp(text) | ||
|
|
||
| msg.text(f"Text: {doc.text}") | ||
| msg.text(f"Entities: {[(ent.text, ent.label_) for ent in doc.ents]}") | ||
|
|
||
| msg.text("Relations:") | ||
| for r in doc._.rel: | ||
| msg.text(f" - {doc.ents[r.dep]} [{r.relation}] {doc.ents[r.dest]}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| typer.run(run_pipeline) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.