Skip to content

Commit 7fcecd3

Browse files
Enable rest protocol between Aggregator and Collaborator (#1500)
- add new AggregatorClientInterface to allow switching b/w grpc and rest - endhance existing AggregatorGRPCClient to start using AggregatorClientInterface - added new transport package for rest with AggregatorRESTClient implementing AggregatorClientInterface - added streaming api support with custom content-type - added various connection flag for streaming request - send additional header key "Sender" for better request logging at server side - aligned Rest and gRPC client for most of the init params - added AggregatorRESTServer and necesary changes in aggregator cli and federated/plan get_server method - added transport_protocol settings in defaults/network.yaml, defaulted the same to 'grpc' - reduced cyclomatic complexity of Rest Server - fixed protobuf streaming issue for v1/task/results API - added more detailed logging for task progression and metadata for each api calls - pinned Flask version to latest stable 3.1.0 - addressing review comments - 13th-May - added ping api and `collaborato` constructor hint for `AggregatorClientInterface` - added send_message_to_server in client and AggregatorClientInterface, Rest Server is already at parity - changed base uri for REST server to 'experimental/v1', adjusted the client and tests accordingly - fixed issue related to mTLS in REST server/client - disabled TLS 1.2 in both server/client rebased 21st.May.1 Signed-off-by: Shailesh Pant <shailesh.pant@intel.com>
1 parent 901e962 commit 7fcecd3

File tree

16 files changed

+2081
-30
lines changed

16 files changed

+2081
-30
lines changed

openfl-workspace/workspace/plan/defaults/network.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ settings:
77
client_reconnect_interval : 5
88
require_client_auth : True
99
cert_folder : cert
10-
enable_atomic_connections : False
10+
enable_atomic_connections : False
11+
transport_protocol : grpc

openfl/component/collaborator/collaborator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from openfl.databases import TensorDB
1515
from openfl.pipelines import NoCompressionPipeline, TensorCodec
1616
from openfl.protocols import utils
17-
from openfl.transport.grpc.aggregator_client import AggregatorGRPCClient
17+
from openfl.transport.grpc.aggregator_client import AggregatorClientInterface
1818
from openfl.utilities import TensorKey
1919

2020
logger = logging.getLogger(__name__)
@@ -64,7 +64,7 @@ def __init__(
6464
collaborator_name,
6565
aggregator_uuid,
6666
federation_uuid,
67-
client: AggregatorGRPCClient,
67+
client: AggregatorClientInterface,
6868
task_runner,
6969
task_config,
7070
opt_treatment="RESET",

openfl/federated/plan/plan.py

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515

1616
from openfl.interface.aggregation_functions import AggregationFunction, WeightedAverage
1717
from openfl.interface.cli_helper import WORKSPACE
18-
from openfl.transport import AggregatorGRPCClient, AggregatorGRPCServer
18+
from openfl.transport import (
19+
AggregatorGRPCClient,
20+
AggregatorGRPCServer,
21+
AggregatorRESTClient,
22+
AggregatorRESTServer,
23+
)
1924
from openfl.utilities.utils import getfqdn_env
2025

2126
SETTINGS = "settings"
@@ -542,8 +547,6 @@ def get_collaborator(
542547
else:
543548
defaults[SETTINGS]["client"] = self.get_client(
544549
collaborator_name,
545-
self.aggregator_uuid,
546-
self.federation_uuid,
547550
root_certificate,
548551
private_key,
549552
certificate,
@@ -557,13 +560,11 @@ def get_collaborator(
557560
def get_client(
558561
self,
559562
collaborator_name,
560-
aggregator_uuid,
561-
federation_uuid,
562563
root_certificate=None,
563564
private_key=None,
564565
certificate=None,
565566
):
566-
"""Get gRPC client for the specified collaborator.
567+
"""Get gRPC or REST client for the specified collaborator.
567568
568569
Args:
569570
collaborator_name (str): Name of the collaborator.
@@ -577,8 +578,38 @@ def get_client(
577578
Defaults to None.
578579
579580
Returns:
580-
AggregatorGRPCClient: gRPC client for the specified collaborator.
581+
AggregatorGRPCClient or AggregatorRESTClient: gRPC or REST client for the collaborator.
581582
"""
583+
client_args = self.get_client_args(
584+
collaborator_name,
585+
root_certificate,
586+
private_key,
587+
certificate,
588+
)
589+
network_cfg = self.config["network"][SETTINGS]
590+
protocol = network_cfg.get("transport_protocol", "grpc").lower()
591+
592+
if self.client_ is None:
593+
self.client_ = self._get_client(protocol, **client_args)
594+
595+
return self.client_
596+
597+
def _get_client(self, protocol, **kwargs):
598+
if protocol == "rest":
599+
client = AggregatorRESTClient(**kwargs)
600+
elif protocol == "grpc":
601+
client = AggregatorGRPCClient(**kwargs)
602+
else:
603+
raise ValueError(f"Unsupported transport_protocol '{protocol}'")
604+
return client
605+
606+
def get_client_args(
607+
self,
608+
collaborator_name,
609+
root_certificate=None,
610+
private_key=None,
611+
certificate=None,
612+
):
582613
common_name = collaborator_name
583614
if not root_certificate or not private_key or not certificate:
584615
root_certificate = "cert/cert_chain.crt"
@@ -593,14 +624,10 @@ def get_client(
593624
client_args["certificate"] = certificate
594625
client_args["private_key"] = private_key
595626

596-
client_args["aggregator_uuid"] = aggregator_uuid
597-
client_args["federation_uuid"] = federation_uuid
627+
client_args["aggregator_uuid"] = self.aggregator_uuid
628+
client_args["federation_uuid"] = self.federation_uuid
598629
client_args["collaborator_name"] = collaborator_name
599-
600-
if self.client_ is None:
601-
self.client_ = AggregatorGRPCClient(**client_args)
602-
603-
return self.client_
630+
return client_args
604631

605632
def get_server(
606633
self,
@@ -609,7 +636,7 @@ def get_server(
609636
certificate=None,
610637
**kwargs,
611638
):
612-
"""Get gRPC server of the aggregator instance.
639+
"""Get gRPC or REST server of the aggregator instance.
613640
614641
Args:
615642
root_certificate (str, optional): Root certificate for the server.
@@ -621,8 +648,29 @@ def get_server(
621648
**kwargs: Additional keyword arguments.
622649
623650
Returns:
624-
AggregatorGRPCServer: gRPC server of the aggregator instance.
651+
Aggregator Server: returns either gRPC or REST server of the aggregator instance.
625652
"""
653+
server_args = self.get_server_args(root_certificate, private_key, certificate, kwargs)
654+
655+
server_args["aggregator"] = self.get_aggregator()
656+
network_cfg = self.config["network"][SETTINGS]
657+
protocol = network_cfg.get("transport_protocol", "grpc").lower()
658+
659+
if self.server_ is None:
660+
self.server_ = self._get_server(protocol, **server_args)
661+
662+
return self.server_
663+
664+
def _get_server(self, protocol, **kwargs):
665+
if protocol == "rest":
666+
server = AggregatorRESTServer(**kwargs)
667+
elif protocol == "grpc":
668+
server = AggregatorGRPCServer(**kwargs)
669+
else:
670+
raise ValueError(f"Unsupported transport_protocol '{protocol}'")
671+
return server
672+
673+
def get_server_args(self, root_certificate, private_key, certificate, kwargs):
626674
common_name = self.config["network"][SETTINGS]["agg_addr"].lower()
627675

628676
if not root_certificate or not private_key or not certificate:
@@ -638,13 +686,7 @@ def get_server(
638686
server_args["root_certificate"] = root_certificate
639687
server_args["certificate"] = certificate
640688
server_args["private_key"] = private_key
641-
642-
server_args["aggregator"] = self.get_aggregator()
643-
644-
if self.server_ is None:
645-
self.server_ = AggregatorGRPCServer(**server_args)
646-
647-
return self.server_
689+
return server_args
648690

649691
def save_model_to_state_file(self, tensor_dict, round_number, output_path):
650692
"""Save model weights to a protobuf state file.

openfl/interface/aggregator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def start_(plan, authorized_cols, task_group):
9292
logger.info(f"Setting aggregator to assign: {task_group} task_group")
9393

9494
logger.info("🧿 Starting the Aggregator Service.")
95-
96-
parsed_plan.get_server().serve()
95+
server = parsed_plan.get_server()
96+
server.serve()
9797

9898

9999
@aggregator.command(name="generate-cert-request")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""AggregatorClientInterface module."""
5+
6+
from abc import ABC, abstractmethod
7+
from typing import Any, List, Tuple
8+
9+
10+
class AggregatorClientInterface(ABC):
11+
@abstractmethod
12+
def ping(self):
13+
"""
14+
Ping the aggregator to check connectivity.
15+
"""
16+
pass
17+
18+
@abstractmethod
19+
def get_tasks(self) -> Tuple[List[Any], int, int, bool]:
20+
"""
21+
Retrieves tasks for the given collaborator client.
22+
Returns a tuple: (tasks, round_number, sleep_time, time_to_quit)
23+
"""
24+
pass
25+
26+
@abstractmethod
27+
def get_aggregated_tensor(
28+
self,
29+
tensor_name: str,
30+
round_number: int,
31+
report: bool,
32+
tags: List[str],
33+
require_lossless: bool,
34+
) -> Any:
35+
"""
36+
Retrieves the aggregated tensor.
37+
"""
38+
pass
39+
40+
@abstractmethod
41+
def send_local_task_results(
42+
self,
43+
round_number: int,
44+
task_name: str,
45+
data_size: int,
46+
named_tensors: List[Any],
47+
) -> Any:
48+
"""
49+
Sends local task results.
50+
Parameters:
51+
collaborator_name: Name of the collaborator.
52+
round_number: The current round.
53+
task_name: Name of the task.
54+
data_size: Size of the data.
55+
named_tensors: A list of tensors (or named tensor objects).
56+
Returns a SendLocalTaskResultsResponse.
57+
"""
58+
pass
59+
60+
@abstractmethod
61+
def send_message_to_server(self, openfl_message: Any, collaborator_name: str) -> Any:
62+
"""
63+
Forwards a converted message from the local client to the OpenFL server and returns the
64+
response.
65+
Args:
66+
openfl_message: The converted message to be sent to the OpenFL server (InteropMessage
67+
proto).
68+
collaborator_name: The name of the collaborator.
69+
Returns:
70+
The response from the OpenFL server (InteropMessage proto).
71+
"""
72+
pass

openfl/transport/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
# Copyright 2020-2024 Intel Corporation
1+
# Copyright 2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44

55
from openfl.transport.grpc import AggregatorGRPCClient, AggregatorGRPCServer
6+
from openfl.transport.rest import AggregatorRESTClient, AggregatorRESTServer

openfl/transport/grpc/aggregator_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import grpc
1212

1313
from openfl.protocols import aggregator_pb2, aggregator_pb2_grpc, utils
14+
from openfl.protocols.aggregator_client_interface import AggregatorClientInterface
1415
from openfl.transport.grpc.common import create_header, create_insecure_channel, create_tls_channel
1516

1617
logger = logging.getLogger(__name__)
@@ -165,9 +166,11 @@ def wrapper(self, *args, **kwargs):
165166
return wrapper
166167

167168

168-
class AggregatorGRPCClient:
169+
class AggregatorGRPCClient(AggregatorClientInterface):
169170
"""Collaborator-side gRPC client that talks to the aggregator.
170171
172+
This class implements a gRPC client for communicating with an aggregator.
173+
171174
Attributes:
172175
agg_addr (str): Aggregator address.
173176
agg_port (int): Aggregator port.

openfl/transport/rest/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright 2020-2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from openfl.transport.rest.aggregator_client import AggregatorRESTClient
6+
from openfl.transport.rest.aggregator_server import AggregatorRESTServer

0 commit comments

Comments
 (0)