diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 7bd3489cbfd..e99bd9ab1dc 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -146,17 +146,17 @@ executorch_generated_lib( deps = [ "//executorch/backends/cadence/generic/kernels:cadence_kernels", "//executorch/backends/cadence/generic/operators:op_requantize", - "//executorch/backends/cadence/generic/operators:im2row_out", - "//executorch/backends/cadence/generic/operators:dequantize_per_tensor", - "//executorch/backends/cadence/generic/operators:quantize_per_tensor", - "//executorch/backends/cadence/generic/operators:quantized_add_out", - "//executorch/backends/cadence/generic/operators:quantized_conv2d_nchw_out", - "//executorch/backends/cadence/generic/operators:quantized_conv2d_nhwc_out", - "//executorch/backends/cadence/generic/operators:quantized_fully_connected_out", - "//executorch/backends/cadence/generic/operators:quantized_layer_norm", - "//executorch/backends/cadence/generic/operators:quantized_linear_out", - "//executorch/backends/cadence/generic/operators:quantized_matmul_out", - "//executorch/backends/cadence/generic/operators:quantized_relu_out", + "//executorch/backends/cadence/generic/operators:op_im2row", + "//executorch/backends/cadence/generic/operators:op_dequantize_per_tensor", + "//executorch/backends/cadence/generic/operators:op_quantize_per_tensor", + "//executorch/backends/cadence/generic/operators:op_quantized_add", + "//executorch/backends/cadence/generic/operators:op_quantized_conv2d", + "//executorch/backends/cadence/generic/operators:op_quantized_conv1d", + "//executorch/backends/cadence/generic/operators:op_quantized_fully_connected", + "//executorch/backends/cadence/generic/operators:op_quantized_layer_norm", + "//executorch/backends/cadence/generic/operators:op_quantized_linear", + "//executorch/backends/cadence/generic/operators:op_quantized_matmul", + "//executorch/backends/cadence/generic/operators:op_quantized_relu", "//executorch/kernels/portable:executorch_all_ops", "//executorch/kernels/portable:operators", ], diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index d8024c0245a..3ba6f4700b1 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -349,12 +349,12 @@ - arg_meta: null kernel_name: impl::generic::im2row_per_tensor_out -- func: cadence::quantized_conv2d_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nchw.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null kernel_name: impl::generic::quantized_conv2d_nchw_per_tensor_out -- func: cadence::quantized_conv2d_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) +- func: cadence::quantized_conv2d_nhwc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null kernel_name: impl::generic::quantized_conv2d_nhwc_per_tensor_out diff --git a/backends/cadence/generic/kernels/kernels.cpp b/backends/cadence/generic/kernels/kernels.cpp index 25e25cfa60a..28961d0faf1 100644 --- a/backends/cadence/generic/kernels/kernels.cpp +++ b/backends/cadence/generic/kernels/kernels.cpp @@ -7,6 +7,7 @@ */ #include + #include #include #include diff --git a/backends/cadence/generic/kernels/kernels.h b/backends/cadence/generic/kernels/kernels.h index 60ee42f8855..4b37eeb45d0 100644 --- a/backends/cadence/generic/kernels/kernels.h +++ b/backends/cadence/generic/kernels/kernels.h @@ -6,8 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include "inttypes.h" -#include "stddef.h" +#include + +#include namespace impl { namespace generic { diff --git a/backends/cadence/generic/operators/TARGETS b/backends/cadence/generic/operators/TARGETS deleted file mode 100644 index 67f2bab681a..00000000000 --- a/backends/cadence/generic/operators/TARGETS +++ /dev/null @@ -1,5 +0,0 @@ -load("targets.bzl", "define_common_targets") - -oncall("odai_jarvis") - -define_common_targets() diff --git a/backends/cadence/generic/operators/cadence_type_util.h b/backends/cadence/generic/operators/cadence_type_util.h new file mode 100644 index 00000000000..43852277031 --- /dev/null +++ b/backends/cadence/generic/operators/cadence_type_util.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +/** + * @file cadence_type_util.h + * @brief Common type macros for Cadence quantized operators + * + * This header provides utility macros for iterating over supported quantized + * data types in Cadence operators. These macros are used with switch statements + * to dispatch to type-specific implementations. + */ + +/** + * Macro to iterate over standard Cadence quantized types (uint8_t, int8_t) + * + * Usage: + * ET_FORALL_CADENCE_QUANTIZED_TYPES(MACRO) + * + * Where MACRO is defined as: #define MACRO(ctype, name) ... + * - ctype: C++ type (uint8_t or int8_t) + * - name: ExecutorTorch ScalarType name suffix (Byte or Char) + * + * Example: + * #define HANDLE_TYPE(ctype, name) \ + * case ScalarType::name: \ + * return process(tensor); \ + * break; + * + * ScalarType dtype = tensor.scalar_type(); + * switch (dtype) { + * ET_FORALL_CADENCE_QUANTIZED_TYPES(HANDLE_TYPE) + * default: + * ET_CHECK_MSG(false, "Unsupported dtype"); + * } + */ +#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) + +/** + * Macro to iterate over extended Cadence quantized types including int16_t + * + * Usage: + * ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(MACRO) + * + * Where MACRO is defined as: #define MACRO(ctype, name) ... + * - ctype: C++ type (uint8_t, int8_t, or int16_t) + * - name: ExecutorTorch ScalarType name suffix (Byte, Char, or Short) + * + * This macro includes int16_t support for operators that can handle 16-bit + * quantized values (e.g., quantized_linear, quantized_fully_connected). + */ +#define ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) diff --git a/backends/cadence/generic/operators/op_quantized_add.cpp b/backends/cadence/generic/operators/op_quantized_add.cpp new file mode 100644 index 00000000000..393a553a253 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_add.cpp @@ -0,0 +1,216 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace impl::generic::native { + +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::dequantize; +using ::impl::generic::kernels::quantize; + +DECLARE_POINTWISE_TENSOR_QUANTIZED_BINARY_OP(quantized_add_, +); + +#define DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \ + template \ + void BINARY_FUNC_NAME( \ + const Tensor& X, \ + float X_scale, \ + int32_t X_zero_point, \ + const float Y, \ + float out_scale, \ + int32_t out_zero_point, \ + Tensor& out) { \ + const T* __restrict__ X_data = X.const_data_ptr(); \ + T* __restrict__ out_data = out.mutable_data_ptr(); \ + float inv_out_scale = 1.0f / out_scale; \ + for (size_t i = 0, e = X.numel(); i < e; ++i) { \ + float x = dequantize(X_data[i], X_scale, X_zero_point); \ + float z = x OP Y; \ + out_data[i] = quantize(z, inv_out_scale, out_zero_point); \ + } \ + } + +DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP(quantized_add_Scalar_, +); + +Tensor& quantized_add_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + const Tensor& X_scale_t, + const Tensor& X_zero_point_t, + const Tensor& Y, + const Tensor& Y_scale_t, + const Tensor& Y_zero_point_t, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + float X_scale = X_scale_t.const_data_ptr()[0]; + int32_t X_zero_point = X_zero_point_t.const_data_ptr()[0]; + float Y_scale = Y_scale_t.const_data_ptr()[0]; + int32_t Y_zero_point = Y_zero_point_t.const_data_ptr()[0]; + +#define typed_quantized_add(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_add_( \ + X, \ + X_scale, \ + X_zero_point, \ + Y, \ + Y_scale, \ + Y_zero_point, \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_add); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_add + + return out; +} + +Tensor& quantized_add_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { +#define typed_quantized_add(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_add_( \ + X, \ + static_cast(X_scale), \ + static_cast(X_zero_point), \ + Y, \ + static_cast(Y_scale), \ + static_cast(Y_zero_point), \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_add); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_add + return out; +} + +Tensor& quantized_add_asym8sxasym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + quantized_add_( + X, + static_cast(X_scale), + static_cast(X_zero_point), + Y, + static_cast(Y_scale), + static_cast(Y_zero_point), + static_cast(out_scale), + static_cast(out_zero_point), + out); + return out; +} + +Tensor& quantized_add_asym8uxasym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + quantized_add_( + X, + static_cast(X_scale), + static_cast(X_zero_point), + Y, + static_cast(Y_scale), + static_cast(Y_zero_point), + static_cast(out_scale), + static_cast(out_zero_point), + out); + return out; +} + +Tensor& quantized_add_Scalar_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + const Tensor& X_scale_t, + const Tensor& X_zero_point_t, + const Scalar& Y_scalar, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + float X_scale = X_scale_t.const_data_ptr()[0]; + int32_t X_zero_point = X_zero_point_t.const_data_ptr()[0]; + float Y = static_cast( + ::torch::executor::native::utils::scalar_to(Y_scalar)); +#define typed_quantized_add_Scalar(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_add_Scalar_( \ + X, \ + X_scale, \ + X_zero_point, \ + Y, \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_add_Scalar) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_add_Scalar + return out; +} + +#undef DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP + +} // namespace impl::generic::native diff --git a/backends/cadence/generic/operators/op_quantized_add.h b/backends/cadence/generic/operators/op_quantized_add.h new file mode 100644 index 00000000000..3f87ddcf5b9 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_add.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_add_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + const ::executorch::aten::Tensor& X_scale, + const ::executorch::aten::Tensor& X_zero_point, + const ::executorch::aten::Tensor& Y, + const ::executorch::aten::Tensor& Y_scale, + const ::executorch::aten::Tensor& Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_add_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + double X_scale, + int64_t X_zero_point, + const ::executorch::aten::Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_add_Scalar_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + const ::executorch::aten::Tensor& X_scale, + const ::executorch::aten::Tensor& X_zero_point, + const ::executorch::aten::Scalar& Y, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_add_asym8sxasym8s_asym8s_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + double X_scale, + int64_t X_zero_point, + const ::executorch::aten::Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_add_asym8uxasym8u_asym8u_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + double X_scale, + int64_t X_zero_point, + const ::executorch::aten::Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_conv1d.cpp b/backends/cadence/generic/operators/op_quantized_conv1d.cpp new file mode 100644 index 00000000000..6ae3a6613fb --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_conv1d.cpp @@ -0,0 +1,514 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +namespace { +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::quantize; + +// This implements a generic 1d conv kernel that operates on raw pointers. +// The quantized version handles both quantized convolutions for 1D inputs. +// The input is of shape [n x c x w] +// The weight is of shape [oc x wc x ww], where wc == c +// The output is of shape [n x oc x ow] +// The bias is of shape [oc] + +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv1d_ncl_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t w, + int32_t oc, + int32_t wc, + int32_t ww, + int32_t ow, + // Stride + int16_t s, + // Padding + int16_t p, + // Dilation + int16_t d, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d == 1 && p == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * c * w; + OT* out_batch = p_out + _n * oc * ow; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + OT* out_plane = out_batch + _oc * ow; + const WT* weight_batch = p_weight + _oc * wc * ww; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of size + // icpg x w, with a stencil of size icpg x ww, to compute an + // output channel of size 1 x ow. + for (int _w = 0, _ow = 0; _ow < ow; _w += s, ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + // If the padding is 0, and dilation is 1, then we can remove the + // unnecessary checks, and simplify the code so that it can be + // vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * w; + const WT* weight_plane = weight_batch + (_ic - sic) * ww; + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = _w + _ww; + int woff = _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = + weight_plane[woff] - (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * w; + const WT* weight_plane = weight_batch + (_ic - sic) * ww; + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_w + d * _ww - p) >= 0) && ((_w + d * _ww - p) < w)) { + int ioff = _w + d * _ww - p; + int woff = _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = + weight_plane[woff] - (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_plane[_ow] = quantize(val, inv_out_scale, out_zero_point); + } else { + out_plane[_ow] = acc; + } + } + } + } + } +} + +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv1d_nlc_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t w, + int32_t c, + int32_t oc, + int32_t ww, + int32_t wc, + int32_t ow, + // Stride + int16_t s, + // Padding + int16_t p, + // Dilation + int16_t d, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1. / out_scale; + bool zero_pad_unit_dilation = d == 1 && p == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * w * c; + OT* out_batch = p_out + _n * ow * oc; + for (int _w = 0, _ow = 0; _ow < ow; _w += s, ++_ow) { + OT* out_line = out_batch + _ow * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const WT* weight_batch = p_weight + _oc * ww * wc; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of + // size w x icpg, with a stencil of size ww x icpg, to + // compute an output channel of size ow x 1. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. If the padding is 0, and dilation is 1, then + // we can remove the unnecessary checks, and simplify the code + // so that it can be vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _ww = 0; _ww < ww; ++_ww) { + const IT* in_line = in_batch + (_w + _ww) * c; + const WT* weight_line = weight_batch + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } else { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_w + d * _ww - p) >= 0) && ((_w + d * _ww - p) < w)) { + const IT* in_line = in_batch + (_w + d * _ww - p) * c; + const WT* weight_line = weight_batch + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_line[_oc] = quantize(val, inv_out_scale, out_zero_point); + } else { + out_line[_oc] = acc; + } + } + } + } + } +} + +void quantized_conv1d_ncl( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + // input = [n, c, w] + const int n = input.size(0); + const int c = input.size(1); + const int w = input.size(2); + // weight = [oc, wc, ww] + const int oc = weight.size(0); + const int wc = weight.size(1); + const int ww = weight.size(2); + // output = [n, oc, ow] + const int ow = out.size(2); + +#define typed_quantized_conv1d_ncl(ctype, dtype) \ + case ScalarType::dtype: { \ + conv1d_ncl_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + c, \ + w, \ + oc, \ + wc, \ + ww, \ + ow, \ + stride[0], \ + padding[0], \ + dilation[0], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv1d_ncl); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv1d_ncl +} + +void quantized_conv1d_nlc( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + // input = [n, w, c] + const int n = input.size(0); + const int w = input.size(1); + const int c = input.size(2); + // weight = [oc, ww, wc] + const int oc = weight.size(0); + const int ww = weight.size(1); + const int wc = weight.size(2); + // output = [n, ow, oc] + const int ow = out.size(1); + +#define typed_quantized_conv1d_nlc(ctype, dtype) \ + case ScalarType::dtype: { \ + conv1d_nlc_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + w, \ + c, \ + oc, \ + ww, \ + wc, \ + ow, \ + stride[0], \ + padding[0], \ + dilation[0], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv1d_nlc); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv1d_nlc +} + +} // namespace + +Tensor& quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv1d_ncl( + input, + weight, + bias, + stride, + padding, + dilation, + static_cast(groups), + static_cast(in_zero_point), + static_cast(weight_zero_point), + static_cast(bias_scale), + static_cast(output_scale), + static_cast(output_zero_point), + out); + return out; +} + +Tensor& quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv1d_ncl( + input, + weight, + bias, + stride, + padding, + dilation, + static_cast(groups), + static_cast(in_zero_point), + static_cast(weight_zero_point), + static_cast(bias_scale), + static_cast(output_scale), + static_cast(output_zero_point), + out); + return out; +} + +Tensor& quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv1d_nlc( + input, + weight, + bias, + stride, + padding, + dilation, + static_cast(groups), + static_cast(in_zero_point), + static_cast(weight_zero_point), + static_cast(bias_scale), + static_cast(output_scale), + static_cast(output_zero_point), + out); + return out; +} + +Tensor& quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv1d_nlc( + input, + weight, + bias, + stride, + padding, + dilation, + static_cast(groups), + static_cast(in_zero_point), + static_cast(weight_zero_point), + static_cast(bias_scale), + static_cast(output_scale), + static_cast(output_zero_point), + out); + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_conv1d.h b/backends/cadence/generic/operators/op_quantized_conv1d.h new file mode 100644 index 00000000000..5cb79ab09fa --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_conv1d.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +executorch::aten::Tensor& +quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out( + executorch::runtime::KernelRuntimeContext& ctx, + const executorch::aten::Tensor& input, + const executorch::aten::Tensor& weight, + const executorch::aten::Tensor& bias, + executorch::aten::IntArrayRef stride, + executorch::aten::IntArrayRef padding, + executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + executorch::aten::Tensor& out); + +executorch::aten::Tensor& +quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out( + executorch::runtime::KernelRuntimeContext& ctx, + const executorch::aten::Tensor& input, + const executorch::aten::Tensor& weight, + const executorch::aten::Tensor& bias, + executorch::aten::IntArrayRef stride, + executorch::aten::IntArrayRef padding, + executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + executorch::aten::Tensor& out); + +executorch::aten::Tensor& +quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out( + executorch::runtime::KernelRuntimeContext& ctx, + const executorch::aten::Tensor& input, + const executorch::aten::Tensor& weight, + const executorch::aten::Tensor& bias, + executorch::aten::IntArrayRef stride, + executorch::aten::IntArrayRef padding, + executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + executorch::aten::Tensor& out); + +executorch::aten::Tensor& +quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out( + executorch::runtime::KernelRuntimeContext& ctx, + const executorch::aten::Tensor& input, + const executorch::aten::Tensor& weight, + const executorch::aten::Tensor& bias, + executorch::aten::IntArrayRef stride, + executorch::aten::IntArrayRef padding, + executorch::aten::IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_conv2d.cpp b/backends/cadence/generic/operators/op_quantized_conv2d.cpp new file mode 100644 index 00000000000..ca701957866 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_conv2d.cpp @@ -0,0 +1,1051 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::quantize; + +/* This implements a generic 2d conv kernel that operates on raw pointers. + * The quantized version handles quantized convolutions for 2D inputs. + * The input is of shape [n x c x h x w] + * The weight is of shape [oc x wc x wh x ww], where wc == c + * The output is of shape [n x oc x oh x ow] + * The bias is of shape [oc] + */ +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nchw_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t c, + int32_t h, + int32_t w, + int32_t oc, + int32_t wc, + int32_t wh, + int32_t ww, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + const float inv_out_scale = 1.f / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * c * h * w; + OT* out_batch = p_out + _n * oc * oh * ow; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + OT* out_plane = out_batch + _oc * oh * ow; + const WT* weight_batch = p_weight + _oc * wc * wh * ww; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of size + // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an + // output channel of size 1 x oh x ow. + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to the + // output channel being computed) with the corresponding weight + // channel. + // If the padding is 0, and dilation is 1, then we can remove the + // unnecessary checks, and simplify the code so that it can be + // vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + int ioff = (_h + _wh) * w + (_w + _ww); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + const IT* in_plane = in_batch + _ic * h * w; + const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1) < w)) { + int ioff = + (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1); + int woff = _wh * ww + _ww; + float lhs = in_plane[ioff] - in_zero_point; + float rhs = weight_plane[woff] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_plane[_oh * ow + _ow] = + quantize(val, inv_out_scale, out_zero_point); + } else { + out_plane[_oh * ow + _ow] = acc; + } + } + } + } + } + } +} + +template < + typename IT = float, + typename WT = IT, + typename BT = IT, + typename OT = IT, + bool quantized = false> +__attribute__((noinline)) void conv2d_nhwc_core_generic( + // All the arrays + const IT* __restrict__ p_in, + const WT* __restrict__ p_weight, + const BT* __restrict__ p_bias, + OT* __restrict__ p_out, + // The array sizes + int32_t n, + int32_t h, + int32_t w, + int32_t c, + int32_t oc, + int32_t wh, + int32_t ww, + int32_t wc, + int32_t oh, + int32_t ow, + // Stride + int16_t s0, + int16_t s1, + // Padding + int16_t p0, + int16_t p1, + // Dilation + int16_t d0, + int16_t d1, + // Group for depthwise conv + int16_t groups, + // Optional args that are only relevant for quantized convolution + // input zero point + IT in_zero_point = 0, + // weight zero point + int32_t weight_zero_point = 0, + float bias_scale = 1, + float out_scale = 1, + OT out_zero_point = 0) { + float inv_out_scale = 1.f / out_scale; + bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; + + // Compute the number of in and out channels per group + const int ocpg = oc / groups; + const int icpg = c / groups; + + // Iterate over all the output batches (i.e., n) + for (int _n = 0; _n < n; ++_n) { + const IT* in_batch = p_in + _n * h * w * c; + OT* out_batch = p_out + _n * oh * ow * oc; + for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { + for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { + OT* out_line = out_batch + (_oh * ow + _ow) * oc; + // Compute separable convolution for each group + for (int _g = 0; _g < groups; ++_g) { + // Identify the input and output channels involved in the computation + // of this group + int sic = _g * icpg; + int soc = _g * ocpg; + // Populate all the output channels in the group + for (int _oc = soc; _oc < soc + ocpg; ++_oc) { + const WT* weight_batch = p_weight + _oc * wh * ww * wc; + // We compute one output channel at a time. The computation can be + // thought of as a stencil computation: we iterate over an input of + // size h x w x icpg, with a stencil of size wh x ww x icpg, to + // compute an output channel of size oh x ow x 1. + float acc = p_bias[_oc]; + // Below is the stencil computation that performs the hadamard + // product+accumulation of each input channel (contributing to + // the output channel being computed) with the corresponding + // weight channel. If the padding is 0, and dilation is 1, then + // we can remove the unnecessary checks, and simplify the code + // so that it can be vectorized by Tensilica compiler. + if (zero_pad_unit_dilation) { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + const IT* in_line = + in_batch + (_h + _wh) * w * c + (_w + _ww) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } else { + for (int _wh = 0; _wh < wh; ++_wh) { + for (int _ww = 0; _ww < ww; ++_ww) { + if (((_h + d0 * _wh - p0) >= 0) && + ((_h + d0 * _wh - p0) < h) && + ((_w + d1 * _ww - p1) >= 0) && + ((_w + d1 * _ww - p1) < w)) { + const IT* in_line = in_batch + + (_h + d0 * _wh - p0) * w * c + (_w + d1 * _ww - p1) * c; + const WT* weight_line = + weight_batch + _wh * ww * wc + _ww * wc; + for (int _ic = sic; _ic < sic + icpg; ++_ic) { + float lhs = in_line[_ic] - in_zero_point; + float rhs = weight_line[_ic - sic] - + (quantized ? weight_zero_point : 0); + acc += lhs * rhs; + } + } + } + } + } + if (quantized) { + float val = bias_scale * acc; + out_line[_oc] = quantize(val, inv_out_scale, out_zero_point); + } else { + out_line[_oc] = acc; + } + } + } + } + } + } +} + +void quantized_conv2d_nchw( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + // input = [n, c, h, w] + const int n = input.size(0); + const int c = input.size(1); + const int h = conv1d ? 1 : input.size(2); + const int w = conv1d ? input.size(2) : input.size(3); + // weight = [oc, wc, wh, ww] + const int oc = weight.size(0); + const int wc = weight.size(1); + const int wh = conv1d ? 1 : weight.size(2); + const int ww = conv1d ? weight.size(2) : weight.size(3); + // output = [n, oc, oh, ow] + const int oh = conv1d ? 1 : out.size(2); + const int ow = conv1d ? out.size(2) : out.size(3); + + ET_CHECK_MSG( + weight_zero_point >= -128 && weight_zero_point <= 127, + "weight_zero_point %" PRId32 + " must be in range [-128, 127] for int8 cast", + weight_zero_point); + + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ScalarType::Short && + input.scalar_type() == ScalarType::Short && + weight.scalar_type() == ScalarType::Char) { + conv2d_nchw_core_generic( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + c, + h, + w, + oc, + wc, + wh, + ww, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], + groups, + static_cast(in_zero_point), + static_cast(weight_zero_point), + bias_scale, + output_scale, + static_cast(output_zero_point)); + return; + } + +#define typed_quantized_conv2d_nchw(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nchw_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + c, \ + h, \ + w, \ + oc, \ + wc, \ + wh, \ + ww, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_conv2d_nchw); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nchw +} + +void quantized_conv2d_nhwc( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + // input = [n, h, w, c] + const int n = input.size(0); + const int h = input.size(1); + const int w = input.size(2); + const int c = input.size(3); + // weight = [oc, wh, ww, wc] + const int oc = weight.size(0); + const int wh = weight.size(1); + const int ww = weight.size(2); + const int wc = weight.size(3); + // output = [n, oh, ow, oc] + const int oh = out.size(1); + const int ow = out.size(2); + + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ScalarType::Short && + input.scalar_type() == ScalarType::Short && + weight.scalar_type() == ScalarType::Char) { + conv2d_nhwc_core_generic( + input.const_data_ptr(), + weight.const_data_ptr(), + bias.const_data_ptr(), + out.mutable_data_ptr(), + n, + h, + w, + c, + oc, + wh, + ww, + wc, + oh, + ow, + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], + groups, + static_cast(in_zero_point), + static_cast(weight_zero_point), + bias_scale, + output_scale, + static_cast(output_zero_point)); + return; + } + +#define typed_quantized_conv2d_nhwc(ctype, dtype) \ + case ScalarType::dtype: { \ + conv2d_nhwc_core_generic( \ + input.const_data_ptr(), \ + weight.const_data_ptr(), \ + bias.const_data_ptr(), \ + out.mutable_data_ptr(), \ + n, \ + h, \ + w, \ + c, \ + oc, \ + wh, \ + ww, \ + wc, \ + oh, \ + ow, \ + stride[0], \ + stride[1], \ + padding[0], \ + padding[1], \ + dilation[0], \ + dilation[1], \ + groups, \ + in_zero_point, \ + weight_zero_point, \ + bias_scale, \ + output_scale, \ + (ctype)output_zero_point); \ + break; \ + } + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_conv2d_nhwc); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_conv2d_nhwc +} + +Tensor& quantized_conv2d_nchw_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED const Tensor& out_multiplier, + ET_UNUSED const Tensor& out_shift, + Tensor& out) { + const float bias_scale_float = bias_scale.const_data_ptr()[0]; + const int32_t weight_zero_point_int = + weight_zero_point.const_data_ptr()[0]; + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED const Tensor& out_multiplier, + ET_UNUSED const Tensor& out_shift, + Tensor& out) { + const float bias_scale_float = bias_scale.const_data_ptr()[0]; + const int32_t weight_zero_point_int = + weight_zero_point.const_data_ptr()[0]; + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point_int, + bias_scale_float, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +Tensor& quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + ET_UNUSED int64_t out_multiplier, + ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv2d_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_conv2d.h b/backends/cadence/generic/operators/op_quantized_conv2d.h new file mode 100644 index 00000000000..07678b0600c --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_conv2d.h @@ -0,0 +1,326 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// Quantized Conv2D operators - NCHW layout +::executorch::aten::Tensor& quantized_conv2d_nchw_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& out); + +::executorch::aten::Tensor& quantized_conv2d_nchw_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +// Quantized Conv2D operators - NHWC layout +::executorch::aten::Tensor& quantized_conv2d_nhwc_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& bias_scale, + double output_scale, + int64_t output_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& out); + +::executorch::aten::Tensor& quantized_conv2d_nhwc_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +::executorch::aten::Tensor& +quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_embedding_byte.cpp b/backends/cadence/generic/operators/op_quantized_embedding_byte.cpp new file mode 100644 index 00000000000..55ca67648ca --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_embedding_byte.cpp @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) + +Tensor& quantized_embedding_byte_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& weight, + const Tensor& weight_scales, + const optional& weight_zero_points, + const Tensor& indices, + ET_UNUSED bool pruned_weights, + Tensor& out) { + size_t embedding_dim = weight.size(1); + + size_t num_groups = 1; + if (weight_scales.dim() == 2) { + num_groups = weight_scales.size(1); + } + + float* out_data = out.mutable_data_ptr(); + const int64_t* indices_ptr = indices.const_data_ptr(); + + const float* scales = weight_scales.const_data_ptr(); + + ScalarType dtype = weight.scalar_type(); + +#define typed_quantized_embedding_byte(ctype, dtype) \ + case ScalarType::dtype: { \ + ctype zp = 0; \ + if (weight_zero_points.has_value()) { \ + zp = weight_zero_points \ + ->const_data_ptr()[index * num_groups + group]; \ + } \ + const size_t output_group_start_offset = \ + embedding_dim * index + group * embedding_group_size; \ + const ctype* w_group = \ + weight.const_data_ptr() + output_group_start_offset; \ + for (size_t j = 0; j < embedding_group_size; ++j) { \ + float val = ((float)w_group[j] - zp) * scale; \ + *out_data++ = val; \ + } \ + break; \ + } + + size_t embedding_group_size = embedding_dim / num_groups; + for (size_t i = 0, e = indices.numel(); i < e; i++) { + int64_t index = indices_ptr[i]; + for (size_t group = 0; group < num_groups; group++) { + float scale = scales[index * num_groups + group]; + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_embedding_byte) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + } + } + +#undef typed_quantized_embedding_byte + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_embedding_byte.h b/backends/cadence/generic/operators/op_quantized_embedding_byte.h new file mode 100644 index 00000000000..a46bebe09df --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_embedding_byte.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_embedding_byte_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& weight_scales, + const ::executorch::aten::optional<::executorch::aten::Tensor>& + weight_zero_points, + const ::executorch::aten::Tensor& indices, + bool pruned_weights, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_fully_connected.cpp b/backends/cadence/generic/operators/op_quantized_fully_connected.cpp new file mode 100644 index 00000000000..55e29cb7f52 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_fully_connected.cpp @@ -0,0 +1,178 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& quantized_fully_connected_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point_t, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear + return out; +} + +Tensor& quantized_fully_connected_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear + return out; +} + +Tensor& quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear + return out; +} + +Tensor& quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + in, \ + weight, \ + bias, \ + in_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_linear); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_linear + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_fully_connected.h b/backends/cadence/generic/operators/op_quantized_fully_connected.h new file mode 100644 index 00000000000..a7510fba95f --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_fully_connected.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_fully_connected_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t in_zero_point, + const ::executorch::aten::Tensor& weight_zero_point_t, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_fully_connected_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& +quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& +quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& in, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/quantized_layer_norm.cpp b/backends/cadence/generic/operators/op_quantized_layer_norm.cpp similarity index 59% rename from backends/cadence/generic/operators/quantized_layer_norm.cpp rename to backends/cadence/generic/operators/op_quantized_layer_norm.cpp index 7d8d353bff3..e34ed342d22 100644 --- a/backends/cadence/generic/operators/quantized_layer_norm.cpp +++ b/backends/cadence/generic/operators/op_quantized_layer_norm.cpp @@ -6,12 +6,26 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include +#include #include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; using ::executorch::aten::ScalarType; using ::executorch::aten::Tensor; using ::executorch::runtime::getLeadingDims; @@ -19,10 +33,6 @@ using ::executorch::runtime::KernelRuntimeContext; using ::impl::generic::kernels::dequantize; using ::impl::generic::kernels::quantize; -namespace impl { -namespace generic { -namespace native { - // Compute quantized layer_norm. The current implementation assumes that the // input is per-tensor quantized. template @@ -45,7 +55,8 @@ void quantized_layer_norm_per_tensor_( float output_inv_scale = 1.0f / output_scale; size_t last_dim = input.size(input.dim() - 1); - size_t leading_dims = getLeadingDims(input, input.dim() - 1); + size_t leading_dims = + ::executorch::runtime::getLeadingDims(input, input.dim() - 1); // Visualize the input tensor as a set of 1d vectors, and compute the // layer_norm for each vector. @@ -72,7 +83,7 @@ void quantized_layer_norm_per_tensor_( float inv_std = 1.0f / std::sqrt(variance + eps); // y = (x - mean) / std * kGamma + kBeta - for (int j = 0; j < last_dim; ++j) { + for (size_t j = 0; j < last_dim; ++j) { // y[j] = (x[j] - mean) / std * kGamma + kBeta; // Since X is quantized, we dequantize it, compute fp32 result, and // quantize the result to an int8/uint8 value. @@ -100,8 +111,6 @@ void quantized_layer_norm_( // Extract the zero point and scale for input tensor. float input_scale = in_scale.const_data_ptr()[0]; int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; - - // Call other overload quantized_layer_norm_per_tensor_( input, input_scale, @@ -114,88 +123,82 @@ void quantized_layer_norm_( out); } -void quantized_layer_norm_out( - __ET_UNUSED KernelRuntimeContext& ctx, +Tensor& quantized_layer_norm_out( + ET_UNUSED ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input, const Tensor& in_scale, const Tensor& in_zero_point, - __ET_UNUSED const executorch::aten::IntArrayRef normalized_shape, + ET_UNUSED const IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias, double eps, double output_scale, int64_t output_zero_point, Tensor& out) { - if (input.scalar_type() == executorch::aten::ScalarType::Byte) { - quantized_layer_norm_( - input, - in_scale, - in_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); - } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { - quantized_layer_norm_( - input, - in_scale, - in_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(input.scalar_type())); +#define typed_quantized_layer_norm(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_layer_norm_( \ + input, \ + in_scale, \ + in_zero_point, \ + weight, \ + bias, \ + eps, \ + output_scale, \ + output_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_layer_norm) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); } + +#undef typed_quantized_layer_norm + return out; } -void quantized_layer_norm_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, +Tensor& quantized_layer_norm_per_tensor_out( + ET_UNUSED ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input, double in_scale, int64_t in_zero_point, - __ET_UNUSED const executorch::aten::IntArrayRef normalized_shape, + ET_UNUSED const IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias, double eps, double output_scale, int64_t output_zero_point, Tensor& out) { - if (input.scalar_type() == executorch::aten::ScalarType::Byte) { - quantized_layer_norm_per_tensor_( - input, - in_scale, - in_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); - } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { - quantized_layer_norm_per_tensor_( - input, - in_scale, - in_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(input.scalar_type())); +#define typed_quantized_layer_norm(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_layer_norm_per_tensor_( \ + input, \ + in_scale, \ + in_zero_point, \ + weight, \ + bias, \ + eps, \ + output_scale, \ + output_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_layer_norm) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); } + +#undef typed_quantized_layer_norm + return out; } } // namespace native diff --git a/backends/cadence/generic/operators/op_quantized_layer_norm.h b/backends/cadence/generic/operators/op_quantized_layer_norm.h new file mode 100644 index 00000000000..ed642559248 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_layer_norm.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_layer_norm_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& in_scale, + const ::executorch::aten::Tensor& in_zero_point, + __ET_UNUSED const ::executorch::aten::IntArrayRef normalized_shape, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_layer_norm_per_tensor_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + double in_scale, + int64_t in_zero_point, + __ET_UNUSED const ::executorch::aten::IntArrayRef normalized_shape, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_linear.cpp b/backends/cadence/generic/operators/op_quantized_linear.cpp new file mode 100644 index 00000000000..87f990a855b --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_linear.cpp @@ -0,0 +1,220 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::executorch::runtime::toString; + +Tensor& quantized_linear_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + int64_t src_zero_point, + const Tensor& weight_zero_point_t, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point_t, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (dtype == ScalarType::Short && src.scalar_type() == ScalarType::Short && + weight.scalar_type() == ScalarType::Char) { + ::impl::generic::quantized::quantized_linear_( + src, + weight, + bias, + src_zero_point, + weight_zero_point_t, + out_multiplier, + out_shift, + out_zero_point, + out); + return out; + } + + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_linear); + default: + ET_DCHECK_MSG(false, "Unhandled dtype %s", toString(dtype)); + } +#undef typed_quantized_linear + return out; +} + +Tensor& quantized_linear_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + ET_UNUSED const optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (dtype == ScalarType::Short && src.scalar_type() == ScalarType::Short && + weight.scalar_type() == ScalarType::Char) { + ::impl::generic::quantized::quantized_linear_per_tensor_( + src, + weight, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + out); + return out; + } + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16( + typed_quantized_linear_per_tensor); + default: + ET_KERNEL_CHECK_MSG( + ctx, + false, + InvalidArgument, + out, + "Unhandled dtype %s", + toString(dtype)); + } +#undef typed_quantized_linear_per_tensor + return out; +} + +Tensor& quantized_linear_asym8sxasym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + ET_UNUSED const std::optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16( + typed_quantized_linear_per_tensor); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); + } +#undef typed_quantized_linear_per_tensor + return out; +} + +Tensor& quantized_linear_asym8uxasym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + ET_UNUSED const std::optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case ScalarType::dtype: { \ + ::impl::generic::quantized::quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16( + typed_quantized_linear_per_tensor); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); + } +#undef typed_quantized_linear_per_tensor + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_linear.h b/backends/cadence/generic/operators/op_quantized_linear.h new file mode 100644 index 00000000000..b5396cb9701 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_linear.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_linear_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + int64_t src_zero_point, + const ::executorch::aten::Tensor& weight_zero_point_t, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_linear_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + const ::executorch::aten::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& +quantized_linear_asym8sxasym8s_asym8s_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + const std::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& +quantized_linear_asym8uxasym8u_asym8u_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& src, + const ::executorch::aten::Tensor& weight, + const ::executorch::aten::Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + const std::optional<::executorch::aten::Tensor>& offset, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_matmul.cpp b/backends/cadence/generic/operators/op_quantized_matmul.cpp new file mode 100644 index 00000000000..e3fb0f00fdc --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_matmul.cpp @@ -0,0 +1,286 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::quantize; + +// The quantized matmul. The quantized matmul accumulates in a wider register, +// whose type is TA. +template < + typename TZ, + typename TA = float, + bool transposed = false, + typename TX = TZ, + typename TY = TZ> +__attribute__((noinline)) void qmatmul( + TZ* __restrict__ Z, + int32_t Z_multiplier, + int32_t Z_shift, + int32_t Z_zero_point, + const TX* __restrict__ X, + int32_t X_zero_point, + const TY* __restrict__ y, + int32_t Y_zero_point, + size_t m, + size_t n, + size_t p) { + // Compute the Z_scale from Z_multiplier and Z_shift + const float Z_scale = -Z_multiplier * 1.0 / (1 << 31) * pow(2, Z_shift); + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < p; ++j) { + TA sum = 0; + for (size_t k = 0; k < n; ++k) { + if (transposed) { + sum += (X[i * n + k] - X_zero_point) * (y[j * n + k] - Y_zero_point); + } else { + sum += (X[i * n + k] - X_zero_point) * (y[k * p + j] - Y_zero_point); + } + } + Z[i * p + j] = quantize(sum, Z_scale, Z_zero_point); + } + } +} + +Tensor& quantized_matmul_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + ET_UNUSED const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + size_t batch_size = ::executorch::runtime::getLeadingDims(X, X.dim() - 2); + size_t leading_dim = X.size(X.dim() - 2); + size_t out_dim = Y.size(Y.dim() - 1 - transposed); + size_t in_dim = X.size(X.dim() - 1); + + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) + if (out.scalar_type() == ScalarType::Short && + X.scalar_type() == ScalarType::Short && + Y.scalar_type() == ScalarType::Char) { + int16_t* __restrict__ out_data = out.mutable_data_ptr(); + const int16_t* __restrict__ X_data = X.const_data_ptr(); + const int8_t* __restrict__ Y_data = Y.const_data_ptr(); + for (size_t i = 0; i < batch_size; ++i) { + const int16_t* x = X_data + i * leading_dim * in_dim; + const int8_t* y = Y_data + i * in_dim * out_dim; + int16_t* z = out_data + i * leading_dim * out_dim; + if (transposed) { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } else { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } + } + return out; + } + +#define typed_quantized_matmul(ctype, dtype) \ + case ScalarType::dtype: { \ + ctype* __restrict__ out_data = out.mutable_data_ptr(); \ + const ctype* __restrict__ X_data = X.const_data_ptr(); \ + const ctype* __restrict__ Y_data = Y.const_data_ptr(); \ + for (size_t i = 0; i < batch_size; ++i) { \ + const ctype* x = X_data + i * leading_dim * in_dim; \ + const ctype* y = Y_data + i * in_dim * out_dim; \ + ctype* z = out_data + i * leading_dim * out_dim; \ + if (transposed) { \ + qmatmul( \ + z, \ + static_cast(out_multiplier), \ + static_cast(out_shift), \ + static_cast(out_zero_point), \ + x, \ + static_cast(X_zero_point), \ + y, \ + static_cast(Y_zero_point), \ + leading_dim, \ + in_dim, \ + out_dim); \ + } else { \ + qmatmul( \ + z, \ + static_cast(out_multiplier), \ + static_cast(out_shift), \ + static_cast(out_zero_point), \ + x, \ + static_cast(X_zero_point), \ + y, \ + static_cast(Y_zero_point), \ + leading_dim, \ + in_dim, \ + out_dim); \ + } \ + } \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_matmul); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_matmul + return out; +} + +template +void _typed_quantized_matmul( + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + ET_UNUSED const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + size_t batch_size = ::executorch::runtime::getLeadingDims(X, X.dim() - 2); + size_t leading_dim = X.size(X.dim() - 2); + size_t out_dim = Y.size(Y.dim() - 1 - transposed); + size_t in_dim = X.size(X.dim() - 1); + + T* __restrict__ out_data = out.mutable_data_ptr(); + const T* __restrict__ X_data = X.const_data_ptr(); + const T* __restrict__ Y_data = Y.const_data_ptr(); + for (size_t i = 0; i < batch_size; ++i) { + const T* x = X_data + i * leading_dim * in_dim; + const T* y = Y_data + i * in_dim * out_dim; + T* z = out_data + i * leading_dim * out_dim; + if (transposed) { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } else { + qmatmul( + z, + static_cast(out_multiplier), + static_cast(out_shift), + static_cast(out_zero_point), + x, + static_cast(X_zero_point), + y, + static_cast(Y_zero_point), + leading_dim, + in_dim, + out_dim); + } + } +} + +Tensor& quantized_matmul_asym8sxasym8s_asym8s_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + _typed_quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + return out; +} + +Tensor& quantized_matmul_asym8uxasym8u_asym8u_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out) { + _typed_quantized_matmul( + X, + X_zero_point, + Y, + Y_zero_point, + bias, + out_multiplier, + out_shift, + out_zero_point, + transposed, + out); + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_matmul.h b/backends/cadence/generic/operators/op_quantized_matmul.h new file mode 100644 index 00000000000..70775380aac --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_matmul.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::optional; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +Tensor& quantized_matmul_out( + KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out); + +Tensor& quantized_matmul_asym8sxasym8s_asym8s_out( + KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out); + +Tensor& quantized_matmul_asym8uxasym8u_asym8u_out( + KernelRuntimeContext& ctx, + const Tensor& X, + int64_t X_zero_point, + const Tensor& Y, + int64_t Y_zero_point, + const optional& bias, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + bool transposed, + Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_mul.cpp b/backends/cadence/generic/operators/op_quantized_mul.cpp new file mode 100644 index 00000000000..89fb2a5250d --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_mul.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::dequantize; +using ::impl::generic::kernels::quantize; + +DECLARE_POINTWISE_TENSOR_QUANTIZED_BINARY_OP(quantized_mul_, *); + +Tensor& quantized_mul_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + const Tensor& X_scale_t, + const Tensor& X_zero_point_t, + const Tensor& Y, + const Tensor& Y_scale_t, + const Tensor& Y_zero_point_t, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + float X_scale = X_scale_t.const_data_ptr()[0]; + int32_t X_zero_point = X_zero_point_t.const_data_ptr()[0]; + float Y_scale = Y_scale_t.const_data_ptr()[0]; + int32_t Y_zero_point = Y_zero_point_t.const_data_ptr()[0]; +#define typed_quantized_mul(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_mul_( \ + X, \ + X_scale, \ + X_zero_point, \ + Y, \ + Y_scale, \ + Y_zero_point, \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_mul) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_mul + return out; +} + +// Generate kernels that perform elementwise arithmetic on a quantized tensor, +// and a scalar. +#define DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \ + template \ + void BINARY_FUNC_NAME( \ + const Tensor& X, \ + float X_scale, \ + int32_t X_zero_point, \ + const float Y, \ + float out_scale, \ + int32_t out_zero_point, \ + Tensor& out) { \ + const T* __restrict__ X_data = X.const_data_ptr(); \ + T* __restrict__ out_data = out.mutable_data_ptr(); \ + float inv_out_scale = 1.0f / out_scale; \ + for (size_t i = 0, e = X.numel(); i < e; ++i) { \ + float x = dequantize(X_data[i], X_scale, X_zero_point); \ + float z = x OP Y; \ + out_data[i] = quantize(z, inv_out_scale, out_zero_point); \ + } \ + } + +DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP(quantized_mul_Scalar_, *); + +Tensor& quantized_mul_Scalar_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& X, + const Tensor& X_scale_t, + const Tensor& X_zero_point_t, + const Scalar& Y_scalar, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + float X_scale = X_scale_t.const_data_ptr()[0]; + int32_t X_zero_point = X_zero_point_t.const_data_ptr()[0]; + float Y = static_cast( + ::torch::executor::native::utils::scalar_to(Y_scalar)); + +#define typed_quantized_mul_Scalar(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_mul_Scalar_( \ + X, \ + X_scale, \ + X_zero_point, \ + Y, \ + static_cast(out_scale), \ + static_cast(out_zero_point), \ + out); \ + break; \ + } + + ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_mul_Scalar) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } +#undef typed_quantized_mul_Scalar + return out; +} + +#undef DECLARE_POINTWISE_SCALAR_QUANTIZED_BINARY_OP + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_mul.h b/backends/cadence/generic/operators/op_quantized_mul.h new file mode 100644 index 00000000000..7ca8b2f1db0 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_mul.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_mul_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + const ::executorch::aten::Tensor& X_scale_t, + const ::executorch::aten::Tensor& X_zero_point_t, + const ::executorch::aten::Tensor& Y, + const ::executorch::aten::Tensor& Y_scale_t, + const ::executorch::aten::Tensor& Y_zero_point_t, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_mul_Scalar_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + const ::executorch::aten::Tensor& X_scale_t, + const ::executorch::aten::Tensor& X_zero_point_t, + const ::executorch::aten::Scalar& Y_scalar, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_relu.cpp b/backends/cadence/generic/operators/op_quantized_relu.cpp new file mode 100644 index 00000000000..9430951f65b --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_relu.cpp @@ -0,0 +1,184 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::quantize; + +template +void quantized_relu_per_tensor_out_( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { + const T* __restrict__ in = input.const_data_ptr(); + T* __restrict__ out = output.mutable_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); + + for (size_t i = 0, e = input.numel(); i < e; ++i) { + const float temp = in[i] > in_zero_point ? (in[i] - in_zero_point) : 0; + out[i] = quantize(temp, out_scale, out_zero_point); + } +} + +Tensor& quantized_relu_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { +#define typed_quantized_relu(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_relu_per_tensor_out_( \ + ctx, \ + input, \ + in_zero_point, \ + out_zero_point, \ + out_multiplier, \ + out_shift, \ + output); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_relu + return output; +} + +template +void quantized_relu_( + const Tensor& input, + const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { + T q_zero_point = in_zero_point.const_data_ptr()[0]; + const T* __restrict__ in = input.const_data_ptr(); + T* __restrict__ out = output.mutable_data_ptr(); + + const int32_t* __restrict__ out_multiplier_data = + out_multiplier.const_data_ptr(); + const int32_t* __restrict__ out_shift_data = + out_shift.const_data_ptr(); + + // Compute the out_scale from out_multiplier and out_shift + const float out_scale = + -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); + + for (size_t i = 0, e = input.numel(); i < e; ++i) { + const T temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0; + out[i] = quantize(temp, out_scale, out_zero_point); + } +} + +Tensor& quantized_relu_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& in_zero_point, + const int64_t out_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + Tensor& output) { +#define typed_quantized_relu(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_relu_( \ + input, \ + in_zero_point, \ + out_zero_point, \ + out_multiplier, \ + out_shift, \ + output); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_relu + return output; +} + +Tensor& quantized_relu_asym8s_asym8s_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { + quantized_relu_per_tensor_out_( + ctx, + input, + in_zero_point, + out_zero_point, + out_multiplier, + out_shift, + output); + return output; +} + +Tensor& quantized_relu_asym8u_asym8u_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + Tensor& output) { + quantized_relu_per_tensor_out_( + ctx, + input, + in_zero_point, + out_zero_point, + out_multiplier, + out_shift, + output); + return output; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_relu.h b/backends/cadence/generic/operators/op_quantized_relu.h new file mode 100644 index 00000000000..6241b2ddfcf --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_relu.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_relu_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& in_zero_point, + const int64_t out_zero_point, + const ::executorch::aten::Tensor& out_multiplier, + const ::executorch::aten::Tensor& out_shift, + ::executorch::aten::Tensor& output); + +::executorch::aten::Tensor& quantized_relu_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + ::executorch::aten::Tensor& output); + +::executorch::aten::Tensor& quantized_relu_asym8s_asym8s_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + ::executorch::aten::Tensor& output); + +::executorch::aten::Tensor& quantized_relu_asym8u_asym8u_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const int64_t in_zero_point, + const int64_t out_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + ::executorch::aten::Tensor& output); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_softmax.cpp b/backends/cadence/generic/operators/op_quantized_softmax.cpp new file mode 100644 index 00000000000..61037f22167 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_softmax.cpp @@ -0,0 +1,221 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace impl { +namespace generic { +namespace native { +namespace { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; +using ::impl::generic::kernels::dequantize; +using ::impl::generic::kernels::quantize; + +template +void quantized_softmax_per_tensor_( + const Tensor& input, + ET_UNUSED const Tensor& mask, + int64_t dim, + const float in_scale, + const int64_t in_zero_point, + const float out_scale, + const int64_t out_zero_point, + Tensor& out) { + const T* __restrict__ in_data = input.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + + float out_inv_scale = 1.0f / out_scale; + if (dim < 0) { + dim += input.dim(); + } + const int64_t input_size = input.numel(); + float* x = new float[input_size]; + + torch::executor::apply_over_dim( + [in_data, + out_data, + x, + in_scale, + in_zero_point, + out_inv_scale, + out_zero_point]( + const size_t size, const size_t stride, const size_t base) { + // Dequantize the input tensor + torch::executor::apply_unary_map_fn( + [in_scale, in_zero_point](const float val_in) { + return dequantize( + val_in, in_scale, static_cast(in_zero_point)); + }, + in_data + base, + x + base, + size, + stride); + + // Subtract max(X) from input tensor + float max_in = torch::executor::apply_unary_reduce_fn( + [](const float val_in, float val_accum) { + return std::max(val_in, val_accum); + }, + x + base, + size, + stride); + + // Compute exp(X - max(X)) + torch::executor::apply_unary_map_fn( + [max_in](const float val_in) { return std::exp(val_in - max_in); }, + x + base, + x + base, + size, + stride); + + // Compute sum(exp(X - max(X)) + float temp_sum = torch::executor::apply_unary_reduce_fn( + [](const float val_in, float val_accum) { + return val_accum + val_in; + }, + x + base, + size, + stride); + + // Compute exp(X - max(X)) / sum(exp(X - max(X)) and quantize the + float recip = 1.0 / temp_sum; + torch::executor::apply_unary_map_fn( + [recip, out_inv_scale, out_zero_point](const float val_in) { + float res = val_in * recip; + return quantize( + res, out_inv_scale, static_cast(out_zero_point)); + }, + x + base, + out_data + base, + size, + stride); + }, + input, + dim); + + delete[] x; +} + +// Compute quantized softmax. The current implementation assumes that the +// input is per-tensor quantized. +template +void quantized_softmax_( + const Tensor& input, + const Tensor& mask, + const int64_t dim, + const Tensor& in_scale, + const Tensor& in_zero_point, + const Tensor& out_scale, + const Tensor& out_zero_point, + Tensor& out) { + // Extract the zero point and scale for input tensor. + float input_scale = in_scale.const_data_ptr()[0]; + int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; + float output_scale = out_scale.const_data_ptr()[0]; + int64_t output_zero_point = out_zero_point.const_data_ptr()[0]; + quantized_softmax_per_tensor_( + input, + mask, + dim, + input_scale, + input_zero_point, + output_scale, + output_zero_point, + out); +} + +} // namespace + +Tensor& quantized_softmax_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& mask, + int64_t dim, + const Tensor& in_scale, + const Tensor& in_zero_point, + const Tensor& out_scale, + const Tensor& out_zero_point, + Tensor& out) { +#define typed_quantized_softmax(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_softmax_( \ + input, \ + mask, \ + dim, \ + in_scale, \ + in_zero_point, \ + out_scale, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_softmax) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_softmax + return out; +} + +Tensor& quantized_softmax_per_tensor_out( + ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& mask, + int64_t dim, + double in_scale, + int64_t in_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { +#define typed_quantized_softmax(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_softmax_per_tensor_( \ + input, \ + mask, \ + dim, \ + in_scale, \ + in_zero_point, \ + out_scale, \ + out_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(typed_quantized_softmax) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_softmax + return out; +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_softmax.h b/backends/cadence/generic/operators/op_quantized_softmax.h new file mode 100644 index 00000000000..485113851a3 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_softmax.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +::executorch::aten::Tensor& quantized_softmax_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& mask, + int64_t dim, + const ::executorch::aten::Tensor& in_scale, + const ::executorch::aten::Tensor& in_zero_point, + const ::executorch::aten::Tensor& out_scale, + const ::executorch::aten::Tensor& out_zero_point, + ::executorch::aten::Tensor& out); + +::executorch::aten::Tensor& quantized_softmax_per_tensor_out( + __ET_UNUSED ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& input, + const ::executorch::aten::Tensor& mask, + int64_t dim, + double in_scale, + int64_t in_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/operators.h b/backends/cadence/generic/operators/operators.h deleted file mode 100644 index c9d8e8782bd..00000000000 --- a/backends/cadence/generic/operators/operators.h +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#pragma once - -#include -#include -#include -#include - -namespace impl { -namespace generic { -namespace native { -namespace { -using ::executorch::runtime::getLeadingDims; - -#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ - _(uint8_t, Byte) \ - _(int8_t, Char) - -inline __attribute__((always_inline)) void linear_( - const ::executorch::aten::Tensor& input, - const ::executorch::aten::Tensor& weight, - const std::optional<::executorch::aten::Tensor>& bias, - ::executorch::aten::Tensor& output) { - const float* __restrict__ input_data = input.const_data_ptr(); - const float* __restrict__ weight_data = weight.const_data_ptr(); - const float* __restrict__ bias_data = bias.value().const_data_ptr(); - float* __restrict__ output_data = output.mutable_data_ptr(); - - // input comes in shape [batch_size, in_dim] - // weight comes in shape [out_dim, in_dim] - // output comes in empty with shape [batch_size, out_dim] - // Perform matrix multiply (M x N) x (N x P) => M x P - int64_t M = weight.size(0); // = out_dim - int64_t N = weight.size(1); // = in_dim - - // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the - // leading dimensions is d0 * d1 * ... * d_{N-2} - int64_t leading_dims = getLeadingDims(input, input.dim() - 1); - - for (int i = 0; i < leading_dims; ++i) { - for (int j = 0; j < M; ++j) { - float sum = bias_data[j]; - for (int k = 0; k < N; ++k) { - sum += input_data[i * N + k] * weight_data[j * N + k]; - } - output_data[i * M + j] = sum; - } - } -} - -} // namespace -} // namespace native -} // namespace generic -} // namespace impl diff --git a/backends/cadence/generic/operators/quantized_add_out.cpp b/backends/cadence/generic/operators/quantized_add_out.cpp deleted file mode 100644 index 14ee62fb944..00000000000 --- a/backends/cadence/generic/operators/quantized_add_out.cpp +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace impl { -namespace generic { -namespace native { - -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; -using ::impl::generic::kernels::dequantize; -using ::impl::generic::kernels::quantize; - -template -void quantized_add_per_tensor_impl( - const Tensor& X, - double X_scale, - int64_t X_zero_point, - const Tensor& Y, - double Y_scale, - int64_t Y_zero_point, - double out_scale, - int64_t out_zero_point, - Tensor& out) { - const T* __restrict__ X_data = X.const_data_ptr(); - const T* __restrict__ Y_data = Y.const_data_ptr(); - T* __restrict__ out_data = out.mutable_data_ptr(); - - ssize_t Y_numel = Y.numel(); - ssize_t X_numel = X.numel(); - ssize_t out_numel = out.numel(); - - float X_scale_f = static_cast(X_scale); - float Y_scale_f = static_cast(Y_scale); - float out_scale_f = static_cast(out_scale); - int32_t X_zero_point_i32 = static_cast(X_zero_point); - int32_t Y_zero_point_i32 = static_cast(Y_zero_point); - int32_t out_zero_point_i32 = static_cast(out_zero_point); - - float inv_out_scale = 1.0f / out_scale_f; - - // Simple case: tensors have the same shape, no broadcasting - if (X_numel == Y_numel && Y_numel == out_numel) { - for (size_t i = 0; i < X_numel; ++i) { - float x = dequantize(X_data[i], X_scale_f, X_zero_point_i32); - float y = dequantize(Y_data[i], Y_scale_f, Y_zero_point_i32); - float z = x + y; - out_data[i] = quantize(z, inv_out_scale, out_zero_point_i32); - } - } - // Y is a scalar tensor - else if (Y_numel == 1) { - float y = dequantize(Y_data[0], Y_scale_f, Y_zero_point_i32); - for (size_t i = 0; i < X_numel; ++i) { - float x = dequantize(X_data[i], X_scale_f, X_zero_point_i32); - float z = x + y; - out_data[i] = quantize(z, inv_out_scale, out_zero_point_i32); - } - } - // X is a scalar tensor - else if (X_numel == 1) { - float x = dequantize(X_data[0], X_scale_f, X_zero_point_i32); - for (size_t i = 0; i < Y_numel; ++i) { - float y = dequantize(Y_data[i], Y_scale_f, Y_zero_point_i32); - float z = x + y; - out_data[i] = quantize(z, inv_out_scale, out_zero_point_i32); - } - } - // General broadcasting case - simplified implementation - else { - for (ssize_t i = 0; i < out_numel; ++i) { - // Simple broadcasting: repeat elements as needed - size_t x_idx = (X_numel == 1) ? 0 : i % X_numel; - size_t y_idx = (Y_numel == 1) ? 0 : i % Y_numel; - - float x = dequantize(X_data[x_idx], X_scale_f, X_zero_point_i32); - float y = dequantize(Y_data[y_idx], Y_scale_f, Y_zero_point_i32); - float z = x + y; - out_data[i] = quantize(z, inv_out_scale, out_zero_point_i32); - } - } -} - -// Generic quantized add with type dispatch -void quantized_add_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& X, - double X_scale, - int64_t X_zero_point, - const Tensor& Y, - double Y_scale, - int64_t Y_zero_point, - double out_scale, - int64_t out_zero_point, - Tensor& out) { - (void)ctx; - - executorch::aten::ScalarType dtype = X.scalar_type(); - switch (dtype) { - case executorch::aten::ScalarType::Byte: - quantized_add_per_tensor_impl( - X, - X_scale, - X_zero_point, - Y, - Y_scale, - Y_zero_point, - out_scale, - out_zero_point, - out); - break; - case executorch::aten::ScalarType::Char: - quantized_add_per_tensor_impl( - X, - X_scale, - X_zero_point, - Y, - Y_scale, - Y_zero_point, - out_scale, - out_zero_point, - out); - break; - default: - ET_CHECK_MSG( - false, "Unhandled input dtype %hhd", static_cast(dtype)); - } -} - -// int8-specific quantized add -void quantized_add_asym8sxasym8s_asym8s_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& X, - double X_scale, - int64_t X_zero_point, - const Tensor& Y, - double Y_scale, - int64_t Y_zero_point, - double out_scale, - int64_t out_zero_point, - Tensor& out) { - (void)ctx; - - quantized_add_per_tensor_impl( - X, - X_scale, - X_zero_point, - Y, - Y_scale, - Y_zero_point, - out_scale, - out_zero_point, - out); -} - -// uint8-specific quantized add -void quantized_add_asym8uxasym8u_asym8u_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& X, - double X_scale, - int64_t X_zero_point, - const Tensor& Y, - double Y_scale, - int64_t Y_zero_point, - double out_scale, - int64_t out_zero_point, - Tensor& out) { - (void)ctx; - - quantized_add_per_tensor_impl( - X, - X_scale, - X_zero_point, - Y, - Y_scale, - Y_zero_point, - out_scale, - out_zero_point, - out); -} - -} // namespace native -} // namespace generic -} // namespace impl diff --git a/backends/cadence/generic/operators/quantized_conv2d_nchw_out.cpp b/backends/cadence/generic/operators/quantized_conv2d_nchw_out.cpp deleted file mode 100644 index fbb01c82e65..00000000000 --- a/backends/cadence/generic/operators/quantized_conv2d_nchw_out.cpp +++ /dev/null @@ -1,567 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -namespace impl { -namespace generic { -namespace native { - -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; - -// This implements a generic 2d conv kernel that operates on raw pointers. -// The version handles both quantized and fp32 convolutions. -// The input is of shape [n x c x h x w] -// The weight is of shape [oc x wc x wh x ww], where wc == c -// The output is of shape [n x oc x oh x ow] -// The bias is of shape [oc] -template < - typename IT = float, - typename WT = IT, - typename BT = IT, - typename OT = IT, - bool quantized = false> -__attribute__((noinline)) void conv2d_nchw_core_generic( - // All the arrays - const IT* __restrict__ p_in, - const WT* __restrict__ p_weight, - const BT* __restrict__ p_bias, - OT* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t c, - int32_t h, - int32_t w, - int32_t oc, - int32_t wc, - int32_t wh, - int32_t ww, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Optional args that are only relevant for quantized convolution - // input zero point - IT in_zero_point = 0, - // weight zero point - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - OT out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const IT* in_batch = p_in + _n * c * h * w; - OT* out_batch = p_out + _n * oc * oh * ow; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - OT* out_plane = out_batch + _oc * oh * ow; - const WT* weight_batch = p_weight + _oc * wc * wh * ww; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of size - // icpg x h x w, with a stencil of size icpg x wh x ww, to compute an - // output channel of size 1 x oh x ow. - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to the - // output channel being computed) with the corresponding weight - // channel. - // If the padding is 0, and dilation is 1, then we can remove the - // unnecessary checks, and simplify the code so that it can be - // vectorized by Tensilica compiler. - if (zero_pad_unit_dilation) { - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - const IT* in_plane = in_batch + _ic * h * w; - const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - int ioff = (_h + _wh) * w + (_w + _ww); - int woff = _wh * ww + _ww; - float lhs = in_plane[ioff] - in_zero_point; - float rhs = weight_plane[woff] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } else { - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - const IT* in_plane = in_batch + _ic * h * w; - const WT* weight_plane = weight_batch + (_ic - sic) * wh * ww; - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - if (((_h + d0 * _wh - p0) >= 0) && - ((_h + d0 * _wh - p0) < h) && - ((_w + d1 * _ww - p1) >= 0) && - ((_w + d1 * _ww - p1) < w)) { - int ioff = - (_h + d0 * _wh - p0) * w + (_w + d1 * _ww - p1); - int woff = _wh * ww + _ww; - float lhs = in_plane[ioff] - in_zero_point; - float rhs = weight_plane[woff] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } - } - if (quantized) { - float val = bias_scale * acc; - out_plane[_oh * ow + _ow] = - ::impl::generic::kernels::quantize( - val, inv_out_scale, out_zero_point); - } else { - out_plane[_oh * ow + _ow] = acc; - } - } - } - } - } - } -} - -// The quantized convolution kernel. in_scale and weight_scale are implicit in -// bias_scale, since it is a product of the two. The kernel will branch to -// quantized::conv1d or quantized::conv2d based on the dimensionality of -// activation tensor. -void quantized_conv2d_nchw( - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, c, h, w] - const int n = input.size(0); - const int c = input.size(1); - const int h = conv1d ? 1 : input.size(2); - const int w = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wc, wh, ww] - const int oc = weight.size(0); - const int wc = weight.size(1); - const int wh = conv1d ? 1 : weight.size(2); - const int ww = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oc, oh, ow] - const int oh = conv1d ? 1 : out.size(2); - const int ow = conv1d ? out.size(2) : out.size(3); - -#define typed_quantized_conv2d_nchw(ctype, dtype) \ - case ScalarType::dtype: { \ - conv2d_nchw_core_generic( \ - input.const_data_ptr(), \ - weight.const_data_ptr(), \ - bias.const_data_ptr(), \ - out.mutable_data_ptr(), \ - n, \ - c, \ - h, \ - w, \ - oc, \ - wc, \ - wh, \ - ww, \ - oh, \ - ow, \ - stride[0], \ - stride[1], \ - padding[0], \ - padding[1], \ - dilation[0], \ - dilation[1], \ - groups, \ - in_zero_point, \ - weight_zero_point, \ - bias_scale, \ - output_scale, \ - (ctype)output_zero_point); \ - break; \ - } - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nchw); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_conv2d_nchw -} - -void quantized_conv2d_nchw_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - const Tensor& weight_zero_point, - const Tensor& bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED const Tensor& out_multiplier, - __ET_UNUSED const Tensor& out_shift, - Tensor& out) { - const float bias_scale_float = bias_scale.const_data_ptr()[0]; - const int32_t weight_zero_point_int = - weight_zero_point.const_data_ptr()[0]; - quantized_conv2d_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point_int, - bias_scale_float, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nchw_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - bool channel_last, - Tensor& out) { - quantized_conv2d_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv1d_ncl_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nchw( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace generic -} // namespace impl diff --git a/backends/cadence/generic/operators/quantized_conv2d_nhwc_out.cpp b/backends/cadence/generic/operators/quantized_conv2d_nhwc_out.cpp deleted file mode 100644 index eca836dcc94..00000000000 --- a/backends/cadence/generic/operators/quantized_conv2d_nhwc_out.cpp +++ /dev/null @@ -1,554 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -namespace impl { -namespace generic { -namespace native { - -using ::executorch::aten::IntArrayRef; -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; - -template < - typename IT = float, - typename WT = IT, - typename BT = IT, - typename OT = IT, - bool quantized = false> -__attribute__((noinline)) void conv2d_nhwc_core_generic( - // All the arrays - const IT* __restrict__ p_in, - const WT* __restrict__ p_weight, - const BT* __restrict__ p_bias, - OT* __restrict__ p_out, - // The array sizes - int32_t n, - int32_t h, - int32_t w, - int32_t c, - int32_t oc, - int32_t wh, - int32_t ww, - int32_t wc, - int32_t oh, - int32_t ow, - // Stride - int16_t s0, - int16_t s1, - // Padding - int16_t p0, - int16_t p1, - // Dilation - int16_t d0, - int16_t d1, - // Group for depthwise conv - int16_t groups, - // Optional args that are only relevant for quantized convolution - // input zero point - IT in_zero_point = 0, - // weight zero point - int32_t weight_zero_point = 0, - float bias_scale = 1, - float out_scale = 1, - OT out_zero_point = 0) { - float inv_out_scale = 1. / out_scale; - bool zero_pad_unit_dilation = d0 == 1 && d1 == 1 && p0 == 0 && p1 == 0; - - // Compute the number of in and out channels per group - const int ocpg = oc / groups; - const int icpg = c / groups; - - // Iterate over all the output batches (i.e., n) - for (int _n = 0; _n < n; ++_n) { - const IT* in_batch = p_in + _n * h * w * c; - OT* out_batch = p_out + _n * oh * ow * oc; - for (int _h = 0, _oh = 0; _oh < oh; _h += s0, ++_oh) { - for (int _w = 0, _ow = 0; _ow < ow; _w += s1, ++_ow) { - OT* out_line = out_batch + (_oh * ow + _ow) * oc; - // Compute separable convolution for each group - for (int _g = 0; _g < groups; ++_g) { - // Identify the input and output channels involved in the computation - // of this group - int sic = _g * icpg; - int soc = _g * ocpg; - // Populate all the output channels in the group - for (int _oc = soc; _oc < soc + ocpg; ++_oc) { - const WT* weight_batch = p_weight + _oc * wh * ww * wc; - // We compute one output channel at a time. The computation can be - // thought of as a stencil computation: we iterate over an input of - // size h x w x icpg, with a stencil of size wh x ww x icpg, to - // compute an output channel of size oh x ow x 1. - float acc = p_bias[_oc]; - // Below is the stencil computation that performs the hadamard - // product+accumulation of each input channel (contributing to - // the output channel being computed) with the corresponding - // weight channel. If the padding is 0, and dilation is 1, then - // we can remove the unnecessary checks, and simplify the code - // so that it can be vectorized by Tensilica compiler.x`` - if (zero_pad_unit_dilation) { - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - const IT* in_line = - in_batch + (_h + _wh) * w * c + (_w + _ww) * c; - const WT* weight_line = - weight_batch + _wh * ww * wc + _ww * wc; - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - float lhs = in_line[_ic] - in_zero_point; - float rhs = weight_line[_ic - sic] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } else { - for (int _wh = 0; _wh < wh; ++_wh) { - for (int _ww = 0; _ww < ww; ++_ww) { - if (((_h + d0 * _wh - p0) >= 0) && - ((_h + d0 * _wh - p0) < h) && - ((_w + d1 * _ww - p1) >= 0) && - ((_w + d1 * _ww - p1 < w))) { - const IT* in_line = in_batch + - (_h + d0 * _wh - p0) * w * c + (_w + d1 * _ww - p1) * c; - const WT* weight_line = - weight_batch + _wh * ww * wc + _ww * wc; - for (int _ic = sic; _ic < sic + icpg; ++_ic) { - float lhs = in_line[_ic] - in_zero_point; - float rhs = weight_line[_ic - sic] - - (quantized ? weight_zero_point : 0); - acc += lhs * rhs; - } - } - } - } - } - if (quantized) { - float val = bias_scale * acc; - out_line[_oc] = ::impl::generic::kernels::quantize( - val, inv_out_scale, out_zero_point); - } else { - out_line[_oc] = acc; - } - } - } - } - } - } -} - -void quantized_conv2d_nhwc( - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int16_t groups, - int32_t in_zero_point, - int32_t weight_zero_point, - float bias_scale, - float output_scale, - int32_t output_zero_point, - Tensor& out) { - bool conv1d = input.dim() == 3; - // input = [n, h, w, c] - const int n = input.size(0); - const int h = conv1d ? 1 : input.size(1); - const int w = conv1d ? input.size(1) : input.size(2); - const int c = conv1d ? input.size(2) : input.size(3); - // weight = [oc, wh, ww, wc] - const int oc = weight.size(0); - const int wh = conv1d ? 1 : weight.size(1); - const int ww = conv1d ? weight.size(1) : weight.size(2); - const int wc = conv1d ? weight.size(2) : weight.size(3); - // output = [n, oh, ow, oc] - const int oh = conv1d ? 1 : out.size(1); - const int ow = conv1d ? out.size(1) : out.size(2); - -#define typed_quantized_conv2d_nhwc(ctype, dtype) \ - case ScalarType::dtype: { \ - conv2d_nhwc_core_generic( \ - input.const_data_ptr(), \ - weight.const_data_ptr(), \ - bias.const_data_ptr(), \ - out.mutable_data_ptr(), \ - n, \ - h, \ - w, \ - c, \ - oc, \ - wh, \ - ww, \ - wc, \ - oh, \ - ow, \ - stride[0], \ - stride[1], \ - padding[0], \ - padding[1], \ - dilation[0], \ - dilation[1], \ - groups, \ - in_zero_point, \ - weight_zero_point, \ - bias_scale, \ - output_scale, \ - (ctype)output_zero_point); \ - break; \ - } - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_conv2d_nhwc); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_conv2d_nhwc -} - -void quantized_conv2d_nhwc_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - const Tensor& weight_zero_point, - const Tensor& bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED const Tensor& out_multiplier, - __ET_UNUSED const Tensor& out_shift, - Tensor& out) { - const float bias_scale_float = bias_scale.const_data_ptr()[0]; - const int32_t weight_zero_point_int = - weight_zero_point.const_data_ptr()[0]; - quantized_conv2d_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point_int, - bias_scale_float, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nhwc_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - bool channel_last, - Tensor& out) { - quantized_conv2d_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv1d_nlc_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -void quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& weight, - const Tensor& bias, - IntArrayRef stride, - IntArrayRef padding, - IntArrayRef dilation, - int64_t groups, - int64_t in_zero_point, - int64_t weight_zero_point, - double bias_scale, - double output_scale, - int64_t output_zero_point, - __ET_UNUSED int64_t out_multiplier, - __ET_UNUSED int64_t out_shift, - Tensor& out) { - quantized_conv2d_nhwc( - input, - weight, - bias, - stride, - padding, - dilation, - groups, - in_zero_point, - weight_zero_point, - bias_scale, - output_scale, - output_zero_point, - out); -} - -} // namespace native -} // namespace generic -} // namespace impl diff --git a/backends/cadence/generic/operators/quantized_fully_connected_out.cpp b/backends/cadence/generic/operators/quantized_fully_connected_out.cpp deleted file mode 100644 index 5a583e3d1bd..00000000000 --- a/backends/cadence/generic/operators/quantized_fully_connected_out.cpp +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include -#include - -namespace impl { -namespace generic { -namespace native { - -using ::executorch::aten::ScalarType; -using ::executorch::aten::Tensor; -using ::executorch::runtime::KernelRuntimeContext; -using std::optional; - -void quantized_fully_connected_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& in, - const Tensor& weight, - const Tensor& bias, - int64_t in_zero_point, - const Tensor& weight_zero_point_t, - const Tensor& out_multiplier, - const Tensor& out_shift, - int64_t out_zero_point, - __ET_UNUSED const optional& offset, - Tensor& out) { -#define typed_quantized_linear(ctype, dtype) \ - case ScalarType::dtype: { \ - quantized_linear_( \ - in, \ - weight, \ - bias, \ - in_zero_point, \ - weight_zero_point_t, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } -#undef typed_quantized_linear -} - -void quantized_fully_connected_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& in, - const Tensor& weight, - const Tensor& bias, - int64_t in_zero_point, - int64_t weight_zero_point, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - __ET_UNUSED const optional& offset, - Tensor& out) { -#define typed_quantized_linear(ctype, dtype) \ - case ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - in, \ - weight, \ - bias, \ - in_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } -#undef typed_quantized_linear -} - -void quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& in, - const Tensor& weight, - const Tensor& bias, - int64_t in_zero_point, - int64_t weight_zero_point, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - __ET_UNUSED const optional& offset, - Tensor& out) { -#define typed_quantized_linear(ctype, dtype) \ - case ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - in, \ - weight, \ - bias, \ - in_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } -#undef typed_quantized_linear -} - -void quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& in, - const Tensor& weight, - const Tensor& bias, - int64_t in_zero_point, - int64_t weight_zero_point, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - __ET_UNUSED const optional& offset, - Tensor& out) { -#define typed_quantized_linear(ctype, dtype) \ - case ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - in, \ - weight, \ - bias, \ - in_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } -#undef typed_quantized_linear -} - -} // namespace native -} // namespace generic -} // namespace impl diff --git a/backends/cadence/generic/operators/quantized_ops.h b/backends/cadence/generic/operators/quantized_linear.h similarity index 85% rename from backends/cadence/generic/operators/quantized_ops.h rename to backends/cadence/generic/operators/quantized_linear.h index 15b2c858aed..1a7d0390dd4 100644 --- a/backends/cadence/generic/operators/quantized_ops.h +++ b/backends/cadence/generic/operators/quantized_linear.h @@ -8,10 +8,17 @@ #pragma once +#include + #include -#include +#include +#include + +namespace impl::generic::quantized { + +constexpr size_t kTensorDimensionLimit = 16; -template +template inline __attribute__((always_inline)) void quantized_linear_per_tensor_( const ::executorch::aten::Tensor& src, const ::executorch::aten::Tensor& weight, @@ -27,19 +34,17 @@ inline __attribute__((always_inline)) void quantized_linear_per_tensor_( // output comes in empty with shape [leading_dims, out_dim] // Perform matrix multiply (M x N) x (N x P)' => M x P const int64_t leading_dims = - executorch::runtime::getLeadingDims(src, src.dim() - 1); + ::executorch::runtime::getLeadingDims(src, src.dim() - 1); const int64_t out_dim = weight.size(0); // = out_dim const int64_t in_dim = weight.size(1); // = in_dim - const T* __restrict__ in_data = src.const_data_ptr(); - const T* __restrict__ weight_data = weight.const_data_ptr(); + const IT* __restrict__ in_data = src.const_data_ptr(); + const WT* __restrict__ weight_data = weight.const_data_ptr(); const int32_t* __restrict__ bias_data = bias.const_data_ptr(); - T* __restrict__ out_data = out.mutable_data_ptr(); - + IT* __restrict__ out_data = out.mutable_data_ptr(); // Compute the requant_scale from out_multiplier and out_shift const float requant_scale = -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); - for (size_t i = 0; i < leading_dims; ++i) { for (size_t j = 0; j < out_dim; ++j) { int32_t sum = bias_data[j]; @@ -49,13 +54,13 @@ inline __attribute__((always_inline)) void quantized_linear_per_tensor_( (int32_t)weight_data[j * in_dim + k] - (int32_t)weight_zero_point; sum += x * w; } - out_data[i * out_dim + j] = ::impl::generic::kernels::quantize( + out_data[i * out_dim + j] = ::impl::generic::kernels::quantize( sum, requant_scale, out_zero_point); } } } -template +template inline __attribute__((always_inline)) void quantized_linear_per_tensor_( const ::executorch::aten::Tensor& src, const ::executorch::aten::Tensor& weight, @@ -68,7 +73,7 @@ inline __attribute__((always_inline)) void quantized_linear_per_tensor_( ::executorch::aten::Tensor& out) { // Get the zero_point of weight. int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; - quantized_linear_per_tensor_( + quantized_linear_per_tensor_( src, weight, bias, @@ -80,7 +85,7 @@ inline __attribute__((always_inline)) void quantized_linear_per_tensor_( out); } -template +template inline __attribute__((always_inline)) void quantized_linear_per_channel_( const ::executorch::aten::Tensor& src, const ::executorch::aten::Tensor& weight, @@ -95,13 +100,13 @@ inline __attribute__((always_inline)) void quantized_linear_per_channel_( // weight comes in shape [out_dim, in_dim] // output comes in empty with shape [leading_dims, out_dim] // Perform matrix multiply (M x N) x (N x P)' => M x P - int64_t leading_dims = - executorch::runtime::getLeadingDims(src, src.dim() - 1); + const int64_t leading_dims = + ::executorch::runtime::getLeadingDims(src, src.dim() - 1); const int64_t out_dim = weight.size(0); // = out_dim const int64_t in_dim = weight.size(1); // = in_dim const T* __restrict__ in_data = src.const_data_ptr(); - const T* __restrict__ weight_data = weight.const_data_ptr(); + const WT* __restrict__ weight_data = weight.const_data_ptr(); const int32_t* __restrict__ bias_data = bias.const_data_ptr(); T* __restrict__ out_data = out.mutable_data_ptr(); const int32_t* __restrict__ out_multiplier_data = @@ -127,7 +132,7 @@ inline __attribute__((always_inline)) void quantized_linear_per_channel_( } } -template +template inline __attribute__((always_inline)) void quantized_linear_( const ::executorch::aten::Tensor& src, const ::executorch::aten::Tensor& weight, @@ -144,7 +149,7 @@ inline __attribute__((always_inline)) void quantized_linear_( out_multiplier.const_data_ptr(); const int32_t* __restrict__ out_shift_data = out_shift.const_data_ptr(); - quantized_linear_per_tensor_( + quantized_linear_per_tensor_( src, weight, bias, @@ -158,7 +163,7 @@ inline __attribute__((always_inline)) void quantized_linear_( } // Use per-channel quantization kernel. - quantized_linear_per_channel_( + quantized_linear_per_channel_( src, weight, bias, @@ -170,7 +175,7 @@ inline __attribute__((always_inline)) void quantized_linear_( out); } -template +template inline __attribute__((always_inline)) void quantized_linear_( const ::executorch::aten::Tensor& src, const ::executorch::aten::Tensor& weight, @@ -183,7 +188,7 @@ inline __attribute__((always_inline)) void quantized_linear_( ::executorch::aten::Tensor& out) { // Get the zero_point of weight. int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; - quantized_linear_( + quantized_linear_( src, weight, bias, @@ -194,3 +199,5 @@ inline __attribute__((always_inline)) void quantized_linear_( out_zero_point, out); } + +} // namespace impl::generic::quantized diff --git a/backends/cadence/generic/operators/quantized_linear_out.cpp b/backends/cadence/generic/operators/quantized_linear_out.cpp deleted file mode 100644 index 7289ec19566..00000000000 --- a/backends/cadence/generic/operators/quantized_linear_out.cpp +++ /dev/null @@ -1,233 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace impl { -namespace generic { -namespace native { - -using executorch::aten::Tensor; -using executorch::runtime::getLeadingDims; -using executorch::runtime::KernelRuntimeContext; - -template -void inline _typed_quantized_linear( - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - int64_t src_zero_point, - const Tensor& weight_zero_point_t, - const Tensor& out_multiplier, - const Tensor& out_shift, - int64_t out_zero_point, - Tensor& out) { - const T* __restrict__ src_data = src.const_data_ptr(); - const T* __restrict__ weight_data = weight.const_data_ptr(); - const int32_t* __restrict__ bias_data = bias.const_data_ptr(); - T* __restrict__ out_data = out.mutable_data_ptr(); - - int32_t weight_zero_point = weight_zero_point_t.const_data_ptr()[0]; - - // input comes in shape [batch_size, in_dim] - // weight comes in shape [out_dim, in_dim] - // output comes in empty with shape [batch_size, out_dim] - // Perform matrix multiply (M x N) x (N x P) => M x P - const auto M = weight.size(0); // = out_dim - const auto N = weight.size(1); // = in_dim - - // Given an N-dimensional input [d0, d1, d2, ..., d_{N-2}, d_{N-1}], the - // leading dimensions is d0 * d1 * ... * d_{N-2} - const auto leading_dims = getLeadingDims(src, src.dim() - 1); - - ET_CHECK_MSG( - out_multiplier.numel() == 1, "out_multiplier should have one element"); - ET_CHECK_MSG( - out_shift.numel() == 1, "out_multiplier should have one element"); - - const int32_t* __restrict__ out_multiplier_data = - out_multiplier.const_data_ptr(); - const int32_t* __restrict__ out_shift_data = - out_shift.const_data_ptr(); - - // Compute the out_scale from out_multiplier and out_shift - const float out_scale = - -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); - - for (int i = 0; i < leading_dims; ++i) { - for (int j = 0; j < M; ++j) { - float sum = bias_data[j]; - for (int k = 0; k < N; ++k) { - sum += (src_data[i * N + k] - src_zero_point) * - (weight_data[j * N + k] - weight_zero_point); - } - out_data[i * M + j] = - kernels::quantize(sum, out_scale, out_zero_point); - } - } -} - -void quantized_linear_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - int64_t src_zero_point, - const Tensor& weight_zero_point_t, - const Tensor& out_multiplier, - const Tensor& out_shift, - int64_t out_zero_point, - __ET_UNUSED const std::optional& offset, - Tensor& out) { - // TODO: refactor to use switch case as quantized_linear_per_tensor_out - if (out.scalar_type() == executorch::aten::ScalarType::Byte) { - _typed_quantized_linear( - src, - weight, - bias, - src_zero_point, - weight_zero_point_t, - out_multiplier, - out_shift, - out_zero_point, - out); - } else if (out.scalar_type() == executorch::aten::ScalarType::Char) { - _typed_quantized_linear( - src, - weight, - bias, - src_zero_point, - weight_zero_point_t, - out_multiplier, - out_shift, - out_zero_point, - out); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(src.scalar_type())); - } -} - -void quantized_linear_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - const int64_t src_zero_point, - const int64_t weight_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - const int64_t out_zero_point, - __ET_UNUSED const std::optional& offset, - Tensor& out) { -#define typed_quantized_linear_per_tensor(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - src, \ - weight, \ - bias, \ - src_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - executorch::aten::ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); - } -#undef typed_quantized_linear_per_tensor -} - -void quantized_linear_asym8sxasym8s_asym8s_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - const int64_t src_zero_point, - const int64_t weight_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - const int64_t out_zero_point, - __ET_UNUSED const std::optional& offset, - Tensor& out) { -#define typed_quantized_linear_per_tensor(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - src, \ - weight, \ - bias, \ - src_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - executorch::aten::ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); - } -#undef typed_quantized_linear_per_tensor -} - -void quantized_linear_asym8uxasym8u_asym8u_per_tensor_out( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& src, - const Tensor& weight, - const Tensor& bias, - const int64_t src_zero_point, - const int64_t weight_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - const int64_t out_zero_point, - __ET_UNUSED const std::optional& offset, - Tensor& out) { -#define typed_quantized_linear_per_tensor(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_linear_per_tensor_( \ - src, \ - weight, \ - bias, \ - src_zero_point, \ - weight_zero_point, \ - out_multiplier, \ - out_shift, \ - out_zero_point, \ - out); \ - break; \ - } - - executorch::aten::ScalarType dtype = out.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); - } -#undef typed_quantized_linear_per_tensor -} - -} // namespace native -} // namespace generic -} // namespace impl diff --git a/backends/cadence/generic/operators/quantized_matmul_out.cpp b/backends/cadence/generic/operators/quantized_matmul_out.cpp deleted file mode 100644 index e983235fc9f..00000000000 --- a/backends/cadence/generic/operators/quantized_matmul_out.cpp +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -namespace impl { -namespace generic { -namespace native { - -using executorch::aten::Tensor; -using executorch::runtime::getLeadingDims; -using executorch::runtime::KernelRuntimeContext; - -// The quantized matmul. The quantized matmul accumulates in a wider register, -// whose type is TA. -template < - typename TZ, - typename TA = float, - bool transposed = false, - typename TX = TZ, - typename TY = TZ> -__attribute__((noinline)) void qmatmul( - TZ* __restrict__ Z, - int32_t Z_multiplier, - int32_t Z_shift, - int32_t Z_zero_point, - const TX* __restrict__ X, - int32_t X_zero_point, - const TY* __restrict__ y, - int32_t Y_zero_point, - size_t m, - size_t n, - size_t p) { - // Compute the Z_scale from Z_multiplier and Z_shift - const float Z_scale = -Z_multiplier * 1.0 / (1 << 31) * pow(2, Z_shift); - for (size_t i = 0; i < m; ++i) { - for (size_t j = 0; j < p; ++j) { - TA sum = 0; - for (size_t k = 0; k < n; ++k) { - if (transposed) { - sum += (X[i * n + k] - X_zero_point) * (y[j * n + k] - Y_zero_point); - } else { - sum += (X[i * n + k] - X_zero_point) * (y[k * p + j] - Y_zero_point); - } - } - Z[i * p + j] = kernels::quantize(sum, Z_scale, Z_zero_point); - } - } -} - -template -void inline _typed_quantized_matmul( - const Tensor& X, - int64_t X_zero_point, - const Tensor& Y, - int64_t Y_zero_point, - const std::optional& bias, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - bool transposed, - Tensor& out) { - size_t batch_size = getLeadingDims(X, X.dim() - 2); - size_t leading_dim = X.size(X.dim() - 2); - size_t out_dim = Y.size(Y.dim() - 1 - transposed); - size_t in_dim = X.size(X.dim() - 1); - - T* __restrict__ out_data = out.mutable_data_ptr(); - const T* __restrict__ X_data = X.const_data_ptr(); - const T* __restrict__ Y_data = Y.const_data_ptr(); - for (size_t i = 0; i < batch_size; ++i) { - const T* x = X_data + i * leading_dim * in_dim; - const T* y = Y_data + i * in_dim * out_dim; - T* z = out_data + i * leading_dim * out_dim; - if (transposed) { - qmatmul( - z, - static_cast(out_multiplier), - static_cast(out_shift), - static_cast(out_zero_point), - x, - static_cast(X_zero_point), - y, - static_cast(Y_zero_point), - leading_dim, - in_dim, - out_dim); - } else { - qmatmul( - z, - static_cast(out_multiplier), - static_cast(out_shift), - static_cast(out_zero_point), - x, - static_cast(X_zero_point), - y, - static_cast(Y_zero_point), - leading_dim, - in_dim, - out_dim); - } - } -} - -void quantized_matmul_out( - KernelRuntimeContext& ctx, - const Tensor& X, - int64_t X_zero_point, - const Tensor& Y, - int64_t Y_zero_point, - const std::optional& bias, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - bool transposed, - Tensor& out) { - if (out.scalar_type() == executorch::aten::ScalarType::Byte) { - _typed_quantized_matmul( - X, - X_zero_point, - Y, - Y_zero_point, - bias, - out_multiplier, - out_shift, - out_zero_point, - transposed, - out); - } else if (out.scalar_type() == executorch::aten::ScalarType::Char) { - _typed_quantized_matmul( - X, - X_zero_point, - Y, - Y_zero_point, - bias, - out_multiplier, - out_shift, - out_zero_point, - transposed, - out); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(X.scalar_type())); - } -} - -void quantized_matmul_asym8sxasym8s_asym8s_out( - KernelRuntimeContext& ctx, - const Tensor& X, - int64_t X_zero_point, - const Tensor& Y, - int64_t Y_zero_point, - const std::optional& bias, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - bool transposed, - Tensor& out) { - _typed_quantized_matmul( - X, - X_zero_point, - Y, - Y_zero_point, - bias, - out_multiplier, - out_shift, - out_zero_point, - transposed, - out); -} - -void quantized_matmul_asym8uxasym8u_asym8u_out( - KernelRuntimeContext& ctx, - const Tensor& X, - int64_t X_zero_point, - const Tensor& Y, - int64_t Y_zero_point, - const std::optional& bias, - int64_t out_multiplier, - int64_t out_shift, - int64_t out_zero_point, - bool transposed, - Tensor& out) { - _typed_quantized_matmul( - X, - X_zero_point, - Y, - Y_zero_point, - bias, - out_multiplier, - out_shift, - out_zero_point, - transposed, - out); -} - -} // namespace native -} // namespace generic -} // namespace impl diff --git a/backends/cadence/generic/operators/quantized_op_macros.h b/backends/cadence/generic/operators/quantized_op_macros.h new file mode 100644 index 00000000000..eda6de2e8d7 --- /dev/null +++ b/backends/cadence/generic/operators/quantized_op_macros.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +// Generate kernels that perform elementwise arithmetic on two quantized +// tensors. The tensors are either the same size, or the second tensor is a +// scalar. +#define DECLARE_POINTWISE_TENSOR_QUANTIZED_BINARY_OP(BINARY_FUNC_NAME, OP) \ + template \ + void BINARY_FUNC_NAME( \ + const ::executorch::aten::Tensor& X, \ + float X_scale, \ + int32_t X_zero_point, \ + const ::executorch::aten::Tensor& Y, \ + float Y_scale, \ + int32_t Y_zero_point, \ + float out_scale, \ + int32_t out_zero_point, \ + ::executorch::aten::Tensor& out) { \ + const T* __restrict__ X_data = X.const_data_ptr(); \ + const T* __restrict__ Y_data = Y.const_data_ptr(); \ + T* __restrict__ out_data = out.mutable_data_ptr(); \ + float inv_out_scale = 1.0f / out_scale; \ + for (size_t i = 0, e = X.numel(); i < e; ++i) { \ + float x = ::impl::generic::kernels::dequantize( \ + X_data[i], X_scale, X_zero_point); \ + float y = ::impl::generic::kernels::dequantize( \ + Y_data[i], Y_scale, Y_zero_point); \ + float z = x OP y; \ + out_data[i] = ::impl::generic::kernels::quantize( \ + z, inv_out_scale, out_zero_point); \ + } \ + } diff --git a/backends/cadence/generic/operators/quantized_relu_out.cpp b/backends/cadence/generic/operators/quantized_relu_out.cpp deleted file mode 100644 index 622fd901084..00000000000 --- a/backends/cadence/generic/operators/quantized_relu_out.cpp +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include -#include - -namespace impl { -namespace generic { -namespace native { - -using executorch::aten::Tensor; -using executorch::runtime::KernelRuntimeContext; - -template -void quantized_relu_( - const Tensor& input, - const Tensor& in_zero_point, - const int64_t out_zero_point, - const Tensor& out_multiplier, - const Tensor& out_shift, - Tensor& output) { - T q_zero_point = in_zero_point.const_data_ptr()[0]; - const T* __restrict__ in = input.const_data_ptr(); - T* __restrict__ out = output.mutable_data_ptr(); - - const int32_t* __restrict__ out_multiplier_data = - out_multiplier.const_data_ptr(); - const int32_t* __restrict__ out_shift_data = - out_shift.const_data_ptr(); - - // Compute the out_scale from out_multiplier and out_shift - const float out_scale = - -out_multiplier_data[0] * 1.0 / (1 << 31) * pow(2, out_shift_data[0]); - - for (size_t i = 0, e = input.numel(); i < e; ++i) { - const T temp = in[i] > q_zero_point ? (in[i] - q_zero_point) : 0; - out[i] = kernels::quantize(temp, out_scale, out_zero_point); - } -} - -void quantized_relu_out( - KernelRuntimeContext& ctx, - const Tensor& input, - const Tensor& in_zero_point, - const int64_t out_zero_point, - const Tensor& out_multiplier, - const Tensor& out_shift, - Tensor& output) { - if (input.scalar_type() == executorch::aten::ScalarType::Byte) { - quantized_relu_( - input, - in_zero_point, - out_zero_point, - out_multiplier, - out_shift, - output); - } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { - quantized_relu_( - input, - in_zero_point, - out_zero_point, - out_multiplier, - out_shift, - output); - } else { - ET_CHECK_MSG( - false, - "Unhandled input dtype %hhd", - static_cast(input.scalar_type())); - } -} - -template -void quantized_relu_per_tensor_out_( - __ET_UNUSED KernelRuntimeContext& ctx, - const Tensor& input, - const int64_t in_zero_point, - const int64_t out_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - Tensor& output) { - const T* __restrict__ in = input.const_data_ptr(); - T* __restrict__ out = output.mutable_data_ptr(); - - // Compute the out_scale from out_multiplier and out_shift - const float out_scale = -out_multiplier * 1.0 / (1 << 31) * pow(2, out_shift); - - for (size_t i = 0, e = input.numel(); i < e; ++i) { - const float temp = in[i] > in_zero_point ? (in[i] - in_zero_point) : 0; - out[i] = kernels::quantize(temp, out_scale, out_zero_point); - } -} - -void quantized_relu_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& input, - const int64_t in_zero_point, - const int64_t out_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - Tensor& output) { -#define typed_quantized_relu(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_relu_per_tensor_out_( \ - ctx, \ - input, \ - in_zero_point, \ - out_zero_point, \ - out_multiplier, \ - out_shift, \ - output); \ - break; \ - } - - executorch::aten::ScalarType dtype = input.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_relu -} - -void quantized_relu_asym8s_asym8s_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& input, - const int64_t in_zero_point, - const int64_t out_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - Tensor& output) { -#define typed_quantized_relu(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_relu_per_tensor_out_( \ - ctx, \ - input, \ - in_zero_point, \ - out_zero_point, \ - out_multiplier, \ - out_shift, \ - output); \ - break; \ - } - - executorch::aten::ScalarType dtype = input.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_relu -} - -void quantized_relu_asym8u_asym8u_per_tensor_out( - KernelRuntimeContext& ctx, - const Tensor& input, - const int64_t in_zero_point, - const int64_t out_zero_point, - const int64_t out_multiplier, - const int64_t out_shift, - Tensor& output) { -#define typed_quantized_relu(ctype, dtype) \ - case executorch::aten::ScalarType::dtype: { \ - quantized_relu_per_tensor_out_( \ - ctx, \ - input, \ - in_zero_point, \ - out_zero_point, \ - out_multiplier, \ - out_shift, \ - output); \ - break; \ - } - - executorch::aten::ScalarType dtype = input.scalar_type(); - switch (dtype) { - ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_relu) - default: - ET_DCHECK_MSG( - false, "Unhandled dtype %s", torch::executor::toString(dtype)); - } - -#undef typed_quantized_relu -} - -} // namespace native -} // namespace generic -} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index 20c4bbd44ea..6990a24db50 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -4,13 +4,38 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") def define_common_targets(): # Individual operator targets with optimized dependencies + # Type utilities for Cadence quantized operators runtime.cxx_library( - name = "im2row_out", - srcs = ["op_im2row.cpp"], - exported_headers = ["op_im2row.h"], + name = "cadence_type_util", + exported_headers = ["cadence_type_util.h"], + ) + + runtime.cxx_library( + name = "quantized_op_macros", + exported_headers = ["quantized_op_macros.h"], + exported_deps = [ + ":cadence_type_util", + "//executorch/runtime/kernel:kernel_includes", + ] + ) + + runtime.cxx_library( + name = "quantized_linear", + exported_headers = ["quantized_linear.h"], + exported_deps = [ + "//executorch/runtime/kernel:kernel_includes", + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + ] + ) + + runtime.cxx_library( + name = "op_dequantize_per_tensor", + srcs = ["op_dequantize_per_tensor.cpp"], + exported_headers = ["op_dequantize_per_tensor.h"], platforms = CXX, deps = [ "//executorch/runtime/kernel:kernel_includes", + "//executorch/backends/cadence/generic/kernels:cadence_kernels", ], visibility = [ "//executorch/backends/cadence/...", @@ -19,9 +44,9 @@ def define_common_targets(): ) runtime.cxx_library( - name = "op_requantize", - srcs = ["op_requantize.cpp"], - exported_headers = ["op_requantize.h"], + name = "op_quantize_per_tensor", + srcs = ["op_quantize_per_tensor.cpp"], + exported_headers = ["op_quantize_per_tensor.h"], platforms = CXX, deps = [ "//executorch/runtime/kernel:kernel_includes", @@ -34,13 +59,14 @@ def define_common_targets(): ) runtime.cxx_library( - name = "dequantize_per_tensor", - srcs = ["op_dequantize_per_tensor.cpp"], - exported_headers = ["op_dequantize_per_tensor.h"], + name = "op_where_scalar", + srcs = ["op_where_scalar.cpp"], + exported_headers = ["op_where_scalar.h"], platforms = CXX, deps = [ "//executorch/runtime/kernel:kernel_includes", - "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", ], visibility = [ "//executorch/backends/cadence/...", @@ -49,13 +75,14 @@ def define_common_targets(): ) runtime.cxx_library( - name = "quantize_per_tensor", - srcs = ["op_quantize_per_tensor.cpp"], - exported_headers = ["op_quantize_per_tensor.h"], + name = "op_rope", + srcs = ["op_rope.cpp"], + exported_headers = ["op_rope.h"], platforms = CXX, deps = [ "//executorch/runtime/kernel:kernel_includes", - "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", ], visibility = [ "//executorch/backends/cadence/...", @@ -64,13 +91,15 @@ def define_common_targets(): ) runtime.cxx_library( - name = "quantized_add_out", - srcs = ["quantized_add_out.cpp"], - exported_headers = ["operators.h", "quantized_ops.h"], + name = "op_linalg_svd", + srcs = ["op_linalg_svd.cpp"], + headers = ["op_linalg_svd.h"], platforms = CXX, deps = [ "//executorch/runtime/kernel:kernel_includes", - "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/util:tensor_util", + "//executorch/runtime/kernel:kernel_runtime_context", ], visibility = [ "//executorch/backends/cadence/...", @@ -79,13 +108,14 @@ def define_common_targets(): ) runtime.cxx_library( - name = "quantized_conv2d_nchw_out", - srcs = ["quantized_conv2d_nchw_out.cpp"], - exported_headers = ["operators.h", "quantized_ops.h"], + name = "op_roi_align_box_processor", + srcs = ["op_roi_align_box_processor.cpp"], + exported_headers = ["op_roi_align_box_processor.h"], platforms = CXX, deps = [ "//executorch/runtime/kernel:kernel_includes", - "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/kernel:kernel_runtime_context", ], visibility = [ "//executorch/backends/cadence/...", @@ -94,13 +124,15 @@ def define_common_targets(): ) runtime.cxx_library( - name = "quantized_conv2d_nhwc_out", - srcs = ["quantized_conv2d_nhwc_out.cpp"], - exported_headers = ["operators.h", "quantized_ops.h"], + name = "op_quantized_add", + srcs = ["op_quantized_add.cpp"], + exported_headers = ["op_quantized_add.h"], platforms = CXX, deps = [ - "//executorch/runtime/kernel:kernel_includes", "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", ], visibility = [ "//executorch/backends/cadence/...", @@ -109,13 +141,14 @@ def define_common_targets(): ) runtime.cxx_library( - name = "quantized_fully_connected_out", - srcs = ["quantized_fully_connected_out.cpp"], - exported_headers = ["operators.h", "quantized_ops.h"], + name = "op_quantized_conv1d", + srcs = ["op_quantized_conv1d.cpp"], + exported_headers = ["op_quantized_conv1d.h"], platforms = CXX, deps = [ - "//executorch/runtime/kernel:kernel_includes", + ":cadence_type_util", "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", ], visibility = [ "//executorch/backends/cadence/...", @@ -124,13 +157,14 @@ def define_common_targets(): ) runtime.cxx_library( - name = "quantized_layer_norm", - srcs = ["quantized_layer_norm.cpp"], - exported_headers = ["operators.h", "quantized_ops.h"], + name = "op_quantized_conv2d", + srcs = ["op_quantized_conv2d.cpp"], + exported_headers = ["op_quantized_conv2d.h"], platforms = CXX, deps = [ - "//executorch/runtime/kernel:kernel_includes", + ":cadence_type_util", "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", ], visibility = [ "//executorch/backends/cadence/...", @@ -139,13 +173,15 @@ def define_common_targets(): ) runtime.cxx_library( - name = "quantized_linear_out", - srcs = ["quantized_linear_out.cpp"], - exported_headers = ["operators.h", "quantized_ops.h"], + name = "op_quantized_fully_connected", + srcs = ["op_quantized_fully_connected.cpp"], + exported_headers = ["op_quantized_fully_connected.h"], platforms = CXX, deps = [ - "//executorch/runtime/kernel:kernel_includes", "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_linear", + ":quantized_op_macros", ], visibility = [ "//executorch/backends/cadence/...", @@ -154,13 +190,31 @@ def define_common_targets(): ) runtime.cxx_library( - name = "quantized_matmul_out", - srcs = ["quantized_matmul_out.cpp"], - exported_headers = ["operators.h", "quantized_ops.h"], + name = "op_quantized_layer_norm", + srcs = ["op_quantized_layer_norm.cpp"], + exported_headers = ["op_quantized_layer_norm.h"], platforms = CXX, deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_linear", + srcs = ["op_quantized_linear.cpp"], + exported_headers = ["op_quantized_linear.h"], + platforms = CXX, + deps = [ "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ":quantized_linear", ], visibility = [ "//executorch/backends/cadence/...", @@ -169,13 +223,30 @@ def define_common_targets(): ) runtime.cxx_library( - name = "quantized_relu_out", - srcs = ["quantized_relu_out.cpp"], - exported_headers = ["operators.h", "quantized_ops.h"], + name = "op_quantized_relu", + srcs = ["op_quantized_relu.cpp"], + exported_headers = ["op_quantized_relu.h"], platforms = CXX, deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", + ], + visibility = [ + "//executorch/backends/cadence/...", + "@EXECUTORCH_CLIENTS", + ], + ) + + runtime.cxx_library( + name = "op_quantized_matmul", + srcs = ["op_quantized_matmul.cpp"], + exported_headers = ["op_quantized_matmul.h"], + platforms = CXX, + deps = [ "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/runtime/kernel:kernel_includes", + ":quantized_op_macros", ], visibility = [ "//executorch/backends/cadence/...", @@ -184,14 +255,15 @@ def define_common_targets(): ) runtime.cxx_library( - name = "op_where_scalar", - srcs = ["op_where_scalar.cpp"], - exported_headers = ["op_where_scalar.h", "operators.h"], + name = "op_quantized_mul", + srcs = ["op_quantized_mul.cpp"], + exported_headers = ["op_quantized_mul.h"], platforms = CXX, deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", "//executorch/runtime/kernel:kernel_includes", - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/kernel:kernel_runtime_context", + "//executorch/kernels/portable/cpu:scalar_utils", + ":quantized_op_macros", ], visibility = [ "//executorch/backends/cadence/...", @@ -200,14 +272,16 @@ def define_common_targets(): ) runtime.cxx_library( - name = "op_rope", - srcs = ["op_rope.cpp"], - exported_headers = ["op_rope.h", "operators.h"], + name = "op_quantized_softmax", + srcs = ["op_quantized_softmax.cpp"], + exported_headers = ["op_quantized_softmax.h"], platforms = CXX, deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", + "//executorch/kernels/portable/cpu/util:reduce_util", + "//executorch/kernels/portable/cpu/util:functional_util", "//executorch/runtime/kernel:kernel_includes", - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/kernel:kernel_runtime_context", + ":quantized_op_macros", ], visibility = [ "//executorch/backends/cadence/...", @@ -216,15 +290,13 @@ def define_common_targets(): ) runtime.cxx_library( - name = "op_linalg_svd", - srcs = ["op_linalg_svd.cpp"], - headers = ["op_linalg_svd.h"], + name = "op_quantized_embedding_byte", + srcs = ["op_quantized_embedding_byte.cpp"], + exported_headers = ["op_quantized_embedding_byte.h"], platforms = CXX, deps = [ "//executorch/runtime/kernel:kernel_includes", - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/core/exec_aten/util:tensor_util", - "//executorch/runtime/kernel:kernel_runtime_context", + ":quantized_op_macros", ], visibility = [ "//executorch/backends/cadence/...", @@ -233,14 +305,14 @@ def define_common_targets(): ) runtime.cxx_library( - name = "op_roi_align_box_processor", - srcs = ["op_roi_align_box_processor.cpp"], - exported_headers = ["op_roi_align_box_processor.h", "operators.h"], + name = "op_requantize", + srcs = ["op_requantize.cpp"], + exported_headers = ["op_requantize.h"], platforms = CXX, deps = [ + "//executorch/backends/cadence/generic/kernels:cadence_kernels", "//executorch/runtime/kernel:kernel_includes", - "//executorch/runtime/core/exec_aten:lib", - "//executorch/runtime/kernel:kernel_runtime_context", + ":quantized_op_macros", ], visibility = [ "//executorch/backends/cadence/...", @@ -315,6 +387,7 @@ def define_common_targets(): ], ) + runtime.cxx_library( name = "op_softmax", srcs = ["op_softmax.cpp"], diff --git a/backends/cadence/hifi/operators/operators.h b/backends/cadence/hifi/operators/operators.h index f7f5194d91a..90028535848 100644 --- a/backends/cadence/hifi/operators/operators.h +++ b/backends/cadence/hifi/operators/operators.h @@ -15,6 +15,11 @@ _(uint8_t, Byte) \ _(int8_t, Char) +#define ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) + namespace impl { namespace HiFi { namespace native { diff --git a/backends/cadence/vision/operators/operators.h b/backends/cadence/vision/operators/operators.h index 6842fad41fd..8b5db4161eb 100644 --- a/backends/cadence/vision/operators/operators.h +++ b/backends/cadence/vision/operators/operators.h @@ -23,6 +23,11 @@ using ::executorch::runtime::getLeadingDims; _(uint8_t, Byte) \ _(int8_t, Char) +#define ET_FORALL_CADENCE_QUANTIZED_TYPES_WITH_INT16(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) + inline __attribute__((always_inline)) void linear_( const ::executorch::aten::Tensor& input, const ::executorch::aten::Tensor& weight,