Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions custom_ops/xpu_ops/src/ops/init_signal_layerwise.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "ops/remote_cache_kv_ipc.h"
#include "paddle/extension.h"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

using cache_write_complete_signal_type =
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;

paddle::Tensor InitSignalLayerwiseFunc(const paddle::Tensor& kv_signal_metadata,
const int layer_id) {
auto kv_signal_metadata_out =
kv_signal_metadata.copy_to(paddle::CPUPlace(), false);
kv_signal_metadata_out.data<int64_t>()[0] = static_cast<int64_t>(layer_id);
return kv_signal_metadata_out;
}

std::vector<paddle::Tensor> InitSignalLayerwise(
const paddle::Tensor& kv_signal_metadata, const int layer_id) {
return {InitSignalLayerwiseFunc(kv_signal_metadata, layer_id)};
}

std::vector<std::vector<int64_t>> InitSignalLayerwiseShape(
const std::vector<int64_t>& kv_signal_metadata_shape, const int layer_id) {
return {kv_signal_metadata_shape};
}

std::vector<paddle::DataType> InitSignalLayerwiseDtype(
const paddle::DataType& kv_signal_metadata_dtype, const int layer_id) {
return {paddle::DataType::INT64};
}

PD_BUILD_STATIC_OP(init_signal_layerwise)
.Inputs({"kv_signal_metadata"})
.Outputs({"kv_signal_metadata_out"})
.Attrs({"layer_id: int"})
.SetKernelFn(PD_KERNEL(InitSignalLayerwise))
.SetInferShapeFn(PD_INFER_SHAPE(InitSignalLayerwiseShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InitSignalLayerwiseDtype));
27 changes: 16 additions & 11 deletions custom_ops/xpu_ops/src/ops/open_shm_and_get_meta_signal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,27 @@
#include "ops/utility/env.h"
#include "paddle/extension.h"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

XPU_DECLARE_BOOL(fmt_write_cache_completed_signal, false);

using cache_write_complete_signal_type =
RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data;

paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
const int device_id,
const bool keep_pd_step_flag) {
cache_write_complete_signal_type kv_signal_metadata;
const char *fmt_write_cache_completed_signal_str =
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
kv_signal_metadata =
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
rank, keep_pd_step_flag);
rank, device_id, keep_pd_step_flag);
}

auto kv_signal_metadata_out =
Expand All @@ -46,9 +51,9 @@ paddle::Tensor OpenShmAndGetMetaSignalFunc(const int rank,
return kv_signal_metadata_out;
}

void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
const paddle::Tensor &seq_lens_this_time_tensor,
const paddle::Tensor &seq_lens_decoder_tensor,
void InitKVSignalPerQuery(const paddle::Tensor& seq_lens_encoder_tensor,
const paddle::Tensor& seq_lens_this_time_tensor,
const paddle::Tensor& seq_lens_decoder_tensor,
const int rank,
const int num_layers) {
if (FLAGS_fmt_write_cache_completed_signal) {
Expand All @@ -68,24 +73,24 @@ void InitKVSignalPerQuery(const paddle::Tensor &seq_lens_encoder_tensor,
}

std::vector<paddle::Tensor> OpenShmAndGetMetaSignal(
const int rank, const bool keep_pd_step_flag) {
return {OpenShmAndGetMetaSignalFunc(rank, keep_pd_step_flag)};
const int rank, const int device_id, const bool keep_pd_step_flag) {
return {OpenShmAndGetMetaSignalFunc(rank, device_id, keep_pd_step_flag)};
}

std::vector<std::vector<int64_t>> OpenShmAndGetMetaSignalShape(
const int rank, const bool keep_pd_step_flag) {
const int rank, const int device_id, const bool keep_pd_step_flag) {
return {{3}};
}

std::vector<paddle::DataType> OpenShmAndGetMetaSignalDtype(
const int rank, const bool keep_pd_step_flag) {
const int rank, const int device_id, const bool keep_pd_step_flag) {
return {paddle::DataType::INT64};
}

PD_BUILD_OP(open_shm_and_get_meta_signal)
PD_BUILD_STATIC_OP(open_shm_and_get_meta_signal)
.Inputs({})
.Outputs({"kv_signal_metadata"})
.Attrs({"rank: int", "keep_pd_step_flag: bool"})
.Attrs({"rank: int", "device_id: int", "keep_pd_step_flag: bool"})
.SetKernelFn(PD_KERNEL(OpenShmAndGetMetaSignal))
.SetInferShapeFn(PD_INFER_SHAPE(OpenShmAndGetMetaSignalShape))
.SetInferDtypeFn(PD_INFER_DTYPE(OpenShmAndGetMetaSignalDtype));
11 changes: 6 additions & 5 deletions custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ bool RemoteCacheKvIpc::kv_complete_signal_shmem_opened = false;

RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
const int rank_id, const bool keep_pd_step_flag) {
const int rank_id, const int device_id, const bool keep_pd_step_flag) {
if (RemoteCacheKvIpc::kv_complete_signal_shmem_opened) {
if (keep_pd_step_flag) {
return RemoteCacheKvIpc::kv_complete_signal_meta_data;
Expand All @@ -47,12 +47,13 @@ RemoteCacheKvIpc::open_shm_and_get_complete_signal_meta_data(
std::string iflags_server_uuid_env_str(iflags_server_uuid_env_p);
flags_server_uuid = iflags_server_uuid_env_str;
}

std::string step_shm_name =
("splitwise_complete_prefilled_step_" + std::to_string(rank_id) + "_" +
flags_server_uuid);
("splitwise_complete_prefilled_step_" + std::to_string(rank_id) + "." +
std::to_string(device_id));
std::string layer_shm_name =
("splitwise_complete_prefilled_layer_" + std::to_string(rank_id) + "_" +
flags_server_uuid);
("splitwise_complete_prefilled_layer_" + std::to_string(rank_id) + "." +
std::to_string(device_id));
if (const char* use_ep = std::getenv("ENABLE_EP_DP")) {
if (std::strcmp(use_ep, "1") == 0) {
step_shm_name = "splitwise_complete_prefilled_step_tprank0_dprank" +
Expand Down
1 change: 1 addition & 0 deletions custom_ops/xpu_ops/src/ops/remote_cache_kv_ipc.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ struct RemoteCacheKvIpc {

static RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise_meta_data
open_shm_and_get_complete_signal_meta_data(const int rank_id,
const int device_id,
const bool keep_pd_step_flag);
static void save_cache_kv_complete_signal_layerwise(void* meta_data);
static void save_cache_kv_complete_signal_layerwise_per_query(
Expand Down
14 changes: 7 additions & 7 deletions custom_ops/xpu_ops/src/ops/share_external_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,26 @@
#include "xpu/plugin.h"
#include "xpu_multiprocess.h" // NOLINT(build/include_subdir)

std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor &input,
std::vector<paddle::Tensor> ShareExternalData(const paddle::Tensor& input,
const std::string shm_name,
const std::vector<int> &shape,
const std::vector<int>& shape,
bool use_ipc) {
sharedMemoryInfo info;
int ret = sharedMemoryOpen(shm_name.c_str(), sizeof(shmStruct), &info);
PD_CHECK(ret == 0, "sharedMemoryOpen failed");
volatile shmStruct *shm = static_cast<volatile shmStruct *>(info.addr);
void *data_ptr_addr = nullptr;
volatile shmStruct* shm = static_cast<volatile shmStruct*>(info.addr);
void* data_ptr_addr = nullptr;
if (use_ipc) {
#if XPURT_VERSION_MAJOR == 5
int ret = xpu_ipc_open_memhandle(&data_ptr_addr,
*(XPUIpcMemHandle *)&shm->memHandle,
*(XPUIpcMemHandle*)&shm->memHandle,
0x01); // NOLINT
PD_CHECK(ret == XPU_SUCCESS, "xpu_ipc_open_memhandle failed");
PD_CHECK(ret == XPU_SUCCESS, shm_name, " xpu_ipc_open_memhandle failed");
#elif XPURT_VERSION_MAJOR == 4
PD_THROW("kl2 not support prefix cache");
#endif
} else {
data_ptr_addr = reinterpret_cast<void *>(shm->data_ptr_addr);
data_ptr_addr = reinterpret_cast<void*>(shm->data_ptr_addr);
}

phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
Expand Down
31 changes: 22 additions & 9 deletions fastdeploy/cache_manager/cache_messager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,20 @@
import numpy as np
import paddle

from fastdeploy.cache_manager.ops import (
get_output_kv_signal,
get_peer_mem_addr,
memory_allocated,
set_data_ipc,
set_device,
)
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
from fastdeploy.config import SpeculativeConfig
from fastdeploy.inter_communicator import (
EngineWorkerQueue,
IPCSignal,
shared_memory_exists,
)
from fastdeploy.model_executor.ops.gpu import get_output_kv_signal, set_data_ipc
from fastdeploy.utils import envs, get_logger

logger = get_logger("cache_messager", "cache_messager.log")
Expand Down Expand Up @@ -157,16 +163,20 @@ def __init__(
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
if paddle.is_compiled_with_xpu():
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)

# 2. initialize the block_bytes
cache_shape = key_cache.shape
max_block_num = cache_shape[0]
block_bytes = math.prod(cache_shape[1:])
if key_cache.dtype == paddle.bfloat16:
if key_cache.dtype == paddle.bfloat16 or key_cache.dtype == paddle.float16:
block_bytes *= 2
logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
Expand Down Expand Up @@ -452,8 +462,12 @@ def __init__(
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
cache_k.append(key_cache)
cache_v.append(val_cache)
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
if paddle.is_compiled_with_xpu():
cache_k_ptr_list.append(get_peer_mem_addr(key_cache.data_ptr()))
cache_v_ptr_list.append(get_peer_mem_addr(val_cache.data_ptr()))
else:
cache_k_ptr_list.append(key_cache.data_ptr())
cache_v_ptr_list.append(val_cache.data_ptr())
cache_k_ptr_list = np.array(cache_k_ptr_list)
cache_v_ptr_list = np.array(cache_v_ptr_list)

Expand Down Expand Up @@ -763,7 +777,7 @@ def _handle_connect_task(self):
def main():
device = args.device_id
rank = args.rank
paddle.set_device(f"gpu:{device}")
set_device(device)
cache_type = args.cache_dtype
speculative_config = SpeculativeConfig(args.speculative_config)
num_extra_layers = speculative_config.num_extra_cache_layer
Expand Down Expand Up @@ -823,7 +837,7 @@ def main():
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
logger.info(f"device :{device}")
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
logger.info(f"done init cache (full) gmem alloc : {memory_allocated}")

if envs.ENABLE_V1_KVCACHE_SCHEDULER:
cache_messager = CacheMessagerV1(
Expand Down Expand Up @@ -875,7 +889,6 @@ def main():
args = parse_args()
rank_id = args.rank + args.local_data_parallel_id * args.mp_num
logger = get_logger("cache_messager", f"cache_messager_rank{rank_id}.log")

logger.info("create cache messager...")
logger.info(f"{args}")
main()
32 changes: 32 additions & 0 deletions fastdeploy/cache_manager/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,27 @@
from fastdeploy.model_executor.ops.gpu import (
cuda_host_alloc,
cuda_host_free,
get_data_ptr_ipc,
get_output_kv_signal,
ipc_sent_key_value_cache_by_remote_ptr,
ipc_sent_key_value_cache_by_remote_ptr_block_sync,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
unset_data_ipc,
)

memory_allocated = paddle.device.cuda.memory_allocated

def get_peer_mem_addr(*args, **kwargs):
raise RuntimeError("CUDA no need of get_peer_mem_addr!")

elif current_platform.is_xpu():
from fastdeploy.model_executor.ops.xpu import (
cuda_host_alloc,
cuda_host_free,
get_output_kv_signal,
get_peer_mem_addr,
set_data_ipc,
share_external_data,
swap_cache_all_layers,
Expand All @@ -25,6 +35,15 @@
unset_data_ipc = None
memory_allocated = paddle.device.xpu.memory_allocated

def get_data_ptr_ipc(*args, **kwargs):
raise RuntimeError("XPU get_data_ptr_ipc UNIMPLENENTED!")

def ipc_sent_key_value_cache_by_remote_ptr(*args, **kwargs):
raise RuntimeError("XPU ipc_sent_key_value_cache_by_remote_ptr UNIMPLENENTED")

def ipc_sent_key_value_cache_by_remote_ptr_block_sync(*args, **kwargs):
raise RuntimeError("XPU No ipc_sent_key_value_cache_by_remote_ptr UNIMPLENENTED")

else:
raise RuntimeError("Prefix cache ops only supported CUDA nor XPU platform ")

Expand All @@ -48,6 +67,13 @@ def share_external_data_(cache, cache_name, cache_shape, use_ipc):
return cache


def get_all_visible_devices():
if current_platform.is_xpu():
return "XPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
else:
return "CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"


__all__ = [
"cuda_host_alloc",
"cuda_host_free",
Expand All @@ -57,4 +83,10 @@ def share_external_data_(cache, cache_name, cache_shape, use_ipc):
"unset_data_ipc", # XPU是 None
"set_device",
"memory_allocated",
"get_output_kv_signal",
"get_data_ptr_ipc",
"ipc_sent_key_value_cache_by_remote_ptr",
"ipc_sent_key_value_cache_by_remote_ptr_block_sync",
"get_peer_mem_addr",
"get_all_visible_devices",
]
9 changes: 7 additions & 2 deletions fastdeploy/cache_manager/prefix_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from fastdeploy import envs
from fastdeploy.cache_manager.cache_data import BlockNode, CacheStatus
from fastdeploy.cache_manager.cache_metrics import CacheMetrics
from fastdeploy.cache_manager.ops import get_all_visible_devices
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, PrefixTreeStatus
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.utils import get_logger
Expand Down Expand Up @@ -243,9 +244,11 @@ def launch_cache_manager(
# Run command to launch cache transfer managers
log_dir = envs.FD_LOG_DIR
cache_manager_processes = []
visible_devices = get_all_visible_devices()
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
"FLAGS_allocator_strategy=auto_growth "
+ visible_devices
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}"
+ f" {sys.executable} {py_path}"
Expand Down Expand Up @@ -328,9 +331,11 @@ def launch_cache_messager(
py_path = os.path.join(current_dir_path, filename)
log_dir = envs.FD_LOG_DIR
cache_messager_processes = []
visible_devices = get_all_visible_devices()
for i in range(tensor_parallel_size):
launch_cmd = (
"FLAGS_allocator_strategy=auto_growth CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7"
"FLAGS_allocator_strategy=auto_growth "
+ visible_devices
+ " NCCL_MAX_NCHANNELS=1 NCCL_BUFFSIZE=0"
+ f" {sys.executable} {py_path}"
+ f" --device_id {int(device_ids[i])}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import paddle

from fastdeploy.model_executor.ops.gpu import (
from fastdeploy.cache_manager.ops import (
get_data_ptr_ipc,
ipc_sent_key_value_cache_by_remote_ptr,
ipc_sent_key_value_cache_by_remote_ptr_block_sync,
Expand Down
Loading
Loading