Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions integration_tests/test_save_img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import numpy as np
import pytest

from keras.utils import img_to_array
from keras.utils import load_img
from keras.utils import save_img


@pytest.mark.parametrize(
"shape, name",
[
((50, 50, 3), "rgb.jpg"),
((50, 50, 4), "rgba.jpg"),
],
)
def test_save_jpg(tmp_path, shape, name):
img = np.random.randint(0, 256, size=shape, dtype=np.uint8)
path = tmp_path / name
save_img(path, img, file_format="jpg")
assert os.path.exists(path)

# Check that the image was saved correctly and converted to RGB if needed.
loaded_img = load_img(path)
loaded_array = img_to_array(loaded_img)
assert loaded_array.shape == (50, 50, 3)
15 changes: 14 additions & 1 deletion keras/src/layers/preprocessing/image_preprocessing/resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,27 @@ def __init__(
self.width_axis = -2

def transform_images(self, images, transformation=None, training=True):
# Compute effective crop flag:
# only crop if aspect ratios differ and flag is True
input_height, input_width = transformation
source_aspect_ratio = input_width / input_height
target_aspect_ratio = self.width / self.height
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To prevent a potential ZeroDivisionError if input_height or self.height is zero, it's safer to add a small epsilon to the denominators. This is a good defensive practice, especially since these values could be dynamic tensors.

Suggested change
source_aspect_ratio = input_width / input_height
target_aspect_ratio = self.width / self.height
source_aspect_ratio = input_width / (input_height + self.backend.epsilon())
target_aspect_ratio = self.width / (self.height + self.backend.epsilon())

# Use a small epsilon for floating-point comparison
aspect_ratios_match = (
abs(source_aspect_ratio - target_aspect_ratio) < 1e-6
)
effective_crop_to_aspect_ratio = (
self.crop_to_aspect_ratio and not aspect_ratios_match
)

size = (self.height, self.width)
resized = self.backend.image.resize(
images,
size=size,
interpolation=self.interpolation,
antialias=self.antialias,
data_format=self.data_format,
crop_to_aspect_ratio=self.crop_to_aspect_ratio,
crop_to_aspect_ratio=effective_crop_to_aspect_ratio,
pad_to_aspect_ratio=self.pad_to_aspect_ratio,
fill_mode=self.fill_mode,
fill_value=self.fill_value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,31 @@ def test_crop_to_aspect_ratio(self, data_format):
ref_out = ref_out.transpose(0, 3, 1, 2)
self.assertAllClose(ref_out, out)

@parameterized.parameters([("channels_first",), ("channels_last",)])
def test_crop_to_aspect_ratio_no_op_when_aspects_match(self, data_format):
# Test that crop_to_aspect_ratio=True behaves identically to False
# when source and target aspect ratios match (no cropping should occur).
img = np.reshape(np.arange(0, 16), (1, 4, 4, 1)).astype("float32")
if data_format == "channels_first":
img = img.transpose(0, 3, 1, 2)
out_false = layers.Resizing(
height=2,
width=2,
interpolation="nearest",
data_format=data_format,
crop_to_aspect_ratio=False,
)(img)
out_true = layers.Resizing(
height=2,
width=2,
interpolation="nearest",
data_format=data_format,
crop_to_aspect_ratio=True,
)(img)
# Outputs should be identical when aspect ratios match
# (4:4 -> 2:2, both 1:1).
self.assertAllClose(out_false, out_true)

@parameterized.parameters([("channels_first",), ("channels_last",)])
def test_unbatched_image(self, data_format):
img = np.reshape(np.arange(0, 16), (4, 4, 1)).astype("float32")
Expand Down
7 changes: 5 additions & 2 deletions keras/src/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,13 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):
**kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
"""
data_format = backend.standardize_data_format(data_format)
# Normalize jpg → jpeg
if file_format is not None and file_format.lower() == "jpg":
file_format = "jpeg"
img = array_to_img(x, data_format=data_format, scale=scale)
if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"):
if img.mode == "RGBA" and file_format == "jpeg":
warnings.warn(
"The JPG format does not support RGBA images, converting to RGB."
"The JPEG format does not support RGBA images, converting to RGB."
)
img = img.convert("RGB")
img.save(path, format=file_format, **kwargs)
Expand Down
Loading