@@ -80,15 +80,30 @@ def parse_qrnn_layer(keras_layer, input_names, input_shapes, data_reader):
8080 layer ['weight_quantizer' ] = get_quantizer_from_config (keras_layer , 'kernel' )
8181 layer ['recurrent_weight_quantizer' ] = get_quantizer_from_config (keras_layer , 'recurrent' )
8282 layer ['bias_quantizer' ] = get_quantizer_from_config (keras_layer , 'bias' )
83+ layer ['accum_quantizer' ] = get_quantizer_from_config (keras_layer , 'state' )
84+
85+ if not isinstance (keras_layer ['config' ]['activation' ], str ):
86+ activation = get_activation_quantizer (keras_layer , input_names )
87+
88+ assert activation ['class_name' ] != 'HardActivation' , 'Hard activation not supported'
89+
90+ layer ['activation' ] = activation ['activation' ]
91+ layer ['activation_quantizer' ] = activation ['activation_quantizer' ]
92+
93+ if keras_layer ['class_name' ] in ['QLSTM' , 'QGRU' ] and not isinstance (keras_layer ['config' ]['recurrent_activation' ], str ):
94+ recurrent_activation = get_activation_quantizer (keras_layer , input_names , activation_name = 'recurrent_activation' )
95+
96+ assert recurrent_activation ['class_name' ] != 'HardActivation' , 'Hard activation not supported'
97+
98+ layer ['recurrent_activation' ] = recurrent_activation ['recurrent_activation' ]
99+ layer ['recurrent_activation_config' ] = recurrent_activation
83100
84101 return layer , output_shape
85102
86103
87- @keras_handler ('QActivation' )
88- def parse_qactivation_layer (keras_layer , input_names , input_shapes , data_reader ):
104+ def get_activation_quantizer (keras_layer , input_names , activation_name = 'activation' ):
89105 from qkeras .quantizers import get_quantizer
90106
91- assert keras_layer ['class_name' ] == 'QActivation'
92107 supported_activations = [
93108 'quantized_relu' ,
94109 'quantized_tanh' ,
@@ -102,7 +117,7 @@ def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader)
102117
103118 layer = parse_default_keras_layer (keras_layer , input_names )
104119
105- activation_config = keras_layer ['config' ]['activation' ]
120+ activation_config = keras_layer ['config' ][activation_name ]
106121 quantizer_obj = get_quantizer (activation_config )
107122 activation_config = {}
108123 # some activations are classes
@@ -136,7 +151,7 @@ def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader)
136151 layer ['threshold' ] = activation_config .get ('config' , {}).get ('threshold' , 0.33 )
137152 if layer ['threshold' ] is None :
138153 layer ['threshold' ] = 0.33 # the default ternary tanh threshold for QKeras
139- layer ['activation' ] = 'ternary_tanh'
154+ layer [activation_name ] = 'ternary_tanh'
140155 elif (
141156 activation_config ['class_name' ] == 'quantized_sigmoid'
142157 and not activation_config ['config' ].get ('use_real_sigmoid' , False )
@@ -149,16 +164,27 @@ def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader)
149164 # Quartus seems to have trouble if the width is 1.
150165 layer ['slope_prec' ] = FixedPrecisionType (width = 2 , integer = 0 , signed = False )
151166 layer ['shift_prec' ] = FixedPrecisionType (width = 2 , integer = 0 , signed = False )
152- layer ['activation' ] = activation_config ['class_name' ].replace ('quantized_' , 'hard_' )
167+ layer [activation_name ] = activation_config ['class_name' ].replace ('quantized_' , 'hard_' )
153168 elif activation_config ['class_name' ] == 'quantized_relu' and activation_config ['config' ]['negative_slope' ] != 0 :
154169 layer ['class_name' ] = 'LeakyReLU'
155- layer ['activation' ] = activation_config ['class_name' ].replace ('quantized_' , 'leaky_' )
170+ layer [activation_name ] = activation_config ['class_name' ].replace ('quantized_' , 'leaky_' )
156171 layer ['activ_param' ] = activation_config ['config' ]['negative_slope' ]
157172 else :
158173 layer ['class_name' ] = 'Activation'
159- layer ['activation' ] = activation_config ['class_name' ].replace ('quantized_' , '' )
174+ layer [activation_name ] = activation_config ['class_name' ].replace ('quantized_' , '' )
175+
176+ layer [f'{ activation_name } _quantizer' ] = activation_config
177+
178+ return layer
179+
180+
181+ @keras_handler ('QActivation' )
182+ def parse_qactivation_layer (keras_layer , input_names , input_shapes , data_reader ):
183+
184+ assert keras_layer ['class_name' ] == 'QActivation'
185+
186+ layer = get_activation_quantizer (keras_layer , input_names )
160187
161- layer ['activation_quantizer' ] = activation_config
162188 return layer , [shape for shape in input_shapes [0 ]]
163189
164190
0 commit comments