Skip to content

Commit 77d9e97

Browse files
SS-JIAssjia
andauthored
[ET-VK] Enable and test texture IO for quantized convolution ops (#16082)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #16082 * #16079 Title says it all! Using texture3d as the input/output storage type may allow for additional optimizations re: bounds checking. Differential Revision: [D88395020](https://our.internmc.facebook.com/intern/diff/D88395020/) --------- Co-authored-by: ssjia <ssjia@devvm1479.ncg0.facebook.com>
1 parent 2d1f181 commit 77d9e97

11 files changed

+97
-25
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_int8_output_tile_store.glslh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
void store_packed_int8_output_tile(
1818
const Int8OutTile int8_tile,
19-
const Conv2dBlockIndex block_idx,
19+
Conv2dBlockIndex block_idx,
2020
const Conv2dBlockExtents block_extents) {
2121
#ifdef PACKED_INT8_OUTPUT_BUFFER
2222
[[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) {
@@ -34,8 +34,11 @@ void store_packed_int8_output_tile(
3434
[[unroll]] for (int n4 = 0; n4 < TILE_N4; n4++) {
3535
if (block_idx.data.x + m4 < block_extents.data.x &&
3636
block_idx.data.z + n4 < block_extents.data.z) {
37+
const ivec3 idx_offset = ivec3(m4, 0, n4);
3738
imageStore(
38-
t_packed_int8_output, block_idx.data, int8_tile.data[m4][n4]);
39+
t_packed_int8_output,
40+
block_idx.data + idx_offset,
41+
int8_tile.data[m4][n4]);
3942
}
4043
}
4144
}

backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ conv2d_q8ta_q8csw_q8to:
1414
parameter_names: [IO_STORAGE, WEIGHT_STORAGE]
1515
combos:
1616
- parameter_values: [buffer, texture2d]
17+
- parameter_values: [texture3d, texture2d]
1718
DTYPE:
1819
- VALUE: float
1920
shader_variants:

backends/vulkan/runtime/graph/ops/glsl/conv2d_q8ta_q8csw_q8to_linear_tiled.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ conv2d_q8ta_q8csw_q8to_linear_tiled:
1414
parameter_names: [IO_STORAGE, WEIGHT_STORAGE]
1515
combos:
1616
- parameter_values: [buffer, texture2d]
17+
- parameter_values: [texture3d, texture2d]
1718
DTYPE:
1819
- VALUE: float
1920
shader_variants:

backends/vulkan/runtime/graph/ops/glsl/im2col_packed_int8.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ im2col_packed_int8:
1010
generate_variant_forall:
1111
STORAGE:
1212
- VALUE: buffer
13+
- VALUE: texture3d
1314
shader_variants:
1415
- NAME: im2col_packed_int8

backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ void printInt8InputTile(const Int8InputTile tile) {
3131

3232
[[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) {
3333
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
34-
debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, k4);
34+
debugPrintfEXT(" tile[%d][%d]:\\n", m4, k4);
3535

3636
// Each ivec4 contains 4 packed integers, each integer contains 4 8-bit
3737
// values
3838
[[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
3939
int packed_int = tile.data[m4][k4][vec_idx];
40-
debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int);
40+
debugPrintfEXT(" [", vec_idx, packed_int);
4141

4242
// Extract 4 8-bit values from this packed integer
4343
[[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) {
@@ -48,6 +48,7 @@ void printInt8InputTile(const Int8InputTile tile) {
4848
debugPrintfEXT("%d] ", val);
4949
}
5050
}
51+
debugPrintfEXT("(packed=%d)\\n", packed_int);
5152
}
5253
debugPrintfEXT("\\n");
5354
}

backends/vulkan/runtime/graph/ops/glsl/linear_int8_output_tile.glslh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,17 @@ void initialize(out Int8OutTile tile) {
3535

3636
void printInt8OutTile(const Int8OutTile tile) {
3737
debugPrintfEXT(
38-
"Int8InputTile [TILE_M4=%d][TILE_N4=%d]:\\n", TILE_M4, TILE_N4);
38+
"Int8OutTile [TILE_M4=%d][TILE_N4=%d]:\\n", TILE_M4, TILE_N4);
3939

4040
[[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) {
4141
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
42-
debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, n4);
42+
debugPrintfEXT(" tile[%d][%d]:\\n", m4, n4);
4343

4444
// Each ivec4 contains 4 packed integers, each integer contains 4 8-bit
4545
// values
4646
[[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) {
4747
int packed_int = tile.data[m4][n4][vec_idx];
48-
debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int);
48+
debugPrintfEXT(" [", vec_idx, packed_int);
4949

5050
// Extract 4 8-bit values from this packed integer
5151
[[unroll]] for (int byte_idx = 0; byte_idx < 4; ++byte_idx) {
@@ -56,6 +56,7 @@ void printInt8OutTile(const Int8OutTile tile) {
5656
debugPrintfEXT("%d] ", val);
5757
}
5858
}
59+
debugPrintfEXT("(packed=%d)\\n", packed_int);
5960
}
6061
debugPrintfEXT("\\n");
6162
}

backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ quantize_and_pack_4w4c:
1414
parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE]
1515
combos:
1616
- parameter_values: [texture3d, texture3d]
17+
- parameter_values: [texture3d, buffer]
1718
- parameter_values: [buffer, texture3d]
1819
- parameter_values: [buffer, buffer]
1920
DTYPE:

backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ unpack_4w4c_and_dequantize:
1414
parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE]
1515
combos:
1616
- parameter_values: [texture3d, texture3d]
17+
- parameter_values: [buffer, texture3d]
1718
- parameter_values: [texture3d, buffer]
1819
- parameter_values: [buffer, buffer]
1920
DTYPE:

backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,7 +1400,7 @@ void static_quantized_conv2d_impl(
14001400
&graph,
14011401
input_im2col_sizes,
14021402
vkapi::kInt8x4,
1403-
utils::kBuffer,
1403+
graph.storage_type_of(packed_int8_input),
14041404
utils::kPackedInt8_4W4C);
14051405

14061406
packed_int8_input_im2col = packed_int8_input_im2col_tensor.vref;
@@ -1492,7 +1492,8 @@ void conv2d_q8ta_q8csw_q8to(
14921492

14931493
void conv2d_q8ta_q8csw_q8to_test(
14941494
ComputeGraph& graph,
1495-
const std::vector<ValueRef>& args) {
1495+
const std::vector<ValueRef>& args,
1496+
utils::StorageType io_storage_type) {
14961497
int32_t idx = 0;
14971498
const ValueRef fp_input = args.at(idx++);
14981499
const ValueRef input_scale = args.at(idx++);
@@ -1514,14 +1515,14 @@ void conv2d_q8ta_q8csw_q8to_test(
15141515
&graph,
15151516
graph.sizes_of(fp_input),
15161517
vkapi::kInt8x4,
1517-
utils::kBuffer,
1518+
io_storage_type,
15181519
utils::kPackedInt8_4W4C);
15191520

15201521
TmpTensor packed_int8_output(
15211522
&graph,
15221523
graph.sizes_of(fp_output),
15231524
vkapi::kInt8x4,
1524-
utils::kBuffer,
1525+
io_storage_type,
15251526
utils::kPackedInt8_4W4C);
15261527

15271528
add_quantize_and_pack_4w4c_node(
@@ -1550,10 +1551,27 @@ void conv2d_q8ta_q8csw_q8to_test(
15501551
graph, packed_int8_output, output_scale, output_zp, fp_output);
15511552
}
15521553

1554+
void conv2d_q8ta_q8csw_q8to_test_buffer(
1555+
ComputeGraph& graph,
1556+
const std::vector<ValueRef>& args) {
1557+
conv2d_q8ta_q8csw_q8to_test(graph, args, utils::kBuffer);
1558+
}
1559+
1560+
void conv2d_q8ta_q8csw_q8to_test_texture(
1561+
ComputeGraph& graph,
1562+
const std::vector<ValueRef>& args) {
1563+
conv2d_q8ta_q8csw_q8to_test(graph, args, utils::kBuffer);
1564+
}
1565+
15531566
REGISTER_OPERATORS {
15541567
VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw.default, conv2d_q8ta_q8csw);
15551568
VK_REGISTER_OP(et_vk.conv2d_q8csw.default, conv2d_q8csw);
1556-
VK_REGISTER_OP(etvk.conv2d_q8ta_q8csw_q8to.test, conv2d_q8ta_q8csw_q8to_test);
1569+
VK_REGISTER_OP(
1570+
etvk.conv2d_q8ta_q8csw_q8to.test_texture,
1571+
conv2d_q8ta_q8csw_q8to_test_texture);
1572+
VK_REGISTER_OP(
1573+
etvk.conv2d_q8ta_q8csw_q8to.test_buffer,
1574+
conv2d_q8ta_q8csw_q8to_test_buffer);
15571575
VK_REGISTER_OP(et_vk.conv2d_q8ta_q8csw_q8to.default, conv2d_q8ta_q8csw_q8to);
15581576
VK_REGISTER_OP(
15591577
et_vk.conv2d_q8ta_q8csw_q8to_dw.default, conv2d_q8ta_q8csw_q8to);

backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1515

16+
// #define DEBUG_MODE
17+
1618
using namespace executorch::vulkan::prototyping;
1719

1820
using namespace vkcompute;
@@ -23,7 +25,8 @@ static constexpr int64_t kRefDimSizeLimit = 100;
2325
TestCase create_test_case_from_config(
2426
const Conv2dConfig& config,
2527
utils::StorageType storage_type,
26-
vkapi::ScalarType input_dtype) {
28+
vkapi::ScalarType input_dtype,
29+
utils::StorageType interm_storage_type) {
2730
TestCase test_case;
2831

2932
// Create a descriptive name for the test case
@@ -35,8 +38,15 @@ TestCase create_test_case_from_config(
3538
config.test_case_name + "_" + storage_str + "_" + dtype_str;
3639
test_case.set_name(test_name);
3740

41+
std::string operator_suffix = ".test";
42+
if (interm_storage_type == utils::kTexture3D) {
43+
operator_suffix += "_texture";
44+
} else {
45+
operator_suffix += "_buffer";
46+
}
47+
3848
// Set the operator name for the test case
39-
std::string operator_name = "etvk." + config.op_name + ".test";
49+
std::string operator_name = "etvk." + config.op_name + operator_suffix;
4050
test_case.set_operator_name(operator_name);
4151

4252
// Calculate output dimensions
@@ -56,7 +66,12 @@ TestCase create_test_case_from_config(
5666
input_dtype,
5767
storage_type,
5868
io_memory_layout,
59-
DataGenType::RANDOM);
69+
#ifdef DEBUG_MODE
70+
DataGenType::RANDOM
71+
#else
72+
DataGenType::RANDOM
73+
#endif
74+
);
6075

6176
if (debugging()) {
6277
print_valuespec_data(input_tensor, "input_tensor");
@@ -193,8 +208,10 @@ std::vector<TestCase> generate_quantized_conv2d_easy_cases() {
193208
// Generate test cases for each combination
194209
for (const auto& storage_type : storage_types) {
195210
for (const auto& input_dtype : float_types) {
196-
test_cases.push_back(
197-
create_test_case_from_config(config, storage_type, input_dtype));
211+
test_cases.push_back(create_test_case_from_config(
212+
config, storage_type, input_dtype, utils::kBuffer));
213+
test_cases.push_back(create_test_case_from_config(
214+
config, storage_type, input_dtype, utils::kTexture3D));
198215
}
199216
}
200217

@@ -373,8 +390,10 @@ std::vector<TestCase> generate_quantized_conv2d_test_cases() {
373390
if (vkcompute::api::context()
374391
->adapter_ptr()
375392
->supports_int8_dot_product()) {
376-
test_cases.push_back(
377-
create_test_case_from_config(config, storage_type, vkapi::kFloat));
393+
test_cases.push_back(create_test_case_from_config(
394+
config, storage_type, vkapi::kFloat, utils::kBuffer));
395+
test_cases.push_back(create_test_case_from_config(
396+
config, storage_type, vkapi::kFloat, utils::kTexture3D));
378397
}
379398
}
380399
}
@@ -610,7 +629,11 @@ int64_t quantized_conv2d_flop_calculator(const TestCase& test_case) {
610629
int main(int argc, char* argv[]) {
611630
set_debugging(false);
612631
set_print_output(false);
632+
#ifdef DEBUG_MODE
633+
set_print_latencies(true);
634+
#else
613635
set_print_latencies(false);
636+
#endif
614637
set_use_gpu_timestamps(true);
615638

616639
print_performance_header();
@@ -623,11 +646,20 @@ int main(int argc, char* argv[]) {
623646

624647
// Execute test cases using the new framework with custom FLOP calculator
625648
auto results = execute_test_cases(
649+
#ifdef DEBUG_MODE
650+
generate_quantized_conv2d_easy_cases,
651+
#else
626652
generate_quantized_conv2d_test_cases,
653+
#endif
627654
quantized_conv2d_flop_calculator,
628655
"QuantizedConv2dQ8ToQ8To",
656+
#ifdef DEBUG_MODE
657+
0,
658+
1,
659+
#else
629660
3,
630661
10,
662+
#endif
631663
ref_fn);
632664

633665
return 0;

0 commit comments

Comments
 (0)