Skip to content

Conversation

@kamillobinski
Copy link
Contributor

Some models within the MLX ecosystem, especially custom or older forks define load_weights(weights) without a strict keyword argument. With recent updates, utils.load_model() passes strict=True unconditionally which causes runtime errors when loading such models.

This PR inspects the load_weights signature at runtime using pythons inspect module and passes the strict parameter only if it’s explicitly supported.

This allows:

  • full compatibility with legacy/custom models,
  • smooth usage of updated utils across projects,
  • forward compatibility with future models that may redefine the method.

No functionality is altered for models that already support strict - this simply makes the system more robust when loading models from various sources.

Prevents TypeError when downstream models omit 'strict' kwarg. Compatible with both legacy and current MLX model APIs. Fixes silent breakages during migration or fork usage.
@Blaizzy
Copy link
Owner

Blaizzy commented Apr 25, 2025

Hey Kamil

Thanks!

Could you share an example of a model that fails with strict true?

@Blaizzy
Copy link
Owner

Blaizzy commented Apr 25, 2025

Because strict here is to check if all model weights match the parameters properly

https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.Module.load_weights.html

@kamillobinski
Copy link
Contributor Author

Hey! You are right, strict=True is valid per nn.Module.load_weights() and most models do support it. That said, some models override load_weights without carrying over the strict parameter. One example is:

# mlx-audio/tts/models/sesame/model.py

def load_weights(self, weights):
    self.model.load_weights(weights)

Calling utils.load_model(..., strict=True) for mlx-community/csm-1b-fp16 on that throws:

TypeError: Model.load_weights() got an unexpected keyword argument 'strict'

This patch guards that call using inspect to check if strict is accepted. It avoids runtime crashes and ensures downstream compatibility. A small, safe fallback.

Let me know if you'd prefer a different fix.

@lucasnewman
Copy link
Collaborator

I think a quick tweak to the Sesame implementation of load_weights() that passed through the strict parameter would be best — it was just an oversight from when I added it originally, and we definitely want the parameter validation in general.

@kamillobinski
Copy link
Contributor Author

kamillobinski commented Apr 25, 2025

Thanks Lucas, thats fair. Just pushed an update that adds strict to the sesame models load_weights implementation.
I kept the inspect fallback in utils.load_model as a precaution, but feel free to remove it or let me know and i will revert those lines cleanly.

Appreciate the quick feedback!

@Blaizzy
Copy link
Owner

Blaizzy commented Apr 26, 2025

Thanks @kamillobinski and @lucasnewman!

@Blaizzy Blaizzy merged commit 8652b35 into Blaizzy:main Apr 26, 2025
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants