-
Notifications
You must be signed in to change notification settings - Fork 40
Description
In our paper we only showed results on causal language models, which use causally masked (decoder) self-attention.
If you'd like to use ALiBi for seq2seq tasks such as translation, speech or T5, or if you'd like to use ALiBi for masked language models such as BERT, some modifications are required.
Encoder-Attention
Encoder-Attention is the non-masked self-attention that is performed in the encoder of seq2seq models such as translation models or T5. This is also the same kind of attention used in MLM models such as BERT.
We can't naively copy paste the ALiBi code for these models because it won't work. We use a trick to quickly calculate the bias matrix for causal language modeling, but this bias matrix is only correct for values in or below the main diagonal (since that's all that's used in causal language modeling).
maxpos = args.tokens_per_sample
attn_heads = args.encoder_attention_heads
context_position = torch.arange(maxpos)[:, None].cuda()
memory_position = torch.arange(maxpos)[None, :].cuda()
relative_position = memory_position - context_position
relative_position = torch.abs(relative_position).unsqueeze(0).expand(attn_heads, -1,-1)
This code correctly generates the full bias matrix. Note that the bias matrix is symmetric around the diagonal, since it computes the absolute distance between the query and key (so all distances are positive).
We're also going to need the code for generating the ALiBi slopes:
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n) #In the paper, we only train models that have 2^a heads for some a. This function has
else: #some good properties that only occur when the input is a power of 2. To maintain that even
closest_power_of_2 = 2**math.floor(math.log2(n)) #when the number of heads is not a power of 2, we use this workaround.
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
There are 3 options for implementing encoder-attention ALiBi:
- Symmetric: In this option, the bias we assign to query/key pairs that are +N or -N tokens apart will be the same.
self.slopes = torch.Tensor(get_slopes(attn_heads)).cuda()*-1
self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
self.alibi = self.alibi.view(1, attn_heads, maxpos, maxpos)
Now just pass self.alibi to the attention function and add it after the query*key computation.
In fairseq for example, the query*key computation is done as such: attn_weights = torch.bmm(q, k.transpose(1, 2))
, and then to add the ALiBi values use:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights += alibi[:,:,:tgt_len,:src_len].to(attn_weights)
attn_weights = attn_weights.view(bsz*self.num_heads, tgt_len, src_len)
- Nonsymmetric: Here we are going to make the model nonsymmetric by using the same ALiBi bias as in (1), but this time, we're going to let the first half of the heads only look left and the second half only look right. We'll do this by adding a mask to our attention.
Note: This code hasn't been fully tested yet and might contain bugs.
self._future_mask_right = torch.triu(utils.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1).unsqueeze(0).repeat(attn_heads//2, 1, 1)
self._future_mask_left = torch.tril(utils.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), -1).unsqueeze(0).repeat(attn_heads//2, 1, 1)
self.nonsym_mask = torch.cat((self._future_mask_right, self._future_mask_left), dim = 0).unsqueeze(0).cuda()
self.slopes = torch.Tensor(get_slopes(attn_heads//2)).cuda()*-1
context_position = torch.arange(maxpos)[:, None].cuda()
memory_position = torch.arange(maxpos)[None, :].cuda()
relative_position = memory_position - context_position
relative_position = torch.abs(relative_position).unsqueeze(0).expand(attn_heads//2, -1,-1)
self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
self.alibi = self.alibi.view(1, attn_heads//2, maxpos, maxpos)
self.alibi = self.alibi.repeat(1, 2, 1, 1).cuda()
Again, as before, add self.alibi to the attn-weights, but this time also add the nonsym_mask tensor. (In fairseq attn_weights += nonsym_mask[:,:,:tgt_len,:src_len].to(attn_weights)
)
- Nonsymmetric with no mask: In this approach, we don't use any masking, but instead we make the positioning non-symmetric by using different ALiBi slopes depending on whether the key is to the left or right of the query. Here, we use learned slopes but you can also do this with non-learned slopes.
Note: I haven't tested this code so it might contain bugs!
slopes_left = torch.nn.parameter.Parameter(torch.Tensor( attn_heads))
nn.init.normal_(slopes_left, -2,1)
slopes_right = torch.nn.parameter.Parameter(torch.Tensor( attn_heads))
nn.init.normal_(slopes_right, -2,1)
slopes_left = -torch.sigmoid(slopes_left)
slopes_right = -torch.sigmoid(slopes_right)
context_position = torch.arange(maxpos)[:, None]
memory_position = torch.arange(maxpos)[None, :]
relative_position = memory_position - context_position
relative_position = torch.abs(relative_position).unsqueeze(0).expand(attn_heads, -1,-1)
alibi_left = slopes_left.unsqueeze(1).unsqueeze(1) * relative_position
alibi_right = slopes_right.unsqueeze(1).unsqueeze(1) * relative_position
self.alibi = torch.triu(alibi_right) + torch.tril(alibi_left)
- Check out the variation on option 3 from the LittleBird paper.
Cross-Attention
For translation models and models like T5 you will also need to implement cross-attention, which is the attention from the decoder to the encoder. The T5 model uses no positional information in cross-attention and I would recommend doing the same thing.
Implementations
NEW: lucidrains/x-transformers#88 lucidrains has implemented some of the above ideas in the x-transformers repo.