Skip to content

Commit 1719b37

Browse files
committed
Add support for absolute paths, fix detach problem, refactor
1 parent 872696f commit 1719b37

File tree

3 files changed

+42
-25
lines changed

3 files changed

+42
-25
lines changed

neural_style_transfer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,18 @@ def closure():
141141
# sorted so that the ones on the top are more likely to be changed than the ones on the bottom
142142
#
143143
parser = argparse.ArgumentParser()
144-
parser.add_argument("--content_img_name", type=str, help="content image name", default='tubingen.png')
145-
parser.add_argument("--style_img_name", type=str, help="style image name", default='kandinsky.jpg')
144+
parser.add_argument("--content_img_name", type=str, help="content image name", default='figures.jpg')
145+
parser.add_argument("--style_img_name", type=str, help="style image name", default='vg_starry_night.jpg')
146146
parser.add_argument("--height", type=int, help="height of content and style images", default=400)
147+
147148
parser.add_argument("--content_weight", type=float, help="weight factor for content loss", default=1e5)
148149
parser.add_argument("--style_weight", type=float, help="weight factor for style loss", default=3e4)
149150
parser.add_argument("--tv_weight", type=float, help="weight factor for total variation loss", default=1e0)
150-
parser.add_argument("--saving_freq", type=int, help="saving frequency for intermediate images (-1 means only final)", default=-1)
151+
151152
parser.add_argument("--optimizer", type=str, choices=['lbfgs', 'adam'], default='lbfgs')
152-
parser.add_argument("--init_method", type=str, choices=['random', 'content', 'style'], default='content')
153153
parser.add_argument("--model", type=str, choices=['vgg16', 'vgg19'], default='vgg19')
154+
parser.add_argument("--init_method", type=str, choices=['random', 'content', 'style'], default='content')
155+
parser.add_argument("--saving_freq", type=int, help="saving frequency for intermediate images (-1 means only final)", default=-1)
154156
args = parser.parse_args()
155157

156158
# some values of weights that worked for figures.jpg, vg_starry_night.jpg (starting point for finding good images)

reconstruct_image_from_representation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def reconstruct_image_from_representation(config):
4444
should_reconstruct_content = config['should_reconstruct_content']
4545
should_visualize_representation = config['should_visualize_representation']
4646
dump_path = os.path.join(config['output_img_dir'], ('c' if should_reconstruct_content else 's') + '_reconstruction_' + config['optimizer'])
47-
dump_path = os.path.join(dump_path, config['content_img_name'].split('.')[0] if should_reconstruct_content else config['style_img_name'].split('.')[0])
47+
dump_path = os.path.join(dump_path, os.path.basename(config['content_img_name']).split('.')[0] if should_reconstruct_content else os.path.basename(config['style_img_name']).split('.')[0])
4848
os.makedirs(dump_path, exist_ok=True)
4949

5050
content_img_path = os.path.join(config['content_images_dir'], config['content_img_name'])

utils/utils.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,35 @@
1313
IMAGENET_STD_NEUTRAL = [1, 1, 1]
1414

1515

16+
#
17+
# Image manipulation util functions
18+
#
19+
1620
def load_image(img_path, target_shape=None):
1721
if not os.path.exists(img_path):
1822
raise Exception(f'Path does not exist: {img_path}')
19-
img = cv.imread(img_path)[:, :, ::-1].astype(np.float32) # [:, :, ::-1] converts rgb into bgr (opencv contraint...)
20-
img /= 255.0 # get to [0, 1] range
21-
if target_shape is not None:
23+
img = cv.imread(img_path)[:, :, ::-1] # [:, :, ::-1] converts BGR (opencv format...) into RGB
24+
25+
if target_shape is not None: # resize section
2226
if isinstance(target_shape, int) and target_shape != -1: # scalar -> implicitly setting the height
23-
ratio = target_shape / img.shape[0]
24-
width = int(img.shape[1] * ratio)
25-
img = cv.resize(img, (width, target_shape), interpolation=cv.INTER_CUBIC)
27+
current_height, current_width = img.shape[:2]
28+
new_height = target_shape
29+
new_width = int(current_width * (new_height / current_height))
30+
img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC)
2631
else: # set both dimensions to target shape
2732
img = cv.resize(img, (target_shape[1], target_shape[0]), interpolation=cv.INTER_CUBIC)
33+
34+
# this need to go after resizing - otherwise cv.resize will push values outside of [0,1] range
35+
img = img.astype(np.float32) # convert from uint8 to float32
36+
img /= 255.0 # get to [0, 1] range
2837
return img
2938

3039

3140
def prepare_img(img_path, target_shape, device):
3241
img = load_image(img_path, target_shape=target_shape)
3342

34-
# normalize using ImageNet's mean and std (VGG was trained on images normalized this way)
35-
# [0, 255] range works much better than [0, 1] range (VGG was again trained that way)
43+
# normalize using ImageNet's mean
44+
# [0, 255] range works much better than [0, 1] range
3645
transform = transforms.Compose([
3746
transforms.ToTensor(),
3847
transforms.Lambda(lambda x: x.mul(255)),
@@ -44,24 +53,14 @@ def prepare_img(img_path, target_shape, device):
4453
return img
4554

4655

47-
def get_uint8_range(x):
48-
if isinstance(x, np.ndarray):
49-
x -= np.min(x)
50-
x /= np.max(x)
51-
x *= 255
52-
return x
53-
else:
54-
raise ValueError(f'Expected numpy array got {type(x)}')
55-
56-
5756
def save_image(img, img_path):
5857
if len(img.shape) == 2:
5958
img = np.stack((img,) * 3, axis=-1)
6059
cv.imwrite(img_path, img[:, :, ::-1]) # [:, :, ::-1] converts rgb into bgr (opencv contraint...)
6160

6261

6362
def generate_out_img_name(config):
64-
prefix = config['content_img_name'].split('.')[0] + '_' + config['style_img_name'].split('.')[0]
63+
prefix = os.path.basename(config['content_img_name']).split('.')[0] + '_' + os.path.basename(config['style_img_name']).split('.')[0]
6564
# called from the reconstruction script
6665
if 'reconstruct_script' in config:
6766
suffix = f'_o_{config["optimizer"]}_h_{str(config["height"])}_m_{config["model"]}{config["img_format"][1]}'
@@ -72,7 +71,7 @@ def generate_out_img_name(config):
7271

7372
def save_and_maybe_display(optimizing_img, dump_path, config, img_id, num_of_iterations, should_display=False):
7473
saving_freq = config['saving_freq']
75-
out_img = optimizing_img.squeeze(axis=0).to('cpu').numpy()
74+
out_img = optimizing_img.squeeze(axis=0).to('cpu').detach().numpy()
7675
out_img = np.moveaxis(out_img, 0, 2) # swap channel from 1st to 3rd position: ch, _, _ -> _, _, chr
7776

7877
# for saving_freq == -1 save only the final result (otherwise save with frequency saving_freq and save the last pic)
@@ -83,11 +82,27 @@ def save_and_maybe_display(optimizing_img, dump_path, config, img_id, num_of_ite
8382
dump_img += np.array(IMAGENET_MEAN_255).reshape((1, 1, 3))
8483
dump_img = np.clip(dump_img, 0, 255).astype('uint8')
8584
cv.imwrite(os.path.join(dump_path, out_img_name), dump_img[:, :, ::-1])
85+
8686
if should_display:
8787
plt.imshow(np.uint8(get_uint8_range(out_img)))
8888
plt.show()
8989

9090

91+
def get_uint8_range(x):
92+
if isinstance(x, np.ndarray):
93+
x -= np.min(x)
94+
x /= np.max(x)
95+
x *= 255
96+
return x
97+
else:
98+
raise ValueError(f'Expected numpy array got {type(x)}')
99+
100+
101+
#
102+
# End of image manipulation util functions
103+
#
104+
105+
91106
# initially it takes some time for PyTorch to download the models into local cache
92107
def prepare_model(model, device):
93108
# we are not tuning model weights -> we are only tuning optimizing_img's pixels! (that's why requires_grad=False)

0 commit comments

Comments
 (0)