@@ -92,111 +92,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
9292 torch::Tensor X, // input
9393 int64_t type, int64_t row) {
9494 int col = X.sizes ()[1 ];
95+ int vecs = X.sizes ()[0 ];
9596 const int padded = (col + 512 - 1 ) / 512 * 512 ;
9697 const at::cuda::OptionalCUDAGuard device_guard (device_of (X));
9798 auto options = torch::TensorOptions ().dtype (X.dtype ()).device (W.device ());
98- at::Tensor Y = torch::empty ({1 , row}, options);
99+ at::Tensor Y = torch::empty ({vecs , row}, options);
99100 cudaStream_t stream = at::cuda::getCurrentCUDAStream ().stream ();
100101 options = torch::TensorOptions ().dtype (torch::kInt32 ).device (W.device ());
101- at::Tensor quant_X = torch::empty ({1 , padded / 32 * 9 }, options);
102+ at::Tensor quant_X = torch::empty ({vecs , padded / 32 * 9 }, options);
102103 VLLM_DISPATCH_FLOATING_TYPES (X.scalar_type (), " ggml_mul_mat_vec_a8" , [&] {
103- quantize_row_q8_1_cuda<scalar_t >(( scalar_t *)X. data_ptr (),
104- (void *)quant_X.data_ptr (), col, 1 , stream);
104+ quantize_row_q8_1_cuda<scalar_t >(
105+ ( scalar_t *)X. data_ptr (), (void *)quant_X.data_ptr (), col, vecs , stream);
105106 switch (type) {
106107 case 2 :
107108 mul_mat_vec_q4_0_q8_1_cuda<scalar_t >(
108109 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
109- (scalar_t *)Y.data_ptr (), col, row, stream);
110+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
110111 break ;
111112 case 3 :
112113 mul_mat_vec_q4_1_q8_1_cuda<scalar_t >(
113114 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
114- (scalar_t *)Y.data_ptr (), col, row, stream);
115+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
115116 break ;
116117 case 6 :
117118 mul_mat_vec_q5_0_q8_1_cuda<scalar_t >(
118119 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
119- (scalar_t *)Y.data_ptr (), col, row, stream);
120+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
120121 break ;
121122 case 7 :
122123 mul_mat_vec_q5_1_q8_1_cuda<scalar_t >(
123124 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
124- (scalar_t *)Y.data_ptr (), col, row, stream);
125+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
125126 break ;
126127 case 8 :
127128 mul_mat_vec_q8_0_q8_1_cuda<scalar_t >(
128129 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
129- (scalar_t *)Y.data_ptr (), col, row, stream);
130+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
130131 break ;
131132 case 10 :
132133 mul_mat_vec_q2_K_q8_1_cuda<scalar_t >(
133134 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
134- (scalar_t *)Y.data_ptr (), col, row, stream);
135+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
135136 break ;
136137 case 11 :
137138 mul_mat_vec_q3_K_q8_1_cuda<scalar_t >(
138139 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
139- (scalar_t *)Y.data_ptr (), col, row, stream);
140+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
140141 break ;
141142 case 12 :
142143 mul_mat_vec_q4_K_q8_1_cuda<scalar_t >(
143144 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
144- (scalar_t *)Y.data_ptr (), col, row, stream);
145+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
145146 break ;
146147 case 13 :
147148 mul_mat_vec_q5_K_q8_1_cuda<scalar_t >(
148149 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
149- (scalar_t *)Y.data_ptr (), col, row, stream);
150+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
150151 break ;
151152 case 14 :
152153 mul_mat_vec_q6_K_q8_1_cuda<scalar_t >(
153154 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
154- (scalar_t *)Y.data_ptr (), col, row, stream);
155+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
155156 break ;
156157 case 16 :
157158 mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t >(
158159 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
159- (scalar_t *)Y.data_ptr (), col, row, stream);
160+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
160161 break ;
161162 case 17 :
162163 mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t >(
163164 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
164- (scalar_t *)Y.data_ptr (), col, row, stream);
165+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
165166 break ;
166167 case 18 :
167168 mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t >(
168169 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
169- (scalar_t *)Y.data_ptr (), col, row, stream);
170+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
170171 break ;
171172 case 19 :
172173 mul_mat_vec_iq1_s_q8_1_cuda<scalar_t >(
173174 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
174- (scalar_t *)Y.data_ptr (), col, row, stream);
175+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
175176 break ;
176177 case 20 :
177178 mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t >(
178179 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
179- (scalar_t *)Y.data_ptr (), col, row, stream);
180+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
180181 break ;
181182 case 21 :
182183 mul_mat_vec_iq3_s_q8_1_cuda<scalar_t >(
183184 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
184- (scalar_t *)Y.data_ptr (), col, row, stream);
185+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
185186 break ;
186187 case 22 :
187188 mul_mat_vec_iq2_s_q8_1_cuda<scalar_t >(
188189 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
189- (scalar_t *)Y.data_ptr (), col, row, stream);
190+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
190191 break ;
191192 case 23 :
192193 mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t >(
193194 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
194- (scalar_t *)Y.data_ptr (), col, row, stream);
195+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
195196 break ;
196197 case 29 :
197198 mul_mat_vec_iq1_m_q8_1_cuda<scalar_t >(
198199 (void *)W.data_ptr (), (void *)quant_X.data_ptr (),
199- (scalar_t *)Y.data_ptr (), col, row, stream);
200+ (scalar_t *)Y.data_ptr (), col, row, vecs, stream);
200201 break ;
201202 }
202203 });
0 commit comments