-
Notifications
You must be signed in to change notification settings - Fork 139
Open
Description
It seems batch size is still not supported from the code? In the forward function of ChildSumTreeLSTM, it seems that it only support process a single tree in one forward.
`
def forward(self, tree, inputs):
for idx in range(tree.num_children):
self.forward(tree.children[idx], inputs)
if tree.num_children == 0:
child_c = inputs[0].detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
child_h = inputs[0].detach().new(1, self.mem_dim).fill_(0.).requires_grad_()
else:
child_c, child_h = zip(* map(lambda x: x.state, tree.children))
child_c, child_h = torch.cat(child_c, dim=0), torch.cat(child_h, dim=0)
tree.state = self.node_forward(inputs[tree.idx], child_c, child_h)
return tree.state
`
Metadata
Metadata
Assignees
Labels
No labels