Skip to content

Commit 409fdf7

Browse files
authored
feat: new syrk passes + lowering (#1908)
* feat: new syrk passes + lowering [skip ci] * test: raising to syrk * feat: more passes * chore: run formatting * fix: dont accidentally raise after fallback lowering * chore: bump versions * chore: bump versions
1 parent d635c44 commit 409fdf7

File tree

4 files changed

+67
-5
lines changed

4 files changed

+67
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
4-
version = "0.2.180"
4+
version = "0.2.181"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -105,7 +105,7 @@ PythonCall = "0.9.25"
105105
Random = "1.10"
106106
Random123 = "1.7"
107107
ReactantCore = "0.1.16"
108-
Reactant_jll = "0.0.265"
108+
Reactant_jll = "0.0.267"
109109
ScopedValues = "1.3.0"
110110
Scratch = "1.2"
111111
Sockets = "1.10"

src/Compiler.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,8 @@ function optimization_passes(
703703
recognize_comms::Bool=true,
704704
lower_comms::Bool=true,
705705
backend::String="gpu",
706+
is_sharded::Bool=false,
707+
raise_shlo_to_blas_lapack::Bool=true,
706708
)
707709
(; max_constant_threshold) = compile_options
708710

@@ -909,8 +911,19 @@ function optimization_passes(
909911
"transpose_symmetric_simplify",
910912
"divide_negated_operands_simplify",
911913
"multiply_negated_operands_simplify",
914+
"transpose_syrk_to_syrk",
915+
"fuse_mul_into_syrk",
916+
"fuse_add_into_syrk",
917+
"factor_scalars_in_dot_general",
912918
]
913919

920+
if !is_sharded
921+
# these passes don't have optimized sharding implementations
922+
if raise_shlo_to_blas_lapack
923+
append!(transform_passes_list, ["dot_general_to_syrk"])
924+
end
925+
end
926+
914927
if !compile_options.disable_auto_batching_passes
915928
append!(
916929
transform_passes_list,
@@ -1693,10 +1706,10 @@ function compile_mlir!(
16931706
end
16941707

16951708
opt_passes = optimization_passes(
1696-
compile_options; sroa=true, recognize_comms, lower_comms, backend
1709+
compile_options; sroa=true, recognize_comms, lower_comms, backend, is_sharded
16971710
)
16981711
opt_passes2 = optimization_passes(
1699-
compile_options; sroa=false, recognize_comms, lower_comms, backend
1712+
compile_options; sroa=false, recognize_comms, lower_comms, backend, is_sharded
17001713
)
17011714

17021715
raise_passes = if raise isa String
@@ -1718,6 +1731,7 @@ function compile_mlir!(
17181731
recognize_comms,
17191732
lower_comms,
17201733
backend,
1734+
is_sharded,
17211735
)
17221736
result = result * "," * opt_passes3
17231737
end
@@ -1728,6 +1742,8 @@ function compile_mlir!(
17281742

17291743
blas_int_width = sizeof(BlasInt) * 8
17301744
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
1745+
blas_int_width=$blas_int_width},\
1746+
lower-enzymexla-blas{backend=$backend \
17311747
blas_int_width=$blas_int_width},\
17321748
lower-enzymexla-lapack{backend=$backend \
17331749
blas_int_width=$blas_int_width}"
@@ -2012,6 +2028,8 @@ function compile_mlir!(
20122028
recognize_comms,
20132029
lower_comms,
20142030
backend,
2031+
is_sharded,
2032+
raise_shlo_to_blas_lapack=false,
20152033
),
20162034
"post_op_transpose_reshape",
20172035
)
@@ -2154,7 +2172,15 @@ function compile_mlir!(
21542172
run_pass_pipeline!(
21552173
mod,
21562174
join(
2157-
[opt_passes, "canonicalize", "cse", "canonicalize", opt_passes2],
2175+
[
2176+
opt_passes,
2177+
"canonicalize",
2178+
"cse",
2179+
"canonicalize",
2180+
opt_passes2,
2181+
lower_enzymexla_linalg_pass,
2182+
jit,
2183+
],
21582184
",",
21592185
),
21602186
"mid_pad_opts",

src/stdlibs/LinearAlgebra.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ function __init__()
4646
(BLAS.@blasfunc(dgesvj_), :enzymexla_lapack_dgesvj_),
4747
(BLAS.@blasfunc(cgesvj_), :enzymexla_lapack_cgesvj_),
4848
(BLAS.@blasfunc(zgesvj_), :enzymexla_lapack_zgesvj_),
49+
# syrk
50+
(BLAS.@blasfunc(ssyrk_), :enzymexla_blas_ssyrk_),
51+
(BLAS.@blasfunc(dsyrk_), :enzymexla_blas_dsyrk_),
52+
(BLAS.@blasfunc(csyrk_), :enzymexla_blas_csyrk_),
53+
(BLAS.@blasfunc(zsyrk_), :enzymexla_blas_zsyrk_),
54+
# trmm
55+
(BLAS.@blasfunc(strmm_), :enzymexla_blas_strmm_),
56+
(BLAS.@blasfunc(dtrmm_), :enzymexla_blas_dtrmm_),
57+
(BLAS.@blasfunc(ctrmm_), :enzymexla_blas_ctrmm_),
58+
(BLAS.@blasfunc(ztrmm_), :enzymexla_blas_ztrmm_),
59+
# symm
60+
(BLAS.@blasfunc(ssymm_), :enzymexla_blas_ssymm_),
61+
(BLAS.@blasfunc(dsymm_), :enzymexla_blas_dsymm_),
62+
(BLAS.@blasfunc(csymm_), :enzymexla_blas_csymm_),
63+
(BLAS.@blasfunc(zsymm_), :enzymexla_blas_zsymm_),
4964
]
5065
sym = Libdl.dlsym(libblastrampoline_handle, cname)
5166
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(

test/integration/linear_algebra.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,3 +723,24 @@ end
723723
@jit LinearAlgebra.normalize!(x_ra)
724724
@test x_ra x
725725
end
726+
727+
raise_to_syrk(x, y) = 3 .* (x * transpose(x)) .+ 5 .* y
728+
raise_to_syrk2(x, y) = 3 .* (transpose(x) * x) .+ 5 .* y
729+
730+
@testset "syrk optimizations" begin
731+
@testset for elty in (Float32, Float64, ComplexF32, ComplexF64)
732+
x = Reactant.TestUtils.construct_test_array(elty, 4, 5)
733+
y1 = Reactant.TestUtils.construct_test_array(elty, 4, 4)
734+
y2 = Reactant.TestUtils.construct_test_array(elty, 5, 5)
735+
x_ra = Reactant.to_rarray(x)
736+
737+
@testset for (fn, y) in ((raise_to_syrk, y1), (raise_to_syrk2, y2))
738+
y_ra = Reactant.to_rarray(y)
739+
740+
hlo = @code_hlo optimize = :before_jit fn(x_ra, y_ra)
741+
@test occursin("enzymexla.blas.syrk", repr(hlo))
742+
743+
@test @jit(fn(x_ra, y_ra)) fn(x, y)
744+
end
745+
end
746+
end

0 commit comments

Comments
 (0)