Skip to content

Commit 047560f

Browse files
authored
Fix Keras3 Issues in TF 2.16.1 for 3.0 new API (#1669)
Signed-off-by: zehao-intel <[email protected]>
1 parent 62aa85d commit 047560f

File tree

8 files changed

+435
-223
lines changed

8 files changed

+435
-223
lines changed

neural_compressor/tensorflow/keras/layers/conv2d.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323

2424
from neural_compressor.tensorflow.utils import version1_gte_version2
2525

26-
if version1_gte_version2(tf.__version__, "2.13.0"):
26+
if version1_gte_version2(tf.__version__, "2.16.1"):
27+
from keras.src.layers.convolutional.base_conv import BaseConv # pylint: disable=E0401
28+
29+
Conv = BaseConv
30+
elif version1_gte_version2(tf.__version__, "2.13.0"):
2731
from keras.src.layers.convolutional.base_conv import Conv # pylint: disable=E0401
2832
else:
2933
from keras.layers.convolutional.base_conv import Conv # pylint: disable=E0401

neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py

Lines changed: 191 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -23,117 +23,202 @@
2323

2424
from neural_compressor.tensorflow.utils import version1_gte_version2
2525

26-
if version1_gte_version2(tf.__version__, "2.13.0"):
26+
if version1_gte_version2(tf.__version__, "2.16.1"):
27+
from keras.src import ops
28+
from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv # pylint: disable=E0401
29+
elif version1_gte_version2(tf.__version__, "2.13.0"):
2730
from keras.src.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401
2831
from keras.src.utils import conv_utils, tf_utils # pylint: disable=E0401
2932
else:
3033
from keras.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401
3134
from keras.utils import conv_utils, tf_utils # pylint: disable=E0401
3235

36+
if version1_gte_version2(tf.__version__, "2.16.1"):
3337

34-
class QDepthwiseConv2D(DepthwiseConv):
35-
def __init__(
36-
self,
37-
kernel_size,
38-
min_value,
39-
max_value,
40-
strides=(1, 1),
41-
padding="valid",
42-
depth_multiplier=1,
43-
data_format=None,
44-
dilation_rate=(1, 1),
45-
activation=None,
46-
use_bias=True,
47-
depthwise_initializer="glorot_uniform",
48-
bias_initializer="zeros",
49-
depthwise_regularizer=None,
50-
bias_regularizer=None,
51-
activity_regularizer=None,
52-
depthwise_constraint=None,
53-
bias_constraint=None,
54-
**kwargs
55-
):
56-
super().__init__(
57-
2,
58-
kernel_size=kernel_size,
59-
strides=strides,
60-
padding=padding,
61-
depth_multiplier=depth_multiplier,
62-
data_format=data_format,
63-
dilation_rate=dilation_rate,
64-
activation=activation,
65-
use_bias=use_bias,
66-
depthwise_initializer=depthwise_initializer,
67-
bias_initializer=bias_initializer,
68-
depthwise_regularizer=depthwise_regularizer,
69-
bias_regularizer=bias_regularizer,
70-
activity_regularizer=activity_regularizer,
71-
depthwise_constraint=depthwise_constraint,
72-
bias_constraint=bias_constraint,
38+
class QDepthwiseConv2D(BaseDepthwiseConv):
39+
def __init__(
40+
self,
41+
kernel_size,
42+
min_value,
43+
max_value,
44+
strides=(1, 1),
45+
padding="valid",
46+
depth_multiplier=1,
47+
data_format=None,
48+
dilation_rate=(1, 1),
49+
activation=None,
50+
use_bias=True,
51+
depthwise_initializer="glorot_uniform",
52+
bias_initializer="zeros",
53+
depthwise_regularizer=None,
54+
bias_regularizer=None,
55+
activity_regularizer=None,
56+
depthwise_constraint=None,
57+
bias_constraint=None,
7358
**kwargs
74-
)
75-
self.min_value = json.loads(min_value)
76-
self.max_value = json.loads(max_value)
77-
78-
def call(self, inputs):
79-
# add the Q/DQ here
80-
kernel, _, _ = quantization.quantize(
81-
self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED"
82-
)
83-
kernel = quantization.dequantize(
84-
kernel,
85-
self.min_value,
86-
self.max_value,
87-
axis=3,
88-
mode="SCALED",
89-
)
90-
outputs = tf.keras.backend.depthwise_conv2d(
91-
inputs,
92-
kernel,
93-
strides=self.strides,
94-
padding=self.padding,
95-
data_format=self.data_format,
96-
dilation_rate=self.dilation_rate,
97-
)
98-
99-
if self.use_bias:
100-
outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format)
101-
102-
if self.activation is not None:
103-
return self.activation(outputs)
104-
105-
return outputs
106-
107-
@classmethod
108-
def from_config(cls, config):
109-
return cls(**config)
110-
111-
@tf_utils.shape_type_conversion
112-
def compute_output_shape(self, input_shape):
113-
if self.data_format == "channels_first":
114-
rows = input_shape[2]
115-
cols = input_shape[3]
116-
out_filters = input_shape[1] * self.depth_multiplier
117-
elif self.data_format == "channels_last":
118-
rows = input_shape[1]
119-
cols = input_shape[2]
120-
out_filters = input_shape[3] * self.depth_multiplier
121-
122-
rows = conv_utils.conv_output_length(
123-
rows,
124-
self.kernel_size[0],
125-
self.padding,
126-
self.strides[0],
127-
self.dilation_rate[0],
128-
)
129-
cols = conv_utils.conv_output_length(
130-
cols,
131-
self.kernel_size[1],
132-
self.padding,
133-
self.strides[1],
134-
self.dilation_rate[1],
135-
)
136-
if self.data_format == "channels_first":
137-
return (input_shape[0], out_filters, rows, cols)
138-
elif self.data_format == "channels_last":
139-
return (input_shape[0], rows, cols, out_filters)
59+
):
60+
super().__init__(
61+
2,
62+
kernel_size=kernel_size,
63+
strides=strides,
64+
padding=padding,
65+
depth_multiplier=depth_multiplier,
66+
data_format=data_format,
67+
dilation_rate=dilation_rate,
68+
activation=activation,
69+
use_bias=use_bias,
70+
depthwise_initializer=depthwise_initializer,
71+
bias_initializer=bias_initializer,
72+
depthwise_regularizer=depthwise_regularizer,
73+
bias_regularizer=bias_regularizer,
74+
activity_regularizer=activity_regularizer,
75+
depthwise_constraint=depthwise_constraint,
76+
bias_constraint=bias_constraint,
77+
**kwargs
78+
)
79+
self.min_value = json.loads(min_value)
80+
self.max_value = json.loads(max_value)
81+
82+
def call(self, inputs):
83+
# add the Q/DQ here
84+
kernel, _, _ = quantization.quantize(
85+
self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED"
86+
)
87+
kernel = quantization.dequantize(
88+
kernel,
89+
self.min_value,
90+
self.max_value,
91+
axis=3,
92+
mode="SCALED",
93+
)
94+
95+
input_channel = self._get_input_channel(inputs.shape)
96+
outputs = ops.depthwise_conv(
97+
inputs,
98+
self.kernel,
99+
strides=self.strides,
100+
padding=self.padding,
101+
dilation_rate=self.dilation_rate,
102+
data_format=self.data_format,
103+
)
104+
105+
if self.use_bias:
106+
if self.data_format == "channels_last":
107+
bias_shape = (1,) * (self.rank + 1) + (self.depth_multiplier * input_channel,)
108+
else:
109+
bias_shape = (1, self.depth_multiplier * input_channel) + (1,) * self.rank
110+
bias = ops.reshape(self.bias, bias_shape)
111+
outputs += bias
112+
113+
if self.activation is not None:
114+
return self.activation(outputs)
115+
return outputs
116+
117+
else:
118+
119+
class QDepthwiseConv2D(DepthwiseConv):
120+
def __init__(
121+
self,
122+
kernel_size,
123+
min_value,
124+
max_value,
125+
strides=(1, 1),
126+
padding="valid",
127+
depth_multiplier=1,
128+
data_format=None,
129+
dilation_rate=(1, 1),
130+
activation=None,
131+
use_bias=True,
132+
depthwise_initializer="glorot_uniform",
133+
bias_initializer="zeros",
134+
depthwise_regularizer=None,
135+
bias_regularizer=None,
136+
activity_regularizer=None,
137+
depthwise_constraint=None,
138+
bias_constraint=None,
139+
**kwargs
140+
):
141+
super().__init__(
142+
2,
143+
kernel_size=kernel_size,
144+
strides=strides,
145+
padding=padding,
146+
depth_multiplier=depth_multiplier,
147+
data_format=data_format,
148+
dilation_rate=dilation_rate,
149+
activation=activation,
150+
use_bias=use_bias,
151+
depthwise_initializer=depthwise_initializer,
152+
bias_initializer=bias_initializer,
153+
depthwise_regularizer=depthwise_regularizer,
154+
bias_regularizer=bias_regularizer,
155+
activity_regularizer=activity_regularizer,
156+
depthwise_constraint=depthwise_constraint,
157+
bias_constraint=bias_constraint,
158+
**kwargs
159+
)
160+
self.min_value = json.loads(min_value)
161+
self.max_value = json.loads(max_value)
162+
163+
def call(self, inputs):
164+
# add the Q/DQ here
165+
kernel, _, _ = quantization.quantize(
166+
self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED"
167+
)
168+
kernel = quantization.dequantize(
169+
kernel,
170+
self.min_value,
171+
self.max_value,
172+
axis=3,
173+
mode="SCALED",
174+
)
175+
outputs = tf.keras.backend.depthwise_conv2d(
176+
inputs,
177+
kernel,
178+
strides=self.strides,
179+
padding=self.padding,
180+
data_format=self.data_format,
181+
dilation_rate=self.dilation_rate,
182+
)
183+
184+
if self.use_bias:
185+
outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format)
186+
187+
if self.activation is not None:
188+
return self.activation(outputs)
189+
190+
return outputs
191+
192+
@classmethod
193+
def from_config(cls, config):
194+
return cls(**config)
195+
196+
@tf_utils.shape_type_conversion
197+
def compute_output_shape(self, input_shape):
198+
if self.data_format == "channels_first":
199+
rows = input_shape[2]
200+
cols = input_shape[3]
201+
out_filters = input_shape[1] * self.depth_multiplier
202+
elif self.data_format == "channels_last":
203+
rows = input_shape[1]
204+
cols = input_shape[2]
205+
out_filters = input_shape[3] * self.depth_multiplier
206+
207+
rows = conv_utils.conv_output_length(
208+
rows,
209+
self.kernel_size[0],
210+
self.padding,
211+
self.strides[0],
212+
self.dilation_rate[0],
213+
)
214+
cols = conv_utils.conv_output_length(
215+
cols,
216+
self.kernel_size[1],
217+
self.padding,
218+
self.strides[1],
219+
self.dilation_rate[1],
220+
)
221+
if self.data_format == "channels_first":
222+
return (input_shape[0], out_filters, rows, cols)
223+
elif self.data_format == "channels_last":
224+
return (input_shape[0], rows, cols, out_filters)

neural_compressor/tensorflow/keras/layers/quantizer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def call(self, inputs):
3838
self.max_value = tf.math.reduce_max(inputs, axis=self.axis)
3939
return inputs
4040

41+
def compute_output_shape(self, input_shape):
42+
input_shape = tf.TensorShape(input_shape).as_list()
43+
return input_shape
44+
4145
@classmethod
4246
def from_config(cls, config):
4347
return cls(**config)
@@ -87,6 +91,10 @@ def call(self, inputs):
8791
)
8892
return outputs
8993

94+
def compute_output_shape(self, input_shape):
95+
input_shape = tf.TensorShape(input_shape).as_list()
96+
return input_shape
97+
9098
def get_config(self):
9199
return {
92100
"min_range": self.min_range,
@@ -122,6 +130,10 @@ def call(self, inputs):
122130
axis=self.axis,
123131
)
124132

133+
def compute_output_shape(self, input_shape):
134+
input_shape = tf.TensorShape(input_shape).as_list()
135+
return input_shape
136+
125137
def get_config(self):
126138
return {
127139
"min_range": self.min_range,

0 commit comments

Comments
 (0)