-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
Dear authors,
I want to check if the following distributed code matches the design of SogCLR.
The distributed part of dynamic_contrastive_loss() in bulider.py might be inconsistent with its non-distributed counterpart, because:
- When distributed,
all_gather_layeronly backpropagate through the locally computed encodings. - Each gpu compute loss using
logits_ab_aaandlogits_ba_bb, therefore the off-diagonal inner products of encodings does not have its gradient fully computed. All other gpus should compute the same part of logits_ab_aa so that all gradients are computed, i.e., replacelogits_ab_aawith inner product ofhidden_large.
I suggest the following implementation for correct distributed behaviour:
def dynamic_contrastive_loss(self, hidden1, hidden2, index=None, gamma=0.9, distributed=True):
# Get (normalized) hidden1 and hidden2.
hidden1, hidden2 = F.normalize(hidden1, p=2, dim=1), F.normalize(hidden2, p=2, dim=1)
batch_size = hidden1.shape[0]
# Gather hidden1/hidden2 across replicas and create local labels.
if distributed:
hidden1_large = torch.cat(all_gather_layer.apply(hidden1), dim=0) # why concat_all_gather()
hidden2_large = torch.cat(all_gather_layer.apply(hidden2), dim=0)
enlarged_batch_size = hidden1_large.shape[0]
labels_idx = torch.arange(enlarged_batch_size, dtype=torch.long)
labels = F.one_hot(labels_idx, enlarged_batch_size*2).to(self.device)
batch_size = enlarged_batch_size
else:
hidden1_large = hidden1
hidden2_large = hidden2
labels = F.one_hot(torch.arange(batch_size, dtype=torch.long), batch_size * 2).to(self.device)
"""each agent should compute the whole logits matrix, because u_i is different across the rows."""
logits_aa = torch.matmul(hidden1_large, hidden1_large.T) # (b * world_size, b * world_size)
logits_bb = torch.matmul(hidden2_large, hidden2_large.T)
logits_ab = torch.matmul(hidden1_large, hidden2_large.T)
logits_ba = torch.matmul(hidden2_large, hidden1_large.T)
# SogCLR
neg_mask = 1-labels
logits_ab_aa = torch.cat([logits_ab, logits_aa ], 1) # neg. pairs inner product, (b * world_size, 2 * b * world_size)
logits_ba_bb = torch.cat([logits_ba, logits_bb ], 1)
neg_logits1 = torch.exp(logits_ab_aa /self.T)*neg_mask #(B, 2B)
neg_logits2 = torch.exp(logits_ba_bb /self.T)*neg_mask
neg_logits1[:, batch_size:].fill_diagonal_(0) # replaces the role of LARGE_NUM
neg_logits2[:, batch_size:].fill_diagonal_(0) # replaces the role of LARGE_NUM
if distributed:
index = concat_all_gather(index)
# u init
if self.u[index.cpu()].sum() == 0:
gamma = 1
u1 = (1 - gamma) * self.u[index.cpu()].cuda() + gamma * torch.sum(neg_logits1, dim=1, keepdim=True)/(2*(batch_size-1))
u2 = (1 - gamma) * self.u[index.cpu()].cuda() + gamma * torch.sum(neg_logits2, dim=1, keepdim=True)/(2*(batch_size-1))
self.u[index.cpu()] = (u1.detach().cpu() + u2.detach().cpu())/2
p_neg_weights1 = (neg_logits1/u1).detach()
p_neg_weights2 = (neg_logits2/u2).detach()
def softmax_cross_entropy_with_logits(labels, logits, weights):
expsum_neg_logits = torch.sum(weights*logits, dim=1, keepdim=True)/(2*(batch_size-1))
normalized_logits = logits - expsum_neg_logits
return -torch.sum(labels * normalized_logits, dim=1)
loss_a = softmax_cross_entropy_with_logits(labels, logits_ab_aa, p_neg_weights1)
loss_b = softmax_cross_entropy_with_logits(labels, logits_ba_bb, p_neg_weights2)
loss = (loss_a + loss_b).mean()
return loss
Thanks!
Metadata
Metadata
Assignees
Labels
No labels