Skip to content

Commit 7070161

Browse files
committed
implemented window dilations on EMLX.window_op()
1 parent 0df7a32 commit 7070161

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

lib/emlx/backend.ex

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,17 +1355,19 @@ defmodule EMLX.Backend do
13551355
{device, _} = t_mx = from_nx(tensor)
13561356

13571357
window_dilations = opts[:window_dilations] || List.duplicate(1, tuple_size(window_shape))
1358-
interior_padding_config = Enum.map(window_dilations, &{0, 0, &1 - 1})
1358+
interior_padding_config = Enum.map(window_dilations, &(&1 - 1))
1359+
1360+
{_device, zero_mx} = EMLX.scalar_tensor(0, :bool, device)
13591361

13601362
window =
13611363
1
13621364
|> EMLX.scalar_tensor(:bool, device)
13631365
|> EMLX.broadcast_to(window_shape)
1364-
|> interior_padding_mlx(0, interior_padding_config)
1366+
|> interior_padding_mlx(zero_mx, interior_padding_config)
13651367

13661368
window_shape = EMLX.shape(window)
13671369

1368-
{_device, pad_mx} =
1370+
{device, pad_mx} =
13691371
case op do
13701372
:sum ->
13711373
EMLX.scalar_tensor(0, to_mlx_type(out.type), device)
@@ -1384,8 +1386,7 @@ defmodule EMLX.Backend do
13841386

13851387
padded_mx
13861388
|> sliding_window_view(EMLX.shape(padded_mx), window_shape, opts[:strides])
1387-
|> EMLX.broadcast_to(window_shape)
1388-
|> EMLX.where(window, &1)
1389+
|> then(&EMLX.where(window, &1, {device, pad_mx}))
13891390
|> then(&apply(EMLX, op, [&1, axes, false]))
13901391
|> to_nx(out)
13911392
end

test/emlx/nx_doctest_test.exs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,6 @@ defmodule EMLX.Nx.DoctestTest do
4747

4848
@to_be_fixed [
4949
:moduledoc,
50-
# window_* do not support window_dilations yet
51-
window_sum: 3,
52-
window_max: 3,
53-
window_min: 3,
54-
window_product: 3,
55-
window_mean: 3,
56-
# missing support for inner padding
5750
# MLX sorts NaNs lowest, Nx sorts them highest
5851
argmin: 2,
5952
argmax: 2,

0 commit comments

Comments
 (0)