Skip to content

Commit c2477d2

Browse files
committed
feat: add support multary relation
1 parent f3ecd4a commit c2477d2

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

treedlib/features.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from treedlib.templates import *
22
import lxml.etree as et
33

4-
def compile_relation_feature_generator(dictionaries=None, opts={}):
4+
def compile_relation_feature_generator(dictionaries=None, opts={}, is_multary=False):
55
"""
66
Given optional arguments, returns a generator function which accepts an xml root
77
and two lists of mention indexes, and will generate relation features for this relation
88
99
Optional args are:
1010
* dictionaries: should be a dictionary of lists of phrases, where the key is the dict name
1111
* opts: see defaults above
12+
* is_multary: whether to use multiple mentions or binary mentions
1213
"""
1314
# TODO: put globals into opts
1415
#BASIC_ATTRIBS_REL = ['word', 'lemma', 'pos', 'ner', 'dep_label']
@@ -66,6 +67,8 @@ def compile_relation_feature_generator(dictionaries=None, opts={}):
6667
templates.append(DictionaryIntersect(SeqBetween(), d_name, d))
6768

6869
# return generator function
70+
if is_multary:
71+
return Compile(templates).apply_multary_relation
6972
return Compile(templates).apply_relation
7073

7174
"""

treedlib/templates.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,5 +428,8 @@ def apply_mention(self, root, mention_idxs, dict_sub={}, stopwords=None):
428428
def apply_relation(self, root, mention1_idxs, mention2_idxs, dict_sub={}, stopwords=None):
429429
return self.apply(root, [mention1_idxs, mention2_idxs], dict_sub=dict_sub, stopwords=stopwords)
430430

431+
def apply_multary_relation(self, root, mentions, dict_sub={}, stopwords=None):
432+
return self.apply(root, mentions, dict_sub=dict_sub, stopwords=stopwords)
433+
431434
def __repr__(self):
432435
return '\n'.join(str(op) for op in self._iterops())

0 commit comments

Comments
 (0)