Skip to content

Commit aa54a2e

Browse files
author
Hiromu Hota
committed
Add nullables to candidate_subclass() (fix #496)
1 parent 8a489a6 commit aa54a2e

File tree

5 files changed

+61
-8
lines changed

5 files changed

+61
-8
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ Added
1818
* `@wajdikhattel`_: Add multinary candidates.
1919
(`#455 <https://github.com/HazyResearch/fonduer/issues/455>`_)
2020
(`#456 <https://github.com/HazyResearch/fonduer/pull/456>`_)
21+
* `@HiromuHota`_: Add ``nullables`` to :func:`candidate_subclass()` to allow NULL mention in a candidate.
22+
(`#496 <https://github.com/HazyResearch/fonduer/issues/496>`_)
23+
(`#497 <https://github.com/HazyResearch/fonduer/pull/497>`_)
2124

2225
Changed
2326
^^^^^^^

src/fonduer/candidates/candidates.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,11 @@ def apply( # type: ignore
265265
enumerate(
266266
# a list of mentions for each mention subclass within a doc
267267
getattr(doc, mention.__tablename__ + "s")
268+
+ ([None] if nullable else [])
269+
)
270+
for mention, nullable in zip(
271+
candidate_class.mentions, candidate_class.nullables
268272
)
269-
for mention in candidate_class.mentions
270273
]
271274
)
272275
# Get a set of stable_ids of candidates.
@@ -286,15 +289,16 @@ def apply( # type: ignore
286289

287290
# TODO: Make this work for higher-order relations
288291
if self.arities[i] == 2:
289-
ai, a = (cand[0][0], cand[0][1].context)
290-
bi, b = (cand[1][0], cand[1][1].context)
292+
ai, a = (cand[0][0], cand[0][1].context if cand[0][1] else None)
293+
bi, b = (cand[1][0], cand[1][1].context if cand[1][1] else None)
291294

292295
# Check for self-joins, "nested" joins (joins from context to
293296
# its subcontext), and flipped duplicate "symmetric" relations
294297
if not self.self_relations and a == b:
295298
logger.debug(f"Skipping self-joined candidate {cand}")
296299
continue
297-
if not self.nested_relations and (a in b or b in a):
300+
# Skip the check if either is None as None is not iterable.
301+
if not self.nested_relations and (a and b) and (a in b or b in a):
298302
logger.debug(f"Skipping nested candidate {cand}")
299303
continue
300304
if not self.symmetric_relations and ai > bi:
@@ -306,7 +310,8 @@ def apply( # type: ignore
306310
candidate_args[arg_name] = cand[j][1]
307311

308312
stable_ids = tuple(
309-
cand[j][1].context.get_stable_id() for j in range(self.arities[i])
313+
cand[j][1].context.get_stable_id() if cand[j][1] else None
314+
for j in range(self.arities[i])
310315
)
311316
# Skip if this (temporary) candidate is used by this candidate class.
312317
if (

src/fonduer/candidates/models/candidate.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def candidate_subclass(
7474
table_name: Optional[str] = None,
7575
cardinality: Optional[int] = None,
7676
values: Optional[List[Any]] = None,
77+
nullables: Optional[List[bool]] = None,
7778
) -> Type[Candidate]:
7879
"""Create new relation.
7980
@@ -95,6 +96,10 @@ def candidate_subclass(
9596
:param cardinality: The cardinality of the variable corresponding to the
9697
Candidate. By default is 2 i.e. is a binary value, e.g. is or is not
9798
a true mention.
99+
:param values: A list of values a candidate can take as their label.
100+
:param nullables: The number of nullables must match that of args.
101+
If nullables[i]==True, a mention for ith mention subclass can be NULL.
102+
If nullables=``None`` (by default), no mention can be NULL.
98103
"""
99104
if table_name is None:
100105
table_name = camel_to_under(class_name)
@@ -124,6 +129,12 @@ def candidate_subclass(
124129
elif cardinality is not None:
125130
values = list(range(cardinality))
126131

132+
if nullables:
133+
if len(nullables) != len(args):
134+
raise ValueError("The number of nullables must match that of args.")
135+
else:
136+
nullables = [False] * len(args)
137+
127138
class_spec = (args, table_name, cardinality, values)
128139
if class_name in candidate_subclasses:
129140
if class_spec == candidate_subclasses[class_name][1]:
@@ -153,6 +164,7 @@ def candidate_subclass(
153164
# Helper method to get argument names
154165
"__argnames__": [_.__tablename__ for _ in args],
155166
"mentions": args,
167+
"nullables": nullables,
156168
}
157169
class_attribs["document_id"] = Column(
158170
Integer, ForeignKey("document.id", ondelete="CASCADE")
@@ -166,10 +178,12 @@ def candidate_subclass(
166178
# Create named arguments, i.e. the entity mentions comprising the
167179
# relation mention.
168180
unique_args = []
169-
for arg in args:
181+
for arg, nullable in zip(args, nullables):
170182
# Primary arguments are constituent Contexts, and their ids
171183
class_attribs[arg.__tablename__ + "_id"] = Column(
172-
Integer, ForeignKey(arg.__tablename__ + ".id", ondelete="CASCADE")
184+
Integer,
185+
ForeignKey(arg.__tablename__ + ".id", ondelete="CASCADE"),
186+
nullable=nullable,
173187
)
174188
class_attribs[arg.__tablename__] = relationship(
175189
arg.__name__,

src/fonduer/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_set_of_stable_ids(
6161
set_of_stable_ids.update(
6262
set(
6363
[
64-
tuple(m.context.get_stable_id() for m in c)
64+
tuple(m.context.get_stable_id() for m in c) if c else None
6565
for c in getattr(doc, candidate_class.__tablename__ + "s")
6666
]
6767
)

tests/candidates/test_candidates.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,3 +541,34 @@ def test_pickle_subclasses():
541541
pickle.loads(pickle.dumps(part))
542542
pickle.loads(pickle.dumps(temp))
543543
pickle.loads(pickle.dumps(parttemp))
544+
545+
546+
def test_candidate_with_nullable_mentions():
547+
"""Test if mentions can be NULL."""
548+
docs_path = "tests/data/html/112823.html"
549+
pdf_path = "tests/data/pdf/112823.pdf"
550+
doc = parse_doc(docs_path, "112823", pdf_path)
551+
552+
# Mention Extraction
553+
MentionTemp = mention_subclass("MentionTemp")
554+
temp_ngrams = MentionNgramsTemp(n_max=2)
555+
mention_extractor_udf = MentionExtractorUDF(
556+
[MentionTemp],
557+
[temp_ngrams],
558+
[temp_matcher],
559+
)
560+
doc = mention_extractor_udf.apply(doc)
561+
562+
assert len(doc.mention_temps) == 23
563+
564+
# Candidate Extraction
565+
CandidateTemp = candidate_subclass("CandidateTemp", [MentionTemp], nullables=[True])
566+
candidate_extractor_udf = CandidateExtractorUDF(
567+
[CandidateTemp], [None], False, False, True
568+
)
569+
570+
doc = candidate_extractor_udf.apply(doc, split=0)
571+
# The number of extracted candidates should be that of mentions + 1 (NULL)
572+
assert len(doc.candidate_temps) == len(doc.mention_temps) + 1
573+
# Extracted candidates should include one with NULL mention.
574+
assert None in [c[0] for c in doc.candidate_temps]

0 commit comments

Comments
 (0)