Skip to content

Commit 518a5a6

Browse files
xiaomengyfacebook-github-bot
authored andcommitted
[fix] Fix test_nucleus_sampling after pytorch multinomial fix (#684)
Summary: Pull Request resolved: #684 Fix test_nucleus_sampling after pytorch multinomial fix Reviewed By: vedanuj Differential Revision: D24997596 fbshipit-source-id: 27b98d1289a36d151abf60e3a2ad0fb150da4139
1 parent 8ff3f56 commit 518a5a6

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

tests/utils/test_text.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mmf.utils.configuration import Configuration
1111
from mmf.utils.env import setup_imports
1212
from mmf.utils.general import get_mmf_root
13+
from packaging.version import LegacyVersion
1314
from tests.test_utils import dummy_args
1415
from tests.utils.test_model import TestDecoderModel
1516

@@ -152,19 +153,26 @@ def test_nucleus_sampling(self):
152153
tokens = model(sample_list)["captions"]
153154

154155
# these are expected tokens for sum_threshold = 0.5
155-
expected_tokens = [
156-
1.0,
157-
29.0,
158-
11.0,
159-
11.0,
160-
39.0,
161-
10.0,
162-
31.0,
163-
4.0,
164-
19.0,
165-
39.0,
166-
2.0,
167-
]
156+
157+
# Because of a bug fix in https://github.com/pytorch/pytorch/pull/47386
158+
# the torch.Tensor.multinomail will generate different random sequence.
159+
# TODO: Remove this hack after OSS uses later version of PyTorch.
160+
if LegacyVersion(torch.__version__) > LegacyVersion("1.7.1"):
161+
expected_tokens = [1.0, 23.0, 38.0, 30.0, 5.0, 11.0, 2.0]
162+
else:
163+
expected_tokens = [
164+
1.0,
165+
29.0,
166+
11.0,
167+
11.0,
168+
39.0,
169+
10.0,
170+
31.0,
171+
4.0,
172+
19.0,
173+
39.0,
174+
2.0,
175+
]
168176

169177
self.assertEqual(tokens[0].tolist(), expected_tokens)
170178

0 commit comments

Comments
 (0)