@@ -3010,8 +3010,11 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
30103010 SYCL_CHECK (CHECK_TRY_ERROR (
30113011 stream->memset (dev_cur_src1_row.get (), 0 , sizeof (int ))));
30123012
3013+ const unsigned int max_work_group_size = ggml_sycl_info ().work_group_size (ctx.device );
3014+ assert (work_group_size % (WARP_SIZE * WARP_SIZE) == 0 );
3015+
30133016 {
3014- sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne10, 768u ));
3017+ sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne10, max_work_group_size ));
30153018 sycl::range<3 > grid_dims (1 , n_ids, ids->ne [1 ]);
30163019 stream->submit ([&](sycl::handler &cgh) {
30173020 sycl::local_accessor<int , 0 > src1_row_acc (cgh);
@@ -3056,7 +3059,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
30563059 ggml_sycl_mul_mat (ctx, &src0_row, &src1_row, &dst_row);
30573060
30583061 {
3059- sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne0, 768u ));
3062+ sycl::range<3 > block_dims (1 , 1 , std::min ((unsigned int )ne0, max_work_group_size ));
30603063 sycl::range<3 > grid_dims (1 , 1 , num_src1_rows);
30613064 stream->submit ([&](sycl::handler &cgh) {
30623065 const char *__restrict dst_contiguous_get =
0 commit comments