-
Notifications
You must be signed in to change notification settings - Fork 6.2k
[Core] fix variant-identification. #9253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
6b379a9
f155ec7
3f36e59
91253e8
dd5941e
564b8b4
fdd0435
d5cad9e
c0b1ceb
247dd93
b024a6d
fdfdc5f
dcf1852
3a71ad9
ab91852
aa631c5
453bfa5
11e4b71
dbdf0f9
671038a
57382f2
ea5ecdb
a510a9b
f583dad
dc0255a
f2ab3de
10baa9d
25ac01f
bac62ac
b6794ed
fcb4e39
4c0c5d2
0b1c2a6
8ad6b23
1190f7d
59cfefb
d72f5c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -50,7 +50,6 @@ | |||
DEPRECATED_REVISION_ARGS, | ||||
BaseOutput, | ||||
PushToHubMixin, | ||||
deprecate, | ||||
is_accelerate_available, | ||||
is_accelerate_version, | ||||
is_torch_npu_available, | ||||
|
@@ -722,6 +721,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
) | ||||
else: | ||||
cached_folder = pretrained_model_name_or_path | ||||
filenames = [] | ||||
for _, _, files in os.walk(cached_folder): | ||||
for file in files: | ||||
filenames.append(os.path.basename(file)) | ||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
if len(variant_filenames) == 0 and variant is not None: | ||||
error_message = ( | ||||
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
f" Available ones are: {model_filenames}." | ||||
) | ||||
raise ValueError(error_message) | ||||
|
||||
config_dict = cls.load_config(cached_folder) | ||||
|
||||
|
@@ -1239,6 +1250,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
model_info_call_error = e # save error to reraise it if model is not cached locally | ||||
|
||||
if not local_files_only: | ||||
filenames = {sibling.rfilename for sibling in info.siblings} | ||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
if len(variant_filenames) == 0 and variant is not None: | ||||
error_message = ( | ||||
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
f" Available ones are: {model_filenames}." | ||||
) | ||||
raise ValueError(error_message) | ||||
|
||||
config_file = hf_hub_download( | ||||
pretrained_model_name, | ||||
cls.config_name, | ||||
|
@@ -1255,9 +1275,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
# retrieve all folder_names that contain relevant files | ||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] | ||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings} | ||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
|
||||
Comment on lines
-1270
to
-1272
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was moved up to raise error earlier in code. |
||||
diffusers_module = importlib.import_module(__name__.split(".")[0]) | ||||
pipelines = getattr(diffusers_module, "pipelines") | ||||
|
||||
|
@@ -1279,15 +1296,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." | ||||
) | ||||
|
||||
if len(variant_filenames) == 0 and variant is not None: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's not remove this error in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not an error, though. It's a deprecation. Do we exactly want to keep it that way? If so, we will have to remove it anyway because the deprecation is supposed to expire after "0.24.0" version. Instead, we are erroring out now from
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got it. I think this should be resolved now. WDYT about catching these errors without having to download the actual files and leveraging This could live in a future PR.
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
deprecation_message = ( | ||||
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" | ||||
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant" | ||||
"modeling files is deprecated." | ||||
) | ||||
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False) | ||||
|
||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
# remove ignored filenames | ||||
model_filenames = set(model_filenames) - set(ignore_filenames) | ||||
variant_filenames = set(variant_filenames) - set(ignore_filenames) | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -655,7 +655,7 @@ def test_local_save_load_index(self): | |
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="np").images | ||
|
||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
pipe.save_pretrained(tmpdirname) | ||
pipe.save_pretrained(tmpdirname, variant=variant, safe_serialization=use_safe) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should have been serialized with |
||
pipe_2 = StableDiffusionPipeline.from_pretrained( | ||
tmpdirname, safe_serialization=use_safe, variant=variant | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1859,6 +1859,74 @@ def callback_increase_guidance(pipe, i, t, callback_kwargs): | |
# accounts for models that modify the number of inference steps based on strength | ||
assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps) | ||
|
||
def test_serialization_with_variants(self): | ||
components = self.get_dummy_components() | ||
pipe = self.pipeline_class(**components) | ||
model_components = [ | ||
component_name for component_name, component in pipe.components.items() if isinstance(component, nn.Module) | ||
] | ||
variant = "fp16" | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) | ||
|
||
with open(f"{tmpdir}/model_index.json", "r") as f: | ||
config = json.load(f) | ||
|
||
for subfolder in os.listdir(tmpdir): | ||
if not os.path.isfile(subfolder) and subfolder in model_components: | ||
folder_path = os.path.join(tmpdir, subfolder) | ||
is_folder = os.path.isdir(folder_path) and subfolder in config | ||
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) | ||
|
||
def test_loading_with_variants(self): | ||
components = self.get_dummy_components() | ||
pipe = self.pipeline_class(**components) | ||
variant = "fp16" | ||
|
||
def is_nan(tensor): | ||
if tensor.ndimension() == 0: | ||
has_nan = torch.isnan(tensor).item() | ||
else: | ||
has_nan = torch.isnan(tensor).any() | ||
return has_nan | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) | ||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, variant=variant) | ||
|
||
model_components_pipe = { | ||
component_name: component | ||
for component_name, component in pipe.components.items() | ||
if isinstance(component, nn.Module) | ||
} | ||
model_components_pipe_loaded = { | ||
component_name: component | ||
for component_name, component in pipe_loaded.components.items() | ||
if isinstance(component, nn.Module) | ||
} | ||
for component_name in model_components_pipe: | ||
pipe_component = model_components_pipe[component_name] | ||
pipe_loaded_component = model_components_pipe_loaded[component_name] | ||
for p1, p2 in zip(pipe_component.parameters(), pipe_loaded_component.parameters()): | ||
# nan check for luminanext (mps). | ||
if not (is_nan(p1) and is_nan(p2)): | ||
self.assertTrue(torch.equal(p1, p2)) | ||
|
||
def test_loading_with_incorrect_variants_raises_error(self): | ||
components = self.get_dummy_components() | ||
pipe = self.pipeline_class(**components) | ||
variant = "fp16" | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
# Don't save with variants. | ||
pipe.save_pretrained(tmpdir, safe_serialization=False) | ||
|
||
with self.assertRaises(ValueError) as error: | ||
_ = self.pipeline_class.from_pretrained(tmpdir, variant=variant) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would have failed with the fixes from this PR rightfully complaining:
We didn't have it because we never tested it. But we should be all good now. |
||
|
||
assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception) | ||
|
||
def test_StableDiffusionMixin_component(self): | ||
"""Any pipeline that have LDMFuncMixin should have vae and unet components.""" | ||
if not issubclass(self.pipeline_class, StableDiffusionMixin): | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think maybe we shoud just update the
_identify_model_variants
function usingvariant_compatible_siblings
and it is still not able to load variants with shared checkpoints from pipeline level
i.e. we should be able to load the fp16 variant in the transformer folder too but it is currently not
you get
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @DN6 @a-r-r-o-w here too