Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 5c87d6d

Browse files
committed
Make GPT2Model a HybridBlock
1 parent e09281c commit 5c87d6d

File tree

1 file changed

+41
-33
lines changed
  • scripts/text_generation/model

1 file changed

+41
-33
lines changed

scripts/text_generation/model/gpt.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import os
2323

2424
import mxnet as mx
25-
from mxnet.gluon import Block, HybridBlock, nn
25+
from mxnet.gluon import HybridBlock, nn
2626
from mxnet.gluon.model_zoo import model_store
2727
import numpy as np
2828

@@ -32,7 +32,7 @@
3232
from gluonnlp.model.utils import _load_pretrained_params, _load_vocab
3333

3434

35-
class GPT2SelfAttentionLayer(Block):
35+
class GPT2SelfAttentionLayer(HybridBlock):
3636
"""Self-attention layer used in OpenAI GPT-2.
3737
3838
Parameters
@@ -88,49 +88,54 @@ def __init__(self, units, num_heads, dropout=0.0,
8888
bias_initializer=bias_initializer,
8989
prefix='out_proj_')
9090

91-
def forward(self, data, states=None): # pylint: disable=arguments-differ
92-
batch_size = data.shape[0]
93-
seq_len = data.shape[1]
91+
def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ
9492
# Generate mask
9593
if states is not None:
9694
prev_key, prev_value = states
97-
prev_len = prev_key.shape[2]
95+
96+
prev_len_range = F.contrib.arange_like(prev_key, axis=2)
97+
data_len_range = F.contrib.arange_like(data, axis=2)
98+
prev_len = F.broadcast_add(F.slice_axis(prev_len_range, axis=0, begin=-1, end=None),
99+
F.ones((1, )))
100+
101+
data_pos = F.broadcast_add(F.contrib.arange_like(data, axis=1), prev_len)
102+
all_pos = F.contrib.arange_like(F.concat(prev_len_range, data_len_range, dim=0))
98103
else:
99104
prev_key, prev_value = None, None
100-
prev_len = 0
101-
data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=data.dtype)
102-
all_pos = mx.nd.arange(seq_len + prev_len, ctx=data.context, dtype=data.dtype)
103-
mask = mx.nd.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1)))
104-
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0,
105-
size=batch_size * self._num_heads)
105+
data_pos = F.contrib.arange_like(data, axis=1)
106+
all_pos = data_pos
107+
108+
mask = F.broadcast_lesser_equal(all_pos.reshape((1, -1)), data_pos.reshape((-1, 1)))
109+
mask = F.broadcast_like(F.expand_dims(mask, axis=0), data, lhs_axes=(0, ), rhs_axes=(0, ))
110+
mask = F.concat(*[mask] * self._num_heads, dim=0)
106111

107112
# Multi-head attention
108113
qkv = self._multi_head_qkv_proj(data) # Shape (batch_size, seq_len, 3 * units)
109-
qkv = mx.nd.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len)
114+
qkv = F.swapaxes(qkv, 1, 2) # Shape (batch_size, 3 * units, seq_len)
110115

111116
# Each has shape (batch_size, units, seq_len)
112-
query, key, value = mx.nd.split(qkv, num_outputs=3, axis=1)
117+
query, key, value = F.split(qkv, num_outputs=3, axis=1)
113118
# Map each to have shape (batch_size * num_head, ele_units, seq_len)
114119
query = query.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
115120
shape=(-1, 0, 0), reverse=True)
116121
key = key.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
117122
shape=(-1, 0, 0), reverse=True)
118123
value = value.reshape(shape=(0, -4, self._num_heads, -1, 0)).reshape(
119124
shape=(-1, 0, 0), reverse=True)
120-
query = mx.nd.swapaxes(query, 1, 2)
121-
key = mx.nd.swapaxes(key, 1, 2)
122-
value = mx.nd.swapaxes(value, 1, 2)
125+
query = F.swapaxes(query, 1, 2)
126+
key = F.swapaxes(key, 1, 2)
127+
value = F.swapaxes(value, 1, 2)
123128
if prev_key is not None:
124-
key = mx.nd.concat(prev_key.reshape((-1, 0, 0), reverse=True),
125-
key, dim=1) # Shape (batch_size * num_heads, all_len, ele_units)
129+
# Shape (batch_size * num_heads, all_len, ele_units)
130+
key = F.concat(prev_key.reshape((-1, 0, 0), reverse=True), key, dim=1)
126131
if prev_value is not None:
127-
value = mx.nd.concat(prev_value.reshape((-1, 0, 0), reverse=True),
128-
value, dim=1)
132+
value = F.concat(prev_value.reshape((-1, 0, 0), reverse=True),
133+
value, dim=1)
129134

130135
# Shape (batch_size * num_heads, all_len, ele_units)
131136
out, _ = self._base_attn_cell(query, key, value, mask)
132-
out = mx.nd.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True),
133-
axes=(0, 2, 1, 3)).reshape((0, 0, -1))
137+
out = F.transpose(out.reshape((-1, self._num_heads, 0, 0), reverse=True),
138+
axes=(0, 2, 1, 3)).reshape((0, 0, -1))
134139
out = self._out_proj(out)
135140
return out, [key.reshape((-1, self._num_heads, 0, 0), reverse=True),
136141
value.reshape((-1, self._num_heads, 0, 0), reverse=True)]
@@ -186,7 +191,7 @@ def hybrid_forward(self, F, data): # pylint: disable=arguments-differ
186191
return out
187192

188193

189-
class GPT2Model(Block):
194+
class GPT2Model(HybridBlock):
190195
"""Generic Model for GPT-2.
191196
192197
Parameters
@@ -223,7 +228,7 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
223228
weight_initializer=mx.init.Normal(0.02))
224229
self._logits_proj = nn.Dense(units=vocab_size, in_units=units, use_bias=False,
225230
flatten=False, params=self._embed.params)
226-
self._self_attention_layers = nn.Sequential()
231+
self._self_attention_layers = nn.HybridSequential()
227232
self._ffn_layers = nn.HybridSequential()
228233
self._attn_ln = nn.HybridSequential()
229234
self._ffn_ln = nn.HybridSequential()
@@ -237,7 +242,7 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
237242
self._ffn_ln.add(nn.LayerNorm(prefix='ffn_ln{}_'.format(i)))
238243
self._final_ln = nn.LayerNorm(prefix='final_ln{}_'.format(i))
239244

240-
def forward(self, data, states=None): # pylint: disable=arguments-differ
245+
def hybrid_forward(self, F, data, states=None): # pylint: disable=arguments-differ
241246
"""
242247
243248
Parameters
@@ -253,15 +258,18 @@ def forward(self, data, states=None): # pylint: disable=arguments-differ
253258
new_states : list of NDArray
254259
"""
255260
new_states = []
256-
batch_size, seq_len = data.shape[0], data.shape[1]
257261
if states is not None:
258-
prev_len = states[0].shape[1]
262+
prev_key, _ = states
263+
prev_len_range = F.contrib.arange_like(prev_key, axis=2)
264+
prev_len = F.broadcast_add(F.slice_axis(prev_len_range, axis=0, begin=-1, end=None),
265+
F.ones((1, )))
266+
data_pos = F.broadcast_add(F.contrib.arange_like(data, axis=1), prev_len)
259267
else:
260-
prev_len = 0
261-
assert seq_len + prev_len <= self._max_length
262-
data_pos = mx.nd.arange(prev_len, prev_len + seq_len, ctx=data.context, dtype=np.float32)
263-
data_pos = mx.nd.broadcast_axes(mx.nd.expand_dims(data_pos, axis=0),
264-
axis=0, size=batch_size)
268+
data_pos = F.contrib.arange_like(data, axis=1)
269+
if F is mx.nd:
270+
assert data.shape[1] + prev_key.shape[2] <= self._max_length
271+
data_pos = F.broadcast_like(F.expand_dims(data_pos, axis=0), data,
272+
lhs_axes=(0, ), rhs_axes=(0, ))
265273
out = self._embed(data) + self._pos_embed(data_pos)
266274
for i in range(self._num_layers):
267275
attn_layer = self._self_attention_layers[i]

0 commit comments

Comments
 (0)