@@ -417,8 +417,8 @@ def forward(
417417 mapping: Optional[torch.Tensor] = None,
418418 comm_dict: Optional[dict[str, torch.Tensor]] = None,
419419 ):
420- parrallel_mode = comm_dict is not None
421- if not parrallel_mode :
420+ parallel_mode = comm_dict is not None
421+ if not parallel_mode :
422422 assert mapping is not None
423423 nframes, nloc, nnei = nlist.shape
424424 nall = extended_coord.view(nframes, -1).shape[1] // 3
@@ -492,7 +492,7 @@ def forward(
492492 cosine_ij = torch.matmul(normalized_diff_i, normalized_diff_j) * (1 - 1e-6)
493493 angle_input = cosine_ij.unsqueeze(-1) / (torch.pi**0.5)
494494
495- if not parrallel_mode and self.use_loc_mapping:
495+ if not parallel_mode and self.use_loc_mapping:
496496 assert mapping is not None
497497 # convert nlist from nall to nloc index
498498 nlist = torch.gather(
@@ -534,15 +534,15 @@ def forward(
534534 angle_ebd = self.angle_embd(angle_input)
535535
536536 # nb x nall x n_dim
537- if not parrallel_mode :
537+ if not parallel_mode :
538538 assert mapping is not None
539539 mapping = (
540540 mapping.view(nframes, nall).unsqueeze(-1).expand(-1, -1, self.n_dim)
541541 )
542542 for idx, ll in enumerate(self.layers):
543543 # node_ebd: nb x nloc x n_dim
544- # node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parrallel_mode
545- if not parrallel_mode :
544+ # node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parallel_mode
545+ if not parallel_mode :
546546 assert mapping is not None
547547 node_ebd_ext = (
548548 torch.gather(node_ebd, 1, mapping)
0 commit comments