@@ -315,17 +315,16 @@ class ModuleImage {
315315 ABSL_GUARDED_BY (mutex_);
316316};
317317
318- Kernel::Kernel (std::string kernel_name, uint32_t num_warps,
318+ Kernel::Kernel (std::string kernel_name, uint32_t num_warps, uint32_t num_ctas,
319319 uint32_t shared_mem_bytes, std::string ptx, std::string ttir,
320- int compute_capability, uint32_t cluster_dim_0,
321- uint32_t cluster_dim_1, uint32_t cluster_dim_2)
320+ int compute_capability)
322321 : kernel_name_(std::move(kernel_name)),
323322 block_dim_x_(num_warps * kNumThreadsPerWarp ),
323+ num_ctas_(num_ctas),
324324 shared_mem_bytes_(shared_mem_bytes),
325325 ptx_(std::move(ptx)),
326326 ttir_(std::move(ttir)),
327- compute_capability_(compute_capability),
328- cluster_dims_{cluster_dim_0, cluster_dim_1, cluster_dim_2} {}
327+ compute_capability_(compute_capability) {}
329328
330329absl::Status Kernel::Launch (gpuStream_t stream, uint32_t grid[3 ],
331330 void ** params) {
@@ -362,26 +361,24 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
362361
363362 JAX_ASSIGN_OR_RETURN (gpuFunction_t kernel,
364363 module_image_->GetFunctionForContext (context));
365- const uint32_t cluster_size =
366- cluster_dims_[0 ] * cluster_dims_[1 ] * cluster_dims_[2 ];
367- if (cluster_size <= 1 ) {
364+ if (num_ctas_ == 1 ) {
368365 return JAX_AS_STATUS (gpuLaunchKernel (
369366 kernel, grid[0 ], grid[1 ], grid[2 ], block_dim_x_,
370367 /* blockDimY=*/ 1 , /* blockDimZ=*/ 1 , shared_mem_bytes_, stream, params,
371368 /* extra=*/ nullptr ));
372369 }
373370 CUlaunchAttribute launch_attrs[2 ];
374371 launch_attrs[0 ].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
375- launch_attrs[0 ].value .clusterDim .x = cluster_dims_[ 0 ] ;
376- launch_attrs[0 ].value .clusterDim .y = cluster_dims_[ 1 ] ;
377- launch_attrs[0 ].value .clusterDim .z = cluster_dims_[ 2 ] ;
372+ launch_attrs[0 ].value .clusterDim .x = num_ctas_ ;
373+ launch_attrs[0 ].value .clusterDim .y = 1 ;
374+ launch_attrs[0 ].value .clusterDim .z = 1 ;
378375 launch_attrs[1 ].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
379376 launch_attrs[1 ].value .clusterSchedulingPolicyPreference =
380377 CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
381378 CUlaunchConfig launch_config = {
382- /* gridDimX=*/ grid[0 ] * cluster_dims_[ 0 ] ,
383- /* gridDimY=*/ grid[1 ] * cluster_dims_[ 1 ] ,
384- /* gridDimZ=*/ grid[2 ] * cluster_dims_[ 2 ] ,
379+ /* gridDimX=*/ grid[0 ] * num_ctas_ ,
380+ /* gridDimY=*/ grid[1 ],
381+ /* gridDimZ=*/ grid[2 ],
385382 /* blockDimX=*/ block_dim_x_,
386383 /* blockDimY=*/ 1 ,
387384 /* blockDimZ=*/ 1 ,
@@ -396,23 +393,23 @@ absl::Status Kernel::Launch(gpuStream_t stream, uint32_t grid[3],
396393}
397394
398395/* static*/ Kernel Kernel::FromProto (const jax_triton::TritonKernel& proto) {
399- return Kernel (proto.kernel_name (), proto.num_warps (),
396+ // Use 1 as default value if not specified in already serialized kernels.
397+ int num_ctas = proto.has_num_ctas () ? proto.num_ctas () : 1 ;
398+
399+ return Kernel (proto.kernel_name (), proto.num_warps (), num_ctas,
400400 proto.shared_mem_bytes (), proto.ptx (), proto.ttir (),
401- proto.compute_capability (), proto.cluster_dim_0 (),
402- proto.cluster_dim_1 (), proto.cluster_dim_2 ());
401+ proto.compute_capability ());
403402}
404403
405404jax_triton::TritonKernel Kernel::ToProto () const {
406405 jax_triton::TritonKernel proto;
407406 proto.set_kernel_name (kernel_name_);
408407 proto.set_num_warps (block_dim_x_ / kNumThreadsPerWarp );
408+ proto.set_num_ctas (num_ctas_);
409409 proto.set_shared_mem_bytes (shared_mem_bytes_);
410410 proto.set_ptx (ptx_);
411411 proto.set_ttir (ttir_);
412412 proto.set_compute_capability (compute_capability_);
413- proto.set_cluster_dim_0 (cluster_dims_[0 ]);
414- proto.set_cluster_dim_1 (cluster_dims_[1 ]);
415- proto.set_cluster_dim_2 (cluster_dims_[2 ]);
416413 return proto;
417414}
418415
0 commit comments