Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3574,6 +3574,7 @@ dependencies = [
"gimli 0.31.1",
"itertools",
"libc",
"libloading 0.9.0",
"measureme",
"object 0.37.3",
"rustc-demangle",
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ bitflags = "2.4.1"
gimli = "0.31"
itertools = "0.12"
libc = "0.2"
libloading = "0.9.0"
measureme = "12.0.1"
object = { version = "0.37.0", default-features = false, features = ["std", "read"] }
rustc-demangle = "0.1.21"
Expand Down
27 changes: 15 additions & 12 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,31 +528,34 @@ fn thin_lto(
}
}

fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
fn enable_autodiff_settings(cgcx: &CodegenContext<LlvmCodegenBackend>, ad: &[config::AutoDiff]) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call this from LlvmCodegenBackend::init() instead? There you have access to the full Session.

use std::sync::Mutex;
let enzyme: &'static Mutex<llvm::EnzymeWrapper> = llvm::EnzymeWrapper::init(cgcx);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe return a MutexGuard from EnzymeWrapper::init()?


for val in ad {
// We intentionally don't use a wildcard, to not forget handling anything new.
match val {
config::AutoDiff::PrintPerf => {
llvm::set_print_perf(true);
enzyme.lock().unwrap().set_print_perf(true);
}
config::AutoDiff::PrintAA => {
llvm::set_print_activity(true);
enzyme.lock().unwrap().set_print_activity(true);
}
config::AutoDiff::PrintTA => {
llvm::set_print_type(true);
enzyme.lock().unwrap().set_print_type(true);
}
config::AutoDiff::PrintTAFn(fun) => {
llvm::set_print_type(true); // Enable general type printing
llvm::set_print_type_fun(&fun); // Set specific function to analyze
enzyme.lock().unwrap().set_print_type(true); // Enable general type printing
enzyme.lock().unwrap().set_print_type_fun(&fun); // Set specific function to analyze
}
config::AutoDiff::Inline => {
llvm::set_inline(true);
enzyme.lock().unwrap().set_inline(true);
}
config::AutoDiff::LooseTypes => {
llvm::set_loose_types(true);
enzyme.lock().unwrap().set_loose_types(true);
}
config::AutoDiff::PrintSteps => {
llvm::set_print(true);
enzyme.lock().unwrap().set_print(true);
}
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintPasses => {}
Expand All @@ -571,9 +574,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
}
}
// This helps with handling enums for now.
llvm::set_strict_aliasing(false);
enzyme.lock().unwrap().set_strict_aliasing(false);
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
llvm::set_rust_rules(true);
enzyme.lock().unwrap().set_rust_rules(true);
}

pub(crate) fn run_pass_manager(
Expand Down Expand Up @@ -609,7 +612,7 @@ pub(crate) fn run_pass_manager(
};

if enable_ad {
enable_autodiff_settings(&config.autodiff);
enable_autodiff_settings(&cgcx, &config.autodiff);
}

unsafe {
Expand Down
11 changes: 9 additions & 2 deletions compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::ffi::{CStr, CString};
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use std::ptr::null_mut;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::{fs, slice, str};

use libc::{c_char, c_int, c_void, size_t};
Expand Down Expand Up @@ -726,6 +726,13 @@ pub(crate) unsafe fn llvm_optimize(

let llvm_plugins = config.llvm_plugins.join(",");

let enzyme_fn = if consider_ad {
let wrapper: &'static Mutex<llvm::EnzymeWrapper> = llvm::EnzymeWrapper::init(cgcx);
wrapper.lock().unwrap().registerEnzymeAndPassPipeline
} else {
std::ptr::null()
};

let result = unsafe {
llvm::LLVMRustOptimize(
module.module_llvm.llmod(),
Expand All @@ -745,7 +752,7 @@ pub(crate) unsafe fn llvm_optimize(
vectorize_loop,
config.no_builtins,
config.emit_lifetime_markers,
run_enzyme,
enzyme_fn,
print_before_enzyme,
print_after_enzyme,
print_passes,
Expand Down
Loading
Loading