-
Notifications
You must be signed in to change notification settings - Fork 213
Description
Describe the bug
Unable to load the weights for Wav2Vec2ForSequenceClassification model. Despite meticulously creating a C# class structure that mirrors the official Hugging Face PyTorch implementation, we consistently receive a System.ArgumentException: 'Mismatched state_dict sizes...' error upon calling model.load().
To Reproduce
-
Get the Model: The model is a fine-tuned Wav2Vec2 for voicemail detection from Bland AI, available on Hugging Face: blandai/wav2vec-vm-finetune. The model's config.json confirms its architecture is Wav2Vec2ForSequenceClassification.
-
C# Model Definition: Create the following C# classes, which are designed to be a direct replica of the Hugging Face Wav2Vec2ForSequenceClassification source code.
// File: Wav2Vec2ForSequenceClassificationSharp.cs
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;
using static TorchSharp.torch.nn;
public class Wav2Vec2ForSequenceClassificationSharp : Module<Tensor, Tensor>
{
// Direct children to match the Python model's structure.
private readonly Wav2Vec2Model wav2vec2;
private readonly Linear projector;
private readonly Linear classifier;
public Wav2Vec2ForSequenceClassificationSharp(
torchaudio.models.FeatureExtractorNormMode extractor_mode,
long[][] extractor_conv_layer_config, bool extractor_conv_bias,
int encoder_embed_dim, double encoder_projection_dropout,
int encoder_pos_conv_kernel, int encoder_pos_conv_groups,
int encoder_num_layers, int encoder_num_heads,
double encoder_attention_dropout, int encoder_ff_interm_features,
double encoder_ff_interm_dropout, double encoder_dropout,
bool encoder_layer_norm_first, double encoder_layer_drop,
int classifier_proj_size, int num_labels
) : base(nameof(Wav2Vec2ForSequenceClassificationSharp))
{
// The names of these member variables ("wav2vec2", "projector", "classifier")
// must exactly match the prefixes in the state_dict.
this.wav2vec2 = torchaudio.models.wav2vec2_model(
extractor_mode, extractor_conv_layer_config, extractor_conv_bias,
encoder_embed_dim, encoder_projection_dropout, encoder_pos_conv_kernel,
encoder_pos_conv_groups, encoder_num_layers, encoder_num_heads,
encoder_attention_dropout, encoder_ff_interm_features, encoder_ff_interm_dropout,
encoder_dropout, encoder_layer_norm_first, encoder_layer_drop,
aux_num_out: null
);
this.projector = Linear(encoder_embed_dim, classifier_proj_size);
this.classifier = Linear(classifier_proj_size, num_labels);
RegisterComponents();
}
public override Tensor forward(Tensor input)
{
var (hidden_states, _) = this.wav2vec2.forward(input);
var projected_states = this.projector.forward(hidden_states);
var pooled_output = projected_states.mean(new long[] { 1 });
var logits = this.classifier.forward(pooled_output);
return logits;
}
}
- C# Loading Code: Attempt to load the converted .bin file into an instance of the model.
// File: Program.cs
using TorchSharp;
using static TorchSharp.torch;
public class Program
{
static void Main(string[] args)
{
Console.WriteLine("Attempting to load Wav2Vec2ForSequenceClassification model...");
// Model Configuration from config.json
var extractor_mode = torchaudio.models.FeatureExtractorNormMode.layer_norm;
var extractor_conv_layer_config = new long[][] {
new long[] { 512, 10, 5 }, new long[] { 512, 3, 2 }, new long[] { 512, 3, 2 },
new long[] { 512, 3, 2 }, new long[] { 512, 3, 2 }, new long[] { 512, 2, 2 },
new long[] { 512, 2, 2 }
};
var extractor_conv_bias = true;
var encoder_embed_dim = 1024;
var encoder_projection_dropout = 0.1;
var encoder_pos_conv_kernel = 128;
var encoder_pos_conv_groups = 16;
var encoder_num_layers = 24;
var encoder_num_heads = 16;
var encoder_attention_dropout = 0.1;
var encoder_ff_interm_features = 4096;
var encoder_ff_interm_dropout = 0.0;
var encoder_dropout = 0.1;
var encoder_layer_norm_first = true;
var encoder_layer_drop = 0.1;
var classifier_proj_size = 256;
var num_labels = 2;
try
{
var model = new Wav2Vec2ForSequenceClassificationSharp(
extractor_mode, extractor_conv_layer_config, extractor_conv_bias,
encoder_embed_dim, encoder_projection_dropout, encoder_pos_conv_kernel,
encoder_pos_conv_groups, encoder_num_layers, encoder_num_heads,
encoder_attention_dropout, encoder_ff_interm_features,
encoder_ff_interm_dropout, encoder_dropout,
encoder_layer_norm_first, encoder_layer_drop,
classifier_proj_size, num_labels
);
var modelPath = "path/to/your/converted_model.bin";
// This is where the exception is thrown.
model.load(modelPath, strict: true);
Console.WriteLine("Model loaded successfully!");
}
catch (Exception ex)
{
Console.WriteLine("\nERROR: Failed to load model.");
Console.WriteLine(ex.ToString());
}
}
}
Enviornment
- TorchSharp Version: TorchSharp-cpu v0.101.5 (or your version)
- .NET Version: .NET 9.0
- Operating System: Windows 10 x64
- CPU/GPU: CPU-only
Thanks for your contribution! ⭐