Skip to content

Commit 4dd085e

Browse files
committed
take care of variable length key / value sequences from vlm
1 parent e3cd421 commit 4dd085e

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "x-transformers"
3-
version = "2.6.0"
3+
version = "2.6.1"
44
description = "X-Transformers"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_x_transformers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1232,4 +1232,6 @@ def test_external_key_values():
12321232
(torch.randn(3, 8, 32, 16), torch.randn(3, 8, 32, 16)),
12331233
]
12341234

1235-
logits = model(seq, self_attn_additional_kv = key_values)
1235+
additional_kv_mask = torch.randint(0, 2, (3, 32)).bool()
1236+
1237+
logits = model(seq, self_attn_additional_kv = key_values, additional_kv_mask = additional_kv_mask)

x_transformers/x_transformers.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,7 +1618,8 @@ def forward(
16181618
return_intermediates = False,
16191619
cache: Intermediates | None = None,
16201620
value_residual = None,
1621-
additional_key_values: tuple[Tensor, Tensor] | None = None
1621+
additional_key_values: tuple[Tensor, Tensor] | None = None,
1622+
additional_key_value_mask = None,
16221623
):
16231624
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
16241625

@@ -1791,15 +1792,22 @@ def forward(
17911792
# maybe append additional key / values
17921793

17931794
if exists(additional_key_values):
1795+
seq_len = k.shape[-2]
17941796

17951797
added_k, added_v = additional_key_values
1796-
added_kv_len = added_k.shape[-2]
17971798

17981799
k = cat((added_k, k), dim = -2)
17991800
v = cat((added_v, v), dim = -2)
18001801

1801-
if exists(input_mask):
1802-
input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1802+
if (exists(input_mask) or exists(additional_key_value_mask)):
1803+
1804+
if not exists(additional_key_value_mask):
1805+
added_kv_len = added_k.shape[-2]
1806+
input_mask = pad_at_dim(input_mask, (added_kv_len, 0), dim = -1, value = True)
1807+
elif not exists(input_mask):
1808+
input_mask = pad_at_dim(additional_key_value_mask, (0, seq_len), dim = -1, value = True)
1809+
else:
1810+
input_mask = cat((additional_key_value_mask, input_mask), dim = -1)
18031811

18041812
# determine masking
18051813

@@ -2426,6 +2434,7 @@ def forward(
24262434
attn_bias = None,
24272435
deep_embeds_and_ids: tuple[nn.Parameter, Tensor] | None = None,
24282436
self_attn_additional_kv: list[tuple[Tensor, Tensor]] | None = None,
2437+
additional_kv_mask = None,
24292438
condition = None,
24302439
in_attn_cond = None, # https://arxiv.org/abs/2105.04090
24312440
layers_execute_order: tuple[int, ...] | None = None
@@ -2666,7 +2675,7 @@ def forward(
26662675
# forward depending on layer type
26672676

26682677
if layer_type == 'a':
2669-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
2678+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, pos = pos, rotary_pos_emb = rotary_pos_emb, additional_key_values = next(iter_self_attn_kv, None), additional_key_value_mask = additional_kv_mask, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
26702679
elif layer_type == 'c':
26712680
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, **cross_attn_rotary_pos_emb, return_intermediates = True)
26722681
elif layer_type == 'f':

0 commit comments

Comments
 (0)