-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Add adaptive pooling (1D, 2D, 3D) support across JAX, TensorFlow, and PyTorch backends #21820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
f99cc63
f830e93
9938ef1
323a1ab
df57227
5343b71
4cc8ac0
12edcb4
248773f
53a5dc9
2727a24
2a94421
edcf848
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1464,3 +1464,368 @@ def _pair(x): | |
| # ---- reshape -> (N, C*kH*kW, L) ---- | ||
| _, CKK, oH, oW = patches.shape | ||
| return patches.reshape(N, CKK, oH * oW) | ||
|
|
||
|
|
||
| def get_static_window_sizes(input_dim, output_dim): | ||
| """Calculate small and big window sizes for adaptive pooling.""" | ||
| small_window = math.ceil(input_dim / output_dim) | ||
| big_window = small_window + 1 | ||
| return small_window, big_window | ||
|
|
||
|
|
||
| def compute_static_gather_indices(input_dim, output_size, big_window): | ||
| """Compute gather indices for Two-Pool Gather method.""" | ||
| window_starts = jnp.floor( | ||
| (jnp.arange(output_size) * input_dim) / output_size | ||
| ).astype(jnp.int32) | ||
|
|
||
| window_ends = jnp.ceil( | ||
| (jnp.arange(1, output_size + 1) * input_dim) / output_size | ||
| ).astype(jnp.int32) | ||
|
|
||
| window_sizes = window_ends - window_starts | ||
| is_big_window = window_sizes == big_window | ||
|
|
||
| small_window = big_window - 1 | ||
| small_pool_len = input_dim - small_window + 1 | ||
|
|
||
| small_indices = window_starts | ||
| big_indices = window_starts + small_pool_len | ||
|
|
||
| gather_indices = jnp.where(is_big_window, big_indices, small_indices) | ||
| return gather_indices.astype(jnp.int32) | ||
|
|
||
|
|
||
| # ---------- 1D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 1D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
|
||
| small_l, big_l = get_static_window_sizes(l, out_l) | ||
| gather_l = compute_static_gather_indices(l, out_l, big_l) | ||
|
|
||
| small_pool_l = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| small_pool_l = small_pool_l / small_l | ||
|
|
||
| big_pool_l = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| big_pool_l = big_pool_l / big_l | ||
|
|
||
| combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) | ||
| pooled_l = jnp.take(combined_l, gather_l, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL | ||
|
|
||
| return pooled_l | ||
|
|
||
|
|
||
| def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 1D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size,) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC | ||
|
|
||
| n, l, c = inputs.shape | ||
| out_l = output_size[0] | ||
|
|
||
| small_l, big_l = get_static_window_sizes(l, out_l) | ||
| gather_l = compute_static_gather_indices(l, out_l, big_l) | ||
|
|
||
| small_pool_l = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
| big_pool_l = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) | ||
| pooled_l = jnp.take(combined_l, gather_l, axis=1) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL | ||
|
|
||
| return pooled_l | ||
|
|
||
|
|
||
| # ---------- 2D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 2D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_h = small_pool_h / small_h | ||
|
|
||
| big_pool_h = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = big_pool_h / big_h | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_w = small_pool_w / small_w | ||
|
|
||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = big_pool_w / big_w | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 2D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC | ||
|
|
||
| n, h, w, c = inputs.shape | ||
| out_h, out_w = output_size | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=1) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=2) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| # ---------- 3D Adaptive Pooling ---------- | ||
| def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Average Pooling 3D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = get_static_window_sizes(d, out_d) | ||
| gather_d = compute_static_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_d = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_d = small_pool_d / small_d | ||
|
|
||
| big_pool_d = lax.reduce_window( | ||
| inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_d = big_pool_d / big_d | ||
|
|
||
| combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_h = small_pool_h / small_h | ||
|
|
||
| big_pool_h = lax.reduce_window( | ||
| pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_h = big_pool_h / big_h | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| small_pool_w = small_pool_w / small_w | ||
|
|
||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
| big_pool_w = big_pool_w / big_w | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"): | ||
| """Adaptive Max Pooling 3D using Two-Pool Gather method.""" | ||
| if isinstance(output_size, int): | ||
| output_size = (output_size, output_size, output_size) | ||
|
|
||
| if data_format == "channels_first": | ||
| inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC | ||
|
|
||
| n, d, h, w, c = inputs.shape | ||
| out_d, out_h, out_w = output_size | ||
|
|
||
| small_d, big_d = get_static_window_sizes(d, out_d) | ||
| gather_d = compute_static_gather_indices(d, out_d, big_d) | ||
|
|
||
| small_h, big_h = get_static_window_sizes(h, out_h) | ||
| gather_h = compute_static_gather_indices(h, out_h, big_h) | ||
|
|
||
| small_w, big_w = get_static_window_sizes(w, out_w) | ||
| gather_w = compute_static_gather_indices(w, out_w, big_w) | ||
|
|
||
| small_pool_d = lax.reduce_window( | ||
| inputs, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, small_d, 1, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_d = lax.reduce_window( | ||
| inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" | ||
| ) | ||
|
|
||
| combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) | ||
| pooled_d = jnp.take(combined_d, gather_d, axis=1) | ||
|
|
||
| small_pool_h = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, small_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_h = lax.reduce_window( | ||
| pooled_d, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, big_h, 1, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) | ||
| pooled_h = jnp.take(combined_h, gather_h, axis=2) | ||
|
|
||
| small_pool_w = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, small_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
| big_pool_w = lax.reduce_window( | ||
| pooled_h, | ||
| -jnp.inf, | ||
| lax.max, | ||
| (1, 1, 1, big_w, 1), | ||
| (1, 1, 1, 1, 1), | ||
| "valid", | ||
| ) | ||
|
|
||
| combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) | ||
| pooled_w = jnp.take(combined_w, gather_w, axis=3) | ||
|
|
||
| if data_format == "channels_first": | ||
| pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW | ||
|
|
||
| return pooled_w | ||
|
|
||
|
|
||
| # ---------- Dispatcher ---------- | ||
| def adaptive_avg_pool(inputs, output_size, data_format="channels_first"): | ||
| """Dispatcher for adaptive average pooling (1D, 2D, or 3D).""" | ||
| ndims = inputs.ndim - 2 | ||
| if ndims == 1: | ||
| return adaptive_avg_pool1d(inputs, output_size, data_format) | ||
| elif ndims == 2: | ||
| return adaptive_avg_pool2d(inputs, output_size, data_format) | ||
| elif ndims == 3: | ||
| return adaptive_avg_pool3d(inputs, output_size, data_format) | ||
| else: | ||
| raise ValueError( | ||
| "adaptive_avg_pool supports 1D, 2D, or 3D inputs only." | ||
| ) | ||
|
|
||
|
|
||
| def adaptive_max_pool(inputs, output_size, data_format="channels_first"): | ||
| """Dispatcher for adaptive max pooling (1D, 2D, or 3D).""" | ||
| ndims = inputs.ndim - 2 | ||
| if ndims == 1: | ||
| return adaptive_max_pool1d(inputs, output_size, data_format) | ||
| elif ndims == 2: | ||
| return adaptive_max_pool2d(inputs, output_size, data_format) | ||
| elif ndims == 3: | ||
| return adaptive_max_pool3d(inputs, output_size, data_format) | ||
| else: | ||
| raise ValueError( | ||
| "adaptive_max_pool supports 1D, 2D, or 3D inputs only." | ||
| ) | ||
|
Comment on lines
+1499
to
+1831
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementations for Here are a couple of suggestions:
Comment on lines
+1500
to
+1831
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementations for 1D, 2D, and 3D adaptive pooling for both Consider refactoring this by creating a generalized helper function. This function could handle the pooling logic for a single dimension and could be parameterized for average vs. max pooling. For example, you could have a helper: Then, the 2D and 3D functions can be implemented by composing this helper function for each spatial dimension. This would greatly reduce code duplication and improve maintainability. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable names
n,l,c, andout_lare quite short. According to the Keras API design guidelines, it's preferred to use fully spelled-out names to improve readability, with a few common exceptions likedimandnum.1 Consider using more descriptive names likebatch_size,length,channels, andoutput_length. This comment also applies to the other adaptive pooling functions in this file.For example:
n, l, c = inputs.shape->batch_size, length, channels = inputs.shapeout_l = output_size[0]->output_length = output_size[0]Style Guide References
Footnotes
The style guide recommends using fully spelled-out names for variables and arguments to improve clarity, e.g.,
attention_scoresinstead ofattn_scores. Short names are acceptable only for very common terms likedimornum. ↩