Skip to content

Commit 721fb6e

Browse files
Add check key padding mask
1 parent 1f6d439 commit 721fb6e

File tree

1 file changed

+59
-3
lines changed

1 file changed

+59
-3
lines changed

src/onnx/parse_multi_head_attention.cpp

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,18 @@ enum class qkv_fomat_t
4242

4343
enum class key_mask_mode_t
4444
{
45-
direct = 0,
46-
left_pad = 1,
47-
right_pad = 2
45+
direct_2d_pad = 0,
46+
left_pad = 1,
47+
right_pad = 2,
48+
direct_3d_pad = 3
4849
};
4950

5051
struct multi_head_attention_parameters
5152
{
5253
int64_t batch_size;
5354
int64_t q_sequence_length;
5455
int64_t kv_sequence_length;
56+
int64_t total_sequence_length;
5557
int64_t hidden_size;
5658
int64_t hidden_size_v;
5759
int64_t head_size;
@@ -207,6 +209,59 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
207209
}
208210
}
209211

212+
check_key_padding_mask(const std::vector<instruction_ref>& args,
213+
multi_head_attention_parameters& params) const
214+
{
215+
if(args.size() > 4)
216+
{
217+
auto key_pad_mask_shape = args[3]->get_shape();
218+
auto key_pad_lens = key_pad_mask_shape.lens();
219+
auto key_pad_len_size = key_pad_mask_les.size();
220+
221+
if(key_pad_len_size > 3 or key_pad_len_size < 1)
222+
MIGRAPHX_THROW("MultiHeadAttention: Key_pad_mask must be either 1D, 2D or 3D shape tensor");
223+
224+
if(key_pad_len_size == 1)
225+
{
226+
auto key_pad_shape = key_pad_lens.at(0);
227+
if(key_pad_size != params.batch_size and key_pad_shape != (3* params.batch_size + 2))
228+
MIGRAPHXTHROW("MultiHeadAttention: Key Padding Mask must be either batch or 3 x Batch + 2 for 1D key pads");
229+
230+
if(key_pad_size == params.batch_size)
231+
{
232+
params.key_pad_mode = right_pad;
233+
}
234+
else
235+
{
236+
params.key_pad_mode = left_pad;
237+
}
238+
}
239+
else if(key_pad_len_size == 2)
240+
{
241+
auto key_pad_batch = key_pad_lens.at(0);
242+
auto key_pad_total_seq_len = key_pad_lens.at(1);
243+
244+
if(key_pad_batch != params.batch_size or key_pad_seq_len != params.kv_sequence_length)
245+
{
246+
MIGRAPHX_THROW("MultiHeadAttention: 2D Keypad mask must have either (batch, kv_sequence_length) or (batch, total_sequence_length)")
247+
}
248+
diparams.key_pad_mode = direct_2d;
249+
}
250+
else // key_pad_len_size == 3 here
251+
{
252+
auto key_pad_batch = key_pad_lens.at(0);
253+
auto key_pad_seq_len = key_pad_lens.at(1);
254+
auto key_pad_total_seq_len = key_pad_lens.at(2);
255+
if(key_pad_batch != params.batch_size or key_pad_seq_len != params.kv_sequence_length or key_pad_total_seq_len != params.total_sequence_length)
256+
{
257+
MIGRAPHX_THROW("MultiHeadAttention: 2D Keypad mask must have either (batch, kv_sequence_length) or (batch, total_sequence_length)")
258+
}
259+
params.key_pad_mode = direct_3d_pad;
260+
}
261+
262+
}
263+
}
264+
210265
void check_bias(const std::vector<instruction_ref>& args,
211266
multi_head_attention_parameters& params) const
212267
{
@@ -237,6 +292,7 @@ struct parse_multi_head_attention : op_parser<parse_multi_head_attention>
237292
// This must be used first to extract hidden size, batch, etc
238293
check_query_dim(args, params);
239294
check_bias(args, params);
295+
check_key_padding_mask(args, params);
240296
}
241297

242298
std::tuple<instruction_ref, instruction_ref, instruction_ref>

0 commit comments

Comments
 (0)