Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from keras.src.backend.jax.core import shape
from keras.src.backend.jax.core import stop_gradient
from keras.src.backend.jax.core import vectorized_map
from keras.src.backend.jax.nn import adaptive_avg_pool
from keras.src.backend.jax.nn import adaptive_max_pool
from keras.src.backend.jax.rnn import cudnn_ok
from keras.src.backend.jax.rnn import gru
from keras.src.backend.jax.rnn import lstm
Expand Down
365 changes: 365 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,3 +1464,368 @@ def _pair(x):
# ---- reshape -> (N, C*kH*kW, L) ----
_, CKK, oH, oW = patches.shape
return patches.reshape(N, CKK, oH * oW)


def get_static_window_sizes(input_dim, output_dim):
"""Calculate small and big window sizes for adaptive pooling."""
small_window = math.ceil(input_dim / output_dim)
big_window = small_window + 1
return small_window, big_window


def compute_static_gather_indices(input_dim, output_size, big_window):
"""Compute gather indices for Two-Pool Gather method."""
window_starts = jnp.floor(
(jnp.arange(output_size) * input_dim) / output_size
).astype(jnp.int32)

window_ends = jnp.ceil(
(jnp.arange(1, output_size + 1) * input_dim) / output_size
).astype(jnp.int32)

window_sizes = window_ends - window_starts
is_big_window = window_sizes == big_window

small_window = big_window - 1
small_pool_len = input_dim - small_window + 1

small_indices = window_starts
big_indices = window_starts + small_pool_len

gather_indices = jnp.where(is_big_window, big_indices, small_indices)
return gather_indices.astype(jnp.int32)


# ---------- 1D Adaptive Pooling ----------
def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 1D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC

n, l, c = inputs.shape
out_l = output_size[0]
Comment on lines +1508 to +1509
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable names n, l, c, and out_l are quite short. According to the Keras API design guidelines, it's preferred to use fully spelled-out names to improve readability, with a few common exceptions like dim and num.1 Consider using more descriptive names like batch_size, length, channels, and output_length. This comment also applies to the other adaptive pooling functions in this file.

For example:
n, l, c = inputs.shape -> batch_size, length, channels = inputs.shape
out_l = output_size[0] -> output_length = output_size[0]

Style Guide References

Footnotes

  1. The style guide recommends using fully spelled-out names for variables and arguments to improve clarity, e.g., attention_scores instead of attn_scores. Short names are acceptable only for very common terms like dim or num.


small_l, big_l = get_static_window_sizes(l, out_l)
gather_l = compute_static_gather_indices(l, out_l, big_l)

small_pool_l = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid"
)
small_pool_l = small_pool_l / small_l

big_pool_l = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid"
)
big_pool_l = big_pool_l / big_l

combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1)
pooled_l = jnp.take(combined_l, gather_l, axis=1)

if data_format == "channels_first":
pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL

return pooled_l


def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 1D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC

n, l, c = inputs.shape
out_l = output_size[0]

small_l, big_l = get_static_window_sizes(l, out_l)
gather_l = compute_static_gather_indices(l, out_l, big_l)

small_pool_l = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid"
)
big_pool_l = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid"
)

combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1)
pooled_l = jnp.take(combined_l, gather_l, axis=1)

if data_format == "channels_first":
pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL

return pooled_l


# ---------- 2D Adaptive Pooling ----------
def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 2D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_h = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
small_pool_h = small_pool_h / small_h

big_pool_h = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)
big_pool_h = big_pool_h / big_h

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
small_pool_w = small_pool_w / small_w

big_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)
big_pool_w = big_pool_w / big_w

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2)
pooled_w = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW

return pooled_w


def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 2D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_h = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
big_pool_h = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_pool_w = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
big_pool_w = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2)
pooled_w = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW

return pooled_w


# ---------- 3D Adaptive Pooling ----------
def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 3D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = get_static_window_sizes(d, out_d)
gather_d = compute_static_gather_indices(d, out_d, big_d)

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_d = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_d = small_pool_d / small_d

big_pool_d = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_d = big_pool_d / big_d

combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_pool_h = lax.reduce_window(
pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_h = small_pool_h / small_h

big_pool_h = lax.reduce_window(
pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_h = big_pool_h / big_h

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_w = small_pool_w / small_w

big_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_w = big_pool_w / big_w

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3)
pooled_w = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW

return pooled_w


def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 3D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = get_static_window_sizes(d, out_d)
gather_d = compute_static_gather_indices(d, out_d, big_d)

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_d = lax.reduce_window(
inputs,
-jnp.inf,
lax.max,
(1, small_d, 1, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_d = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)

combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_pool_h = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, small_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_h = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, big_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_pool_w = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, small_w, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_w = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, big_w, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3)
pooled_w = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW

return pooled_w


# ---------- Dispatcher ----------
def adaptive_avg_pool(inputs, output_size, data_format="channels_first"):
"""Dispatcher for adaptive average pooling (1D, 2D, or 3D)."""
ndims = inputs.ndim - 2
if ndims == 1:
return adaptive_avg_pool1d(inputs, output_size, data_format)
elif ndims == 2:
return adaptive_avg_pool2d(inputs, output_size, data_format)
elif ndims == 3:
return adaptive_avg_pool3d(inputs, output_size, data_format)
else:
raise ValueError(
"adaptive_avg_pool supports 1D, 2D, or 3D inputs only."
)


def adaptive_max_pool(inputs, output_size, data_format="channels_first"):
"""Dispatcher for adaptive max pooling (1D, 2D, or 3D)."""
ndims = inputs.ndim - 2
if ndims == 1:
return adaptive_max_pool1d(inputs, output_size, data_format)
elif ndims == 2:
return adaptive_max_pool2d(inputs, output_size, data_format)
elif ndims == 3:
return adaptive_max_pool3d(inputs, output_size, data_format)
else:
raise ValueError(
"adaptive_max_pool supports 1D, 2D, or 3D inputs only."
)
Comment on lines +1499 to +1831
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementations for adaptive_avg_pool{1,2,3}d and adaptive_max_pool{1,2,3}d are very similar, leading to significant code duplication. To improve maintainability, consider refactoring this code.

Here are a couple of suggestions:

  1. Create a helper function for each dimension (e.g., _adaptive_pool1d) that takes the pooling type ('avg' or 'max') as an argument. This would halve the number of functions.
  2. A more advanced refactoring would be to create a single generic n-dimensional pooling function that iterates over the spatial dimensions. This would further consolidate the logic for 1D, 2D, and 3D pooling into one place.

Comment on lines +1500 to +1831
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementations for 1D, 2D, and 3D adaptive pooling for both avg and max operations contain a significant amount of duplicated code. This makes the code harder to read and maintain.

Consider refactoring this by creating a generalized helper function. This function could handle the pooling logic for a single dimension and could be parameterized for average vs. max pooling.

For example, you could have a helper:
_adaptive_pool_1d_single_dim(inputs, axis, output_dim, reduce_fn, init_val, normalize=False)

Then, the 2D and 3D functions can be implemented by composing this helper function for each spatial dimension. This would greatly reduce code duplication and improve maintainability.

16 changes: 16 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,3 +1237,19 @@ def _pair(x):

# ---- reshape -> (N, C*kH*kW, L) ----
return patches.reshape(N, C * k[0] * k[1], -1)


def adaptive_max_pool(inputs, output_size, data_format=None):
"""Adaptive max pooling - Numpy backend not yet supported."""
raise NotImplementedError(
"Adaptive pooling not implemented for Numpy. "
"Use JAX, Torch or Tensorflow backend."
)


def adaptive_avg_pool(inputs, output_size, data_format=None):
"""Adaptive average pooling - Numpy backend not yet supported."""
raise NotImplementedError(
"Adaptive pooling not implemented for Numpy. "
"Use JAX, Torch or Tensorflow backend."
)
Loading
Loading