From fa3d9d3078826810f49a24b2f2b7b4aab96dda64 Mon Sep 17 00:00:00 2001 From: Akarshan Date: Sun, 31 Aug 2025 09:16:37 +0530 Subject: [PATCH 1/3] Test CUDA conv2D type conversion fix --- .github/workflows/menlo-build.yml | 5 ++--- ggml/src/ggml-cuda/conv2d.cu | 13 +++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/.github/workflows/menlo-build.yml b/.github/workflows/menlo-build.yml index 07f1c75c10d..8a62b1bcc57 100644 --- a/.github/workflows/menlo-build.yml +++ b/.github/workflows/menlo-build.yml @@ -56,7 +56,6 @@ jobs: build-and-test: runs-on: ${{ matrix.runs-on }} - needs: [create-draft-release] timeout-minutes: 270 strategy: fail-fast: false @@ -285,7 +284,7 @@ jobs: uses: actions/checkout@v3 with: submodules: recursive - + - name: Replace our Makefile run: | cat menlo/Makefile | tee Makefile @@ -635,4 +634,4 @@ jobs: upload_url: ${{ needs.create-draft-release.outputs.upload_url }} asset_path: /tmp/cudart-llama-bin-win-cu11.7-x64.tar.gz asset_name: cudart-llama-bin-win-cu11.7-x64.tar.gz - asset_content_type: application/gzip \ No newline at end of file + asset_content_type: application/gzip diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index bcb70762ee0..084bc29cf77 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -17,6 +17,15 @@ struct kernel_bounds { int64_t x_min, x_max; }; +template +__device__ __forceinline__ float to_float(const T& val) { + if constexpr (std::is_same_v) { + return __half2float(val); + } else { + return val; // Assumes T is float + } +} + __device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { return (a > b) ? a : b; } @@ -94,8 +103,8 @@ static __global__ void conv2d_kernel(const float * __restrict__ input, const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; - const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; - acc += (input_val * kernel_val); + const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; + acc += (input_val * to_float(kernel_val)); } } } From 7393cdb2c7131e1f2896433240412e3ed5b109da Mon Sep 17 00:00:00 2001 From: Minh141120 Date: Sun, 31 Aug 2025 10:59:11 +0700 Subject: [PATCH 2/3] chore: temporarily add trigger on pull request dev --- .github/workflows/menlo-build.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/menlo-build.yml b/.github/workflows/menlo-build.yml index 8a62b1bcc57..27968fd602f 100644 --- a/.github/workflows/menlo-build.yml +++ b/.github/workflows/menlo-build.yml @@ -25,6 +25,9 @@ on: ] workflow_dispatch: + pull_request: + branches: ["dev"] + env: VULKAN_VERSION: 1.3.261.1 From eafb220a8a30d7a653fd5424e9266954024bd3ee Mon Sep 17 00:00:00 2001 From: Akarshan Date: Sun, 31 Aug 2025 15:36:04 +0530 Subject: [PATCH 3/3] test 2 with reviewed comments --- ggml/src/ggml-cuda/conv2d.cu | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d.cu b/ggml/src/ggml-cuda/conv2d.cu index 084bc29cf77..21f12422373 100644 --- a/ggml/src/ggml-cuda/conv2d.cu +++ b/ggml/src/ggml-cuda/conv2d.cu @@ -1,4 +1,5 @@ #include "conv2d.cuh" +#include "convert.cuh" struct conv_params { const int64_t IW, IH; @@ -17,15 +18,6 @@ struct kernel_bounds { int64_t x_min, x_max; }; -template -__device__ __forceinline__ float to_float(const T& val) { - if constexpr (std::is_same_v) { - return __half2float(val); - } else { - return val; // Assumes T is float - } -} - __device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { return (a > b) ? a : b; } @@ -104,7 +96,7 @@ static __global__ void conv2d_kernel(const float * __restrict__ input, const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; - acc += (input_val * to_float(kernel_val)); + acc += (input_val * ggml_cuda_cast(kernel_val)); } } }