@@ -346,6 +346,46 @@ def forward(x):
346
346
expected = forward (input_data )
347
347
self .assertTrue (mx .allclose (expected , out ))
348
348
349
+ def test_export_control_flow (self ):
350
+
351
+ def fun (x , y ):
352
+ if y .shape [0 ] <= 2 :
353
+ return x + y
354
+ else :
355
+ return x + 2 * y
356
+
357
+ for y in (mx .array ([1 , 2 , 3 ]), mx .array ([1 , 2 ])):
358
+ for shapeless in (True , False ):
359
+ with self .subTest (y = y , shapeless = shapeless ):
360
+ x = mx .array (1 )
361
+ export_path = os .path .join (self .test_dir , "control_flow.mlxfn" )
362
+ mx .export_function (export_path , fun , x , y , shapeless = shapeless )
363
+
364
+ imported_fn = mx .import_function (export_path )
365
+ self .assertTrue (mx .array_equal (imported_fn (x , y )[0 ], fun (x , y )))
366
+
367
+ def test_export_quantized_model (self ):
368
+ for shapeless in (True , False ):
369
+ with self .subTest (shapeless = shapeless ):
370
+ model = nn .Sequential (
371
+ nn .Linear (1024 , 512 ), nn .ReLU (), nn .Linear (512 , 1024 )
372
+ )
373
+ model .eval ()
374
+ mx .eval (model .parameters ())
375
+ input_data = mx .ones (shape = (512 , 1024 ))
376
+ nn .quantize (model )
377
+ self .assertTrue (isinstance (model .layers [0 ], nn .QuantizedLinear ))
378
+ self .assertTrue (isinstance (model .layers [2 ], nn .QuantizedLinear ))
379
+ mx .eval (model .parameters ())
380
+
381
+ export_path = os .path .join (self .test_dir , "quantized_linear.mlxfn" )
382
+ mx .export_function (export_path , model , input_data , shapeless = shapeless )
383
+
384
+ imported_fn = mx .import_function (export_path )
385
+ self .assertTrue (
386
+ mx .array_equal (imported_fn (input_data )[0 ], model (input_data ))
387
+ )
388
+
349
389
350
390
if __name__ == "__main__" :
351
391
mlx_tests .MLXTestRunner ()
0 commit comments