Skip to content

keras==3.0.5 - MultiHeadAttention(... return_attention_scores=True) does not work #19303

@lbortolotti

Description

@lbortolotti

With keras==3.0.5, the following snippet:

import keras
from keras import layers

mha = layers.MultiHeadAttention(num_heads=2, key_dim=2)

x = keras.Input(shape=(3, 5))
y = keras.Input(shape=(3, 5))

out_attn, out_attn_scores = mha(x, y, return_attention_scores=True)

Throws:

Traceback (most recent call last):
  File "mha_repro.py", line 10, in <module>
    out_attn, out_attn_scores = mha(x, y, return_attention_scores=True)
  File "venv\lib\site-packages\keras\src\backend\common\keras_tensor.py", line 121, in __iter__
    raise NotImplementedError(
NotImplementedError: Iterating over a symbolic KerasTensor is not supported.

Because return_attention_scores seems to have no effect, and the second output is not returned. The mha.call() method seems right, it's just that the __call__ doesn't seem to construct the output tuple correctly.

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions