Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def create_rope_const_params(self, interleave: bool = True):

if self.scale_type == RotaryScalingType.yarn:
rope_inv_freq = None
rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
_, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
self.max_positions,
self.dim,
self.theta,
Expand Down
19 changes: 11 additions & 8 deletions tensorrt_llm/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4772,6 +4772,7 @@ def create_sinusoidal_positions_yarn(
beta_slow: int = 1,
mscale: float = 1.0,
mscale_all_dim: float = 1.0,
duplicate_data: bool = True,
dtype=np.float32):

# Copy from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py
Expand Down Expand Up @@ -4829,23 +4830,25 @@ def yarn_linear_ramp_mask(min, max, dim):
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high,
dim // 2).astype(dtype)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
t = np.arange(num_pos, dtype=dtype)

freqs = np.outer(t, inv_freq)
sinusoid_inp = np.expand_dims(np.einsum("i , j -> i j",
np.arange(num_pos, dtype=dtype),
inv_freq,
dtype=dtype),
axis=-1)

_mscale = float(
yarn_get_mscale(scaling_factor, mscale) /
yarn_get_mscale(scaling_factor, mscale_all_dim))

emb = np.concatenate((freqs, freqs), axis=-1)
if duplicate_data:
emb = np.concatenate((sinusoid_inp, sinusoid_inp), axis=-2)
else:
emb = sinusoid_inp

concat = np.concatenate((np.cos(emb) * _mscale, np.sin(emb) * _mscale),
axis=-1)

concat = concat.reshape((num_pos, 2, dim))
concat = np.transpose(concat, (0, 2, 1))

return concat.reshape((1, -1)).astype(dtype)
return inv_freq, concat.reshape((1, -1)).astype(dtype)

@staticmethod
def rotate_every_two(tensor: Tensor) -> Tensor:
Expand Down
30 changes: 29 additions & 1 deletion tensorrt_llm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,34 @@ def create_attention_const_params(model_cls, config):
is_buffer=True))
model_cls.short_mscale = short_mscale
model_cls.long_mscale = long_mscale
elif rotary_embedding_scale_type == RotaryScalingType.yarn:
beta_fast = rotary_embedding_scaling.get("beta_fast", 32.0)
beta_slow = rotary_embedding_scaling.get("beta_slow", 1.0)
mscale = rotary_embedding_scaling.get("mscale", 1.0)
mscale_all_dim = rotary_embedding_scaling.get("mscale_all_dim", 0.0)
original_max_position_embeddings = rotary_embedding_scaling.get(
"original_max_position_embeddings", 4096)
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
max_position_embeddings, rotary_embedding_dim,
rotary_embedding_base, rotary_embedding_scale,
original_max_position_embeddings, beta_fast, beta_slow, mscale,
mscale_all_dim, False)

embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
max_position_embeddings,
rotary_embedding_dim,
)
model_cls.register_parameter(
'embed_positions',
Parameter(embed_positions, dtype='float32', is_buffer=True))
model_cls.register_parameter(
'rotary_inv_freq',
Parameter(rotary_inv_freq, dtype='float32', is_buffer=True))
model_cls.register_parameter(
'embed_positions_for_gpt_attention',
Parameter(embed_positions_for_gpt_attention,
dtype='float32',
is_buffer=True))
else:

def register_rope_params(rotary_base, names_to_register):
Expand Down Expand Up @@ -2048,7 +2076,7 @@ def yarn_get_mscale(scale=1, mscale=1):
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.q_scaling = 1.0 / (mscale * mscale)

embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
_, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_yarn(
self.max_position_embeddings, self.qk_rope_head_dim,
self.rotary_embedding_base, self.rotary_scaling["factor"],
rotary_embedding_origin_max_position, rotary_embedding_beta_fast,
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest/_torch/test_attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def _run_test_for_backend(backend_name, num_heads, num_kv_heads, num_layers,
rope_config.rope_scaling['beta_slow'],
rope_config.rope_scaling['mscale'],
rope_config.rope_scaling['mscale_all_dim'],
),
)[1],
dtype=torch.float32,
device=device,
).reshape(rope_config.max_position_embeddings, -1, 2).transpose(-2, -1)
Expand Down