Skip to content

Commit 8b25ce6

Browse files
authored
Add tests for export including control flow models and quantized models (#2430)
* Add tests for export, including control flow export and quantized model export. * Skip quantization related test for CUDA backend.
1 parent da5912e commit 8b25ce6

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

python/tests/cuda_skip.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,5 @@
7474
"TestQuantized.test_small_matrix",
7575
"TestQuantized.test_throw",
7676
"TestQuantized.test_vjp_scales_biases",
77+
"TestExportImport.test_export_quantized_model",
7778
}

python/tests/test_export_import.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,46 @@ def forward(x):
346346
expected = forward(input_data)
347347
self.assertTrue(mx.allclose(expected, out))
348348

349+
def test_export_control_flow(self):
350+
351+
def fun(x, y):
352+
if y.shape[0] <= 2:
353+
return x + y
354+
else:
355+
return x + 2 * y
356+
357+
for y in (mx.array([1, 2, 3]), mx.array([1, 2])):
358+
for shapeless in (True, False):
359+
with self.subTest(y=y, shapeless=shapeless):
360+
x = mx.array(1)
361+
export_path = os.path.join(self.test_dir, "control_flow.mlxfn")
362+
mx.export_function(export_path, fun, x, y, shapeless=shapeless)
363+
364+
imported_fn = mx.import_function(export_path)
365+
self.assertTrue(mx.array_equal(imported_fn(x, y)[0], fun(x, y)))
366+
367+
def test_export_quantized_model(self):
368+
for shapeless in (True, False):
369+
with self.subTest(shapeless=shapeless):
370+
model = nn.Sequential(
371+
nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 1024)
372+
)
373+
model.eval()
374+
mx.eval(model.parameters())
375+
input_data = mx.ones(shape=(512, 1024))
376+
nn.quantize(model)
377+
self.assertTrue(isinstance(model.layers[0], nn.QuantizedLinear))
378+
self.assertTrue(isinstance(model.layers[2], nn.QuantizedLinear))
379+
mx.eval(model.parameters())
380+
381+
export_path = os.path.join(self.test_dir, "quantized_linear.mlxfn")
382+
mx.export_function(export_path, model, input_data, shapeless=shapeless)
383+
384+
imported_fn = mx.import_function(export_path)
385+
self.assertTrue(
386+
mx.array_equal(imported_fn(input_data)[0], model(input_data))
387+
)
388+
349389

350390
if __name__ == "__main__":
351391
mlx_tests.MLXTestRunner()

0 commit comments

Comments
 (0)