22
22
import os
23
23
24
24
import mxnet as mx
25
- from mxnet .gluon import Block , HybridBlock , nn
25
+ from mxnet .gluon import HybridBlock , nn
26
26
from mxnet .gluon .model_zoo import model_store
27
27
import numpy as np
28
28
32
32
from gluonnlp .model .utils import _load_pretrained_params , _load_vocab
33
33
34
34
35
- class GPT2SelfAttentionLayer (Block ):
35
+ class GPT2SelfAttentionLayer (HybridBlock ):
36
36
"""Self-attention layer used in OpenAI GPT-2.
37
37
38
38
Parameters
@@ -88,49 +88,54 @@ def __init__(self, units, num_heads, dropout=0.0,
88
88
bias_initializer = bias_initializer ,
89
89
prefix = 'out_proj_' )
90
90
91
- def forward (self , data , states = None ): # pylint: disable=arguments-differ
92
- batch_size = data .shape [0 ]
93
- seq_len = data .shape [1 ]
91
+ def hybrid_forward (self , F , data , states = None ): # pylint: disable=arguments-differ
94
92
# Generate mask
95
93
if states is not None :
96
94
prev_key , prev_value = states
97
- prev_len = prev_key .shape [2 ]
95
+
96
+ prev_len_range = F .contrib .arange_like (prev_key , axis = 2 )
97
+ data_len_range = F .contrib .arange_like (data , axis = 2 )
98
+ prev_len = F .broadcast_add (F .slice_axis (prev_len_range , axis = 0 , begin = - 1 , end = None ),
99
+ F .ones ((1 , )))
100
+
101
+ data_pos = F .broadcast_add (F .contrib .arange_like (data , axis = 1 ), prev_len )
102
+ all_pos = F .contrib .arange_like (F .concat (prev_len_range , data_len_range , dim = 0 ))
98
103
else :
99
104
prev_key , prev_value = None , None
100
- prev_len = 0
101
- data_pos = mx . nd . arange ( prev_len , prev_len + seq_len , ctx = data . context , dtype = data . dtype )
102
- all_pos = mx . nd . arange ( seq_len + prev_len , ctx = data . context , dtype = data . dtype )
103
- mask = mx . nd .broadcast_lesser_equal (all_pos .reshape ((1 , - 1 )), data_pos .reshape ((- 1 , 1 )))
104
- mask = mx . nd . broadcast_axes ( mx . nd . expand_dims (mask , axis = 0 ), axis = 0 ,
105
- size = batch_size * self ._num_heads )
105
+ data_pos = F . contrib . arange_like ( data , axis = 1 )
106
+ all_pos = data_pos
107
+
108
+ mask = F .broadcast_lesser_equal (all_pos .reshape ((1 , - 1 )), data_pos .reshape ((- 1 , 1 )))
109
+ mask = F . broadcast_like ( F . expand_dims (mask , axis = 0 ), data , lhs_axes = ( 0 , ), rhs_axes = ( 0 , ))
110
+ mask = F . concat ( * [ mask ] * self ._num_heads , dim = 0 )
106
111
107
112
# Multi-head attention
108
113
qkv = self ._multi_head_qkv_proj (data ) # Shape (batch_size, seq_len, 3 * units)
109
- qkv = mx . nd .swapaxes (qkv , 1 , 2 ) # Shape (batch_size, 3 * units, seq_len)
114
+ qkv = F .swapaxes (qkv , 1 , 2 ) # Shape (batch_size, 3 * units, seq_len)
110
115
111
116
# Each has shape (batch_size, units, seq_len)
112
- query , key , value = mx . nd .split (qkv , num_outputs = 3 , axis = 1 )
117
+ query , key , value = F .split (qkv , num_outputs = 3 , axis = 1 )
113
118
# Map each to have shape (batch_size * num_head, ele_units, seq_len)
114
119
query = query .reshape (shape = (0 , - 4 , self ._num_heads , - 1 , 0 )).reshape (
115
120
shape = (- 1 , 0 , 0 ), reverse = True )
116
121
key = key .reshape (shape = (0 , - 4 , self ._num_heads , - 1 , 0 )).reshape (
117
122
shape = (- 1 , 0 , 0 ), reverse = True )
118
123
value = value .reshape (shape = (0 , - 4 , self ._num_heads , - 1 , 0 )).reshape (
119
124
shape = (- 1 , 0 , 0 ), reverse = True )
120
- query = mx . nd .swapaxes (query , 1 , 2 )
121
- key = mx . nd .swapaxes (key , 1 , 2 )
122
- value = mx . nd .swapaxes (value , 1 , 2 )
125
+ query = F .swapaxes (query , 1 , 2 )
126
+ key = F .swapaxes (key , 1 , 2 )
127
+ value = F .swapaxes (value , 1 , 2 )
123
128
if prev_key is not None :
124
- key = mx . nd . concat ( prev_key . reshape (( - 1 , 0 , 0 ), reverse = True ),
125
- key , dim = 1 ) # Shape (batch_size * num_heads, all_len, ele_units )
129
+ # Shape (batch_size * num_heads, all_len, ele_units)
130
+ key = F . concat ( prev_key . reshape (( - 1 , 0 , 0 ), reverse = True ), key , dim = 1 )
126
131
if prev_value is not None :
127
- value = mx . nd .concat (prev_value .reshape ((- 1 , 0 , 0 ), reverse = True ),
128
- value , dim = 1 )
132
+ value = F .concat (prev_value .reshape ((- 1 , 0 , 0 ), reverse = True ),
133
+ value , dim = 1 )
129
134
130
135
# Shape (batch_size * num_heads, all_len, ele_units)
131
136
out , _ = self ._base_attn_cell (query , key , value , mask )
132
- out = mx . nd .transpose (out .reshape ((- 1 , self ._num_heads , 0 , 0 ), reverse = True ),
133
- axes = (0 , 2 , 1 , 3 )).reshape ((0 , 0 , - 1 ))
137
+ out = F .transpose (out .reshape ((- 1 , self ._num_heads , 0 , 0 ), reverse = True ),
138
+ axes = (0 , 2 , 1 , 3 )).reshape ((0 , 0 , - 1 ))
134
139
out = self ._out_proj (out )
135
140
return out , [key .reshape ((- 1 , self ._num_heads , 0 , 0 ), reverse = True ),
136
141
value .reshape ((- 1 , self ._num_heads , 0 , 0 ), reverse = True )]
@@ -186,7 +191,7 @@ def hybrid_forward(self, F, data): # pylint: disable=arguments-differ
186
191
return out
187
192
188
193
189
- class GPT2Model (Block ):
194
+ class GPT2Model (HybridBlock ):
190
195
"""Generic Model for GPT-2.
191
196
192
197
Parameters
@@ -223,7 +228,7 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
223
228
weight_initializer = mx .init .Normal (0.02 ))
224
229
self ._logits_proj = nn .Dense (units = vocab_size , in_units = units , use_bias = False ,
225
230
flatten = False , params = self ._embed .params )
226
- self ._self_attention_layers = nn .Sequential ()
231
+ self ._self_attention_layers = nn .HybridSequential ()
227
232
self ._ffn_layers = nn .HybridSequential ()
228
233
self ._attn_ln = nn .HybridSequential ()
229
234
self ._ffn_ln = nn .HybridSequential ()
@@ -237,7 +242,7 @@ def __init__(self, units, vocab_size, max_length, num_layers, num_heads, dropout
237
242
self ._ffn_ln .add (nn .LayerNorm (prefix = 'ffn_ln{}_' .format (i )))
238
243
self ._final_ln = nn .LayerNorm (prefix = 'final_ln{}_' .format (i ))
239
244
240
- def forward (self , data , states = None ): # pylint: disable=arguments-differ
245
+ def hybrid_forward (self , F , data , states = None ): # pylint: disable=arguments-differ
241
246
"""
242
247
243
248
Parameters
@@ -253,15 +258,18 @@ def forward(self, data, states=None): # pylint: disable=arguments-differ
253
258
new_states : list of NDArray
254
259
"""
255
260
new_states = []
256
- batch_size , seq_len = data .shape [0 ], data .shape [1 ]
257
261
if states is not None :
258
- prev_len = states [0 ].shape [1 ]
262
+ prev_key , _ = states
263
+ prev_len_range = F .contrib .arange_like (prev_key , axis = 2 )
264
+ prev_len = F .broadcast_add (F .slice_axis (prev_len_range , axis = 0 , begin = - 1 , end = None ),
265
+ F .ones ((1 , )))
266
+ data_pos = F .broadcast_add (F .contrib .arange_like (data , axis = 1 ), prev_len )
259
267
else :
260
- prev_len = 0
261
- assert seq_len + prev_len <= self . _max_length
262
- data_pos = mx . nd . arange ( prev_len , prev_len + seq_len , ctx = data .context , dtype = np . float32 )
263
- data_pos = mx . nd . broadcast_axes ( mx . nd . expand_dims (data_pos , axis = 0 ),
264
- axis = 0 , size = batch_size )
268
+ data_pos = F . contrib . arange_like ( data , axis = 1 )
269
+ if F is mx . nd :
270
+ assert data .shape [ 1 ] + prev_key . shape [ 2 ] <= self . _max_length
271
+ data_pos = F . broadcast_like ( F . expand_dims (data_pos , axis = 0 ), data ,
272
+ lhs_axes = ( 0 , ), rhs_axes = ( 0 , ) )
265
273
out = self ._embed (data ) + self ._pos_embed (data_pos )
266
274
for i in range (self ._num_layers ):
267
275
attn_layer = self ._self_attention_layers [i ]
0 commit comments