Skip to content

Commit 8716a4a

Browse files
authored
(Fix #11) devolearn: upgrade generative model
1 parent 01353e9 commit 8716a4a

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

devolearn/embryo_generator_model/embryo_generator_model.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222

2323

24+
2425
"""
2526
GAN to generate images of embryos
2627
"""
@@ -42,19 +43,21 @@ def __init__(self, ngf, nz, nc):
4243
nn.BatchNorm2d(ngf * 2),
4344
nn.ReLU(True),
4445
# state size. (ngf*2) x 16 x 16
45-
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
46-
nn.BatchNorm2d(ngf),
46+
nn.ConvTranspose2d( ngf * 2, ngf*2, 4, 2, 1, bias=False),
47+
nn.BatchNorm2d(ngf*2),
4748
nn.ReLU(True),
48-
# state size. (ngf) x 32 x 32
49-
50-
nn.ConvTranspose2d( ngf, ngf, 4, 2, 1, bias=False), ## added custom stuff here
49+
# state size. (ngf*2) x 32 x 32
50+
nn.ConvTranspose2d( ngf*2, ngf, 4, 2, 1, bias=False), ## added custom stuff here
5151
nn.BatchNorm2d(ngf),
5252
nn.ReLU(True),
5353
# state size. (ngf) x 64 x 64
54-
55-
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
54+
nn.ConvTranspose2d( ngf, 10, 4, 2, 1, bias=False), ## added custom stuff here
55+
nn.BatchNorm2d(10),
56+
nn.ReLU(True),
57+
# state size. 10 x 128 x 128
58+
nn.ConvTranspose2d( 10, nc, 4, 2, 1, bias=False),
5659
nn.Tanh()
57-
# state size. (nc) x 128 x 128
60+
# state size. (nc) x 256 x 256
5861
)
5962

6063
def forward(self, input):
@@ -68,7 +71,6 @@ def __init__(self, mode = "cpu"):
6871
ngf = size of output image of the GAN
6972
nz = size of latent space noise (latent vector)
7073
nc = number of color channels of the output image
71-
7274
Do not tweak these unless you're changing the Generator() with a new model with a different architecture.
7375
7476
"""
@@ -77,10 +79,11 @@ def __init__(self, mode = "cpu"):
7779
self.nz = 128
7880
self.nc = 1
7981
self.generator= Generator(self.ngf, self.nz, self.nc)
80-
self.model_url = "https://github.com/DevoLearn/devolearn/raw/master/devolearn/embryo_generator_model/embryo_generator.pt"
81-
self.model_name = "embryo_generator.pt"
82+
self.model_url = "https://raw.githubusercontent.com/Mainakdeb/devolearn/master/devolearn/embryo_generator_model/embryo_generator.pth"
83+
self.model_name = "embryo_generator.pth"
8284
self.model_dir = os.path.dirname(__file__)
8385
# print("at : ", os.path.dirname(__file__))
86+
print("Searching here.. ",self.model_dir + "/" + self.model_name)
8487

8588
try:
8689
# print("model already downloaded, loading model...")
@@ -113,7 +116,6 @@ def generate(self, image_size = (700,500)):
113116
}
114117
The native size of the GAN's output is 128*128, and then it resizes the
115118
generated image to the desired size.
116-
117119
"""
118120
with torch.no_grad():
119121
noise = torch.randn([1,128,1,1])
@@ -139,7 +141,6 @@ def generate_n_images(self, n = 3, foldername = "generated_images", image_size =
139141
}
140142
141143
This is an extension of the generator.generate() function for generating multiple images at once and saving them into a folder.
142-
143144
"""
144145

145146
if os.path.isdir(foldername) == False:
@@ -148,7 +149,7 @@ def generate_n_images(self, n = 3, foldername = "generated_images", image_size =
148149

149150
for i in tqdm(range(n), desc = "generating images :"):
150151
filename = foldername + "/" + str(i) + ".jpg"
151-
gen_image = self.generate() ## 2d numpy array
152+
gen_image = self.generate() ## 2d numpy arreay
152153
cv2.imwrite(filename, gen_image)
153154

154-
print ("Saved ", n, " images in", foldername)
155+
print ("Saved ", n, " images in", foldername)

0 commit comments

Comments
 (0)