From d7aef58d7c77ce882a13cfcc5f5e8415be339906 Mon Sep 17 00:00:00 2001 From: David Grayson Date: Tue, 24 Jan 2023 20:20:06 -0800 Subject: [PATCH 1/3] Fix unet upsampling dimensions the channel dimensions for the unet were doing some weird things (e.g. decreasing by 4x then back to 2x) I believe this is more consistent with the original unet implementation --- modules.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/modules.py b/modules.py index dab4bcfbe..17d53a35d 100644 --- a/modules.py +++ b/modules.py @@ -103,7 +103,7 @@ class Up(nn.Module): def __init__(self, in_channels, out_channels, emb_dim=256): super().__init__() - self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.up = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2) self.conv = nn.Sequential( DoubleConv(in_channels, in_channels, residual=True), DoubleConv(in_channels, out_channels, in_channels // 2), @@ -201,7 +201,6 @@ def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device="cuda self.bot1 = DoubleConv(256, 512) self.bot2 = DoubleConv(512, 512) - self.bot3 = DoubleConv(512, 256) self.up1 = Up(512, 128) self.sa4 = SelfAttention(128, 16) From ffa70685a0c5b227c594fa4ca7b228352b727877 Mon Sep 17 00:00:00 2001 From: David Grayson Date: Tue, 24 Jan 2023 20:22:16 -0800 Subject: [PATCH 2/3] Fix unet upsampling dimensions --- modules.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/modules.py b/modules.py index 17d53a35d..69cd3505d 100644 --- a/modules.py +++ b/modules.py @@ -140,12 +140,11 @@ def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"): self.bot1 = DoubleConv(256, 512) self.bot2 = DoubleConv(512, 512) - self.bot3 = DoubleConv(512, 256) - self.up1 = Up(512, 128) - self.sa4 = SelfAttention(128, 16) - self.up2 = Up(256, 64) - self.sa5 = SelfAttention(64, 32) + self.up1 = Up(512, 256) + self.sa4 = SelfAttention(256, 16) + self.up2 = Up(256, 128) + self.sa5 = SelfAttention(128, 32) self.up3 = Up(128, 64) self.sa6 = SelfAttention(64, 64) self.outc = nn.Conv2d(64, c_out, kernel_size=1) From 8f3e4d4392f6591ea983e89d8d1a3a6ec49964dc Mon Sep 17 00:00:00 2001 From: David Grayson Date: Tue, 24 Jan 2023 20:28:57 -0800 Subject: [PATCH 3/3] Fix upsampling on conditional model --- modules.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modules.py b/modules.py index 69cd3505d..161a2d3ed 100644 --- a/modules.py +++ b/modules.py @@ -173,7 +173,6 @@ def forward(self, x, t): x4 = self.bot1(x4) x4 = self.bot2(x4) - x4 = self.bot3(x4) x = self.up1(x4, x3, t) x = self.sa4(x) @@ -201,10 +200,10 @@ def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device="cuda self.bot1 = DoubleConv(256, 512) self.bot2 = DoubleConv(512, 512) - self.up1 = Up(512, 128) - self.sa4 = SelfAttention(128, 16) - self.up2 = Up(256, 64) - self.sa5 = SelfAttention(64, 32) + self.up1 = Up(512, 256) + self.sa4 = SelfAttention(256, 16) + self.up2 = Up(256, 128) + self.sa5 = SelfAttention(128, 32) self.up3 = Up(128, 64) self.sa6 = SelfAttention(64, 64) self.outc = nn.Conv2d(64, c_out, kernel_size=1) @@ -239,7 +238,6 @@ def forward(self, x, t, y): x4 = self.bot1(x4) x4 = self.bot2(x4) - x4 = self.bot3(x4) x = self.up1(x4, x3, t) x = self.sa4(x)