From dcdbec6650e0bb2d21ad7c59d8ca41d32b59ed7c Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Tue, 18 Feb 2025 12:10:47 -0800 Subject: [PATCH 01/12] Commit for Draft Changes --- src/grpc/grpc_server.cc | 733 ++++++++++++++++++++-------------------- src/grpc/grpc_server.h | 8 +- 2 files changed, 378 insertions(+), 363 deletions(-) diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 74ec443ae6..836e32d11e 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -1,4 +1,4 @@ -// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -79,6 +79,57 @@ namespace { // are deemed to be not performance critical. //========================================================================= +template +class CommonCallbackData { + public: + using CallbackFunc = + std::function; + + CommonCallbackData( + const std::string& name, + inference::GRPCInferenceService::CallbackService* service, + const CallbackFunc& callback, + const std::pair& restricted_kv) + : name_(name), service_(service), callback_(callback), + restricted_kv_(restricted_kv) + { + } + + void operator()(RequestType* request) + { + ResponseType response; + ::grpc::Status status; + + if (ExecutePrecondition()) { + callback_(*request, &response, &status); + } else { + status = ::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + std::string("This protocol is restricted, expecting header '") + + restricted_kv_.first + "'"); + } + + request->request()->Complete(status); + delete this; + } + + private: + bool ExecutePrecondition() + { + if (!restricted_kv_.first.empty()) { + const auto& metadata = request->context()->client_metadata(); + const auto it = metadata.find(restricted_kv_.first); + return (it != metadata.end()) && (it->second == restricted_kv_.second); + } + return true; + } + + const std::string name_; + inference::GRPCInferenceService::CallbackService* service_; + CallbackFunc callback_; + std::pair restricted_kv_; +}; + template class CommonCallData : public ICallData { public: @@ -264,7 +315,8 @@ class CommonHandler : public HandlerBase { TraceManager* trace_manager, inference::GRPCInferenceService::AsyncService* service, ::grpc::health::v1::Health::AsyncService* health_service, - ::grpc::ServerCompletionQueue* cq, + inference::GRPCInferenceService::CallbackService* + non_inference_callback_service, const RestrictedFeatures& restricted_keys, const uint64_t response_delay); // Descriptive name of of the handler. @@ -315,6 +367,9 @@ class CommonHandler : public HandlerBase { inference::GRPCInferenceService::AsyncService* service_; ::grpc::health::v1::Health::AsyncService* health_service_; + inference::GRPCInferenceService::CallbackService* + non_inference_callback_service_; + ::grpc::ServerCompletionQueue* cq_; std::unique_ptr thread_; RestrictedFeatures restricted_keys_{}; @@ -333,7 +388,8 @@ CommonHandler::CommonHandler( const uint64_t response_delay = 0) : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), trace_manager_(trace_manager), service_(service), - health_service_(health_service), cq_(cq), + health_service_(health_service), + non_inference_callback_service_(non_inference_callback_service), cq_(cq), restricted_keys_(restricted_keys), response_delay_(response_delay) { } @@ -464,23 +520,18 @@ CommonHandler::RegisterServerLive() false /* async */, cq_, restricted_kv, response_delay_); } +// This change leverages the callback API, simplifying the handling of the +// ServerReady request by directly using the non_inference_callback_service_. void CommonHandler::RegisterServerReady() { - auto OnRegisterServerReady = - [this]( - ::grpc::ServerContext* ctx, inference::ServerReadyRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestServerReady( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerReady = [this]( - inference::ServerReadyRequest& request, - inference::ServerReadyResponse* response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a ServerReadyRequest, + // a ServerReadyResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteServerReady function. + auto callback = [this]( + inference::ServerReadyRequest& request, + inference::ServerReadyResponse* response, + ::grpc::Status* status) { bool ready = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); @@ -493,33 +544,25 @@ CommonHandler::RegisterServerReady() const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ServerReadyRequest, inference::ServerReadyResponse>( - "ServerReady", 0, OnRegisterServerReady, OnExecuteServerReady, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->ServerReady to register the callback. + // This replaces the use of CommonCallData. + non_inference_callback_service_->ServerReady( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::ServerReadyRequest, inference::ServerReadyResponse>( + "ServerReady", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterHealthCheck() { - auto OnRegisterHealthCheck = - [this]( - ::grpc::ServerContext* ctx, - ::grpc::health::v1::HealthCheckRequest* request, - ::grpc::ServerAsyncResponseWriter< - ::grpc::health::v1::HealthCheckResponse>* responder, - void* tag) { - this->health_service_->RequestCheck( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteHealthCheck = [this]( - ::grpc::health::v1::HealthCheckRequest& - request, - ::grpc::health::v1::HealthCheckResponse* - response, - ::grpc::Status* status) { + auto callback = [this]( + ::grpc::health::v1::HealthCheckRequest& request, + ::grpc::health::v1::HealthCheckResponse* response, + ::grpc::Status* status) { bool live = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), &live); @@ -540,32 +583,21 @@ CommonHandler::RegisterHealthCheck() const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - ::grpc::health::v1::HealthCheckResponse>, - ::grpc::health::v1::HealthCheckRequest, - ::grpc::health::v1::HealthCheckResponse>( - "Check", 0, OnRegisterHealthCheck, OnExecuteHealthCheck, - false /* async */, cq_, restricted_kv, response_delay_); + + non_inference_callback_service_->Check( + new CommonCallbackData< + ::grpc::health::v1::HealthCheckRequest, + ::grpc::health::v1::HealthCheckResponse>( + "Check", non_inference_callback_service_, callback, restricted_kv)); } void CommonHandler::RegisterModelReady() { - auto OnRegisterModelReady = - [this]( - ::grpc::ServerContext* ctx, inference::ModelReadyRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelReady( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelReady = [this]( - inference::ModelReadyRequest& request, - inference::ModelReadyResponse* response, - ::grpc::Status* status) { + auto callback = [this]( + ::grpc::health::v1::HealthCheckRequest& request, + ::grpc::health::v1::HealthCheckResponse* response, + ::grpc::Status* status) { bool is_ready = false; int64_t requested_model_version; auto err = @@ -581,335 +613,314 @@ CommonHandler::RegisterModelReady() GrpcStatusUtil::Create(status, err); TRITONSERVER_ErrorDelete(err); }; - const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelReadyRequest, inference::ModelReadyResponse>( - "ModelReady", 0, OnRegisterModelReady, OnExecuteModelReady, - false /* async */, cq_, restricted_kv, response_delay_); + non_inference_callback_service_->ModelReady( + new CommonCallbackData< + ::grpc::health::v1::HealthCheckRequest, + ::grpc::health::v1::HealthCheckResponse>( + "ModelReady", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterServerMetadata() { - auto OnRegisterServerMetadata = - [this]( - ::grpc::ServerContext* ctx, inference::ServerMetadataRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestServerMetadata( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerMetadata = - [this]( - inference::ServerMetadataRequest& request, - inference::ServerMetadataResponse* response, ::grpc::Status* status) { - TRITONSERVER_Message* server_metadata_message = nullptr; - TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( - tritonserver_.get(), &server_metadata_message); - GOTO_IF_ERR(err, earlyexit); - - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - server_metadata_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); - - { - triton::common::TritonJson::Value server_metadata_json; - err = server_metadata_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); - + // Define a lambda function 'callback' that takes a ServerMetadataRequest, + // a ServerMetadataResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteServerMetadata function. + auto callback = [this]( + inference::ServerMetadataRequest& request, + inference::ServerMetadataResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Message* server_metadata_message = nullptr; + TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( + tritonserver_.get(), &server_metadata_message); + if (err == nullptr) { + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + server_metadata_message, &buffer, &byte_size); + if (err == nullptr) { + triton::common::TritonJson::Value server_metadata_json; + err = server_metadata_json.Parse(buffer, byte_size); + if (err == nullptr) { const char* name; size_t namelen; err = server_metadata_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); - - const char* version; - size_t versionlen; - err = server_metadata_json.MemberAsString( - "version", &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - - response->set_name(std::string(name, namelen)); - response->set_version(std::string(version, versionlen)); - - if (server_metadata_json.Find("extensions")) { - triton::common::TritonJson::Value extensions_json; - err = server_metadata_json.MemberAsArray( - "extensions", &extensions_json); - GOTO_IF_ERR(err, earlyexit); - - for (size_t idx = 0; idx < extensions_json.ArraySize(); ++idx) { - const char* ext; - size_t extlen; - err = extensions_json.IndexAsString(idx, &ext, &extlen); - GOTO_IF_ERR(err, earlyexit); - response->add_extensions(std::string(ext, extlen)); + if (err == nullptr) { + const char* version; + size_t versionlen; + err = server_metadata_json.MemberAsString( + "version", &version, &versionlen); + if (err == nullptr) { + response->set_name(std::string(name, namelen)); + response->set_version(std::string(version, versionlen)); + + if (server_metadata_json.Find("extensions")) { + triton::common::TritonJson::Value extensions_json; + err = server_metadata_json.MemberAsArray( + "extensions", &extensions_json); + if (err == nullptr) { + for (size_t idx = 0; idx < extensions_json.ArraySize(); + ++idx) { + const char* ext; + size_t extlen; + err = extensions_json.IndexAsString(idx, &ext, &extlen); + if (err == nullptr) { + response->add_extensions(std::string(ext, extlen)); + } + } + } + } } } - TRITONSERVER_MessageDelete(server_metadata_message); } + } + TRITONSERVER_MessageDelete(server_metadata_message); + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::METADATA); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ServerMetadataRequest, inference::ServerMetadataResponse>( - "ServerMetadata", 0, OnRegisterServerMetadata, OnExecuteServerMetadata, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->ServerMetadata to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->ServerMetadata( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::ServerMetadataRequest, inference::ServerMetadataResponse>( + "ServerMetadata", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterModelMetadata() { - auto OnRegisterModelMetadata = - [this]( - ::grpc::ServerContext* ctx, inference::ModelMetadataRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelMetadata( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelMetadata = [this]( - inference::ModelMetadataRequest& request, - inference::ModelMetadataResponse* response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a ModelMetadataRequest, + // a ModelMetadataResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteModelMetadata function. + auto callback = [this]( + inference::ModelMetadataRequest& request, + inference::ModelMetadataResponse* response, + ::grpc::Status* status) { int64_t requested_model_version; auto err = GetModelVersionFromString(request.version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); - { - TRITONSERVER_Message* model_metadata_message = nullptr; - err = TRITONSERVER_ServerModelMetadata( - tritonserver_.get(), request.name().c_str(), requested_model_version, - &model_metadata_message); - GOTO_IF_ERR(err, earlyexit); + TRITONSERVER_Message* model_metadata_message = nullptr; + err = TRITONSERVER_ServerModelMetadata( + tritonserver_.get(), request.name().c_str(), requested_model_version, + &model_metadata_message); + GOTO_IF_ERR(err, earlyexit); - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_metadata_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_metadata_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); - triton::common::TritonJson::Value model_metadata_json; - err = model_metadata_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + triton::common::TritonJson::Value model_metadata_json; + err = model_metadata_json.Parse(buffer, byte_size); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = model_metadata_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + const char* name; + size_t namelen; + err = model_metadata_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - response->set_name(std::string(name, namelen)); + response->set_name(std::string(name, namelen)); - if (model_metadata_json.Find("versions")) { - triton::common::TritonJson::Value versions_json; - err = model_metadata_json.MemberAsArray("versions", &versions_json); - GOTO_IF_ERR(err, earlyexit); + if (model_metadata_json.Find("versions")) { + triton::common::TritonJson::Value versions_json; + err = model_metadata_json.MemberAsArray("versions", &versions_json); + GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < versions_json.ArraySize(); ++idx) { - const char* version; - size_t versionlen; - err = versions_json.IndexAsString(idx, &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - response->add_versions(std::string(version, versionlen)); - } + for (size_t idx = 0; idx < versions_json.ArraySize(); ++idx) { + const char* version; + size_t versionlen; + err = versions_json.IndexAsString(idx, &version, &versionlen); + GOTO_IF_ERR(err, earlyexit); + response->add_versions(std::string(version, versionlen)); } + } + + const char* platform; + size_t platformlen; + err = + model_metadata_json.MemberAsString("platform", &platform, &platformlen); + GOTO_IF_ERR(err, earlyexit); + response->set_platform(std::string(platform, platformlen)); - const char* platform; - size_t platformlen; - err = model_metadata_json.MemberAsString( - "platform", &platform, &platformlen); + if (model_metadata_json.Find("inputs")) { + triton::common::TritonJson::Value inputs_json; + err = model_metadata_json.MemberAsArray("inputs", &inputs_json); GOTO_IF_ERR(err, earlyexit); - response->set_platform(std::string(platform, platformlen)); - if (model_metadata_json.Find("inputs")) { - triton::common::TritonJson::Value inputs_json; - err = model_metadata_json.MemberAsArray("inputs", &inputs_json); + for (size_t idx = 0; idx < inputs_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value io_json; + err = inputs_json.IndexAsObject(idx, &io_json); GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < inputs_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value io_json; - err = inputs_json.IndexAsObject(idx, &io_json); - GOTO_IF_ERR(err, earlyexit); + inference::ModelMetadataResponse::TensorMetadata* io = + response->add_inputs(); - inference::ModelMetadataResponse::TensorMetadata* io = - response->add_inputs(); + const char* name; + size_t namelen; + err = io_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = io_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + const char* datatype; + size_t datatypelen; + err = io_json.MemberAsString("datatype", &datatype, &datatypelen); + GOTO_IF_ERR(err, earlyexit); - const char* datatype; - size_t datatypelen; - err = io_json.MemberAsString("datatype", &datatype, &datatypelen); - GOTO_IF_ERR(err, earlyexit); + io->set_name(std::string(name, namelen)); + io->set_datatype(std::string(datatype, datatypelen)); - io->set_name(std::string(name, namelen)); - io->set_datatype(std::string(datatype, datatypelen)); + if (io_json.Find("shape")) { + triton::common::TritonJson::Value shape_json; + err = io_json.MemberAsArray("shape", &shape_json); + GOTO_IF_ERR(err, earlyexit); - if (io_json.Find("shape")) { - triton::common::TritonJson::Value shape_json; - err = io_json.MemberAsArray("shape", &shape_json); + for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { + int64_t d; + err = shape_json.IndexAsInt(sidx, &d); GOTO_IF_ERR(err, earlyexit); - - for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { - int64_t d; - err = shape_json.IndexAsInt(sidx, &d); - GOTO_IF_ERR(err, earlyexit); - - io->add_shape(d); - } + io->add_shape(d); } } } + } + + if (model_metadata_json.Find("outputs")) { + triton::common::TritonJson::Value outputs_json; + err = model_metadata_json.MemberAsArray("outputs", &outputs_json); + GOTO_IF_ERR(err, earlyexit); - if (model_metadata_json.Find("outputs")) { - triton::common::TritonJson::Value outputs_json; - err = model_metadata_json.MemberAsArray("outputs", &outputs_json); + for (size_t idx = 0; idx < outputs_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value io_json; + err = outputs_json.IndexAsObject(idx, &io_json); GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < outputs_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value io_json; - err = outputs_json.IndexAsObject(idx, &io_json); - GOTO_IF_ERR(err, earlyexit); + inference::ModelMetadataResponse::TensorMetadata* io = + response->add_outputs(); + + const char* name; + size_t namelen; + err = io_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - inference::ModelMetadataResponse::TensorMetadata* io = - response->add_outputs(); + const char* datatype; + size_t datatypelen; + err = io_json.MemberAsString("datatype", &datatype, &datatypelen); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = io_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + io->set_name(std::string(name, namelen)); + io->set_datatype(std::string(datatype, datatypelen)); - const char* datatype; - size_t datatypelen; - err = io_json.MemberAsString("datatype", &datatype, &datatypelen); + if (io_json.Find("shape")) { + triton::common::TritonJson::Value shape_json; + err = io_json.MemberAsArray("shape", &shape_json); GOTO_IF_ERR(err, earlyexit); - io->set_name(std::string(name, namelen)); - io->set_datatype(std::string(datatype, datatypelen)); - - if (io_json.Find("shape")) { - triton::common::TritonJson::Value shape_json; - err = io_json.MemberAsArray("shape", &shape_json); + for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { + int64_t d; + err = shape_json.IndexAsInt(sidx, &d); GOTO_IF_ERR(err, earlyexit); - - for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { - int64_t d; - err = shape_json.IndexAsInt(sidx, &d); - GOTO_IF_ERR(err, earlyexit); - - io->add_shape(d); - } + io->add_shape(d); } } } - - TRITONSERVER_MessageDelete(model_metadata_message); } earlyexit: + TRITONSERVER_MessageDelete(model_metadata_message); GrpcStatusUtil::Create(status, err); TRITONSERVER_ErrorDelete(err); }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::METADATA); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelMetadataRequest, inference::ModelMetadataResponse>( - "ModelMetadata", 0, OnRegisterModelMetadata, OnExecuteModelMetadata, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->ModelMetadata to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->ModelMetadata( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::ModelMetadataRequest, inference::ModelMetadataResponse>( + "ModelMetadata", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterModelConfig() { - auto OnRegisterModelConfig = - [this]( - ::grpc::ServerContext* ctx, inference::ModelConfigRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelConfig( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelConfig = [this]( - inference::ModelConfigRequest& request, - inference::ModelConfigResponse* response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a ModelConfigRequest, + // a ModelConfigResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteModelConfig function. + auto callback = [this]( + inference::ModelConfigRequest& request, + inference::ModelConfigResponse* response, + ::grpc::Status* status) { int64_t requested_model_version; auto err = GetModelVersionFromString(request.version(), &requested_model_version); - if (err == nullptr) { - TRITONSERVER_Message* model_config_message = nullptr; - err = TRITONSERVER_ServerModelConfig( - tritonserver_.get(), request.name().c_str(), requested_model_version, - 1 /* config_version */, &model_config_message); - if (err == nullptr) { - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_config_message, &buffer, &byte_size); - if (err == nullptr) { - ::google::protobuf::util::JsonStringToMessage( - ::google::protobuf::stringpiece_internal::StringPiece( - buffer, (int)byte_size), - response->mutable_config()); - } - TRITONSERVER_MessageDelete(model_config_message); - } - } + GOTO_IF_ERR(err, earlyexit); + TRITONSERVER_Message* model_config_message = nullptr; + err = TRITONSERVER_ServerModelConfig( + tritonserver_.get(), request.name().c_str(), requested_model_version, + 1 /* config_version */, &model_config_message); + GOTO_IF_ERR(err, earlyexit); + + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_config_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); + + ::google::protobuf::util::JsonStringToMessage( + ::google::protobuf::stringpiece_internal::StringPiece( + buffer, static_cast(byte_size)), + response->mutable_config()); + + earlyexit: + TRITONSERVER_MessageDelete(model_config_message); GrpcStatusUtil::Create(status, err); TRITONSERVER_ErrorDelete(err); }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelConfigRequest, inference::ModelConfigResponse>( - "ModelConfig", 0, OnRegisterModelConfig, OnExecuteModelConfig, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->ModelConfig to register the callback. + // This replaces the use of CommonCallData. + non_inference_callback_service_->ModelConfig( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::ModelConfigRequest, inference::ModelConfigResponse>( + "ModelConfig", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterModelStatistics() { - auto OnRegisterModelStatistics = - [this]( - ::grpc::ServerContext* ctx, - inference::ModelStatisticsRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestModelStatistics( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteModelStatistics = [this]( - inference::ModelStatisticsRequest& - request, - inference::ModelStatisticsResponse* - response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a ModelStatisticsRequest, + // a ModelStatisticsResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteModelStatistics function. + auto callback = [this]( + inference::ModelStatisticsRequest& request, + inference::ModelStatisticsResponse* response, + ::grpc::Status* status) { #ifdef TRITON_ENABLE_STATS triton::common::TritonJson::Value model_stats_json; @@ -918,24 +929,22 @@ CommonHandler::RegisterModelStatistics() GetModelVersionFromString(request.version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); - { - TRITONSERVER_Message* model_stats_message = nullptr; - err = TRITONSERVER_ServerModelStatistics( - tritonserver_.get(), request.name().c_str(), requested_model_version, - &model_stats_message); - GOTO_IF_ERR(err, earlyexit); + TRITONSERVER_Message* model_stats_message = nullptr; + err = TRITONSERVER_ServerModelStatistics( + tritonserver_.get(), request.name().c_str(), requested_model_version, + &model_stats_message); + GOTO_IF_ERR(err, earlyexit); - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_stats_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_stats_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); - err = model_stats_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + err = model_stats_json.Parse(buffer, byte_size); + GOTO_IF_ERR(err, earlyexit); - TRITONSERVER_MessageDelete(model_stats_message); - } + TRITONSERVER_MessageDelete(model_stats_message); if (model_stats_json.Find("model_stats")) { triton::common::TritonJson::Value stats_json; @@ -1133,11 +1142,17 @@ CommonHandler::RegisterModelStatistics() const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::STATISTICS); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ModelStatisticsRequest, inference::ModelStatisticsResponse>( - "ModelStatistics", 0, OnRegisterModelStatistics, OnExecuteModelStatistics, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->ModelStatistics to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->ModelStatistics( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::ModelStatisticsRequest, + inference::ModelStatisticsResponse>( + "ModelStatistics", non_inference_callback_service_, callback, + restricted_kv)); } template @@ -1163,20 +1178,13 @@ CommonHandler::SetStatisticsDuration( void CommonHandler::RegisterTrace() { - auto OnRegisterTrace = - [this]( - ::grpc::ServerContext* ctx, inference::TraceSettingRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestTraceSetting( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteTrace = [this]( - inference::TraceSettingRequest& request, - inference::TraceSettingResponse* response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a TraceSettingRequest, + // a TraceSettingResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteTrace function. + auto callback = [this]( + inference::TraceSettingRequest& request, + inference::TraceSettingResponse* response, + ::grpc::Status* status) { #ifdef TRITON_ENABLE_TRACING TRITONSERVER_Error* err = nullptr; TRITONSERVER_InferenceTraceLevel level = TRITONSERVER_TRACE_LEVEL_DISABLED; @@ -1447,30 +1455,28 @@ CommonHandler::RegisterTrace() const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::TRACE); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::TraceSettingRequest, inference::TraceSettingResponse>( - "Trace", 0, OnRegisterTrace, OnExecuteTrace, false /* async */, cq_, - restricted_kv, response_delay_); + + // Use non_inference_callback_service_->TraceSetting to register the callback. + // This replaces the use of CommonCallData. + non_inference_callback_service_->TraceSetting( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::TraceSettingRequest, inference::TraceSettingResponse>( + "TraceSetting", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterLogging() { - auto OnRegisterLogging = - [this]( - ::grpc::ServerContext* ctx, inference::LogSettingsRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestLogSettings( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteLogging = [this]( - inference::LogSettingsRequest& request, - inference::LogSettingsResponse* response, - ::grpc::Status* status) { + // Define a lambda function 'callback' that takes a LogSettingsRequest, + // a LogSettingsResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteLogging function. + auto callback = [this]( + inference::LogSettingsRequest& request, + inference::LogSettingsResponse* response, + ::grpc::Status* status) { #ifdef TRITON_ENABLE_LOGGING TRITONSERVER_Error* err = nullptr; @@ -1634,11 +1640,16 @@ CommonHandler::RegisterLogging() const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::LOGGING); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::LogSettingsRequest, inference::LogSettingsResponse>( - "Logging", 0, OnRegisterLogging, OnExecuteLogging, false /* async */, cq_, - restricted_kv, response_delay_); + + // Use non_inference_callback_service_->LogSettings to register the callback. + // This replaces the use of CommonCallData. + non_inference_callback_service_->LogSettings( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::LogSettingsRequest, inference::LogSettingsResponse>( + "LogSettings", non_inference_callback_service_, callback, + restricted_kv)); } void @@ -2285,6 +2296,7 @@ Server::Server( builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); builder_.RegisterService(&service_); builder_.RegisterService(&health_service_); + builder_.RegisterService(&non_inference_callback_service_); builder_.AddChannelArgument( GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); @@ -2383,8 +2395,8 @@ Server::Server( // A common Handler for other non-inference requests common_handler_.reset(new CommonHandler( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, - &health_service_, common_cq_.get(), options.restricted_protocols_, - response_delay)); + &health_service_, &non_inference_callback_service_, common_cq_.get(), + options.restricted_protocols_, response_delay)); // [FIXME] "register" logic is different for infer // Handler for model inference requests. @@ -2546,6 +2558,7 @@ Server::Start() (std::string("Socket '") + server_addr_ + "' already in use ").c_str()); } + // Remove this common_handler_->Start(); for (auto& model_infer_handler : model_infer_handlers_) { model_infer_handler->Start(); diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index 89d8dc7388..2a7a5ff0ba 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -1,4 +1,4 @@ -// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -139,14 +139,16 @@ class Server { inference::GRPCInferenceService::AsyncService service_; ::grpc::health::v1::Health::AsyncService health_service_; + inference::GRPCInferenceService::CallbackService + non_inference_callback_service_; std::unique_ptr<::grpc::Server> server_; - std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; + // std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; std::unique_ptr<::grpc::ServerCompletionQueue> model_infer_cq_; std::unique_ptr<::grpc::ServerCompletionQueue> model_stream_infer_cq_; - std::unique_ptr common_handler_; + // std::unique_ptr common_handler_; std::vector> model_infer_handlers_; std::vector> model_stream_infer_handlers_; From 9c17ed6aba0b29b6f8c600955388b28fbea6e3a3 Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Wed, 19 Feb 2025 08:02:42 -0800 Subject: [PATCH 02/12] Convert all Non Inference RPCs --- src/grpc/grpc_server.cc | 862 +++++++++++++++++++--------------------- 1 file changed, 411 insertions(+), 451 deletions(-) diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 836e32d11e..a15912eb41 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -1655,115 +1655,103 @@ CommonHandler::RegisterLogging() void CommonHandler::RegisterSystemSharedMemoryStatus() { - auto OnRegisterSystemSharedMemoryStatus = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryStatusRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryStatusResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryStatus( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryStatus = - [this]( - inference::SystemSharedMemoryStatusRequest& request, - inference::SystemSharedMemoryStatusResponse* response, - ::grpc::Status* status) { - triton::common::TritonJson::Value shm_status_json( - triton::common::TritonJson::ValueType::ARRAY); - TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); - GOTO_IF_ERR(err, earlyexit); + // Define a lambda function 'callback' that takes a + // SystemSharedMemoryStatusRequest, a SystemSharedMemoryStatusResponse, and a + // grpc::Status. This function performs the same logic as the original + // OnExecuteSystemSharedMemoryStatus function. + auto callback = [this]( + inference::SystemSharedMemoryStatusRequest& request, + inference::SystemSharedMemoryStatusResponse* response, + ::grpc::Status* status) { + triton::common::TritonJson::Value shm_status_json( + triton::common::TritonJson::ValueType::ARRAY); + TRITONSERVER_Error* err = shm_manager_->GetStatus( + request.name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); + GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value shm_region_json; - err = shm_status_json.IndexAsObject(idx, &shm_region_json); - GOTO_IF_ERR(err, earlyexit); + for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value shm_region_json; + err = shm_status_json.IndexAsObject(idx, &shm_region_json); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = shm_region_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + const char* name; + size_t namelen; + err = shm_region_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - const char* key; - size_t keylen; - err = shm_region_json.MemberAsString("key", &key, &keylen); - GOTO_IF_ERR(err, earlyexit); + const char* key; + size_t keylen; + err = shm_region_json.MemberAsString("key", &key, &keylen); + GOTO_IF_ERR(err, earlyexit); - uint64_t offset; - err = shm_region_json.MemberAsUInt("offset", &offset); - GOTO_IF_ERR(err, earlyexit); + uint64_t offset; + err = shm_region_json.MemberAsUInt("offset", &offset); + GOTO_IF_ERR(err, earlyexit); - uint64_t byte_size; - err = shm_region_json.MemberAsUInt("byte_size", &byte_size); - GOTO_IF_ERR(err, earlyexit); + uint64_t byte_size; + err = shm_region_json.MemberAsUInt("byte_size", &byte_size); + GOTO_IF_ERR(err, earlyexit); - inference::SystemSharedMemoryStatusResponse::RegionStatus - region_status; - region_status.set_name(std::string(name, namelen)); - region_status.set_key(std::string(key, keylen)); - region_status.set_offset(offset); - region_status.set_byte_size(byte_size); + inference::SystemSharedMemoryStatusResponse::RegionStatus region_status; + region_status.set_name(std::string(name, namelen)); + region_status.set_key(std::string(key, keylen)); + region_status.set_offset(offset); + region_status.set_byte_size(byte_size); - (*response->mutable_regions())[name] = region_status; - } + (*response->mutable_regions())[name] = region_status; + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + earlyexit: + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryStatusResponse>, - inference::SystemSharedMemoryStatusRequest, - inference::SystemSharedMemoryStatusResponse>( - "SystemSharedMemoryStatus", 0, OnRegisterSystemSharedMemoryStatus, - OnExecuteSystemSharedMemoryStatus, false /* async */, cq_, restricted_kv, - response_delay_); + + // Use non_inference_callback_service_->SystemSharedMemoryStatus to register + // the callback. This replaces the use of CommonCallData. + non_inference_callback_service_->SystemSharedMemoryStatus( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::SystemSharedMemoryStatusRequest, + inference::SystemSharedMemoryStatusResponse>( + "SystemSharedMemoryStatus", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterSystemSharedMemoryRegister() { - auto OnRegisterSystemSharedMemoryRegister = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryRegisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryRegisterResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryRegister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryRegister = - [this]( - inference::SystemSharedMemoryRegisterRequest& request, - inference::SystemSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( - request.name(), request.key(), request.offset(), - request.byte_size()); + // Define a lambda function 'callback' that takes a + // SystemSharedMemoryRegisterRequest, a SystemSharedMemoryRegisterResponse, + // and a grpc::Status. This function performs the same logic as the original + // OnExecuteSystemSharedMemoryRegister function. + auto callback = [this]( + inference::SystemSharedMemoryRegisterRequest& request, + inference::SystemSharedMemoryRegisterResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( + request.name(), request.key(), request.offset(), request.byte_size()); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryRegisterResponse>, - inference::SystemSharedMemoryRegisterRequest, - inference::SystemSharedMemoryRegisterResponse>( - "SystemSharedMemoryRegister", 0, OnRegisterSystemSharedMemoryRegister, - OnExecuteSystemSharedMemoryRegister, false /* async */, cq_, - restricted_kv, response_delay_); + + // Use non_inference_callback_service_->SystemSharedMemoryRegister to register + // the callback. This replaces the use of CommonCallData. + non_inference_callback_service_->SystemSharedMemoryRegister( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::SystemSharedMemoryRegisterRequest, + inference::SystemSharedMemoryRegisterResponse>( + "SystemSharedMemoryRegister", non_inference_callback_service_, + callback, restricted_kv)); } void @@ -1812,447 +1800,419 @@ CommonHandler::RegisterSystemSharedMemoryUnregister() void CommonHandler::RegisterCudaSharedMemoryStatus() { - auto OnRegisterCudaSharedMemoryStatus = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryStatusRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryStatusResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryStatus( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - auto OnExecuteCudaSharedMemoryStatus = - [this]( - inference::CudaSharedMemoryStatusRequest& request, - inference::CudaSharedMemoryStatusResponse* response, - ::grpc::Status* status) { - triton::common::TritonJson::Value shm_status_json( - triton::common::TritonJson::ValueType::ARRAY); - TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); - GOTO_IF_ERR(err, earlyexit); - - for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value shm_region_json; - err = shm_status_json.IndexAsObject(idx, &shm_region_json); - GOTO_IF_ERR(err, earlyexit); + // Define a lambda function 'callback' that takes a + // CudaSharedMemoryStatusRequest, a CudaSharedMemoryStatusResponse, and a + // grpc::Status. This function performs the same logic as the original + // OnExecuteCudaSharedMemoryStatus function. + auto callback = [this]( + inference::CudaSharedMemoryStatusRequest& request, + inference::CudaSharedMemoryStatusResponse* response, + ::grpc::Status* status) { + triton::common::TritonJson::Value shm_status_json( + triton::common::TritonJson::ValueType::ARRAY); + TRITONSERVER_Error* err = shm_manager_->GetStatus( + request.name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = shm_region_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value shm_region_json; + err = shm_status_json.IndexAsObject(idx, &shm_region_json); + GOTO_IF_ERR(err, earlyexit); - uint64_t device_id; - err = shm_region_json.MemberAsUInt("device_id", &device_id); - GOTO_IF_ERR(err, earlyexit); + const char* name; + size_t namelen; + err = shm_region_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - uint64_t byte_size; - err = shm_region_json.MemberAsUInt("byte_size", &byte_size); - GOTO_IF_ERR(err, earlyexit); + uint64_t device_id; + err = shm_region_json.MemberAsUInt("device_id", &device_id); + GOTO_IF_ERR(err, earlyexit); + uint64_t byte_size; + err = shm_region_json.MemberAsUInt("byte_size", &byte_size); + GOTO_IF_ERR(err, earlyexit); - inference::CudaSharedMemoryStatusResponse::RegionStatus region_status; - region_status.set_name(std::string(name, namelen)); - region_status.set_device_id(device_id); - region_status.set_byte_size(byte_size); + inference::CudaSharedMemoryStatusResponse::RegionStatus region_status; + region_status.set_name(std::string(name, namelen)); + region_status.set_device_id(device_id); + region_status.set_byte_size(byte_size); - (*response->mutable_regions())[name] = region_status; - } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + (*response->mutable_regions())[name] = region_status; + } + earlyexit: + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryStatusResponse>, - inference::CudaSharedMemoryStatusRequest, - inference::CudaSharedMemoryStatusResponse>( - "CudaSharedMemoryStatus", 0, OnRegisterCudaSharedMemoryStatus, - OnExecuteCudaSharedMemoryStatus, false /* async */, cq_, restricted_kv, - response_delay_); + + // Use non_inference_callback_service_->CudaSharedMemoryStatus to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->CudaSharedMemoryStatus( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::CudaSharedMemoryStatusRequest, + inference::CudaSharedMemoryStatusResponse>( + "CudaSharedMemoryStatus", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterCudaSharedMemoryRegister() { - auto OnRegisterCudaSharedMemoryRegister = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryRegisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryRegisterResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryRegister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteCudaSharedMemoryRegister = - [this]( - inference::CudaSharedMemoryRegisterRequest& request, - inference::CudaSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; + // Define a lambda function 'callback' that takes a + // CudaSharedMemoryRegisterRequest, a CudaSharedMemoryRegisterResponse, and a + // grpc::Status. This function performs the same logic as the original + // OnExecuteCudaSharedMemoryRegister function. + auto callback = [this]( + inference::CudaSharedMemoryRegisterRequest& request, + inference::CudaSharedMemoryRegisterResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = nullptr; #ifdef TRITON_ENABLE_GPU - err = shm_manager_->RegisterCUDASharedMemory( - request.name(), - reinterpret_cast( - request.raw_handle().c_str()), - request.byte_size(), request.device_id()); + err = shm_manager_->RegisterCUDASharedMemory( + request.name(), + reinterpret_cast( + request.raw_handle().c_str()), + request.byte_size(), request.device_id()); #else - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string( - "failed to register CUDA shared memory region: '" + - request.name() + "', GPUs not supported") - .c_str()); + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "failed to register CUDA shared memory region: '" + request.name() + + "', GPUs not supported") + .c_str()); #endif // TRITON_ENABLE_GPU - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryRegisterResponse>, - inference::CudaSharedMemoryRegisterRequest, - inference::CudaSharedMemoryRegisterResponse>( - "CudaSharedMemoryRegister", 0, OnRegisterCudaSharedMemoryRegister, - OnExecuteCudaSharedMemoryRegister, false /* async */, cq_, restricted_kv, - response_delay_); + + // Use non_inference_callback_service_->CudaSharedMemoryRegister to register + // the callback. This replaces the use of CommonCallData. + non_inference_callback_service_->CudaSharedMemoryRegister( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::CudaSharedMemoryRegisterRequest, + inference::CudaSharedMemoryRegisterResponse>( + "CudaSharedMemoryRegister", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterCudaSharedMemoryUnregister() { - auto OnRegisterCudaSharedMemoryUnregister = - [this]( - ::grpc::ServerContext* ctx, - inference::CudaSharedMemoryUnregisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryUnregisterResponse>* responder, - void* tag) { - this->service_->RequestCudaSharedMemoryUnregister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; + // Define a lambda function 'callback' that takes a + // CudaSharedMemoryUnregisterRequest, a CudaSharedMemoryUnregisterResponse, + // and a grpc::Status. This function performs the same logic as the original + // OnExecuteCudaSharedMemoryUnregister function. + auto callback = [this]( + inference::CudaSharedMemoryUnregisterRequest& request, + inference::CudaSharedMemoryUnregisterResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = nullptr; + if (request.name().empty()) { + err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); + } else { + err = shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_GPU); + } - auto OnExecuteCudaSharedMemoryUnregister = - [this]( - inference::CudaSharedMemoryUnregisterRequest& request, - inference::CudaSharedMemoryUnregisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.name().empty()) { - err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); - } else { - err = - shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_GPU); - } + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::CudaSharedMemoryUnregisterResponse>, - inference::CudaSharedMemoryUnregisterRequest, - inference::CudaSharedMemoryUnregisterResponse>( - "CudaSharedMemoryUnregister", 0, OnRegisterCudaSharedMemoryUnregister, - OnExecuteCudaSharedMemoryUnregister, false /* async */, cq_, - restricted_kv, response_delay_); + // Use non_inference_callback_service_->CudaSharedMemoryUnregister to register + // the callback. This replaces the use of CommonCallData. + non_inference_callback_service_->CudaSharedMemoryUnregister( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::CudaSharedMemoryUnregisterRequest, + inference::CudaSharedMemoryUnregisterResponse>( + "CudaSharedMemoryUnregister", non_inference_callback_service_, + callback, restricted_kv)); } void CommonHandler::RegisterRepositoryIndex() { - auto OnRegisterRepositoryIndex = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryIndexRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestRepositoryIndex( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryIndex = - [this]( - inference::RepositoryIndexRequest& request, - inference::RepositoryIndexResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - uint32_t flags = 0; - if (request.ready()) { - flags |= TRITONSERVER_INDEX_FLAG_READY; - } - - TRITONSERVER_Message* model_index_message = nullptr; - err = TRITONSERVER_ServerModelIndex( - tritonserver_.get(), flags, &model_index_message); - GOTO_IF_ERR(err, earlyexit); + // Define a lambda function 'callback' that takes a RepositoryIndexRequest, + // a RepositoryIndexResponse, and a grpc::Status. This function performs + // the same logic as the original OnExecuteRepositoryIndex function. + auto callback = [this]( + inference::RepositoryIndexRequest& request, + inference::RepositoryIndexResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = nullptr; + if (request.repository_name().empty()) { + uint32_t flags = 0; + if (request.ready()) { + flags |= TRITONSERVER_INDEX_FLAG_READY; + } - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_index_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + TRITONSERVER_Message* model_index_message = nullptr; + err = TRITONSERVER_ServerModelIndex( + tritonserver_.get(), flags, &model_index_message); + GOTO_IF_ERR(err, earlyexit); - triton::common::TritonJson::Value model_index_json; - err = model_index_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_index_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); - err = model_index_json.AssertType( - triton::common::TritonJson::ValueType::ARRAY); - GOTO_IF_ERR(err, earlyexit); + triton::common::TritonJson::Value model_index_json; + err = model_index_json.Parse(buffer, byte_size); + GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value index_json; - err = model_index_json.IndexAsObject(idx, &index_json); - GOTO_IF_ERR(err, earlyexit); + err = model_index_json.AssertType( + triton::common::TritonJson::ValueType::ARRAY); + GOTO_IF_ERR(err, earlyexit); - auto model_index = response->add_models(); + for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value index_json; + err = model_index_json.IndexAsObject(idx, &index_json); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = index_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_name(std::string(name, namelen)); + auto model_index = response->add_models(); - if (index_json.Find("version")) { - const char* version; - size_t versionlen; - err = index_json.MemberAsString("version", &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_version(std::string(version, versionlen)); - } - if (index_json.Find("state")) { - const char* state; - size_t statelen; - err = index_json.MemberAsString("state", &state, &statelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_state(std::string(state, statelen)); - } - if (index_json.Find("reason")) { - const char* reason; - size_t reasonlen; - err = index_json.MemberAsString("reason", &reason, &reasonlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_reason(std::string(reason, reasonlen)); - } - } + const char* name; + size_t namelen; + err = index_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); + model_index->set_name(std::string(name, namelen)); - TRITONSERVER_MessageDelete(model_index_message); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); + if (index_json.Find("version")) { + const char* version; + size_t versionlen; + err = index_json.MemberAsString("version", &version, &versionlen); + GOTO_IF_ERR(err, earlyexit); + model_index->set_version(std::string(version, versionlen)); + } + if (index_json.Find("state")) { + const char* state; + size_t statelen; + err = index_json.MemberAsString("state", &state, &statelen); + GOTO_IF_ERR(err, earlyexit); + model_index->set_state(std::string(state, statelen)); } + if (index_json.Find("reason")) { + const char* reason; + size_t reasonlen; + err = index_json.MemberAsString("reason", &reason, &reasonlen); + GOTO_IF_ERR(err, earlyexit); + model_index->set_reason(std::string(reason, reasonlen)); + } + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + TRITONSERVER_MessageDelete(model_index_message); + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + earlyexit: + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::RepositoryIndexRequest, inference::RepositoryIndexResponse>( - "RepositoryIndex", 0, OnRegisterRepositoryIndex, OnExecuteRepositoryIndex, - false /* async */, cq_, restricted_kv, response_delay_); + + // Use non_inference_callback_service_->RepositoryIndex to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->RepositoryIndex( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::RepositoryIndexRequest, + inference::RepositoryIndexResponse>( + "RepositoryIndex", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterRepositoryModelLoad() { - auto OnRegisterRepositoryModelLoad = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryModelLoadRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelLoadResponse>* responder, - void* tag) { - this->service_->RequestRepositoryModelLoad( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryModelLoad = - [this]( - inference::RepositoryModelLoadRequest& request, - inference::RepositoryModelLoadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - std::vector params; - // WAR for the const-ness check - std::vector const_params; - for (const auto& param_proto : request.parameters()) { - if (param_proto.first == "config") { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kStringParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected string_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterNew( - param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, - param_proto.second.string_param().c_str()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); - break; - } - } - } else if (param_proto.first.rfind("file:", 0) == 0) { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBytesParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected bytes_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterBytesNew( - param_proto.first.c_str(), - param_proto.second.bytes_param().data(), - param_proto.second.bytes_param().length()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); - break; - } - } + // Define a lambda function 'callback' that takes a + // RepositoryModelLoadRequest, a RepositoryModelLoadResponse, and a + // grpc::Status. This function performs the same logic as the original + // OnExecuteRepositoryModelLoad function. + auto callback = [this]( + inference::RepositoryModelLoadRequest& request, + inference::RepositoryModelLoadResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = nullptr; + if (request.repository_name().empty()) { + std::vector params; + // WAR for the const-ness check + std::vector const_params; + for (const auto& param_proto : request.parameters()) { + if (param_proto.first == "config") { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kStringParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected string_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterNew( + param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, + param_proto.second.string_param().c_str()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); } else { err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("unrecognized load parameter '") + - param_proto.first + "'.") - .c_str()); + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); break; } } - if (err == nullptr) { - err = TRITONSERVER_ServerLoadModelWithParameters( - tritonserver_.get(), request.model_name().c_str(), - const_params.data(), const_params.size()); - } - // Assumes no further 'params' access after load API returns - for (auto& param : params) { - TRITONSERVER_ParameterDelete(param); + } else if (param_proto.first.rfind("file:", 0) == 0) { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBytesParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected bytes_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterBytesNew( + param_proto.first.c_str(), + param_proto.second.bytes_param().data(), + param_proto.second.bytes_param().length()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); + break; + } } } else { err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unrecognized load parameter '") + + param_proto.first + "'.") + .c_str()); + break; } + } + if (err == nullptr) { + err = TRITONSERVER_ServerLoadModelWithParameters( + tritonserver_.get(), request.model_name().c_str(), + const_params.data(), const_params.size()); + } + // Assumes no further 'params' access after load API returns + for (auto& param : params) { + TRITONSERVER_ParameterDelete(param); + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::RepositoryModelLoadRequest, - inference::RepositoryModelLoadResponse>( - "RepositoryModelLoad", 0, OnRegisterRepositoryModelLoad, - OnExecuteRepositoryModelLoad, true /* async */, cq_, restricted_kv, - response_delay_); + + // Use non_inference_callback_service_->RepositoryModelLoad to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->RepositoryModelLoad( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::RepositoryModelLoadRequest, + inference::RepositoryModelLoadResponse>( + "RepositoryModelLoad", non_inference_callback_service_, callback, + restricted_kv)); } void CommonHandler::RegisterRepositoryModelUnload() { - auto OnRegisterRepositoryModelUnload = - [this]( - ::grpc::ServerContext* ctx, - inference::RepositoryModelUnloadRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelUnloadResponse>* responder, - void* tag) { - this->service_->RequestRepositoryModelUnload( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteRepositoryModelUnload = - [this]( - inference::RepositoryModelUnloadRequest& request, - inference::RepositoryModelUnloadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - // Check if the dependent models should be removed - bool unload_dependents = false; - for (auto param : request.parameters()) { - if (param.first.compare("unload_dependents") == 0) { - const auto& unload_param = param.second; - if (unload_param.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBoolParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - "invalid value type for 'unload_dependents' parameter, " - "expected " - "bool_param."); - } - unload_dependents = unload_param.bool_param(); - break; - } - } - if (err == nullptr) { - if (unload_dependents) { - err = TRITONSERVER_ServerUnloadModelAndDependents( - tritonserver_.get(), request.model_name().c_str()); - } else { - err = TRITONSERVER_ServerUnloadModel( - tritonserver_.get(), request.model_name().c_str()); - } + // Define a lambda function 'callback' that takes a + // RepositoryModelUnloadRequest, a RepositoryModelUnloadResponse, and a + // grpc::Status. This function performs the same logic as the original + // OnExecuteRepositoryModelUnload function. + auto callback = [this]( + inference::RepositoryModelUnloadRequest& request, + inference::RepositoryModelUnloadResponse* response, + ::grpc::Status* status) { + TRITONSERVER_Error* err = nullptr; + if (request.repository_name().empty()) { + // Check if the dependent models should be removed + bool unload_dependents = false; + for (auto param : request.parameters()) { + if (param.first.compare("unload_dependents") == 0) { + const auto& unload_param = param.second; + if (unload_param.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBoolParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "invalid value type for 'unload_dependents' parameter, " + "expected " + "bool_param."); } + unload_dependents = unload_param.bool_param(); + break; + } + } + if (err == nullptr) { + if (unload_dependents) { + err = TRITONSERVER_ServerUnloadModelAndDependents( + tritonserver_.get(), request.model_name().c_str()); } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); + err = TRITONSERVER_ServerUnloadModel( + tritonserver_.get(), request.model_name().c_str()); } + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; + GrpcStatusUtil::Create(status, err); + TRITONSERVER_ErrorDelete(err); + }; const std::pair& restricted_kv = restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::RepositoryModelUnloadResponse>, - inference::RepositoryModelUnloadRequest, - inference::RepositoryModelUnloadResponse>( - "RepositoryModelUnload", 0, OnRegisterRepositoryModelUnload, - OnExecuteRepositoryModelUnload, true /* async */, cq_, restricted_kv, - response_delay_); + + // Use non_inference_callback_service_->RepositoryModelUnload to register the + // callback. This replaces the use of CommonCallData. + non_inference_callback_service_->RepositoryModelUnload( + // Create a new CommonCallbackData object with the callback function + // and register it with the non_inference_callback_service_. + new CommonCallbackData< + inference::RepositoryModelUnloadRequest, + inference::RepositoryModelUnloadResponse>( + "RepositoryModelUnload", non_inference_callback_service_, callback, + restricted_kv)); } } // namespace From e52cde3e5d23b706412c38f1432a138d8eeda36a Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Mon, 24 Feb 2025 11:03:13 -0800 Subject: [PATCH 03/12] Non Infrence Code Added --- src/grpc/grpc_handler.h | 11 +- src/grpc/grpc_server.cc | 2033 ++++++++++++++------------------------- src/grpc/grpc_server.h | 2 +- 3 files changed, 731 insertions(+), 1315 deletions(-) diff --git a/src/grpc/grpc_handler.h b/src/grpc/grpc_handler.h index 4f1bcdfac0..ce9a30667e 100644 --- a/src/grpc/grpc_handler.h +++ b/src/grpc/grpc_handler.h @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -25,14 +25,23 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once +#include + #include +#include "grpc_service.grpc.pb.h" + namespace triton { namespace server { namespace grpc { class HandlerBase { public: virtual ~HandlerBase() = default; virtual void Start() = 0; virtual void Stop() = 0; + virtual inference::GRPCInferenceService::CallbackService* + GetUnifiedCallbackService() + { + return nullptr; + } }; class ICallData { diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index a15912eb41..ab644e3a07 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -79,560 +79,201 @@ namespace { // are deemed to be not performance critical. //========================================================================= -template -class CommonCallbackData { +// Define your unified callback service that implements several non-inference +// RPCs. +class UnifiedCallbackService + : public inference::GRPCInferenceService::CallbackService { public: - using CallbackFunc = - std::function; - - CommonCallbackData( - const std::string& name, - inference::GRPCInferenceService::CallbackService* service, - const CallbackFunc& callback, - const std::pair& restricted_kv) - : name_(name), service_(service), callback_(callback), - restricted_kv_(restricted_kv) - { - } - - void operator()(RequestType* request) - { - ResponseType response; - ::grpc::Status status; - - if (ExecutePrecondition()) { - callback_(*request, &response, &status); - } else { - status = ::grpc::Status( - ::grpc::StatusCode::UNAVAILABLE, - std::string("This protocol is restricted, expecting header '") + - restricted_kv_.first + "'"); - } - - request->request()->Complete(status); - delete this; - } - - private: - bool ExecutePrecondition() - { - if (!restricted_kv_.first.empty()) { - const auto& metadata = request->context()->client_metadata(); - const auto it = metadata.find(restricted_kv_.first); - return (it != metadata.end()) && (it->second == restricted_kv_.second); - } - return true; - } - - const std::string name_; - inference::GRPCInferenceService::CallbackService* service_; - CallbackFunc callback_; - std::pair restricted_kv_; -}; - -template -class CommonCallData : public ICallData { - public: - using StandardRegisterFunc = std::function; - using StandardCallbackFunc = - std::function; - - CommonCallData( - const std::string& name, const uint64_t id, - const StandardRegisterFunc OnRegister, - const StandardCallbackFunc OnExecute, const bool async, - ::grpc::ServerCompletionQueue* cq, - const std::pair& restricted_kv, - const uint64_t& response_delay = 0) - : name_(name), id_(id), OnRegister_(OnRegister), OnExecute_(OnExecute), - async_(async), cq_(cq), responder_(&ctx_), step_(Steps::START), - restricted_kv_(restricted_kv), response_delay_(response_delay) - { - OnRegister_(&ctx_, &request_, &responder_, this); - LOG_VERBOSE(1) << "Ready for RPC '" << name_ << "', " << id_; - } - - ~CommonCallData() + UnifiedCallbackService( + const std::shared_ptr& server, + const std::shared_ptr& shm_manager, + const std::pair& restrictedKV) + : tritonserver_(server), shm_manager_(shm_manager), + restricted_kv_(restrictedKV) { - if (async_thread_.joinable()) { - async_thread_.join(); - } - } - - bool Process(bool ok) override; - - std::string Name() override { return name_; } - - uint64_t Id() override { return id_; } - - private: - void Execute(); - void AddToCompletionQueue(); - void WriteResponse(); - bool ExecutePrecondition(); - - const std::string name_; - const uint64_t id_; - const StandardRegisterFunc OnRegister_; - const StandardCallbackFunc OnExecute_; - const bool async_; - ::grpc::ServerCompletionQueue* cq_; - - ::grpc::ServerContext ctx_; - ::grpc::Alarm alarm_; - - ResponderType responder_; - RequestType request_; - ResponseType response_; - ::grpc::Status status_; - - std::thread async_thread_; - - Steps step_; - - std::pair restricted_kv_{"", ""}; - - const uint64_t response_delay_; -}; - -template -bool -CommonCallData::Process(bool rpc_ok) -{ - LOG_VERBOSE(1) << "Process for " << name_ << ", rpc_ok=" << rpc_ok << ", " - << id_ << " step " << step_; - - // If RPC failed on a new request then the server is shutting down - // and so we should do nothing (including not registering for a new - // request). If RPC failed on a non-START step then there is nothing - // we can do since we one execute one step. - const bool shutdown = (!rpc_ok && (step_ == Steps::START)); - if (shutdown) { - if (async_thread_.joinable()) { - async_thread_.join(); - } - step_ = Steps::FINISH; - } - - if (step_ == Steps::START) { - // Start a new request to replace this one... - if (!shutdown) { - new CommonCallData( - name_, id_ + 1, OnRegister_, OnExecute_, async_, cq_, restricted_kv_, - response_delay_); - } - - if (!async_) { - // For synchronous calls, execute and write response - // here. - Execute(); - WriteResponse(); - } else { - // For asynchronous calls, delegate the execution to another - // thread. - step_ = Steps::ISSUED; - async_thread_ = std::thread(&CommonCallData::Execute, this); - } - } else if (step_ == Steps::WRITEREADY) { - // Will only come here for asynchronous mode. - WriteResponse(); - } else if (step_ == Steps::COMPLETE) { - step_ = Steps::FINISH; - } - - return step_ != Steps::FINISH; -} - -template -void -CommonCallData::Execute() -{ - if (ExecutePrecondition()) { - OnExecute_(request_, &response_, &status_); - } else { - status_ = ::grpc::Status( - ::grpc::StatusCode::UNAVAILABLE, - std::string("This protocol is restricted, expecting header '") + - restricted_kv_.first + "'"); - } - step_ = Steps::WRITEREADY; - - if (async_) { - // For asynchronous operation, need to add itself onto the completion - // queue so that the response can be written once the object is - // taken up next for execution. - AddToCompletionQueue(); } -} -template -bool -CommonCallData::ExecutePrecondition() -{ - if (!restricted_kv_.first.empty()) { - const auto& metadata = ctx_.client_metadata(); - const auto it = metadata.find(restricted_kv_.first); - return (it != metadata.end()) && (it->second == restricted_kv_.second); - } - return true; -} - -template -void -CommonCallData::AddToCompletionQueue() -{ - alarm_.Set(cq_, gpr_now(gpr_clock_type::GPR_CLOCK_REALTIME), this); -} - -template -void -CommonCallData::WriteResponse() -{ - if (response_delay_ != 0) { - // Will delay the write of the response by the specified time. - // This can be used to test the flow where there are other - // responses available to be written. - LOG_VERBOSE(1) << "Delaying the write of the response by " - << response_delay_ << " seconds"; - std::this_thread::sleep_for(std::chrono::seconds(response_delay_)); - } - step_ = Steps::COMPLETE; - responder_.Finish(response_, status_, this); -} - -// -// CommonHandler -// -// A common handler for all non-inference requests. -// -class CommonHandler : public HandlerBase { - public: - CommonHandler( - const std::string& name, - const std::shared_ptr& tritonserver, - const std::shared_ptr& shm_manager, - TraceManager* trace_manager, - inference::GRPCInferenceService::AsyncService* service, - ::grpc::health::v1::Health::AsyncService* health_service, - inference::GRPCInferenceService::CallbackService* - non_inference_callback_service, - const RestrictedFeatures& restricted_keys, const uint64_t response_delay); - - // Descriptive name of of the handler. - const std::string& Name() const { return name_; } - - // Start handling requests. - void Start() override; - - // Stop handling requests. - void Stop() override; - - private: - void SetUpAllRequests(); - - // [FIXME] turn into generated code - void RegisterServerLive(); - void RegisterServerReady(); - void RegisterHealthCheck(); - void RegisterModelReady(); - void RegisterServerMetadata(); - void RegisterModelMetadata(); - void RegisterModelConfig(); - void RegisterModelStatistics(); - void RegisterTrace(); - void RegisterLogging(); - void RegisterSystemSharedMemoryStatus(); - void RegisterSystemSharedMemoryRegister(); - void RegisterSystemSharedMemoryUnregister(); - void RegisterCudaSharedMemoryStatus(); - void RegisterCudaSharedMemoryRegister(); - void RegisterCudaSharedMemoryUnregister(); - void RegisterRepositoryIndex(); - void RegisterRepositoryModelLoad(); - void RegisterRepositoryModelUnload(); - - // Set count and cumulative duration for 'RegisterModelStatistics()' template TRITONSERVER_Error* SetStatisticsDuration( triton::common::TritonJson::Value& statistics_json, const std::string& statistics_name, - PBTYPE* mutable_statistics_duration_protobuf) const; - - const std::string name_; - std::shared_ptr tritonserver_; - - std::shared_ptr shm_manager_; - TraceManager* trace_manager_; - - inference::GRPCInferenceService::AsyncService* service_; - ::grpc::health::v1::Health::AsyncService* health_service_; - inference::GRPCInferenceService::CallbackService* - non_inference_callback_service_; - - ::grpc::ServerCompletionQueue* cq_; - std::unique_ptr thread_; - RestrictedFeatures restricted_keys_{}; - const uint64_t response_delay_ = 0; -}; + PBTYPE* mutable_statistics_duration_protobuf) + { + triton::common::TritonJson::Value statistics_duration_json; + RETURN_IF_ERR(statistics_json.MemberAsObject( + statistics_name.c_str(), &statistics_duration_json)); + + uint64_t value; + RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("count", &value)); + mutable_statistics_duration_protobuf->set_count(value); + RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("ns", &value)); + mutable_statistics_duration_protobuf->set_ns(value); + return nullptr; + } -CommonHandler::CommonHandler( - const std::string& name, - const std::shared_ptr& tritonserver, - const std::shared_ptr& shm_manager, - TraceManager* trace_manager, - inference::GRPCInferenceService::AsyncService* service, - ::grpc::health::v1::Health::AsyncService* health_service, - ::grpc::ServerCompletionQueue* cq, - const RestrictedFeatures& restricted_keys, - const uint64_t response_delay = 0) - : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), - trace_manager_(trace_manager), service_(service), - health_service_(health_service), - non_inference_callback_service_(non_inference_callback_service), cq_(cq), - restricted_keys_(restricted_keys), response_delay_(response_delay) -{ -} + // Example RPC method: ServerLive + ::grpc::ServerUnaryReactor* ServerLive( + ::grpc::CallbackServerContext* context, + const inference::ServerLiveRequest* request, + inference::ServerLiveResponse* response) override + { + auto* reactor = context->DefaultReactor(); -void -CommonHandler::Start() -{ - // Use a barrier to make sure we don't return until thread has - // started. - auto barrier = std::make_shared(2); - - thread_.reset(new std::thread([this, barrier] { - SetUpAllRequests(); - barrier->Wait(); - - void* tag; - bool ok; - - while (cq_->Next(&tag, &ok)) { - ICallData* call_data = static_cast(tag); - if (!call_data->Process(ok)) { - LOG_VERBOSE(1) << "Done for " << call_data->Name() << ", " - << call_data->Id(); - delete call_data; + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; } } - })); - - barrier->Wait(); - LOG_VERBOSE(1) << "Thread started for " << Name(); -} - -void -CommonHandler::Stop() -{ - if (thread_->joinable()) { - thread_->join(); - } - - LOG_VERBOSE(1) << "Thread exited for " << Name(); -} -void -CommonHandler::SetUpAllRequests() -{ - // Define all the RPCs to be handled by this handler below - // - // Within each of the Register function, the format of RPC specification is: - // 1. A OnRegister function: This will be called when the - // server is ready to receive the requests for this RPC. - // 2. A OnExecute function: This will be called when the - // to process the request. - // 3. Create a CommonCallData object with the above callback - // functions - - // health (GRPC standard) - RegisterHealthCheck(); - // health (Triton) - RegisterServerLive(); - RegisterServerReady(); - RegisterModelReady(); - - // Metadata - RegisterServerMetadata(); - RegisterModelMetadata(); - - // model config - RegisterModelConfig(); - - // shared memory - // system.. - RegisterSystemSharedMemoryStatus(); - RegisterSystemSharedMemoryRegister(); - RegisterSystemSharedMemoryUnregister(); - // cuda.. - RegisterCudaSharedMemoryStatus(); - RegisterCudaSharedMemoryRegister(); - RegisterCudaSharedMemoryUnregister(); - - // model repository - RegisterRepositoryIndex(); - RegisterRepositoryModelLoad(); - RegisterRepositoryModelUnload(); - - // statistics - RegisterModelStatistics(); - - // trace - RegisterTrace(); - - // logging - RegisterLogging(); -} - -void -CommonHandler::RegisterServerLive() -{ - auto OnRegisterServerLive = - [this]( - ::grpc::ServerContext* ctx, inference::ServerLiveRequest* request, - ::grpc::ServerAsyncResponseWriter* - responder, - void* tag) { - this->service_->RequestServerLive( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteServerLive = [this]( - inference::ServerLiveRequest& request, - inference::ServerLiveResponse* response, - ::grpc::Status* status) { + // Business logic for ServerLive. bool live = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsLive(tritonserver_.get(), &live); - response->set_live((err == nullptr) && live); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter, - inference::ServerLiveRequest, inference::ServerLiveResponse>( - "ServerLive", 0, OnRegisterServerLive, OnExecuteServerLive, - false /* async */, cq_, restricted_kv, response_delay_); -} + reactor->Finish(status); + return reactor; + } -// This change leverages the callback API, simplifying the handling of the -// ServerReady request by directly using the non_inference_callback_service_. -void -CommonHandler::RegisterServerReady() -{ - // Define a lambda function 'callback' that takes a ServerReadyRequest, - // a ServerReadyResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteServerReady function. - auto callback = [this]( - inference::ServerReadyRequest& request, - inference::ServerReadyResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ServerReady( + ::grpc::CallbackServerContext* context, + const inference::ServerReadyRequest* request, + inference::ServerReadyResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Business logic for ServerReady. bool ready = false; TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); - response->set_ready((err == nullptr) && ready); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - - // Use non_inference_callback_service_->ServerReady to register the callback. - // This replaces the use of CommonCallData. - non_inference_callback_service_->ServerReady( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::ServerReadyRequest, inference::ServerReadyResponse>( - "ServerReady", non_inference_callback_service_, callback, - restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterHealthCheck() -{ - auto callback = [this]( - ::grpc::health::v1::HealthCheckRequest& request, - ::grpc::health::v1::HealthCheckResponse* response, - ::grpc::Status* status) { - bool live = false; - TRITONSERVER_Error* err = - TRITONSERVER_ServerIsReady(tritonserver_.get(), &live); + // ::grpc::ServerUnaryReactor* Check( + // ::grpc::CallbackServerContext* context, + // const ::grpc::health::v1::HealthCheckRequest* request, + // ::grpc::health::v1::HealthCheckResponse* response) override { + // auto* reactor = context->DefaultReactor(); + + // // (Optionally) Check client metadata for restricted access. + // if (!restricted_kv_.first.empty()) { + // const auto& metadata = context->client_metadata(); + // auto it = metadata.find(restricted_kv_.first); + // if (it == metadata.end() || it->second != restricted_kv_.second) { + // reactor->Finish(::grpc::Status(::grpc::StatusCode::UNAVAILABLE, + // "Missing or mismatched restricted header")); + // return reactor; + // } + // } + + // // Business logic for HealthCheck. + // bool live = false; + // TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), + // &live); + + // auto serving_status = + // ::grpc::health::v1::HealthCheckResponse_ServingStatus_UNKNOWN; if (err == + // nullptr) { + // serving_status = live + // ? ::grpc::health::v1::HealthCheckResponse_ServingStatus_SERVING + // : + // ::grpc::health::v1::HealthCheckResponse_ServingStatus_NOT_SERVING; + // } + // response->set_status(serving_status); + + // ::grpc::Status status; + // GrpcStatusUtil::Create(&status, err); + // TRITONSERVER_ErrorDelete(err); + // reactor->Finish(status); + // return reactor; + // } + + ::grpc::ServerUnaryReactor* ModelReady( + ::grpc::CallbackServerContext* context, + const inference::ModelReadyRequest* request, + inference::ModelReadyResponse* response) override + { + auto* reactor = context->DefaultReactor(); - auto serving_status = - ::grpc::health::v1::HealthCheckResponse_ServingStatus_UNKNOWN; - if (err == nullptr) { - serving_status = - live ? ::grpc::health::v1::HealthCheckResponse_ServingStatus_SERVING - : ::grpc::health::v1:: - HealthCheckResponse_ServingStatus_NOT_SERVING; + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } } - response->set_status(serving_status); - - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - non_inference_callback_service_->Check( - new CommonCallbackData< - ::grpc::health::v1::HealthCheckRequest, - ::grpc::health::v1::HealthCheckResponse>( - "Check", non_inference_callback_service_, callback, restricted_kv)); -} - -void -CommonHandler::RegisterModelReady() -{ - auto callback = [this]( - ::grpc::health::v1::HealthCheckRequest& request, - ::grpc::health::v1::HealthCheckResponse* response, - ::grpc::Status* status) { + // Business logic for ModelReady. bool is_ready = false; int64_t requested_model_version; - auto err = - GetModelVersionFromString(request.version(), &requested_model_version); + TRITONSERVER_Error* err = + GetModelVersionFromString(request->version(), &requested_model_version); if (err == nullptr) { err = TRITONSERVER_ServerModelIsReady( - tritonserver_.get(), request.name().c_str(), requested_model_version, + tritonserver_.get(), request->name().c_str(), requested_model_version, &is_ready); } response->set_ready(is_ready); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::HEALTH); - non_inference_callback_service_->ModelReady( - new CommonCallbackData< - ::grpc::health::v1::HealthCheckRequest, - ::grpc::health::v1::HealthCheckResponse>( - "ModelReady", non_inference_callback_service_, callback, - restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterServerMetadata() -{ - // Define a lambda function 'callback' that takes a ServerMetadataRequest, - // a ServerMetadataResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteServerMetadata function. - auto callback = [this]( - inference::ServerMetadataRequest& request, - inference::ServerMetadataResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ServerMetadata( + ::grpc::CallbackServerContext* context, + const inference::ServerMetadataRequest* request, + inference::ServerMetadataResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Business logic for ServerMetadata. TRITONSERVER_Message* server_metadata_message = nullptr; TRITONSERVER_Error* err = TRITONSERVER_ServerMetadata( tritonserver_.get(), &server_metadata_message); @@ -680,271 +321,271 @@ CommonHandler::RegisterServerMetadata() TRITONSERVER_MessageDelete(server_metadata_message); } - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::METADATA); - - // Use non_inference_callback_service_->ServerMetadata to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->ServerMetadata( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::ServerMetadataRequest, inference::ServerMetadataResponse>( - "ServerMetadata", non_inference_callback_service_, callback, - restricted_kv)); -} - -void -CommonHandler::RegisterModelMetadata() -{ - // Define a lambda function 'callback' that takes a ModelMetadataRequest, - // a ModelMetadataResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteModelMetadata function. - auto callback = [this]( - inference::ModelMetadataRequest& request, - inference::ModelMetadataResponse* response, - ::grpc::Status* status) { - int64_t requested_model_version; - auto err = - GetModelVersionFromString(request.version(), &requested_model_version); - GOTO_IF_ERR(err, earlyexit); + reactor->Finish(status); + return reactor; + } - TRITONSERVER_Message* model_metadata_message = nullptr; - err = TRITONSERVER_ServerModelMetadata( - tritonserver_.get(), request.name().c_str(), requested_model_version, - &model_metadata_message); - GOTO_IF_ERR(err, earlyexit); + ::grpc::ServerUnaryReactor* ModelMetadata( + ::grpc::CallbackServerContext* context, + const inference::ModelMetadataRequest* request, + inference::ModelMetadataResponse* response) override + { + auto* reactor = context->DefaultReactor(); - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_metadata_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - triton::common::TritonJson::Value model_metadata_json; - err = model_metadata_json.Parse(buffer, byte_size); + // Core business logic - kept same as original + int64_t requested_model_version; + auto err = + GetModelVersionFromString(request->version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); + { + TRITONSERVER_Message* model_metadata_message = nullptr; + err = TRITONSERVER_ServerModelMetadata( + tritonserver_.get(), request->name().c_str(), requested_model_version, + &model_metadata_message); + GOTO_IF_ERR(err, earlyexit); - const char* name; - size_t namelen; - err = model_metadata_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_metadata_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); - response->set_name(std::string(name, namelen)); + triton::common::TritonJson::Value model_metadata_json; + err = model_metadata_json.Parse(buffer, byte_size); + GOTO_IF_ERR(err, earlyexit); - if (model_metadata_json.Find("versions")) { - triton::common::TritonJson::Value versions_json; - err = model_metadata_json.MemberAsArray("versions", &versions_json); + const char* name; + size_t namelen; + err = model_metadata_json.MemberAsString("name", &name, &namelen); GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < versions_json.ArraySize(); ++idx) { - const char* version; - size_t versionlen; - err = versions_json.IndexAsString(idx, &version, &versionlen); + response->set_name(std::string(name, namelen)); + + if (model_metadata_json.Find("versions")) { + triton::common::TritonJson::Value versions_json; + err = model_metadata_json.MemberAsArray("versions", &versions_json); GOTO_IF_ERR(err, earlyexit); - response->add_versions(std::string(version, versionlen)); - } - } - const char* platform; - size_t platformlen; - err = - model_metadata_json.MemberAsString("platform", &platform, &platformlen); - GOTO_IF_ERR(err, earlyexit); - response->set_platform(std::string(platform, platformlen)); + for (size_t idx = 0; idx < versions_json.ArraySize(); ++idx) { + const char* version; + size_t versionlen; + err = versions_json.IndexAsString(idx, &version, &versionlen); + GOTO_IF_ERR(err, earlyexit); + response->add_versions(std::string(version, versionlen)); + } + } - if (model_metadata_json.Find("inputs")) { - triton::common::TritonJson::Value inputs_json; - err = model_metadata_json.MemberAsArray("inputs", &inputs_json); + const char* platform; + size_t platformlen; + err = model_metadata_json.MemberAsString( + "platform", &platform, &platformlen); GOTO_IF_ERR(err, earlyexit); + response->set_platform(std::string(platform, platformlen)); - for (size_t idx = 0; idx < inputs_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value io_json; - err = inputs_json.IndexAsObject(idx, &io_json); + if (model_metadata_json.Find("inputs")) { + triton::common::TritonJson::Value inputs_json; + err = model_metadata_json.MemberAsArray("inputs", &inputs_json); GOTO_IF_ERR(err, earlyexit); - inference::ModelMetadataResponse::TensorMetadata* io = - response->add_inputs(); - - const char* name; - size_t namelen; - err = io_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + for (size_t idx = 0; idx < inputs_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value io_json; + err = inputs_json.IndexAsObject(idx, &io_json); + GOTO_IF_ERR(err, earlyexit); - const char* datatype; - size_t datatypelen; - err = io_json.MemberAsString("datatype", &datatype, &datatypelen); - GOTO_IF_ERR(err, earlyexit); + inference::ModelMetadataResponse::TensorMetadata* io = + response->add_inputs(); - io->set_name(std::string(name, namelen)); - io->set_datatype(std::string(datatype, datatypelen)); + const char* name; + size_t namelen; + err = io_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - if (io_json.Find("shape")) { - triton::common::TritonJson::Value shape_json; - err = io_json.MemberAsArray("shape", &shape_json); + const char* datatype; + size_t datatypelen; + err = io_json.MemberAsString("datatype", &datatype, &datatypelen); GOTO_IF_ERR(err, earlyexit); - for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { - int64_t d; - err = shape_json.IndexAsInt(sidx, &d); + io->set_name(std::string(name, namelen)); + io->set_datatype(std::string(datatype, datatypelen)); + + if (io_json.Find("shape")) { + triton::common::TritonJson::Value shape_json; + err = io_json.MemberAsArray("shape", &shape_json); GOTO_IF_ERR(err, earlyexit); - io->add_shape(d); + + for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { + int64_t d; + err = shape_json.IndexAsInt(sidx, &d); + GOTO_IF_ERR(err, earlyexit); + io->add_shape(d); + } } } } - } - - if (model_metadata_json.Find("outputs")) { - triton::common::TritonJson::Value outputs_json; - err = model_metadata_json.MemberAsArray("outputs", &outputs_json); - GOTO_IF_ERR(err, earlyexit); - for (size_t idx = 0; idx < outputs_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value io_json; - err = outputs_json.IndexAsObject(idx, &io_json); + if (model_metadata_json.Find("outputs")) { + triton::common::TritonJson::Value outputs_json; + err = model_metadata_json.MemberAsArray("outputs", &outputs_json); GOTO_IF_ERR(err, earlyexit); - inference::ModelMetadataResponse::TensorMetadata* io = - response->add_outputs(); - - const char* name; - size_t namelen; - err = io_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); + for (size_t idx = 0; idx < outputs_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value io_json; + err = outputs_json.IndexAsObject(idx, &io_json); + GOTO_IF_ERR(err, earlyexit); - const char* datatype; - size_t datatypelen; - err = io_json.MemberAsString("datatype", &datatype, &datatypelen); - GOTO_IF_ERR(err, earlyexit); + inference::ModelMetadataResponse::TensorMetadata* io = + response->add_outputs(); - io->set_name(std::string(name, namelen)); - io->set_datatype(std::string(datatype, datatypelen)); + const char* name; + size_t namelen; + err = io_json.MemberAsString("name", &name, &namelen); + GOTO_IF_ERR(err, earlyexit); - if (io_json.Find("shape")) { - triton::common::TritonJson::Value shape_json; - err = io_json.MemberAsArray("shape", &shape_json); + const char* datatype; + size_t datatypelen; + err = io_json.MemberAsString("datatype", &datatype, &datatypelen); GOTO_IF_ERR(err, earlyexit); - for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { - int64_t d; - err = shape_json.IndexAsInt(sidx, &d); + io->set_name(std::string(name, namelen)); + io->set_datatype(std::string(datatype, datatypelen)); + + if (io_json.Find("shape")) { + triton::common::TritonJson::Value shape_json; + err = io_json.MemberAsArray("shape", &shape_json); GOTO_IF_ERR(err, earlyexit); - io->add_shape(d); + + for (size_t sidx = 0; sidx < shape_json.ArraySize(); ++sidx) { + int64_t d; + err = shape_json.IndexAsInt(sidx, &d); + GOTO_IF_ERR(err, earlyexit); + io->add_shape(d); + } } } } + TRITONSERVER_MessageDelete(model_metadata_message); } - earlyexit: - TRITONSERVER_MessageDelete(model_metadata_message); - GrpcStatusUtil::Create(status, err); + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::METADATA); - - // Use non_inference_callback_service_->ModelMetadata to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->ModelMetadata( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::ModelMetadataRequest, inference::ModelMetadataResponse>( - "ModelMetadata", non_inference_callback_service_, callback, - restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterModelConfig() -{ - // Define a lambda function 'callback' that takes a ModelConfigRequest, - // a ModelConfigResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteModelConfig function. - auto callback = [this]( - inference::ModelConfigRequest& request, - inference::ModelConfigResponse* response, - ::grpc::Status* status) { + ::grpc::ServerUnaryReactor* ModelConfig( + ::grpc::CallbackServerContext* context, + const inference::ModelConfigRequest* request, + inference::ModelConfigResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic int64_t requested_model_version; auto err = - GetModelVersionFromString(request.version(), &requested_model_version); - GOTO_IF_ERR(err, earlyexit); - - TRITONSERVER_Message* model_config_message = nullptr; - err = TRITONSERVER_ServerModelConfig( - tritonserver_.get(), request.name().c_str(), requested_model_version, - 1 /* config_version */, &model_config_message); - GOTO_IF_ERR(err, earlyexit); + GetModelVersionFromString(request->version(), &requested_model_version); + if (err == nullptr) { + TRITONSERVER_Message* model_config_message = nullptr; + err = TRITONSERVER_ServerModelConfig( + tritonserver_.get(), request->name().c_str(), requested_model_version, + 1 /* config_version */, &model_config_message); + if (err == nullptr) { + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_config_message, &buffer, &byte_size); + if (err == nullptr) { + ::google::protobuf::util::JsonStringToMessage( + ::google::protobuf::stringpiece_internal::StringPiece( + buffer, (int)byte_size), + response->mutable_config()); + } + TRITONSERVER_MessageDelete(model_config_message); + } + } - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_config_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - ::google::protobuf::util::JsonStringToMessage( - ::google::protobuf::stringpiece_internal::StringPiece( - buffer, static_cast(byte_size)), - response->mutable_config()); + // Other RPC methods (e.g., ServerReady, HealthCheck) would be implemented + // similarly. + ::grpc::ServerUnaryReactor* ModelStatistics( + ::grpc::CallbackServerContext* context, + const inference::ModelStatisticsRequest* request, + inference::ModelStatisticsResponse* response) override + { + auto* reactor = context->DefaultReactor(); - earlyexit: - TRITONSERVER_MessageDelete(model_config_message); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); - - // Use non_inference_callback_service_->ModelConfig to register the callback. - // This replaces the use of CommonCallData. - non_inference_callback_service_->ModelConfig( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::ModelConfigRequest, inference::ModelConfigResponse>( - "ModelConfig", non_inference_callback_service_, callback, - restricted_kv)); -} + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } -void -CommonHandler::RegisterModelStatistics() -{ - // Define a lambda function 'callback' that takes a ModelStatisticsRequest, - // a ModelStatisticsResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteModelStatistics function. - auto callback = [this]( - inference::ModelStatisticsRequest& request, - inference::ModelStatisticsResponse* response, - ::grpc::Status* status) { + // Core business logic - kept same as original #ifdef TRITON_ENABLE_STATS triton::common::TritonJson::Value model_stats_json; int64_t requested_model_version; auto err = - GetModelVersionFromString(request.version(), &requested_model_version); - GOTO_IF_ERR(err, earlyexit); - - TRITONSERVER_Message* model_stats_message = nullptr; - err = TRITONSERVER_ServerModelStatistics( - tritonserver_.get(), request.name().c_str(), requested_model_version, - &model_stats_message); + GetModelVersionFromString(request->version(), &requested_model_version); GOTO_IF_ERR(err, earlyexit); + { + TRITONSERVER_Message* model_stats_message = nullptr; + err = TRITONSERVER_ServerModelStatistics( + tritonserver_.get(), request->name().c_str(), requested_model_version, + &model_stats_message); + GOTO_IF_ERR(err, earlyexit); - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_stats_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_stats_message, &buffer, &byte_size); + GOTO_IF_ERR(err, earlyexit); - err = model_stats_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + err = model_stats_json.Parse(buffer, byte_size); + GOTO_IF_ERR(err, earlyexit); - TRITONSERVER_MessageDelete(model_stats_message); + TRITONSERVER_MessageDelete(model_stats_message); + } if (model_stats_json.Find("model_stats")) { triton::common::TritonJson::Value stats_json; @@ -1129,62 +770,42 @@ CommonHandler::RegisterModelStatistics() } earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support model statistics"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::STATISTICS); - - // Use non_inference_callback_service_->ModelStatistics to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->ModelStatistics( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::ModelStatisticsRequest, - inference::ModelStatisticsResponse>( - "ModelStatistics", non_inference_callback_service_, callback, - restricted_kv)); -} -template -TRITONSERVER_Error* -CommonHandler::SetStatisticsDuration( - triton::common::TritonJson::Value& statistics_json, - const std::string& statistics_name, - PBTYPE* mutable_statistics_duration_protobuf) const -{ - triton::common::TritonJson::Value statistics_duration_json; - RETURN_IF_ERR(statistics_json.MemberAsObject( - statistics_name.c_str(), &statistics_duration_json)); + reactor->Finish(status); + return reactor; + } - uint64_t value; - RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("count", &value)); - mutable_statistics_duration_protobuf->set_count(value); - RETURN_IF_ERR(statistics_duration_json.MemberAsUInt("ns", &value)); - mutable_statistics_duration_protobuf->set_ns(value); + ::grpc::ServerUnaryReactor* TraceSetting( + ::grpc::CallbackServerContext* context, + const inference::TraceSettingRequest* request, + inference::TraceSettingResponse* response) override + { + auto* reactor = context->DefaultReactor(); - return nullptr; -} + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } -void -CommonHandler::RegisterTrace() -{ - // Define a lambda function 'callback' that takes a TraceSettingRequest, - // a TraceSettingResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteTrace function. - auto callback = [this]( - inference::TraceSettingRequest& request, - inference::TraceSettingResponse* response, - ::grpc::Status* status) { + // Core business logic - kept same as original #ifdef TRITON_ENABLE_TRACING TRITONSERVER_Error* err = nullptr; TRITONSERVER_InferenceTraceLevel level = TRITONSERVER_TRACE_LEVEL_DISABLED; @@ -1195,29 +816,29 @@ CommonHandler::RegisterTrace() InferenceTraceMode trace_mode; TraceConfigMap config_map; - if (!request.model_name().empty()) { + if (!request->model_name().empty()) { bool ready = false; - GOTO_IF_ERR( - TRITONSERVER_ServerModelIsReady( - tritonserver_.get(), request.model_name().c_str(), - -1 /* model version */, &ready), - earlyexit); + err = TRITONSERVER_ServerModelIsReady( + tritonserver_.get(), request->model_name().c_str(), + -1 /* model version */, &ready); + GOTO_IF_ERR(err, earlyexit); if (!ready) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, - (std::string("Request for unknown model : ") + request.model_name()) + (std::string("Request for unknown model : ") + + request->model_name()) .c_str()); GOTO_IF_ERR(err, earlyexit); } } // Update trace setting - if (!request.settings().empty()) { + if (!request->settings().empty()) { TraceManager::NewSetting new_setting; { static std::string setting_name = "trace_file"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, "trace file location can not be updated through network " @@ -1227,8 +848,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_level"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_level_ = true; } else { @@ -1258,8 +879,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_rate"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_rate_ = true; } else if (it->second.value().size() == 1) { @@ -1298,8 +919,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "trace_count"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_count_ = true; } else if (it->second.value().size() == 1) { @@ -1347,8 +968,8 @@ CommonHandler::RegisterTrace() } { static std::string setting_name = "log_frequency"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { if (it->second.value().size() == 0) { new_setting.clear_log_frequency_ = true; } else if (it->second.value().size() == 1) { @@ -1386,16 +1007,16 @@ CommonHandler::RegisterTrace() } } - err = - trace_manager_->UpdateTraceSetting(request.model_name(), new_setting); + err = trace_manager_->UpdateTraceSetting( + request->model_name(), new_setting); GOTO_IF_ERR(err, earlyexit); } - // Get current trace setting, this is needed even if the setting - // has been updated above as some values may not be provided in the request. + // Get current trace setting trace_manager_->GetTraceSetting( - request.model_name(), &level, &rate, &count, &log_frequency, &filepath, + request->model_name(), &level, &rate, &count, &log_frequency, &filepath, &trace_mode, &config_map); + // level { inference::TraceSettingResponse::SettingValue level_setting; @@ -1411,6 +1032,7 @@ CommonHandler::RegisterTrace() } (*response->mutable_settings())["trace_level"] = level_setting; } + (*response->mutable_settings())["trace_rate"].add_value( std::to_string(rate)); (*response->mutable_settings())["trace_count"].add_value( @@ -1442,52 +1064,53 @@ CommonHandler::RegisterTrace() } } } + earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support trace"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::TRACE); - - // Use non_inference_callback_service_->TraceSetting to register the callback. - // This replaces the use of CommonCallData. - non_inference_callback_service_->TraceSetting( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::TraceSettingRequest, inference::TraceSettingResponse>( - "TraceSetting", non_inference_callback_service_, callback, - restricted_kv)); -} -void -CommonHandler::RegisterLogging() -{ - // Define a lambda function 'callback' that takes a LogSettingsRequest, - // a LogSettingsResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteLogging function. - auto callback = [this]( - inference::LogSettingsRequest& request, - inference::LogSettingsResponse* response, - ::grpc::Status* status) { + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* LogSettings( + ::grpc::CallbackServerContext* context, + const inference::LogSettingsRequest* request, + inference::LogSettingsResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + // Core business logic - kept same as original #ifdef TRITON_ENABLE_LOGGING TRITONSERVER_Error* err = nullptr; // Update log settings // Server and Core repos do not have the same Logger object // Each update must be applied to both server and core repo versions - if (!request.settings().empty()) { + if (!request->settings().empty()) { { static std::string setting_name = "log_file"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, "log file location can not be updated through network protocol"); @@ -1496,8 +1119,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_info"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1516,8 +1139,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_warning"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1536,8 +1159,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_error"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1556,8 +1179,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_verbose_level"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1576,8 +1199,8 @@ CommonHandler::RegisterLogging() } { static std::string setting_name = "log_format"; - auto it = request.settings().find(setting_name); - if (it != request.settings().end()) { + auto it = request->settings().find(setting_name); + if (it != request->settings().end()) { const auto& log_param = it->second; if (log_param.parameter_choice_case() != inference::LogSettingsRequest_SettingValue::ParameterChoiceCase:: @@ -1616,6 +1239,7 @@ CommonHandler::RegisterLogging() } GOTO_IF_ERR(err, earlyexit); } + (*response->mutable_settings())["log_file"].set_string_param(LOG_FILE); (*response->mutable_settings())["log_info"].set_bool_param(LOG_INFO_IS_ON); (*response->mutable_settings())["log_warning"].set_bool_param( @@ -1626,47 +1250,79 @@ CommonHandler::RegisterLogging() LOG_VERBOSE_LEVEL); (*response->mutable_settings())["log_format"].set_string_param( LOG_FORMAT_STRING); + earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #else auto err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNAVAILABLE, "the server does not support dynamic logging"); - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); #endif - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::LOGGING); - - // Use non_inference_callback_service_->LogSettings to register the callback. - // This replaces the use of CommonCallData. - non_inference_callback_service_->LogSettings( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::LogSettingsRequest, inference::LogSettingsResponse>( - "LogSettings", non_inference_callback_service_, callback, - restricted_kv)); -} -void -CommonHandler::RegisterSystemSharedMemoryStatus() -{ - // Define a lambda function 'callback' that takes a - // SystemSharedMemoryStatusRequest, a SystemSharedMemoryStatusResponse, and a - // grpc::Status. This function performs the same logic as the original - // OnExecuteSystemSharedMemoryStatus function. - auto callback = [this]( - inference::SystemSharedMemoryStatusRequest& request, - inference::SystemSharedMemoryStatusResponse* response, - ::grpc::Status* status) { + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* SystemSharedMemoryRegister( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryRegisterRequest* request, + inference::SystemSharedMemoryRegisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic - kept same as original + TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( + request->name(), request->key(), request->offset(), + request->byte_size()); + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* SystemSharedMemoryStatus( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryStatusRequest* request, + inference::SystemSharedMemoryStatusResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic - kept same as original triton::common::TritonJson::Value shm_status_json( triton::common::TritonJson::ValueType::ARRAY); TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); + request->name(), TRITONSERVER_MEMORY_CPU, &shm_status_json); GOTO_IF_ERR(err, earlyexit); for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { @@ -1702,116 +1358,80 @@ CommonHandler::RegisterSystemSharedMemoryStatus() } earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - // Use non_inference_callback_service_->SystemSharedMemoryStatus to register - // the callback. This replaces the use of CommonCallData. - non_inference_callback_service_->SystemSharedMemoryStatus( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::SystemSharedMemoryStatusRequest, - inference::SystemSharedMemoryStatusResponse>( - "SystemSharedMemoryStatus", non_inference_callback_service_, callback, - restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterSystemSharedMemoryRegister() -{ - // Define a lambda function 'callback' that takes a - // SystemSharedMemoryRegisterRequest, a SystemSharedMemoryRegisterResponse, - // and a grpc::Status. This function performs the same logic as the original - // OnExecuteSystemSharedMemoryRegister function. - auto callback = [this]( - inference::SystemSharedMemoryRegisterRequest& request, - inference::SystemSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = shm_manager_->RegisterSystemSharedMemory( - request.name(), request.key(), request.offset(), request.byte_size()); + ::grpc::ServerUnaryReactor* CudaSharedMemoryRegister( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryRegisterRequest* request, + inference::CudaSharedMemoryRegisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; +#ifdef TRITON_ENABLE_GPU + err = shm_manager_->RegisterCUDASharedMemory( + request->name(), + reinterpret_cast( + request->raw_handle().c_str()), + request->byte_size(), request->device_id()); +#else + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "failed to register CUDA shared memory region: '" + + request->name() + "', GPUs not supported") + .c_str()); +#endif // TRITON_ENABLE_GPU - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - // Use non_inference_callback_service_->SystemSharedMemoryRegister to register - // the callback. This replaces the use of CommonCallData. - non_inference_callback_service_->SystemSharedMemoryRegister( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::SystemSharedMemoryRegisterRequest, - inference::SystemSharedMemoryRegisterResponse>( - "SystemSharedMemoryRegister", non_inference_callback_service_, - callback, restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterSystemSharedMemoryUnregister() -{ - auto OnRegisterSystemSharedMemoryUnregister = - [this]( - ::grpc::ServerContext* ctx, - inference::SystemSharedMemoryUnregisterRequest* request, - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryUnregisterResponse>* responder, - void* tag) { - this->service_->RequestSystemSharedMemoryUnregister( - ctx, request, responder, this->cq_, this->cq_, tag); - }; - - auto OnExecuteSystemSharedMemoryUnregister = - [this]( - inference::SystemSharedMemoryUnregisterRequest& request, - inference::SystemSharedMemoryUnregisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.name().empty()) { - err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_CPU); - } else { - err = - shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_CPU); - } + ::grpc::ServerUnaryReactor* CudaSharedMemoryStatus( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryStatusRequest* request, + inference::CudaSharedMemoryStatusResponse* response) override + { + auto* reactor = context->DefaultReactor(); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - new CommonCallData< - ::grpc::ServerAsyncResponseWriter< - inference::SystemSharedMemoryUnregisterResponse>, - inference::SystemSharedMemoryUnregisterRequest, - inference::SystemSharedMemoryUnregisterResponse>( - "SystemSharedMemoryUnregister", 0, OnRegisterSystemSharedMemoryUnregister, - OnExecuteSystemSharedMemoryUnregister, false /* async */, cq_, - restricted_kv, response_delay_); -} + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } -void -CommonHandler::RegisterCudaSharedMemoryStatus() -{ - // Define a lambda function 'callback' that takes a - // CudaSharedMemoryStatusRequest, a CudaSharedMemoryStatusResponse, and a - // grpc::Status. This function performs the same logic as the original - // OnExecuteCudaSharedMemoryStatus function. - auto callback = [this]( - inference::CudaSharedMemoryStatusRequest& request, - inference::CudaSharedMemoryStatusResponse* response, - ::grpc::Status* status) { + // Core business logic - kept same as original triton::common::TritonJson::Value shm_status_json( triton::common::TritonJson::ValueType::ARRAY); TRITONSERVER_Error* err = shm_manager_->GetStatus( - request.name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); + request->name(), TRITONSERVER_MEMORY_GPU, &shm_status_json); GOTO_IF_ERR(err, earlyexit); for (size_t idx = 0; idx < shm_status_json.ArraySize(); ++idx) { @@ -1839,380 +1459,163 @@ CommonHandler::RegisterCudaSharedMemoryStatus() (*response->mutable_regions())[name] = region_status; } + earlyexit: - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - // Use non_inference_callback_service_->CudaSharedMemoryStatus to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->CudaSharedMemoryStatus( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::CudaSharedMemoryStatusRequest, - inference::CudaSharedMemoryStatusResponse>( - "CudaSharedMemoryStatus", non_inference_callback_service_, callback, - restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterCudaSharedMemoryRegister() -{ - // Define a lambda function 'callback' that takes a - // CudaSharedMemoryRegisterRequest, a CudaSharedMemoryRegisterResponse, and a - // grpc::Status. This function performs the same logic as the original - // OnExecuteCudaSharedMemoryRegister function. - auto callback = [this]( - inference::CudaSharedMemoryRegisterRequest& request, - inference::CudaSharedMemoryRegisterResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; -#ifdef TRITON_ENABLE_GPU - err = shm_manager_->RegisterCUDASharedMemory( - request.name(), - reinterpret_cast( - request.raw_handle().c_str()), - request.byte_size(), request.device_id()); -#else - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string( - "failed to register CUDA shared memory region: '" + request.name() + - "', GPUs not supported") - .c_str()); -#endif // TRITON_ENABLE_GPU + ::grpc::ServerUnaryReactor* SystemSharedMemoryUnregister( + ::grpc::CallbackServerContext* context, + const inference::SystemSharedMemoryUnregisterRequest* request, + inference::SystemSharedMemoryUnregisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - // Use non_inference_callback_service_->CudaSharedMemoryRegister to register - // the callback. This replaces the use of CommonCallData. - non_inference_callback_service_->CudaSharedMemoryRegister( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::CudaSharedMemoryRegisterRequest, - inference::CudaSharedMemoryRegisterResponse>( - "CudaSharedMemoryRegister", non_inference_callback_service_, callback, - restricted_kv)); -} + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } -void -CommonHandler::RegisterCudaSharedMemoryUnregister() -{ - // Define a lambda function 'callback' that takes a - // CudaSharedMemoryUnregisterRequest, a CudaSharedMemoryUnregisterResponse, - // and a grpc::Status. This function performs the same logic as the original - // OnExecuteCudaSharedMemoryUnregister function. - auto callback = [this]( - inference::CudaSharedMemoryUnregisterRequest& request, - inference::CudaSharedMemoryUnregisterResponse* response, - ::grpc::Status* status) { + // Core business logic - kept same as original TRITONSERVER_Error* err = nullptr; - if (request.name().empty()) { - err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); + if (request->name().empty()) { + err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_CPU); } else { - err = shm_manager_->Unregister(request.name(), TRITONSERVER_MEMORY_GPU); + err = shm_manager_->Unregister(request->name(), TRITONSERVER_MEMORY_CPU); } - GrpcStatusUtil::Create(status, err); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - - // Use non_inference_callback_service_->CudaSharedMemoryUnregister to register - // the callback. This replaces the use of CommonCallData. - non_inference_callback_service_->CudaSharedMemoryUnregister( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::CudaSharedMemoryUnregisterRequest, - inference::CudaSharedMemoryUnregisterResponse>( - "CudaSharedMemoryUnregister", non_inference_callback_service_, - callback, restricted_kv)); -} + reactor->Finish(status); + return reactor; + } -void -CommonHandler::RegisterRepositoryIndex() -{ - // Define a lambda function 'callback' that takes a RepositoryIndexRequest, - // a RepositoryIndexResponse, and a grpc::Status. This function performs - // the same logic as the original OnExecuteRepositoryIndex function. - auto callback = [this]( - inference::RepositoryIndexRequest& request, - inference::RepositoryIndexResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - uint32_t flags = 0; - if (request.ready()) { - flags |= TRITONSERVER_INDEX_FLAG_READY; - } + // Add here + ::grpc::ServerUnaryReactor* CudaSharedMemoryUnregister( + ::grpc::CallbackServerContext* context, + const inference::CudaSharedMemoryUnregisterRequest* request, + inference::CudaSharedMemoryUnregisterResponse* response) override + { + auto* reactor = context->DefaultReactor(); - TRITONSERVER_Message* model_index_message = nullptr; - err = TRITONSERVER_ServerModelIndex( - tritonserver_.get(), flags, &model_index_message); - GOTO_IF_ERR(err, earlyexit); + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson( - model_index_message, &buffer, &byte_size); - GOTO_IF_ERR(err, earlyexit); + // Core business logic - kept same as original + TRITONSERVER_Error* err = nullptr; + if (request->name().empty()) { + err = shm_manager_->UnregisterAll(TRITONSERVER_MEMORY_GPU); + } else { + err = shm_manager_->Unregister(request->name(), TRITONSERVER_MEMORY_GPU); + } - triton::common::TritonJson::Value model_index_json; - err = model_index_json.Parse(buffer, byte_size); - GOTO_IF_ERR(err, earlyexit); + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } - err = model_index_json.AssertType( - triton::common::TritonJson::ValueType::ARRAY); - GOTO_IF_ERR(err, earlyexit); + private: + std::shared_ptr tritonserver_; + std::shared_ptr shm_manager_; + std::pair restricted_kv_; +}; - for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { - triton::common::TritonJson::Value index_json; - err = model_index_json.IndexAsObject(idx, &index_json); - GOTO_IF_ERR(err, earlyexit); +// +// CommonHandler +// +// A common handler for all non-inference requests. +// +class CommonHandler : public HandlerBase { + public: + CommonHandler( + const std::string& name, + const std::shared_ptr& tritonserver, + const std::shared_ptr& shm_manager, + TraceManager* trace_manager, + inference::GRPCInferenceService::AsyncService* service, + ::grpc::health::v1::Health::AsyncService* health_service, + inference::GRPCInferenceService::CallbackService* + non_inference_callback_service, + const RestrictedFeatures& restricted_keys, const uint64_t response_delay); - auto model_index = response->add_models(); + // Implement pure virtual functions + void Start() override {} // No-op for callback implementation + void Stop() override {} // No-op for callback implementation - const char* name; - size_t namelen; - err = index_json.MemberAsString("name", &name, &namelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_name(std::string(name, namelen)); + // Descriptive name of of the handler. + const std::string& Name() const { return name_; } - if (index_json.Find("version")) { - const char* version; - size_t versionlen; - err = index_json.MemberAsString("version", &version, &versionlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_version(std::string(version, versionlen)); - } - if (index_json.Find("state")) { - const char* state; - size_t statelen; - err = index_json.MemberAsString("state", &state, &statelen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_state(std::string(state, statelen)); - } - if (index_json.Find("reason")) { - const char* reason; - size_t reasonlen; - err = index_json.MemberAsString("reason", &reason, &reasonlen); - GOTO_IF_ERR(err, earlyexit); - model_index->set_reason(std::string(reason, reasonlen)); - } - } + void CreateUnifiedCallbackService(); - TRITONSERVER_MessageDelete(model_index_message); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); - } + // Add a new public method to return the non_inference_callback_service_ + inference::GRPCInferenceService::CallbackService* GetUnifiedCallbackService() + { + return non_inference_callback_service_; + } - earlyexit: - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - - // Use non_inference_callback_service_->RepositoryIndex to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->RepositoryIndex( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::RepositoryIndexRequest, - inference::RepositoryIndexResponse>( - "RepositoryIndex", non_inference_callback_service_, callback, - restricted_kv)); -} + private: + const std::string name_; + std::shared_ptr tritonserver_; + std::shared_ptr shm_manager_; + TraceManager* trace_manager_; + inference::GRPCInferenceService::AsyncService* service_; + ::grpc::health::v1::Health::AsyncService* health_service_; + inference::GRPCInferenceService::CallbackService* + non_inference_callback_service_; + std::unique_ptr thread_; + RestrictedFeatures restricted_keys_; + const uint64_t response_delay_; +}; -void -CommonHandler::RegisterRepositoryModelLoad() +CommonHandler::CommonHandler( + const std::string& name, + const std::shared_ptr& tritonserver, + const std::shared_ptr& shm_manager, + TraceManager* trace_manager, + inference::GRPCInferenceService::AsyncService* service, + ::grpc::health::v1::Health::AsyncService* health_service, + inference::GRPCInferenceService::CallbackService* + non_inference_callback_service, + const RestrictedFeatures& restricted_keys, const uint64_t response_delay) + : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), + trace_manager_(trace_manager), service_(service), + health_service_(health_service), + non_inference_callback_service_(non_inference_callback_service), + restricted_keys_(restricted_keys), response_delay_(response_delay) { - // Define a lambda function 'callback' that takes a - // RepositoryModelLoadRequest, a RepositoryModelLoadResponse, and a - // grpc::Status. This function performs the same logic as the original - // OnExecuteRepositoryModelLoad function. - auto callback = [this]( - inference::RepositoryModelLoadRequest& request, - inference::RepositoryModelLoadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - std::vector params; - // WAR for the const-ness check - std::vector const_params; - for (const auto& param_proto : request.parameters()) { - if (param_proto.first == "config") { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kStringParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected string_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterNew( - param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, - param_proto.second.string_param().c_str()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); - break; - } - } - } else if (param_proto.first.rfind("file:", 0) == 0) { - if (param_proto.second.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBytesParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("invalid value type for load parameter '") + - param_proto.first + "', expected bytes_param.") - .c_str()); - break; - } else { - auto param = TRITONSERVER_ParameterBytesNew( - param_proto.first.c_str(), - param_proto.second.bytes_param().data(), - param_proto.second.bytes_param().length()); - if (param != nullptr) { - params.emplace_back(param); - const_params.emplace_back(param); - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "unexpected error on creating Triton parameter"); - break; - } - } - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - (std::string("unrecognized load parameter '") + - param_proto.first + "'.") - .c_str()); - break; - } - } - if (err == nullptr) { - err = TRITONSERVER_ServerLoadModelWithParameters( - tritonserver_.get(), request.model_name().c_str(), - const_params.data(), const_params.size()); - } - // Assumes no further 'params' access after load API returns - for (auto& param : params) { - TRITONSERVER_ParameterDelete(param); - } - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); - } - - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - - // Use non_inference_callback_service_->RepositoryModelLoad to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->RepositoryModelLoad( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::RepositoryModelLoadRequest, - inference::RepositoryModelLoadResponse>( - "RepositoryModelLoad", non_inference_callback_service_, callback, - restricted_kv)); + CreateUnifiedCallbackService(); } void -CommonHandler::RegisterRepositoryModelUnload() +CommonHandler::CreateUnifiedCallbackService() { - // Define a lambda function 'callback' that takes a - // RepositoryModelUnloadRequest, a RepositoryModelUnloadResponse, and a - // grpc::Status. This function performs the same logic as the original - // OnExecuteRepositoryModelUnload function. - auto callback = [this]( - inference::RepositoryModelUnloadRequest& request, - inference::RepositoryModelUnloadResponse* response, - ::grpc::Status* status) { - TRITONSERVER_Error* err = nullptr; - if (request.repository_name().empty()) { - // Check if the dependent models should be removed - bool unload_dependents = false; - for (auto param : request.parameters()) { - if (param.first.compare("unload_dependents") == 0) { - const auto& unload_param = param.second; - if (unload_param.parameter_choice_case() != - inference::ModelRepositoryParameter::ParameterChoiceCase:: - kBoolParam) { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - "invalid value type for 'unload_dependents' parameter, " - "expected " - "bool_param."); - } - unload_dependents = unload_param.bool_param(); - break; - } - } - if (err == nullptr) { - if (unload_dependents) { - err = TRITONSERVER_ServerUnloadModelAndDependents( - tritonserver_.get(), request.model_name().c_str()); - } else { - err = TRITONSERVER_ServerUnloadModel( - tritonserver_.get(), request.model_name().c_str()); - } - } - } else { - err = TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_UNSUPPORTED, - "'repository_name' specification is not supported"); - } - - GrpcStatusUtil::Create(status, err); - TRITONSERVER_ErrorDelete(err); - }; - - const std::pair& restricted_kv = - restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); - - // Use non_inference_callback_service_->RepositoryModelUnload to register the - // callback. This replaces the use of CommonCallData. - non_inference_callback_service_->RepositoryModelUnload( - // Create a new CommonCallbackData object with the callback function - // and register it with the non_inference_callback_service_. - new CommonCallbackData< - inference::RepositoryModelUnloadRequest, - inference::RepositoryModelUnloadResponse>( - "RepositoryModelUnload", non_inference_callback_service_, callback, - restricted_kv)); + const auto& restrictedKV = restricted_keys_.Get(RestrictedCategory::HEALTH); + // Create a single unified callback service instance. + non_inference_callback_service_ = + new UnifiedCallbackService(tritonserver_, shm_manager_, restrictedKV); } } // namespace @@ -2254,9 +1657,8 @@ Server::Server( builder_.AddListeningPort(server_addr_, credentials, &bound_port_); builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); - builder_.RegisterService(&service_); - builder_.RegisterService(&health_service_); - builder_.RegisterService(&non_inference_callback_service_); + // builder_.RegisterService(&service_); + // builder_.RegisterService(&health_service_); builder_.AddChannelArgument( GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); @@ -2342,7 +1744,6 @@ Server::Server( LOG_TABLE_VERBOSE(1, table_printer); } - common_cq_ = builder_.AddCompletionQueue(); model_infer_cq_ = builder_.AddCompletionQueue(); model_stream_infer_cq_ = builder_.AddCompletionQueue(); @@ -2355,8 +1756,15 @@ Server::Server( // A common Handler for other non-inference requests common_handler_.reset(new CommonHandler( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, - &health_service_, &non_inference_callback_service_, common_cq_.get(), + &health_service_, &non_inference_callback_service_, options.restricted_protocols_, response_delay)); + // Use common_handler_ and register + // builder_.RegisterService(non_inference_callback_service_); here Cast to + // CommonHandler to access the method + auto* handler = dynamic_cast(common_handler_.get()); + if (handler != nullptr) { + builder_.RegisterService(handler->GetUnifiedCallbackService()); + } // [FIXME] "register" logic is different for infer // Handler for model inference requests. @@ -2519,7 +1927,6 @@ Server::Start() } // Remove this - common_handler_->Start(); for (auto& model_infer_handler : model_infer_handlers_) { model_infer_handler->Start(); } @@ -2543,13 +1950,13 @@ Server::Stop() // Always shutdown the completion queue after the server. server_->Shutdown(); - common_cq_->Shutdown(); + // common_cq_->Shutdown(); model_infer_cq_->Shutdown(); model_stream_infer_cq_->Shutdown(); // Must stop all handlers explicitly to wait for all the handler // threads to join since they are referencing completion queue, etc. - common_handler_->Stop(); + // common_handler_->Stop(); for (auto& model_infer_handler : model_infer_handlers_) { model_infer_handler->Stop(); } diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index 2a7a5ff0ba..89203d5d0b 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -148,7 +148,7 @@ class Server { std::unique_ptr<::grpc::ServerCompletionQueue> model_infer_cq_; std::unique_ptr<::grpc::ServerCompletionQueue> model_stream_infer_cq_; - // std::unique_ptr common_handler_; + std::unique_ptr common_handler_; std::vector> model_infer_handlers_; std::vector> model_stream_infer_handlers_; From c3c7c90f6c1ab710bdea78f6ccda2dbc8e120db0 Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Wed, 26 Feb 2025 10:04:49 -0800 Subject: [PATCH 04/12] Non Infrence Migrated --- src/grpc/CMakeLists.txt | 3 ++- src/grpc/grpc_handler.h | 3 ++- src/grpc/grpc_server.cc | 13 +++++++------ src/grpc/grpc_server.h | 3 ++- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/grpc/CMakeLists.txt b/src/grpc/CMakeLists.txt index 0cd027a30a..1b0544c37c 100644 --- a/src/grpc/CMakeLists.txt +++ b/src/grpc/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -67,6 +67,7 @@ target_link_libraries( triton-common-json # from repo-common grpc-health-library # from repo-common grpc-service-library # from repo-common + grpccallback-service-library # from repo-common triton-core-serverapi # from repo-core triton-core-serverstub # from repo-core gRPC::grpc++ diff --git a/src/grpc/grpc_handler.h b/src/grpc/grpc_handler.h index ce9a30667e..ad5f551b70 100644 --- a/src/grpc/grpc_handler.h +++ b/src/grpc/grpc_handler.h @@ -30,6 +30,7 @@ #include #include "grpc_service.grpc.pb.h" +#include "grpccallback_service.grpc.pb.h" namespace triton { namespace server { namespace grpc { class HandlerBase { @@ -37,7 +38,7 @@ class HandlerBase { virtual ~HandlerBase() = default; virtual void Start() = 0; virtual void Stop() = 0; - virtual inference::GRPCInferenceService::CallbackService* + virtual inference::GRPCInferenceServiceCallback::CallbackService* GetUnifiedCallbackService() { return nullptr; diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index ab644e3a07..1922801ac6 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -82,7 +82,7 @@ namespace { // Define your unified callback service that implements several non-inference // RPCs. class UnifiedCallbackService - : public inference::GRPCInferenceService::CallbackService { + : public inference::GRPCInferenceServiceCallback::CallbackService { public: UnifiedCallbackService( const std::shared_ptr& server, @@ -1557,7 +1557,7 @@ class CommonHandler : public HandlerBase { TraceManager* trace_manager, inference::GRPCInferenceService::AsyncService* service, ::grpc::health::v1::Health::AsyncService* health_service, - inference::GRPCInferenceService::CallbackService* + inference::GRPCInferenceServiceCallback::CallbackService* non_inference_callback_service, const RestrictedFeatures& restricted_keys, const uint64_t response_delay); @@ -1571,7 +1571,8 @@ class CommonHandler : public HandlerBase { void CreateUnifiedCallbackService(); // Add a new public method to return the non_inference_callback_service_ - inference::GRPCInferenceService::CallbackService* GetUnifiedCallbackService() + inference::GRPCInferenceServiceCallback::CallbackService* + GetUnifiedCallbackService() { return non_inference_callback_service_; } @@ -1583,7 +1584,7 @@ class CommonHandler : public HandlerBase { TraceManager* trace_manager_; inference::GRPCInferenceService::AsyncService* service_; ::grpc::health::v1::Health::AsyncService* health_service_; - inference::GRPCInferenceService::CallbackService* + inference::GRPCInferenceServiceCallback::CallbackService* non_inference_callback_service_; std::unique_ptr thread_; RestrictedFeatures restricted_keys_; @@ -1597,7 +1598,7 @@ CommonHandler::CommonHandler( TraceManager* trace_manager, inference::GRPCInferenceService::AsyncService* service, ::grpc::health::v1::Health::AsyncService* health_service, - inference::GRPCInferenceService::CallbackService* + inference::GRPCInferenceServiceCallback::CallbackService* non_inference_callback_service, const RestrictedFeatures& restricted_keys, const uint64_t response_delay) : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), @@ -1657,7 +1658,7 @@ Server::Server( builder_.AddListeningPort(server_addr_, credentials, &bound_port_); builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); - // builder_.RegisterService(&service_); + builder_.RegisterService(&service_); // builder_.RegisterService(&health_service_); builder_.AddChannelArgument( GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index 89203d5d0b..eceb7f9b85 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -36,6 +36,7 @@ #include "grpc_handler.h" #include "grpc_service.grpc.pb.h" #include "grpc_utils.h" +#include "grpccallback_service.grpc.pb.h" #include "health.grpc.pb.h" #include "infer_handler.h" #include "stream_infer_handler.h" @@ -139,7 +140,7 @@ class Server { inference::GRPCInferenceService::AsyncService service_; ::grpc::health::v1::Health::AsyncService health_service_; - inference::GRPCInferenceService::CallbackService + inference::GRPCInferenceServiceCallback::CallbackService non_inference_callback_service_; std::unique_ptr<::grpc::Server> server_; From ad4a2a281ac07e1c31039025ed3af3f6013f802a Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Mon, 10 Mar 2025 19:15:47 -0700 Subject: [PATCH 05/12] Add missing RPCs --- src/grpc/grpc_server.cc | 325 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 321 insertions(+), 4 deletions(-) diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 1922801ac6..88039b86db 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -82,7 +82,8 @@ namespace { // Define your unified callback service that implements several non-inference // RPCs. class UnifiedCallbackService - : public inference::GRPCInferenceServiceCallback::CallbackService { + : public inference::GRPCInferenceServiceCallback::CallbackService, + public ::grpc::health::v1::Health::CallbackService { public: UnifiedCallbackService( const std::shared_ptr& server, @@ -1537,6 +1538,323 @@ class UnifiedCallbackService return reactor; } + ::grpc::ServerUnaryReactor* RepositoryIndex( + ::grpc::CallbackServerContext* context, + const inference::RepositoryIndexRequest* request, + inference::RepositoryIndexResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + uint32_t flags = 0; + if (request->ready()) { + flags |= TRITONSERVER_INDEX_FLAG_READY; + } + + TRITONSERVER_Message* model_index_message = nullptr; + err = TRITONSERVER_ServerModelIndex( + tritonserver_.get(), flags, &model_index_message); + if (err == nullptr) { + const char* buffer; + size_t byte_size; + err = TRITONSERVER_MessageSerializeToJson( + model_index_message, &buffer, &byte_size); + if (err == nullptr) { + triton::common::TritonJson::Value model_index_json; + err = model_index_json.Parse(buffer, byte_size); + if (err == nullptr) { + err = model_index_json.AssertType( + triton::common::TritonJson::ValueType::ARRAY); + if (err == nullptr) { + for (size_t idx = 0; idx < model_index_json.ArraySize(); ++idx) { + triton::common::TritonJson::Value index_json; + err = model_index_json.IndexAsObject(idx, &index_json); + if (err != nullptr) { + break; + } + + auto model_index = response->add_models(); + + const char* name; + size_t namelen; + err = index_json.MemberAsString("name", &name, &namelen); + if (err != nullptr) { + break; + } + model_index->set_name(std::string(name, namelen)); + + if (index_json.Find("version")) { + const char* version; + size_t versionlen; + err = index_json.MemberAsString( + "version", &version, &versionlen); + if (err != nullptr) { + break; + } + model_index->set_version(std::string(version, versionlen)); + } + if (index_json.Find("state")) { + const char* state; + size_t statelen; + err = index_json.MemberAsString("state", &state, &statelen); + if (err != nullptr) { + break; + } + model_index->set_state(std::string(state, statelen)); + } + if (index_json.Find("reason")) { + const char* reason; + size_t reasonlen; + err = + index_json.MemberAsString("reason", &reason, &reasonlen); + if (err != nullptr) { + break; + } + model_index->set_reason(std::string(reason, reasonlen)); + } + } + } + } + } + TRITONSERVER_MessageDelete(model_index_message); + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* RepositoryModelLoad( + ::grpc::CallbackServerContext* context, + const inference::RepositoryModelLoadRequest* request, + inference::RepositoryModelLoadResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + std::vector params; + // WAR for the const-ness check + std::vector const_params; + + for (const auto& param_proto : request->parameters()) { + if (param_proto.first == "config") { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kStringParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected string_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterNew( + param_proto.first.c_str(), TRITONSERVER_PARAMETER_STRING, + param_proto.second.string_param().c_str()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); + break; + } + } + } else if (param_proto.first.rfind("file:", 0) == 0) { + if (param_proto.second.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBytesParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("invalid value type for load parameter '") + + param_proto.first + "', expected bytes_param.") + .c_str()); + break; + } else { + auto param = TRITONSERVER_ParameterBytesNew( + param_proto.first.c_str(), + param_proto.second.bytes_param().data(), + param_proto.second.bytes_param().length()); + if (param != nullptr) { + params.emplace_back(param); + const_params.emplace_back(param); + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unexpected error on creating Triton parameter"); + break; + } + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unrecognized load parameter '") + + param_proto.first + "'.") + .c_str()); + break; + } + } + + if (err == nullptr) { + err = TRITONSERVER_ServerLoadModelWithParameters( + tritonserver_.get(), request->model_name().c_str(), + const_params.data(), const_params.size()); + } + + // Assumes no further 'params' access after load API returns + for (auto& param : params) { + TRITONSERVER_ParameterDelete(param); + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* RepositoryModelUnload( + ::grpc::CallbackServerContext* context, + const inference::RepositoryModelUnloadRequest* request, + inference::RepositoryModelUnloadResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Core business logic + TRITONSERVER_Error* err = nullptr; + if (request->repository_name().empty()) { + // Check if the dependent models should be removed + bool unload_dependents = false; + for (const auto& param : request->parameters()) { + if (param.first.compare("unload_dependents") == 0) { + const auto& unload_param = param.second; + if (unload_param.parameter_choice_case() != + inference::ModelRepositoryParameter::ParameterChoiceCase:: + kBoolParam) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "invalid value type for 'unload_dependents' parameter, " + "expected bool_param."); + } + unload_dependents = unload_param.bool_param(); + break; + } + } + + if (err == nullptr) { + if (unload_dependents) { + err = TRITONSERVER_ServerUnloadModelAndDependents( + tritonserver_.get(), request->model_name().c_str()); + } else { + err = TRITONSERVER_ServerUnloadModel( + tritonserver_.get(), request->model_name().c_str()); + } + } + } else { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "'repository_name' specification is not supported"); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + ::grpc::ServerUnaryReactor* Check( + ::grpc::CallbackServerContext* context, + const ::grpc::health::v1::HealthCheckRequest* request, + ::grpc::health::v1::HealthCheckResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // (Optionally) Check client metadata for restricted access. + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Check if server is ready + bool ready = false; + TRITONSERVER_Error* err = + TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); + + // Set health status based on server readiness + if (err == nullptr && ready) { + response->set_status(::grpc::health::v1::HealthCheckResponse::SERVING); + } else { + response->set_status( + ::grpc::health::v1::HealthCheckResponse::NOT_SERVING); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + private: std::shared_ptr tritonserver_; std::shared_ptr shm_manager_; @@ -1759,11 +2077,10 @@ Server::Server( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, &health_service_, &non_inference_callback_service_, options.restricted_protocols_, response_delay)); - // Use common_handler_ and register - // builder_.RegisterService(non_inference_callback_service_); here Cast to - // CommonHandler to access the method + // Use common_handler_ and register services auto* handler = dynamic_cast(common_handler_.get()); if (handler != nullptr) { + // Register the unified service directly without casting builder_.RegisterService(handler->GetUnifiedCallbackService()); } From b896688e04a4a6fb18485784d834b19762954d5f Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Mon, 10 Mar 2025 19:50:41 -0700 Subject: [PATCH 06/12] Health API fixed --- src/grpc/grpc_handler.h | 7 +++ src/grpc/grpc_server.cc | 95 +++++++++++++++++++++++++++++++++++------ src/grpc/grpc_server.h | 1 + 3 files changed, 89 insertions(+), 14 deletions(-) diff --git a/src/grpc/grpc_handler.h b/src/grpc/grpc_handler.h index ad5f551b70..405a78d737 100644 --- a/src/grpc/grpc_handler.h +++ b/src/grpc/grpc_handler.h @@ -31,6 +31,7 @@ #include "grpc_service.grpc.pb.h" #include "grpccallback_service.grpc.pb.h" +#include "health.grpc.pb.h" namespace triton { namespace server { namespace grpc { class HandlerBase { @@ -43,6 +44,12 @@ class HandlerBase { { return nullptr; } + + virtual ::grpc::health::v1::Health::CallbackService* + GetHealthCallbackService() + { + return nullptr; + } }; class ICallData { diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 88039b86db..357f63d3ac 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -79,8 +79,61 @@ namespace { // are deemed to be not performance critical. //========================================================================= -// Define your unified callback service that implements several non-inference -// RPCs. +// Define a dedicated health service that implements the health check RPC +class HealthCallbackService + : public ::grpc::health::v1::Health::CallbackService { + public: + HealthCallbackService( + const std::shared_ptr& server, + const std::pair& restrictedKV) + : tritonserver_(server), restricted_kv_(restrictedKV) + { + } + + ::grpc::ServerUnaryReactor* Check( + ::grpc::CallbackServerContext* context, + const ::grpc::health::v1::HealthCheckRequest* request, + ::grpc::health::v1::HealthCheckResponse* response) override + { + auto* reactor = context->DefaultReactor(); + + // Check restricted access if configured + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + auto it = metadata.find(restricted_kv_.first); + if (it == metadata.end() || it->second != restricted_kv_.second) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, + "Missing or mismatched restricted header")); + return reactor; + } + } + + // Check if server is ready + bool ready = false; + TRITONSERVER_Error* err = + TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); + + // Set health status based on server readiness + if (err == nullptr && ready) { + response->set_status(::grpc::health::v1::HealthCheckResponse::SERVING); + } else { + response->set_status( + ::grpc::health::v1::HealthCheckResponse::NOT_SERVING); + } + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + reactor->Finish(status); + return reactor; + } + + private: + std::shared_ptr tritonserver_; + std::pair restricted_kv_; +}; + class UnifiedCallbackService : public inference::GRPCInferenceServiceCallback::CallbackService, public ::grpc::health::v1::Health::CallbackService { @@ -1886,15 +1939,20 @@ class CommonHandler : public HandlerBase { // Descriptive name of of the handler. const std::string& Name() const { return name_; } - void CreateUnifiedCallbackService(); + void CreateCallbackServices(); - // Add a new public method to return the non_inference_callback_service_ + // Add methods to return the callback services inference::GRPCInferenceServiceCallback::CallbackService* GetUnifiedCallbackService() { return non_inference_callback_service_; } + ::grpc::health::v1::Health::CallbackService* GetHealthCallbackService() + { + return health_callback_service_; + } + private: const std::string name_; std::shared_ptr tritonserver_; @@ -1904,6 +1962,7 @@ class CommonHandler : public HandlerBase { ::grpc::health::v1::Health::AsyncService* health_service_; inference::GRPCInferenceServiceCallback::CallbackService* non_inference_callback_service_; + ::grpc::health::v1::Health::CallbackService* health_callback_service_; std::unique_ptr thread_; RestrictedFeatures restricted_keys_; const uint64_t response_delay_; @@ -1923,18 +1982,26 @@ CommonHandler::CommonHandler( trace_manager_(trace_manager), service_(service), health_service_(health_service), non_inference_callback_service_(non_inference_callback_service), - restricted_keys_(restricted_keys), response_delay_(response_delay) + health_callback_service_(nullptr), restricted_keys_(restricted_keys), + response_delay_(response_delay) { - CreateUnifiedCallbackService(); + CreateCallbackServices(); } void -CommonHandler::CreateUnifiedCallbackService() +CommonHandler::CreateCallbackServices() { - const auto& restrictedKV = restricted_keys_.Get(RestrictedCategory::HEALTH); - // Create a single unified callback service instance. - non_inference_callback_service_ = - new UnifiedCallbackService(tritonserver_, shm_manager_, restrictedKV); + // Create the unified callback service for non-inference operations + const auto& inference_restrictedKV = + restricted_keys_.Get(RestrictedCategory::INFERENCE); + non_inference_callback_service_ = new UnifiedCallbackService( + tritonserver_, shm_manager_, inference_restrictedKV); + + // Create the health callback service + const auto& health_restrictedKV = + restricted_keys_.Get(RestrictedCategory::HEALTH); + health_callback_service_ = + new HealthCallbackService(tritonserver_, health_restrictedKV); } } // namespace @@ -1977,7 +2044,6 @@ Server::Server( builder_.AddListeningPort(server_addr_, credentials, &bound_port_); builder_.SetMaxMessageSize(MAX_GRPC_MESSAGE_SIZE); builder_.RegisterService(&service_); - // builder_.RegisterService(&health_service_); builder_.AddChannelArgument( GRPC_ARG_ALLOW_REUSEPORT, options.socket_.reuse_port_); @@ -2075,13 +2141,14 @@ Server::Server( // A common Handler for other non-inference requests common_handler_.reset(new CommonHandler( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, - &health_service_, &non_inference_callback_service_, + &health_service_, nullptr /* non_inference_callback_service */, options.restricted_protocols_, response_delay)); // Use common_handler_ and register services auto* handler = dynamic_cast(common_handler_.get()); if (handler != nullptr) { - // Register the unified service directly without casting + // Register both the unified service and health service builder_.RegisterService(handler->GetUnifiedCallbackService()); + builder_.RegisterService(handler->GetHealthCallbackService()); } // [FIXME] "register" logic is different for infer diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index eceb7f9b85..5777059fcf 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -143,6 +143,7 @@ class Server { inference::GRPCInferenceServiceCallback::CallbackService non_inference_callback_service_; + ::grpc::health::v1::Health::CallbackService* health_callback_service_; std::unique_ptr<::grpc::Server> server_; // std::unique_ptr<::grpc::ServerCompletionQueue> common_cq_; From 636b933df784d1041ed9eba1a8571ad967e858f7 Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Tue, 11 Mar 2025 11:36:07 -0700 Subject: [PATCH 07/12] Test Script for new Service Names --- test_grpc_callbacks.sh | 148 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100755 test_grpc_callbacks.sh diff --git a/test_grpc_callbacks.sh b/test_grpc_callbacks.sh new file mode 100755 index 0000000000..8058af5cdd --- /dev/null +++ b/test_grpc_callbacks.sh @@ -0,0 +1,148 @@ +#!/bin/bash +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note: Before running this script, start Triton server in explicit model control mode: +# tritonserver --model-repository=/path/to/model/repository --model-control-mode=explicit + +# Default server URL +SERVER_URL=${1:-"localhost:8001"} +PROTO_PATH="/mnt/builddir/triton-server/_deps/repo-common-src/protobuf" +PROTO_FILE="${PROTO_PATH}/grpccallback_service.proto" +HEALTH_PROTO="${PROTO_PATH}/health.proto" + +# Colors for output +GREEN='\033[0;32m' +RED='\033[0;31m' +NC='\033[0m' # No Color +BOLD='\033[1m' + +# Function to print test results +print_result() { + local test_name=$1 + local result=$2 + if [ $result -eq 0 ]; then + echo -e "${test_name}: ${GREEN}PASSED${NC}" + else + echo -e "${test_name}: ${RED}FAILED${NC}" + fi +} + +echo -e "\n${BOLD}Testing gRPC Callback RPCs against ${SERVER_URL}${NC}\n" + +# Test Health Check +echo -e "\n${BOLD}Testing Health Check:${NC}" +grpcurl -proto ${HEALTH_PROTO} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + grpc.health.v1.Health/Check +print_result "Health Check" $? + +# Test Repository Index +echo -e "\n${BOLD}Testing Repository Index:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryIndex +print_result "Repository Index" $? + +# Test Model Load +echo -e "\n${BOLD}Testing Model Load:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelLoad +print_result "Model Load" $? + +# Wait for model to load +sleep 2 + +# Test Model Unload +echo -e "\n${BOLD}Testing Model Unload:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelUnload +print_result "Model Unload" $? + +# Test Server Live +echo -e "\n${BOLD}Testing Server Live:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerLive +print_result "Server Live" $? + +# Test Server Ready +echo -e "\n${BOLD}Testing Server Ready:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerReady +print_result "Server Ready" $? + +# Load model again before testing Model Ready +echo -e "\n${BOLD}Loading model for Model Ready test:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"model_name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/RepositoryModelLoad +print_result "Model Load" $? + +# Wait for model to load +sleep 2 + +# Test Model Ready +echo -e "\n${BOLD}Testing Model Ready:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ModelReady +print_result "Model Ready" $? + +# Test Server Metadata +echo -e "\n${BOLD}Testing Server Metadata:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ServerMetadata +print_result "Server Metadata" $? + +# Test Model Metadata +echo -e "\n${BOLD}Testing Model Metadata:${NC}" +grpcurl -proto ${PROTO_FILE} \ + --import-path ${PROTO_PATH} \ + -plaintext -d '{"name": "simple"}' \ + ${SERVER_URL} \ + inference.GRPCInferenceServiceCallback/ModelMetadata +print_result "Model Metadata" $? + +echo -e "\n${BOLD}Test Summary:${NC}" +echo "----------------------------------------" \ No newline at end of file From bb626ad2e22bfb1c18c23209e8d279159dec0fe9 Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Tue, 11 Mar 2025 18:15:48 -0700 Subject: [PATCH 08/12] Fix RestrictedFeatures --- src/grpc/grpc_server.cc | 255 +++++++++++++++++----------------------- 1 file changed, 105 insertions(+), 150 deletions(-) diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 357f63d3ac..54dc98dfc9 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -85,8 +85,8 @@ class HealthCallbackService public: HealthCallbackService( const std::shared_ptr& server, - const std::pair& restrictedKV) - : tritonserver_(server), restricted_kv_(restrictedKV) + RestrictedFeatures& restricted_keys_) + : tritonserver_(server), restricted_keys_(restricted_keys_) { } @@ -98,10 +98,12 @@ class HealthCallbackService auto* reactor = context->DefaultReactor(); // Check restricted access if configured - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -131,19 +133,18 @@ class HealthCallbackService private: std::shared_ptr tritonserver_; - std::pair restricted_kv_; + RestrictedFeatures restricted_keys_; }; class UnifiedCallbackService - : public inference::GRPCInferenceServiceCallback::CallbackService, - public ::grpc::health::v1::Health::CallbackService { + : public inference::GRPCInferenceServiceCallback::CallbackService { public: UnifiedCallbackService( const std::shared_ptr& server, const std::shared_ptr& shm_manager, - const std::pair& restrictedKV) + RestrictedFeatures& restricted_keys_) : tritonserver_(server), shm_manager_(shm_manager), - restricted_kv_(restrictedKV) + restricted_keys_(restricted_keys_) { } @@ -174,10 +175,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -206,10 +209,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -230,45 +235,6 @@ class UnifiedCallbackService return reactor; } - // ::grpc::ServerUnaryReactor* Check( - // ::grpc::CallbackServerContext* context, - // const ::grpc::health::v1::HealthCheckRequest* request, - // ::grpc::health::v1::HealthCheckResponse* response) override { - // auto* reactor = context->DefaultReactor(); - - // // (Optionally) Check client metadata for restricted access. - // if (!restricted_kv_.first.empty()) { - // const auto& metadata = context->client_metadata(); - // auto it = metadata.find(restricted_kv_.first); - // if (it == metadata.end() || it->second != restricted_kv_.second) { - // reactor->Finish(::grpc::Status(::grpc::StatusCode::UNAVAILABLE, - // "Missing or mismatched restricted header")); - // return reactor; - // } - // } - - // // Business logic for HealthCheck. - // bool live = false; - // TRITONSERVER_Error* err = TRITONSERVER_ServerIsReady(tritonserver_.get(), - // &live); - - // auto serving_status = - // ::grpc::health::v1::HealthCheckResponse_ServingStatus_UNKNOWN; if (err == - // nullptr) { - // serving_status = live - // ? ::grpc::health::v1::HealthCheckResponse_ServingStatus_SERVING - // : - // ::grpc::health::v1::HealthCheckResponse_ServingStatus_NOT_SERVING; - // } - // response->set_status(serving_status); - - // ::grpc::Status status; - // GrpcStatusUtil::Create(&status, err); - // TRITONSERVER_ErrorDelete(err); - // reactor->Finish(status); - // return reactor; - // } - ::grpc::ServerUnaryReactor* ModelReady( ::grpc::CallbackServerContext* context, const inference::ModelReadyRequest* request, @@ -277,10 +243,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -316,10 +284,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -390,10 +360,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -551,10 +523,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -603,10 +577,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::STATISTICS); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -848,10 +824,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::TRACE); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1143,10 +1121,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::LOGGING); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1330,10 +1310,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1361,10 +1343,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1427,10 +1411,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1470,10 +1456,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1530,10 +1518,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1565,10 +1555,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1599,10 +1591,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1706,10 +1700,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1815,10 +1811,12 @@ class UnifiedCallbackService auto* reactor = context->DefaultReactor(); // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); + if (!restricted_kv.first.empty()) { const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { + auto it = metadata.find(restricted_kv.first); + if (it == metadata.end() || it->second != restricted_kv.second) { reactor->Finish(::grpc::Status( ::grpc::StatusCode::UNAVAILABLE, "Missing or mismatched restricted header")); @@ -1869,49 +1867,10 @@ class UnifiedCallbackService return reactor; } - ::grpc::ServerUnaryReactor* Check( - ::grpc::CallbackServerContext* context, - const ::grpc::health::v1::HealthCheckRequest* request, - ::grpc::health::v1::HealthCheckResponse* response) override - { - auto* reactor = context->DefaultReactor(); - - // (Optionally) Check client metadata for restricted access. - if (!restricted_kv_.first.empty()) { - const auto& metadata = context->client_metadata(); - auto it = metadata.find(restricted_kv_.first); - if (it == metadata.end() || it->second != restricted_kv_.second) { - reactor->Finish(::grpc::Status( - ::grpc::StatusCode::UNAVAILABLE, - "Missing or mismatched restricted header")); - return reactor; - } - } - - // Check if server is ready - bool ready = false; - TRITONSERVER_Error* err = - TRITONSERVER_ServerIsReady(tritonserver_.get(), &ready); - - // Set health status based on server readiness - if (err == nullptr && ready) { - response->set_status(::grpc::health::v1::HealthCheckResponse::SERVING); - } else { - response->set_status( - ::grpc::health::v1::HealthCheckResponse::NOT_SERVING); - } - - ::grpc::Status status; - GrpcStatusUtil::Create(&status, err); - TRITONSERVER_ErrorDelete(err); - reactor->Finish(status); - return reactor; - } - private: std::shared_ptr tritonserver_; std::shared_ptr shm_manager_; - std::pair restricted_kv_; + RestrictedFeatures restricted_keys_; }; // @@ -1992,16 +1951,12 @@ void CommonHandler::CreateCallbackServices() { // Create the unified callback service for non-inference operations - const auto& inference_restrictedKV = - restricted_keys_.Get(RestrictedCategory::INFERENCE); - non_inference_callback_service_ = new UnifiedCallbackService( - tritonserver_, shm_manager_, inference_restrictedKV); + non_inference_callback_service_ = + new UnifiedCallbackService(tritonserver_, shm_manager_, restricted_keys_); // Create the health callback service - const auto& health_restrictedKV = - restricted_keys_.Get(RestrictedCategory::HEALTH); health_callback_service_ = - new HealthCallbackService(tritonserver_, health_restrictedKV); + new HealthCallbackService(tritonserver_, restricted_keys_); } } // namespace From 5c97cc644f7b8804556ff5b4627a26ed26cf92ac Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Sun, 30 Mar 2025 16:07:55 -0700 Subject: [PATCH 09/12] First Commit for Model Infer --- src/grpc/grpc_server.cc | 44 ++++++- src/grpc/infer_handler.cc | 261 +++++++++++++++++++++++++++++++++++++- src/grpc/infer_handler.h | 81 +++++++++++- 3 files changed, 377 insertions(+), 9 deletions(-) diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 54dc98dfc9..76a4f4a06a 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -141,10 +141,17 @@ class UnifiedCallbackService public: UnifiedCallbackService( const std::shared_ptr& server, + TraceManager* trace_manager, const std::shared_ptr& shm_manager, - RestrictedFeatures& restricted_keys_) + grpc_compression_level compression_level, + RestrictedFeatures& restricted_keys_, + const std::string& forward_header_pattern) : tritonserver_(server), shm_manager_(shm_manager), - restricted_keys_(restricted_keys_) + trace_manager_(trace_manager), restricted_keys_(restricted_keys_), + model_infer_handler_( + "ModelInferCallbackHandler", tritonserver_, trace_manager_, + shm_manager_, compression_level, restricted_keys_, + forward_header_pattern) { } @@ -166,6 +173,24 @@ class UnifiedCallbackService return nullptr; } + ::grpc::ServerUnaryReactor* ModelInfer( + ::grpc::CallbackServerContext* context, + const inference::ModelInferRequest* request, + inference::ModelInferResponse* response) override + { + // 1. Create reactor for this RPC - This is incorrect for callback API. + // The reactor is obtained from the context, but we don't need it here + // directly. The handler function will obtain and manage it. + + // 2. Process request and start inference by calling the *member handler*. + // The handler function itself returns the reactor. + return model_infer_handler_.HandleModelInfer( + context, request, response); // CORRECTED CALL + + // 3. Return reactor to gRPC - Handled by returning the result of the line + // above. + } + // Example RPC method: ServerLive ::grpc::ServerUnaryReactor* ServerLive( ::grpc::CallbackServerContext* context, @@ -1870,7 +1895,9 @@ class UnifiedCallbackService private: std::shared_ptr tritonserver_; std::shared_ptr shm_manager_; + TraceManager* trace_manager_; RestrictedFeatures restricted_keys_; + ModelInferCallbackHandler model_infer_handler_; }; // @@ -1951,9 +1978,16 @@ void CommonHandler::CreateCallbackServices() { // Create the unified callback service for non-inference operations - non_inference_callback_service_ = - new UnifiedCallbackService(tritonserver_, shm_manager_, restricted_keys_); - + // Pass all required arguments to the UnifiedCallbackService constructor + non_inference_callback_service_ = new UnifiedCallbackService( + "CommonHandler", tritonserver_, + trace_manager_, // Pass the trace manager from CommonHandler + shm_manager_, + grpc_compression_level::GRPC_COMPRESS_LEVEL_NONE, // Provide a default + // compression level + restricted_keys_, + "" // Provide an empty default for forward_header_pattern + ); // Create the health callback service health_callback_service_ = new HealthCallbackService(tritonserver_, restricted_keys_); diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 9c7eef48bb..c754af50da 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -1,4 +1,4 @@ -// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -649,6 +649,265 @@ InferRequestComplete( } } +ModelInferCallbackHandler::ModelInferCallbackHandler( + const std::string& name, + const std::shared_ptr& tritonserver, + TraceManager* trace_manager, + const std::shared_ptr& shm_manager, + grpc_compression_level compression_level, + RestrictedFeatures& restricted_keys, + const std::string& forward_header_pattern) + : name_(name), tritonserver_(tritonserver), trace_manager_(trace_manager), + shm_manager_(shm_manager), compression_level_(compression_level), + restricted_kv_(restricted_keys.Get(RestrictedCategory::INFERENCE)), + header_forward_pattern_(forward_header_pattern), + header_forward_regex_(forward_header_pattern) +{ + FAIL_IF_ERR( + TRITONSERVER_ResponseAllocatorNew( + &allocator_, InferResponseAlloc, InferResponseFree, + InferResponseStart), + "creating inference response allocator"); + FAIL_IF_ERR( + TRITONSERVER_ResponseAllocatorSetQueryFunction( + allocator_, OutputBufferQuery), + "setting allocator's query function"); + FAIL_IF_ERR( + TRITONSERVER_ResponseAllocatorSetBufferAttributesFunction( + allocator_, OutputBufferAttributes), + "setting allocator's output buffer attributes function"); +} + +ModelInferCallbackHandler::~ModelInferCallbackHandler() +{ + LOG_TRITONSERVER_ERROR( + TRITONSERVER_ResponseAllocatorDelete(allocator_), + "deleting response allocator"); +} + +::grpc::ServerUnaryReactor* +ModelInferCallbackHandler::HandleModelInfer( + ::grpc::CallbackServerContext* context, + const inference::ModelInferRequest* request, + inference::ModelInferResponse* response) +{ + auto* reactor = context->DefaultReactor(); + + // Check preconditions + if (!ExecutePrecondition(context)) { + reactor->Finish(::grpc::Status( + ::grpc::StatusCode::UNAVAILABLE, "This protocol is restricted")); + return reactor; + } + + // Create callback state + auto callback_state = std::make_unique( + response, reactor, context, tritonserver_); + + // Execute the request + Execute(context, request, response, reactor, callback_state); + + return reactor; +} + +void +ModelInferCallbackHandler::InferResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) +{ + std::unique_ptr callback_state( + static_cast(userp)); + + if (response != nullptr) { + // Dereference callback_state->response_ to pass a reference + TRITONSERVER_Error* err = InferResponseCompleteCommon( + callback_state->tritonserver_.get(), + response, // Pass the TRITONSERVER_InferenceResponse* + *(callback_state + ->response_), // Pass the inference::ModelInferResponse& + callback_state->alloc_payload_); // Pass the AllocPayload<...> + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + + callback_state->reactor_->Finish(status); + } else { + callback_state->reactor_->Finish( + ::grpc::Status(::grpc::StatusCode::INTERNAL, "null response")); + } + + TRITONSERVER_InferenceResponseDelete( + response); // Delete the TRITONSERVER_InferenceResponse +} + +bool +ModelInferCallbackHandler::ExecutePrecondition( + ::grpc::CallbackServerContext* context) +{ + if (!restricted_kv_.first.empty()) { + const auto& metadata = context->client_metadata(); + const auto it = metadata.find(restricted_kv_.first); + return (it != metadata.end()) && (it->second == restricted_kv_.second); + } + return true; +} + +// Implement the new private helper function +TRITONSERVER_Error* +ModelInferCallbackHandler::ForwardHeadersAsParametersCallback( + TRITONSERVER_InferenceRequest* irequest, + const ::grpc::CallbackServerContext* context) +{ + TRITONSERVER_Error* err = nullptr; + // Use the members stored in *this* specific handler instance + if (!header_forward_pattern_.empty()) { + const auto& metadata = + context->client_metadata(); // Use the passed context + for (const auto& pair : metadata) { + // Need to convert grpc::string_ref to std::string for RE2/Triton API + std::string key_str(pair.first.data(), pair.first.length()); + std::string value_str(pair.second.data(), pair.second.length()); + + // Use the regex member stored in *this* handler instance + if (RE2::PartialMatch(key_str, header_forward_regex_)) { + err = TRITONSERVER_InferenceRequestSetStringParameter( + irequest, key_str.c_str(), value_str.c_str()); + if (err != nullptr) { + break; // Exit loop on error + } + } + } + } + return err; +} + +void +ModelInferCallbackHandler::Execute( + ::grpc::CallbackServerContext* context, + const inference::ModelInferRequest* request, + inference::ModelInferResponse* response, + ::grpc::ServerUnaryReactor* reactor, + std::unique_ptr& callback_state) +{ + TRITONSERVER_Error* err = nullptr; + int64_t requested_model_version; + if (err == nullptr) { + err = GetModelVersionFromString( + request->model_version(), &requested_model_version); + } + + // Check if model has decoupled transaction policy + if (err == nullptr) { + uint32_t txn_flags; + err = TRITONSERVER_ServerModelTransactionProperties( + tritonserver_.get(), request->model_name().c_str(), + requested_model_version, &txn_flags, nullptr /* voidp */); + if ((err == nullptr) && (txn_flags & TRITONSERVER_TXN_DECOUPLED) != 0) { + err = TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "ModelInfer RPC doesn't support models with decoupled " + "transaction policy"); + } + } + + // Create the inference request + TRITONSERVER_InferenceRequest* irequest = nullptr; + if (err == nullptr) { + err = TRITONSERVER_InferenceRequestNew( + &irequest, tritonserver_.get(), request->model_name().c_str(), + requested_model_version); + } + + // Set metadata and parameters + if (err == nullptr) { + // Create a local StateParameters object for the call + StateParameters state_params; + // Pass the local object as the third argument + err = SetInferenceRequestMetadata(irequest, *request, state_params); + } + + // Call the correct private helper function here + if (err == nullptr) { + err = ForwardHeadersAsParametersCallback(irequest, context); + } + + // Handle input tensors and shared memory + if (err == nullptr) { + err = InferGRPCToInput( + tritonserver_, shm_manager_, *request, + &callback_state->serialized_data_, irequest, + &callback_state->shm_regions_info_); + } + + // Set up allocator payload + if (err == nullptr) { + err = InferAllocatorPayload( + tritonserver_, shm_manager_, *request, + std::move(callback_state->serialized_data_), response, + &callback_state->alloc_payload_, &callback_state->shm_regions_info_); + } + + // Set response callback + if (err == nullptr) { + err = TRITONSERVER_InferenceRequestSetResponseCallback( + irequest, allocator_, &callback_state->alloc_payload_, + InferResponseComplete, callback_state.get()); + } + + // Get request ID for logging + const char* request_id = ""; + if (irequest != nullptr) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestId(irequest, &request_id), + "unable to retrieve request ID string"); + } + + if (!strncmp(request_id, "", 1)) { + request_id = ""; + } + + // Set up tracing if enabled + TRITONSERVER_InferenceTrace* triton_trace = nullptr; +#ifdef TRITON_ENABLE_TRACING + if (err == nullptr && trace_manager_ != nullptr) { + GrpcServerCarrier carrier(context); + auto start_options = + trace_manager_->GetTraceStartOptions(carrier, request->model_name()); + callback_state->trace_ = + std::move(trace_manager_->SampleTrace(start_options)); + if (callback_state->trace_ != nullptr) { + triton_trace = callback_state->trace_->trace_; + } + } +#endif // TRITON_ENABLE_TRACING + + // Issue async inference request + if (err == nullptr) { + err = TRITONSERVER_ServerInferAsync( + tritonserver_.get(), irequest, triton_trace); + } + + // Handle errors or complete successfully + if (err == nullptr) { + // Success case - callback_state ownership transferred to callback + callback_state.release(); + } else { + // Error case + LOG_VERBOSE(1) << "[request id: " << request_id << "] " + << "Infer failed: " << TRITONSERVER_ErrorMessage(err); + + ::grpc::Status status; + GrpcStatusUtil::Create(&status, err); // Use existing utility + TRITONSERVER_ErrorDelete(err); + + if (irequest != nullptr) { + TRITONSERVER_InferenceRequestDelete(irequest); + } + + // Complete RPC with error + reactor->Finish(status); + } +} //=========================================================================== // The following section contains the handling mechanism for ModelInfer RPC. // This implementation is tuned towards performance and reducing latency. diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 86428a514e..4fb9c60e4f 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -1,4 +1,4 @@ -// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -34,13 +34,13 @@ #include #include +#include "../restricted_features.h" #include "../tracer.h" #include "grpc_handler.h" #include "grpc_service.grpc.pb.h" #include "grpc_utils.h" #include "triton/common/logging.h" #include "triton/core/tritonserver.h" - // Unique IDs are only needed when debugging. They only appear in // verbose logging. #ifndef NDEBUG @@ -938,7 +938,7 @@ class InferHandlerState { // FIXME: Is there a better way to put task on the // completion queue rather than using alarm object? // The alarm object will add a new task to the back of the - // completion queue when it expires or when it’s cancelled. + // completion queue when it expires or when it's cancelled. state->alarm_.Set( cq_, gpr_now(gpr_clock_type::GPR_CLOCK_REALTIME), state); } @@ -1532,6 +1532,81 @@ InferHandler:: return err; } +class ModelInferCallbackHandler { + public: + ModelInferCallbackHandler( + const std::string& name, + const std::shared_ptr& tritonserver, + TraceManager* trace_manager, + const std::shared_ptr& shm_manager, + grpc_compression_level compression_level, + RestrictedFeatures& restricted_keys, + const std::string& forward_header_pattern); + + ~ModelInferCallbackHandler(); + + ::grpc::ServerUnaryReactor* HandleModelInfer( + ::grpc::CallbackServerContext* context, + const inference::ModelInferRequest* request, + inference::ModelInferResponse* response); + + private: + // Define CallbackState first, before any methods that use it + struct CallbackState { + CallbackState( + inference::ModelInferResponse* response, + ::grpc::ServerUnaryReactor* reactor, + ::grpc::CallbackServerContext* context, + const std::shared_ptr& tritonserver) + : response_(response), reactor_(reactor), context_(context), + tritonserver_(tritonserver) + { + } + + inference::ModelInferResponse* response_; + ::grpc::ServerUnaryReactor* reactor_; + ::grpc::CallbackServerContext* context_; + std::shared_ptr tritonserver_; + + // Request resources + AllocPayload alloc_payload_; + std::list serialized_data_; + std::vector> + shm_regions_info_; + +#ifdef TRITON_ENABLE_TRACING + std::shared_ptr trace_; +#endif // TRITON_ENABLE_TRACING + }; + + TRITONSERVER_Error* ForwardHeadersAsParametersCallback( + TRITONSERVER_InferenceRequest* irequest, + const ::grpc::CallbackServerContext* context); + // Now Execute can use CallbackState + void Execute( + ::grpc::CallbackServerContext* context, + const inference::ModelInferRequest* request, + inference::ModelInferResponse* response, + ::grpc::ServerUnaryReactor* reactor, + std::unique_ptr& callback_state); + + static void InferResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, + void* userp); + + bool ExecutePrecondition(::grpc::CallbackServerContext* context); + + const std::string name_; + std::shared_ptr tritonserver_; + TraceManager* trace_manager_; + std::shared_ptr shm_manager_; + TRITONSERVER_ResponseAllocator* allocator_; + + grpc_compression_level compression_level_; + const std::pair restricted_kv_; + const std::string header_forward_pattern_; + re2::RE2 header_forward_regex_; +}; // // ModelInferHandler // From 205e5f05d5c422ee3a63d66abf2dad489ff36e26 Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Mon, 31 Mar 2025 11:25:13 -0700 Subject: [PATCH 10/12] Add response queue --- src/grpc/grpc_server.cc | 1 + src/grpc/infer_handler.cc | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 0907e42e1d..0099ff3761 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -140,6 +140,7 @@ class UnifiedCallbackService : public inference::GRPCInferenceServiceCallback::CallbackService { public: UnifiedCallbackService( + const std::string& name, const std::shared_ptr& server, TraceManager* trace_manager, const std::shared_ptr& shm_manager, diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index c754af50da..a099e4ca30 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -839,11 +839,17 @@ ModelInferCallbackHandler::Execute( &callback_state->shm_regions_info_); } + auto response_queue = + std::make_shared>(1); + inference::ModelInferResponse* response_ptr = + response_queue->GetNonDecoupledResponse(); + *response_ptr = *response; + // Set up allocator payload if (err == nullptr) { err = InferAllocatorPayload( tritonserver_, shm_manager_, *request, - std::move(callback_state->serialized_data_), response, + std::move(callback_state->serialized_data_), response_queue, &callback_state->alloc_payload_, &callback_state->shm_regions_info_); } From 3e80da101b57c38d5295ac37c31f69ff3951134d Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Wed, 9 Apr 2025 08:42:02 -0700 Subject: [PATCH 11/12] Working Callback ModelInfer --- src/grpc/infer_handler.cc | 304 ++++++++++++++++++++++++++++--------- src/grpc/infer_handler.h | 308 +++++++++++++++++++++++++++++++++++++- 2 files changed, 543 insertions(+), 69 deletions(-) diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 2cda3754f9..16e241634e 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -113,6 +113,29 @@ InferResponseAlloc( buffer_userp, actual_memory_type, actual_memory_type_id); } +// Make sure to keep InferResponseAllocCallback and OutputBufferQuery logic in +// sync +TRITONSERVER_Error* +InferResponseAllocCallback( + TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name, + size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type, + int64_t preferred_memory_type_id, void* userp, void** buffer, + void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type, + int64_t* actual_memory_type_id) +{ + AllocPayloadCallback* payload = + reinterpret_cast*>( + userp); + + // ModelInfer RPC expects exactly one response per request. Hence, + // Get pointer directly from the modified payload instead of the queue. + inference::ModelInferResponse* response = payload->response_ptr_; + return ResponseAllocatorHelper( + allocator, tensor_name, byte_size, preferred_memory_type, + preferred_memory_type_id, response, payload->shm_map_, buffer, + buffer_userp, actual_memory_type, actual_memory_type_id); +} + // Make sure to keep InferResponseAlloc and OutputBufferQuery logic in sync TRITONSERVER_Error* OutputBufferQuery( @@ -120,8 +143,9 @@ OutputBufferQuery( const char* tensor_name, size_t* byte_size, TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) { - AllocPayload* payload = - reinterpret_cast*>(userp); + AllocPayloadCallback* payload = + reinterpret_cast*>( + userp); return OutputBufferQueryHelper( allocator, tensor_name, byte_size, payload->shm_map_, memory_type, @@ -136,8 +160,9 @@ OutputBufferAttributes( TRITONSERVER_BufferAttributes* buffer_attributes, void* userp, void* buffer_userp) { - AllocPayload* payload = - reinterpret_cast*>(userp); + AllocPayloadCallback* payload = + reinterpret_cast*>( + userp); return OutputBufferAttributesHelper( allocator, tensor_name, payload->shm_map_, buffer_attributes); @@ -191,12 +216,12 @@ InferGRPCToInputHelper( TRITONSERVER_Error* InferResponseStart(TRITONSERVER_ResponseAllocator* allocator, void* userp) { - AllocPayload* payload = - reinterpret_cast*>(userp); + // AllocPayload* payload = + // reinterpret_cast*>(userp); // ModelInfer RPC expects exactly one response per request. Hence, always call // GetNonDecoupledResponse() to create one response object on response start. - payload->response_queue_->GetNonDecoupledResponse(); + // payload->response_queue_->GetNonDecoupledResponse(); return nullptr; // success } @@ -639,7 +664,7 @@ void InferRequestComplete( TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp) { - LOG_VERBOSE(1) << "ModelInferHandler::InferRequestComplete"; + LOG_VERBOSE(1) << "ModelInferHandler::InferRequestComplete!"; RequestReleasePayload* request_release_payload = static_cast(userp); @@ -665,7 +690,7 @@ ModelInferCallbackHandler::ModelInferCallbackHandler( { FAIL_IF_ERR( TRITONSERVER_ResponseAllocatorNew( - &allocator_, InferResponseAlloc, InferResponseFree, + &allocator_, InferResponseAllocCallback, InferResponseFree, InferResponseStart), "creating inference response allocator"); FAIL_IF_ERR( @@ -685,6 +710,28 @@ ModelInferCallbackHandler::~ModelInferCallbackHandler() "deleting response allocator"); } +/** + * @brief Handles gRPC ModelInfer requests using the callback API pattern + * + * Request flow path: + * 1. Client creates and sends ModelInferRequest via gRPC + * 2. gRPC framework deserializes the protobuf message + * 3. gRPC calls this handler based on service registration + * 4. This function creates a callback state and reactor to manage async + * lifecycle + * 5. The Execute method initiates processing with proper ownership transfer + * + * Memory management: + * - CallbackState manages lifecycle of request/response objects + * - Ownership transfers to completion callbacks for async cleanup + * - Response memory allocation handled through allocator_ + * - Shared memory regions tracked and released after completion + * + * @param context The gRPC server context for this request + * @param request The deserialized ModelInferRequest from client + * @param response Output parameter for the ModelInferResponse to client + * @return ::grpc::ServerUnaryReactor* Reactor that signals request completion + */ ::grpc::ServerUnaryReactor* ModelInferCallbackHandler::HandleModelInfer( ::grpc::CallbackServerContext* context, @@ -714,30 +761,76 @@ void ModelInferCallbackHandler::InferResponseComplete( TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) { + LOG_VERBOSE(1) << "[InferResponseComplete START] Received userp " + "(CallbackState*) address: " + << userp; std::unique_ptr callback_state( static_cast(userp)); - + LOG_VERBOSE(1) << "[InferResponseComplete] CallbackState unique_ptr now owns " + "state at address: " + << callback_state.get(); if (response != nullptr) { - // Dereference callback_state->response_ to pass a reference - TRITONSERVER_Error* err = InferResponseCompleteCommon( - callback_state->tritonserver_.get(), - response, // Pass the TRITONSERVER_InferenceResponse* - *(callback_state - ->response_), // Pass the inference::ModelInferResponse& - callback_state->alloc_payload_); // Pass the AllocPayload<...> + // Use the pre-allocated response directly from the callback state + ::grpc::Status status = ::grpc::Status::OK; + + // Get the response from the payload's response queue as a fallback + LOG_VERBOSE(1) + << "[InferResponseComplete] Attempting to retrieve response pointer " + "directly from callback_state->response_ which points to: " + << callback_state->response_; + inference::ModelInferResponse* grpc_response = callback_state->response_; + + // If not available in callback state, try to get from response queue + if (grpc_response == nullptr) { + LOG_VERBOSE(1) + << "[InferResponseComplete] >>> Fallback Triggered! grpc_response " + "from state was NULL, attempting fallback from queue."; + grpc_response = callback_state->alloc_payload_.response_ptr_; + } - ::grpc::Status status; - GrpcStatusUtil::Create(&status, err); - TRITONSERVER_ErrorDelete(err); + if (grpc_response != nullptr) { + // Process the response + LOG_VERBOSE(1) + << "InferResponseComplete: Checking response object at address: " + << grpc_response; + TRITONSERVER_Error* err = InferResponseCompleteCommonCallback( + callback_state->tritonserver_.get(), response, *grpc_response, + callback_state->alloc_payload_); + + if (err != nullptr) { + GrpcStatusUtil::Create(&status, err); + TRITONSERVER_ErrorDelete(err); + } + } else { + status = ::grpc::Status( + ::grpc::StatusCode::INTERNAL, + "response object not found in callback"); + } - callback_state->reactor_->Finish(status); + // For callback API, we complete the RPC by finishing the reactor + // Only finish the reactor when we get the final response or on error + if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) || !status.ok()) { + callback_state->reactor_->Finish(status); + } } else { + // Handle null response case callback_state->reactor_->Finish( ::grpc::Status(::grpc::StatusCode::INTERNAL, "null response")); } - TRITONSERVER_InferenceResponseDelete( - response); // Delete the TRITONSERVER_InferenceResponse +#ifdef TRITON_ENABLE_TRACING + if (callback_state->trace_ != nullptr) { + callback_state->trace_timestamps_.emplace_back(std::make_pair( + "INFER_RESPONSE_COMPLETE", TraceManager::CaptureTimestamp())); + } +#endif // TRITON_ENABLE_TRACING + + // Always delete the TRITONSERVER_InferenceResponse + if (response != nullptr) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceResponseDelete(response), + "deleting inference response"); + } } bool @@ -790,19 +883,23 @@ ModelInferCallbackHandler::Execute( std::unique_ptr& callback_state) { TRITONSERVER_Error* err = nullptr; + TRITONSERVER_InferenceRequest* irequest = nullptr; + LOG_VERBOSE(1) << "[Execute START] Incoming response object address: " + << response; + // --- Step 1: Receive & Validate --- int64_t requested_model_version; - if (err == nullptr) { - err = GetModelVersionFromString( - request->model_version(), &requested_model_version); - } + err = GetModelVersionFromString( + request->model_version(), &requested_model_version); - // Check if model has decoupled transaction policy + // Check if model has decoupled transaction policy (not supported by this RPC) if (err == nullptr) { uint32_t txn_flags; + // Query model properties err = TRITONSERVER_ServerModelTransactionProperties( tritonserver_.get(), request->model_name().c_str(), requested_model_version, &txn_flags, nullptr /* voidp */); if ((err == nullptr) && (txn_flags & TRITONSERVER_TXN_DECOUPLED) != 0) { + // Set error if decoupled err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_UNSUPPORTED, "ModelInfer RPC doesn't support models with decoupled " @@ -810,72 +907,96 @@ ModelInferCallbackHandler::Execute( } } - // Create the inference request - TRITONSERVER_InferenceRequest* irequest = nullptr; + // --- Step 2: Prepare Triton Request Object --- if (err == nullptr) { + // Create the core Triton request object err = TRITONSERVER_InferenceRequestNew( &irequest, tritonserver_.get(), request->model_name().c_str(), requested_model_version); } - // Set metadata and parameters + // Populate request metadata (ID, sequence flags, priority, params, etc.) if (err == nullptr) { - // Create a local StateParameters object for the call - StateParameters state_params; - // Pass the local object as the third argument + StateParameters state_params; // Temporary params for this call scope err = SetInferenceRequestMetadata(irequest, *request, state_params); } - // Call the correct private helper function here + // Forward relevant gRPC headers as Triton parameters if (err == nullptr) { err = ForwardHeadersAsParametersCallback(irequest, context); } - // Handle input tensors and shared memory + // --- Step 3: Process Input Tensors --- if (err == nullptr) { + // Parse inputs from request, handle shared memory (if any), + // serialize string data, and add data pointers/attributes to irequest. + // Serialized data stored in callback_state->serialized_data_ + // SHM info stored in callback_state->shm_regions_info_ err = InferGRPCToInput( tritonserver_, shm_manager_, *request, &callback_state->serialized_data_, irequest, &callback_state->shm_regions_info_); } - auto response_queue = - std::make_shared>(1); - inference::ModelInferResponse* response_ptr = - response_queue->GetNonDecoupledResponse(); - *response_ptr = *response; + // --- Step 4: Prepare for Response Handling (Callback Specific) --- + std::shared_ptr> response_queue = + nullptr; + if (err == nullptr) { + // Use the externally provided response object directly. + // Store the external response pointer in the state for later access. + callback_state->response_ = response; + LOG_VERBOSE(1) << "[Execute] Stored response object address in " + "callback_state->response_: " + << callback_state->response_; + // Clear the externally provided response object directly. + response->Clear(); // Ensure it's empty before Triton writes to it + } - // Set up allocator payload + // Prepare the allocator payload: info needed by allocation callback later. + // Moves serialized input data into the payload. References the + // response_queue. if (err == nullptr) { - err = InferAllocatorPayload( + err = InferAllocatorPayloadCallback( tritonserver_, shm_manager_, *request, - std::move(callback_state->serialized_data_), response_queue, + std::move(callback_state->serialized_data_), callback_state->response_, &callback_state->alloc_payload_, &callback_state->shm_regions_info_); } - // Set response callback + // --- Step 5: Setup Automatic Cleanup Payloads & Register Callbacks --- + // Create payload for request release callback (manages irequest lifetime) + auto request_release_payload = std::make_unique( + std::shared_ptr( + irequest, [](TRITONSERVER_InferenceRequest* r) { + // Custom deleter: Ensures delete is called via shared_ptr lifecycle + if (r != nullptr) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(r), + "deleting inference request via shared_ptr custom deleter"); + } + })); + + // Register the release callback (cleans up request_release_payload & + // irequest) if (err == nullptr) { + err = TRITONSERVER_InferenceRequestSetReleaseCallback( + irequest, InferRequestComplete, request_release_payload.get()); + } + + // Register the response callback (processes result, finishes RPC, cleans up + // callback_state) + if (err == nullptr) { + // Note: Passing callback_state.get() transfers potential ownership to the + // callback mechanism upon success (see step 7). err = TRITONSERVER_InferenceRequestSetResponseCallback( irequest, allocator_, &callback_state->alloc_payload_, InferResponseComplete, callback_state.get()); } - // Get request ID for logging - const char* request_id = ""; - if (irequest != nullptr) { - LOG_TRITONSERVER_ERROR( - TRITONSERVER_InferenceRequestId(irequest, &request_id), - "unable to retrieve request ID string"); - } - - if (!strncmp(request_id, "", 1)) { - request_id = ""; - } - - // Set up tracing if enabled + // --- Optional: Setup Tracing --- TRITONSERVER_InferenceTrace* triton_trace = nullptr; #ifdef TRITON_ENABLE_TRACING if (err == nullptr && trace_manager_ != nullptr) { + // Setup and start tracing if configured GrpcServerCarrier carrier(context); auto start_options = trace_manager_->GetTraceStartOptions(carrier, request->model_name()); @@ -887,31 +1008,75 @@ ModelInferCallbackHandler::Execute( } #endif // TRITON_ENABLE_TRACING - // Issue async inference request + // Get request ID for logging, handle potential null irequest if error + // occurred early + const char* request_id_cstr = ""; + std::string request_id = ""; + if (irequest != nullptr) { + auto id_err = TRITONSERVER_InferenceRequestId(irequest, &request_id_cstr); + if (id_err == nullptr && request_id_cstr != nullptr && + strlen(request_id_cstr) > 0) { + request_id = request_id_cstr; + } + TRITONSERVER_ErrorDelete(id_err); // Delete error from ID retrieval if any + } + + + // --- Step 6: Start Asynchronous Inference --- if (err == nullptr) { err = TRITONSERVER_ServerInferAsync( tritonserver_.get(), irequest, triton_trace); } - // Handle errors or complete successfully + // --- Step 7/8: Handle Outcome (Success or Error) --- if (err == nullptr) { - // Success case - callback_state ownership transferred to callback + // --- Success Path --- + // Inference successfully submitted to Triton core. + // Release ownership of payloads to the callback mechanism. + // Callbacks (InferResponseComplete, InferRequestComplete) are now + // responsible for cleanup. + LOG_VERBOSE(1) << "[Execute SUCCESS] Releasing ownership of callback_state " + "at address: " + << callback_state.get(); callback_state.release(); + request_release_payload.release(); + // Execute function finishes here; gRPC call waits for reactor->Finish() in + // callback. + LOG_VERBOSE(1) << "[request id: " << request_id << "] " + << "Async inference submitted successfully."; + } else { - // Error case + // --- Error Path --- + // An error occurred during setup before submitting to Triton. LOG_VERBOSE(1) << "[request id: " << request_id << "] " - << "Infer failed: " << TRITONSERVER_ErrorMessage(err); + << "Setup failed before submitting inference: " + << TRITONSERVER_ErrorMessage(err); + // Create gRPC status from Triton error ::grpc::Status status; - GrpcStatusUtil::Create(&status, err); // Use existing utility - TRITONSERVER_ErrorDelete(err); + GrpcStatusUtil::Create(&status, err); + // Perform explicit cleanup as callbacks won't run + TRITONSERVER_ErrorDelete(err); // Delete the primary Triton error if (irequest != nullptr) { - TRITONSERVER_InferenceRequestDelete(irequest); + // Explicitly delete the request object as the release callback won't run + // Note: The shared_ptr in request_release_payload will handle this + // gracefully + // when the unique_ptr goes out of scope below, due to the custom + // deleter. However, explicit deletion here is safe and clear. + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(irequest), + "explicitly deleting inference request due to setup error"); + irequest = + nullptr; // Avoid potential double delete if shared_ptr logic changes } + // Note: callback_state and request_release_payload unique_ptrs will + // automatically clean up their managed objects when they go out of + // scope now, as .release() was not called. - // Complete RPC with error + // Immediately finish the gRPC call with the error status reactor->Finish(status); + // Execute function finishes here. } } //=========================================================================== @@ -1084,13 +1249,16 @@ ResponseAllocatorHelper( *actual_memory_type = preferred_memory_type; *actual_memory_type_id = preferred_memory_type_id; + LOG_VERBOSE(1) << "AllocatorHelper: Modifying response object at address: " + << response; // We add an output contents even if the 'byte_size' == 0 because we // expect to have a contents for every output. inference::ModelInferResponse::InferOutputTensor* output_tensor = response->add_outputs(); output_tensor->set_name(tensor_name); std::string* raw_output = response->add_raw_output_contents(); - + LOG_VERBOSE(1) << "AllocatorHelper: After add_outputs for " << tensor_name + << ", response->outputs_size() = " << response->outputs_size(); if (byte_size > 0) { const auto& pr = shm_map.find(tensor_name); if (pr != shm_map.end()) { diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 6c873875dd..1b2da1caf1 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -353,6 +353,46 @@ struct AllocPayload { std::list serialized_data_; }; +// +// AllocPayloadCallback +// +// Simple structure that carries the userp payload needed for +// allocation specifically for the Callback API, holding a direct +// pointer to the gRPC response object. +// +template +struct AllocPayloadCallback { + using ClassificationMap = std::unordered_map; + + // Constructor initializes the response pointer to null + explicit AllocPayloadCallback() + : response_ptr_(nullptr), response_alloc_count_(0) + { + } + + // Destructor - does nothing with response_ptr_ as ownership + // lies with the gRPC framework or CallbackState unique_ptr initially. + ~AllocPayloadCallback() = default; // Default destructor is sufficient + + // Direct pointer to the gRPC response object managed externally + // (by gRPC reactor or CallbackState). + ResponseType* response_ptr_; + + // Counter for allocations related to this payload. + uint32_t response_alloc_count_; + + // Map for shared memory information for output tensors. + TensorShmMap shm_map_; + + // Map for classification parameters for output tensors. + ClassificationMap classification_map_; + + // Used to extend the lifetime of serialized input data (e.g., for BYTES + // tensors) needed during the allocation phase (though data originates from + // the request). + std::list serialized_data_; +}; + template TRITONSERVER_Error* InferAllocatorPayload( @@ -430,6 +470,83 @@ InferAllocatorPayload( return nullptr; // Success } +template +TRITONSERVER_Error* +InferAllocatorPayloadCallback( + const std::shared_ptr& tritonserver, + const std::shared_ptr& shm_manager, + const inference::ModelInferRequest& request, + std::list&& serialized_data, + inference::ModelInferResponse* response_ptr, + AllocPayloadCallback* alloc_payload, + std::vector>* + shm_regions_info) +{ + alloc_payload->response_ptr_ = response_ptr; + alloc_payload->shm_map_.clear(); + alloc_payload->classification_map_.clear(); + alloc_payload->serialized_data_ = std::move(serialized_data); + + // If any of the outputs use shared memory, then we must calculate + // the memory address for that output and store it in the allocator + // payload so that it is available when the allocation callback is + // invoked. + for (const auto& io : request.outputs()) { + std::string region_name; + int64_t offset; + size_t byte_size; + bool has_shared_memory; + RETURN_IF_ERR(ParseSharedMemoryParams< + inference::ModelInferRequest::InferRequestedOutputTensor>( + io, &has_shared_memory, ®ion_name, &offset, &byte_size)); + + bool has_classification; + uint32_t classification_count; + RETURN_IF_ERR(ParseClassificationParams( + io, &has_classification, &classification_count)); + + if (has_shared_memory && has_classification) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "output can't set both 'shared_memory_region' and " + "'classification'"); + } + + if (has_shared_memory) { + void* base; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + std::shared_ptr shm_info = + nullptr; + RETURN_IF_ERR(shm_manager->GetMemoryInfo( + region_name, offset, byte_size, &base, &memory_type, &memory_type_id, + &shm_info)); + shm_regions_info->emplace_back(shm_info); + + if (memory_type == TRITONSERVER_MEMORY_GPU) { +#ifdef TRITON_ENABLE_GPU + char* cuda_handle; + RETURN_IF_ERR(shm_manager->GetCUDAHandle( + region_name, reinterpret_cast(&cuda_handle))); + alloc_payload->shm_map_.emplace( + io.name(), + ShmInfo(base, byte_size, memory_type, memory_type_id, cuda_handle)); +#endif + } else { + alloc_payload->shm_map_.emplace( + io.name(), ShmInfo( + base, byte_size, memory_type, memory_type_id, + nullptr /* cuda_ipc_handle */)); + } + } else if (has_classification) { + alloc_payload->classification_map_.emplace( + io.name(), classification_count); + } + } + + return nullptr; // Success +} + TRITONSERVER_Error* InferGRPCToInputHelper( const std::string& input_name, const std::string& model_name, const TRITONSERVER_DataType tensor_dt, const TRITONSERVER_DataType input_dt, @@ -694,6 +811,195 @@ InferResponseCompleteCommon( return nullptr; // success } +template +TRITONSERVER_Error* +InferResponseCompleteCommonCallback( + TRITONSERVER_Server* server, TRITONSERVER_InferenceResponse* iresponse, + inference::ModelInferResponse& response, + const AllocPayloadCallback& alloc_payload) +{ + RETURN_IF_ERR(TRITONSERVER_InferenceResponseError(iresponse)); + + const char *model_name, *id; + int64_t model_version; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseModel( + iresponse, &model_name, &model_version)); + RETURN_IF_ERR(TRITONSERVER_InferenceResponseId(iresponse, &id)); + + response.set_id(id); + response.set_model_name(model_name); + response.set_model_version(std::to_string(model_version)); + + // Propagate response parameters. + uint32_t parameter_count; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseParameterCount( + iresponse, ¶meter_count)); + for (uint32_t pidx = 0; pidx < parameter_count; ++pidx) { + const char* name; + TRITONSERVER_ParameterType type; + const void* vvalue; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseParameter( + iresponse, pidx, &name, &type, &vvalue)); + inference::InferParameter& param = (*response.mutable_parameters())[name]; + switch (type) { + case TRITONSERVER_PARAMETER_BOOL: + param.set_bool_param(*(reinterpret_cast(vvalue))); + break; + case TRITONSERVER_PARAMETER_INT: + param.set_int64_param(*(reinterpret_cast(vvalue))); + break; + case TRITONSERVER_PARAMETER_STRING: + param.set_string_param(reinterpret_cast(vvalue)); + break; + case TRITONSERVER_PARAMETER_DOUBLE: + param.set_double_param(*(reinterpret_cast(vvalue))); + break; + case TRITONSERVER_PARAMETER_BYTES: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "Response parameter of type 'TRITONSERVER_PARAMETER_BYTES' is not " + "currently supported"); + break; + } + } + + // Go through each response output and transfer information to the + // corresponding GRPC response output. + uint32_t output_count; + RETURN_IF_ERR( + TRITONSERVER_InferenceResponseOutputCount(iresponse, &output_count)); + if (output_count != (uint32_t)response.outputs_size()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "response output count mismatch"); + } + + for (uint32_t output_idx = 0; output_idx < output_count; ++output_idx) { + const char* cname; + TRITONSERVER_DataType datatype; + const int64_t* shape; + uint64_t dim_count; + const void* base; + size_t byte_size; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + void* userp; + + RETURN_IF_ERR(TRITONSERVER_InferenceResponseOutput( + iresponse, output_idx, &cname, &datatype, &shape, &dim_count, &base, + &byte_size, &memory_type, &memory_type_id, &userp)); + + const std::string name(cname); + + // There are usually very few outputs so fastest just to look for + // the one we want... could create a map for cases where there are + // a large number of outputs. Or rely on order to be same... + inference::ModelInferResponse::InferOutputTensor* output = nullptr; + for (auto& io : *(response.mutable_outputs())) { + if (io.name() == name) { + output = &io; + break; + } + } + + if (output == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "unable to find expected response output"); + } + + // If this output was requested as classification then remove the + // raw output from the response and instead return classification + // results as a string tensor + const auto itr = alloc_payload.classification_map_.find(name); + if (itr == alloc_payload.classification_map_.end()) { + // Not classification... + output->set_datatype(TRITONSERVER_DataTypeString(datatype)); + for (size_t idx = 0; idx < dim_count; idx++) { + output->add_shape(shape[idx]); + } + } else { + // Classification + const uint32_t classification_count = itr->second; + + // For classification need to determine the batch size, if any, + // because need to use that to break up the response for each + // batch entry. + uint32_t batch_size = 0; + + uint32_t batch_flags; + RETURN_IF_ERR(TRITONSERVER_ServerModelBatchProperties( + server, model_name, model_version, &batch_flags, + nullptr /* voidp */)); + if ((dim_count > 0) && + ((batch_flags & TRITONSERVER_BATCH_FIRST_DIM) != 0)) { + batch_size = shape[0]; + } + + // Determine the batch1 byte size of the tensor... needed when + // the response tensor batch-size > 1 so that we know how to + // stride though the tensor data. + size_t batch1_element_count = 1; + for (size_t idx = ((batch_size == 0) ? 0 : 1); idx < dim_count; idx++) { + batch1_element_count *= shape[idx]; + } + + const size_t batch1_byte_size = + batch1_element_count * TRITONSERVER_DataTypeByteSize(datatype); + + // Create the classification contents + std::string serialized; + + size_t class_offset = 0; + for (uint32_t bs = 0; bs < std::max((uint32_t)1, batch_size); ++bs) { + std::vector class_strs; + RETURN_IF_ERR(TopkClassifications( + iresponse, output_idx, + reinterpret_cast(base) + class_offset, + ((class_offset + batch1_byte_size) > byte_size) ? 0 + : batch1_byte_size, + datatype, classification_count, &class_strs)); + + // Serialize for binary representation... + for (const auto& str : class_strs) { + uint32_t len = str.size(); + serialized.append(reinterpret_cast(&len), sizeof(len)); + if (len > 0) { + serialized.append(str); + } + } + + class_offset += batch1_byte_size; + } + + // Update the output with new datatype, shape and contents. + output->set_datatype( + TRITONSERVER_DataTypeString(TRITONSERVER_TYPE_BYTES)); + + if (batch_size > 0) { + output->add_shape(batch_size); + } + output->add_shape( + std::min(classification_count, (uint32_t)batch1_element_count)); + + (*response.mutable_raw_output_contents())[output_idx] = + std::move(serialized); + } + } + + // Make sure response doesn't exceed GRPC limits. + if (response.ByteSizeLong() > MAX_GRPC_MESSAGE_SIZE) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "Response has byte size " + + std::to_string(response.ByteSizeLong()) + + " which exceeds gRPC's byte size limit " + std::to_string(INT_MAX) + + ".") + .c_str()); + } + + return nullptr; // success +} // // InferHandlerState // @@ -1633,7 +1939,7 @@ class ModelInferCallbackHandler { std::shared_ptr tritonserver_; // Request resources - AllocPayload alloc_payload_; + AllocPayloadCallback alloc_payload_; std::list serialized_data_; std::vector> shm_regions_info_; From bc214a24562f9feb682793f1233fb1f004de603a Mon Sep 17 00:00:00 2001 From: Indrajit Bhosale Date: Fri, 11 Apr 2025 09:44:43 -0700 Subject: [PATCH 12/12] Cleanup ModelInfer --- src/grpc/infer_handler.cc | 5 ----- src/grpc/infer_handler.h | 4 ++++ 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 16e241634e..e44413fa39 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -938,9 +938,6 @@ ModelInferCallbackHandler::Execute( &callback_state->shm_regions_info_); } - // --- Step 4: Prepare for Response Handling (Callback Specific) --- - std::shared_ptr> response_queue = - nullptr; if (err == nullptr) { // Use the externally provided response object directly. // Store the external response pointer in the state for later access. @@ -953,8 +950,6 @@ ModelInferCallbackHandler::Execute( } // Prepare the allocator payload: info needed by allocation callback later. - // Moves serialized input data into the payload. References the - // response_queue. if (err == nullptr) { err = InferAllocatorPayloadCallback( tritonserver_, shm_manager_, *request, diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 1b2da1caf1..dbcfb7dc53 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -811,6 +811,10 @@ InferResponseCompleteCommon( return nullptr; // success } +// Common function to populate the gRPC ModelInferResponse protobuf from the +// TRITONSERVER_InferenceResponse C structure. Handles metadata, parameters, +// output tensor data transfer, and classification formatting. Used by the +// callback API path. template TRITONSERVER_Error* InferResponseCompleteCommonCallback(