Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Added
* `@wajdikhattel`_: Add multinary candidates.
(`#455 <https://github.com/HazyResearch/fonduer/issues/455>`_)
(`#456 <https://github.com/HazyResearch/fonduer/pull/456>`_)
* `@HiromuHota`_: Add ``nullables`` to :func:`candidate_subclass()` to allow NULL mention in a candidate.
(`#496 <https://github.com/HazyResearch/fonduer/issues/496>`_)
(`#497 <https://github.com/HazyResearch/fonduer/pull/497>`_)

Changed
^^^^^^^
Expand Down
15 changes: 10 additions & 5 deletions src/fonduer/candidates/candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,11 @@ def apply( # type: ignore
enumerate(
# a list of mentions for each mention subclass within a doc
getattr(doc, mention.__tablename__ + "s")
+ ([None] if nullable else [])
)
for mention, nullable in zip(
candidate_class.mentions, candidate_class.nullables
)
for mention in candidate_class.mentions
]
)
# Get a set of stable_ids of candidates.
Expand All @@ -286,15 +289,16 @@ def apply( # type: ignore

# TODO: Make this work for higher-order relations
if self.arities[i] == 2:
ai, a = (cand[0][0], cand[0][1].context)
bi, b = (cand[1][0], cand[1][1].context)
ai, a = (cand[0][0], cand[0][1].context if cand[0][1] else None)
bi, b = (cand[1][0], cand[1][1].context if cand[1][1] else None)

# Check for self-joins, "nested" joins (joins from context to
# its subcontext), and flipped duplicate "symmetric" relations
if not self.self_relations and a == b:
logger.debug(f"Skipping self-joined candidate {cand}")
continue
if not self.nested_relations and (a in b or b in a):
# Skip the check if either is None as None is not iterable.
if not self.nested_relations and (a and b) and (a in b or b in a):
logger.debug(f"Skipping nested candidate {cand}")
continue
if not self.symmetric_relations and ai > bi:
Expand All @@ -306,7 +310,8 @@ def apply( # type: ignore
candidate_args[arg_name] = cand[j][1]

stable_ids = tuple(
cand[j][1].context.get_stable_id() for j in range(self.arities[i])
cand[j][1].context.get_stable_id() if cand[j][1] else None
for j in range(self.arities[i])
)
# Skip if this (temporary) candidate is used by this candidate class.
if (
Expand Down
18 changes: 16 additions & 2 deletions src/fonduer/candidates/models/candidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def candidate_subclass(
table_name: Optional[str] = None,
cardinality: Optional[int] = None,
values: Optional[List[Any]] = None,
nullables: Optional[List[bool]] = None,
) -> Type[Candidate]:
"""Create new relation.

Expand All @@ -95,6 +96,10 @@ def candidate_subclass(
:param cardinality: The cardinality of the variable corresponding to the
Candidate. By default is 2 i.e. is a binary value, e.g. is or is not
a true mention.
:param values: A list of values a candidate can take as their label.
:param nullables: The number of nullables must match that of args.
If nullables[i]==True, a mention for ith mention subclass can be NULL.
If nullables=``None`` (by default), no mention can be NULL.
"""
if table_name is None:
table_name = camel_to_under(class_name)
Expand Down Expand Up @@ -124,6 +129,12 @@ def candidate_subclass(
elif cardinality is not None:
values = list(range(cardinality))

if nullables:
if len(nullables) != len(args):
raise ValueError("The number of nullables must match that of args.")
else:
nullables = [False] * len(args)

class_spec = (args, table_name, cardinality, values)
if class_name in candidate_subclasses:
if class_spec == candidate_subclasses[class_name][1]:
Expand Down Expand Up @@ -153,6 +164,7 @@ def candidate_subclass(
# Helper method to get argument names
"__argnames__": [_.__tablename__ for _ in args],
"mentions": args,
"nullables": nullables,
}
class_attribs["document_id"] = Column(
Integer, ForeignKey("document.id", ondelete="CASCADE")
Expand All @@ -166,10 +178,12 @@ def candidate_subclass(
# Create named arguments, i.e. the entity mentions comprising the
# relation mention.
unique_args = []
for arg in args:
for arg, nullable in zip(args, nullables):
# Primary arguments are constituent Contexts, and their ids
class_attribs[arg.__tablename__ + "_id"] = Column(
Integer, ForeignKey(arg.__tablename__ + ".id", ondelete="CASCADE")
Integer,
ForeignKey(arg.__tablename__ + ".id", ondelete="CASCADE"),
nullable=nullable,
)
class_attribs[arg.__tablename__] = relationship(
arg.__name__,
Expand Down
2 changes: 1 addition & 1 deletion src/fonduer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_set_of_stable_ids(
set_of_stable_ids.update(
set(
[
tuple(m.context.get_stable_id() for m in c)
tuple(m.context.get_stable_id() for m in c) if c else None
for c in getattr(doc, candidate_class.__tablename__ + "s")
]
)
Expand Down
31 changes: 31 additions & 0 deletions tests/candidates/test_candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,3 +541,34 @@ def test_pickle_subclasses():
pickle.loads(pickle.dumps(part))
pickle.loads(pickle.dumps(temp))
pickle.loads(pickle.dumps(parttemp))


def test_candidate_with_nullable_mentions():
"""Test if mentions can be NULL."""
docs_path = "tests/data/html/112823.html"
pdf_path = "tests/data/pdf/112823.pdf"
doc = parse_doc(docs_path, "112823", pdf_path)

# Mention Extraction
MentionTemp = mention_subclass("MentionTemp")
temp_ngrams = MentionNgramsTemp(n_max=2)
mention_extractor_udf = MentionExtractorUDF(
[MentionTemp],
[temp_ngrams],
[temp_matcher],
)
doc = mention_extractor_udf.apply(doc)

assert len(doc.mention_temps) == 23

# Candidate Extraction
CandidateTemp = candidate_subclass("CandidateTemp", [MentionTemp], nullables=[True])
candidate_extractor_udf = CandidateExtractorUDF(
[CandidateTemp], [None], False, False, True
)

doc = candidate_extractor_udf.apply(doc, split=0)
# The number of extracted candidates should be that of mentions + 1 (NULL)
assert len(doc.candidate_temps) == len(doc.mention_temps) + 1
# Extracted candidates should include one with NULL mention.
assert None in [c[0] for c in doc.candidate_temps]