Skip to content

Commit 8652b35

Browse files
Handle strict argument in load_weights for better model compatibility (#94)
1 parent 8c1dc4a commit 8652b35

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

mlx_audio/tts/models/sesame/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ def sanitize(self, weights):
344344

345345
return sanitized_weights
346346

347-
def load_weights(self, weights):
348-
self.model.load_weights(weights)
347+
def load_weights(self, weights, strict: bool = True):
348+
self.model.load_weights(weights, strict=strict)
349349

350350
def generate(
351351
self,

mlx_audio/tts/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import glob
22
import importlib
3+
import inspect
34
import logging
45
import shutil
56
from pathlib import Path
@@ -191,7 +192,13 @@ def get_class_predicate(p, m):
191192
class_predicate=get_class_predicate,
192193
)
193194

194-
model.load_weights(list(weights.items()), strict=strict)
195+
load_weights_sig = inspect.signature(model.load_weights)
196+
197+
kwargs = {"weights": list(weights.items())}
198+
if "strict" in load_weights_sig.parameters:
199+
kwargs["strict"] = strict
200+
201+
model.load_weights(**kwargs)
195202

196203
if not lazy:
197204
mx.eval(model.parameters())

0 commit comments

Comments
 (0)