Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jaxlib/gpu/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ namespace jax::JAX_GPU_NAMESPACE {

NB_MODULE(_triton, m) {
nb::class_<Kernel>(m, "TritonKernel")
.def(nb::init<std::string, uint32_t, uint32_t, std::string, std::string,
int, uint32_t, uint32_t, uint32_t>());
.def(nb::init<std::string, uint32_t, uint32_t, uint32_t, std::string,
std::string, int>());

nb::class_<KernelCall::Parameter>(m, "TritonParameter");

Expand Down
6 changes: 3 additions & 3 deletions jaxlib/gpu/triton.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ package jax_triton;
message TritonKernel {
string kernel_name = 1; // Kernel function name within module.
uint32 num_warps = 2;
optional uint32 num_ctas = 10;
uint32 shared_mem_bytes = 3;
string ptx = 4;
string ttir = 5;
uint32 compute_capability = 6;
uint32 cluster_dim_0 = 7;
uint32 cluster_dim_1 = 8;
uint32 cluster_dim_2 = 9;

reserved 7, 8, 9; // cluster_dim_0, cluster_dim_1, cluster_dim_2
}

message TritonKernelCall {
Expand Down
37 changes: 17 additions & 20 deletions jaxlib/gpu/triton_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,17 +315,16 @@ class ModuleImage {
ABSL_GUARDED_BY(mutex_);
};

Kernel::Kernel(std::string kernel_name, uint32_t num_warps,
Kernel::Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas,
uint32_t shared_mem_bytes, std::string ptx, std::string ttir,
int compute_capability, uint32_t cluster_dim_0,
uint32_t cluster_dim_1, uint32_t cluster_dim_2)
int compute_capability)
: kernel_name_(std::move(kernel_name)),
block_dim_x_(num_warps * kNumThreadsPerWarp),
num_ctas_(num_ctas),
shared_mem_bytes_(shared_mem_bytes),
ptx_(std::move(ptx)),
ttir_(std::move(ttir)),
compute_capability_(compute_capability),
cluster_dims_{cluster_dim_0, cluster_dim_1, cluster_dim_2} {}
compute_capability_(compute_capability) {}

absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
void** params) {
Expand Down Expand Up @@ -362,26 +361,24 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],

JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
module_image_->GetFunctionForContext(context));
const uint32_t cluster_size =
cluster_dims_[0] * cluster_dims_[1] * cluster_dims_[2];
if (cluster_size <= 1) {
if (num_ctas_ == 1) {
return JAX_AS_STATUS(gpuLaunchKernel(
kernel, grid[0], grid[1], grid[2], block_dim_x_,
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
/*extra=*/nullptr));
}
CUlaunchAttribute launch_attrs[2];
launch_attrs[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
launch_attrs[0].value.clusterDim.x = cluster_dims_[0];
launch_attrs[0].value.clusterDim.y = cluster_dims_[1];
launch_attrs[0].value.clusterDim.z = cluster_dims_[2];
launch_attrs[0].value.clusterDim.x = num_ctas_;
launch_attrs[0].value.clusterDim.y = 1;
launch_attrs[0].value.clusterDim.z = 1;
launch_attrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
launch_attrs[1].value.clusterSchedulingPolicyPreference =
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
CUlaunchConfig launch_config = {
/*gridDimX=*/grid[0] * cluster_dims_[0],
/*gridDimY=*/grid[1] * cluster_dims_[1],
/*gridDimZ=*/grid[2] * cluster_dims_[2],
/*gridDimX=*/grid[0] * num_ctas_,
/*gridDimY=*/grid[1],
/*gridDimZ=*/grid[2],
/*blockDimX=*/block_dim_x_,
/*blockDimY=*/1,
/*blockDimZ=*/1,
Expand All @@ -396,23 +393,23 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
}

/*static*/ Kernel Kernel::FromProto(const jax_triton::TritonKernel& proto) {
return Kernel(proto.kernel_name(), proto.num_warps(),
// Use 1 as default value if not specified in already serialized kernels.
int num_ctas = proto.has_num_ctas() ? proto.num_ctas() : 1;

return Kernel(proto.kernel_name(), proto.num_warps(), num_ctas,
proto.shared_mem_bytes(), proto.ptx(), proto.ttir(),
proto.compute_capability(), proto.cluster_dim_0(),
proto.cluster_dim_1(), proto.cluster_dim_2());
proto.compute_capability());
}

jax_triton::TritonKernel Kernel::ToProto() const {
jax_triton::TritonKernel proto;
proto.set_kernel_name(kernel_name_);
proto.set_num_warps(block_dim_x_ / kNumThreadsPerWarp);
proto.set_num_ctas(num_ctas_);
proto.set_shared_mem_bytes(shared_mem_bytes_);
proto.set_ptx(ptx_);
proto.set_ttir(ttir_);
proto.set_compute_capability(compute_capability_);
proto.set_cluster_dim_0(cluster_dims_[0]);
proto.set_cluster_dim_1(cluster_dims_[1]);
proto.set_cluster_dim_2(cluster_dims_[2]);
return proto;
}

Expand Down
12 changes: 5 additions & 7 deletions jaxlib/gpu/triton_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ class ModuleImage;

class Kernel {
public:
Kernel(std::string kernel_name, uint32_t num_warps, uint32_t shared_mem_bytes,
std::string ptx, std::string ttir, int compute_capability,
uint32_t cluster_dim_0, uint32_t cluster_dim_1,
uint32_t cluster_dim_2);
Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas,
uint32_t shared_mem_bytes, std::string ptx, std::string ttir,
int compute_capability);

absl::Status Launch(gpuStream_t stream, uint32_t grid[3], void** params);

Expand All @@ -54,11 +53,11 @@ class Kernel {
private:
std::string kernel_name_;
uint32_t block_dim_x_;
uint32_t num_ctas_;
uint32_t shared_mem_bytes_;
std::string ptx_;
std::string ttir_;
int compute_capability_;
uint32_t cluster_dims_[3];

ModuleImage* module_image_ = nullptr;
};
Expand Down Expand Up @@ -107,8 +106,7 @@ class AutotunedKernelCall {

AutotunedKernelCall(
std::string name, std::vector<Config> configs,
std::vector<std::tuple<size_t,
size_t, size_t>> input_output_aliases);
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases);

static absl::StatusOr<KernelCall> Autotune(AutotunedKernelCall kernel_call,
gpuStream_t stream,
Expand Down
Loading