Skip to content

Commit e50143f

Browse files
committed
reformatted strip_lcedequantize_test.py using black code style
1 parent 46c0071 commit e50143f

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

larq_compute_engine/tests/strip_lcedequantize_test.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,56 @@
1111

1212
def toy_model_sign(**kwargs):
1313
img = tf.keras.layers.Input(shape=(224, 224, 3))
14-
x = lq.layers.QuantConv2D(256, kernel_size=3, strides=1, padding="same", pad_values=1, input_quantizer="ste_sign", kernel_quantizer="ste_sign", kernel_constraint="weight_clip")(img)
14+
x = lq.layers.QuantConv2D(
15+
256,
16+
kernel_size=3,
17+
strides=1,
18+
padding="same",
19+
pad_values=1,
20+
input_quantizer="ste_sign",
21+
kernel_quantizer="ste_sign",
22+
kernel_constraint="weight_clip",
23+
)(img)
1524
x = lq.quantizers.SteSign()(x)
16-
return tf.keras.Model(inputs = img, outputs = x)
25+
return tf.keras.Model(inputs=img, outputs=x)
26+
1727

1828
def quant(x):
1929
return tf.quantization.fake_quant_with_min_max_vars(x, -3.0, 3.0)
2030

31+
2132
def toy_model_int8_sign(**kwargs):
2233
img = tf.keras.layers.Input(shape=(224, 224, 3))
2334
x = quant(img)
24-
x = lq.layers.QuantConv2D(256, kernel_size=3, strides=1, padding="same", pad_values=1, input_quantizer="ste_sign", kernel_quantizer="ste_sign", kernel_constraint="weight_clip")(img)
35+
x = lq.layers.QuantConv2D(
36+
256,
37+
kernel_size=3,
38+
strides=1,
39+
padding="same",
40+
pad_values=1,
41+
input_quantizer="ste_sign",
42+
kernel_quantizer="ste_sign",
43+
kernel_constraint="weight_clip",
44+
)(img)
2545
x = lq.quantizers.SteSign()(x)
2646
x = quant(x)
27-
return tf.keras.Model(inputs = img, outputs = x)
47+
return tf.keras.Model(inputs=img, outputs=x)
2848

2949

3050
@pytest.mark.parametrize("model_cls", [toy_model_sign, toy_model_int8_sign])
3151
@pytest.mark.parametrize("inference_input_type", [tf.int8, tf.float32])
3252
@pytest.mark.parametrize("inference_output_type", [tf.int8, tf.float32])
33-
def test_strip_lcedequantize_ops(model_cls, inference_input_type, inference_output_type):
53+
def test_strip_lcedequantize_ops(
54+
model_cls, inference_input_type, inference_output_type
55+
):
3456
model_lce = convert_keras_model(
3557
model_cls(),
3658
inference_input_type=inference_input_type,
3759
inference_output_type=inference_output_type,
38-
experimental_default_int8_range=(-6.0, 6.0) if model_cls == toy_model_sign else None,
39-
experimental_enable_bitpacked_activations=True
60+
experimental_default_int8_range=(-6.0, 6.0)
61+
if model_cls == toy_model_sign
62+
else None,
63+
experimental_enable_bitpacked_activations=True,
4064
)
4165
model_lce = strip_lcedequantize_ops(model_lce)
4266
interpreter = Interpreter(model_lce)

0 commit comments

Comments
 (0)