Skip to content

Commit bbfa56a

Browse files
author
Hiromu Hota
committed
Merge remote-tracking branch 'github/fix/425' into fix/425
2 parents c55295d + d489d15 commit bbfa56a

File tree

5 files changed

+38
-3
lines changed

5 files changed

+38
-3
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ Fixed
99
^^^^^
1010
* `@kaikun213`_: Fix bug in table range difference calculations.
1111
(`#420 <https://github.com/HazyResearch/fonduer/pull/420>`_)
12+
* `@HiromuHota`_: mention_extractor.apply with clear=True now works even if it's not the first run.
13+
(`#424 <https://github.com/HazyResearch/fonduer/pull/424>`_)
1214
* `@HiromuHota`_: Fix :func:`get_horz_ngrams` and :func:`get_vert_ngrams` so that they
1315
work even when the input mention is tabular.
1416
(`#425 <https://github.com/HazyResearch/fonduer/issues/425>`_)

src/fonduer/candidates/mentions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from fonduer.candidates.models.temporary_context import TemporaryContext
2020
from fonduer.parser.models import Context, Document, Sentence
2121
from fonduer.utils.udf import UDF, UDFRunner
22+
from fonduer.utils.utils import get_dict_of_stable_id
2223

2324
logger = logging.getLogger(__name__)
2425

@@ -566,7 +567,7 @@ def apply(self, doc: Document, **kwargs: Any) -> Document:
566567
:param doc: A document to process.
567568
"""
568569
# Get a dict of stable_id of contexts.
569-
dict_of_stable_id: Dict[str, Context] = {}
570+
dict_of_stable_id: Dict[str, Context] = get_dict_of_stable_id(doc)
570571

571572
# Iterate over each mention class
572573
for i, mention_class in enumerate(self.mention_classes):

src/fonduer/utils/data_model_utils/visual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def _get_direction_ngrams(
271271
yield ngram
272272
else:
273273
for ts in ngrams_space.apply(sentence):
274-
if ( # True if visually aligned AND not from itsself.
274+
if ( # True if visually aligned AND not from itself.
275275
bbox_direction_aligned(bbox_from_span(ts), bbox_from_span(span))
276276
and ts not in span
277277
and span not in ts

src/fonduer/utils/utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from builtins import range
33
from typing import TYPE_CHECKING, Dict, Iterator, List, Set, Tuple, Type, Union
44

5-
from fonduer.parser.models import Document, Sentence
5+
from fonduer.parser.models import Context, Document, Sentence
66

77
if TYPE_CHECKING: # to prevent circular imports
88
from fonduer.candidates.models import Candidate
@@ -64,3 +64,29 @@ def get_set_of_stable_ids(
6464
)
6565
)
6666
return set_of_stable_ids
67+
68+
69+
def get_dict_of_stable_id(doc: Document) -> Dict[str, Context]:
70+
"""Return a mapping of a stable_id to its context."""
71+
return {
72+
doc.stable_id: doc,
73+
**{
74+
c.stable_id: c
75+
for a in [
76+
"sentences",
77+
"paragraphs",
78+
"captions",
79+
"cells",
80+
"tables",
81+
"sections",
82+
"figures",
83+
]
84+
for c in getattr(doc, a)
85+
},
86+
**{
87+
c.stable_id: c
88+
for s in doc.sentences
89+
for a in ["spans", "implicit_spans"]
90+
for c in getattr(s, a)
91+
},
92+
}

tests/e2e/test_incremental.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ def test_incremental():
9292
assert session.query(Part).count() == 11
9393
assert session.query(Temp).count() == 8
9494

95+
# Test if clear=True works
96+
mention_extractor.apply(docs, parallelism=PARALLEL, clear=True)
97+
98+
assert session.query(Part).count() == 11
99+
assert session.query(Temp).count() == 8
100+
95101
# Candidate Extraction
96102
PartTemp = candidate_subclass("PartTemp", [Part, Temp])
97103

0 commit comments

Comments
 (0)