diff --git a/Cargo.lock b/Cargo.lock index 9dce64ce66ab6..679e2866c7d15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3574,6 +3574,7 @@ dependencies = [ "gimli 0.31.1", "itertools", "libc", + "libloading 0.9.0", "measureme", "object 0.37.3", "rustc-demangle", diff --git a/compiler/rustc_codegen_llvm/Cargo.toml b/compiler/rustc_codegen_llvm/Cargo.toml index 67bd1e59bb0c2..5eb65c01b4b8d 100644 --- a/compiler/rustc_codegen_llvm/Cargo.toml +++ b/compiler/rustc_codegen_llvm/Cargo.toml @@ -14,6 +14,7 @@ bitflags = "2.4.1" gimli = "0.31" itertools = "0.12" libc = "0.2" +libloading = "0.9.0" measureme = "12.0.1" object = { version = "0.37.0", default-features = false, features = ["std", "read"] } rustc-demangle = "0.1.21" diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index b820b992105fd..c2477f0f53160 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -528,31 +528,35 @@ fn thin_lto( } } -fn enable_autodiff_settings(ad: &[config::AutoDiff]) { +fn enable_autodiff_settings(cgcx: &CodegenContext, ad: &[config::AutoDiff]) { + // Initialize Enzyme if not already done (idempotent due to OnceLock) + // This ensures it works even if LlvmCodegenBackend::init() didn't run it + let mut enzyme = llvm::EnzymeWrapper::get_or_init(&cgcx.sysroot); + for val in ad { // We intentionally don't use a wildcard, to not forget handling anything new. match val { config::AutoDiff::PrintPerf => { - llvm::set_print_perf(true); + enzyme.set_print_perf(true); } config::AutoDiff::PrintAA => { - llvm::set_print_activity(true); + enzyme.set_print_activity(true); } config::AutoDiff::PrintTA => { - llvm::set_print_type(true); + enzyme.set_print_type(true); } config::AutoDiff::PrintTAFn(fun) => { - llvm::set_print_type(true); // Enable general type printing - llvm::set_print_type_fun(&fun); // Set specific function to analyze + enzyme.set_print_type(true); // Enable general type printing + enzyme.set_print_type_fun(&fun); // Set specific function to analyze } config::AutoDiff::Inline => { - llvm::set_inline(true); + enzyme.set_inline(true); } config::AutoDiff::LooseTypes => { - llvm::set_loose_types(true); + enzyme.set_loose_types(true); } config::AutoDiff::PrintSteps => { - llvm::set_print(true); + enzyme.set_print(true); } // We handle this in the PassWrapper.cpp config::AutoDiff::PrintPasses => {} @@ -571,9 +575,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) { } } // This helps with handling enums for now. - llvm::set_strict_aliasing(false); + enzyme.set_strict_aliasing(false); // FIXME(ZuseZ4): Test this, since it was added a long time ago. - llvm::set_rust_rules(true); + enzyme.set_rust_rules(true); } pub(crate) fn run_pass_manager( @@ -609,7 +613,7 @@ pub(crate) fn run_pass_manager( }; if enable_ad { - enable_autodiff_settings(&config.autodiff); + enable_autodiff_settings(&cgcx, &config.autodiff); } unsafe { diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index fde7dd6ef7a85..aa33aaba743d9 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -726,6 +726,14 @@ pub(crate) unsafe fn llvm_optimize( let llvm_plugins = config.llvm_plugins.join(","); + let enzyme_fn = if consider_ad { + // Initialize Enzyme if not already done (idempotent due to OnceLock) + let wrapper = llvm::EnzymeWrapper::get_or_init(&cgcx.sysroot); + wrapper.registerEnzymeAndPassPipeline + } else { + std::ptr::null() + }; + let result = unsafe { llvm::LLVMRustOptimize( module.module_llvm.llmod(), @@ -745,7 +753,7 @@ pub(crate) unsafe fn llvm_optimize( vectorize_loop, config.no_builtins, config.emit_lifetime_markers, - run_enzyme, + enzyme_fn, print_before_enzyme, print_after_enzyme, print_passes, diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index b10a1282f4dd0..7b33f7e1b882e 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -1099,7 +1099,13 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { // vs. copying a struct with mixed types requires different derivative handling. // The TypeTree tells Enzyme exactly what memory layout to expect. if let Some(tt) = tt { - crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt); + crate::typetree::add_tt( + self.cx().llmod, + self.cx().llcx, + memcpy, + tt, + &self.cx().tcx.sess.opts.sysroot, + ); } } diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 4b433e2b63616..7ffd8762114e0 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -375,7 +375,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>( ); if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() { - crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree); + crate::typetree::add_tt( + cx.llmod, + cx.llcx, + fn_to_diff, + fnc_tree, + &builder.cx().tcx.sess.opts.sysroot, + ); } let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None); diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 1b65a133d58c1..e78da3fb83c4d 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -41,7 +41,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::Session; -use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; +use rustc_session::config::{AutoDiff, OptLevel, OutputFilenames, PrintKind, PrintRequest}; use rustc_span::Symbol; use rustc_target::spec::{RelocModel, TlsModel}; @@ -240,6 +240,13 @@ impl CodegenBackend for LlvmCodegenBackend { fn init(&self, sess: &Session) { llvm_util::init(sess); // Make sure llvm is inited + + if sess.opts.unstable_opts.autodiff.contains(&AutoDiff::Enable) { + #[cfg(feature = "llvm_enzyme")] + { + drop(llvm::EnzymeWrapper::get_or_init(&sess.opts.sysroot)); + } + } } fn provide(&self, providers: &mut Providers) { diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index e63043b21227f..da000b9a37bec 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -1,6 +1,11 @@ #![expect(dead_code)] use libc::{c_char, c_uint}; +// I am going to delete this declaration. +// I have just added this to avoid conflicting to main branch and to let CI run. +// I will make libloading as optional later, before merging this PR. +#[cfg(not(feature = "llvm_enzyme"))] +use libloading as _; use super::MetadataKindId; use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value}; @@ -91,102 +96,364 @@ pub(crate) use self::Enzyme_AD::*; #[cfg(feature = "llvm_enzyme")] pub(crate) mod Enzyme_AD { - use std::ffi::{CString, c_char}; + use std::ffi::{c_char, c_void}; + use std::sync::{Mutex, MutexGuard, OnceLock}; - use libc::c_void; + use rustc_middle::bug; + use rustc_session::config::{Sysroot, host_tuple}; + use rustc_session::filesearch; use super::{CConcreteType, CTypeTreeRef, Context}; - - unsafe extern "C" { - pub(crate) fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); - pub(crate) fn EnzymeSetCLString(arg1: *mut ::std::os::raw::c_void, arg2: *const c_char); + use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor}; + + type EnzymeSetCLBoolFn = unsafe extern "C" fn(*mut c_void, u8); + type EnzymeSetCLStringFn = unsafe extern "C" fn(*mut c_void, *const c_char); + + type EnzymeNewTypeTreeFn = unsafe extern "C" fn() -> CTypeTreeRef; + type EnzymeNewTypeTreeCTFn = unsafe extern "C" fn(CConcreteType, &Context) -> CTypeTreeRef; + type EnzymeNewTypeTreeTRFn = unsafe extern "C" fn(CTypeTreeRef) -> CTypeTreeRef; + type EnzymeFreeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef); + type EnzymeMergeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef, CTypeTreeRef) -> bool; + type EnzymeTypeTreeOnlyEqFn = unsafe extern "C" fn(CTypeTreeRef, i64); + type EnzymeTypeTreeData0EqFn = unsafe extern "C" fn(CTypeTreeRef); + type EnzymeTypeTreeShiftIndiciesEqFn = + unsafe extern "C" fn(CTypeTreeRef, *const c_char, i64, i64, u64); + type EnzymeTypeTreeInsertEqFn = + unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context); + type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char; + type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char); + + #[allow(non_snake_case)] + pub(crate) struct EnzymeWrapper { + EnzymeNewTypeTree: EnzymeNewTypeTreeFn, + EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn, + EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn, + EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn, + EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn, + EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn, + EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn, + EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn, + EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn, + EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn, + EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn, + + EnzymePrintPerf: *mut c_void, + EnzymePrintActivity: *mut c_void, + EnzymePrintType: *mut c_void, + EnzymeFunctionToAnalyze: *mut c_void, + EnzymePrint: *mut c_void, + EnzymeStrictAliasing: *mut c_void, + EnzymeInline: *mut c_void, + EnzymeMaxTypeDepth: *mut c_void, + RustTypeRules: *mut c_void, + looseTypeAnalysis: *mut c_void, + + EnzymeSetCLBool: EnzymeSetCLBoolFn, + EnzymeSetCLString: EnzymeSetCLStringFn, + pub registerEnzymeAndPassPipeline: *const c_void, + lib: libloading::Library, + } + + unsafe impl Sync for EnzymeWrapper {} + unsafe impl Send for EnzymeWrapper {} + + fn load_ptr_by_symbol_mut_void( + lib: &libloading::Library, + bytes: &[u8], + ) -> Result<*mut c_void, Box> { + unsafe { + let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?; + // libloading = 0.9.0: try_as_raw_ptr always succeeds and returns Some + let s = s.try_as_raw_ptr().unwrap(); + Ok(s) + } } - // TypeTree functions - unsafe extern "C" { - pub(crate) fn EnzymeNewTypeTree() -> CTypeTreeRef; - pub(crate) fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; - pub(crate) fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; - pub(crate) fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); - pub(crate) fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; - pub(crate) fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); - pub(crate) fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); - pub(crate) fn EnzymeTypeTreeShiftIndiciesEq( - arg1: CTypeTreeRef, + // e.g. + // load_ptrs_by_symbols_mut_void(ABC, XYZ); + // => + // let ABC = load_ptr_mut_void(&lib, b"ABC")?; + // let XYZ = load_ptr_mut_void(&lib, b"XYZ")?; + macro_rules! load_ptrs_by_symbols_mut_void { + ($lib:expr, $($name:ident),* $(,)?) => { + $( + #[allow(non_snake_case)] + let $name = load_ptr_by_symbol_mut_void(&$lib, stringify!($name).as_bytes())?; + )* + }; + } + + // e.g. + // load_ptrs_by_symbols_fn(ABC: ABCFn, XYZ: XYZFn); + // => + // let ABC: libloading::Symbol<'_, ABCFn> = unsafe { lib.get(b"ABC")? }; + // let XYZ: libloading::Symbol<'_, XYZFn> = unsafe { lib.get(b"XYZ")? }; + macro_rules! load_ptrs_by_symbols_fn { + ($lib:expr, $($name:ident : $ty:ty),* $(,)?) => { + $( + #[allow(non_snake_case)] + let $name: $ty = *unsafe { $lib.get::<$ty>(stringify!($name).as_bytes())? }; + )* + }; + } + + static ENZYME_INSTANCE: OnceLock> = OnceLock::new(); + + impl EnzymeWrapper { + /// Initialize EnzymeWrapper with the given sysroot if not already initialized. + /// Safe to call multiple times - subsequent calls are no-ops due to OnceLock. + pub(crate) fn get_or_init( + sysroot: &rustc_session::config::Sysroot, + ) -> MutexGuard<'static, Self> { + ENZYME_INSTANCE + .get_or_init(|| { + Self::call_dynamic(sysroot) + .unwrap_or_else(|e| bug!("failed to load Enzyme: {e}")) + .into() + }) + .lock() + .unwrap() + } + + /// Get the EnzymeWrapper instance. Panics if not initialized. + /// Call get_or_init with a sysroot first. + pub(crate) fn get_instance() -> MutexGuard<'static, Self> { + ENZYME_INSTANCE + .get() + .expect("EnzymeWrapper not initialized. Call get_or_init with sysroot first.") + .lock() + .unwrap() + } + + pub(crate) fn new_type_tree(&self) -> CTypeTreeRef { + unsafe { (self.EnzymeNewTypeTree)() } + } + + pub(crate) fn new_type_tree_ct( + &self, + t: CConcreteType, + ctx: &Context, + ) -> *mut EnzymeTypeTree { + unsafe { (self.EnzymeNewTypeTreeCT)(t, ctx) } + } + + pub(crate) fn new_type_tree_tr(&self, tree: CTypeTreeRef) -> CTypeTreeRef { + unsafe { (self.EnzymeNewTypeTreeTR)(tree) } + } + + pub(crate) fn free_type_tree(&self, tree: CTypeTreeRef) { + unsafe { (self.EnzymeFreeTypeTree)(tree) } + } + + pub(crate) fn merge_type_tree(&self, tree1: CTypeTreeRef, tree2: CTypeTreeRef) -> bool { + unsafe { (self.EnzymeMergeTypeTree)(tree1, tree2) } + } + + pub(crate) fn tree_only_eq(&self, tree: CTypeTreeRef, num: i64) { + unsafe { (self.EnzymeTypeTreeOnlyEq)(tree, num) } + } + + pub(crate) fn tree_data0_eq(&self, tree: CTypeTreeRef) { + unsafe { (self.EnzymeTypeTreeData0Eq)(tree) } + } + + pub(crate) fn shift_indicies_eq( + &self, + tree: CTypeTreeRef, data_layout: *const c_char, offset: i64, max_size: i64, add_offset: u64, - ); - pub(crate) fn EnzymeTypeTreeInsertEq( - CTT: CTypeTreeRef, + ) { + unsafe { + (self.EnzymeTypeTreeShiftIndiciesEq)( + tree, + data_layout, + offset, + max_size, + add_offset, + ) + } + } + + pub(crate) fn tree_insert_eq( + &self, + tree: CTypeTreeRef, indices: *const i64, len: usize, ct: CConcreteType, ctx: &Context, - ); - pub(crate) fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; - pub(crate) fn EnzymeTypeTreeToStringFree(arg1: *const c_char); - } + ) { + unsafe { (self.EnzymeTypeTreeInsertEq)(tree, indices, len, ct, ctx) } + } - unsafe extern "C" { - static mut EnzymePrintPerf: c_void; - static mut EnzymePrintActivity: c_void; - static mut EnzymePrintType: c_void; - static mut EnzymeFunctionToAnalyze: c_void; - static mut EnzymePrint: c_void; - static mut EnzymeStrictAliasing: c_void; - static mut looseTypeAnalysis: c_void; - static mut EnzymeInline: c_void; - static mut RustTypeRules: c_void; - } - pub(crate) fn set_print_perf(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8); + pub(crate) fn tree_to_string(&self, tree: *mut EnzymeTypeTree) -> *const c_char { + unsafe { (self.EnzymeTypeTreeToString)(tree) } } - } - pub(crate) fn set_print_activity(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8); + + pub(crate) fn tree_to_string_free(&self, ch: *const c_char) { + unsafe { (self.EnzymeTypeTreeToStringFree)(ch) } } - } - pub(crate) fn set_print_type(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8); + + pub(crate) fn get_max_type_depth(&self) -> usize { + unsafe { std::ptr::read::(self.EnzymeMaxTypeDepth as *const u32) as usize } } - } - pub(crate) fn set_print_type_fun(fun_name: &str) { - let c_fun_name = CString::new(fun_name).unwrap(); - unsafe { - EnzymeSetCLString( - std::ptr::addr_of_mut!(EnzymeFunctionToAnalyze), - c_fun_name.as_ptr() as *const c_char, - ); + + pub(crate) fn set_print_perf(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymePrintPerf, print as u8); + } } - } - pub(crate) fn set_print(print: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8); + + pub(crate) fn set_print_activity(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymePrintActivity, print as u8); + } } - } - pub(crate) fn set_strict_aliasing(strict: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8); + + pub(crate) fn set_print_type(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymePrintType, print as u8); + } } - } - pub(crate) fn set_loose_types(loose: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8); + + pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) { + let c_fun_name = std::ffi::CString::new(fun_name) + .unwrap_or_else(|err| bug!("failed to set_print_type_fun: {err}")); + unsafe { + (self.EnzymeSetCLString)( + self.EnzymeFunctionToAnalyze, + c_fun_name.as_ptr() as *const c_char, + ); + } } - } - pub(crate) fn set_inline(val: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8); + + pub(crate) fn set_print(&mut self, print: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymePrint, print as u8); + } } - } - pub(crate) fn set_rust_rules(val: bool) { - unsafe { - EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8); + + pub(crate) fn set_strict_aliasing(&mut self, strict: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymeStrictAliasing, strict as u8); + } + } + + pub(crate) fn set_loose_types(&mut self, loose: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.looseTypeAnalysis, loose as u8); + } + } + + pub(crate) fn set_inline(&mut self, val: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.EnzymeInline, val as u8); + } + } + + pub(crate) fn set_rust_rules(&mut self, val: bool) { + unsafe { + (self.EnzymeSetCLBool)(self.RustTypeRules, val as u8); + } + } + + #[allow(non_snake_case)] + fn call_dynamic( + sysroot: &rustc_session::config::Sysroot, + ) -> Result> { + let enzyme_path = Self::get_enzyme_path(sysroot)?; + let lib = unsafe { libloading::Library::new(enzyme_path)? }; + + load_ptrs_by_symbols_fn!( + lib, + EnzymeNewTypeTree: EnzymeNewTypeTreeFn, + EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn, + EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn, + EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn, + EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn, + EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn, + EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn, + EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn, + EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn, + EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn, + EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn, + EnzymeSetCLBool: EnzymeSetCLBoolFn, + EnzymeSetCLString: EnzymeSetCLStringFn, + ); + + load_ptrs_by_symbols_mut_void!( + lib, + registerEnzymeAndPassPipeline, + EnzymePrintPerf, + EnzymePrintActivity, + EnzymePrintType, + EnzymeFunctionToAnalyze, + EnzymePrint, + EnzymeStrictAliasing, + EnzymeInline, + EnzymeMaxTypeDepth, + RustTypeRules, + looseTypeAnalysis, + ); + + Ok(Self { + EnzymeNewTypeTree, + EnzymeNewTypeTreeCT, + EnzymeNewTypeTreeTR, + EnzymeFreeTypeTree, + EnzymeMergeTypeTree, + EnzymeTypeTreeOnlyEq, + EnzymeTypeTreeData0Eq, + EnzymeTypeTreeShiftIndiciesEq, + EnzymeTypeTreeInsertEq, + EnzymeTypeTreeToString, + EnzymeTypeTreeToStringFree, + EnzymePrintPerf, + EnzymePrintActivity, + EnzymePrintType, + EnzymeFunctionToAnalyze, + EnzymePrint, + EnzymeStrictAliasing, + EnzymeInline, + EnzymeMaxTypeDepth, + RustTypeRules, + looseTypeAnalysis, + EnzymeSetCLBool, + EnzymeSetCLString, + registerEnzymeAndPassPipeline, + lib, + }) + } + + fn get_enzyme_path(sysroot: &Sysroot) -> Result { + let llvm_version_major = unsafe { LLVMRustVersionMajor() }; + + let path_buf = sysroot + .all_paths() + .map(|sysroot_path| { + filesearch::make_target_lib_path(sysroot_path, host_tuple()) + .join("lib") + .with_file_name(format!("libEnzyme-{llvm_version_major}")) + .with_extension(std::env::consts::DLL_EXTENSION) + }) + .find(|f| f.exists()) + .ok_or_else(|| { + let candidates = sysroot + .all_paths() + .map(|p| p.join("lib").display().to_string()) + .collect::>() + .join("\n* "); + format!( + "failed to find a `libEnzyme-{llvm_version_major}` folder \ + in the sysroot candidates:\n* {candidates}" + ) + })?; + + Ok(path_buf + .to_str() + .ok_or_else(|| format!("invalid UTF-8 in path: {}", path_buf.display()))? + .to_string()) } } } @@ -198,111 +465,150 @@ pub(crate) use self::Fallback_AD::*; pub(crate) mod Fallback_AD { #![allow(unused_variables)] + use std::ffi::c_void; + use std::sync::Mutex; + use libc::c_char; + use rustc_codegen_ssa::back::write::CodegenContext; + use rustc_codegen_ssa::traits::WriteBackendMethods; - use super::{CConcreteType, CTypeTreeRef, Context}; + use super::{CConcreteType, CTypeTreeRef, Context, EnzymeTypeTree}; - // TypeTree function fallbacks - pub(crate) unsafe fn EnzymeNewTypeTree() -> CTypeTreeRef { - unimplemented!() + pub(crate) struct EnzymeWrapper { + pub registerEnzymeAndPassPipeline: *const c_void, } - pub(crate) unsafe fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef { - unimplemented!() - } + impl EnzymeWrapper { + pub(crate) fn init<'a, B: WriteBackendMethods>( + _cgcx: &'a CodegenContext, + ) -> &'static Mutex { + unimplemented!("Enzyme not available: build with llvm_enzyme feature") + } - pub(crate) unsafe fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef { - unimplemented!() - } + pub(crate) fn get_instance() -> &'static Mutex { + unimplemented!("Enzyme not available: build with llvm_enzyme feature") + } - pub(crate) unsafe fn EnzymeFreeTypeTree(CTT: CTypeTreeRef) { - unimplemented!() - } + pub(crate) fn new_type_tree(&self) -> CTypeTreeRef { + unimplemented!() + } - pub(crate) unsafe fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool { - unimplemented!() - } + pub(crate) fn new_type_tree_ct( + &self, + t: CConcreteType, + ctx: &Context, + ) -> *mut EnzymeTypeTree { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64) { - unimplemented!() - } + pub(crate) fn new_type_tree_tr(&self, tree: CTypeTreeRef) -> CTypeTreeRef { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef) { - unimplemented!() - } + pub(crate) fn free_type_tree(&self, tree: CTypeTreeRef) { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeShiftIndiciesEq( - arg1: CTypeTreeRef, - data_layout: *const c_char, - offset: i64, - max_size: i64, - add_offset: u64, - ) { - unimplemented!() - } + pub(crate) fn merge_type_tree(&self, tree1: CTypeTreeRef, tree2: CTypeTreeRef) -> bool { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeInsertEq( - CTT: CTypeTreeRef, - indices: *const i64, - len: usize, - ct: CConcreteType, - ctx: &Context, - ) { - unimplemented!() - } + pub(crate) fn tree_only_eq(&self, tree: CTypeTreeRef, num: i64) { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char { - unimplemented!() - } + pub(crate) fn tree_data0_eq(&self, tree: CTypeTreeRef) { + unimplemented!() + } - pub(crate) unsafe fn EnzymeTypeTreeToStringFree(arg1: *const c_char) { - unimplemented!() - } + pub(crate) fn shift_indicies_eq( + &self, + tree: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ) { + unimplemented!() + } - pub(crate) fn set_inline(val: bool) { - unimplemented!() - } - pub(crate) fn set_print_perf(print: bool) { - unimplemented!() - } - pub(crate) fn set_print_activity(print: bool) { - unimplemented!() - } - pub(crate) fn set_print_type(print: bool) { - unimplemented!() - } - pub(crate) fn set_print_type_fun(fun_name: &str) { - unimplemented!() - } - pub(crate) fn set_print(print: bool) { - unimplemented!() - } - pub(crate) fn set_strict_aliasing(strict: bool) { - unimplemented!() - } - pub(crate) fn set_loose_types(loose: bool) { - unimplemented!() - } - pub(crate) fn set_rust_rules(val: bool) { - unimplemented!() + pub(crate) fn tree_insert_eq( + &self, + tree: CTypeTreeRef, + indices: *const i64, + len: usize, + ct: CConcreteType, + ctx: &Context, + ) { + unimplemented!() + } + + pub(crate) fn tree_to_string(&self, tree: *mut EnzymeTypeTree) -> *const c_char { + unimplemented!() + } + + pub(crate) fn tree_to_string_free(&self, ch: *const c_char) { + unimplemented!() + } + + pub(crate) fn get_max_type_depth(&self) -> usize { + unimplemented!() + } + + pub(crate) fn set_inline(&mut self, val: bool) { + unimplemented!() + } + + pub(crate) fn set_print_perf(&mut self, print: bool) { + unimplemented!() + } + + pub(crate) fn set_print_activity(&mut self, print: bool) { + unimplemented!() + } + + pub(crate) fn set_print_type(&mut self, print: bool) { + unimplemented!() + } + + pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) { + unimplemented!() + } + + pub(crate) fn set_print(&mut self, print: bool) { + unimplemented!() + } + + pub(crate) fn set_strict_aliasing(&mut self, strict: bool) { + unimplemented!() + } + + pub(crate) fn set_loose_types(&mut self, loose: bool) { + unimplemented!() + } + + pub(crate) fn set_rust_rules(&mut self, val: bool) { + unimplemented!() + } } } impl TypeTree { pub(crate) fn new() -> TypeTree { - let inner = unsafe { EnzymeNewTypeTree() }; + let wrapper = EnzymeWrapper::get_instance(); + let inner = wrapper.new_type_tree(); TypeTree { inner } } pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { - let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + let wrapper = EnzymeWrapper::get_instance(); + let inner = wrapper.new_type_tree_ct(t, ctx); TypeTree { inner } } pub(crate) fn merge(self, other: Self) -> Self { - unsafe { - EnzymeMergeTypeTree(self.inner, other.inner); - } + let wrapper = EnzymeWrapper::get_instance(); + wrapper.merge_type_tree(self.inner, other.inner); drop(other); self } @@ -316,37 +622,36 @@ impl TypeTree { add_offset: usize, ) -> Self { let layout = std::ffi::CString::new(layout).unwrap(); - - unsafe { - EnzymeTypeTreeShiftIndiciesEq( - self.inner, - layout.as_ptr(), - offset as i64, - max_size as i64, - add_offset as u64, - ); - } + let wrapper = EnzymeWrapper::get_instance(); + wrapper.shift_indicies_eq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ); self } pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) { - unsafe { - EnzymeTypeTreeInsertEq(self.inner, indices.as_ptr(), indices.len(), ct, ctx); - } + let wrapper = EnzymeWrapper::get_instance(); + wrapper.tree_insert_eq(self.inner, indices.as_ptr(), indices.len(), ct, ctx); } } impl Clone for TypeTree { fn clone(&self) -> Self { - let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + let wrapper = EnzymeWrapper::get_instance(); + let inner = wrapper.new_type_tree_tr(self.inner); TypeTree { inner } } } impl std::fmt::Display for TypeTree { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let wrapper = EnzymeWrapper::get_instance(); + let ptr = wrapper.tree_to_string(self.inner); let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) }; match cstr.to_str() { Ok(x) => write!(f, "{}", x)?, @@ -354,9 +659,7 @@ impl std::fmt::Display for TypeTree { } // delete C string pointer - unsafe { - EnzymeTypeTreeToStringFree(ptr); - } + wrapper.tree_to_string_free(ptr); Ok(()) } @@ -370,6 +673,7 @@ impl std::fmt::Debug for TypeTree { impl Drop for TypeTree { fn drop(&mut self) { - unsafe { EnzymeFreeTypeTree(self.inner) } + let wrapper = EnzymeWrapper::get_instance(); + wrapper.free_type_tree(self.inner) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index ca64d96c2a33c..be68d52330341 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -2384,7 +2384,7 @@ unsafe extern "C" { LoopVectorize: bool, DisableSimplifyLibCalls: bool, EmitLifetimeMarkers: bool, - RunEnzyme: bool, + RunEnzyme: *const c_void, PrintBeforeEnzyme: bool, PrintAfterEnzyme: bool, PrintPasses: bool, diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs index 7e2635037008e..64605c03de65d 100644 --- a/compiler/rustc_codegen_llvm/src/typetree.rs +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -2,6 +2,7 @@ use rustc_ast::expand::typetree::FncTree; #[cfg(feature = "llvm_enzyme")] use { crate::attributes, + crate::llvm::EnzymeWrapper, rustc_ast::expand::typetree::TypeTree as RustTypeTree, std::ffi::{CString, c_char, c_uint}, }; @@ -62,6 +63,7 @@ pub(crate) fn add_tt<'ll>( llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree, + sysroot: &rustc_session::config::Sysroot, ) { let inputs = tt.args; let ret_tt: RustTypeTree = tt.ret; @@ -74,10 +76,12 @@ pub(crate) fn add_tt<'ll>( let attr_name = "enzyme_type"; let c_attr_name = CString::new(attr_name).unwrap(); + let enzyme_wrapper = EnzymeWrapper::get_or_init(sysroot); + for (i, input) in inputs.iter().enumerate() { unsafe { let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); - let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); + let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); let c_str = std::ffi::CStr::from_ptr(c_str); let attr = llvm::LLVMCreateStringAttribute( @@ -89,13 +93,13 @@ pub(crate) fn add_tt<'ll>( ); attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); - llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } unsafe { let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); - let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); + let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner); let c_str = std::ffi::CStr::from_ptr(c_str); let ret_attr = llvm::LLVMCreateStringAttribute( @@ -107,7 +111,7 @@ pub(crate) fn add_tt<'ll>( ); attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); - llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + enzyme_wrapper.tree_to_string_free(c_str.as_ptr()); } } @@ -117,6 +121,7 @@ pub(crate) fn add_tt<'ll>( _llcx: &'ll llvm::Context, _fn_def: &'ll Value, _tt: FncTree, + _sysroot: &rustc_session::config::Sysroot, ) { unimplemented!() } diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index fc1edec8de843..c0abc9f2fbb90 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -29,6 +29,7 @@ use rustc_middle::ty::TyCtxt; use rustc_session::Session; use rustc_session::config::{ self, CrateType, Lto, OutFileName, OutputFilenames, OutputType, Passes, SwitchWithOptPath, + Sysroot, }; use rustc_span::source_map::SourceMap; use rustc_span::{FileName, InnerSpan, Span, SpanData, sym}; @@ -346,6 +347,7 @@ pub struct CodegenContext { pub split_debuginfo: rustc_target::spec::SplitDebuginfo, pub split_dwarf_kind: rustc_session::config::SplitDwarfKind, pub pointer_size: Size, + pub sysroot: Sysroot, /// Emitter to use for diagnostics produced during codegen. pub diag_emitter: SharedEmitter, @@ -1316,6 +1318,7 @@ fn start_executing_work( parallel: backend.supports_parallel() && !sess.opts.unstable_opts.no_parallel_backend, pointer_size: tcx.data_layout.pointer_size(), invocation_temp: sess.invocation_temp.clone(), + sysroot: sess.opts.sysroot.clone(), }; // This is the "main loop" of parallel work happening for parallel codegen. diff --git a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp index 143cc94790890..225b3d3d5a8c1 100644 --- a/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp @@ -540,17 +540,8 @@ struct LLVMRustSanitizerOptions { bool SanitizeKernelAddressRecover; }; -// This symbol won't be available or used when Enzyme is not enabled. -// Always set AugmentPassBuilder to true, since it registers optimizations which -// will improve the performance for Enzyme. -#ifdef ENZYME -extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB, - /* augmentPassBuilder */ bool); - -extern "C" { -extern llvm::cl::opt EnzymeFunctionToAnalyze; -} -#endif +extern "C" typedef void (*registerEnzymeAndPassPipelineFn)( + llvm::PassBuilder &PB, bool augment); extern "C" LLVMRustResult LLVMRustOptimize( LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef, @@ -559,8 +550,8 @@ extern "C" LLVMRustResult LLVMRustOptimize( bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO, bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops, bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls, - bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme, - bool PrintAfterEnzyme, bool PrintPasses, + bool EmitLifetimeMarkers, registerEnzymeAndPassPipelineFn EnzymePtr, + bool PrintBeforeEnzyme, bool PrintAfterEnzyme, bool PrintPasses, LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath, const char *PGOUsePath, bool InstrumentCoverage, const char *InstrProfileOutput, const char *PGOSampleUsePath, @@ -897,8 +888,8 @@ extern "C" LLVMRustResult LLVMRustOptimize( } // now load "-enzyme" pass: -#ifdef ENZYME - if (RunEnzyme) { + // With dlopen, ENZYME macro may not be defined, so check EnzymePtr directly + if (EnzymePtr) { if (PrintBeforeEnzyme) { // Handle the Rust flag `-Zautodiff=PrintModBefore`. @@ -906,29 +897,19 @@ extern "C" LLVMRustResult LLVMRustOptimize( MPM.addPass(PrintModulePass(outs(), Banner, true, false)); } - registerEnzymeAndPassPipeline(PB, false); + EnzymePtr(PB, false); if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) { std::string ErrMsg = toString(std::move(Err)); LLVMRustSetLastError(ErrMsg.c_str()); return LLVMRustResult::Failure; } - // Check if PrintTAFn was used and add type analysis pass if needed - if (!EnzymeFunctionToAnalyze.empty()) { - if (auto Err = PB.parsePassPipeline(MPM, "print-type-analysis")) { - std::string ErrMsg = toString(std::move(Err)); - LLVMRustSetLastError(ErrMsg.c_str()); - return LLVMRustResult::Failure; - } - } - if (PrintAfterEnzyme) { // Handle the Rust flag `-Zautodiff=PrintModAfter`. std::string Banner = "Module after EnzymeNewPM"; MPM.addPass(PrintModulePass(outs(), Banner, true, false)); } } -#endif if (PrintPasses) { // Print all passes from the PM: std::string Pipeline; diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 8823c83922822..c27cd013fe32e 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -1700,18 +1700,6 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) { GV.setSanitizerMetadata(MD); } -#ifdef ENZYME -extern "C" { -extern llvm::cl::opt EnzymeMaxTypeDepth; -} - -extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { return EnzymeMaxTypeDepth; } -#else -extern "C" size_t LLVMRustEnzymeGetMaxTypeDepth() { - return 6; // Default fallback depth -} -#endif - // Statically assert that the fixed metadata kind IDs declared in // `metadata_kind.rs` match the ones actually used by LLVM. #define FIXED_MD_KIND(VARIANT, VALUE) \ diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index 6857a40ada81b..7865d456c8dad 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1228,19 +1228,6 @@ pub fn rustc_cargo( // . cargo.rustflag("-Zon-broken-pipe=kill"); - // We want to link against registerEnzyme and in the future we want to use additional - // functionality from Enzyme core. For that we need to link against Enzyme. - if builder.config.llvm_enzyme { - let arch = builder.build.host_target; - let enzyme_dir = builder.build.out.join(arch).join("enzyme").join("lib"); - cargo.rustflag("-L").rustflag(enzyme_dir.to_str().expect("Invalid path")); - - if let Some(llvm_config) = builder.llvm_config(builder.config.host_target) { - let llvm_version_major = llvm::get_llvm_version_major(builder, &llvm_config); - cargo.rustflag("-l").rustflag(&format!("Enzyme-{llvm_version_major}")); - } - } - // Building with protected visibility reduces the number of dynamic relocations needed, giving // us a faster startup time. However GNU ld < 2.40 will error if we try to link a shared object // with direct references to protected symbols, so for now we only use protected symbols if diff --git a/typos.toml b/typos.toml index 758239ffe751c..b9d9c6c3522cf 100644 --- a/typos.toml +++ b/typos.toml @@ -50,6 +50,8 @@ unstalled = "unstalled" debug_aranges = "debug_aranges" DNS_ERROR_INVAILD_VIRTUALIZATION_INSTANCE_NAME = "DNS_ERROR_INVAILD_VIRTUALIZATION_INSTANCE_NAME" EnzymeTypeTreeShiftIndiciesEq = "EnzymeTypeTreeShiftIndiciesEq" +EnzymeTypeTreeShiftIndiciesEqFn = "EnzymeTypeTreeShiftIndiciesEqFn" +shift_indicies_eq = "shift_indicies_eq" ERRNO_ACCES = "ERRNO_ACCES" ERROR_DS_FILTER_USES_CONTRUCTED_ATTRS = "ERROR_DS_FILTER_USES_CONTRUCTED_ATTRS" ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC = "ERROR_DS_NOT_AUTHORITIVE_FOR_DST_NC"