Skip to content

Commit 0df7a32

Browse files
committed
first iteration of window dilations implementation followin Torchx steps
1 parent 6ab8207 commit 0df7a32

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

lib/emlx/backend.ex

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,8 +1342,6 @@ defmodule EMLX.Backend do
13421342
end
13431343

13441344
defp window_op(op, out, tensor, window_shape, opts) do
1345-
# TODO: window dilations can be implemented after we support internal padding
1346-
# in Nx.pad (we should have pad_internal as a shared defp)
13471345
tensor_rank = tuple_size(tensor.shape)
13481346

13491347
axes =
@@ -1356,6 +1354,17 @@ defmodule EMLX.Backend do
13561354
{low_pad, high_pad} = Enum.unzip(opts[:padding])
13571355
{device, _} = t_mx = from_nx(tensor)
13581356

1357+
window_dilations = opts[:window_dilations] || List.duplicate(1, tuple_size(window_shape))
1358+
interior_padding_config = Enum.map(window_dilations, &{0, 0, &1 - 1})
1359+
1360+
window =
1361+
1
1362+
|> EMLX.scalar_tensor(:bool, device)
1363+
|> EMLX.broadcast_to(window_shape)
1364+
|> interior_padding_mlx(0, interior_padding_config)
1365+
1366+
window_shape = EMLX.shape(window)
1367+
13591368
{_device, pad_mx} =
13601369
case op do
13611370
:sum ->
@@ -1375,6 +1384,8 @@ defmodule EMLX.Backend do
13751384

13761385
padded_mx
13771386
|> sliding_window_view(EMLX.shape(padded_mx), window_shape, opts[:strides])
1387+
|> EMLX.broadcast_to(window_shape)
1388+
|> EMLX.where(window, &1)
13781389
|> then(&apply(EMLX, op, [&1, axes, false]))
13791390
|> to_nx(out)
13801391
end

0 commit comments

Comments
 (0)