|
16 | 16 | from openfl.databases import PersistentTensorDB, TensorDB |
17 | 17 | from openfl.interface.aggregation_functions import SecureWeightedAverage, WeightedAverage |
18 | 18 | from openfl.pipelines import NoCompressionPipeline, TensorCodec |
19 | | -from openfl.protocols import base_pb2, utils |
| 19 | +from openfl.protocols import utils |
20 | 20 | from openfl.protocols.base_pb2 import NamedTensor |
21 | 21 | from openfl.utilities import TaskResultKey, TensorKey, change_tags |
22 | 22 |
|
@@ -83,39 +83,13 @@ def __init__( |
83 | 83 | single_col_cert_common_name=None, |
84 | 84 | compression_pipeline=None, |
85 | 85 | db_store_rounds=1, |
86 | | - initial_tensor_dict=None, |
87 | 86 | log_memory_usage=False, |
88 | 87 | write_logs=False, |
89 | 88 | callbacks: Optional[List] = [], |
90 | 89 | persist_checkpoint=True, |
91 | 90 | persistent_db_path=None, |
92 | 91 | secure_aggregation=False, |
93 | 92 | ): |
94 | | - """Initializes the Aggregator. |
95 | | -
|
96 | | - Args: |
97 | | - aggregator_uuid (int): Aggregation ID. |
98 | | - federation_uuid (str): Federation ID. |
99 | | - authorized_cols (list of str): The list of IDs of enrolled |
100 | | - collaborators. |
101 | | - init_state_path (str): The location of the initial weight file. |
102 | | - best_state_path (str): The file location to store the weight of |
103 | | - the best model. |
104 | | - last_state_path (str): The file location to store the latest |
105 | | - weight. |
106 | | - assigner: Assigner object. |
107 | | - straggler_handling_policy (optional): Straggler handling policy. |
108 | | - rounds_to_train (int, optional): Number of rounds to train. |
109 | | - Defaults to 256. |
110 | | - single_col_cert_common_name (str, optional): Common name for single |
111 | | - collaborator certificate. Defaults to None. |
112 | | - compression_pipeline (optional): Compression pipeline. Defaults to |
113 | | - NoCompressionPipeline. |
114 | | - db_store_rounds (int, optional): Rounds to store in TensorDB. |
115 | | - Defaults to 1. |
116 | | - initial_tensor_dict (dict, optional): Initial tensor dictionary. |
117 | | - callbacks: List of callbacks to be used during the experiment. |
118 | | - """ |
119 | 93 | self.round_number = 0 |
120 | 94 | self.next_model_round_number = 0 |
121 | 95 |
|
@@ -205,20 +179,10 @@ def __init__( |
205 | 179 | last_state_path=self.last_state_path, |
206 | 180 | ) |
207 | 181 |
|
208 | | - if initial_tensor_dict: |
209 | | - self._load_initial_tensors_from_dict(initial_tensor_dict) |
210 | | - self.model = utils.construct_model_proto( |
211 | | - tensor_dict=initial_tensor_dict, |
212 | | - round_number=0, |
213 | | - tensor_pipe=self.compression_pipeline, |
214 | | - ) |
215 | | - else: |
216 | | - if self.connector: |
217 | | - # The model definition will be handled by the respective framework |
218 | | - self.model = {} |
219 | | - else: |
220 | | - self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path) |
221 | | - self._load_initial_tensors() # keys are TensorKeys |
| 182 | + self.model = {} |
| 183 | + if not self.connector: |
| 184 | + self.model = utils.load_proto(self.init_state_path) |
| 185 | + self._load_initial_tensors() # keys are TensorKeys |
222 | 186 |
|
223 | 187 | self._secure_aggregation_enabled = secure_aggregation |
224 | 188 | if self._secure_aggregation_enabled: |
@@ -337,23 +301,6 @@ def _load_initial_tensors(self): |
337 | 301 | self.tensor_db.cache_tensor(tensor_key_dict) |
338 | 302 | logger.debug("This is the initial tensor_db: %s", self.tensor_db) |
339 | 303 |
|
340 | | - def _load_initial_tensors_from_dict(self, tensor_dict): |
341 | | - """Load all of the tensors required to begin federated learning. |
342 | | -
|
343 | | - Required tensors are: \ |
344 | | - 1. Initial model. |
345 | | -
|
346 | | - Returns: |
347 | | - None |
348 | | - """ |
349 | | - tensor_key_dict = { |
350 | | - TensorKey(k, self.uuid, self.round_number, False, ("model",)): v |
351 | | - for k, v in tensor_dict.items() |
352 | | - } |
353 | | - # all initial model tensors are loaded here |
354 | | - self.tensor_db.cache_tensor(tensor_key_dict) |
355 | | - logger.debug("This is the initial tensor_db: %s", self.tensor_db) |
356 | | - |
357 | 304 | def _save_model(self, round_number, file_path): |
358 | 305 | """Save the best or latest model. |
359 | 306 |
|
@@ -485,10 +432,7 @@ def get_tasks(self, collaborator_name): |
485 | 432 |
|
486 | 433 | # first, if it is time to quit, inform the collaborator |
487 | 434 | if self._time_to_quit(): |
488 | | - logger.info( |
489 | | - "Sending signal to collaborator %s to shutdown...", |
490 | | - collaborator_name, |
491 | | - ) |
| 435 | + logger.info("Sending signal to collaborator %s to shutdown...", collaborator_name) |
492 | 436 | self.quit_job_sent_to.append(collaborator_name) |
493 | 437 |
|
494 | 438 | tasks = None |
@@ -853,7 +797,6 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict: |
853 | 797 | ) |
854 | 798 | # Leave out straggler for the round even if they've partially |
855 | 799 | # completed given tasks |
856 | | - collaborators_for_task = [] |
857 | 800 | collaborators_for_task = [ |
858 | 801 | c for c in all_collaborators_for_task if c in self.collaborators_done |
859 | 802 | ] |
@@ -925,11 +868,7 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict: |
925 | 868 | f"{agg_results:f}" |
926 | 869 | ) |
927 | 870 | self._save_model(round_number, self.best_state_path) |
928 | | - else: |
929 | | - logger.info( |
930 | | - f"Round {round_number}: best score observed {agg_results:f} " |
931 | | - "(model not saved in evaluation mode)" |
932 | | - ) |
| 871 | + |
933 | 872 | if "trained" in tags: |
934 | 873 | self._prepare_trained(tensor_name, origin, round_number, report, agg_results) |
935 | 874 |
|
|
0 commit comments