-
Notifications
You must be signed in to change notification settings - Fork 30.2k
[Flax .from_pretrained
] Raise a warning if model weights are not in float32
#16762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax .from_pretrained
] Raise a warning if model weights are not in float32
#16762
Conversation
.from_pretrained
] Raise a warning if model weights are not in float32
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
# dictionary of key: bools that establish whether each parameter is in jnp.float32 | ||
param_dtypes = jax.tree_map(lambda x: x.dtype != jnp.float32, state) | ||
# extract keys of parameters not in jnp.float32 | ||
downcast_params = [k for k in param_dtypes if param_dtypes[k]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) maybe name it as half_precision_params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have modified the code to generate two different lists: one for fp16 params and another for bf16 params. The warning message then specifies the erroneous dtype of the model weights loaded. I think this is more informative than simply saying the weights are in a dtype different to fp32.
As an example, loading a set of PyTorch float16 Bart model weights into a FlaxBartForCausalLM model produces the following warning: from transformers import FlaxBartForCausalLM
model = FlaxBartForCausalLM.from_pretrained('sanchit-gandhi/tiny-random-bart-fp16', from_pt=True)
|
if len(fp16_params) > 0: | ||
logger.warning( | ||
f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from " | ||
f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n" | ||
"You should probably UPCAST the model weights to float32 if this was not intended. " | ||
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." | ||
) | ||
|
||
if len(bf16_params) > 0: | ||
logger.warning( | ||
f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from " | ||
f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n" | ||
"You should probably UPCAST the model weights to float32 if this was not intended. " | ||
"See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Sorry this a super nitty question, but I just wanted to ask to make sure we're all on the same page for best practice! Should one not ideally merge their own PR's rather than the reviewer? |
Aah, yes! One should merge their own PRs, I rushed a bit this one. |
… float32 (huggingface#16762) * [Flax] Raise a warning if model weights are not in float32 * apply suggestions and few small changes * reorder wording for better readability
The Flax
.from_pretrained
method respects the dtype of the model weights from which it is loaded. For model weights stored in bfloat16/float16, Flax models are instantiated with parameter weights in bfloat16/float16 respectively (see #16736). The general assumption is that all Flax model weights are in float32. Loading and storing model weights in a lower precision (bfloat16/float16) is likely to lead to undesirable behaviour and model instabilities. This PR adds a warning to the.from_pretrained
method should any of the model weights not be in float32, and advices the user to upcast the weights to float32 prior to use.