Skip to content

Commit 2b09423

Browse files
author
ssjia
committed
[ez][ET-VK] Small fix for choose_qparams_affine_impl
It seems that `choose_qparams_affine` has recently appended some arguments to the schema. This causes newly exported models to break because at runtime, the output arg can no longer be found. Fix by locating the output argument as the last entry in the args vector, rather than continuously incrementing the args index. Update quantize/dequantize ops as well since it seems quantized_decomposed namespace ops are subject to change in the future. Note that it would be good to do this for all operators in the Vulkan backend as a later refactor. Differential Revision: [D88887463](https://our.internmc.facebook.com/intern/diff/D88887463/) [ghstack-poisoned]
1 parent a0a6278 commit 2b09423

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ bool can_use_choose_qparams_per_row(
158158
void choose_qparams_affine_impl(
159159
ComputeGraph& graph,
160160
const std::vector<ValueRef>& args) {
161-
int arg_idx = 0;
161+
size_t arg_idx = 0;
162+
size_t last_arg_idx = args.size() - 1;
162163
const ValueRef input = args[arg_idx++];
163164
const ValueRef mapping_type = args[arg_idx++];
164165
(void)mapping_type;
@@ -170,7 +171,8 @@ void choose_qparams_affine_impl(
170171
(void)eps;
171172
const ValueRef scale_dtype = args[arg_idx++];
172173
const ValueRef zero_point_dtype = args[arg_idx++];
173-
const ValueRef out_tuple_ref = args[arg_idx++];
174+
175+
const ValueRef out_tuple_ref = args[last_arg_idx];
174176

175177
// Suppress unused variable warnings
176178
(void)target_dtype;

backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,8 @@ void add_unpack_4w4c_and_dequantize_node(
369369
void quantize_per_tensor_impl(
370370
ComputeGraph& graph,
371371
const std::vector<ValueRef>& args) {
372-
int32_t arg_idx = 0;
372+
size_t arg_idx = 0;
373+
size_t last_arg_idx = args.size() - 1;
373374
const ValueRef fp_input = args[arg_idx++];
374375
const ValueRef scale = args[arg_idx++];
375376
const ValueRef zero_point = args[arg_idx++];
@@ -380,7 +381,7 @@ void quantize_per_tensor_impl(
380381
const ValueRef dtype = args[arg_idx++];
381382
(void)dtype;
382383

383-
const ValueRef int8_output = args[arg_idx++];
384+
const ValueRef int8_output = args[last_arg_idx];
384385

385386
VK_CHECK_COND(
386387
graph.estimate_memory_layout_of(int8_output) == utils::kPackedInt8_4W4C);
@@ -392,7 +393,8 @@ void quantize_per_tensor_impl(
392393
void dequantize_per_tensor_impl(
393394
ComputeGraph& graph,
394395
const std::vector<ValueRef>& args) {
395-
int32_t arg_idx = 0;
396+
size_t arg_idx = 0;
397+
size_t last_arg_idx = args.size() - 1;
396398
const ValueRef int8_input = args[arg_idx++];
397399
const ValueRef scale = args[arg_idx++];
398400
const ValueRef zero_point = args[arg_idx++];
@@ -405,7 +407,7 @@ void dequantize_per_tensor_impl(
405407
const ValueRef output_dtype = args[arg_idx++];
406408
(void)output_dtype;
407409

408-
const ValueRef fp_output = args[arg_idx++];
410+
const ValueRef fp_output = args[last_arg_idx];
409411

410412
VK_CHECK_COND(
411413
graph.estimate_memory_layout_of(int8_input) == utils::kPackedInt8_4W4C);

0 commit comments

Comments
 (0)