@@ -97,15 +97,13 @@ def test_send_task_results(collaborator_mock, tensor_key):
9797 """Test that send_task_results works correctly."""
9898 task_name = 'task_name'
9999 tensor_key = tensor_key ._replace (report = True )
100- tensor_dict = {tensor_key : 0 }
100+ tensor_dict = {}
101101 round_number = 0
102102 data_size = - 1
103- collaborator_mock .nparray_to_named_tensor = mock .Mock (return_value = None )
104103 collaborator_mock .client .send_local_task_results = mock .Mock ()
105104 collaborator_mock .send_task_results (tensor_dict , round_number , task_name )
106-
107105 collaborator_mock .client .send_local_task_results .assert_called_with (
108- round_number , task_name , data_size , [None ])
106+ round_number , task_name , data_size , [])
109107
110108
111109def test_send_task_results_train (collaborator_mock ):
@@ -138,95 +136,15 @@ def test_send_task_results_valid(collaborator_mock):
138136 round_number , task_name , data_size , [])
139137
140138
141- def test_named_tensor_to_nparray_without_tags (collaborator_mock , named_tensor ):
142- """Test that named_tensor_to_nparray works correctly for tensor without tags."""
143- nparray = collaborator_mock .named_tensor_to_nparray (named_tensor )
144-
145- assert named_tensor .data_bytes == nparray
146-
147-
148- @pytest .mark .parametrize ('tag' , ['compressed' , 'lossy_compressed' ])
149- def test_named_tensor_to_nparray_compressed_tag (collaborator_mock , named_tensor , tag ):
150- """Test that named_tensor_to_nparray works correctly for tensor with tags."""
151- named_tensor .tags .append (tag )
152- nparray = collaborator_mock .named_tensor_to_nparray (named_tensor )
153-
154- assert isinstance (nparray , numpy .ndarray )
155-
156-
157- def test_nparray_to_named_tensor (collaborator_mock , tensor_key , named_tensor ):
158- """Test that nparray_to_named_tensor works correctly."""
159- named_tensor .tags .append ('compressed' )
160- nparray = collaborator_mock .named_tensor_to_nparray (named_tensor )
161- tensor = collaborator_mock .nparray_to_named_tensor (tensor_key , nparray )
162- assert tensor .data_bytes == named_tensor .data_bytes
163- assert tensor .lossless is True
164-
165-
166- def test_nparray_to_named_tensor_trained (collaborator_mock , tensor_key_trained , named_tensor ):
167- """Test that nparray_to_named_tensor works correctly for trained tensor."""
168- named_tensor .tags .append ('compressed' )
169- collaborator_mock .use_delta_updates = True
170- nparray = collaborator_mock .named_tensor_to_nparray (named_tensor )
171- collaborator_mock .tensor_db .get_tensor_from_cache = mock .Mock (
172- return_value = nparray )
173- tensor = collaborator_mock .nparray_to_named_tensor (tensor_key_trained , nparray )
174- assert len (tensor .data_bytes ) == 32
175- assert tensor .lossless is False
176- assert 'delta' in tensor .tags
177-
178-
179- @pytest .mark .parametrize ('require_lossless' , [True , False ])
180- def test_get_aggregated_tensor_from_aggregator (collaborator_mock , tensor_key ,
181- named_tensor , require_lossless ):
182- """Test that get_aggregated_tensor works correctly."""
183- collaborator_mock .client .get_aggregated_tensor = mock .Mock (return_value = named_tensor )
184- nparray = collaborator_mock .get_aggregated_tensor_from_aggregator (tensor_key , require_lossless )
185-
186- collaborator_mock .client .get_aggregated_tensor .assert_called_with (
187- tensor_key .tensor_name , tensor_key .round_number ,
188- tensor_key .report , tensor_key .tags , require_lossless )
189- assert nparray == named_tensor .data_bytes
190-
191-
192- def test_get_data_for_tensorkey_from_db (collaborator_mock , tensor_key ):
193- """Test that get_data_for_tensorkey works correctly for data form db."""
194- expected_nparray = 'some_data'
195- collaborator_mock .tensor_db .get_tensor_from_cache = mock .Mock (
196- return_value = 'some_data' )
197- nparray = collaborator_mock .get_data_for_tensorkey (tensor_key )
198-
199- assert nparray == expected_nparray
200-
201-
202- def test_get_data_for_tensorkey (collaborator_mock , tensor_key ):
203- """Test that get_data_for_tensorkey works correctly if data is not in db."""
204- collaborator_mock .tensor_db .get_tensor_from_cache = mock .Mock (
205- return_value = None )
206- collaborator_mock .get_aggregated_tensor_from_aggregator = mock .Mock ()
207- collaborator_mock .get_data_for_tensorkey (tensor_key )
208- collaborator_mock .get_aggregated_tensor_from_aggregator .assert_called_with (
209- tensor_key , require_lossless = True )
210-
211-
212- def test_get_data_for_tensorkey_locally (collaborator_mock , tensor_key ):
213- """Test that get_data_for_tensorkey works correctly if found tensor locally."""
214- tensor_key = tensor_key ._replace (round_number = 1 )
215- nparray = numpy .array ([0 , 1 , 2 , 3 , 4 ])
216- collaborator_mock .tensor_db .get_tensor_from_cache = mock .Mock (
217- side_effect = [None , nparray ])
218- ret = collaborator_mock .get_data_for_tensorkey (tensor_key )
219-
220- assert numpy .array_equal (ret , nparray )
221-
222-
223- def test_get_data_for_tensorkey_dependencies (collaborator_mock , tensor_key ):
224- """Test that get_data_for_tensorkey works correctly if additional dependencies."""
225- tensor_key = tensor_key ._replace (round_number = 1 )
226- collaborator_mock .tensor_db .get_tensor_from_cache = mock .Mock (
227- return_value = None )
228- collaborator_mock .tensor_codec .find_dependencies = mock .Mock (return_value = [tensor_key ])
229- collaborator_mock .get_aggregated_tensor_from_aggregator = mock .Mock ()
230- collaborator_mock .get_data_for_tensorkey (tensor_key )
231- collaborator_mock .get_aggregated_tensor_from_aggregator .assert_called_with (
232- tensor_key , require_lossless = True )
139+ def test_fetch_tensors_from_aggregator (collaborator_mock , tensor_key , named_tensor ):
140+ """Test that fetch_tensors_from_aggregator works correctly."""
141+ # Simulate tensor not in cache
142+ collaborator_mock .tensor_db .get_tensor_from_cache .return_value = None
143+ collaborator_mock .client .get_aggregated_tensors = mock .Mock (return_value = [named_tensor ])
144+ collaborator_mock .tensor_db .cache_tensor = mock .Mock ()
145+ # Patch utils.deserialize_tensor to avoid side effects
146+ with mock .patch ("openfl.protocols.utils.deserialize_tensor" , return_value = (tensor_key , "nparray" )):
147+ collaborator_mock .fetch_tensors_from_aggregator ([tensor_key ])
148+ collaborator_mock .client .get_aggregated_tensors .assert_called_with (
149+ [tensor_key ], require_lossless = True )
150+ collaborator_mock .tensor_db .cache_tensor .assert_called ()
0 commit comments