@@ -1342,8 +1342,6 @@ defmodule EMLX.Backend do
13421342 end
13431343
13441344 defp window_op ( op , out , tensor , window_shape , opts ) do
1345- # TODO: window dilations can be implemented after we support internal padding
1346- # in Nx.pad (we should have pad_internal as a shared defp)
13471345 tensor_rank = tuple_size ( tensor . shape )
13481346
13491347 axes =
@@ -1356,6 +1354,17 @@ defmodule EMLX.Backend do
13561354 { low_pad , high_pad } = Enum . unzip ( opts [ :padding ] )
13571355 { device , _ } = t_mx = from_nx ( tensor )
13581356
1357+ window_dilations = opts [ :window_dilations ] || List . duplicate ( 1 , tuple_size ( window_shape ) )
1358+ interior_padding_config = Enum . map ( window_dilations , & { 0 , 0 , & 1 - 1 } )
1359+
1360+ window =
1361+ 1
1362+ |> EMLX . scalar_tensor ( :bool , device )
1363+ |> EMLX . broadcast_to ( window_shape )
1364+ |> interior_padding_mlx ( 0 , interior_padding_config )
1365+
1366+ window_shape = EMLX . shape ( window )
1367+
13591368 { _device , pad_mx } =
13601369 case op do
13611370 :sum ->
@@ -1375,6 +1384,8 @@ defmodule EMLX.Backend do
13751384
13761385 padded_mx
13771386 |> sliding_window_view ( EMLX . shape ( padded_mx ) , window_shape , opts [ :strides ] )
1387+ |> EMLX . broadcast_to ( window_shape )
1388+ |> EMLX . where ( window , & 1 )
13781389 |> then ( & apply ( EMLX , op , [ & 1 , axes , false ] ) )
13791390 |> to_nx ( out )
13801391 end
0 commit comments