Skip to content

Commit 648cdab

Browse files
[ILUVATAR_GPU] Add logic to apply patches to python files in install script && Fix the segment fault that occurred after linking with the NCCL library. (#1762)
1 parent c4f3116 commit 648cdab

15 files changed

+545
-109
lines changed

backends/iluvatar_gpu/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ cmake_minimum_required(VERSION 3.10)
1515
set(PROJ_NAME "paddle-iluvatar-gpu")
1616
project(${PROJ_NAME} CXX C CUDA)
1717

18-
set(PLUGIN_VERSION "0.0.1")
1918
set(TARGET_NAME ${PROJ_NAME})
2019

2120
set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake")
@@ -35,6 +34,7 @@ include(external/xxhash)
3534
include(external/zlib)
3635
include(external/protobuf)
3736

37+
set(PLUGIN_VERSION ${PADDLE_VERSION})
3838
set(PROTO_FILE "${PADDLE_SOURCE_DIR}/paddle/phi/core/external_error.proto")
3939
get_filename_component(PROTO_WE "${PROTO_FILE}" NAME_WE)
4040

@@ -253,7 +253,6 @@ target_link_libraries(
253253
protobuf
254254
external_error_proto
255255
cuinfer
256-
# May cause a segment fault when the program exits
257256
nccl)
258257

259258
include_directories(BEFORE ${PADDLE_SOURCE_DIR})

backends/iluvatar_gpu/build_paddle.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ fi
2424
BUILD_TEST=${BUILD_TEST:-1}
2525
COREX_ARCH=${COREX_ARCH:-ivcore11}
2626
export CMAKE_CUDA_ARCHITECTURES=${COREX_ARCH}
27-
export PADDLE_VERSION=${PADDLE_VERSION:-3.0.0}
2827

2928
CURRENT_DIR=$(pwd)
3029
PADDLE_SOURCE_DIR="${CURRENT_DIR}/../../Paddle"

backends/iluvatar_gpu/install_paddle.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ PYTHON_PATH=$(which python3)
2020
PYTHON_DIST_PATH=${TARGET_DIR}/lib/python3/dist-packages
2121

2222
PKG_DIR="build_pip"
23+
PKGCPU_NAME="paddlepaddle"
2324
PKG_NAME="paddle_iluvatar_gpu"
2425

2526
if [[ ! -d ${PKG_DIR} ]]; then
@@ -43,6 +44,8 @@ if [[ "${TARGET_DIR}" != "" ]]; then
4344
rm -rf ./tmp
4445
echo "Paddle installed in ${PYTHON_DIST_PATH}; please add it to your PYTHONPATH."
4546
else
47+
${PYTHON_PATH} -m pip uninstall ${PKGCPU_NAME} -y
48+
${PYTHON_PATH} -m pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
4649
${PYTHON_PATH} -m pip uninstall ${PKG_NAME} -y
4750
${PYTHON_PATH} -m pip install ${PKG_DIR}/${latest_pkg} || exit
4851
fi
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "glog/logging.h"
16+
#include "paddle/common/flags.h"
17+
#include "paddle/phi/backends/gpu/gpu_context.h"
18+
#include "paddle/phi/backends/gpu/gpu_primitives.h"
19+
#include "paddle/phi/core/kernel_registry.h"
20+
#include "paddle/phi/kernels/c_embedding_kernel.h"
21+
#include "paddle/phi/kernels/funcs/eigen/common.h"
22+
#include "paddle/phi/kernels/funcs/embedding_grad.h"
23+
24+
COMMON_DECLARE_int64(embedding_deterministic);
25+
26+
namespace phi {
27+
28+
static constexpr int kNumCUDAThreads = 512;
29+
static constexpr int kNumMaximumNumBlocks = 4096;
30+
31+
static inline int NumBlocks(const int N) {
32+
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
33+
kNumMaximumNumBlocks);
34+
}
35+
36+
template <typename T, typename IndexT>
37+
__global__ void CEmbeddingGrad(T* table,
38+
const T* output,
39+
const IndexT* ids,
40+
const int rows,
41+
const int columns,
42+
const int64_t N,
43+
const int64_t start_idx,
44+
const int64_t end_idx,
45+
const int64_t limit) {
46+
CUDA_KERNEL_LOOP(i, limit) {
47+
size_t row = i / columns;
48+
size_t col = i % columns;
49+
auto id = ids[row];
50+
if (id >= start_idx && id < end_idx) {
51+
auto real_idx = id - start_idx;
52+
phi::CudaAtomicAdd(&table[real_idx * columns + col], output[i]);
53+
}
54+
}
55+
}
56+
57+
template <typename T, typename Context>
58+
void CEmbeddingGradKernel(const Context& dev_ctx,
59+
const DenseTensor& w,
60+
const DenseTensor& ids,
61+
const DenseTensor& out_grad,
62+
int64_t start_index,
63+
DenseTensor* w_grad) {
64+
int N = w_grad->dims()[0];
65+
int D = w_grad->dims()[1];
66+
int K = ids.numel();
67+
68+
auto limit = K * D;
69+
int blocks = NumBlocks(limit);
70+
int threads = kNumCUDAThreads;
71+
72+
const T* d_output = out_grad.data<T>();
73+
T* d_table = dev_ctx.template Alloc<T>(w_grad);
74+
75+
auto t = EigenVector<T>::Flatten(*w_grad);
76+
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
77+
78+
const auto& index_type = ids.dtype();
79+
if (FLAGS_embedding_deterministic == 1) {
80+
if (index_type == phi::DataType::INT32) {
81+
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, int32_t>(
82+
dev_ctx,
83+
ids.data<int32_t>(),
84+
d_output,
85+
d_table,
86+
N,
87+
D,
88+
K,
89+
start_index);
90+
return;
91+
} else if (index_type == phi::DataType::INT64) {
92+
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, int64_t>(
93+
dev_ctx,
94+
ids.data<int64_t>(),
95+
d_output,
96+
d_table,
97+
N,
98+
D,
99+
K,
100+
start_index);
101+
return;
102+
}
103+
} else {
104+
if (FLAGS_embedding_deterministic > 1) {
105+
VLOG(2) << "Run grad kernel of embedding with single thread.";
106+
blocks = 1;
107+
}
108+
const int64_t end_idx = start_index + N;
109+
if (index_type == phi::DataType::INT32) {
110+
CEmbeddingGrad<T, int32_t>
111+
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
112+
d_output,
113+
ids.data<int32_t>(),
114+
K,
115+
D,
116+
N,
117+
start_index,
118+
end_idx,
119+
limit);
120+
return;
121+
} else if (index_type == phi::DataType::INT64) {
122+
CEmbeddingGrad<T, int64_t>
123+
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
124+
d_output,
125+
ids.data<int64_t>(),
126+
K,
127+
D,
128+
N,
129+
start_index,
130+
end_idx,
131+
limit);
132+
return;
133+
}
134+
}
135+
PADDLE_THROW(common::errors::InvalidArgument(
136+
"The data type of Input(Ids) must be int32 or int64."));
137+
}
138+
139+
} // namespace phi
140+
141+
PD_REGISTER_PLUGIN_KERNEL(c_embedding_grad,
142+
iluvatar_gpu,
143+
ALL_LAYOUT,
144+
phi::CEmbeddingGradKernel,
145+
float,
146+
phi::dtype::bfloat16,
147+
phi::dtype::float16,
148+
phi::dtype::complex<float>) {}

backends/iluvatar_gpu/kernels/cuda_kernels/c_embedding_grad_kernel_register.cc

Lines changed: 0 additions & 25 deletions
This file was deleted.
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/backends/gpu/gpu_context.h"
16+
#include "paddle/phi/core/kernel_registry.h"
17+
#include "paddle/phi/kernels/c_embedding_kernel.h"
18+
19+
namespace phi {
20+
21+
static constexpr int kNumCUDAThreads = 512;
22+
static constexpr int kNumMaximumNumBlocks = 4096;
23+
24+
static inline int NumBlocks(const int N) {
25+
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
26+
kNumMaximumNumBlocks);
27+
}
28+
29+
template <typename T, typename IndexT>
30+
__global__ void CEmbedding(T* out,
31+
const T* table,
32+
const IndexT* ids,
33+
const int rows,
34+
const int columns,
35+
const int64_t N,
36+
const int64_t start_idx,
37+
const int64_t end_idx,
38+
const int64_t limit,
39+
const int64_t vocab_size) {
40+
CUDA_KERNEL_LOOP(i, limit) {
41+
size_t row = i / columns;
42+
size_t col = i % columns;
43+
auto id = ids[row];
44+
45+
PADDLE_ENFORCE(
46+
id >= 0 && (vocab_size < 0 || id < vocab_size),
47+
"The index is out of bounds, "
48+
"please check whether the dimensions of index and "
49+
"input meet the requirements. It should "
50+
"be less than [%d] and greater than or equal to 0, but received [%d]",
51+
vocab_size,
52+
id);
53+
if (id >= start_idx && id < end_idx) {
54+
auto real_idx = id - start_idx;
55+
out[i] = table[real_idx * columns + col];
56+
} else {
57+
out[i] = static_cast<T>(0);
58+
}
59+
}
60+
}
61+
62+
template <typename T, typename Context>
63+
void CEmbeddingKernel(const Context& dev_ctx,
64+
const DenseTensor& w,
65+
const DenseTensor& ids,
66+
int64_t start_index,
67+
int64_t vocab_size,
68+
DenseTensor* out) {
69+
size_t N = w.dims()[0];
70+
size_t D = w.dims()[1];
71+
size_t K = ids.numel();
72+
73+
const int64_t end_idx = start_index + N;
74+
75+
auto* table = w.data<T>();
76+
auto* output = dev_ctx.template Alloc<T>(out);
77+
78+
auto limit = K * D;
79+
int blocks = NumBlocks(limit);
80+
int threads = kNumCUDAThreads;
81+
82+
const auto& index_type = ids.dtype();
83+
if (index_type == phi::DataType::INT32) {
84+
CEmbedding<T, int32_t>
85+
<<<blocks, threads, 0, dev_ctx.stream()>>>(output,
86+
table,
87+
ids.data<int32_t>(),
88+
K,
89+
D,
90+
N,
91+
start_index,
92+
end_idx,
93+
limit,
94+
vocab_size);
95+
96+
} else if (index_type == phi::DataType::INT64) {
97+
CEmbedding<T, int64_t>
98+
<<<blocks, threads, 0, dev_ctx.stream()>>>(output,
99+
table,
100+
ids.data<int64_t>(),
101+
K,
102+
D,
103+
N,
104+
start_index,
105+
end_idx,
106+
limit,
107+
vocab_size);
108+
} else {
109+
PADDLE_THROW(common::errors::Unavailable(
110+
"GPU c_embedding ids only support int32 or int64."));
111+
}
112+
}
113+
} // namespace phi
114+
115+
PD_REGISTER_PLUGIN_KERNEL(c_embedding,
116+
iluvatar_gpu,
117+
ALL_LAYOUT,
118+
phi::CEmbeddingKernel,
119+
float,
120+
phi::dtype::bfloat16,
121+
phi::dtype::float16,
122+
phi::dtype::complex<float>) {}

backends/iluvatar_gpu/kernels/cuda_kernels/c_embedding_kernel_register.cc

Lines changed: 0 additions & 25 deletions
This file was deleted.

backends/iluvatar_gpu/kernels/cuda_kernels/flash_attn_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ void FlashAttnUnpaddedBaseKernel(
152152
flashAttnInfo.softmax_scale = std::sqrt(1.f / head_size);
153153
flashAttnInfo.dropout_prob = is_test ? 0.0f : dropout;
154154
flashAttnInfo.is_causal = causal;
155+
flashAttnInfo.causal_mode = 1;
155156
// flashAttnInfo.is_alibi = use_alibi;
156157
// flashAttnInfo.alibi_mode = alibi_mode;
157158
flashAttnInfo.return_softmax_lse = true;

0 commit comments

Comments
 (0)