Skip to content

[Flax] Torch fp16 model weights not upcast when loaded in Flax #16736

@sanchit-gandhi

Description

@sanchit-gandhi

In some scenarios, one may want to load a Flax model directly from pre-trained PyTorch model weights. In this process, the original dtype of the PyTorch model weights is maintained when loaded into Flax. For models such as bart-large, which has it's PyTorch weights stored in fp16 on the Hub, this can result in a Flax model with weights in an undesirable dtype. This is highlighted by the following code snippet, which first loads a FlaxSpeechEncoderDecoderModel from entirely fp32 PyTorch weights, and then again from fp32 encoder weights and fp16 decoder weights:

from transformers import FlaxSpeechEncoderDecoderModel

# fp32 PyTorch weights
encoder_id = 'hf-internal-testing/tiny-random-wav2vec2'
decoder_id = 'hf-internal-testing/tiny-random-bart'

model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)
print("-----------From fp32 PyTorch weights-----------")
print(f"Encoder dtype: {model.params['encoder']['masked_spec_embed'].dtype}")
print(f"Decoder dtype: {model.params['decoder']['model']['decoder']['embed_tokens']['embedding'].dtype}")

# same decoder as previously, but with weights downcasted to fp16
decoder_id = 'sanchit-gandhi/tiny-random-bart-fp16'

model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_from_pt=True, decoder_from_pt=True)
print("---------From fp32/fp16 PyTorch weights---------")
print(f"Encoder dtype: {model.params['encoder']['masked_spec_embed'].dtype}")
print(f"Decoder dtype: {model.params['decoder']['model']['decoder']['embed_tokens']['embedding'].dtype}")

Output:

-----------From fp32 PyTorch weights-----------
Encoder dtype: float32
Decoder dtype: float32

---------From fp32/fp16 PyTorch weights---------
Encoder dtype: float32
Decoder dtype: float16

Having a model stored in two different dtype raises issues with training - Optax optimisers expect the model to maintain one uniform dtype. Furthermore, the default assumption is that all Flax model weights are in fp32.

This weight conversion is handled by the general conversion script: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_flax_pytorch_utils.py. Would it be wise to inform the user of the potentially erroneous model dtype in this scenario? If informed, they could then call the to_fp32 method from modeling_flax_utils to upcast the weights to fp32:

def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions