@@ -42,16 +42,18 @@ enum class qkv_fomat_t
42
42
43
43
enum class key_mask_mode_t
44
44
{
45
- direct = 0 ,
46
- left_pad = 1 ,
47
- right_pad = 2
45
+ direct_2d_pad = 0 ,
46
+ left_pad = 1 ,
47
+ right_pad = 2 ,
48
+ direct_3d_pad = 3
48
49
};
49
50
50
51
struct multi_head_attention_parameters
51
52
{
52
53
int64_t batch_size;
53
54
int64_t q_sequence_length;
54
55
int64_t kv_sequence_length;
56
+ int64_t total_sequence_length;
55
57
int64_t hidden_size;
56
58
int64_t hidden_size_v;
57
59
int64_t head_size;
@@ -207,6 +209,59 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
207
209
}
208
210
}
209
211
212
+ check_key_padding_mask (const std::vector<instruction_ref>& args,
213
+ multi_head_attention_parameters& params) const
214
+ {
215
+ if (args.size () > 4 )
216
+ {
217
+ auto key_pad_mask_shape = args[3 ]->get_shape ();
218
+ auto key_pad_lens = key_pad_mask_shape.lens ();
219
+ auto key_pad_len_size = key_pad_mask_les.size ();
220
+
221
+ if (key_pad_len_size > 3 or key_pad_len_size < 1 )
222
+ MIGRAPHX_THROW (" MultiHeadAttention: Key_pad_mask must be either 1D, 2D or 3D shape tensor" );
223
+
224
+ if (key_pad_len_size == 1 )
225
+ {
226
+ auto key_pad_shape = key_pad_lens.at (0 );
227
+ if (key_pad_size != params.batch_size and key_pad_shape != (3 * params.batch_size + 2 ))
228
+ MIGRAPHXTHROW (" MultiHeadAttention: Key Padding Mask must be either batch or 3 x Batch + 2 for 1D key pads" );
229
+
230
+ if (key_pad_size == params.batch_size )
231
+ {
232
+ params.key_pad_mode = right_pad;
233
+ }
234
+ else
235
+ {
236
+ params.key_pad_mode = left_pad;
237
+ }
238
+ }
239
+ else if (key_pad_len_size == 2 )
240
+ {
241
+ auto key_pad_batch = key_pad_lens.at (0 );
242
+ auto key_pad_total_seq_len = key_pad_lens.at (1 );
243
+
244
+ if (key_pad_batch != params.batch_size or key_pad_seq_len != params.kv_sequence_length )
245
+ {
246
+ MIGRAPHX_THROW (" MultiHeadAttention: 2D Keypad mask must have either (batch, kv_sequence_length) or (batch, total_sequence_length)" )
247
+ }
248
+ diparams.key_pad_mode = direct_2d;
249
+ }
250
+ else // key_pad_len_size == 3 here
251
+ {
252
+ auto key_pad_batch = key_pad_lens.at (0 );
253
+ auto key_pad_seq_len = key_pad_lens.at (1 );
254
+ auto key_pad_total_seq_len = key_pad_lens.at (2 );
255
+ if (key_pad_batch != params.batch_size or key_pad_seq_len != params.kv_sequence_length or key_pad_total_seq_len != params.total_sequence_length )
256
+ {
257
+ MIGRAPHX_THROW (" MultiHeadAttention: 2D Keypad mask must have either (batch, kv_sequence_length) or (batch, total_sequence_length)" )
258
+ }
259
+ params.key_pad_mode = direct_3d_pad;
260
+ }
261
+
262
+ }
263
+ }
264
+
210
265
void check_bias (const std::vector<instruction_ref>& args,
211
266
multi_head_attention_parameters& params) const
212
267
{
@@ -237,6 +292,7 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
237
292
// This must be used first to extract hidden size, batch, etc
238
293
check_query_dim (args, params);
239
294
check_bias (args, params);
295
+ check_key_padding_mask (args, params);
240
296
}
241
297
242
298
std::tuple<instruction_ref, instruction_ref, instruction_ref>
0 commit comments