diff --git a/modules.py b/modules.py index dab4bcfbe..161a2d3ed 100644 --- a/modules.py +++ b/modules.py @@ -103,7 +103,7 @@ class Up(nn.Module): def __init__(self, in_channels, out_channels, emb_dim=256): super().__init__() - self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.up = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2) self.conv = nn.Sequential( DoubleConv(in_channels, in_channels, residual=True), DoubleConv(in_channels, out_channels, in_channels // 2), @@ -140,12 +140,11 @@ def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"): self.bot1 = DoubleConv(256, 512) self.bot2 = DoubleConv(512, 512) - self.bot3 = DoubleConv(512, 256) - self.up1 = Up(512, 128) - self.sa4 = SelfAttention(128, 16) - self.up2 = Up(256, 64) - self.sa5 = SelfAttention(64, 32) + self.up1 = Up(512, 256) + self.sa4 = SelfAttention(256, 16) + self.up2 = Up(256, 128) + self.sa5 = SelfAttention(128, 32) self.up3 = Up(128, 64) self.sa6 = SelfAttention(64, 64) self.outc = nn.Conv2d(64, c_out, kernel_size=1) @@ -174,7 +173,6 @@ def forward(self, x, t): x4 = self.bot1(x4) x4 = self.bot2(x4) - x4 = self.bot3(x4) x = self.up1(x4, x3, t) x = self.sa4(x) @@ -201,12 +199,11 @@ def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device="cuda self.bot1 = DoubleConv(256, 512) self.bot2 = DoubleConv(512, 512) - self.bot3 = DoubleConv(512, 256) - self.up1 = Up(512, 128) - self.sa4 = SelfAttention(128, 16) - self.up2 = Up(256, 64) - self.sa5 = SelfAttention(64, 32) + self.up1 = Up(512, 256) + self.sa4 = SelfAttention(256, 16) + self.up2 = Up(256, 128) + self.sa5 = SelfAttention(128, 32) self.up3 = Up(128, 64) self.sa6 = SelfAttention(64, 64) self.outc = nn.Conv2d(64, c_out, kernel_size=1) @@ -241,7 +238,6 @@ def forward(self, x, t, y): x4 = self.bot1(x4) x4 = self.bot2(x4) - x4 = self.bot3(x4) x = self.up1(x4, x3, t) x = self.sa4(x)