Skip to content

Commit 723195e

Browse files
authored
oneAPI RNN State and Activation Quantizers (#1195)
* RNN Accum type to QKeras state quantizer and RNN QActivation parsing * pre-commit hook * QKeras recurrent activation * Cast array size to int * QRNN fix
1 parent 72badf1 commit 723195e

File tree

5 files changed

+42
-16
lines changed

5 files changed

+42
-16
lines changed

hls4ml/backends/oneapi/passes/reshaping_templates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def format(self, node):
188188
new_shape, perm_strides = node.model.config.backend.permute_config_gen(name, shape, perm)
189189
return transpose_config_template.format(
190190
dims=len(shape),
191-
N=np.prod(shape),
191+
N=int(np.prod(shape)),
192192
from_shape=', '.join(str(x) for x in shape),
193193
perm=', '.join(str(x) for x in perm),
194194
perm_strides=', '.join(str(x) for x in perm_strides),
@@ -251,5 +251,5 @@ def __init__(self):
251251

252252
def format(self, node):
253253
params = self._default_function_params(node)
254-
params['size'] = np.prod(node.get_output_variable().shape)
254+
params['size'] = int(np.prod(node.get_output_variable().shape))
255255
return self.template.format(**params)

hls4ml/converters/keras/qkeras.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,30 @@ def parse_qrnn_layer(keras_layer, input_names, input_shapes, data_reader):
8080
layer['weight_quantizer'] = get_quantizer_from_config(keras_layer, 'kernel')
8181
layer['recurrent_weight_quantizer'] = get_quantizer_from_config(keras_layer, 'recurrent')
8282
layer['bias_quantizer'] = get_quantizer_from_config(keras_layer, 'bias')
83+
layer['accum_quantizer'] = get_quantizer_from_config(keras_layer, 'state')
84+
85+
if not isinstance(keras_layer['config']['activation'], str):
86+
activation = get_activation_quantizer(keras_layer, input_names)
87+
88+
assert activation['class_name'] != 'HardActivation', 'Hard activation not supported'
89+
90+
layer['activation'] = activation['activation']
91+
layer['activation_quantizer'] = activation['activation_quantizer']
92+
93+
if keras_layer['class_name'] in ['QLSTM', 'QGRU'] and not isinstance(keras_layer['config']['recurrent_activation'], str):
94+
recurrent_activation = get_activation_quantizer(keras_layer, input_names, activation_name='recurrent_activation')
95+
96+
assert recurrent_activation['class_name'] != 'HardActivation', 'Hard activation not supported'
97+
98+
layer['recurrent_activation'] = recurrent_activation['recurrent_activation']
99+
layer['recurrent_activation_config'] = recurrent_activation
83100

84101
return layer, output_shape
85102

86103

87-
@keras_handler('QActivation')
88-
def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader):
104+
def get_activation_quantizer(keras_layer, input_names, activation_name='activation'):
89105
from qkeras.quantizers import get_quantizer
90106

91-
assert keras_layer['class_name'] == 'QActivation'
92107
supported_activations = [
93108
'quantized_relu',
94109
'quantized_tanh',
@@ -102,7 +117,7 @@ def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader)
102117

103118
layer = parse_default_keras_layer(keras_layer, input_names)
104119

105-
activation_config = keras_layer['config']['activation']
120+
activation_config = keras_layer['config'][activation_name]
106121
quantizer_obj = get_quantizer(activation_config)
107122
activation_config = {}
108123
# some activations are classes
@@ -136,7 +151,7 @@ def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader)
136151
layer['threshold'] = activation_config.get('config', {}).get('threshold', 0.33)
137152
if layer['threshold'] is None:
138153
layer['threshold'] = 0.33 # the default ternary tanh threshold for QKeras
139-
layer['activation'] = 'ternary_tanh'
154+
layer[activation_name] = 'ternary_tanh'
140155
elif (
141156
activation_config['class_name'] == 'quantized_sigmoid'
142157
and not activation_config['config'].get('use_real_sigmoid', False)
@@ -149,16 +164,27 @@ def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader)
149164
# Quartus seems to have trouble if the width is 1.
150165
layer['slope_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
151166
layer['shift_prec'] = FixedPrecisionType(width=2, integer=0, signed=False)
152-
layer['activation'] = activation_config['class_name'].replace('quantized_', 'hard_')
167+
layer[activation_name] = activation_config['class_name'].replace('quantized_', 'hard_')
153168
elif activation_config['class_name'] == 'quantized_relu' and activation_config['config']['negative_slope'] != 0:
154169
layer['class_name'] = 'LeakyReLU'
155-
layer['activation'] = activation_config['class_name'].replace('quantized_', 'leaky_')
170+
layer[activation_name] = activation_config['class_name'].replace('quantized_', 'leaky_')
156171
layer['activ_param'] = activation_config['config']['negative_slope']
157172
else:
158173
layer['class_name'] = 'Activation'
159-
layer['activation'] = activation_config['class_name'].replace('quantized_', '')
174+
layer[activation_name] = activation_config['class_name'].replace('quantized_', '')
175+
176+
layer[f'{activation_name}_quantizer'] = activation_config
177+
178+
return layer
179+
180+
181+
@keras_handler('QActivation')
182+
def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader):
183+
184+
assert keras_layer['class_name'] == 'QActivation'
185+
186+
layer = get_activation_quantizer(keras_layer, input_names)
160187

161-
layer['activation_quantizer'] = activation_config
162188
return layer, [shape for shape in input_shapes[0]]
163189

164190

hls4ml/converters/keras_to_hls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def parse_keras_model(model_arch, reader):
209209
'HGQ>UnaryLUT',
210210
]
211211
# Recurrent layers
212-
recurrent_layers = ['SimpleRNN', 'LSTM', 'GRU']
212+
recurrent_layers = ['SimpleRNN', 'LSTM', 'GRU', 'QSimpleRNN', 'QLSTM', 'QGRU']
213213
# All supported layers
214214
supported_layers = get_supported_keras_layers() + skip_layers
215215

hls4ml/model/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def __init__(self, var_name, type_name, precision, data, quantizer=None, **kwarg
437437
self.data = data
438438
self.nzeros = -1
439439
self.shape = list(self.data.shape)
440-
self.data_length = np.prod(self.data.shape)
440+
self.data_length = int(np.prod(self.data.shape))
441441
self.nonzeros = np.count_nonzero(self.data)
442442
self.nzeros = self.data_length - self.nonzeros
443443
self.min = np.min(self.data)

hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ void gru(const data_T &data, res_T &res, const typename CONFIG_T::weight_t &weig
165165
const typename CONFIG_T::recurrent_weight_t &recurrent_weights, const typename CONFIG_T::bias_t &bias,
166166
const typename CONFIG_T::recurrent_bias_t &recurrent_bias) {
167167

168-
using h_T = array<typename res_T::value_type, CONFIG_T::n_units>;
168+
using h_T = array<typename CONFIG_T::accum_t, CONFIG_T::n_units>;
169169
[[intel::fpga_register]] data_T x;
170170
[[intel::fpga_register]] h_T h;
171171

@@ -259,7 +259,7 @@ void simple_rnn(const data_T &data, res_T &res, const typename CONFIG_T::weight_
259259
const typename CONFIG_T::recurrent_weight_t &rec_kernel, const typename CONFIG_T::bias_t &bias) {
260260

261261
using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;
262-
using h_T = array<typename res_T::value_type, CONFIG_T::n_out>;
262+
using h_T = array<typename CONFIG_T::accum_t, CONFIG_T::n_out>;
263263

264264
[[intel::fpga_register]] h_T hidden_state[CONFIG_T::n_timesteps + 1];
265265
[[intel::fpga_register]] h_T hidden_state_temp;
@@ -500,7 +500,7 @@ void lstm(const data_T &data, res_T &res, const typename CONFIG_T::weight_i_t &W
500500
// Note: currently this does not support recurrent bias
501501

502502
using in_T = array<typename data_T::value_type, CONFIG_T::n_in>;
503-
using h_T = array<typename res_T::value_type, CONFIG_T::n_out>;
503+
using h_T = array<typename CONFIG_T::accum_t, CONFIG_T::n_out>;
504504

505505
[[intel::fpga_register]] h_T hidden_state[CONFIG_T::n_timesteps + 1];
506506
[[intel::fpga_register]] h_T hidden_state_temp;

0 commit comments

Comments
 (0)