Skip to content

Failed to load Wav2Vec2ForSequenceClassification model #1487

@abdofallah

Description

@abdofallah

Describe the bug
Unable to load the weights for Wav2Vec2ForSequenceClassification model. Despite meticulously creating a C# class structure that mirrors the official Hugging Face PyTorch implementation, we consistently receive a System.ArgumentException: 'Mismatched state_dict sizes...' error upon calling model.load().

To Reproduce

  1. Get the Model: The model is a fine-tuned Wav2Vec2 for voicemail detection from Bland AI, available on Hugging Face: blandai/wav2vec-vm-finetune. The model's config.json confirms its architecture is Wav2Vec2ForSequenceClassification.

  2. C# Model Definition: Create the following C# classes, which are designed to be a direct replica of the Hugging Face Wav2Vec2ForSequenceClassification source code.

// File: Wav2Vec2ForSequenceClassificationSharp.cs
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;

public class Wav2Vec2ForSequenceClassificationSharp : Module<Tensor, Tensor>
{
    // Direct children to match the Python model's structure.
    private readonly Wav2Vec2Model wav2vec2;
    private readonly Linear projector;
    private readonly Linear classifier;

    public Wav2Vec2ForSequenceClassificationSharp(
        torchaudio.models.FeatureExtractorNormMode extractor_mode,
        long[][] extractor_conv_layer_config, bool extractor_conv_bias,
        int encoder_embed_dim, double encoder_projection_dropout,
        int encoder_pos_conv_kernel, int encoder_pos_conv_groups,
        int encoder_num_layers, int encoder_num_heads,
        double encoder_attention_dropout, int encoder_ff_interm_features,
        double encoder_ff_interm_dropout, double encoder_dropout,
        bool encoder_layer_norm_first, double encoder_layer_drop,
        int classifier_proj_size, int num_labels
        ) : base(nameof(Wav2Vec2ForSequenceClassificationSharp))
    {
        // The names of these member variables ("wav2vec2", "projector", "classifier")
        // must exactly match the prefixes in the state_dict.
        this.wav2vec2 = torchaudio.models.wav2vec2_model(
            extractor_mode, extractor_conv_layer_config, extractor_conv_bias,
            encoder_embed_dim, encoder_projection_dropout, encoder_pos_conv_kernel,
            encoder_pos_conv_groups, encoder_num_layers, encoder_num_heads,
            encoder_attention_dropout, encoder_ff_interm_features, encoder_ff_interm_dropout,
            encoder_dropout, encoder_layer_norm_first, encoder_layer_drop,
            aux_num_out: null
        );
        
        this.projector = Linear(encoder_embed_dim, classifier_proj_size);
        this.classifier = Linear(classifier_proj_size, num_labels);

        RegisterComponents();
    }

    public override Tensor forward(Tensor input)
    {
        var (hidden_states, _) = this.wav2vec2.forward(input);
        var projected_states = this.projector.forward(hidden_states);
        var pooled_output = projected_states.mean(new long[] { 1 });
        var logits = this.classifier.forward(pooled_output);
        return logits;
    }
}
  1. C# Loading Code: Attempt to load the converted .bin file into an instance of the model.
// File: Program.cs
using TorchSharp;
using static TorchSharp.torch;

public class Program
{
    static void Main(string[] args)
    {
        Console.WriteLine("Attempting to load Wav2Vec2ForSequenceClassification model...");

        // Model Configuration from config.json
        var extractor_mode = torchaudio.models.FeatureExtractorNormMode.layer_norm;
        var extractor_conv_layer_config = new long[][] {
            new long[] { 512, 10, 5 }, new long[] { 512, 3, 2 }, new long[] { 512, 3, 2 },
            new long[] { 512, 3, 2 }, new long[] { 512, 3, 2 }, new long[] { 512, 2, 2 },
            new long[] { 512, 2, 2 }
        };
        var extractor_conv_bias = true;
        var encoder_embed_dim = 1024;
        var encoder_projection_dropout = 0.1;
        var encoder_pos_conv_kernel = 128;
        var encoder_pos_conv_groups = 16;
        var encoder_num_layers = 24;
        var encoder_num_heads = 16;
        var encoder_attention_dropout = 0.1;
        var encoder_ff_interm_features = 4096;
        var encoder_ff_interm_dropout = 0.0;
        var encoder_dropout = 0.1;
        var encoder_layer_norm_first = true;
        var encoder_layer_drop = 0.1;
        var classifier_proj_size = 256;
        var num_labels = 2;

        try
        {
            var model = new Wav2Vec2ForSequenceClassificationSharp(
                extractor_mode, extractor_conv_layer_config, extractor_conv_bias,
                encoder_embed_dim, encoder_projection_dropout, encoder_pos_conv_kernel,
                encoder_pos_conv_groups, encoder_num_layers, encoder_num_heads,
                encoder_attention_dropout, encoder_ff_interm_features,
                encoder_ff_interm_dropout, encoder_dropout,
                encoder_layer_norm_first, encoder_layer_drop,
                classifier_proj_size, num_labels
            );

            var modelPath = "path/to/your/converted_model.bin";
            
            // This is where the exception is thrown.
            model.load(modelPath, strict: true); 

            Console.WriteLine("Model loaded successfully!");
        }
        catch (Exception ex)
        {
            Console.WriteLine("\nERROR: Failed to load model.");
            Console.WriteLine(ex.ToString());
        }
    }
}

Enviornment

  • TorchSharp Version: TorchSharp-cpu v0.101.5 (or your version)
  • .NET Version: .NET 9.0
  • Operating System: Windows 10 x64
  • CPU/GPU: CPU-only

Thanks for your contribution! ⭐

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions