We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 3f561aa commit 92c6d2bCopy full SHA for 92c6d2b
golf_federated/client/process/config/model/torchmodel.py
@@ -201,6 +201,8 @@ def train(self) -> None:
201
Model training.
202
203
"""
204
+
205
+ self.global_net = deepcopy(self.model)
206
207
self.loss = self.loss.to(self.process_unit)
208
self.model = self.model.to(self.process_unit)
@@ -246,7 +248,6 @@ def train(self) -> None:
246
248
epoch, self.train_epoch, training_loss / training_total))
247
249
self.loss = self.loss.to('cpu')
250
self.model = self.model.to('cpu')
- self.global_net = deepcopy(self.model)
251
del support_x, support_y, query_y, query_x
252
gc.collect()
253
torch.cuda.empty_cache()
0 commit comments