Skip to content

Commit 28a0811

Browse files
authored
Improve mismatched sizes management when loading a pretrained model (huggingface#17257)
- Add --ignore_mismatched_sizes argument to classification examples - Expand the error message when loading a model whose head dimensions are different from expected dimensions
1 parent 1f13ba8 commit 28a0811

File tree

13 files changed

+64
-9
lines changed

13 files changed

+64
-9
lines changed

examples/pytorch/audio-classification/README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ limitations under the License.
1818

1919
The following examples showcase how to fine-tune `Wav2Vec2` for audio classification using PyTorch.
2020

21-
Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
22-
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
23-
[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
24-
[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
21+
Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
22+
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
23+
[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
24+
[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
2525
very little annotated data to yield good performance on speech classification datasets.
2626

27-
## Single-GPU
27+
## Single-GPU
2828

2929
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the 🗣️ [Keyword Spotting subset](https://huggingface.co/datasets/superb#ks) of the SUPERB dataset.
3030

@@ -63,7 +63,9 @@ On a single V100 GPU (16GB), this script should run in ~14 minutes and yield acc
6363

6464
👀 See the results here: [anton-l/wav2vec2-base-ft-keyword-spotting](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting)
6565

66-
## Multi-GPU
66+
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
67+
68+
## Multi-GPU
6769

6870
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) for 🌎 **Language Identification** on the [CommonLanguage dataset](https://huggingface.co/datasets/anton-l/common_language).
6971

@@ -139,7 +141,7 @@ It has been verified that the script works for the following datasets:
139141

140142
| Dataset | Pretrained Model | # transformer layers | Accuracy on eval | GPU setup | Training time | Fine-tuned Model & Logs |
141143
|---------|------------------|----------------------|------------------|-----------|---------------|--------------------------|
142-
| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
144+
| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
143145
| Keyword Spotting | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) | 12 | 0.9826 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting) |
144146
| Keyword Spotting | [facebook/hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960) | 12 | 0.9819 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/hubert-base-ft-keyword-spotting) |
145147
| Keyword Spotting | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 24 | 0.9757 | 1 V100 GPU | 15min | [here](https://huggingface.co/anton-l/sew-mid-100k-ft-keyword-spotting) |

examples/pytorch/audio-classification/run_audio_classification.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,10 @@ class ModelArguments:
163163
freeze_feature_extractor: Optional[bool] = field(
164164
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
165165
)
166+
ignore_mismatched_sizes: bool = field(
167+
default=False,
168+
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
169+
)
166170

167171
def __post_init__(self):
168172
if not self.freeze_feature_extractor and self.freeze_feature_encoder:
@@ -333,6 +337,7 @@ def compute_metrics(eval_pred):
333337
cache_dir=model_args.cache_dir,
334338
revision=model_args.model_revision,
335339
use_auth_token=True if model_args.use_auth_token else None,
340+
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
336341
)
337342

338343
# freeze the convolutional waveform encoder

examples/pytorch/image-classification/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,11 @@ python run_image_classification.py \
6262

6363
Note that you can replace the model and dataset by simply setting the `model_name_or_path` and `dataset_name` arguments respectively, with any model or dataset from the [hub](https://huggingface.co/). For an overview of all possible arguments, we refer to the [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) of the `TrainingArguments`, which can be passed as flags.
6464

65+
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
66+
6567
### Using your own data
6668

67-
To use your own dataset, there are 2 ways:
69+
To use your own dataset, there are 2 ways:
6870
- you can either provide your own folders as `--train_dir` and/or `--validation_dir` arguments
6971
- you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
7072

examples/pytorch/image-classification/run_image_classification.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ class ModelArguments:
150150
)
151151
},
152152
)
153+
ignore_mismatched_sizes: bool = field(
154+
default=False,
155+
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
156+
)
153157

154158

155159
def collate_fn(examples):
@@ -269,6 +273,7 @@ def compute_metrics(p):
269273
cache_dir=model_args.cache_dir,
270274
revision=model_args.model_revision,
271275
use_auth_token=True if model_args.use_auth_token else None,
276+
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
272277
)
273278
feature_extractor = AutoFeatureExtractor.from_pretrained(
274279
model_args.feature_extractor_name or model_args.model_name_or_path,

examples/pytorch/image-classification/run_image_classification_no_trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ def parse_args():
165165
action="store_true",
166166
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
167167
)
168+
parser.add_argument(
169+
"--ignore_mismatched_sizes",
170+
action="store_true",
171+
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
172+
)
168173
args = parser.parse_args()
169174

170175
# Sanity checks
@@ -278,6 +283,7 @@ def main():
278283
args.model_name_or_path,
279284
from_tf=bool(".ckpt" in args.model_name_or_path),
280285
config=config,
286+
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
281287
)
282288

283289
# Preprocessing the datasets

examples/pytorch/text-classification/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/
2222

2323
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
2424
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models)
25-
and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
25+
and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
2626
(the script might need some tweaks in that case, refer to the comments inside for help).
2727

2828
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
@@ -79,6 +79,8 @@ python run_glue.py \
7979
--output_dir /tmp/imdb/
8080
```
8181

82+
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
83+
8284

8385
### Mixed precision training
8486

examples/pytorch/text-classification/run_glue.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ class ModelArguments:
196196
)
197197
},
198198
)
199+
ignore_mismatched_sizes: bool = field(
200+
default=False,
201+
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
202+
)
199203

200204

201205
def main():
@@ -364,6 +368,7 @@ def main():
364368
cache_dir=model_args.cache_dir,
365369
revision=model_args.model_revision,
366370
use_auth_token=True if model_args.use_auth_token else None,
371+
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
367372
)
368373

369374
# Preprocessing the raw_datasets

examples/pytorch/text-classification/run_glue_no_trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ def parse_args():
170170
action="store_true",
171171
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
172172
)
173+
parser.add_argument(
174+
"--ignore_mismatched_sizes",
175+
action="store_true",
176+
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
177+
)
173178
args = parser.parse_args()
174179

175180
# Sanity checks
@@ -288,6 +293,7 @@ def main():
288293
args.model_name_or_path,
289294
from_tf=bool(".ckpt" in args.model_name_or_path),
290295
config=config,
296+
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
291297
)
292298

293299
# Preprocessing the datasets

examples/pytorch/text-classification/run_xnli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ class ModelArguments:
162162
)
163163
},
164164
)
165+
ignore_mismatched_sizes: bool = field(
166+
default=False,
167+
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
168+
)
165169

166170

167171
def main():
@@ -291,6 +295,7 @@ def main():
291295
cache_dir=model_args.cache_dir,
292296
revision=model_args.model_revision,
293297
use_auth_token=True if model_args.use_auth_token else None,
298+
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
294299
)
295300

296301
# Preprocessing the datasets

examples/pytorch/token-classification/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ uses special features of those tokenizers. You can check if your favorite model
5555
[this table](https://huggingface.co/transformers/index.html#supported-frameworks), if it doesn't you can still use the old version
5656
of the script.
5757

58+
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
59+
5860
## Old version of the script
5961

6062
You can find the old version of the PyTorch script [here](https://github.com/huggingface/transformers/blob/main/examples/legacy/token-classification/run_ner.py).

0 commit comments

Comments
 (0)