@@ -1618,7 +1618,8 @@ def forward(
1618
1618
return_intermediates = False ,
1619
1619
cache : Intermediates | None = None ,
1620
1620
value_residual = None ,
1621
- additional_key_values : tuple [Tensor , Tensor ] | None = None
1621
+ additional_key_values : tuple [Tensor , Tensor ] | None = None ,
1622
+ additional_key_value_mask = None ,
1622
1623
):
1623
1624
b , n , h , kv_h , head_scale , num_mem_kv , device , has_context , qkv_receive_diff_residuals , is_multi_latent_attn = x .shape [0 ], x .shape [1 ], self .heads , self .kv_heads , self .head_scale , self .num_mem_kv , x .device , exists (context ), self .qkv_receive_diff_residuals , self .use_latent_kv
1624
1625
@@ -1791,15 +1792,22 @@ def forward(
1791
1792
# maybe append additional key / values
1792
1793
1793
1794
if exists (additional_key_values ):
1795
+ seq_len = k .shape [- 2 ]
1794
1796
1795
1797
added_k , added_v = additional_key_values
1796
- added_kv_len = added_k .shape [- 2 ]
1797
1798
1798
1799
k = cat ((added_k , k ), dim = - 2 )
1799
1800
v = cat ((added_v , v ), dim = - 2 )
1800
1801
1801
- if exists (input_mask ):
1802
- input_mask = pad_at_dim (input_mask , (added_kv_len , 0 ), dim = - 1 , value = True )
1802
+ if (exists (input_mask ) or exists (additional_key_value_mask )):
1803
+
1804
+ if not exists (additional_key_value_mask ):
1805
+ added_kv_len = added_k .shape [- 2 ]
1806
+ input_mask = pad_at_dim (input_mask , (added_kv_len , 0 ), dim = - 1 , value = True )
1807
+ elif not exists (input_mask ):
1808
+ input_mask = pad_at_dim (additional_key_value_mask , (0 , seq_len ), dim = - 1 , value = True )
1809
+ else :
1810
+ input_mask = cat ((additional_key_value_mask , input_mask ), dim = - 1 )
1803
1811
1804
1812
# determine masking
1805
1813
@@ -2426,6 +2434,7 @@ def forward(
2426
2434
attn_bias = None ,
2427
2435
deep_embeds_and_ids : tuple [nn .Parameter , Tensor ] | None = None ,
2428
2436
self_attn_additional_kv : list [tuple [Tensor , Tensor ]] | None = None ,
2437
+ additional_kv_mask = None ,
2429
2438
condition = None ,
2430
2439
in_attn_cond = None , # https://arxiv.org/abs/2105.04090
2431
2440
layers_execute_order : tuple [int , ...] | None = None
@@ -2666,7 +2675,7 @@ def forward(
2666
2675
# forward depending on layer type
2667
2676
2668
2677
if layer_type == 'a' :
2669
- out , inter = block (x , mask = mask , context_mask = self_attn_kv_mask , attn_mask = attn_mask , rel_pos = self .rel_pos , pos = pos , rotary_pos_emb = rotary_pos_emb , additional_key_values = next (iter_self_attn_kv , None ), prev_attn = prev_attn , cache = next (iter_attn_cache , None ), mem = layer_mem , mem_mask = layer_mem_mask , attn_bias = attn_bias , value_residual = maybe_self_attn_value_residual , return_intermediates = True )
2678
+ out , inter = block (x , mask = mask , context_mask = self_attn_kv_mask , attn_mask = attn_mask , rel_pos = self .rel_pos , pos = pos , rotary_pos_emb = rotary_pos_emb , additional_key_values = next (iter_self_attn_kv , None ), additional_key_value_mask = additional_kv_mask , prev_attn = prev_attn , cache = next (iter_attn_cache , None ), mem = layer_mem , mem_mask = layer_mem_mask , attn_bias = attn_bias , value_residual = maybe_self_attn_value_residual , return_intermediates = True )
2670
2679
elif layer_type == 'c' :
2671
2680
out , inter = block (x , context = context , mask = mask , context_mask = context_mask , prev_attn = prev_cross_attn , cache = next (iter_attn_cache , None ), value_residual = maybe_cross_attn_value_residual , ** cross_attn_rotary_pos_emb , return_intermediates = True )
2672
2681
elif layer_type == 'f' :
0 commit comments