Skip to content

Commit 9e7da67

Browse files
committed
Add a neon::relu.
1 parent c86e7e2 commit 9e7da67

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

paddle/math/BaseMatrix.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include <cmath>
1818
#include "BaseMatrix.h"
1919
#include "MathFunctions.h"
20+
#include "NEONFunctions.h"
2021
#include "SIMDFunctions.h"
2122
#include "hl_matrix_apply.cuh"
2223
#include "hl_matrix_base.cuh"
@@ -666,6 +667,13 @@ void BaseMatrixT<T>::relu(BaseMatrixT& b) {
666667
applyBinary(binary::Relu<T>(), b);
667668
}
668669

670+
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
671+
template <>
672+
void BaseMatrixT<float>::relu(BaseMatrixT& b) {
673+
neon::relu(data_, b.data_, height_ * width_);
674+
}
675+
#endif
676+
669677
DEFINE_MATRIX_BINARY_OP(ReluDerivative, a *= (b > 0.0f ? 1.0f : 0.0f));
670678
template <class T>
671679
void BaseMatrixT<T>::reluDerivative(BaseMatrixT& b) {

paddle/math/NEONFunctions.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
16+
17+
#include "NEONFunctions.h"
18+
#include <arm_neon.h>
19+
20+
namespace paddle {
21+
namespace neon {
22+
23+
// b[i] = a[i] > 0.0f ? a[i] : 0.0f
24+
void relu(const float* a, float* b, int len) {
25+
int offset = len % 16;
26+
float32x4_t ma0, ma1, ma2, ma3;
27+
float32x4_t mb0, mb1, mb2, mb3;
28+
29+
float32x4_t zero = vdupq_n_f32(0.f);
30+
for (int k = 0; k < len / 16; k++, a += 16, b += 16) {
31+
ma0 = vld1q_f32(a);
32+
ma1 = vld1q_f32(a + 4);
33+
ma2 = vld1q_f32(a + 8);
34+
ma3 = vld1q_f32(a + 12);
35+
36+
mb0 = vmaxq_f32(ma0, zero);
37+
mb1 = vmaxq_f32(ma1, zero);
38+
mb2 = vmaxq_f32(ma2, zero);
39+
mb3 = vmaxq_f32(ma3, zero);
40+
41+
vst1q_f32(b, mb0);
42+
vst1q_f32(b + 4, mb1);
43+
vst1q_f32(b + 8, mb2);
44+
vst1q_f32(b + 12, mb3);
45+
}
46+
47+
for (int i = 0; i < offset; i++) {
48+
b[i] = a[i] > 0.0f ? a[i] : 0.0f;
49+
}
50+
}
51+
52+
} // namespace neon
53+
} // namespace paddle
54+
55+
#endif

paddle/math/NEONFunctions.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
namespace paddle {
18+
namespace neon {
19+
20+
void relu(const float* a, float* b, int len);
21+
22+
} // namespace neon
23+
} // namespace paddle

0 commit comments

Comments
 (0)