Skip to content

Commit ae104a2

Browse files
committed
test(exla): triangular_solve pass batched input test
1 parent 80cf775 commit ae104a2

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

exla/lib/exla/mlir/value.ex

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -598,25 +598,18 @@ defmodule EXLA.MLIR.Value do
598598
transform,
599599
typespec
600600
) do
601-
a_typespec = get_typespec(a)
602-
b_typespect = get_typespec(b)
603601

604-
a_shape = a_typespec.shape
605-
b_shape = b_typespect.shape
606-
607-
# a_rank = tuple_size(a_shape)
608-
# b_rank = tuple_size(b_shape)
609-
610-
# if a_rank == 3 and b_rank == 2 do
611-
# Convert {2, 3} → {2, 3, 1}
612602
new_b_shape = {2, 3, 1}
613-
614603
new_b_typespec = %{typespec | shape: new_b_shape}
615-
616604
b = reshape(b, new_b_typespec)
617605

606+
expected_output_shape = {2, 3, 1}
607+
new_typespec = %{typespec | shape: expected_output_shape}
608+
result_types = typespecs_to_mlir_types([new_typespec])
618609

619-
result_types = typespecs_to_mlir_types([typespec])
610+
611+
612+
# result_types = typespecs_to_mlir_types([typespec])
620613

621614
complex? = Nx.Type.complex?(typespec.type)
622615

@@ -634,12 +627,13 @@ defmodule EXLA.MLIR.Value do
634627
transpose_a: transpose_a
635628
]
636629

637-
# a_shape = get_shape(a)
638-
# b_shape = get_shape(b)
639630

640631

641632
op(func, "stablehlo.triangular_solve", [a, b], result_types, attributes: attributes)
642633
|> one!()
634+
|> reshape(%{typespec | shape: {2, 3}})
635+
636+
643637
end
644638

645639
def dynamic_update_slice(%Value{function: func} = operand, updates, starts, typespec) do

0 commit comments

Comments
 (0)