Skip to content

Commit 5db9037

Browse files
MasterSkepticistaishaileshpant
authored andcommitted
Update tests
Signed-off-by: Shah, Karan <kbshah1998@outlook.com>
1 parent 53dc8e5 commit 5db9037

File tree

1 file changed

+14
-96
lines changed

1 file changed

+14
-96
lines changed

tests/openfl/component/collaborator/test_collaborator.py

Lines changed: 14 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,13 @@ def test_send_task_results(collaborator_mock, tensor_key):
9797
"""Test that send_task_results works correctly."""
9898
task_name = 'task_name'
9999
tensor_key = tensor_key._replace(report=True)
100-
tensor_dict = {tensor_key: 0}
100+
tensor_dict = {}
101101
round_number = 0
102102
data_size = -1
103-
collaborator_mock.nparray_to_named_tensor = mock.Mock(return_value=None)
104103
collaborator_mock.client.send_local_task_results = mock.Mock()
105104
collaborator_mock.send_task_results(tensor_dict, round_number, task_name)
106-
107105
collaborator_mock.client.send_local_task_results.assert_called_with(
108-
round_number, task_name, data_size, [None])
106+
round_number, task_name, data_size, [])
109107

110108

111109
def test_send_task_results_train(collaborator_mock):
@@ -138,95 +136,15 @@ def test_send_task_results_valid(collaborator_mock):
138136
round_number, task_name, data_size, [])
139137

140138

141-
def test_named_tensor_to_nparray_without_tags(collaborator_mock, named_tensor):
142-
"""Test that named_tensor_to_nparray works correctly for tensor without tags."""
143-
nparray = collaborator_mock.named_tensor_to_nparray(named_tensor)
144-
145-
assert named_tensor.data_bytes == nparray
146-
147-
148-
@pytest.mark.parametrize('tag', ['compressed', 'lossy_compressed'])
149-
def test_named_tensor_to_nparray_compressed_tag(collaborator_mock, named_tensor, tag):
150-
"""Test that named_tensor_to_nparray works correctly for tensor with tags."""
151-
named_tensor.tags.append(tag)
152-
nparray = collaborator_mock.named_tensor_to_nparray(named_tensor)
153-
154-
assert isinstance(nparray, numpy.ndarray)
155-
156-
157-
def test_nparray_to_named_tensor(collaborator_mock, tensor_key, named_tensor):
158-
"""Test that nparray_to_named_tensor works correctly."""
159-
named_tensor.tags.append('compressed')
160-
nparray = collaborator_mock.named_tensor_to_nparray(named_tensor)
161-
tensor = collaborator_mock.nparray_to_named_tensor(tensor_key, nparray)
162-
assert tensor.data_bytes == named_tensor.data_bytes
163-
assert tensor.lossless is True
164-
165-
166-
def test_nparray_to_named_tensor_trained(collaborator_mock, tensor_key_trained, named_tensor):
167-
"""Test that nparray_to_named_tensor works correctly for trained tensor."""
168-
named_tensor.tags.append('compressed')
169-
collaborator_mock.use_delta_updates = True
170-
nparray = collaborator_mock.named_tensor_to_nparray(named_tensor)
171-
collaborator_mock.tensor_db.get_tensor_from_cache = mock.Mock(
172-
return_value=nparray)
173-
tensor = collaborator_mock.nparray_to_named_tensor(tensor_key_trained, nparray)
174-
assert len(tensor.data_bytes) == 32
175-
assert tensor.lossless is False
176-
assert 'delta' in tensor.tags
177-
178-
179-
@pytest.mark.parametrize('require_lossless', [True, False])
180-
def test_get_aggregated_tensor_from_aggregator(collaborator_mock, tensor_key,
181-
named_tensor, require_lossless):
182-
"""Test that get_aggregated_tensor works correctly."""
183-
collaborator_mock.client.get_aggregated_tensor = mock.Mock(return_value=named_tensor)
184-
nparray = collaborator_mock.get_aggregated_tensor_from_aggregator(tensor_key, require_lossless)
185-
186-
collaborator_mock.client.get_aggregated_tensor.assert_called_with(
187-
tensor_key.tensor_name, tensor_key.round_number,
188-
tensor_key.report, tensor_key.tags, require_lossless)
189-
assert nparray == named_tensor.data_bytes
190-
191-
192-
def test_get_data_for_tensorkey_from_db(collaborator_mock, tensor_key):
193-
"""Test that get_data_for_tensorkey works correctly for data form db."""
194-
expected_nparray = 'some_data'
195-
collaborator_mock.tensor_db.get_tensor_from_cache = mock.Mock(
196-
return_value='some_data')
197-
nparray = collaborator_mock.get_data_for_tensorkey(tensor_key)
198-
199-
assert nparray == expected_nparray
200-
201-
202-
def test_get_data_for_tensorkey(collaborator_mock, tensor_key):
203-
"""Test that get_data_for_tensorkey works correctly if data is not in db."""
204-
collaborator_mock.tensor_db.get_tensor_from_cache = mock.Mock(
205-
return_value=None)
206-
collaborator_mock.get_aggregated_tensor_from_aggregator = mock.Mock()
207-
collaborator_mock.get_data_for_tensorkey(tensor_key)
208-
collaborator_mock.get_aggregated_tensor_from_aggregator.assert_called_with(
209-
tensor_key, require_lossless=True)
210-
211-
212-
def test_get_data_for_tensorkey_locally(collaborator_mock, tensor_key):
213-
"""Test that get_data_for_tensorkey works correctly if found tensor locally."""
214-
tensor_key = tensor_key._replace(round_number=1)
215-
nparray = numpy.array([0, 1, 2, 3, 4])
216-
collaborator_mock.tensor_db.get_tensor_from_cache = mock.Mock(
217-
side_effect=[None, nparray])
218-
ret = collaborator_mock.get_data_for_tensorkey(tensor_key)
219-
220-
assert numpy.array_equal(ret, nparray)
221-
222-
223-
def test_get_data_for_tensorkey_dependencies(collaborator_mock, tensor_key):
224-
"""Test that get_data_for_tensorkey works correctly if additional dependencies."""
225-
tensor_key = tensor_key._replace(round_number=1)
226-
collaborator_mock.tensor_db.get_tensor_from_cache = mock.Mock(
227-
return_value=None)
228-
collaborator_mock.tensor_codec.find_dependencies = mock.Mock(return_value=[tensor_key])
229-
collaborator_mock.get_aggregated_tensor_from_aggregator = mock.Mock()
230-
collaborator_mock.get_data_for_tensorkey(tensor_key)
231-
collaborator_mock.get_aggregated_tensor_from_aggregator.assert_called_with(
232-
tensor_key, require_lossless=True)
139+
def test_fetch_tensors_from_aggregator(collaborator_mock, tensor_key, named_tensor):
140+
"""Test that fetch_tensors_from_aggregator works correctly."""
141+
# Simulate tensor not in cache
142+
collaborator_mock.tensor_db.get_tensor_from_cache.return_value = None
143+
collaborator_mock.client.get_aggregated_tensors = mock.Mock(return_value=[named_tensor])
144+
collaborator_mock.tensor_db.cache_tensor = mock.Mock()
145+
# Patch utils.deserialize_tensor to avoid side effects
146+
with mock.patch("openfl.protocols.utils.deserialize_tensor", return_value=(tensor_key, "nparray")):
147+
collaborator_mock.fetch_tensors_from_aggregator([tensor_key])
148+
collaborator_mock.client.get_aggregated_tensors.assert_called_with(
149+
[tensor_key], require_lossless=True)
150+
collaborator_mock.tensor_db.cache_tensor.assert_called()

0 commit comments

Comments
 (0)