Skip to content

Commit 10704e1

Browse files
[Test] Fix W2V-Conformer integration test (huggingface#17303)
* [Test] Fix W2V-Conformer integration test * correct w2v2 * up
1 parent 28a0811 commit 10704e1

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

src/transformers/models/wav2vec2/modeling_wav2vec2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1414,7 +1414,6 @@ def forward(
14141414
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
14151415
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
14161416
>>> from datasets import load_dataset
1417-
>>> import soundfile as sf
14181417
14191418
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
14201419
>>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")

src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,7 +1442,7 @@ def compute_contrastive_logits(
14421442

14431443
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
14441444
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1445-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2-base->wav2vec2-conformer-rel-pos-large,wav2vec2->wav2vec2_conformer
1445+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
14461446
def forward(
14471447
self,
14481448
input_values: Optional[torch.Tensor],
@@ -1470,14 +1470,9 @@ def forward(
14701470
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
14711471
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
14721472
>>> from datasets import load_dataset
1473-
>>> import soundfile as sf
1474-
1475-
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(
1476-
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
1477-
... )
1478-
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained(
1479-
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
1480-
... )
1473+
1474+
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
1475+
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
14811476
14821477
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
14831478
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1

tests/models/wav2vec2_conformer/test_modeling_wav2vec2_conformer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,10 @@ def _mock_init_weights(self, module):
581581
module.weight_v.data.fill_(3)
582582
if hasattr(module, "bias") and module.bias is not None:
583583
module.bias.data.fill_(3)
584+
if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None:
585+
module.pos_bias_u.data.fill_(3)
586+
if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None:
587+
module.pos_bias_v.data.fill_(3)
584588
if hasattr(module, "codevectors") and module.codevectors is not None:
585589
module.codevectors.data.fill_(3)
586590
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:

0 commit comments

Comments
 (0)