@@ -598,25 +598,18 @@ defmodule EXLA.MLIR.Value do
598598 transform ,
599599 typespec
600600 ) do
601- a_typespec = get_typespec ( a )
602- b_typespect = get_typespec ( b )
603601
604- a_shape = a_typespec . shape
605- b_shape = b_typespect . shape
606-
607- # a_rank = tuple_size(a_shape)
608- # b_rank = tuple_size(b_shape)
609-
610- # if a_rank == 3 and b_rank == 2 do
611- # Convert {2, 3} → {2, 3, 1}
612602 new_b_shape = { 2 , 3 , 1 }
613-
614603 new_b_typespec = % { typespec | shape: new_b_shape }
615-
616604 b = reshape ( b , new_b_typespec )
617605
606+ expected_output_shape = { 2 , 3 , 1 }
607+ new_typespec = % { typespec | shape: expected_output_shape }
608+ result_types = typespecs_to_mlir_types ( [ new_typespec ] )
618609
619- result_types = typespecs_to_mlir_types ( [ typespec ] )
610+
611+
612+ # result_types = typespecs_to_mlir_types([typespec])
620613
621614 complex? = Nx.Type . complex? ( typespec . type )
622615
@@ -634,12 +627,13 @@ defmodule EXLA.MLIR.Value do
634627 transpose_a: transpose_a
635628 ]
636629
637- # a_shape = get_shape(a)
638- # b_shape = get_shape(b)
639630
640631
641632 op ( func , "stablehlo.triangular_solve" , [ a , b ] , result_types , attributes: attributes )
642633 |> one! ( )
634+ |> reshape ( % { typespec | shape: { 2 , 3 } } )
635+
636+
643637 end
644638
645639 def dynamic_update_slice ( % Value { function: func } = operand , updates , starts , typespec ) do
0 commit comments