Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,9 @@ def rho():
for chunk_no, chunk in enumerate(chunks):
reallen += len(chunk) # keep track of how many documents we've processed so far

if reallen > lencorpus:
raise RuntimeError("processed size is bigger than corpus length (don't use infinite iterators)")

if eval_every and ((reallen == lencorpus) or ((chunk_no + 1) % (eval_every * self.numworkers) == 0)):
self.log_perplexity(chunk, total_docs=lencorpus)

Expand Down
31 changes: 31 additions & 0 deletions gensim/test/test_ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,37 @@ def testAlphaAuto(self):
# endclass TestLdaMulticore


class TestLdaIterableTraining(unittest.TestCase, basetmtests.TestBaseTopicModel):
"""Class for testing infinite generators,
for case when size of processed documents
goes beyond corpus length."""

class ProblematicIterable:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no benefit to nesting classes here. Please move this new class to module scope.

def __init__(self):
self.bag_of_words = [(0, 2), (3, 1), (6, 1), (100, 2)]
self.cursor = 0

def __iter__(self):
self.cursor = 0
logging.info('TestIterable() __iter__ was called')
return self

def __next__(self):
if self.cursor < 11:
self.cursor += 1
return self.bag_of_words
else:
logging.info('TestIterable() returned StopIteration')
raise StopIteration

def setUp(self) -> None:
self.corpus = self.ProblematicIterable()
self.class_ = ldamodel.LdaModel

def testDoesntTrainBeyondCorpusSize(self):
with self.assertRaises(RuntimeError):
self.model = self.class_(self.corpus, num_topics=2)

if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()