Skip to content

Conversation

laxmareddyp
Copy link
Collaborator

@laxmareddyp laxmareddyp commented Sep 5, 2025

@divyashreepathihalli @mattdangerw @abheesht17 Could you please check and provide your feedback on the quality of this code generated through script.

I assume that 80-85% the code is matching and backbone files import successfully and it's possible to instantiate a backbone model. There still were some errors , which might be alleviated with a stronger model.

The converter and weight conversion scripts are still in development. Generating a workable solution is complex because it requires providing the model with a comprehensive understanding of the entire architectural layout to handle the intricate dependencies of the model's layers and weights.

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have followed the Keras Hub Model contribution guidelines in making these changes.
  • I have followed the Keras Hub API design guidelines in making these changes.
  • I have signed the Contributor License Agreement.

@divyashreepathihalli
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new GPT-OSS model, seemingly generated by a script. While the overall structure aligns with the repository's conventions, there are several critical issues that need to be addressed. The weight conversion scripts in both utils/transformers and tools/checkpoint_conversion are non-functional due to incorrect layer and weight names, likely a result of being copied from another model's implementation without proper adaptation. Additionally, there's a critical code duplication issue with the GptOssCausalLMPreprocessor class being defined in two separate files. Several style guide violations were also found, including the use of type hints in function signatures, incorrect docstring formatting, and missing docstrings or examples in several new components. These issues must be resolved to ensure the model is correct, maintainable, and usable.

Comment on lines 12 to 46
@keras_hub_export("keras_hub.models.GptOssCausalLMPreprocessor")
class GptOssCausalLMPreprocessor(CausalLMPreprocessor):
"""GPT-OSS Causal LM preprocessor.

This class is responsible for preprocessing the inputs for the GPT-OSS
Causal LM model. It tokenizes the input text and creates the attention
mask.

Args:
tokenizer: A `keras_hub.models.GptOssTokenizer` instance.
sequence_length: The maximum sequence length.
add_start_token: Whether to add a start token to the input.
add_end_token: Whether to add an end token to the input.
"""

def __init__(
self,
tokenizer: GptOssTokenizer,
sequence_length: int,
add_start_token: bool = True,
add_end_token: bool = False,
**kwargs,
):
super().__init__(
tokenizer=tokenizer,
sequence_length=sequence_length,
add_start_token=add_start_token,
add_end_token=add_end_token,
**kwargs,
)

def get_config(self):
config = super().get_config()
return config

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This GptOssCausalLMPreprocessor class is also defined in gpt_oss_causal_lm_preprocessor.py. To avoid code duplication and adhere to the project structure where each component has its own file, please remove this class definition and instead import it from the other file.1

You should add the following import at the top of the file:

from keras_hub.src.models.gpt_oss.gpt_oss_causal_lm_preprocessor import GptOssCausalLMPreprocessor

Style Guide References

Footnotes

  1. A preprocessor should be in its own <model_name>_preprocessor.py file.

Comment on lines 32 to 175
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.key_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
hook_fn=transpose_and_reshape,
)
if backbone.use_bias:
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.key_dense.bias,
hf_weight_key=f"model.layers.{i}.self_attn.k_proj.bias",
)
## Value
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.value_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
hook_fn=transpose_and_reshape,
)
if backbone.use_bias:
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.value_dense.bias,
hf_weight_key=f"model.layers.{i}.self_attn.v_proj.bias",
)
## Output
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.output_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
hook_fn=transpose_and_reshape,
)
if backbone.use_bias:
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.output_dense.bias,
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.bias",
)
## Sinks (unique to GptOssAttention)
loader.port_weight(
keras_variable=decoder_layer._self_attention_layer.sinks,
hf_weight_key=f"model.layers.{i}.self_attn.sinks",
)

# MoE layers (GptOssMLP)
# Router gate (GptOssTopKRouter)
loader.port_weight(
keras_variable=decoder_layer._sparse_moe_block._sparse_feedforward_gate_dense.kernel,
hf_weight_key=f"model.layers.{i}.mlp.router.weight",
hook_fn=transpose_and_reshape,
)
loader.port_weight(
keras_variable=decoder_layer._sparse_moe_block._sparse_feedforward_gate_dense.bias,
hf_weight_key=f"model.layers.{i}.mlp.router.bias",
)

hf_gate_up_proj = loader.get_tensor(
f"model.layers.{i}.mlp.experts.gate_up_proj"
)
hf_gate_up_proj_bias = loader.get_tensor(
f"model.layers.{i}.mlp.experts.gate_up_proj_bias"
)
hf_down_proj = loader.get_tensor(
f"model.layers.{i}.mlp.experts.down_proj"
)
hf_down_proj_bias = loader.get_tensor(
f"model.layers.{i}.mlp.experts.down_proj_bias"
)

gate_kernels = hf_gate_up_proj[:, :, ::2]
intermediate_kernels = hf_gate_up_proj[:, :, 1::2]
output_kernels = hf_down_proj

gate_biases = hf_gate_up_proj_bias[:, ::2]
intermediate_biases = hf_gate_up_proj_bias[:, 1::2]
output_biases = hf_down_proj_bias

# Assign batched weights to expert_bank variables
expert_bank = decoder_layer._sparse_moe_block.expert_bank

expert_bank._expert_feedforward_gate_kernel.assign(gate_kernels)
expert_bank._expert_feedforward_gate_bias.assign(gate_biases)

expert_bank._expert_feedforward_intermediate_kernel.assign(
intermediate_kernels
)
expert_bank._expert_feedforward_intermediate_bias.assign(
intermediate_biases
)

expert_bank._expert_feedforward_output_kernel.assign(output_kernels)
expert_bank._expert_feedforward_output_bias.assign(output_biases)

# Feedforward layernorm (GptOssRMSNorm)
loader.port_weight(
keras_variable=decoder_layer._feedforward_layernorm.scale,
hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
)

# Final normalization layer (GptOssRMSNorm)
loader.port_weight(
keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
hf_weight_key="model.norm.weight",
)

return backbone

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The convert_weights function appears to be incorrect and will fail during execution. The layer and weight names used do not match the implementation in GptOssBackbone and its sublayers. It seems this script might have been copied from another model's converter without being fully adapted.

Here are some of the mismatches:

  • decoder_layer._self_attention_layernorm should be decoder_layer._input_layernorm.
  • decoder_layer._sparse_moe_block should be decoder_layer._mlp_block.
  • The expert weight loading logic is incorrect. GptOssExperts uses batched weights, but the converter seems to assume a different structure (e.g., expert_bank).

Please review and correct the entire function to match the GptOss model architecture.

Comment on lines 83 to 208
keras_layer.pre_attention_norm.gamma.assign(
hf_layer.input_layernorm.weight.detach().cpu().numpy()
)

# Attention
# Q, K, V, O projections
keras_layer.attention.query_dense.kernel.assign(
hf_layer.self_attn.q_proj.weight.T.detach().cpu().numpy()
)
if hf_layer.self_attn.q_proj.bias is not None:
keras_layer.attention.query_dense.bias.assign(
hf_layer.self_attn.q_proj.bias.detach().cpu().numpy()
)

keras_layer.attention.key_dense.kernel.assign(
hf_layer.self_attn.k_proj.weight.T.detach().cpu().numpy()
)
if hf_layer.self_attn.k_proj.bias is not None:
keras_layer.attention.key_dense.bias.assign(
hf_layer.self_attn.k_proj.bias.detach().cpu().numpy()
)

keras_layer.attention.value_dense.kernel.assign(
hf_layer.self_attn.v_proj.weight.T.detach().cpu().numpy()
)
if hf_layer.self_attn.v_proj.bias is not None:
keras_layer.attention.value_dense.bias.assign(
hf_layer.self_attn.v_proj.bias.detach().cpu().numpy()
)

keras_layer.attention.output_dense.kernel.assign(
hf_layer.self_attn.o_proj.weight.T.detach().cpu().numpy()
)
if hf_layer.self_attn.o_proj.bias is not None:
keras_layer.attention.output_dense.bias.assign(
hf_layer.self_attn.o_proj.bias.detach().cpu().numpy()
)

# Sinks
keras_layer.attention.sinks.assign(
hf_layer.self_attn.sinks.detach().cpu().numpy()
)

# Post-Attention Layer Norm
keras_layer.pre_mlp_norm.gamma.assign(
hf_layer.post_attention_layernorm.weight.detach().cpu().numpy()
)

# MoE MLP
# Router
keras_layer.moe_mlp.router.kernel.assign(
hf_layer.mlp.router.weight.T.detach().cpu().numpy()
)
keras_layer.moe_mlp.router.bias.assign(
hf_layer.mlp.router.bias.detach().cpu().numpy()
)

# Experts
num_experts = hf_model.config.num_local_experts
for j in range(num_experts):
hf_expert_gate_up_proj = hf_layer.mlp.experts.gate_up_proj[
j
] # (hidden_size, 2 * expert_dim)
hf_expert_gate_up_proj_bias = (
hf_layer.mlp.experts.gate_up_proj_bias[j]
) # (2 * expert_dim)

# Split gate_up_proj into gate and up based on
# PyTorch forward logic (::2, 1::2)
hf_gate_proj_weight = hf_expert_gate_up_proj[
:, ::2
] # (hidden_size, expert_dim)
hf_up_proj_weight = hf_expert_gate_up_proj[
:, 1::2
] # (hidden_size, expert_dim)

hf_gate_proj_bias = hf_expert_gate_up_proj_bias[::2] # (expert_dim)
hf_up_proj_bias = hf_expert_gate_up_proj_bias[1::2] # (expert_dim)

keras_layer.moe_mlp.experts[j].gate_dense.kernel.assign(
hf_gate_proj_weight.T.detach().cpu().numpy()
)
keras_layer.moe_mlp.experts[j].gate_dense.bias.assign(
hf_gate_proj_bias.detach().cpu().numpy()
)

keras_layer.moe_mlp.experts[j].up_dense.kernel.assign(
hf_up_proj_weight.T.detach().cpu().numpy()
)
keras_layer.moe_mlp.experts[j].up_dense.bias.assign(
hf_up_proj_bias.detach().cpu().numpy()
)

keras_layer.moe_mlp.experts[j].down_dense.kernel.assign(
hf_layer.mlp.experts.down_proj[j].T.detach().cpu().numpy()
)
keras_layer.moe_mlp.experts[j].down_dense.bias.assign(
hf_layer.mlp.experts.down_proj_bias[j].detach().cpu().numpy()
)
print("Weights converted successfully.")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The convert_weights function is incorrect and will not work with the GptOssBackbone implementation. There are several mismatches in layer and weight names:

  • keras_hub_backbone.transformer_layers[-1].layer_norm.gamma (line 100): The final layer norm is at keras_hub_backbone.layer_norm, and its weight is named scale, not gamma.
  • keras_layer.pre_attention_norm.gamma (line 109): This should be keras_layer._input_layernorm.scale.
  • keras_layer.pre_mlp_norm.gamma (line 153): This should be keras_layer._post_attention_layernorm.scale.
  • The MoE and expert weight conversion logic (lines 157-207) is incorrect. It assumes a list of expert layers (keras_layer.moe_mlp.experts[j]), but GptOssMLP contains a single GptOssExperts layer with batched weights.

Please rewrite this function to correctly map the Hugging Face checkpoint weights to the KerasHub model's architecture.

Comment on lines 38 to 69
vocabulary_size (int): The size of the token vocabulary.
num_layers (int): The number of transformer layers.
num_query_heads (int): The number of query attention heads for
each transformer.
hidden_dim (int): The size of the transformer encoding and pooling
layers.
intermediate_dim (int): The output dimension of the first Dense layer
in a three-layer feedforward network for each transformer.
num_key_value_heads (int): The number of key and value attention heads
for each transformer.
num_experts (int): The total number of experts in the MoE layer.
top_k (int, optional): The number of experts to select per token.
Defaults to `2`.
rope_max_wavelength (int, optional): The maximum angular wavelength of
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
rope_scaling_factor (float, optional): The scaling factor for
calculation of rotary embedding. Defaults to `1.0`.
layer_norm_epsilon (float, optional): Epsilon for the layer
normalization layers in the transformer decoder. Defaults to `1e-6`.
sliding_window (int, optional): The sliding window for the attention
layers. This controls the maximum cache size for the
attention layers in each transformer decoder. Only `sliding_window`
number of tokens are saved in the cache and used to generate the
next token. Defaults to `4096`.
dropout (float, optional): Dropout rate for attention probabilities.
Defaults to `0`.
use_bias (bool, optional): Whether to include bias terms in the dense
projections within the attention mechanism. Defaults to `False`.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The docstring format for arguments violates the style guide.1 The format should be arg_name: type. description instead of arg_name (type): description. Please update all arguments in this docstring.

For example, vocabulary_size (int): should be vocabulary_size: int..

Style Guide References

Footnotes

  1. Type information is provided in the docstring Args section using the format arg_name: type. description.

```
"""

def __init__(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This __init__ method is missing the initial call to super().__init__() before layer creation. According to the style guide, the super constructor should be called at the beginning of __init__ to handle dtype and other base layer setup.1 Please add super().__init__(dtype=dtype, **kwargs) at the beginning of the method.

    def __init__(
        self,
        vocabulary_size,
        num_layers,
        num_query_heads,
        hidden_dim,
        intermediate_dim,
        num_key_value_heads,
        num_experts,
        top_k=2,
        rope_max_wavelength=10000,
        rope_scaling_factor=1.0,
        layer_norm_epsilon=1e-6,
        sliding_window=4096,
        dropout=0,
        use_bias=False,
        dtype=None,
        **kwargs,
    ):
        super().__init__(dtype=dtype, **kwargs)

Style Guide References

Footnotes

  1. The style guide for backbone models shows super().__init__ being called at the start of the __init__ method to set up the base Backbone class.

Comment on lines 118 to 125
def __init__(
self,
tokenizer: GptOssTokenizer,
sequence_length: int,
add_start_token: bool = True,
add_end_token: bool = False,
**kwargs,
):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The style guide specifies that type hints should not be used in function signatures.1 Please remove the type hints for tokenizer, sequence_length, add_start_token, and add_end_token.

The type information should be in the Args section of the docstring, which is already correctly done.

Suggested change
def __init__(
self,
tokenizer: GptOssTokenizer,
sequence_length: int,
add_start_token: bool = True,
add_end_token: bool = False,
**kwargs,
):
def __init__(
self,
tokenizer,
sequence_length,
add_start_token=True,
add_end_token=False,
**kwargs,
):

Style Guide References

Footnotes

  1. KerasHub does not use type hints in function signatures or __init__ methods. Default values are okay.

Comment on lines 10 to 33
class CachedGptOssAttention(keras.layers.Layer):
"""A cached attention layer for GPT-OSS with sink tokens and sliding window.

This layer implements the attention mechanism for the GPT-OSS model,
including grouped query attention (GQA),rotary positional embeddings(RoPE)
and a specific handling for "sink" tokens which are added to the attention
logits before softmax. It also supports caching for efficient generation.

Args:
num_query_heads: Number of attention heads for queries.
num_key_value_heads: Number of attention heads for keys and values.
If `num_query_heads != num_key_value_heads`, grouped query attention
is used.
rope_max_wavelength: The maximum wavelength for the rotary embedding.
rope_scaling_factor: Scaling factor for rotary embeddings.
kernel_initializer: Initializer for the dense layer kernels.
sliding_window: The size of the sliding window for attention.
Tokens outside this window are masked. This parameter is used for
configuration but the actual masking should be handled by the
`attention_mask` input.
dropout: Dropout rate for attention probabilities.
use_bias: Whether to include bias terms in the dense projections.
**kwargs: Additional keyword arguments passed to the base Layer class.
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this public class is missing an Examples section, which is required by the style guide.1 Please add a usage example to demonstrate how to instantiate and use this layer.

Style Guide References

Footnotes

  1. Public classes should have thorough documentation including usage examples.

Comment on lines 17 to 32
class GptOssExperts(keras.layers.Layer):
"""Batched feed-forward experts for GPT-OSS (pure keras.ops).

This layer implements the expert network for the Mixture-of-Experts (MoE)
block in GPT-OSS. It computes the output for all experts and then
applies the routing weights to combine their contributions.

Args:
num_experts: Integer, total number of experts.
hidden_dim: Integer, the hidden dimension of the model.
intermediate_dim: Integer, the intermediate dimension of the expert.
alpha: Float, scaling factor for the GLU activation.
limit: Float, clamping limit for gate and up projections.
kernel_initializer: Initializer for the dense layer kernels.
**kwargs: Additional keyword arguments passed to the base Layer class.
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstrings for public classes GptOssExperts, GptOssTopKRouter, GptOssMLP, and GptOssTransformerDecoder are missing an Examples section, which is required by the style guide.1 Please add a usage example to each class docstring.

Style Guide References

Footnotes

  1. All public classes, methods, and functions should have Google-style docstrings, including comprehensive examples showing usage patterns.

Comment on lines 5 to 15
class GptOssLayerNormalization(keras.layers.Layer):
"""A normalization layer for GPT-OSS that implements RMS normalization.

This layer applies Root Mean Square (RMS) normalization, which is a common
normalization technique used in models like Llama and GPT-OSS. It normalizes
the input by its root mean square, then scales it by a learnable weight.

Args:
epsilon: A small float number to prevent division by zero.
**kwargs: Additional keyword arguments passed to the base Layer class.
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for this public class is missing an Examples section, which is required by the style guide.1 Please add a usage example to demonstrate how to instantiate and use this layer.

Style Guide References

Footnotes

  1. All public classes, methods, and functions should have Google-style docstrings, including comprehensive examples showing usage patterns.

"keras_hub.models.GptOssTokenizer",
]
)
class GptOssTokenizer(SentencePieceTokenizer):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This public class is missing a docstring, which is required by the style guide.1 Please add a docstring explaining what the class does, its arguments, and include a usage example.

Style Guide References

Footnotes

  1. All public classes, methods, and functions must have Google-style docstrings.

@laxmareddyp
Copy link
Collaborator Author

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the GPT_OSS model, including its backbone, causal language model task, preprocessor, and tokenizer. The overall structure aligns well with the project's conventions. However, there are several areas that need attention.

Most critically, the Hugging Face weight conversion script in convert_gpt_oss.py is incorrect and incompatible with the implemented model architecture, particularly for the Mixture-of-Experts (MoE) layers. This will prevent loading pre-trained weights correctly. I've provided a corrected implementation for the convert_weights function.

Additionally, there are multiple instances of incorrect docstring formatting across the new files, which violates the project's style guide. I've pointed these out with suggestions for correction. A minor but important clarification is also needed in a test file comment to improve maintainability.

Comment on lines +44 to +159
for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"transformer_layer_{i}")

# Input layernorm
loader.port_weight(
keras_variable=decoder_layer._self_attention_layernorm.scale,
hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
)

# Attention layers
attention_layer = decoder_layer._self_attention_layer
# Query
loader.port_weight(
keras_variable=attention_layer.query_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
hook_fn=transpose_and_reshape,
)
# Key
loader.port_weight(
keras_variable=attention_layer.key_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
hook_fn=transpose_and_reshape,
)
# Value
loader.port_weight(
keras_variable=attention_layer.value_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
hook_fn=transpose_and_reshape,
)
# Output
loader.port_weight(
keras_variable=attention_layer.output_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
hook_fn=transpose_and_reshape,
)
# Sinks
loader.port_weight(
keras_variable=attention_layer.sinks,
hf_weight_key=f"model.layers.{i}.self_attn.sinks",
)

# MoE layers
moe_block = decoder_layer._sparse_moe_block
# Router gate
loader.port_weight(
keras_variable=moe_block._sparse_feedforward_gate_dense.kernel,
hf_weight_key=f"model.layers.{i}.mlp.router.weight",
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
)
loader.port_weight(
keras_variable=moe_block._sparse_feedforward_gate_dense.bias,
hf_weight_key=f"model.layers.{i}.mlp.router.bias",
)

# Batched experts
gate_up_proj = loader.get_tensor(
f"model.layers.{i}.mlp.experts.gate_up_proj"
)
gate_up_proj_bias = loader.get_tensor(
f"model.layers.{i}.mlp.experts.gate_up_proj_bias"
)
down_proj = loader.get_tensor(f"model.layers.{i}.mlp.experts.down_proj")
down_proj_bias = loader.get_tensor(
f"model.layers.{i}.mlp.experts.down_proj_bias"
)

# De-interleave gate and up projections
gate_proj_kernel = gate_up_proj[:, :, ::2]
up_proj_kernel = gate_up_proj[:, :, 1::2]
gate_proj_bias = gate_up_proj_bias[:, ::2]
up_proj_bias = gate_up_proj_bias[:, 1::2]

# Assign batched weights to expert_bank
expert_bank = moe_block.expert_bank
expert_bank._expert_feedforward_gate_dense.kernel.assign(
gate_proj_kernel
)
expert_bank._expert_feedforward_gate_dense.bias.assign(gate_proj_bias)
expert_bank._expert_feedforward_intermediate_dense.kernel.assign(
up_proj_kernel
)
expert_bank._expert_feedforward_intermediate_dense.bias.assign(
up_proj_bias
)
expert_bank._expert_feedforward_output_dense.kernel.assign(down_proj)
expert_bank._expert_feedforward_output_dense.bias.assign(down_proj_bias)

# Feedforward layernorm
loader.port_weight(
keras_variable=decoder_layer._feedforward_layernorm.scale,
hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
)

# Final normalization layer
loader.port_weight(
keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
hf_weight_key="model.norm.weight",
)

return backbone

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This weight conversion function is incorrect and will fail. It seems to be based on a different model's architecture (likely Mixtral) and is not compatible with the GptOss implementation in this PR.

Specifically:

  1. It references non-existent attributes like _sparse_feedforward_gate_dense and expert_bank.
  2. It uses incorrect layer names (e.g., _self_attention_layernorm instead of input_layernorm).
  3. The logic for handling MoE expert weights is wrong for this model's implementation.

I've provided a corrected version that should work with the GptOss model structure.

def convert_weights(backbone, loader, transformers_config):
    """Convert Gpt-Oss weights."""
    # Embeddings
    loader.port_weight(
        keras_variable=backbone.get_layer("token_embedding").embeddings,
        hf_weight_key="model.embed_tokens.weight",
    )
    loader.port_weight(
        keras_variable=backbone.get_layer("token_embedding").reverse_embeddings,
        hf_weight_key="lm_head.weight",
        hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
    )

    def transpose_and_reshape(x, shape):
        return np.reshape(np.transpose(x), shape)

    for i in range(backbone.num_layers):
        decoder_layer = backbone.get_layer(f"transformer_layer_{i}")

        # Input layernorm
        loader.port_weight(
            keras_variable=decoder_layer.input_layernorm.scale,
            hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
        )

        # Attention layers
        attention_layer = decoder_layer.self_attention_layer
        # Query
        loader.port_weight(
            keras_variable=attention_layer.query_dense.kernel,
            hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
            hook_fn=transpose_and_reshape,
        )
        # Key
        loader.port_weight(
            keras_variable=attention_layer.key_dense.kernel,
            hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
            hook_fn=transpose_and_reshape,
        )
        # Value
        loader.port_weight(
            keras_variable=attention_layer.value_dense.kernel,
            hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
            hook_fn=transpose_and_reshape,
        )
        # Output
        loader.port_weight(
            keras_variable=attention_layer.output_dense.kernel,
            hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
            hook_fn=transpose_and_reshape,
        )
        # Sinks
        loader.port_weight(
            keras_variable=attention_layer.sinks,
            hf_weight_key=f"model.layers.{i}.self_attn.sinks",
        )

        # MoE layers
        moe_block = decoder_layer.sparse_moe_block
        # Router gate
        loader.port_weight(
            keras_variable=moe_block.router.router_dense.kernel,
            hf_weight_key=f"model.layers.{i}.mlp.router.weight",
            hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
        )
        loader.port_weight(
            keras_variable=moe_block.router.router_dense.bias,
            hf_weight_key=f"model.layers.{i}.mlp.router.bias",
        )

        # Batched experts
        experts_layer = moe_block.experts
        loader.port_weight(
            keras_variable=experts_layer.gate_up_proj,
            hf_weight_key=f"model.layers.{i}.mlp.experts.gate_up_proj",
            hook_fn=lambda hf_tensor, keras_shape: np.transpose(
                hf_tensor, (0, 2, 1)
            ),
        )
        loader.port_weight(
            keras_variable=experts_layer.gate_up_proj_bias,
            hf_weight_key=f"model.layers.{i}.mlp.experts.gate_up_proj_bias",
        )
        loader.port_weight(
            keras_variable=experts_layer.down_proj,
            hf_weight_key=f"model.layers.{i}.mlp.experts.down_proj",
            hook_fn=lambda hf_tensor, keras_shape: np.transpose(
                hf_tensor, (0, 2, 1)
            ),
        )
        loader.port_weight(
            keras_variable=experts_layer.down_proj_bias,
            hf_weight_key=f"model.layers.{i}.mlp.experts.down_proj_bias",
        )

        # Feedforward layernorm
        loader.port_weight(
            keras_variable=decoder_layer.post_attention_layernorm.scale,
            hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
        )

    # Final normalization layer
    loader.port_weight(
        keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
        hf_weight_key="model.norm.weight",
    )

    return backbone

Comment on lines +32 to +45
Args:
num_query_heads (int): The number of query attention heads.
num_key_value_heads (int): The number of key and value attention
heads.
rope_max_wavelength (int, optional): The maximum wavelength for the
rotary position embedding. Defaults to 10000.
rope_scaling_factor (float, optional): The scaling factor for the
rotary position embedding. Defaults to 1.0.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".
sliding_window (int, optional): The size of the sliding window.
Defaults to 4096.
dropout (float, optional): The dropout rate. Defaults to 0.
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring format for arguments does not follow the style guide. The format should be arg_name: type. instead of arg_name (type):. 1

Suggested change
Args:
num_query_heads (int): The number of query attention heads.
num_key_value_heads (int): The number of key and value attention
heads.
rope_max_wavelength (int, optional): The maximum wavelength for the
rotary position embedding. Defaults to 10000.
rope_scaling_factor (float, optional): The scaling factor for the
rotary position embedding. Defaults to 1.0.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".
sliding_window (int, optional): The size of the sliding window.
Defaults to 4096.
dropout (float, optional): The dropout rate. Defaults to 0.
"""
Args:
num_query_heads: int. The number of query attention heads.
num_key_value_heads: int. The number of key and value attention
heads.
rope_max_wavelength: int, optional. The maximum wavelength for the
rotary position embedding. Defaults to 10000.
rope_scaling_factor: float, optional. The scaling factor for the
rotary position embedding. Defaults to 1.0.
kernel_initializer: str, optional. The initializer for the kernel
weights. Defaults to "glorot_uniform".
sliding_window: int, optional. The size of the sliding window.
Defaults to 4096.
dropout: float, optional. The dropout rate. Defaults to 0.

Style Guide References

Footnotes

  1. Type information should be provided in the docstring Args section using the format arg_name: type. description.

Comment on lines +49 to +77
Args:
vocabulary_size (int): The size of the token vocabulary.
num_layers (int): The number of transformer layers.
num_query_heads (int): The number of query attention heads for
each transformer.
hidden_dim (int): The size of the transformer encoding and pooling
layers.
intermediate_dim (int): The output dimension of the first Dense layer
in a three-layer feedforward network for each transformer.
num_key_value_heads (int): The number of key and value attention heads
for each transformer.
num_experts (int): The number of experts for the MoE layers.
top_k (int, optional): The number of experts to use for each token.
Defaults to `2`.
rope_max_wavelength (int, optional): The maximum angular wavelength of
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
rope_scaling_factor (float, optional): The scaling factor for
calculation of roatary embedding. Defaults to `1.0`.
layer_norm_epsilon (float, optional): Epsilon for the layer
normalization layers in the transformer decoder. Defaults to `1e-6`.
sliding_window (int, optional): The sliding window for the attention
layers. This controls the maximum cache size for the attention
layers in each transformer decoder. Only `sliding_window` number
of tokens are saved in the cache and used to generate the next
token. Defaults to `4096`.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for model computations and weights. Note that some computations,
such as softmax and layer normalization, will always be done at
float32 precision regardless of dtype.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring format for arguments does not follow the style guide. The format should be arg_name: type. instead of arg_name (type):. 1

    Args:
        vocabulary_size: int. The size of the token vocabulary.
        num_layers: int. The number of transformer layers.
        num_query_heads: int. The number of query attention heads for
            each transformer.
        hidden_dim: int. The size of the transformer encoding and pooling
            layers.
        intermediate_dim: int. The output dimension of the first Dense layer
            in a three-layer feedforward network for each transformer.
        num_key_value_heads: int. The number of key and value attention heads
            for each transformer.
        num_experts: int. The number of experts for the MoE layers.
        top_k: int, optional. The number of experts to use for each token.
            Defaults to `2`.
        rope_max_wavelength: int, optional. The maximum angular wavelength of
            the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
        rope_scaling_factor: float, optional. The scaling factor for
            calculation of roatary embedding. Defaults to `1.0`.
        layer_norm_epsilon: float, optional. Epsilon for the layer
            normalization layers in the transformer decoder. Defaults to `1e-6`.
        sliding_window: int, optional. The sliding window for the attention
            layers. This controls the maximum cache size for the attention
            layers in each transformer decoder. Only `sliding_window` number
            of tokens are saved in the cache and used to generate the next
            token. Defaults to `4096`.
        dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
            for model computations and weights. Note that some computations,
            such as softmax and layer normalization, will always be done at
            float32 precision regardless of dtype.

Style Guide References

Footnotes

  1. Type information should be provided in the docstring Args section using the format arg_name: type. description.

model = GptOssBackbone(**self.init_kwargs)
# Calculated based on the model architecture:
# - Token embedding: vocabulary_size * hidden_dim
# - Output projection: hidden_dim * vocabulary_size

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is slightly misleading. The backbone itself doesn't have an output projection layer. The ReversibleEmbedding layer contains weights for reverse projection, which are used by the GptOssCausalLM task but not directly by the backbone's forward pass. Clarifying this will improve maintainability.

Suggested change
# - Output projection: hidden_dim * vocabulary_size
# - Reverse embedding projection: hidden_dim * vocabulary_size

Comment on lines +69 to +73
Args:
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
cache: a dense float Tensor, the cache of key and value.
cache_update_index: int, or int Tensor. The index of current inputs
in the whole sequence.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring format for arguments does not follow the style guide. The format should be arg_name: type. description. 1

Suggested change
Args:
token_ids: a dense int Tensor with shape `(batch_size, max_length)`.
cache: a dense float Tensor, the cache of key and value.
cache_update_index: int, or int Tensor. The index of current inputs
in the whole sequence.
Args:
token_ids: tensor. A dense int Tensor with shape `(batch_size, max_length)`.
cache: tensor. A dense float Tensor, the cache of key and value.
cache_update_index: int or int Tensor. The index of current inputs
in the whole sequence.

Style Guide References

Footnotes

  1. Type information should be provided in the docstring Args section using the format arg_name: type. description.

Comment on lines +37 to +47
Args:
num_experts (int): The total number of experts.
hidden_dim (int): The hidden size of the model.
intermediate_dim (int): The intermediate size of the feed-forward
network.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".
alpha (float, optional): The alpha parameter for the custom GLU
activation. Defaults to 1.702.
limit (float, optional): The clamping limit for gate and up
projections. Defaults to 7.0.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring format for arguments does not follow the style guide. The format should be arg_name: type. instead of arg_name (type):. 1

Suggested change
Args:
num_experts (int): The total number of experts.
hidden_dim (int): The hidden size of the model.
intermediate_dim (int): The intermediate size of the feed-forward
network.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".
alpha (float, optional): The alpha parameter for the custom GLU
activation. Defaults to 1.702.
limit (float, optional): The clamping limit for gate and up
projections. Defaults to 7.0.
Args:
num_experts: int. The total number of experts.
hidden_dim: int. The hidden size of the model.
intermediate_dim: int. The intermediate size of the feed-forward
network.
kernel_initializer: str, optional. The initializer for the kernel
weights. Defaults to "glorot_uniform".
alpha: float, optional. The alpha parameter for the custom GLU
activation. Defaults to 1.702.
limit: float, optional. The clamping limit for gate and up
projections. Defaults to 7.0.

Style Guide References

Footnotes

  1. Type information should be provided in the docstring Args section using the format arg_name: type. description.

Comment on lines +124 to +128
Args:
num_experts (int): The total number of experts.
top_k (int): The number of experts to route each token to.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring format for arguments does not follow the style guide. The format should be arg_name: type. instead of arg_name (type):. 1

Suggested change
Args:
num_experts (int): The total number of experts.
top_k (int): The number of experts to route each token to.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".
Args:
num_experts: int. The total number of experts.
top_k: int. The number of experts to route each token to.
kernel_initializer: str, optional. The initializer for the kernel
weights. Defaults to "glorot_uniform".

Style Guide References

Footnotes

  1. Type information should be provided in the docstring Args section using the format arg_name: type. description.

Comment on lines +184 to +192
Args:
hidden_dim (int): The hidden size of the model.
intermediate_dim (int): The intermediate size of the feed-forward
network.
num_experts (int): The total number of experts.
top_k (int, optional): The number of experts to route each token to.
Defaults to 2.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring format for arguments does not follow the style guide. The format should be arg_name: type. instead of arg_name (type):. 1

Suggested change
Args:
hidden_dim (int): The hidden size of the model.
intermediate_dim (int): The intermediate size of the feed-forward
network.
num_experts (int): The total number of experts.
top_k (int, optional): The number of experts to route each token to.
Defaults to 2.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".
Args:
hidden_dim: int. The hidden size of the model.
intermediate_dim: int. The intermediate size of the feed-forward
network.
num_experts: int. The total number of experts.
top_k: int, optional. The number of experts to route each token to.
Defaults to 2.
kernel_initializer: str, optional. The initializer for the kernel
weights. Defaults to "glorot_uniform".

Style Guide References

Footnotes

  1. Type information should be provided in the docstring Args section using the format arg_name: type. description.

Comment on lines +267 to +288
Args:
intermediate_dim (int): The intermediate size of the feed-forward
network.
num_query_heads (int): The number of query attention heads.
num_key_value_heads (int): The number of key and value attention
heads.
num_experts (int): The total number of experts in the MoE layer.
top_k (int, optional): The number of experts to route each token to.
Defaults to 2.
output_router_logits (bool, optional): If True, the router logits will
be returned by the layer. Defaults to False.
rope_max_wavelength (int, optional): The maximum wavelength for the
rotary position embedding. Defaults to 10000.
rope_scaling_factor (float, optional): The scaling factor for the
rotary position embedding. Defaults to 1.0.
layer_norm_epsilon (float, optional): The epsilon for layer
normalization. Defaults to 1e-6.
kernel_initializer (str, optional): The initializer for the kernel
weights. Defaults to "glorot_uniform".
sliding_window (int, optional): The size of the sliding window for
attention. Defaults to 4096.
dropout (float, optional): The dropout rate. Defaults to 0.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring format for arguments does not follow the style guide. The format should be arg_name: type. instead of arg_name (type):. 1

    Args:
        intermediate_dim: int. The intermediate size of the feed-forward
            network.
        num_query_heads: int. The number of query attention heads.
        num_key_value_heads: int. The number of key and value attention
            heads.
        num_experts: int. The total number of experts in the MoE layer.
        top_k: int, optional. The number of experts to route each token to.
            Defaults to 2.
        output_router_logits: bool, optional. If True, the router logits will
            be returned by the layer. Defaults to False.
        rope_max_wavelength: int, optional. The maximum wavelength for the
            rotary position embedding. Defaults to 10000.
        rope_scaling_factor: float, optional. The scaling factor for the
            rotary position embedding. Defaults to 1.0.
        layer_norm_epsilon: float, optional. The epsilon for layer
            normalization. Defaults to 1e-6.
        kernel_initializer: str, optional. The initializer for the kernel
            weights. Defaults to "glorot_uniform".
        sliding_window: int, optional. The size of the sliding window for
            attention. Defaults to 4096.
        dropout: float, optional. The dropout rate. Defaults to 0.

Style Guide References

Footnotes

  1. Type information should be provided in the docstring Args section using the format arg_name: type. description.

@laxmareddyp
Copy link
Collaborator Author

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new GptOss model, including its backbone, causal language model, preprocessor, and associated tests and conversion scripts. The code is script-generated, and this review focuses on its quality, correctness, and adherence to the project's style guide.

The model definition and tests are generally well-structured. However, there are a few issues in the GptOssBackbone implementation regarding initialization and configuration that need to be addressed. The most critical issue lies in the Hugging Face conversion script (convert_gpt_oss.py), which appears to contain significant copy-paste errors from another model's converter. This will prevent correct weight loading and needs a substantial revision. My detailed comments provide specific suggestions to fix these issues.

Comment on lines +44 to +159
for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"transformer_layer_{i}")

# Input layernorm
loader.port_weight(
keras_variable=decoder_layer._self_attention_layernorm.scale,
hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
)

# Attention layers
attention_layer = decoder_layer._self_attention_layer
# Query
loader.port_weight(
keras_variable=attention_layer.query_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
hook_fn=transpose_and_reshape,
)
# Key
loader.port_weight(
keras_variable=attention_layer.key_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
hook_fn=transpose_and_reshape,
)
# Value
loader.port_weight(
keras_variable=attention_layer.value_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
hook_fn=transpose_and_reshape,
)
# Output
loader.port_weight(
keras_variable=attention_layer.output_dense.kernel,
hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
hook_fn=transpose_and_reshape,
)
# Sinks
loader.port_weight(
keras_variable=attention_layer.sinks,
hf_weight_key=f"model.layers.{i}.self_attn.sinks",
)

# MoE layers
moe_block = decoder_layer._sparse_moe_block
# Router gate
loader.port_weight(
keras_variable=moe_block._sparse_feedforward_gate_dense.kernel,
hf_weight_key=f"model.layers.{i}.mlp.router.weight",
hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
)
loader.port_weight(
keras_variable=moe_block._sparse_feedforward_gate_dense.bias,
hf_weight_key=f"model.layers.{i}.mlp.router.bias",
)

# Batched experts
gate_up_proj = loader.get_tensor(
f"model.layers.{i}.mlp.experts.gate_up_proj"
)
gate_up_proj_bias = loader.get_tensor(
f"model.layers.{i}.mlp.experts.gate_up_proj_bias"
)
down_proj = loader.get_tensor(f"model.layers.{i}.mlp.experts.down_proj")
down_proj_bias = loader.get_tensor(
f"model.layers.{i}.mlp.experts.down_proj_bias"
)

# De-interleave gate and up projections
gate_proj_kernel = gate_up_proj[:, :, ::2]
up_proj_kernel = gate_up_proj[:, :, 1::2]
gate_proj_bias = gate_up_proj_bias[:, ::2]
up_proj_bias = gate_up_proj_bias[:, 1::2]

# Assign batched weights to expert_bank
expert_bank = moe_block.expert_bank
expert_bank._expert_feedforward_gate_dense.kernel.assign(
gate_proj_kernel
)
expert_bank._expert_feedforward_gate_dense.bias.assign(gate_proj_bias)
expert_bank._expert_feedforward_intermediate_dense.kernel.assign(
up_proj_kernel
)
expert_bank._expert_feedforward_intermediate_dense.bias.assign(
up_proj_bias
)
expert_bank._expert_feedforward_output_dense.kernel.assign(down_proj)
expert_bank._expert_feedforward_output_dense.bias.assign(down_proj_bias)

# Feedforward layernorm
loader.port_weight(
keras_variable=decoder_layer._feedforward_layernorm.scale,
hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
)

# Final normalization layer
loader.port_weight(
keras_variable=backbone.get_layer("sequence_output_layernorm").scale,
hf_weight_key="model.norm.weight",
)

return backbone

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The convert_weights function contains several errors, likely from being copied from another model's conversion script. The layer names are incorrect, and the logic for handling Mixture-of-Experts (MoE) weights is flawed.

Specifically:

  • Private attributes like _self_attention_layernorm are accessed, but the correct public names are input_layernorm, self_attention_layer, etc.
  • The router weight access is incorrect (moe_block._sparse_feedforward_gate_dense.kernel should be moe_block.router.router_dense.kernel).
  • The expert weight conversion logic incorrectly assumes a moe_block.expert_bank structure and performs unnecessary de-interleaving. The GptOssExperts layer expects batched weights to be assigned directly.

I've provided a corrected implementation below that addresses these issues. Note that this assumes the Hugging Face model has gemma-like MoE weight names, which will need to be verified.

def convert_weights(backbone, loader, transformers_config):
    """Convert Gpt-Oss weights."""
    # Embeddings
    loader.port_weight(
        keras_variable=backbone.token_embedding.embeddings,
        hf_weight_key="model.embed_tokens.weight",
    )
    loader.port_weight(
        keras_variable=backbone.token_embedding.reverse_embeddings,
        hf_weight_key="lm_head.weight",
        hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
    )

    def transpose_and_reshape(x, shape):
        return np.reshape(np.transpose(x), shape)

    for i in range(backbone.num_layers):
        decoder_layer = backbone.get_layer(f"transformer_layer_{i}")

        # Input layernorm
        loader.port_weight(
            keras_variable=decoder_layer.input_layernorm.scale,
            hf_weight_key=f"model.layers.{i}.input_layernorm.weight",
        )

        # Attention layers
        attention_layer = decoder_layer.self_attention_layer
        # Query
        loader.port_weight(
            keras_variable=attention_layer.query_dense.kernel,
            hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight",
            hook_fn=transpose_and_reshape,
        )
        # Key
        loader.port_weight(
            keras_variable=attention_layer.key_dense.kernel,
            hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight",
            hook_fn=transpose_and_reshape,
        )
        # Value
        loader.port_weight(
            keras_variable=attention_layer.value_dense.kernel,
            hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight",
            hook_fn=transpose_and_reshape,
        )
        # Output
        loader.port_weight(
            keras_variable=attention_layer.output_dense.kernel,
            hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight",
            hook_fn=transpose_and_reshape,
        )
        # Sinks
        loader.port_weight(
            keras_variable=attention_layer.sinks,
            hf_weight_key=f"model.layers.{i}.self_attn.sinks",
        )

        # MoE layers
        moe_block = decoder_layer.sparse_moe_block
        # Router gate
        loader.port_weight(
            keras_variable=moe_block.router.router_dense.kernel,
            hf_weight_key=f"model.layers.{i}.block_sparse_moe.gate.weight",
            hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)),
        )

        # Experts
        experts_layer = moe_block.experts
        hf_expert_weights = [
            loader.get_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.gate_up_proj.weight")
            for j in range(backbone.num_experts)
        ]
        experts_layer.gate_up_proj.assign(np.stack(hf_expert_weights, axis=0))

        hf_expert_biases = [
            loader.get_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.gate_up_proj.bias")
            for j in range(backbone.num_experts)
        ]
        experts_layer.gate_up_proj_bias.assign(np.stack(hf_expert_biases, axis=0))

        hf_down_proj_weights = [
            loader.get_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.down_proj.weight")
            for j in range(backbone.num_experts)
        ]
        experts_layer.down_proj.assign(np.stack(hf_down_proj_weights, axis=0))

        hf_down_proj_biases = [
            loader.get_tensor(f"model.layers.{i}.block_sparse_moe.experts.{j}.down_proj.bias")
            for j in range(backbone.num_experts)
        ]
        experts_layer.down_proj_bias.assign(np.stack(hf_down_proj_biases, axis=0))

        # Post-attention layernorm
        loader.port_weight(
            keras_variable=decoder_layer.post_attention_layernorm.scale,
            hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight",
        )

    # Final normalization layer
    loader.port_weight(
        keras_variable=backbone.layer_norm.scale,
        hf_weight_key="model.norm.weight",
    )

    return backbone

Comment on lines +171 to +172
for transformer_layer in self.transformer_layers:
x = transformer_layer(x, decoder_padding_mask=padding_mask_input)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The loop for processing transformer layers doesn't account for the GptOssTransformerDecoder returning a tuple when output_router_logits is True. This will cause an Unpacking error at runtime if the backbone is initialized with output_router_logits=True. The backbone should handle this tuple output, for example by discarding the extra router_logits if they are not used.

        for transformer_layer in self.transformer_layers:
            # The decoder layer might return a tuple of (sequence_output, router_logits).
            # We are only interested in the sequence_output here.
            output = transformer_layer(x, decoder_padding_mask=padding_mask_input)
            x = output[0] if isinstance(output, tuple) else output

Comment on lines +110 to +128
def __init__(
self,
vocabulary_size,
num_layers,
num_query_heads,
hidden_dim,
intermediate_dim,
num_key_value_heads,
num_experts,
top_k=2,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
layer_norm_epsilon=1e-6,
sliding_window=4096,
dropout=0,
dtype=None,
output_router_logits=False,
**kwargs,
):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ method should call super().__init__(dtype=dtype, **kwargs) at the beginning. This is important for properly initializing the model, including setting the dtype_policy. The current implementation defines layers before the super call, which can lead to unexpected behavior if dtype is not explicitly passed during layer instantiation.1

Suggested change
def __init__(
self,
vocabulary_size,
num_layers,
num_query_heads,
hidden_dim,
intermediate_dim,
num_key_value_heads,
num_experts,
top_k=2,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
layer_norm_epsilon=1e-6,
sliding_window=4096,
dropout=0,
dtype=None,
output_router_logits=False,
**kwargs,
):
def __init__(
self,
vocabulary_size,
num_layers,
num_query_heads,
hidden_dim,
intermediate_dim,
num_key_value_heads,
num_experts,
top_k=2,
rope_max_wavelength=10000,
rope_scaling_factor=1.0,
layer_norm_epsilon=1e-6,
sliding_window=4096,
dropout=0,
dtype=None,
output_router_logits=False,
**kwargs,
):
super().__init__(dtype=dtype, **kwargs)

Style Guide References

Footnotes

  1. The style guide example for a backbone's __init__ method shows that super().__init__ should be called at the beginning to initialize the model before defining layers.

Comment on lines +184 to +197
# === Config ===
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_query_heads = num_query_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_key_value_heads = num_key_value_heads
self.num_experts = num_experts
self.top_k = top_k
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_factor = rope_scaling_factor
self.sliding_window = sliding_window
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output_router_logits parameter from __init__ is not stored as a class attribute and is missing from the get_config method. This will cause an error when saving and loading the model if output_router_logits is set to True.

        # === Config ===
        self.vocabulary_size = vocabulary_size
        self.num_layers = num_layers
        self.num_query_heads = num_query_heads
        self.hidden_dim = hidden_dim
        self.intermediate_dim = intermediate_dim
        self.num_key_value_heads = num_key_value_heads
        self.num_experts = num_experts
        self.top_k = top_k
        self.rope_max_wavelength = rope_max_wavelength
        self.rope_scaling_factor = rope_scaling_factor
        self.sliding_window = sliding_window
        self.layer_norm_epsilon = layer_norm_epsilon
        self.dropout = dropout
        self.output_router_logits = output_router_logits

Comment on lines +215 to +216
"dropout": self.dropout,
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output_router_logits parameter is missing from the configuration dictionary returned by get_config. It should be added to ensure proper model serialization.

                "dropout": self.dropout,
                "output_router_logits": self.output_router_logits,
            }

@mattdangerw
Copy link
Member

@laxmareddyp does it work? Looks like it probably doesn't even output valid code yet right? Going by the test output.

The overall code style looks ok, so this will come down to how accurate the code is.

I would view scripts like this as a developer tool rather than an automated workflow. So maybe the thing to do here is to try getting this code end to end working. Then you can have some first hand experience on how useful this was at saving time. If there's tons of hard to find errors in the code, this might be slower than doing it yourself. If it's pretty accurate, maybe this is saving time and worth putting forward as a tool for contributors to use.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants