@@ -224,6 +224,47 @@ def __init__(
224224 mask = torch .empty (n_ctx , n_ctx ).fill_ (- np .inf ).triu_ (1 )
225225 self .register_buffer ("mask" , mask , persistent = False )
226226
227+ # Optimisation: pre-compute and register the mask in CUDA if available
228+ if torch .cuda .is_available ():
229+ self .register_buffer ("mask_cuda" , mask .cuda (), persistent = False )
230+
231+
232+ def forward (self , tokens : Tensor , audio_features : Tensor ) -> Tensor :
233+ """
234+ Args:
235+ tokens: (n_batch, n_token)
236+ audio_features: (n_batch, n_audio_ctx, n_audio_state)
237+
238+ Returns:
239+ logits: (n_batch, n_token, n_vocab)
240+ """
241+ n_batch , n_token = tokens .shape
242+ n_audio_ctx , n_audio_state = audio_features .shape [1 :]
243+
244+ x = self .token_embedding (tokens ) + self .positional_embedding [:n_token ]
245+
246+ # Optimisation: Move audio_features to GPU once here.
247+ if torch .cuda .is_available ():
248+ audio_features = audio_features .cuda ()
249+
250+
251+ for block in self .blocks :
252+ x = block (x , audio_features )
253+
254+ x = self .ln (x )
255+ logits = x @ self .token_embedding .weight .T
256+
257+ # Optimisation: Apply the precomputed CUDA mask if available.
258+ if torch .cuda .is_available ():
259+ mask = self .mask_cuda [:n_token , :n_token ]
260+ else :
261+ mask = self .mask [:n_token , :n_token ]
262+
263+ logits = logits + mask
264+
265+ return logits
266+
267+
227268 def forward (self , x : Tensor , xa : Tensor , kv_cache : Optional [dict ] = None ):
228269 """
229270 x : torch.LongTensor, shape = (batch_size, <= n_ctx)
@@ -342,4 +383,4 @@ def install_hooks(layer: nn.Module):
342383
343384 detect_language = detect_language_function
344385 transcribe = transcribe_function
345- decode = decode_function
386+ decode = decode_function
0 commit comments