Skip to content

Commit 2287f4d

Browse files
JanFSchultejmitrevspre-commit-ci[bot]
authored
Initial values for the hidden/cell state for LSTM and GRU models in Pytorch (#1120)
* allow initial values for the hidden/cell state to be passed for LSTM and GRU models * initial state rnns for oneAPI * fix data types in quartus * more type updates * update types for lstm init state oneAPI * fix pytorch_order for GRU, recurrent bias for simpleNN, oneAPI * fix simplernn in oneAPI * snapshot that compiles but fails pytests * fix order of indices for pytorch simple RNN oneAPI * [pre-commit.ci] auto fixes from pre-commit hooks * trigger pre-commit * trigger pre-commit * trigger pre-commit * fix simple-rnn config for Keras; make test names unique * remove unused base config, update style * style comments from Vladimir --------- Co-authored-by: Jovan Mitrevski <jmitrevs@fnal.gov> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 18ccc61 commit 2287f4d

File tree

11 files changed

+755
-80
lines changed

11 files changed

+755
-80
lines changed

hls4ml/backends/oneapi/passes/recurrent_templates.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,14 @@
9292
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
9393
9494
static const unsigned reuse_factor = {reuse};
95+
static const unsigned pytorch_order = {pytorch};
9596
static const bool store_weights_in_bram = false;
9697
}};\n'''
9798

9899
gru_function_template = 'nnet::gru<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
100+
gru_function_initial_state_template = (
101+
'nnet::gru_init_state<{input_t}, {h_t}, {output_t}, {config}>({input}, {init_state}, {output}, {w}, {wr}, {b}, {br});'
102+
)
99103
gru_task_sequence_template = 'task_sequence<nnet::gru_stream<{input_pipe}, {output_pipe}, {config}>> {name};'
100104
gru_stream_function_template = '{name}.async({w}, {wr}, {b}, {br});'
101105

@@ -120,6 +124,7 @@ def format(self, node):
120124
params['config_mult_h'] = f'config{node.index}_h_mult'
121125
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
122126
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
127+
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
123128
gru_config = self.gru_template.format(**params)
124129

125130
# Activation is on candidate hidden state, dimensionality (1, n_units)
@@ -163,15 +168,23 @@ def format(self, node):
163168
class GRUFunctionTemplate(FunctionCallTemplate):
164169
def __init__(self):
165170
super().__init__(GRU, include_header=recurrent_include_list)
166-
self.template = gru_function_template
167171

168172
def format(self, node):
169173
params = self._default_function_params(node)
174+
if params['pass_initial_states'] == 'true':
175+
params['h_t'] = node.get_input_variable(node.inputs[1]).type.name
176+
params['init_state'] = node.get_input_variable(node.inputs[1]).name
170177
params['w'] = node.get_weights('weight').name
171178
params['b'] = node.get_weights('bias').name
172179
params['wr'] = node.get_weights('recurrent_weight').name
173180
params['br'] = node.get_weights('recurrent_bias').name
174-
return self.template.format(**params)
181+
182+
if params['pass_initial_states'] == 'true':
183+
template = gru_function_initial_state_template
184+
else:
185+
template = gru_function_template
186+
187+
return template.format(**params)
175188

176189

177190
class GRUTaskSequenceTemplate(TaskSequenceTemplate):
@@ -235,6 +248,10 @@ def format(self, node):
235248
}};\n"""
236249

237250
lstm_function_template = 'nnet::lstm<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
251+
lstm_function_initial_state_template = (
252+
'nnet::lstm_init_state<{input_t}, {h_t}, {hc_t}, {output_t}, {config}>'
253+
'({input}, {init_state}, {init_cell}, {output}, {weights});'
254+
)
238255

239256

240257
class LSTMConfigTemplate(LayerConfigTemplate):
@@ -275,11 +292,16 @@ def format(self, node):
275292
class LSTMFunctionTemplate(FunctionCallTemplate):
276293
def __init__(self):
277294
super().__init__(LSTM, include_header=recurrent_include_list)
278-
self.template = lstm_function_template
279295

280296
def format(self, node):
281297
params = self._default_function_params(node)
282298

299+
if params['pass_initial_states'] == 'true':
300+
params['h_t'] = node.get_input_variable(node.inputs[1]).type.name
301+
params['init_state'] = node.get_input_variable(node.inputs[1]).name
302+
params['init_cell'] = node.get_input_variable(node.inputs[2]).name
303+
params['hc_t'] = node.get_input_variable(node.inputs[2]).type.name
304+
283305
types = ['i', 'f', 'c', 'o']
284306
params['weights'] = ''
285307
for t in types:
@@ -289,13 +311,18 @@ def format(self, node):
289311
for t in types:
290312
params['weights'] += 'bias_{}_{}{}'.format(t, str(node.index), ',' if t != 'o' else '')
291313

292-
return self.template.format(**params)
314+
if params['pass_initial_states'] == 'true':
315+
template = lstm_function_initial_state_template
316+
else:
317+
template = lstm_function_template
318+
319+
return template.format(**params)
293320

294321

295322
################################################
296323
# SimpleRNN Template
297324
################################################
298-
simple_rnn_config_template = """struct config{index} : nnet::simpleRNN_config {{
325+
simple_rnn_config_template = """struct config{index} : nnet::simple_rnn_config {{
299326
static const unsigned n_in = {n_in};
300327
static const unsigned n_out = {n_out};
301328
static const unsigned n_outputs = {n_outputs};
@@ -306,6 +333,7 @@ def format(self, node):
306333
typedef {weight_t.name} weight_t;
307334
typedef {bias_t.name} bias_t;
308335
typedef {recurrent_weight_t.name} recurrent_weight_t;
336+
typedef {recurrent_bias_t.name} recurrent_bias_t;
309337
310338
typedef {act_t} ACT_CONFIG_T;
311339
template<class x_T, class y_T, class config_T>
@@ -320,6 +348,10 @@ def format(self, node):
320348
}};\n"""
321349

322350
simple_rnn_function_template = 'nnet::simple_rnn<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
351+
simple_rnn_pytorch_function_template = (
352+
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
353+
)
354+
simple_rnn_pytorch_function_initial_state_template = 'nnet::simple_rnn_pytorch_init_state<{input_t}, {h_t}, {output_t}, {config}>({input}, {init_state}, {output}, {weights});' # noqa E501
323355

324356

325357
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -341,6 +373,9 @@ def format(self, node):
341373
)
342374
simple_rnn_params['recurrent_activation'] = 'relu'
343375

376+
# In Keras there is no recurrent bias, so put a placeholder
377+
simple_rnn_params.setdefault('recurrent_bias_t', simple_rnn_params['bias_t'])
378+
344379
simple_rnn_config = self.template.format(**simple_rnn_params)
345380

346381
act_params = self._default_config_params(node)
@@ -365,5 +400,17 @@ def __init__(self):
365400

366401
def format(self, node):
367402
params = self._default_function_params(node)
368-
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
369-
return self.template.format(**params)
403+
if params['pass_initial_states'] == 'true':
404+
params['h_t'] = node.get_input_variable(node.inputs[1]).type.name
405+
params['init_state'] = node.get_input_variable(node.inputs[1]).name
406+
407+
if node.get_attr('pytorch', False):
408+
if params['pass_initial_states'] == 'true':
409+
template = simple_rnn_pytorch_function_initial_state_template
410+
else:
411+
template = simple_rnn_pytorch_function_template
412+
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
413+
else:
414+
template = simple_rnn_function_template
415+
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
416+
return template.format(**params)

hls4ml/backends/quartus/passes/recurrent_templates.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@
7171
}};\n'''
7272

7373
gru_function_template = 'nnet::gru<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
74+
gru_function_initial_state_template = (
75+
'nnet::gru<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {w}, {wr}, {b}, {br});'
76+
)
7477

7578

7679
class GRUConfigTemplate(LayerConfigTemplate):
@@ -137,15 +140,23 @@ def format(self, node):
137140
class GRUFunctionTemplate(FunctionCallTemplate):
138141
def __init__(self):
139142
super().__init__(GRU, include_header=recurrent_include_list)
140-
self.template = gru_function_template
141143

142144
def format(self, node):
143145
params = self._default_function_params(node)
146+
if params['pass_initial_states'] == 'true':
147+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
148+
params['input2'] = node.get_input_variable(node.inputs[1]).name
144149
params['w'] = node.get_weights('weight').name
145150
params['b'] = node.get_weights('bias').name
146151
params['wr'] = node.get_weights('recurrent_weight').name
147152
params['br'] = node.get_weights('recurrent_bias').name
148-
return self.template.format(**params)
153+
154+
if params['pass_initial_states'] == 'true':
155+
template = gru_function_initial_state_template
156+
else:
157+
template = gru_function_template
158+
159+
return template.format(**params)
149160

150161

151162
################################################
@@ -174,6 +185,9 @@ def format(self, node):
174185
}};\n"""
175186

176187
lstm_function_template = 'nnet::lstm<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
188+
lstm_function_initial_state_template = (
189+
'nnet::lstm<{input_t}, {input2_t}, {input3_t}, {output_t}, {config}>({input}, {input2}, {input3}, {output}, {weights});'
190+
)
177191

178192

179193
class LSTMConfigTemplate(LayerConfigTemplate):
@@ -214,11 +228,16 @@ def format(self, node):
214228
class LSTMFunctionTemplate(FunctionCallTemplate):
215229
def __init__(self):
216230
super().__init__(LSTM, include_header=recurrent_include_list)
217-
self.template = lstm_function_template
218231

219232
def format(self, node):
220233
params = self._default_function_params(node)
221234

235+
if params['pass_initial_states'] == 'true':
236+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
237+
params['input2'] = node.get_input_variable(node.inputs[1]).name
238+
params['input3'] = node.get_input_variable(node.inputs[2]).name
239+
params['input3_t'] = node.get_input_variable(node.inputs[2]).type.name
240+
222241
types = ['i', 'f', 'c', 'o']
223242
params['weights'] = ''
224243
for t in types:
@@ -228,13 +247,18 @@ def format(self, node):
228247
for t in types:
229248
params['weights'] += 'bias_{}_{}{}'.format(t, str(node.index), ',' if t != 'o' else '')
230249

231-
return self.template.format(**params)
250+
if params['pass_initial_states'] == 'true':
251+
template = lstm_function_initial_state_template
252+
else:
253+
template = lstm_function_template
254+
255+
return template.format(**params)
232256

233257

234258
################################################
235259
# SimpleRNN Template
236260
################################################
237-
simple_rnn_config_template = """struct config{index} : nnet::simpleRNN_config {{
261+
simple_rnn_config_template = """struct config{index} : nnet::simple_rnn_config {{
238262
static const unsigned n_in = {n_in};
239263
static const unsigned n_out = {n_out};
240264
static const unsigned n_outputs = {n_outputs};
@@ -261,6 +285,9 @@ def format(self, node):
261285
simple_rnn_pytorch_function_template = (
262286
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
263287
)
288+
simple_rnn_pytorch_function_initial_state_template = (
289+
'nnet::simple_rnn_pytorch<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {weights});'
290+
)
264291

265292

266293
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -302,13 +329,20 @@ def format(self, node):
302329
class SimpleRNNFunctionTemplate(FunctionCallTemplate):
303330
def __init__(self):
304331
super().__init__(SimpleRNN, include_header=recurrent_include_list)
305-
self.template = simple_rnn_function_template
306332

307333
def format(self, node):
308334
params = self._default_function_params(node)
335+
if params['pass_initial_states'] == 'true':
336+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
337+
params['input2'] = node.get_input_variable(node.inputs[1]).name
338+
309339
if node.get_attr('pytorch', False):
310-
self.template = simple_rnn_pytorch_function_template
340+
if params['pass_initial_states'] == 'true':
341+
template = simple_rnn_pytorch_function_initial_state_template
342+
else:
343+
template = simple_rnn_pytorch_function_template
311344
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
312345
else:
346+
template = simple_rnn_function_template
313347
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
314-
return self.template.format(**params)
348+
return template.format(**params)

hls4ml/backends/vivado/passes/recurrent_templates.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@
8787
}};\n"""
8888

8989
recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
90+
recr_function_template_initial_states_lstm = 'nnet::{recr_type}_stack<{input_t}, {input2_t}, {input3_t}, {output_t}, {config}>({input}, {input2}, {input3}, {output}, {w}, {wr}, {b}, {br});' # noqa: E501
91+
recr_function_template_initial_states_gru = 'nnet::{recr_type}_stack<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {w}, {wr}, {b}, {br});' # noqa: E501
9092

9193
recr_include_list = ['nnet_utils/nnet_recurrent.h']
9294

@@ -208,10 +210,16 @@ def format(self, node):
208210
class RecurrentFunctionTemplate(FunctionCallTemplate):
209211
def __init__(self):
210212
super().__init__((LSTM, GRU), include_header=recr_include_list)
211-
self.template = recr_function_template
212213

213214
def format(self, node):
214215
params = self._default_function_params(node)
216+
if params['pass_initial_states'] == 'true':
217+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
218+
params['input2'] = node.get_input_variable(node.inputs[1]).name
219+
if node.class_name == 'LSTM':
220+
params['input3'] = node.get_input_variable(node.inputs[2]).name
221+
params['input3_t'] = node.get_input_variable(node.inputs[2]).type.name
222+
215223
params['w'] = node.get_weights('weight').name
216224
params['b'] = node.get_weights('bias').name
217225
params['wr'] = node.get_weights('recurrent_weight').name
@@ -220,4 +228,12 @@ def format(self, node):
220228
params['recurrent_activation'] = node.get_attr('recurrent_activation')
221229
params['recr_type'] = node.class_name.lower()
222230

223-
return self.template.format(**params)
231+
if params['pass_initial_states'] == 'true':
232+
if node.class_name == 'LSTM':
233+
template = recr_function_template_initial_states_lstm
234+
else:
235+
template = recr_function_template_initial_states_gru
236+
else:
237+
template = recr_function_template
238+
239+
return template.format(**params)

hls4ml/converters/pytorch/recurrent.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import warnings
2-
31
import numpy as np
42

53
from hls4ml.converters.pytorch_to_hls import pytorch_handler
@@ -15,14 +13,13 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas
1513

1614
layer["name"] = layer_name
1715

18-
layer['inputs'] = [input_names[0]]
19-
if len(input_names) > 1:
20-
warnings.warn(
21-
'hls4ml disregards the initial value of the hidden state passed to the model, assuming that it is all zeros',
22-
stacklevel=2,
23-
)
16+
layer['inputs'] = input_names
17+
if 'IOType' in config.keys():
18+
if len(input_names) > 1 and config['IOType'] == 'io_stream':
19+
raise Exception('Passing initial values for the hidden state is not support for io_stream input type.')
20+
2421
layer['class_name'] = operation
25-
if operation == "RNN":
22+
if operation == 'RNN':
2623
layer['class_name'] = 'SimpleRNN'
2724

2825
layer['return_sequences'] = False # parameter does not exist in pytorch
@@ -31,7 +28,7 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas
3128
if layer['class_name'] == 'SimpleRNN':
3229
layer['activation'] = class_object.nonlinearity # Default is tanh, can also be ReLU in pytorch
3330
else:
34-
layer['activation'] = "tanh" # GRU and LSTM are hard-coded to use tanh in pytorch
31+
layer['activation'] = 'tanh' # GRU and LSTM are hard-coded to use tanh in pytorch
3532

3633
if layer['class_name'] == 'GRU' or layer['class_name'] == 'LSTM':
3734
layer['recurrent_activation'] = 'sigmoid' # GRU and LSTM are hard-coded to use sigmoid in pytorch
@@ -51,7 +48,6 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas
5148

5249
if class_object.bidirectional:
5350
raise Exception('hls4ml does not support birectional RNNs')
54-
5551
if class_object.dropout > 0:
5652
raise Exception('hls4ml does not support RNNs with dropout')
5753

@@ -70,5 +66,9 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas
7066
output_shape = [input_shapes[0][0], layer['n_out']]
7167

7268
layer['pytorch'] = True # need to switch some behaviors to match pytorch implementations
69+
if len(input_names) == 1:
70+
layer['pass_initial_states'] = False
71+
else:
72+
layer['pass_initial_states'] = True
7373

7474
return layer, output_shape

hls4ml/converters/pytorch_to_hls.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,17 @@ def parse_pytorch_model(config, verbose=True):
225225
# parse info from class object
226226
input_names = [inputs_map.get(str(i), str(i)) for i in node.args]
227227
if pytorch_class in ["RNN", "GRU", "LSTM"]:
228-
# we currently don't support the passing of the initial value of the hidden state to RNN models
229-
input_names = [inputs_map.get(str(node.args[0]), str(node.args[0]))]
230-
input_shapes = [output_shapes[str(node.args[0])]]
228+
input_shapes = []
229+
input_names = []
230+
for arg in node.args:
231+
if isinstance(arg, tuple):
232+
for input in arg:
233+
input_shapes.append(output_shapes[str(input)])
234+
input_names.append(inputs_map.get(str(input), str(input)))
235+
else:
236+
input_shapes.append(output_shapes[str(arg)])
237+
input_names.append(inputs_map.get(str(arg), str(arg)))
238+
231239
# if a 'getitem' is the input to a node, step back in the graph to find the real source of the input
232240
elif "getitem" in node.args[0].name:
233241
for tmp_node in traced_model.graph.nodes:

0 commit comments

Comments
 (0)