1414
1515
1616class APRInputLayer :
17- def __call__ (self , apr_arr , parts_arr ):
17+ def __call__ (self , apr_arr , parts_arr , dtype = np . float32 ):
1818
1919 batch_size = len (apr_arr )
2020 #assert parts_arr.shape[0] == batch_size
@@ -32,7 +32,7 @@ def __call__(self, apr_arr, parts_arr):
3232 npart = apr .total_number_particles ()
3333 npartmax = max (npart , npartmax )
3434
35- x = np .empty ((batch_size , nch , npartmax ), dtype = np . float32 )
35+ x = np .empty ((batch_size , nch , npartmax ), dtype = dtype )
3636
3737 for i in range (len (parts_arr )):
3838 #for j in range(nch):
@@ -57,7 +57,7 @@ def forward(ctx, intensities, weights, bias, aprs, level_deltas):
5757
5858 ctx .save_for_backward (intensities , weights , bias , torch .from_numpy (np .copy (dlevel )))
5959
60- output = np .zeros (shape = (intensities .shape [0 ], weights .shape [0 ], intensities .shape [2 ]), dtype = np . float32 )
60+ output = np .zeros (shape = (intensities .shape [0 ], weights .shape [0 ], intensities .shape [2 ]), dtype = intensities . data . numpy (). dtype )
6161
6262 aprnn .convolve (aprs , intensities .data .numpy (), weights .data .numpy (), bias .data .numpy (), output , dlevel )
6363
@@ -71,7 +71,7 @@ def backward(ctx, grad_output):
7171
7272 dlevel = level_deltas .data .numpy ()
7373
74- d_input = np .zeros (input_features .shape , dtype = np . float32 )
74+ d_input = np .zeros (input_features .shape , dtype = input_features . data . numpy (). dtype )
7575 d_weights = np .empty (weights .shape , dtype = np .float32 )
7676 d_bias = np .empty (bias .shape , dtype = np .float32 )
7777
@@ -126,7 +126,7 @@ def forward(ctx, intensities, weights, bias, aprs, level_deltas):
126126
127127 ctx .save_for_backward (intensities , weights , bias , torch .from_numpy (np .copy (dlevel )))
128128
129- output = np .zeros (shape = (intensities .shape [0 ], weights .shape [0 ], intensities .shape [2 ]), dtype = np . float32 )
129+ output = np .zeros (shape = (intensities .shape [0 ], weights .shape [0 ], intensities .shape [2 ]), dtype = intensities . data . numpy (). dtype )
130130
131131 aprnn .convolve3x3 (aprs , intensities .data .numpy (), weights .data .numpy (), bias .data .numpy (), output , dlevel )
132132
@@ -140,7 +140,7 @@ def backward(ctx, grad_output):
140140
141141 dlevel = level_deltas .data .numpy ()
142142
143- d_input = np .zeros (input_features .shape , dtype = np . float32 )
143+ d_input = np .zeros (input_features .shape , dtype = input_features . data . numpy (). dtype )
144144 d_weights = np .empty (weights .shape , dtype = np .float32 )
145145 d_bias = np .empty (bias .shape , dtype = np .float32 )
146146
@@ -185,7 +185,7 @@ def forward(ctx, intensities, weights, bias, aprs, level_deltas):
185185
186186 ctx .save_for_backward (intensities , weights , bias , torch .from_numpy (np .copy (dlevel )))
187187
188- output = np .zeros (shape = (intensities .shape [0 ], weights .shape [0 ], intensities .shape [2 ]), dtype = np . float32 )
188+ output = np .zeros (shape = (intensities .shape [0 ], weights .shape [0 ], intensities .shape [2 ]), dtype = intensities . data . numpy (). dtype )
189189
190190 aprnn .convolve1x1 (aprs , intensities .data .numpy (), weights .data .numpy (), bias .data .numpy (), output , dlevel )
191191
@@ -198,12 +198,13 @@ def backward(ctx, grad_output):
198198 aprs = ctx .apr
199199
200200 dlevel = level_deltas .data .numpy ()
201+ np_input = input_features .data .numpy ()
201202
202- d_input = np .zeros (input_features .shape , dtype = np . float32 )
203+ d_input = np .zeros (input_features .shape , dtype = np_input . dtype )
203204 d_weights = np .empty (weights .shape , dtype = np .float32 )
204205 d_bias = np .empty (bias .shape , dtype = np .float32 )
205206
206- aprnn .convolve1x1_backward (aprs , grad_output .data .numpy (), input_features . data . numpy () , weights .data .numpy (),
207+ aprnn .convolve1x1_backward (aprs , grad_output .data .numpy (), np_input , weights .data .numpy (),
207208 d_input , d_weights , d_bias , dlevel )
208209
209210 return torch .from_numpy (d_input ), torch .from_numpy (d_weights ), torch .from_numpy (d_bias ), None , None
@@ -234,7 +235,7 @@ def forward(self, input_features, apr, level_deltas):
234235
235236class APRMaxPoolFunction (Function ):
236237 @staticmethod
237- def forward (ctx , intensities , apr , level_deltas ):
238+ def forward (ctx , intensities , apr , level_deltas , inc_dlvl ):
238239
239240 dlevel = level_deltas .data .numpy ()
240241
@@ -246,13 +247,14 @@ def forward(ctx, intensities, apr, level_deltas):
246247 npart = aprnn .number_particles_after_pool (apr [i ], dlevel [i ])
247248 npartmax = max (npartmax , npart )
248249
249- output = - (np .finfo (np .float32 ).max / 2 ) * np .ones (shape = (intensities .shape [0 ], intensities .shape [1 ], npartmax ), dtype = np . float32 )
250+ output = - (np .finfo (np .float32 ).max / 2 ) * np .ones (shape = (intensities .shape [0 ], intensities .shape [1 ], npartmax ), dtype = intensities . data . numpy (). dtype )
250251 index_arr = - np .ones (output .shape , dtype = np .int64 )
251252
252253 aprnn .max_pool (apr , intensities .data .numpy (), output , dlevel , index_arr )
253254
254- for i in range (level_deltas .shape [0 ]):
255- level_deltas [i ] += 1
255+ if inc_dlvl :
256+ for i in range (level_deltas .shape [0 ]):
257+ level_deltas [i ] += 1
256258
257259 ctx .max_indices = index_arr
258260
@@ -262,17 +264,19 @@ def forward(ctx, intensities, apr, level_deltas):
262264 def backward (ctx , grad_output ):
263265
264266 max_indices = ctx .max_indices
265- grad_input = np .zeros (ctx .input_shape , dtype = np . float32 )
267+ grad_input = np .zeros (ctx .input_shape , dtype = grad_output . data . numpy (). dtype )
266268
267269 aprnn .max_pool_backward (grad_output .data .numpy (), grad_input , max_indices )
268270
269- return torch .from_numpy (grad_input ), None , None
271+ return torch .from_numpy (grad_input ), None , None , None
270272
271273
272274class APRMaxPool (nn .Module ):
273- def __init__ (self ):
275+ def __init__ (self , increment_level_delta = True ):
274276 super (APRMaxPool , self ).__init__ ()
275277
278+ self .increment_level_delta = increment_level_delta
279+
276280 def forward (self , input_features , apr , level_deltas ):
277281
278- return APRMaxPoolFunction .apply (input_features , apr , level_deltas )
282+ return APRMaxPoolFunction .apply (input_features , apr , level_deltas , self . increment_level_delta )
0 commit comments