From 58c968829d5fc003fd5d08cb1717f89c8b68d7d5 Mon Sep 17 00:00:00 2001 From: Ilya Tikhonovskiy Date: Mon, 8 Dec 2025 05:50:52 -0800 Subject: [PATCH] Integrate Triton up to 8d445186 https://github.com/openxla/triton/tree/triton_integrate_branch-1.15 PiperOrigin-RevId: 841726853 --- jaxlib/gpu/triton.cc | 4 ++-- jaxlib/gpu/triton.proto | 6 +++--- jaxlib/gpu/triton_kernels.cc | 37 +++++++++++++++++------------------- jaxlib/gpu/triton_kernels.h | 12 +++++------- 4 files changed, 27 insertions(+), 32 deletions(-) diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 42c58eb613a2..a1bb10ed510f 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -45,8 +45,8 @@ namespace jax::JAX_GPU_NAMESPACE { NB_MODULE(_triton, m) { nb::class_(m, "TritonKernel") - .def(nb::init()); + .def(nb::init()); nb::class_(m, "TritonParameter"); diff --git a/jaxlib/gpu/triton.proto b/jaxlib/gpu/triton.proto index 786b07afbdbe..4559833c1098 100644 --- a/jaxlib/gpu/triton.proto +++ b/jaxlib/gpu/triton.proto @@ -5,13 +5,13 @@ package jax_triton; message TritonKernel { string kernel_name = 1; // Kernel function name within module. uint32 num_warps = 2; + optional uint32 num_ctas = 10; uint32 shared_mem_bytes = 3; string ptx = 4; string ttir = 5; uint32 compute_capability = 6; - uint32 cluster_dim_0 = 7; - uint32 cluster_dim_1 = 8; - uint32 cluster_dim_2 = 9; + + reserved 7, 8, 9; // cluster_dim_0, cluster_dim_1, cluster_dim_2 } message TritonKernelCall { diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 1961ada1bf76..0ad86f522d9d 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -315,17 +315,16 @@ class ModuleImage { ABSL_GUARDED_BY(mutex_); }; -Kernel::Kernel(std::string kernel_name, uint32_t num_warps, +Kernel::Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas, uint32_t shared_mem_bytes, std::string ptx, std::string ttir, - int compute_capability, uint32_t cluster_dim_0, - uint32_t cluster_dim_1, uint32_t cluster_dim_2) + int compute_capability) : kernel_name_(std::move(kernel_name)), block_dim_x_(num_warps * kNumThreadsPerWarp), + num_ctas_(num_ctas), shared_mem_bytes_(shared_mem_bytes), ptx_(std::move(ptx)), ttir_(std::move(ttir)), - compute_capability_(compute_capability), - cluster_dims_{cluster_dim_0, cluster_dim_1, cluster_dim_2} {} + compute_capability_(compute_capability) {} absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], void** params) { @@ -362,9 +361,7 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel, module_image_->GetFunctionForContext(context)); - const uint32_t cluster_size = - cluster_dims_[0] * cluster_dims_[1] * cluster_dims_[2]; - if (cluster_size <= 1) { + if (num_ctas_ == 1) { return JAX_AS_STATUS(gpuLaunchKernel( kernel, grid[0], grid[1], grid[2], block_dim_x_, /*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params, @@ -372,16 +369,16 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], } CUlaunchAttribute launch_attrs[2]; launch_attrs[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launch_attrs[0].value.clusterDim.x = cluster_dims_[0]; - launch_attrs[0].value.clusterDim.y = cluster_dims_[1]; - launch_attrs[0].value.clusterDim.z = cluster_dims_[2]; + launch_attrs[0].value.clusterDim.x = num_ctas_; + launch_attrs[0].value.clusterDim.y = 1; + launch_attrs[0].value.clusterDim.z = 1; launch_attrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; launch_attrs[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_SPREAD; CUlaunchConfig launch_config = { - /*gridDimX=*/grid[0] * cluster_dims_[0], - /*gridDimY=*/grid[1] * cluster_dims_[1], - /*gridDimZ=*/grid[2] * cluster_dims_[2], + /*gridDimX=*/grid[0] * num_ctas_, + /*gridDimY=*/grid[1], + /*gridDimZ=*/grid[2], /*blockDimX=*/block_dim_x_, /*blockDimY=*/1, /*blockDimZ=*/1, @@ -396,23 +393,23 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3], } /*static*/ Kernel Kernel::FromProto(const jax_triton::TritonKernel& proto) { - return Kernel(proto.kernel_name(), proto.num_warps(), + // Use 1 as default value if not specified in already serialized kernels. + int num_ctas = proto.has_num_ctas() ? proto.num_ctas() : 1; + + return Kernel(proto.kernel_name(), proto.num_warps(), num_ctas, proto.shared_mem_bytes(), proto.ptx(), proto.ttir(), - proto.compute_capability(), proto.cluster_dim_0(), - proto.cluster_dim_1(), proto.cluster_dim_2()); + proto.compute_capability()); } jax_triton::TritonKernel Kernel::ToProto() const { jax_triton::TritonKernel proto; proto.set_kernel_name(kernel_name_); proto.set_num_warps(block_dim_x_ / kNumThreadsPerWarp); + proto.set_num_ctas(num_ctas_); proto.set_shared_mem_bytes(shared_mem_bytes_); proto.set_ptx(ptx_); proto.set_ttir(ttir_); proto.set_compute_capability(compute_capability_); - proto.set_cluster_dim_0(cluster_dims_[0]); - proto.set_cluster_dim_1(cluster_dims_[1]); - proto.set_cluster_dim_2(cluster_dims_[2]); return proto; } diff --git a/jaxlib/gpu/triton_kernels.h b/jaxlib/gpu/triton_kernels.h index 3ab3e9143fb8..08320a104183 100644 --- a/jaxlib/gpu/triton_kernels.h +++ b/jaxlib/gpu/triton_kernels.h @@ -38,10 +38,9 @@ class ModuleImage; class Kernel { public: - Kernel(std::string kernel_name, uint32_t num_warps, uint32_t shared_mem_bytes, - std::string ptx, std::string ttir, int compute_capability, - uint32_t cluster_dim_0, uint32_t cluster_dim_1, - uint32_t cluster_dim_2); + Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas, + uint32_t shared_mem_bytes, std::string ptx, std::string ttir, + int compute_capability); absl::Status Launch(gpuStream_t stream, uint32_t grid[3], void** params); @@ -54,11 +53,11 @@ class Kernel { private: std::string kernel_name_; uint32_t block_dim_x_; + uint32_t num_ctas_; uint32_t shared_mem_bytes_; std::string ptx_; std::string ttir_; int compute_capability_; - uint32_t cluster_dims_[3]; ModuleImage* module_image_ = nullptr; }; @@ -107,8 +106,7 @@ class AutotunedKernelCall { AutotunedKernelCall( std::string name, std::vector configs, - std::vector> input_output_aliases); + std::vector> input_output_aliases); static absl::StatusOr Autotune(AutotunedKernelCall kernel_call, gpuStream_t stream,