@@ -147,10 +147,10 @@ def forward(self, x, y):
147147 stablehlo = self .run_func_get_stablehlo (M (), input_args )
148148 self .assertEqual (stablehlo .count ("stablehlo.composite \" test.sdpa\" " ), 2 )
149149 self .assertTrue (
150- '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0 }'
150+ '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl }'
151151 in stablehlo )
152152 self .assertTrue (
153- '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl }'
153+ '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0 }'
154154 in stablehlo )
155155
156156 def test_composite_builder_sdpa_pattern (self ):
@@ -175,10 +175,10 @@ def forward(self, x, y):
175175 stablehlo = self .run_func_get_stablehlo (M (), input_args )
176176 self .assertEqual (stablehlo .count ("stablehlo.composite \" test.sdpa\" " ), 2 )
177177 self .assertTrue (
178- '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0 }'
178+ '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl }'
179179 in stablehlo )
180180 self .assertTrue (
181- '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl }'
181+ '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0 }'
182182 in stablehlo )
183183
184184 def test_composite_builder_export_sdpa_pattern (self ):
@@ -208,10 +208,10 @@ def forward(self, x, y):
208208 stablehlo = stablehlo_gm .get_stablehlo_text ()
209209 self .assertEqual (stablehlo .count ("stablehlo.composite \" test.sdpa\" " ), 2 )
210210 self .assertTrue (
211- '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0 }'
211+ '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl }'
212212 in stablehlo )
213213 self .assertTrue (
214- '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl }'
214+ '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0 }'
215215 in stablehlo )
216216 if has_tf_package ():
217217 self .assertTrue (os .path .exists (os .path .join (tmp_path , 'saved_model.pb' )))
@@ -240,10 +240,10 @@ def forward(self, x, y):
240240 stablehlo = stablehlo_gm .get_stablehlo_text ()
241241 self .assertEqual (stablehlo .count ("stablehlo.composite \" test.sdpa\" " ), 2 )
242242 self .assertTrue (
243- '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl_0 }'
243+ '{composite_attributes = {scale = 2.500000e-01 : f32}, decomposition = @test.sdpa.impl }'
244244 in stablehlo )
245245 self .assertTrue (
246- '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl }'
246+ '{composite_attributes = {scale = 2 : i64}, decomposition = @test.sdpa.impl_0 }'
247247 in stablehlo )
248248 if has_tf_package ():
249249 self .assertTrue (os .path .exists (os .path .join (tmp_path , 'saved_model.pb' )))
0 commit comments