@@ -228,6 +228,8 @@ struct vk_device_struct {
228228 vk_pipeline pipeline_repeat_f32;
229229 vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
230230 vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
231+ vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
232+ vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
231233 vk_pipeline pipeline_norm_f32;
232234 vk_pipeline pipeline_group_norm_f32;
233235 vk_pipeline pipeline_rms_norm_f32;
@@ -1965,6 +1967,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
19651967 ggml_vk_create_pipeline (device, device->pipeline_contig_cpy_f32_f16 , " contig_cpy_f32_f16" , contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
19661968 ggml_vk_create_pipeline (device, device->pipeline_contig_cpy_f16_f16 , " contig_cpy_f16_f16" , contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
19671969
1970+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q4_0], " cpy_f32_q4_0" , cpy_f32_q4_0_len, cpy_f32_q4_0_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_0), 1 , 1 }, {}, 1 );
1971+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q4_1], " cpy_f32_q4_1" , cpy_f32_q4_1_len, cpy_f32_q4_1_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_1), 1 , 1 }, {}, 1 );
1972+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q5_0], " cpy_f32_q5_0" , cpy_f32_q5_0_len, cpy_f32_q5_0_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_0), 1 , 1 }, {}, 1 );
1973+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q5_1], " cpy_f32_q5_1" , cpy_f32_q5_1_len, cpy_f32_q5_1_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_1), 1 , 1 }, {}, 1 );
1974+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_Q8_0], " cpy_f32_q8_0" , cpy_f32_q8_0_len, cpy_f32_q8_0_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q8_0), 1 , 1 }, {}, 1 );
1975+ ggml_vk_create_pipeline (device, device->pipeline_cpy_f32_quant [GGML_TYPE_IQ4_NL], " cpy_f32_iq4_nl" , cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_IQ4_NL), 1 , 1 }, {}, 1 );
1976+
1977+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q4_0], " cpy_q4_0_f32" , cpy_q4_0_f32_len, cpy_q4_0_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_0), 1 , 1 }, {}, 1 );
1978+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q4_1], " cpy_q4_1_f32" , cpy_q4_1_f32_len, cpy_q4_1_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q4_1), 1 , 1 }, {}, 1 );
1979+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q5_0], " cpy_q5_0_f32" , cpy_q5_0_f32_len, cpy_q5_0_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_0), 1 , 1 }, {}, 1 );
1980+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q5_1], " cpy_q5_1_f32" , cpy_q5_1_f32_len, cpy_q5_1_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q5_1), 1 , 1 }, {}, 1 );
1981+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_Q8_0], " cpy_q8_0_f32" , cpy_q8_0_f32_len, cpy_q8_0_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_Q8_0), 1 , 1 }, {}, 1 );
1982+ ggml_vk_create_pipeline (device, device->pipeline_cpy_quant_f32 [GGML_TYPE_IQ4_NL], " cpy_iq4_nl_f32" , cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {(uint32_t )ggml_blck_size (GGML_TYPE_IQ4_NL), 1 , 1 }, {}, 1 );
1983+
19681984 ggml_vk_create_pipeline (device, device->pipeline_add_f32 , " add_f32" , add_f32_len, add_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
19691985 ggml_vk_create_pipeline (device, device->pipeline_add_f32_norepeat , " add_f32_norepeat" , add_f32_len, add_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
19701986 ggml_vk_create_pipeline (device, device->pipeline_add_f16_f32_f16 , " add_f16_f32_f16" , add_f16_f32_f16_len, add_f16_f32_f16_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
@@ -3689,6 +3705,33 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
36893705 return ctx->device ->pipeline_cpy_f16_f16 ;
36903706 }
36913707 }
3708+ if (src->type == GGML_TYPE_F32) {
3709+ switch (to) {
3710+ case GGML_TYPE_Q4_0:
3711+ case GGML_TYPE_Q4_1:
3712+ case GGML_TYPE_Q5_0:
3713+ case GGML_TYPE_Q5_1:
3714+ case GGML_TYPE_Q8_0:
3715+ case GGML_TYPE_IQ4_NL:
3716+ return ctx->device ->pipeline_cpy_f32_quant [to];
3717+ default :
3718+ break ;
3719+ }
3720+ }
3721+
3722+ if (to == GGML_TYPE_F32) {
3723+ switch (src->type ) {
3724+ case GGML_TYPE_Q4_0:
3725+ case GGML_TYPE_Q4_1:
3726+ case GGML_TYPE_Q5_0:
3727+ case GGML_TYPE_Q5_1:
3728+ case GGML_TYPE_Q8_0:
3729+ case GGML_TYPE_IQ4_NL:
3730+ return ctx->device ->pipeline_cpy_quant_f32 [src->type ];
3731+ default :
3732+ break ;
3733+ }
3734+ }
36923735
36933736 std::cerr << " Missing CPY op for types: " << ggml_type_name (src->type ) << " " << ggml_type_name (to) << std::endl;
36943737 GGML_ABORT (" fatal error" );
@@ -5160,7 +5203,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
51605203 }
51615204 std::cerr << " ), (" << dst << " , name=" << dst->name << " , type=" << dst->type << " , ne0=" << dst->ne [0 ] << " , ne1=" << dst->ne [1 ] << " , ne2=" << dst->ne [2 ] << " , ne3=" << dst->ne [3 ] << " , nb0=" << dst->nb [0 ] << " , nb1=" << dst->nb [1 ] << " , nb2=" << dst->nb [2 ] << " , nb3=" << dst->nb [3 ];
51625205 std::cerr << " ), " << ggml_op_name (op) << " , " << (dryrun ? " dryrun" : " " ) << " )" );
5163- GGML_ASSERT (op == GGML_OP_GET_ROWS || (!ggml_is_quantized (src0->type ) && (src1 == nullptr || !ggml_is_quantized (src1->type )))); // NOLINT
5206+ GGML_ASSERT (op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized (src0->type ) && (src1 == nullptr || !ggml_is_quantized (src1->type )))); // NOLINT
51645207 GGML_ASSERT (ggml_vk_op_supports_incontiguous (op) || ggml_vk_dim01_contiguous (src0)); // NOLINT
51655208 GGML_ASSERT (dst->buffer != nullptr );
51665209 const uint64_t ne00 = src0->ne [0 ];
@@ -7905,12 +7948,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
79057948 {
79067949 ggml_type src0_type = op->src [0 ]->type ;
79077950 ggml_type src1_type = op->src [1 ] != nullptr ? op->src [1 ]->type : src0_type;
7908- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
7909- return true ;
7951+
7952+ if (src0_type == GGML_TYPE_F32) {
7953+ switch (src1_type) {
7954+ case GGML_TYPE_F32:
7955+ case GGML_TYPE_F16:
7956+ case GGML_TYPE_Q4_0:
7957+ case GGML_TYPE_Q4_1:
7958+ case GGML_TYPE_Q5_0:
7959+ case GGML_TYPE_Q5_1:
7960+ case GGML_TYPE_Q8_0:
7961+ case GGML_TYPE_IQ4_NL:
7962+ return true ;
7963+ default :
7964+ break ;
7965+ }
79107966 }
7911- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
7912- return true ;
7967+ if (src1_type == GGML_TYPE_F32) {
7968+ switch (src0_type) {
7969+ case GGML_TYPE_Q4_0:
7970+ case GGML_TYPE_Q4_1:
7971+ case GGML_TYPE_Q5_0:
7972+ case GGML_TYPE_Q5_1:
7973+ case GGML_TYPE_Q8_0:
7974+ case GGML_TYPE_IQ4_NL:
7975+ return true ;
7976+ default :
7977+ break ;
7978+ }
79137979 }
7980+
79147981 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
79157982 return true ;
79167983 }
0 commit comments