Skip to content

Commit fe1bc8a

Browse files
Cortex_M backend: Fuse clamp + hardswish decompostion (#16016)
Adds quantization and fusion of clamp. This is in turn used to decompose the hardswish operator in two passes, one clamping the dynamic range before quantization, and one decomposing the reminder of the operation into a maximum and mul op. The tests in this patch exposes an issue in the runtime dim_order check as it cannot differ between channels_last/channels_first for tensors with C=1 or H=W=1. Therefore the check is removed and added as a TBD instead. Additionally fixes per_tensor quantization for conv2d. --------- Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent 2ed6d6a commit fe1bc8a

File tree

11 files changed

+326
-16
lines changed

11 files changed

+326
-16
lines changed

backends/cortex_m/ops/cortex_m_ops_common.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,7 @@ inline void validate_cmsis_nn_tensor_requirements(
6969
"Output must have the same sizes as inputs");
7070
}
7171

72-
// Dim order consistency
73-
ET_CHECK_MSG(
74-
executorch::runtime::tensors_have_same_dim_order(input1, input2, output),
75-
"Tensors must have same dimension order");
76-
72+
// TBD (#16032): Validate dim_order
7773
// TBD: Validate memory alignment (CMSIS-NN requirement)
7874
}
7975

backends/cortex_m/passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from .activation_fusion_pass import ActivationFusionPass # noqa
7+
from .clamp_hardswish_pass import ClampHardswishPass # noqa
78
from .convert_to_cortex_m_pass import ConvertToCortexMPass # noqa
9+
from .decompose_hardswish_pass import DecomposeHardswishPass # noqa
810
from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa
911
from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa
1012
from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip

backends/cortex_m/passes/activation_fusion_pass.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import executorch.backends.cortex_m.ops.operators # noqa: F401
1010
from executorch.backends.arm._passes.quant_args import QuantArgs
11+
from executorch.backends.cortex_m.passes.passes_utils import quantize_val
1112

1213
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import ExportPass
@@ -33,16 +34,14 @@ class ActivationFusionPass(ExportPass):
3334
exir_ops.edge.aten.relu.default,
3435
exir_ops.edge.aten.hardtanh.default,
3536
exir_ops.edge.aten.hardsigmoid.default,
37+
exir_ops.edge.aten.clamp.default,
3638
}
3739

3840
FUSE_OPS = {
3941
exir_ops.edge.aten.linear.default,
4042
exir_ops.edge.aten.convolution.default,
4143
}
4244

43-
def _quantize(self, val, scale, zp, qmin, qmax):
44-
return min(max(round(val / scale + zp), qmin), qmax)
45-
4645
def _get_validated_qparams(self, node, input_node):
4746

4847
if "input_qparams" not in input_node.meta or "output_qparams" not in node.meta:
@@ -65,14 +64,26 @@ def _get_validated_qparams(self, node, input_node):
6564

6665
match node.target:
6766
case exir_ops.edge.aten.relu.default:
68-
quantized_min_val = self._quantize(0, scale, zp, qmin, qmax)
67+
quantized_min_val = quantize_val(0, scale, zp, qmin, qmax)
6968
quantized_max_val = qmax
7069
case exir_ops.edge.aten.hardtanh.default:
71-
quantized_min_val = self._quantize(node.args[1], scale, zp, qmin, qmax)
72-
quantized_max_val = self._quantize(node.args[2], scale, zp, qmin, qmax)
70+
quantized_min_val = quantize_val(node.args[1], scale, zp, qmin, qmax)
71+
quantized_max_val = quantize_val(node.args[2], scale, zp, qmin, qmax)
7372
case exir_ops.edge.aten.hardsigmoid.default:
74-
quantized_min_val = self._quantize(0, scale, zp, qmin, qmax)
75-
quantized_max_val = self._quantize(1, scale, zp, qmin, qmax)
73+
quantized_min_val = quantize_val(0, scale, zp, qmin, qmax)
74+
quantized_max_val = quantize_val(1, scale, zp, qmin, qmax)
75+
case exir_ops.edge.aten.clamp.default:
76+
quantized_min_val = (
77+
quantize_val(node.args[1], scale, zp, qmin, qmax)
78+
if node.args[1] is not None
79+
else qmin
80+
)
81+
# Last arg is removed if none, so check length of args here
82+
quantized_max_val = (
83+
quantize_val(node.args[2], scale, zp, qmin, qmax)
84+
if len(node.args) == 3
85+
else qmax
86+
)
7687
case _:
7788
raise RuntimeError("Unexpected target {node.target}.")
7889

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Dict
7+
8+
import torch
9+
10+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
11+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
12+
from torch.fx.node import Argument
13+
14+
15+
class ClampHardswishPass(ExportPass):
16+
"""
17+
Adds a clamp operation before hardswish to ensure input is in the range [-3, inf).
18+
19+
By doing this before quantization the output range of the preceeding op is minimized,
20+
potentially improving accuracy.
21+
"""
22+
23+
def call_operator(
24+
self,
25+
op: EdgeOpOverload,
26+
args: tuple[Argument, ...],
27+
kwargs: Dict[str, Argument],
28+
meta: NodeMetadata,
29+
) -> ProxyValue:
30+
if op == torch.ops.aten.hardswish.default:
31+
clamped_args = (args[0], -3)
32+
clamped_input = super().call_operator(
33+
torch.ops.aten.clamp.default, clamped_args, {}, meta
34+
)
35+
args = (clamped_input,)
36+
37+
return super().call_operator(op, args, kwargs, meta)

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
import torch.fx
13+
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1314
from executorch.backends.cortex_m.passes.passes_utils import quantize_multiplier_aot
1415

1516
from executorch.backends.transforms.utils import (
@@ -137,7 +138,8 @@ def _get_convolution_replacement(self, node) -> int:
137138
input_zero_point = node.meta["input_qparams"][0].zp
138139
weight_scales = node.meta["input_qparams"][1].scale
139140
if not isinstance(weight_scales, list):
140-
weight_scales = [weight_scales] * weight.data.shape[0]
141+
weight_tensor = get_first_fake_tensor(weight)
142+
weight_scales = [weight_scales] * weight_tensor.shape[0]
141143

142144
output_qparams = node.meta["output_qparams"][0]
143145
output_scale = output_qparams.scale

backends/cortex_m/passes/cortex_m_pass_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
)
1313
from executorch.backends.cortex_m.passes import (
1414
ActivationFusionPass,
15+
ClampHardswishPass,
1516
ConvertToCortexMPass,
17+
DecomposeHardswishPass,
1618
QuantizedOpFusionPass,
1719
ReplaceQuantNodesPass,
1820
)
@@ -31,14 +33,16 @@ class CortexMPassManager(PassManager):
3133
FoldAndAnnotateQParamsPass,
3234
ReplaceScalarWithTensorArgPass,
3335
ReplaceQuantNodesPass,
34-
QuantizedOpFusionPass,
3536
ActivationFusionPass,
37+
DecomposeHardswishPass,
38+
QuantizedOpFusionPass,
3639
ConvertToCortexMPass,
3740
]
3841

3942
pass_list_transform_for_annotation: list[ExportPass] = [
4043
ScalarsToAttributePass,
4144
ReplaceScalarWithTensorArgPass,
45+
ClampHardswishPass,
4246
]
4347

4448
def __init__(self, exported_program, passes=None):
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import logging
8+
9+
import executorch.backends.cortex_m.ops.operators # noqa: F401
10+
11+
import torch
12+
from executorch.backends.arm._passes.quant_args import QuantArgs
13+
14+
from executorch.backends.cortex_m.passes.passes_utils import quantize_val
15+
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.pass_base import ExportPass
18+
from torch.fx import GraphModule, Node
19+
from torch.fx.passes.infra.pass_manager import PassResult
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class DecomposeHardswishPass(ExportPass):
25+
"""
26+
Decomposes hardswish like
27+
28+
hardswish(x) = x * (clamp(x, -3, 3) + 3)/6
29+
30+
where the add and division is implemented by modifying the quantization parameters similar
31+
to hardsigmoid in the activation_fusion_pass. Note that this pass assumes
32+
that the output range of the preceding op is already clamped to [-3, inf] during
33+
quantization by the clamp_hardswish_pass, removing the need for the negative clamp.
34+
"""
35+
36+
TARGETS = {
37+
exir_ops.edge.aten.hardswish.default,
38+
}
39+
40+
FUSE_OPS = {
41+
exir_ops.edge.aten.linear.default,
42+
exir_ops.edge.aten.convolution.default,
43+
}
44+
45+
def call(self, graph_module: GraphModule) -> PassResult:
46+
modified = False
47+
nodes_to_erase: list[Node] = []
48+
49+
for node in list(graph_module.graph.nodes):
50+
if node.op != "call_function" or node.target not in self.TARGETS:
51+
continue
52+
53+
input_node = node.args[0]
54+
if (
55+
input_node.op != "call_function"
56+
or input_node.target not in self.FUSE_OPS
57+
):
58+
logger.warning(
59+
f"Cannot fuse activation {node.name} as input node {input_node.name} is not a supported fused activation op."
60+
)
61+
continue
62+
if len(input_node.users.values()) > 1:
63+
logger.warning(
64+
f"Cannot fuse activation {node.name} as input node {input_node.name} has multiple users."
65+
)
66+
continue
67+
68+
input_quant_dict = input_node.meta.get("output_qparams", [None])[
69+
0
70+
]._asdict()
71+
scale = input_quant_dict["scale"]
72+
zero_point = input_quant_dict["zp"]
73+
qmin = input_quant_dict["qmin"]
74+
qmax = input_quant_dict["qmax"]
75+
76+
# Create min node
77+
with graph_module.graph.inserting_after(input_node):
78+
clamp_node = graph_module.graph.create_node(
79+
"call_function",
80+
target=exir_ops.edge.aten.minimum.default,
81+
args=(
82+
input_node,
83+
torch.tensor(
84+
quantize_val(3, scale, zero_point, qmin, qmax),
85+
dtype=torch.int8,
86+
),
87+
),
88+
kwargs={},
89+
)
90+
clamp_node.meta = input_node.meta.copy()
91+
92+
# Create mul node
93+
with graph_module.graph.inserting_after(clamp_node):
94+
mul_node = graph_module.graph.create_node(
95+
"call_function",
96+
target=exir_ops.edge.aten.mul.Tensor,
97+
args=(input_node, clamp_node),
98+
kwargs={},
99+
)
100+
mul_node.meta = node.meta.copy()
101+
102+
mul_quant_dict = node.meta["input_qparams"][0]._asdict()
103+
104+
mul_quant_dict_shifted = mul_quant_dict.copy()
105+
mul_quant_dict_shifted["zp"] = mul_quant_dict_shifted["zp"] - round(
106+
3 / (mul_quant_dict_shifted["scale"])
107+
)
108+
109+
output_quant_dict = node.meta["output_qparams"][0]._asdict()
110+
output_quant_dict["scale"] = output_quant_dict["scale"] * 6
111+
112+
node.meta["input_qparams"][0] = QuantArgs(**mul_quant_dict)
113+
mul_node.meta["input_qparams"][1] = QuantArgs(**mul_quant_dict_shifted)
114+
mul_node.meta["output_qparams"][0] = QuantArgs(**output_quant_dict)
115+
116+
node.replace_all_uses_with(mul_node)
117+
nodes_to_erase.append(node)
118+
modified = True
119+
120+
for node in nodes_to_erase:
121+
graph_module.graph.erase_node(node)
122+
123+
if modified:
124+
graph_module.graph.eliminate_dead_code()
125+
graph_module.recompile()
126+
127+
return PassResult(graph_module, modified)

backends/cortex_m/passes/passes_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
SHIFT_INT8 = 20
1818

1919

20+
def quantize_val(val, scale, zp, qmin, qmax):
21+
return min(max(round(val / scale + zp), qmin), qmax)
22+
23+
2024
def dequantize_per_tensor_cmsis(
2125
qtensor: torch.Tensor, zero_point: int, multiplier: int, shift: int
2226
) -> torch.Tensor:

backends/cortex_m/quantizer/operator_configs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
BINARY_OP_PATTERNS = [
2020
[torch.ops.aten.add.Tensor],
2121
[torch.ops.aten.mul.Tensor],
22+
[torch.ops.aten.hardswish.default],
23+
[torch.ops.aten.hardswish_.default],
2224
]
2325

2426
LINEAR_OP_PATTERNS = [
@@ -29,6 +31,8 @@
2931
[torch.ops.aten.linear.default, torch.ops.aten.hardtanh_.default],
3032
[torch.ops.aten.linear.default, torch.ops.aten.hardsigmoid.default],
3133
[torch.ops.aten.linear.default, torch.ops.aten.hardsigmoid_.default],
34+
[torch.ops.aten.linear.default, torch.ops.aten.clamp.default],
35+
[torch.ops.aten.linear.default, torch.ops.aten.clamp_.default],
3236
]
3337

3438
CONV_OP_PATTERNS = [
@@ -39,6 +43,8 @@
3943
[torch.ops.aten.conv2d.default, torch.ops.aten.hardtanh_.default],
4044
[torch.ops.aten.conv2d.default, torch.ops.aten.hardsigmoid.default],
4145
[torch.ops.aten.conv2d.default, torch.ops.aten.hardsigmoid_.default],
46+
[torch.ops.aten.conv2d.default, torch.ops.aten.clamp.default],
47+
[torch.ops.aten.conv2d.default, torch.ops.aten.clamp_.default],
4248
]
4349

4450
# ----------------- OPERATOR CONFIG PRESETS -----------------

0 commit comments

Comments
 (0)