Skip to content

Commit 898d53d

Browse files
steltzestzelepijmitrevsvloncar
authored
Depthwise 1D and 2D Resource strategy for io_stream (#1079)
* Declare FIFO depth from depthwise to pointwise as const to supress Vitis HLS warnings * Split stream depthwise resource in 3 cases * Clean hls implementation, extend depthwise test * Pass depthwise2d and sepconv2d tests for various filters and rfs * Include tests to all relevant test files * Run pre-commit * Clean code * Fix vivado hls synthesis * Run precommit * Fix depthwise tests to include only 4 cases * Format code * Fix sepconv and depthwise tests to include only 3 cases * Remove unused test options * Restore "valid" on testbench, use constexpr instead of const for sep conv FIFO depth * Reduce number of filters * Move implementation into a different file * Correct assertion comments * Run precommit * Use the dense function pointer in depthwise convolution --------- Co-authored-by: stzelepi <stylianos.tzelepis@cern.ch> Co-authored-by: Jovan Mitrevski <jmitrevs@fnal.gov> Co-authored-by: Vladimir Loncar <vloncar@users.noreply.github.com>
1 parent 2287f4d commit 898d53d

File tree

12 files changed

+442
-146
lines changed

12 files changed

+442
-146
lines changed

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,27 @@ def format(self, node):
131131
namespace = params['namespace']
132132

133133
if node.get_attr('strategy').lower() == 'latency':
134-
mult_params['dense_function'] = 'nnet::DenseLatency'
134+
if isinstance(node, DepthwiseConv1D):
135+
mult_params['dense_function'] = 'nnet::DepthwiseDenseLatency'
136+
else:
137+
mult_params['dense_function'] = 'nnet::DenseLatency'
135138
elif node.get_attr('strategy').lower() == 'resource':
136-
if int(mult_params['reuse_factor']) <= int(mult_params['n_in']):
137-
mult_params['dense_function'] = 'nnet::DenseResource_rf_leq_nin'
139+
if isinstance(node, DepthwiseConv1D):
140+
if int(mult_params['reuse_factor']) <= int(mult_params['n_out']):
141+
mult_params['dense_function'] = 'nnet::DepthwiseDenseResource_rf_leq_nout'
142+
else:
143+
if int(mult_params['reuse_factor']) % int(mult_params['n_out']) == 0:
144+
mult_params['dense_function'] = 'nnet::DepthwiseDenseResource_rf_gt_nout_rem0'
145+
else:
146+
mult_params['dense_function'] = 'nnet::DepthwiseDenseResource_rf_gt_nout'
138147
else:
139-
mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0'
140-
# The 3rd case is never used
148+
if int(mult_params['reuse_factor']) <= int(mult_params['n_in']):
149+
mult_params['dense_function'] = 'nnet::DenseResource_rf_leq_nin'
150+
else:
151+
if int(mult_params['reuse_factor']) % int(mult_params['n_in']) == 0:
152+
mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0'
153+
else:
154+
mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin'
141155
elif node.get_attr('strategy').lower() == 'resource_unrolled':
142156
mult_params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}'
143157

@@ -262,13 +276,27 @@ def format(self, node):
262276

263277
namespace = params['namespace']
264278
if node.get_attr('strategy').lower() == 'latency':
265-
mult_params['dense_function'] = 'nnet::DenseLatency'
279+
if isinstance(node, DepthwiseConv2D):
280+
mult_params['dense_function'] = 'nnet::DepthwiseDenseLatency'
281+
else:
282+
mult_params['dense_function'] = 'nnet::DenseLatency'
266283
elif node.get_attr('strategy').lower() == 'resource':
267-
if int(mult_params['reuse_factor']) <= int(mult_params['n_in']):
268-
mult_params['dense_function'] = 'nnet::DenseResource_rf_leq_nin'
284+
if isinstance(node, DepthwiseConv2D):
285+
if int(mult_params['reuse_factor']) <= int(mult_params['n_out']):
286+
mult_params['dense_function'] = 'nnet::DepthwiseDenseResource_rf_leq_nout'
287+
else:
288+
if int(mult_params['reuse_factor']) % int(mult_params['n_out']) == 0:
289+
mult_params['dense_function'] = 'nnet::DepthwiseDenseResource_rf_gt_nout_rem0'
290+
else:
291+
mult_params['dense_function'] = 'nnet::DepthwiseDenseResource_rf_gt_nout'
269292
else:
270-
mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0'
271-
# The 3rd case is never used
293+
if int(mult_params['reuse_factor']) <= int(mult_params['n_in']):
294+
mult_params['dense_function'] = 'nnet::DenseResource_rf_leq_nin'
295+
else:
296+
if int(mult_params['reuse_factor']) % int(mult_params['n_in']) == 0:
297+
mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0'
298+
else:
299+
mult_params['dense_function'] = 'nnet::DenseResource_rf_gt_nin'
272300
elif node.get_attr('strategy').lower() == 'resource_unrolled':
273301
mult_params['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}'
274302

hls4ml/templates/vitis/nnet_utils/nnet_sepconv1d_stream.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ void separable_conv_1d_cl(hls::stream<data_T> &data, hls::stream<res_T> &res,
8686
#pragma HLS DATAFLOW
8787

8888
hls::stream<dw_res_T> depthwise_res;
89-
unsigned res_depth = CONFIG_T::depthwise_config::out_width;
89+
constexpr unsigned res_depth = CONFIG_T::depthwise_config::out_width;
9090
#pragma HLS STREAM variable=depthwise_res depth=res_depth
9191

9292
depthwise_conv_1d_buffer_cl<data_T, dw_res_T, typename CONFIG_T::depthwise_config>(data, depthwise_res,

hls4ml/templates/vitis/nnet_utils/nnet_sepconv2d_stream.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ void separable_conv_2d_cl(hls::stream<data_T> &data, hls::stream<res_T> &res,
120120
#pragma HLS DATAFLOW
121121

122122
hls::stream<dw_res_T> depthwise_res;
123-
unsigned res_depth = CONFIG_T::depthwise_config::out_height * CONFIG_T::depthwise_config::out_width;
123+
constexpr unsigned res_depth = CONFIG_T::depthwise_config::out_height * CONFIG_T::depthwise_config::out_width;
124124
#pragma HLS STREAM variable=depthwise_res depth=res_depth
125125

126126
depthwise_conv_2d_buffer_cl<data_T, dw_res_T, typename CONFIG_T::depthwise_config>(data, depthwise_res,
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
#ifndef NNET_DEPTHWISE_PRODUCT_H_
2+
#define NNET_DEPTHWISE_PRODUCT_H_
3+
4+
namespace nnet {
5+
6+
template <class data_T, class res_T, typename CONFIG_T>
7+
void depthwise_product_latency(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],
8+
typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],
9+
typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) {
10+
#pragma HLS INLINE
11+
12+
typename CONFIG_T::accum_t mult[CONFIG_T::n_in];
13+
typename CONFIG_T::accum_t acc[CONFIG_T::n_out];
14+
15+
// Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases
16+
#pragma HLS function_instantiate variable=weights
17+
18+
#pragma HLS PIPELINE II=CONFIG_T::reuse_factor
19+
20+
#pragma HLS ARRAY_PARTITION variable=mult complete
21+
22+
#pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::multiplier_limit
23+
24+
// Do the matrix-multiply
25+
Product:
26+
for (int ii = 0; ii < CONFIG_T::n_in; ii++) {
27+
#pragma HLS UNROLL
28+
mult[ii] = CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[ii], weights[ii]);
29+
}
30+
31+
// Initialize accumulator with input biases
32+
ResetAccum:
33+
for (int iacc = 0; iacc < CONFIG_T::n_out; iacc++) {
34+
#pragma HLS UNROLL
35+
acc[iacc] = (typename CONFIG_T::accum_t)biases[iacc];
36+
}
37+
38+
// Accumulate multiplication result
39+
Accum1:
40+
for (int ii = 0; ii < CONFIG_T::n_in / CONFIG_T::n_out; ii++) {
41+
Accum2:
42+
for (int jj = 0; jj < CONFIG_T::n_out; jj++) {
43+
int index = ii * CONFIG_T::n_out + jj;
44+
acc[jj] += mult[index];
45+
}
46+
}
47+
48+
// Cast to "res_t" type
49+
Result:
50+
for (int ires = 0; ires < CONFIG_T::n_out; ires++) {
51+
#pragma HLS UNROLL
52+
res[ires] = cast<data_T, res_T, CONFIG_T>(acc[ires]);
53+
}
54+
}
55+
56+
template <class data_T, class res_T, typename CONFIG_T>
57+
void depthwise_product_resource_rf_leq_nout(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],
58+
typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],
59+
typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) {
60+
61+
const int nin = CONFIG_T::n_in;
62+
const int nout = CONFIG_T::n_out;
63+
const int rufactor = CONFIG_T::reuse_factor;
64+
const int multfactor = MIN(CONFIG_T::n_in, rufactor);
65+
const int multiplier_limit = DIV_ROUNDUP(nin, multfactor);
66+
const int block_factor = DIV_ROUNDUP(nin, rufactor);
67+
68+
assert((multiplier_limit == block_factor) && "This function is correct only for RF <= N_CHAN");
69+
70+
#pragma HLS function_instantiate variable=weights,biases
71+
#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor
72+
#pragma HLS ARRAY_RESHAPE variable=data block factor=block_factor
73+
74+
#pragma HLS ARRAY_PARTITION variable=biases complete
75+
76+
typename CONFIG_T::accum_t acc[nout];
77+
#pragma HLS ARRAY_PARTITION variable=acc complete
78+
79+
InitAccum:
80+
for (int iacc = 0; iacc < nout; iacc++) {
81+
#pragma HLS UNROLL
82+
acc[iacc] = (typename CONFIG_T::accum_t)biases[iacc];
83+
}
84+
85+
ReuseLoop:
86+
for (int ir = 0; ir < rufactor; ir++) {
87+
#pragma HLS PIPELINE II=1 rewind
88+
89+
int in_index = ir;
90+
int out_index = ir;
91+
92+
MultLoop:
93+
for (int im = 0; im < block_factor; im++) {
94+
#pragma HLS UNROLL
95+
96+
acc[out_index] += static_cast<typename CONFIG_T::accum_t>(
97+
CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[in_index]));
98+
99+
in_index += rufactor;
100+
out_index += rufactor;
101+
102+
if (out_index >= nout) {
103+
out_index -= nout;
104+
}
105+
}
106+
}
107+
108+
// Cast to "res_t" type
109+
Result:
110+
for (int ires = 0; ires < nout; ires++) {
111+
#pragma HLS UNROLL
112+
res[ires] = cast<data_T, res_T, CONFIG_T>(acc[ires]);
113+
}
114+
}
115+
116+
template <class data_T, class res_T, typename CONFIG_T>
117+
void depthwise_product_resource_rf_gt_nout_rem0(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],
118+
typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],
119+
typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) {
120+
121+
const int nin = CONFIG_T::n_in;
122+
const int nout = CONFIG_T::n_out;
123+
const int rufactor = MIN(CONFIG_T::reuse_factor, nin);
124+
const int multfactor = MIN(nin, rufactor);
125+
const int multiplier_limit = DIV_ROUNDUP(nin, multfactor);
126+
const int block_factor = DIV_ROUNDUP(nin, rufactor);
127+
128+
assert((rufactor >= nout && rufactor % nout == 0) &&
129+
"This function is correct only for RF >= N_CHAN && RF % N_CHAN == 0");
130+
131+
#pragma HLS function_instantiate variable=weights,biases
132+
#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor
133+
#pragma HLS ARRAY_RESHAPE variable=data block factor=block_factor
134+
135+
#pragma HLS ARRAY_PARTITION variable=biases complete
136+
137+
typename CONFIG_T::accum_t acc[nout];
138+
#pragma HLS ARRAY_PARTITION variable=acc complete
139+
140+
InitAccum:
141+
for (int iacc = 0; iacc < nout; iacc++) {
142+
#pragma HLS UNROLL
143+
acc[iacc] = (typename CONFIG_T::accum_t)biases[iacc];
144+
}
145+
146+
int outidx[rufactor];
147+
int outstep = 0;
148+
IndexLoop:
149+
for (int ir = 0; ir < rufactor; ir++) {
150+
outidx[ir] = outstep;
151+
outstep++;
152+
if (outstep == nout) {
153+
outstep = 0;
154+
}
155+
}
156+
157+
int out_index = 0;
158+
159+
ReuseLoop:
160+
for (int ir = 0; ir < rufactor; ir++) {
161+
#pragma HLS PIPELINE II=1 rewind
162+
163+
int in_index = ir;
164+
out_index = outidx[ir];
165+
166+
MultLoop:
167+
for (int im = 0; im < block_factor; im++) {
168+
#pragma HLS UNROLL
169+
170+
acc[out_index] += static_cast<typename CONFIG_T::accum_t>(
171+
CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[in_index]));
172+
173+
in_index += rufactor;
174+
}
175+
}
176+
177+
// Cast to "res_t" type
178+
Result:
179+
for (int ires = 0; ires < nout; ires++) {
180+
#pragma HLS UNROLL
181+
res[ires] = cast<data_T, res_T, CONFIG_T>(acc[ires]);
182+
}
183+
}
184+
185+
template <class data_T, class res_T, typename CONFIG_T>
186+
void depthwise_product_resource_gt_nout(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],
187+
typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],
188+
typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) {
189+
190+
const int nin = CONFIG_T::n_in;
191+
const int nout = CONFIG_T::n_out;
192+
const int rufactor = MIN(CONFIG_T::reuse_factor, nin);
193+
const int block_factor = DIV_ROUNDUP(nin, rufactor);
194+
assert((rufactor > nout) && "This function is correct only for RF > N_CHAN");
195+
196+
#pragma HLS function_instantiate variable=weights,biases
197+
#pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor
198+
#pragma HLS ARRAY_RESHAPE variable=data block factor=block_factor
199+
200+
#pragma HLS ARRAY_PARTITION variable=biases complete
201+
202+
typename CONFIG_T::accum_t acc[nout];
203+
#pragma HLS ARRAY_PARTITION variable=acc complete
204+
205+
InitAccum:
206+
for (int iacc = 0; iacc < nout; iacc++) {
207+
#pragma HLS UNROLL
208+
acc[iacc] = (typename CONFIG_T::accum_t)biases[iacc];
209+
}
210+
211+
const int remainder = CONFIG_T::reuse_factor % nout;
212+
213+
int outidx[rufactor];
214+
int outstep = 0;
215+
IndexLoop:
216+
for (int ir = 0; ir < rufactor; ir++) {
217+
outidx[ir] = outstep;
218+
outstep++;
219+
if (outstep == nout) {
220+
outstep = 0;
221+
}
222+
}
223+
224+
ReuseLoop:
225+
for (int ir = 0; ir < rufactor; ir++) {
226+
#pragma HLS PIPELINE II=1 rewind
227+
228+
int in_index = ir;
229+
int out_index = outidx[ir];
230+
231+
MultLoop:
232+
for (int im = 0; im < block_factor; im++) {
233+
#pragma HLS UNROLL
234+
235+
// out_index = in_index % nout;
236+
acc[out_index] += static_cast<typename CONFIG_T::accum_t>(
237+
CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::product(data[in_index], weights[in_index]));
238+
239+
in_index += rufactor;
240+
out_index += remainder;
241+
if (out_index >= nout) {
242+
out_index -= nout;
243+
}
244+
}
245+
}
246+
247+
// Cast to "res_t" type
248+
Result:
249+
for (int ires = 0; ires < nout; ires++) {
250+
#pragma HLS UNROLL
251+
res[ires] = cast<data_T, res_T, CONFIG_T>(acc[ires]);
252+
}
253+
}
254+
255+
template <class data_T, class res_T, typename CONFIG_T>
256+
class DepthwiseDenseLatency : public DepthwiseDenseKernel<data_T, res_T, CONFIG_T> {
257+
public:
258+
static void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],
259+
typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],
260+
typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) {
261+
#pragma HLS INLINE
262+
depthwise_product_latency<data_T, res_T, CONFIG_T>(data, res, weights, biases);
263+
}
264+
};
265+
266+
template <class data_T, class res_T, typename CONFIG_T>
267+
class DepthwiseDenseResource_rf_leq_nout : public DepthwiseDenseKernel<data_T, res_T, CONFIG_T> {
268+
public:
269+
static void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],
270+
typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],
271+
typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) {
272+
#pragma HLS INLINE
273+
depthwise_product_resource_rf_leq_nout<data_T, res_T, CONFIG_T>(data, res, weights, biases);
274+
}
275+
};
276+
277+
template <class data_T, class res_T, typename CONFIG_T>
278+
class DepthwiseDenseResource_rf_gt_nout_rem0 : public DepthwiseDenseKernel<data_T, res_T, CONFIG_T> {
279+
public:
280+
static void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],
281+
typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],
282+
typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) {
283+
#pragma HLS INLINE
284+
depthwise_product_resource_rf_gt_nout_rem0<data_T, res_T, CONFIG_T>(data, res, weights, biases);
285+
}
286+
};
287+
288+
template <class data_T, class res_T, typename CONFIG_T>
289+
class DepthwiseDenseResource_rf_gt_nout : public DepthwiseDenseKernel<data_T, res_T, CONFIG_T> {
290+
public:
291+
static void dense(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],
292+
typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],
293+
typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) {
294+
#pragma HLS INLINE
295+
depthwise_product_resource_gt_nout<data_T, res_T, CONFIG_T>(data, res, weights, biases);
296+
}
297+
};
298+
299+
} // namespace nnet
300+
#endif

0 commit comments

Comments
 (0)