Skip to content

Commit 9d794b9

Browse files
YasushiMiyatasenwu
authored andcommitted
Add multi-thread support for Parser._add, Labeler._add and Featurizer._add
1 parent 5ab8d4f commit 9d794b9

File tree

5 files changed

+30
-19
lines changed

5 files changed

+30
-19
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ Changed
3636
* `@YasushiMiyata`_: Changed :class:`UDFRunner`'s and :class:`UDF`'s data commit process as follows:
3737
(`#545 <https://github.com/HazyResearch/fonduer/pull/545>`_)
3838

39-
* Removed ``add`` process in :func:`_apply` in :class:`UDFRunner`.
40-
* Added ``add`` and ``commit`` of ``y`` to :class:`UDF`.
39+
* Removed ``add`` process on single-thread in :func:`_apply` in :class:`UDFRunner`.
40+
* Added ``UDFRunner._add`` of ``y`` on multi-threads to :class:`Parser`, :class:`Labeler` and :class:`Featurizer`.
4141
* Removed ``y`` of document parsed result from ``out_queue`` in :class:`UDF`.
4242

4343
Fixed

src/fonduer/features/featurizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,11 @@ def get_keys(self) -> List[FeatureKey]:
245245
"""
246246
return list(get_sparse_matrix_keys(self.session, FeatureKey))
247247

248-
def _add(self, records_list: List[List[Dict[str, Any]]]) -> None:
248+
def _add(self, session: Session, records_list: List[List[Dict[str, Any]]]) -> None:
249249
# Make a flat list of all records from the list of list of records.
250250
# This helps reduce the number of queries needed to update.
251251
all_records = list(itertools.chain.from_iterable(records_list))
252-
batch_upsert_records(self.session, Feature, all_records)
252+
batch_upsert_records(session, Feature, all_records)
253253

254254
def clear(self, train: bool = False, split: int = 0) -> None: # type: ignore
255255
"""Delete Features of each class from the database.

src/fonduer/parser/parser.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,11 @@ def apply( # type: ignore
129129
progress_bar=progress_bar,
130130
)
131131

132-
def _add(self, doc: Union[Document, None]) -> None:
132+
def _add(self, session: Session, doc: Union[Document, None]) -> None:
133133
# Persist the object if no error happens during parsing.
134134
if doc:
135-
self.session.add(doc)
136-
self.session.commit()
135+
session.add(doc)
136+
session.commit()
137137

138138
def clear(self) -> None: # type: ignore
139139
"""Clear all of the ``Context`` objects in the database."""
@@ -156,6 +156,12 @@ def get_documents(self) -> List[Document]:
156156
157157
:return: A list of all ``Documents`` in the database ordered by name.
158158
"""
159+
# return (
160+
# self.session.query(Document, Sentence)
161+
# .join(Sentence, Document.id == Sentence.document_id)
162+
# .all()
163+
# )
164+
# return self.session.query(Sentence).order_by(Sentence.name).all()
159165
return self.session.query(Document).order_by(Document.name).all()
160166

161167

src/fonduer/supervision/labeler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,9 @@ def drop_keys(
306306

307307
drop_keys(self.session, LabelKey, key_map)
308308

309-
def _add(self, records_list: List[List[Dict[str, Any]]]) -> None:
309+
def _add(self, session: Session, records_list: List[List[Dict[str, Any]]]) -> None:
310310
for records in records_list:
311-
batch_upsert_records(self.session, self.table, records)
311+
batch_upsert_records(session, self.table, records)
312312

313313
def clear( # type: ignore
314314
self,

src/fonduer/utils/udf.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Collection, Dict, List, Optional, Set, Type, Union
77

88
from sqlalchemy import inspect
9-
from sqlalchemy.orm import Session
9+
from sqlalchemy.orm import Session, scoped_session, sessionmaker
1010

1111
from fonduer.meta import Meta, new_sessionmaker
1212
from fonduer.parser.models.document import Document
@@ -94,7 +94,7 @@ def _after_apply(self, **kwargs: Any) -> None:
9494
"""Execute this method by a single process after apply."""
9595
pass
9696

97-
def _add(self, instance: Any) -> None:
97+
def _add(self, session: Session, instance: Any) -> None:
9898
pass
9999

100100
def _apply(
@@ -114,9 +114,13 @@ def _apply(
114114
# Clear the last documents parsed by the last run
115115
self.last_docs = set()
116116

117+
# Create DB session factory for insert data on each UDF (#545)
118+
session_factory = new_sessionmaker()
117119
# Create UDF Processes
118120
for i in range(parallelism):
119121
udf = self.udf_class(
122+
session_factory=session_factory,
123+
runner=self,
120124
in_queue=in_queue,
121125
out_queue=out_queue,
122126
worker_id=i,
@@ -164,8 +168,6 @@ def in_thread_func() -> None:
164168
# Flush the processes
165169
self.udfs = []
166170

167-
self.session.commit()
168-
169171

170172
class UDF(Process):
171173
"""UDF class."""
@@ -174,6 +176,8 @@ class UDF(Process):
174176

175177
def __init__(
176178
self,
179+
session_factory: sessionmaker = None,
180+
runner: UDFRunner = None,
177181
in_queue: Optional[Queue] = None,
178182
out_queue: Optional[Queue] = None,
179183
worker_id: int = 0,
@@ -187,6 +191,8 @@ def __init__(
187191
"""
188192
super().__init__()
189193
self.daemon = True
194+
self.session_factory = session_factory
195+
self.runner = runner
190196
self.in_queue = in_queue
191197
self.out_queue = out_queue
192198
self.worker_id = worker_id
@@ -201,9 +207,9 @@ def run(self) -> None:
201207
multiprocess setting The basic routine is: get from JoinableQueue,
202208
apply, put / add outputs, loop
203209
"""
204-
# Each UDF starts its own Engine
205-
# See SQLalchemy, using connection pools with multiprocessing.
206-
Session = new_sessionmaker()
210+
# Each UDF get thread local (scoped) session from connection pools
211+
# See SQLalchemy, using scoped sesion with multiprocessing.
212+
Session = scoped_session(self.session_factory)
207213
session = Session()
208214
while True:
209215
doc = self.in_queue.get() # block until an item is available
@@ -214,12 +220,11 @@ def run(self) -> None:
214220
if not inspect(doc).transient:
215221
doc = session.merge(doc, load=False)
216222
y = self.apply(doc, **self.apply_kwargs)
217-
if y:
218-
session.add(y)
219-
session.commit()
223+
self.runner._add(session, y)
220224
self.out_queue.put(doc.name)
221225
session.commit()
222226
session.close()
227+
Session.remove()
223228

224229
def apply(
225230
self, doc: Document, **kwargs: Any

0 commit comments

Comments
 (0)