@@ -1355,17 +1355,19 @@ defmodule EMLX.Backend do
13551355 { device , _ } = t_mx = from_nx ( tensor )
13561356
13571357 window_dilations = opts [ :window_dilations ] || List . duplicate ( 1 , tuple_size ( window_shape ) )
1358- interior_padding_config = Enum . map ( window_dilations , & { 0 , 0 , & 1 - 1 } )
1358+ interior_padding_config = Enum . map ( window_dilations , & ( & 1 - 1 ) )
1359+
1360+ { _device , zero_mx } = EMLX . scalar_tensor ( 0 , :bool , device )
13591361
13601362 window =
13611363 1
13621364 |> EMLX . scalar_tensor ( :bool , device )
13631365 |> EMLX . broadcast_to ( window_shape )
1364- |> interior_padding_mlx ( 0 , interior_padding_config )
1366+ |> interior_padding_mlx ( zero_mx , interior_padding_config )
13651367
13661368 window_shape = EMLX . shape ( window )
13671369
1368- { _device , pad_mx } =
1370+ { device , pad_mx } =
13691371 case op do
13701372 :sum ->
13711373 EMLX . scalar_tensor ( 0 , to_mlx_type ( out . type ) , device )
@@ -1384,8 +1386,7 @@ defmodule EMLX.Backend do
13841386
13851387 padded_mx
13861388 |> sliding_window_view ( EMLX . shape ( padded_mx ) , window_shape , opts [ :strides ] )
1387- |> EMLX . broadcast_to ( window_shape )
1388- |> EMLX . where ( window , & 1 )
1389+ |> then ( & EMLX . where ( window , & 1 , { device , pad_mx } ) )
13891390 |> then ( & apply ( EMLX , op , [ & 1 , axes , false ] ) )
13901391 |> to_nx ( out )
13911392 end
0 commit comments