|
11 | 11 |
|
12 | 12 | def toy_model_sign(**kwargs):
|
13 | 13 | img = tf.keras.layers.Input(shape=(224, 224, 3))
|
14 |
| - x = lq.layers.QuantConv2D(256, kernel_size=3, strides=1, padding="same", pad_values=1, input_quantizer="ste_sign", kernel_quantizer="ste_sign", kernel_constraint="weight_clip")(img) |
| 14 | + x = lq.layers.QuantConv2D( |
| 15 | + 256, |
| 16 | + kernel_size=3, |
| 17 | + strides=1, |
| 18 | + padding="same", |
| 19 | + pad_values=1, |
| 20 | + input_quantizer="ste_sign", |
| 21 | + kernel_quantizer="ste_sign", |
| 22 | + kernel_constraint="weight_clip", |
| 23 | + )(img) |
15 | 24 | x = lq.quantizers.SteSign()(x)
|
16 |
| - return tf.keras.Model(inputs = img, outputs = x) |
| 25 | + return tf.keras.Model(inputs=img, outputs=x) |
| 26 | + |
17 | 27 |
|
18 | 28 | def quant(x):
|
19 | 29 | return tf.quantization.fake_quant_with_min_max_vars(x, -3.0, 3.0)
|
20 | 30 |
|
| 31 | + |
21 | 32 | def toy_model_int8_sign(**kwargs):
|
22 | 33 | img = tf.keras.layers.Input(shape=(224, 224, 3))
|
23 | 34 | x = quant(img)
|
24 |
| - x = lq.layers.QuantConv2D(256, kernel_size=3, strides=1, padding="same", pad_values=1, input_quantizer="ste_sign", kernel_quantizer="ste_sign", kernel_constraint="weight_clip")(img) |
| 35 | + x = lq.layers.QuantConv2D( |
| 36 | + 256, |
| 37 | + kernel_size=3, |
| 38 | + strides=1, |
| 39 | + padding="same", |
| 40 | + pad_values=1, |
| 41 | + input_quantizer="ste_sign", |
| 42 | + kernel_quantizer="ste_sign", |
| 43 | + kernel_constraint="weight_clip", |
| 44 | + )(img) |
25 | 45 | x = lq.quantizers.SteSign()(x)
|
26 | 46 | x = quant(x)
|
27 |
| - return tf.keras.Model(inputs = img, outputs = x) |
| 47 | + return tf.keras.Model(inputs=img, outputs=x) |
28 | 48 |
|
29 | 49 |
|
30 | 50 | @pytest.mark.parametrize("model_cls", [toy_model_sign, toy_model_int8_sign])
|
31 | 51 | @pytest.mark.parametrize("inference_input_type", [tf.int8, tf.float32])
|
32 | 52 | @pytest.mark.parametrize("inference_output_type", [tf.int8, tf.float32])
|
33 |
| -def test_strip_lcedequantize_ops(model_cls, inference_input_type, inference_output_type): |
| 53 | +def test_strip_lcedequantize_ops( |
| 54 | + model_cls, inference_input_type, inference_output_type |
| 55 | +): |
34 | 56 | model_lce = convert_keras_model(
|
35 | 57 | model_cls(),
|
36 | 58 | inference_input_type=inference_input_type,
|
37 | 59 | inference_output_type=inference_output_type,
|
38 |
| - experimental_default_int8_range=(-6.0, 6.0) if model_cls == toy_model_sign else None, |
39 |
| - experimental_enable_bitpacked_activations=True |
| 60 | + experimental_default_int8_range=(-6.0, 6.0) |
| 61 | + if model_cls == toy_model_sign |
| 62 | + else None, |
| 63 | + experimental_enable_bitpacked_activations=True, |
40 | 64 | )
|
41 | 65 | model_lce = strip_lcedequantize_ops(model_lce)
|
42 | 66 | interpreter = Interpreter(model_lce)
|
|
0 commit comments