Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit b5ded8f

Browse files
[v0.8.x][BUGFIX] avoid using dict for attention cell parameter creation (#1051)
* avoid using dict for attention cell parameter creation * fix * Update test_sequence_sampler.py * Update cache.py
1 parent 88fb92e commit b5ded8f

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

src/gluonnlp/model/attention_cell.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def __init__(self, base_cell, query_units, key_units, value_units, num_heads, us
209209
self._base_cell = base_cell
210210
self._num_heads = num_heads
211211
self._use_bias = use_bias
212-
units = {'query': query_units, 'key': key_units, 'value': value_units}
213-
for name, unit in units.items():
212+
units = [('query', query_units), ('key', key_units), ('value', value_units)]
213+
for name, unit in units:
214214
if unit % self._num_heads != 0:
215215
raise ValueError(
216216
'In MultiHeadAttetion, the {name}_units should be divided exactly'

src/gluonnlp/model/train/cache.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,13 @@ def __init__(self, lm_model, vocab_size, window, theta, lambdas, **kwargs):
7373
with self.name_scope():
7474
self.lm_model = lm_model
7575

76-
def save_parameters(self, filename, deduplicate=False): # pylint: disable=arguments-differ
76+
def save_parameters(self, filename): # pylint: disable=arguments-differ
7777
"""Save parameters to file.
7878
7979
filename : str
8080
Path to file.
81-
deduplicate : bool, default False
82-
If True, save shared parameters only once. Otherwise, if a Block
83-
contains multiple sub-blocks that share parameters, each of the
84-
shared parameters will be separately saved for every sub-block.
8581
"""
86-
self.lm_model.save_parameters(filename, deduplicate=deduplicate)
82+
self.lm_model.save_parameters(filename)
8783

8884
def load_parameters(self, filename, ctx=mx.cpu()): # pylint: disable=arguments-differ
8985
"""Load parameters from file.

tests/unittest/test_sequence_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def context_free_distribution(step_input, states):
5353
true_dist = dist.softmax().asnumpy()
5454
assert_allclose(true_dist, np.array(emp_dist), atol=0.01, rtol=0.1)
5555

56-
# temporarily disabled model.HybridBeamSearchSampler test
57-
# due to https://github.com/dmlc/gluon-nlp/issues/706
56+
57+
@pytest.mark.skip(reason='https://github.com/dmlc/gluon-nlp/issues/1020')
5858
@pytest.mark.seed(1)
5959
@pytest.mark.parametrize('hybridize', [False, True])
6060
@pytest.mark.parametrize('sampler_cls', [model.BeamSearchSampler, model.HybridBeamSearchSampler])

0 commit comments

Comments
 (0)