Skip to content

Commit 5cdfbf0

Browse files
loisloGoogle-ML-Automation
authored andcommitted
Integrate Triton up to 8d445186
https://github.com/openxla/triton/tree/triton_integrate_branch-1.15 PiperOrigin-RevId: 841726853
1 parent fac6550 commit 5cdfbf0

File tree

4 files changed

+27
-32
lines changed

4 files changed

+27
-32
lines changed

jaxlib/gpu/triton.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ namespace jax::JAX_GPU_NAMESPACE {
4545

4646
NB_MODULE(_triton, m) {
4747
nb::class_<Kernel>(m, "TritonKernel")
48-
.def(nb::init<std::string, uint32_t, uint32_t, std::string, std::string,
49-
int, uint32_t, uint32_t, uint32_t>());
48+
.def(nb::init<std::string, uint32_t, uint32_t, uint32_t, std::string,
49+
std::string, int>());
5050

5151
nb::class_<KernelCall::Parameter>(m, "TritonParameter");
5252

jaxlib/gpu/triton.proto

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ package jax_triton;
55
message TritonKernel {
66
string kernel_name = 1; // Kernel function name within module.
77
uint32 num_warps = 2;
8+
optional uint32 num_ctas = 10;
89
uint32 shared_mem_bytes = 3;
910
string ptx = 4;
1011
string ttir = 5;
1112
uint32 compute_capability = 6;
12-
uint32 cluster_dim_0 = 7;
13-
uint32 cluster_dim_1 = 8;
14-
uint32 cluster_dim_2 = 9;
13+
14+
reserved 7, 8, 9; // cluster_dim_0, cluster_dim_1, cluster_dim_2
1515
}
1616

1717
message TritonKernelCall {

jaxlib/gpu/triton_kernels.cc

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -315,17 +315,16 @@ class ModuleImage {
315315
ABSL_GUARDED_BY(mutex_);
316316
};
317317

318-
Kernel::Kernel(std::string kernel_name, uint32_t num_warps,
318+
Kernel::Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas,
319319
uint32_t shared_mem_bytes, std::string ptx, std::string ttir,
320-
int compute_capability, uint32_t cluster_dim_0,
321-
uint32_t cluster_dim_1, uint32_t cluster_dim_2)
320+
int compute_capability)
322321
: kernel_name_(std::move(kernel_name)),
323322
block_dim_x_(num_warps * kNumThreadsPerWarp),
323+
num_ctas_(num_ctas),
324324
shared_mem_bytes_(shared_mem_bytes),
325325
ptx_(std::move(ptx)),
326326
ttir_(std::move(ttir)),
327-
compute_capability_(compute_capability),
328-
cluster_dims_{cluster_dim_0, cluster_dim_1, cluster_dim_2} {}
327+
compute_capability_(compute_capability) {}
329328

330329
absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
331330
void** params) {
@@ -362,26 +361,24 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
362361

363362
JAX_ASSIGN_OR_RETURN(gpuFunction_t kernel,
364363
module_image_->GetFunctionForContext(context));
365-
const uint32_t cluster_size =
366-
cluster_dims_[0] * cluster_dims_[1] * cluster_dims_[2];
367-
if (cluster_size <= 1) {
364+
if (num_ctas_ == 1) {
368365
return JAX_AS_STATUS(gpuLaunchKernel(
369366
kernel, grid[0], grid[1], grid[2], block_dim_x_,
370367
/*blockDimY=*/1, /*blockDimZ=*/1, shared_mem_bytes_, stream, params,
371368
/*extra=*/nullptr));
372369
}
373370
CUlaunchAttribute launch_attrs[2];
374371
launch_attrs[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
375-
launch_attrs[0].value.clusterDim.x = cluster_dims_[0];
376-
launch_attrs[0].value.clusterDim.y = cluster_dims_[1];
377-
launch_attrs[0].value.clusterDim.z = cluster_dims_[2];
372+
launch_attrs[0].value.clusterDim.x = num_ctas_;
373+
launch_attrs[0].value.clusterDim.y = 1;
374+
launch_attrs[0].value.clusterDim.z = 1;
378375
launch_attrs[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
379376
launch_attrs[1].value.clusterSchedulingPolicyPreference =
380377
CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
381378
CUlaunchConfig launch_config = {
382-
/*gridDimX=*/grid[0] * cluster_dims_[0],
383-
/*gridDimY=*/grid[1] * cluster_dims_[1],
384-
/*gridDimZ=*/grid[2] * cluster_dims_[2],
379+
/*gridDimX=*/grid[0] * num_ctas_,
380+
/*gridDimY=*/grid[1],
381+
/*gridDimZ=*/grid[2],
385382
/*blockDimX=*/block_dim_x_,
386383
/*blockDimY=*/1,
387384
/*blockDimZ=*/1,
@@ -396,23 +393,23 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
396393
}
397394

398395
/*static*/ Kernel Kernel::FromProto(const jax_triton::TritonKernel& proto) {
399-
return Kernel(proto.kernel_name(), proto.num_warps(),
396+
// Use 1 as default value if not specified in already serialized kernels.
397+
int num_ctas = proto.has_num_ctas() ? proto.num_ctas() : 1;
398+
399+
return Kernel(proto.kernel_name(), proto.num_warps(), num_ctas,
400400
proto.shared_mem_bytes(), proto.ptx(), proto.ttir(),
401-
proto.compute_capability(), proto.cluster_dim_0(),
402-
proto.cluster_dim_1(), proto.cluster_dim_2());
401+
proto.compute_capability());
403402
}
404403

405404
jax_triton::TritonKernel Kernel::ToProto() const {
406405
jax_triton::TritonKernel proto;
407406
proto.set_kernel_name(kernel_name_);
408407
proto.set_num_warps(block_dim_x_ / kNumThreadsPerWarp);
408+
proto.set_num_ctas(num_ctas_);
409409
proto.set_shared_mem_bytes(shared_mem_bytes_);
410410
proto.set_ptx(ptx_);
411411
proto.set_ttir(ttir_);
412412
proto.set_compute_capability(compute_capability_);
413-
proto.set_cluster_dim_0(cluster_dims_[0]);
414-
proto.set_cluster_dim_1(cluster_dims_[1]);
415-
proto.set_cluster_dim_2(cluster_dims_[2]);
416413
return proto;
417414
}
418415

jaxlib/gpu/triton_kernels.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ class ModuleImage;
3838

3939
class Kernel {
4040
public:
41-
Kernel(std::string kernel_name, uint32_t num_warps, uint32_t shared_mem_bytes,
42-
std::string ptx, std::string ttir, int compute_capability,
43-
uint32_t cluster_dim_0, uint32_t cluster_dim_1,
44-
uint32_t cluster_dim_2);
41+
Kernel(std::string kernel_name, uint32_t num_warps, uint32_t num_ctas,
42+
uint32_t shared_mem_bytes, std::string ptx, std::string ttir,
43+
int compute_capability);
4544

4645
absl::Status Launch(gpuStream_t stream, uint32_t grid[3], void** params);
4746

@@ -54,11 +53,11 @@ class Kernel {
5453
private:
5554
std::string kernel_name_;
5655
uint32_t block_dim_x_;
56+
uint32_t num_ctas_;
5757
uint32_t shared_mem_bytes_;
5858
std::string ptx_;
5959
std::string ttir_;
6060
int compute_capability_;
61-
uint32_t cluster_dims_[3];
6261

6362
ModuleImage* module_image_ = nullptr;
6463
};
@@ -107,8 +106,7 @@ class AutotunedKernelCall {
107106

108107
AutotunedKernelCall(
109108
std::string name, std::vector<Config> configs,
110-
std::vector<std::tuple<size_t,
111-
size_t, size_t>> input_output_aliases);
109+
std::vector<std::tuple<size_t, size_t, size_t>> input_output_aliases);
112110

113111
static absl::StatusOr<KernelCall> Autotune(AutotunedKernelCall kernel_call,
114112
gpuStream_t stream,

0 commit comments

Comments
 (0)