Skip to content

Commit 17331a4

Browse files
authored
Add precisoin bits to recurrent pytorch pytest (#1215)
* change default precision to fixed<32,16> throughout * added rounding to the input values plus small fixes
1 parent 9a7529a commit 17331a4

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

test/pytest/test_recurrent_pytorch.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,18 @@ def test_gru(backend, io_type):
3838
model.eval()
3939

4040
X_input = torch.randn(1, 1, 10)
41+
X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16>
4142
h0 = torch.randn(1, 1, 20)
43+
h0 = np.round(h0 * 2**16) * 2**-16
4244

4345
pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(h0)).detach().numpy()
4446

4547
config = config_from_pytorch_model(
46-
model, [(None, 1, 10), (None, 1, 20)], channels_last_conversion="off", transpose_outputs=False
48+
model,
49+
[(None, 1, 10), (None, 1, 20)],
50+
channels_last_conversion="off",
51+
transpose_outputs=False,
52+
default_precision='fixed<32,16>',
4753
)
4854
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_gru_{backend}_{io_type}')
4955

@@ -63,10 +69,13 @@ def test_gru_stream(backend, io_type):
6369
model.eval()
6470

6571
X_input = torch.randn(1, 1, 10)
72+
X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16>
6673

6774
pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy()
6875

69-
config = config_from_pytorch_model(model, (None, 1, 10), channels_last_conversion="off", transpose_outputs=False)
76+
config = config_from_pytorch_model(
77+
model, (None, 1, 10), channels_last_conversion="off", transpose_outputs=False, default_precision='fixed<32,16>'
78+
)
7079
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_gru_{backend}_{io_type}')
7180

7281
hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type)
@@ -105,13 +114,20 @@ def test_lstm(backend, io_type):
105114
model.eval()
106115

107116
X_input = torch.randn(1, 1, 10)
117+
X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16>
108118
h0 = torch.randn(1, 1, 20)
119+
h0 = np.round(h0 * 2**16) * 2**-16
109120
c0 = torch.randn(1, 1, 20)
121+
c0 = np.round(c0 * 2**16) * 2**-16
110122

111123
pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(h0), torch.tensor(c0)).detach().numpy()
112124

113125
config = config_from_pytorch_model(
114-
model, [(None, 1, 10), (None, 1, 20), (None, 1, 20)], channels_last_conversion="off", transpose_outputs=False
126+
model,
127+
[(None, 1, 10), (None, 1, 20), (None, 1, 20)],
128+
channels_last_conversion="off",
129+
transpose_outputs=False,
130+
default_precision='fixed<32,16>',
115131
)
116132
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_lstm_{backend}_{io_type}')
117133

@@ -140,10 +156,13 @@ def test_lstm_stream(backend, io_type):
140156
model.eval()
141157

142158
X_input = torch.randn(1, 1, 10)
159+
X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16>
143160

144161
pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy()
145162

146-
config = config_from_pytorch_model(model, [(None, 1, 10)], channels_last_conversion="off", transpose_outputs=False)
163+
config = config_from_pytorch_model(
164+
model, [(None, 1, 10)], channels_last_conversion="off", transpose_outputs=False, default_precision='fixed<32,16>'
165+
)
147166
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_lstm_{backend}_{io_type}')
148167

149168
hls_model = convert_from_pytorch_model(
@@ -179,17 +198,26 @@ def test_rnn(backend, io_type):
179198
model.eval()
180199

181200
X_input = torch.randn(1, 1, 10)
201+
X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16>
182202
h0 = torch.zeros(1, 1, 20)
183203

184204
pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(h0)).detach().numpy()
185205

186206
config = config_from_pytorch_model(
187-
model, [(1, 10), (1, 20)], channels_last_conversion="off", transpose_outputs=False
207+
model,
208+
[(1, 10), (1, 20)],
209+
channels_last_conversion="off",
210+
transpose_outputs=False,
211+
default_precision='fixed<32,16>',
188212
)
189213
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_rnn_{backend}_{io_type}')
190214

191215
hls_model = convert_from_pytorch_model(
192-
model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type
216+
model,
217+
hls_config=config,
218+
output_dir=output_dir,
219+
backend=backend,
220+
io_type=io_type,
193221
)
194222

195223
hls_model.compile()

0 commit comments

Comments
 (0)