1313IMAGENET_STD_NEUTRAL = [1 , 1 , 1 ]
1414
1515
16+ #
17+ # Image manipulation util functions
18+ #
19+
1620def load_image (img_path , target_shape = None ):
1721 if not os .path .exists (img_path ):
1822 raise Exception (f'Path does not exist: { img_path } ' )
19- img = cv .imread (img_path )[:, :, ::- 1 ]. astype ( np . float32 ) # [:, :, ::-1] converts rgb into bgr (opencv contraint ...)
20- img /= 255.0 # get to [0, 1] range
21- if target_shape is not None :
23+ img = cv .imread (img_path )[:, :, ::- 1 ] # [:, :, ::-1] converts BGR (opencv format ...) into RGB
24+
25+ if target_shape is not None : # resize section
2226 if isinstance (target_shape , int ) and target_shape != - 1 : # scalar -> implicitly setting the height
23- ratio = target_shape / img .shape [0 ]
24- width = int (img .shape [1 ] * ratio )
25- img = cv .resize (img , (width , target_shape ), interpolation = cv .INTER_CUBIC )
27+ current_height , current_width = img .shape [:2 ]
28+ new_height = target_shape
29+ new_width = int (current_width * (new_height / current_height ))
30+ img = cv .resize (img , (new_width , new_height ), interpolation = cv .INTER_CUBIC )
2631 else : # set both dimensions to target shape
2732 img = cv .resize (img , (target_shape [1 ], target_shape [0 ]), interpolation = cv .INTER_CUBIC )
33+
34+ # this need to go after resizing - otherwise cv.resize will push values outside of [0,1] range
35+ img = img .astype (np .float32 ) # convert from uint8 to float32
36+ img /= 255.0 # get to [0, 1] range
2837 return img
2938
3039
3140def prepare_img (img_path , target_shape , device ):
3241 img = load_image (img_path , target_shape = target_shape )
3342
34- # normalize using ImageNet's mean and std (VGG was trained on images normalized this way)
35- # [0, 255] range works much better than [0, 1] range (VGG was again trained that way)
43+ # normalize using ImageNet's mean
44+ # [0, 255] range works much better than [0, 1] range
3645 transform = transforms .Compose ([
3746 transforms .ToTensor (),
3847 transforms .Lambda (lambda x : x .mul (255 )),
@@ -44,24 +53,14 @@ def prepare_img(img_path, target_shape, device):
4453 return img
4554
4655
47- def get_uint8_range (x ):
48- if isinstance (x , np .ndarray ):
49- x -= np .min (x )
50- x /= np .max (x )
51- x *= 255
52- return x
53- else :
54- raise ValueError (f'Expected numpy array got { type (x )} ' )
55-
56-
5756def save_image (img , img_path ):
5857 if len (img .shape ) == 2 :
5958 img = np .stack ((img ,) * 3 , axis = - 1 )
6059 cv .imwrite (img_path , img [:, :, ::- 1 ]) # [:, :, ::-1] converts rgb into bgr (opencv contraint...)
6160
6261
6362def generate_out_img_name (config ):
64- prefix = config ['content_img_name' ].split ('.' )[0 ] + '_' + config ['style_img_name' ].split ('.' )[0 ]
63+ prefix = os . path . basename ( config ['content_img_name' ]) .split ('.' )[0 ] + '_' + os . path . basename ( config ['style_img_name' ]) .split ('.' )[0 ]
6564 # called from the reconstruction script
6665 if 'reconstruct_script' in config :
6766 suffix = f'_o_{ config ["optimizer" ]} _h_{ str (config ["height" ])} _m_{ config ["model" ]} { config ["img_format" ][1 ]} '
@@ -72,7 +71,7 @@ def generate_out_img_name(config):
7271
7372def save_and_maybe_display (optimizing_img , dump_path , config , img_id , num_of_iterations , should_display = False ):
7473 saving_freq = config ['saving_freq' ]
75- out_img = optimizing_img .squeeze (axis = 0 ).to ('cpu' ).numpy ()
74+ out_img = optimizing_img .squeeze (axis = 0 ).to ('cpu' ).detach (). numpy ()
7675 out_img = np .moveaxis (out_img , 0 , 2 ) # swap channel from 1st to 3rd position: ch, _, _ -> _, _, chr
7776
7877 # for saving_freq == -1 save only the final result (otherwise save with frequency saving_freq and save the last pic)
@@ -83,11 +82,27 @@ def save_and_maybe_display(optimizing_img, dump_path, config, img_id, num_of_ite
8382 dump_img += np .array (IMAGENET_MEAN_255 ).reshape ((1 , 1 , 3 ))
8483 dump_img = np .clip (dump_img , 0 , 255 ).astype ('uint8' )
8584 cv .imwrite (os .path .join (dump_path , out_img_name ), dump_img [:, :, ::- 1 ])
85+
8686 if should_display :
8787 plt .imshow (np .uint8 (get_uint8_range (out_img )))
8888 plt .show ()
8989
9090
91+ def get_uint8_range (x ):
92+ if isinstance (x , np .ndarray ):
93+ x -= np .min (x )
94+ x /= np .max (x )
95+ x *= 255
96+ return x
97+ else :
98+ raise ValueError (f'Expected numpy array got { type (x )} ' )
99+
100+
101+ #
102+ # End of image manipulation util functions
103+ #
104+
105+
91106# initially it takes some time for PyTorch to download the models into local cache
92107def prepare_model (model , device ):
93108 # we are not tuning model weights -> we are only tuning optimizing_img's pixels! (that's why requires_grad=False)
0 commit comments