Skip to content

Conversation

sanchit-gandhi
Copy link
Contributor

The Flax .from_pretrained method respects the dtype of the model weights from which it is loaded. For model weights stored in bfloat16/float16, Flax models are instantiated with parameter weights in bfloat16/float16 respectively (see #16736). The general assumption is that all Flax model weights are in float32. Loading and storing model weights in a lower precision (bfloat16/float16) is likely to lead to undesirable behaviour and model instabilities. This PR adds a warning to the .from_pretrained method should any of the model weights not be in float32, and advices the user to upcast the weights to float32 prior to use.

@sanchit-gandhi sanchit-gandhi changed the title [Flax .from_pretrained] Raise a warning if model weights are not in float32 [Flax .from_pretrained] Raise a warning if model weights are not in float32 Apr 13, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 13, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

# dictionary of key: bools that establish whether each parameter is in jnp.float32
param_dtypes = jax.tree_map(lambda x: x.dtype != jnp.float32, state)
# extract keys of parameters not in jnp.float32
downcast_params = [k for k in param_dtypes if param_dtypes[k]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) maybe name it as half_precision_params

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have modified the code to generate two different lists: one for fp16 params and another for bf16 params. The warning message then specifies the erroneous dtype of the model weights loaded. I think this is more informative than simply saying the weights are in a dtype different to fp32.

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Apr 13, 2022

As an example, loading a set of PyTorch float16 Bart model weights into a FlaxBartForCausalLM model produces the following warning:

from transformers import FlaxBartForCausalLM
model = FlaxBartForCausalLM.from_pretrained('sanchit-gandhi/tiny-random-bart-fp16', from_pt=True)
Some of the weights of FlaxBartForCausalLM were initialized in float16 precision from the model checkpoint at sanchit-gandhi/tiny-random-bart-fp16:
[('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'embed_tokens', 'embedding'), ('model', 'decoder', 'layernorm_embedding', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'encoder_attn', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'encoder_attn_layer_norm', 'bias'), ('model', 'decoder', 'layers', '0', 'encoder_attn_layer_norm', 'scale'), ('model', 'decoder', 'layers', '0', 'fc1', 'bias'), ('model', 'decoder', 'layers', '0', 'fc1', 'kernel'), ('model', 'decoder', 'layers', '0', 'fc2', 'bias'), ('model', 'decoder', 'layers', '0', 'fc2', 'kernel'), ('model', 'decoder', 'layers', '0', 'final_layer_norm', 'bias'), ('model', 'decoder', 'layers', '0', 'final_layer_norm', 'scale'), ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('model', 'decoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', '0', 'self_attn_layer_norm', 'bias'), ('model', 'decoder', 'layers', '0', 'self_attn_layer_norm', 'scale'), ('model', 'decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'bias'), ('model', 'decoder', 'layers', '1', 'encoder_attn', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'bias'), ('model', 'decoder', 'layers', '1', 'encoder_attn', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'bias'), ('model', 'decoder', 'layers', '1', 'encoder_attn', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'bias'), ('model', 'decoder', 'layers', '1', 'encoder_attn', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', '1', 'encoder_attn_layer_norm', 'bias'), ('model', 'decoder', 'layers', '1', 'encoder_attn_layer_norm', 'scale'), ('model', 'decoder', 'layers', '1', 'fc1', 'bias'), ('model', 'decoder', 'layers', '1', 'fc1', 'kernel'), ('model', 'decoder', 'layers', '1', 'fc2', 'bias'), ('model', 'decoder', 'layers', '1', 'fc2', 'kernel'), ('model', 'decoder', 'layers', '1', 'final_layer_norm', 'bias'), ('model', 'decoder', 'layers', '1', 'final_layer_norm', 'scale'), ('model', 'decoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('model', 'decoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('model', 'decoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('model', 'decoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('model', 'decoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('model', 'decoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('model', 'decoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('model', 'decoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('model', 'decoder', 'layers', '1', 'self_attn_layer_norm', 'bias'), ('model', 'decoder', 'layers', '1', 'self_attn_layer_norm', 'scale')]
You should probably UPCAST the model weights to float32 if this was not intended. See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this.

Comment on lines +667 to +682
if len(fp16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)

if len(bf16_params) > 0:
logger.warning(
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
"You should probably UPCAST the model weights to float32 if this was not intended. "
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@patil-suraj patil-suraj merged commit d8269eb into huggingface:main Apr 14, 2022
@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Apr 14, 2022

Sorry this a super nitty question, but I just wanted to ask to make sure we're all on the same page for best practice! Should one not ideally merge their own PR's rather than the reviewer?

@patil-suraj
Copy link
Contributor

Aah, yes! One should merge their own PRs, I rushed a bit this one.

elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
… float32 (huggingface#16762)

* [Flax] Raise a warning if model weights are not in float32

* apply suggestions and few small changes

* reorder wording for better readability
@sanchit-gandhi sanchit-gandhi deleted the flax-from-pretrained branch June 25, 2023 09:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants