Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/cuda_builder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,10 @@ fn invoke_rustc(builder: &CudaBuilder) -> Result<PathBuf, CudaBuilderError> {
llvm_args.push("--override-libm".to_string());
}

if builder.use_constant_memory_space {
llvm_args.push("--use-constant-memory-space".to_string());
}

if let Some(path) = &builder.final_module_path {
llvm_args.push("--final-module-path".to_string());
llvm_args.push(path.to_str().unwrap().to_string());
Expand Down
105 changes: 96 additions & 9 deletions crates/rustc_codegen_nvvm/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ use tracing::{debug, trace};
/// <https://docs.nvidia.com/cuda/archive/12.8.1/pdf/CUDA_C_Best_Practices_Guide.pdf>
const CONSTANT_MEMORY_SIZE_LIMIT_BYTES: u64 = 64 * 1024;

/// Threshold for warning when approaching 80% of constant memory limit
const CONSTANT_MEMORY_WARNING_THRESHOLD_BYTES: u64 = (CONSTANT_MEMORY_SIZE_LIMIT_BYTES * 80) / 100;

pub(crate) struct CodegenCx<'ll, 'tcx> {
pub tcx: TyCtxt<'tcx>,

Expand Down Expand Up @@ -104,6 +107,9 @@ pub(crate) struct CodegenCx<'ll, 'tcx> {
pub codegen_args: CodegenArgs,
// the value of the last call instruction. Needed for return type remapping.
pub last_call_llfn: Cell<Option<&'ll Value>>,

/// Tracks cumulative constant memory usage in bytes for compile-time diagnostics
constant_memory_usage: Cell<u64>,
}

impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
Expand Down Expand Up @@ -174,6 +180,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
dbg_cx,
codegen_args: CodegenArgs::from_session(tcx.sess()),
last_call_llfn: Cell::new(None),
constant_memory_usage: Cell::new(0),
};
cx.build_intrinsics_map();
cx
Expand Down Expand Up @@ -281,16 +288,96 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
// static and many small ones, you might want the small ones to all be
// in constant memory or just the big one depending on your workload.
let layout = self.layout_of(ty);
if layout.size.bytes() > CONSTANT_MEMORY_SIZE_LIMIT_BYTES {
self.tcx.sess.dcx().warn(format!(
"static `{instance}` exceeds the constant memory limit; placing in global memory (performance may be reduced)"
));
// Place instance in global memory if it is too big for constant memory.
AddressSpace(1)
} else {
// Place instance in constant memory if it fits.
AddressSpace(4)
let size_bytes = layout.size.bytes();
let current_usage = self.constant_memory_usage.get();
let new_usage = current_usage + size_bytes;

// Check if this single static is too large for constant memory
if size_bytes > CONSTANT_MEMORY_SIZE_LIMIT_BYTES {
let def_id = instance.def_id();
let span = self.tcx.def_span(def_id);
let mut diag = self.tcx.sess.dcx().struct_span_warn(
span,
format!(
"static `{instance}` is {size_bytes} bytes, exceeds the constant memory limit of {} bytes",
CONSTANT_MEMORY_SIZE_LIMIT_BYTES
),
);
diag.span_label(span, "static exceeds constant memory limit");
diag.note("placing in global memory (performance may be reduced)");
diag.help("use `#[cuda_std::address_space(global)]` to explicitly place this static in global memory");
diag.emit();
return AddressSpace(1);
}

// Check if adding this static would exceed the cumulative limit
if new_usage > CONSTANT_MEMORY_SIZE_LIMIT_BYTES {
let def_id = instance.def_id();
let span = self.tcx.def_span(def_id);
let mut diag = self.tcx.sess.dcx().struct_span_err(
span,
format!(
"cannot place static `{instance}` ({size_bytes} bytes) in constant memory: \
cumulative constant memory usage would be {new_usage} bytes, exceeding the {} byte limit",
CONSTANT_MEMORY_SIZE_LIMIT_BYTES
),
);
diag.span_label(
span,
format!(
"this static would cause total usage to exceed {} bytes",
CONSTANT_MEMORY_SIZE_LIMIT_BYTES
),
);
diag.note(format!(
"current constant memory usage: {current_usage} bytes"
));
diag.note(format!("static size: {size_bytes} bytes"));
diag.note(format!("would result in: {new_usage} bytes total"));

diag.help("move this or other statics to global memory using `#[cuda_std::address_space(global)]`");
diag.help("reduce the total size of static data");
diag.help("disable automatic constant memory placement by setting `.use_constant_memory_space(false)` on `CudaBuilder` in build.rs");

diag.emit();
self.tcx.sess.dcx().abort_if_errors();
unreachable!()
}

// If successfully placed in constant memory: update cumulative usage
self.constant_memory_usage.set(new_usage);

// If approaching the threshold: warns
if new_usage > CONSTANT_MEMORY_WARNING_THRESHOLD_BYTES
&& current_usage <= CONSTANT_MEMORY_WARNING_THRESHOLD_BYTES
{
let def_id = instance.def_id();
let span = self.tcx.def_span(def_id);
let usage_percent =
(new_usage as f64 / CONSTANT_MEMORY_SIZE_LIMIT_BYTES as f64) * 100.0;
let mut diag = self.tcx.sess.dcx().struct_span_warn(
span,
format!(
"constant memory usage is approaching the limit: {new_usage} / {} bytes ({usage_percent:.1}% used)",
CONSTANT_MEMORY_SIZE_LIMIT_BYTES
),
);
diag.span_label(
span,
"this placement brought you over 80% of constant memory capacity",
);
diag.note(format!(
"only {} bytes of constant memory remain",
CONSTANT_MEMORY_SIZE_LIMIT_BYTES - new_usage
));
diag.help("to prevent constant memory overflow, consider moving some statics to global memory using `#[cuda_std::address_space(global)]`");
diag.emit();
}

trace!(
"Placing static `{instance}` ({size_bytes} bytes) in constant memory. Total usage: {new_usage} bytes"
);
AddressSpace(4)
}
} else {
AddressSpace::ZERO
Expand Down
Loading