Skip to content

Commit 24e0fd5

Browse files
author
joeljonsson
committed
templated internals for layer operations
1 parent d5f5d47 commit 24e0fd5

File tree

4 files changed

+876
-607
lines changed

4 files changed

+876
-607
lines changed

pyapr/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from .converter import *
2929
from .io import *
3030
from .numerics import *
31-
from .viewer import *
31+
#from .viewer import *
3232

33-
#__all__ = ['data_containers', 'io', 'nn', 'viewer', 'converter', 'numerics']
34-
__all__ = ['data_containers', 'io', 'viewer', 'converter', 'numerics']
33+
__all__ = ['data_containers', 'io', 'nn', 'viewer', 'converter', 'numerics']

pyapr/nn/APRNet.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
class APRInputLayer:
17-
def __call__(self, apr_arr, parts_arr):
17+
def __call__(self, apr_arr, parts_arr, dtype=np.float32):
1818

1919
batch_size = len(apr_arr)
2020
#assert parts_arr.shape[0] == batch_size
@@ -32,7 +32,7 @@ def __call__(self, apr_arr, parts_arr):
3232
npart = apr.total_number_particles()
3333
npartmax = max(npart, npartmax)
3434

35-
x = np.empty((batch_size, nch, npartmax), dtype=np.float32)
35+
x = np.empty((batch_size, nch, npartmax), dtype=dtype)
3636

3737
for i in range(len(parts_arr)):
3838
#for j in range(nch):
@@ -57,7 +57,7 @@ def forward(ctx, intensities, weights, bias, aprs, level_deltas):
5757

5858
ctx.save_for_backward(intensities, weights, bias, torch.from_numpy(np.copy(dlevel)))
5959

60-
output = np.zeros(shape=(intensities.shape[0], weights.shape[0], intensities.shape[2]), dtype=np.float32)
60+
output = np.zeros(shape=(intensities.shape[0], weights.shape[0], intensities.shape[2]), dtype=intensities.data.numpy().dtype)
6161

6262
aprnn.convolve(aprs, intensities.data.numpy(), weights.data.numpy(), bias.data.numpy(), output, dlevel)
6363

@@ -71,7 +71,7 @@ def backward(ctx, grad_output):
7171

7272
dlevel = level_deltas.data.numpy()
7373

74-
d_input = np.zeros(input_features.shape, dtype=np.float32)
74+
d_input = np.zeros(input_features.shape, dtype=input_features.data.numpy().dtype)
7575
d_weights = np.empty(weights.shape, dtype=np.float32)
7676
d_bias = np.empty(bias.shape, dtype=np.float32)
7777

@@ -126,7 +126,7 @@ def forward(ctx, intensities, weights, bias, aprs, level_deltas):
126126

127127
ctx.save_for_backward(intensities, weights, bias, torch.from_numpy(np.copy(dlevel)))
128128

129-
output = np.zeros(shape=(intensities.shape[0], weights.shape[0], intensities.shape[2]), dtype=np.float32)
129+
output = np.zeros(shape=(intensities.shape[0], weights.shape[0], intensities.shape[2]), dtype=intensities.data.numpy().dtype)
130130

131131
aprnn.convolve3x3(aprs, intensities.data.numpy(), weights.data.numpy(), bias.data.numpy(), output, dlevel)
132132

@@ -140,7 +140,7 @@ def backward(ctx, grad_output):
140140

141141
dlevel = level_deltas.data.numpy()
142142

143-
d_input = np.zeros(input_features.shape, dtype=np.float32)
143+
d_input = np.zeros(input_features.shape, dtype=input_features.data.numpy().dtype)
144144
d_weights = np.empty(weights.shape, dtype=np.float32)
145145
d_bias = np.empty(bias.shape, dtype=np.float32)
146146

@@ -185,7 +185,7 @@ def forward(ctx, intensities, weights, bias, aprs, level_deltas):
185185

186186
ctx.save_for_backward(intensities, weights, bias, torch.from_numpy(np.copy(dlevel)))
187187

188-
output = np.zeros(shape=(intensities.shape[0], weights.shape[0], intensities.shape[2]), dtype=np.float32)
188+
output = np.zeros(shape=(intensities.shape[0], weights.shape[0], intensities.shape[2]), dtype=intensities.data.numpy().dtype)
189189

190190
aprnn.convolve1x1(aprs, intensities.data.numpy(), weights.data.numpy(), bias.data.numpy(), output, dlevel)
191191

@@ -198,12 +198,13 @@ def backward(ctx, grad_output):
198198
aprs = ctx.apr
199199

200200
dlevel = level_deltas.data.numpy()
201+
np_input = input_features.data.numpy()
201202

202-
d_input = np.zeros(input_features.shape, dtype=np.float32)
203+
d_input = np.zeros(input_features.shape, dtype=np_input.dtype)
203204
d_weights = np.empty(weights.shape, dtype=np.float32)
204205
d_bias = np.empty(bias.shape, dtype=np.float32)
205206

206-
aprnn.convolve1x1_backward(aprs, grad_output.data.numpy(), input_features.data.numpy(), weights.data.numpy(),
207+
aprnn.convolve1x1_backward(aprs, grad_output.data.numpy(), np_input, weights.data.numpy(),
207208
d_input, d_weights, d_bias, dlevel)
208209

209210
return torch.from_numpy(d_input), torch.from_numpy(d_weights), torch.from_numpy(d_bias), None, None
@@ -234,7 +235,7 @@ def forward(self, input_features, apr, level_deltas):
234235

235236
class APRMaxPoolFunction(Function):
236237
@staticmethod
237-
def forward(ctx, intensities, apr, level_deltas):
238+
def forward(ctx, intensities, apr, level_deltas, inc_dlvl):
238239

239240
dlevel = level_deltas.data.numpy()
240241

@@ -246,13 +247,14 @@ def forward(ctx, intensities, apr, level_deltas):
246247
npart = aprnn.number_particles_after_pool(apr[i], dlevel[i])
247248
npartmax = max(npartmax, npart)
248249

249-
output = -(np.finfo(np.float32).max / 2) * np.ones(shape=(intensities.shape[0], intensities.shape[1], npartmax), dtype=np.float32)
250+
output = -(np.finfo(np.float32).max / 2) * np.ones(shape=(intensities.shape[0], intensities.shape[1], npartmax), dtype=intensities.data.numpy().dtype)
250251
index_arr = -np.ones(output.shape, dtype=np.int64)
251252

252253
aprnn.max_pool(apr, intensities.data.numpy(), output, dlevel, index_arr)
253254

254-
for i in range(level_deltas.shape[0]):
255-
level_deltas[i] += 1
255+
if inc_dlvl:
256+
for i in range(level_deltas.shape[0]):
257+
level_deltas[i] += 1
256258

257259
ctx.max_indices = index_arr
258260

@@ -262,17 +264,19 @@ def forward(ctx, intensities, apr, level_deltas):
262264
def backward(ctx, grad_output):
263265

264266
max_indices = ctx.max_indices
265-
grad_input = np.zeros(ctx.input_shape, dtype=np.float32)
267+
grad_input = np.zeros(ctx.input_shape, dtype=grad_output.data.numpy().dtype)
266268

267269
aprnn.max_pool_backward(grad_output.data.numpy(), grad_input, max_indices)
268270

269-
return torch.from_numpy(grad_input), None, None
271+
return torch.from_numpy(grad_input), None, None, None
270272

271273

272274
class APRMaxPool(nn.Module):
273-
def __init__(self):
275+
def __init__(self, increment_level_delta=True):
274276
super(APRMaxPool, self).__init__()
275277

278+
self.increment_level_delta=increment_level_delta
279+
276280
def forward(self, input_features, apr, level_deltas):
277281

278-
return APRMaxPoolFunction.apply(input_features, apr, level_deltas)
282+
return APRMaxPoolFunction.apply(input_features, apr, level_deltas, self.increment_level_delta)

0 commit comments

Comments
 (0)