@@ -106,6 +106,7 @@ enum rpc_cmd {
106106 RPC_CMD_GET_ALLOC_SIZE,
107107 RPC_CMD_HELLO,
108108 RPC_CMD_DEVICE_COUNT,
109+ RPC_CMD_GRAPH_RECOMPUTE,
109110 RPC_CMD_COUNT,
110111};
111112
@@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp {
205206 uint8_t result;
206207};
207208
208- struct rpc_msg_graph_compute_rsp {
209- uint8_t result;
210- };
211-
212209struct rpc_msg_get_device_memory_req {
213210 uint32_t device;
214211};
@@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp {
217214 uint64_t free_mem;
218215 uint64_t total_mem;
219216};
217+
218+ struct rpc_msg_graph_recompute_req {
219+ uint32_t device;
220+ };
221+
220222#pragma pack(pop)
221223
222224// RPC data structures
@@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context {
234236 size_t max_size;
235237};
236238
239+ struct graph_cache {
240+
241+ bool is_cached (const ggml_cgraph * cgraph) {
242+ if ((int )last_graph.size () != cgraph->n_nodes ) {
243+ return false ;
244+ }
245+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
246+ if (memcmp (&last_graph[i], cgraph->nodes [i], sizeof (ggml_tensor)) != 0 ) {
247+ return false ;
248+ }
249+ }
250+ return true ;
251+ }
252+
253+ void add (const ggml_cgraph * cgraph) {
254+ last_graph.resize (cgraph->n_nodes );
255+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
256+ memcpy (&last_graph[i], cgraph->nodes [i], sizeof (ggml_tensor));
257+ }
258+ }
259+
260+ std::vector<ggml_tensor> last_graph;
261+ };
262+
237263struct ggml_backend_rpc_context {
238264 std::string endpoint;
239265 uint32_t device;
240266 std::string name;
267+ graph_cache gc;
241268};
242269
243270struct ggml_backend_rpc_buffer_context {
@@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
815842
816843static enum ggml_status ggml_backend_rpc_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
817844 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
818- std::vector<uint8_t > input;
819- serialize_graph (rpc_ctx->device , cgraph, input);
820- rpc_msg_graph_compute_rsp response;
821- auto sock = get_socket (rpc_ctx->endpoint );
822- bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input.data (), input.size (), &response, sizeof (response));
823- RPC_STATUS_ASSERT (status);
824- return (enum ggml_status)response.result ;
845+
846+ GGML_ASSERT (cgraph->n_nodes > 0 );
847+ bool reuse = rpc_ctx->gc .is_cached (cgraph);
848+ if (reuse) {
849+ rpc_msg_graph_recompute_req request;
850+ request.device = rpc_ctx->device ;
851+ auto sock = get_socket (rpc_ctx->endpoint );
852+ bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof (request));
853+ RPC_STATUS_ASSERT (status);
854+ } else {
855+ rpc_ctx->gc .add (cgraph);
856+ std::vector<uint8_t > input;
857+ serialize_graph (rpc_ctx->device , cgraph, input);
858+ auto sock = get_socket (rpc_ctx->endpoint );
859+ bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input.data (), input.size ());
860+ RPC_STATUS_ASSERT (status);
861+ }
862+ return GGML_STATUS_SUCCESS;
825863}
826864
827865static ggml_backend_i ggml_backend_rpc_interface = {
@@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
880918 ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
881919 /* .endpoint = */ endpoint,
882920 /* .device = */ device,
883- /* .name = */ dev_name
921+ /* .name = */ dev_name,
922+ /* .gc = */ {},
884923 };
885924 auto reg = ggml_backend_rpc_add_server (endpoint);
886925 ggml_backend_t backend = new ggml_backend {
@@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
920959
921960class rpc_server {
922961public:
923- rpc_server (std::vector<ggml_backend_t > backends, const char * cache_dir)
924- : backends(std::move(backends)), cache_dir(cache_dir) {
962+ rpc_server (std::vector<ggml_backend_t > all_backends, const char * cache_dir)
963+ : backends(std::move(all_backends)), cache_dir(cache_dir) {
964+ stored_graphs.resize (backends.size ());
925965 }
926966 ~rpc_server ();
927967
@@ -936,11 +976,17 @@ class rpc_server {
936976 bool set_tensor_hash (const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
937977 bool get_tensor (const rpc_msg_get_tensor_req & request, std::vector<uint8_t > & response);
938978 bool copy_tensor (const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
939- bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
979+ bool graph_compute (const std::vector<uint8_t > & input);
980+ bool graph_recompute (const rpc_msg_graph_recompute_req & request);
940981 bool init_tensor (const rpc_msg_init_tensor_req & request);
941982 bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
942983 bool get_device_memory (const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
943984
985+ struct stored_graph {
986+ ggml_context_ptr ctx_ptr;
987+ ggml_cgraph * graph;
988+ };
989+
944990private:
945991 bool get_cached_file (uint64_t hash, std::vector<uint8_t > & data);
946992 ggml_tensor * deserialize_tensor (struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -953,6 +999,8 @@ class rpc_server {
953999 std::vector<ggml_backend_t > backends;
9541000 const char * cache_dir;
9551001 std::unordered_set<ggml_backend_buffer_t > buffers;
1002+ // store the last computed graph for each backend
1003+ std::vector<stored_graph> stored_graphs;
9561004};
9571005
9581006void rpc_server::hello (rpc_msg_hello_rsp & response) {
@@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
13941442 return result;
13951443}
13961444
1397- bool rpc_server::graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response ) {
1445+ bool rpc_server::graph_compute (const std::vector<uint8_t > & input) {
13981446 // serialization format:
13991447 // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
14001448 if (input.size () < 2 *sizeof (uint32_t )) {
@@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14551503 }
14561504 }
14571505 ggml_status status = ggml_backend_graph_compute (backends[device], graph);
1458- response.result = status;
1506+ GGML_ASSERT (status == GGML_STATUS_SUCCESS && " Unsuccessful graph computations are not supported with RPC" );
1507+ stored_graphs[device].ctx_ptr .swap (ctx_ptr);
1508+ stored_graphs[device].graph = graph;
1509+ return true ;
1510+ }
1511+
1512+ bool rpc_server::graph_recompute (const rpc_msg_graph_recompute_req & request) {
1513+ uint32_t device = request.device ;
1514+ if (device >= backends.size ()) {
1515+ return false ;
1516+ }
1517+ if (stored_graphs[device].graph == nullptr ) {
1518+ return false ;
1519+ }
1520+ ggml_cgraph * graph = stored_graphs[device].graph ;
1521+ LOG_DBG (" [%s] device: %u\n " , __func__, device);
1522+ ggml_status status = ggml_backend_graph_compute (backends[device], graph);
1523+ GGML_ASSERT (status == GGML_STATUS_SUCCESS && " Unsuccessful graph computations are not supported with RPC" );
14591524 return true ;
14601525}
14611526
@@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
16901755 if (!recv_msg (sockfd, input)) {
16911756 return ;
16921757 }
1693- rpc_msg_graph_compute_rsp response;
1694- if (!server.graph_compute (input, response)) {
1758+ if (!server.graph_compute (input)) {
16951759 return ;
16961760 }
1697- if (!send_msg (sockfd, &response, sizeof (response))) {
1761+ break ;
1762+ }
1763+ case RPC_CMD_GRAPH_RECOMPUTE: {
1764+ rpc_msg_graph_recompute_req request;
1765+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1766+ return ;
1767+ }
1768+ if (!server.graph_recompute (request)) {
16981769 return ;
16991770 }
17001771 break ;
0 commit comments