Skip to content

Commit 8dfbe7f

Browse files
New Feature: Kerngraph Tagged Loop Reordering Support (#141)
* Supports grouping of p-isa lists to enable flexible loop reordering through <reorderable></reorderable> tags in Comments. * Support for loop reordering for mod and digit decomp kernels. * Added support for optimal strategy selection based on kerngen/kernel_optimization/loop_order_config.json
1 parent a06a537 commit 8dfbe7f

File tree

14 files changed

+270
-1458
lines changed

14 files changed

+270
-1458
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
# These contents may have been developed with support from one or more Intel-operated
4+
# generative artificial intelligence solutions.
5+
6+
"""
7+
Loop order lookup functionality for encrypted computing kernels.
8+
9+
This module provides functionality to determine primary and secondary loop orders
10+
based on scheme, kernel type, polynomial order, and RNS parameters.
11+
"""
12+
13+
import json
14+
from pathlib import Path
15+
16+
LOOP_ORDER_CONFIG = str(Path(__file__).parent.parent.absolute() / "kernel_optimization/loop_order_config.json")
17+
18+
19+
def _parse_range(range_str: str) -> tuple[int, int]:
20+
"""
21+
Parse a range string like '1-5' or '3' into min, max values.
22+
23+
Args:
24+
range_str (str): Range string like '1-5' or '3'
25+
26+
Returns:
27+
Tuple[int, int]: (min_value, max_value) inclusive
28+
"""
29+
if "-" in range_str:
30+
min_val, max_val = range_str.split("-")
31+
return int(min_val), int(max_val)
32+
else:
33+
val = int(range_str)
34+
return val, val
35+
36+
37+
def _value_in_range(value: int, range_str: str) -> bool:
38+
"""
39+
Check if a value falls within a range string.
40+
41+
Args:
42+
value (int): Value to check
43+
range_str (str): Range string like '1-5' or '3'
44+
45+
Returns:
46+
bool: True if value is in range
47+
"""
48+
min_val, max_val = _parse_range(range_str)
49+
return min_val <= value <= max_val
50+
51+
52+
def get_loop_order(
53+
scheme: str,
54+
kernel: str,
55+
polyorder: int,
56+
max_rns: int,
57+
) -> tuple[str, str]:
58+
"""
59+
Get primary and secondary loop order based on configuration.
60+
61+
Args:
62+
scheme (str): Encryption scheme ('bgv', 'ckks')
63+
kernel (str): Kernel type ('add', 'mul', 'muli', 'copy', 'sub',
64+
'square', 'ntt', 'intt', 'mod', 'modup', 'relin',
65+
'rotate', 'rescale')
66+
polyorder (int): Polynomial order (16384, 32768, 65536)
67+
max_rns (int): Maximum RNS value
68+
config_file (str, optional): Path to configuration file.
69+
Defaults to loop_order_config.json in same directory.
70+
71+
Returns:
72+
Tuple[str, str]: Primary and secondary loop order
73+
74+
Raises:
75+
FileNotFoundError: If configuration file is not found
76+
KeyError: If the specified parameters are not found in configuration
77+
ValueError: If parameters are invalid
78+
"""
79+
# Validate inputs
80+
valid_schemes = {"bgv", "ckks"}
81+
valid_kernels = {"add", "mul", "muli", "copy", "sub", "square", "ntt", "intt", "mod", "modup", "relin", "rotate", "rescale"}
82+
valid_polyorders = {16384, 32768, 65536}
83+
84+
scheme = scheme.lower()
85+
kernel = kernel.lower()
86+
87+
if scheme not in valid_schemes:
88+
raise ValueError(f"Invalid scheme '{scheme}'. Must be one of {valid_schemes}")
89+
90+
if kernel not in valid_kernels:
91+
raise ValueError(f"Invalid kernel '{kernel}'. Must be one of {valid_kernels}")
92+
93+
if polyorder not in valid_polyorders:
94+
raise ValueError(f"Invalid polyorder '{polyorder}'. Must be one of {valid_polyorders}")
95+
96+
if max_rns < 1:
97+
raise ValueError(f"Invalid RNS value: max_rns={max_rns}")
98+
99+
try:
100+
with open(LOOP_ORDER_CONFIG, encoding="utf-8") as f:
101+
config = json.load(f)
102+
except FileNotFoundError as e:
103+
raise FileNotFoundError(f"Configuration file not found: {LOOP_ORDER_CONFIG}") from e
104+
except json.JSONDecodeError as e:
105+
raise ValueError(f"Invalid JSON in configuration file: {e}") from e
106+
107+
# Lookup configuration with range support
108+
try:
109+
scheme_config = config[scheme]
110+
kernel_config = scheme_config[kernel]
111+
polyorder_config = kernel_config[str(polyorder)]
112+
113+
# Find matching max_rns range
114+
loop_order = None
115+
for max_rns_range, order_config in polyorder_config.items():
116+
if _value_in_range(max_rns, max_rns_range):
117+
loop_order = order_config
118+
break
119+
120+
if loop_order is None:
121+
raise KeyError(f"max_rns={max_rns}")
122+
123+
return tuple(loop_order)
124+
125+
except KeyError as e:
126+
raise KeyError(
127+
f"Configuration not found for scheme='{scheme}', kernel='{kernel}', "
128+
f"polyorder={polyorder}, max_rns={max_rns}. "
129+
f"Missing key: {e}"
130+
) from e
131+
132+
133+
def list_available_configurations(config_file: str | None = None) -> dict:
134+
"""
135+
List all available configurations in the config file.
136+
137+
Args:
138+
config_file (str, optional): Path to configuration file.
139+
140+
Returns:
141+
dict: The complete configuration structure
142+
"""
143+
144+
with open(LOOP_ORDER_CONFIG, encoding="utf-8") as f:
145+
return json.load(f)

p-isa_tools/kerngen/kernel_optimization/loops.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import re
99

1010
from const.options import LoopKey
11-
from high_parser.pisa_operations import Comment, PIsaOp
11+
from high_parser.pisa_operations import NTT, BinaryOp, Comment, PIsaOp
1212

1313

1414
class PIsaOpGroup:
@@ -53,6 +53,54 @@ def remove_comments(pisa_list: list[PIsaOp]) -> list[PIsaOp]:
5353
return [pisa for pisa in pisa_list if not isinstance(pisa, Comment)]
5454

5555

56+
def reuse_rns_label(pisa: PIsaOp, current_rns: int) -> PIsaOp:
57+
"""Helper function to remove RNS terms from modsw output"""
58+
if isinstance(pisa, BinaryOp):
59+
if "x_" in pisa.input0 or "y_" in pisa.input0 or "outtmp_" in pisa.input0:
60+
pisa.input0 = re.sub(
61+
"(x|y|outtmp)_([0-9]+)_[0-9]+_",
62+
r"\1_\2_" + f"{current_rns-1}_",
63+
pisa.input0,
64+
)
65+
if "x_" in pisa.input1 or "y_" in pisa.input1 or "outtmp_" in pisa.input1:
66+
pisa.input1 = re.sub(
67+
"(x|y|outtmp)_([0-9]+)_[0-9]+_",
68+
r"\1_\2_" + f"{current_rns-1}_",
69+
pisa.input1,
70+
)
71+
if "x_" in pisa.output or "y_" in pisa.output or "outtmp_" in pisa.output:
72+
pisa.output = re.sub(
73+
"(x|y|outtmp)_([0-9]+)_[0-9]+_",
74+
r"\1_\2_" + f"{current_rns-1}_",
75+
pisa.output,
76+
)
77+
if isinstance(pisa, NTT):
78+
if "x_" in pisa.input0 and "x_" in pisa.input1 or "outtmp_" in pisa.input0 and "outtmp_" in pisa.input1:
79+
pisa.input0 = re.sub(
80+
"(x|outtmp|y)_([0-9]+)_[0-9]+_",
81+
r"\1_\2_" + f"{current_rns-1}_",
82+
pisa.input0,
83+
)
84+
pisa.input1 = re.sub(
85+
"(x|outtmp|y)_([0-9]+)_[0-9]+_",
86+
r"\1_\2_" + f"{current_rns-1}_",
87+
pisa.input1,
88+
)
89+
if "x_" in pisa.output0 and "x_" in pisa.output1 or "outtmp_" in pisa.output0 and "outtmp_" in pisa.output1:
90+
pisa.output0 = re.sub(
91+
"(x|outtmp|y)_([0-9]+)_[0-9]+_",
92+
r"\1_\2_" + f"{current_rns-1}_",
93+
pisa.output0,
94+
)
95+
pisa.output1 = re.sub(
96+
"(x|outtmp|y)_([0-9]+)_[0-9]+_",
97+
r"\1_\2_" + f"{current_rns-1}_",
98+
pisa.output1,
99+
)
100+
101+
return pisa
102+
103+
56104
def split_by_reorderable(pisa_list: list[PIsaOp]) -> list[PIsaOpGroup]:
57105
"""Split a list of PIsaOp instructions into reorderable and non-reorderable groups.
58106
@@ -95,7 +143,6 @@ def split_by_reorderable(pisa_list: list[PIsaOp]) -> list[PIsaOpGroup]:
95143
if no_reoderable_group:
96144
for group in groups:
97145
group.is_reorderable = True
98-
99146
return groups
100147

101148

p-isa_tools/kerngen/kerngraph.py

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434

3535
from const.options import LoopKey
3636
from high_parser.config import Config
37-
from kernel_optimization.loops import loop_interchange, split_by_reorderable
37+
from kernel_optimization.loop_ordering_lookup import get_loop_order
38+
from kernel_optimization.loops import loop_interchange, reuse_rns_label, split_by_reorderable
3839
from kernel_parser.parser import KernelParser
3940
from pisa_generators.basic import mixed_to_pisa_ops
4041

@@ -49,8 +50,8 @@ def parse_args():
4950
"--target",
5051
nargs="*",
5152
default=[],
52-
# Composition high ops such are ntt, mod, and relin are not currently supported
53-
choices=["add", "sub", "mul", "muli", "copy", "ntt", "intt", "mod", "relin"],
53+
# Composition high ops such are ntt, mod, etc.
54+
choices=["add", "sub", "mul", "muli", "copy", "ntt", "intt", "mod", "relin", "rotate", "rescale"],
5455
help="List of high_op names",
5556
)
5657
parser.add_argument(
@@ -69,11 +70,55 @@ def parse_args():
6970
choices=list(LoopKey) + [None],
7071
help="Secondary key for loop interchange (default: None, Options: RNS, PART)",
7172
)
73+
parser.add_argument(
74+
"--optimal",
75+
action="store_true",
76+
help="Use optimal primary and secondary loop order based on kernel configuration (overrides -p and -s)",
77+
)
7278
parsed_args = parser.parse_args()
7379
# verify that primary and secondary keys are not the same
74-
if parsed_args.primary == parsed_args.secondary:
80+
if not parsed_args.optimal and parsed_args.primary == parsed_args.secondary:
7581
raise ValueError("Primary and secondary keys cannot be the same.")
76-
return parser.parse_args()
82+
return parsed_args
83+
84+
85+
def get_optimal_loop_order(kernel, debug=False):
86+
"""
87+
Get optimal loop order for a kernel based on its properties.
88+
89+
Args:
90+
kernel: Parsed kernel object
91+
debug (bool): Enable debug output
92+
93+
Returns:
94+
Tuple[LoopKey, LoopKey]: Primary and secondary loop keys, or (None, None) if not found
95+
"""
96+
try:
97+
# Extract kernel properties
98+
scheme = getattr(kernel.context, "scheme", "bgv").lower()
99+
kernel_name = str(kernel).split("(")[0].lower()
100+
polyorder = getattr(kernel.context, "poly_order", 16384)
101+
max_rns = getattr(kernel.context, "max_rns", 3)
102+
# Get optimal loop order from configuration
103+
primary_str, secondary_str = get_loop_order(scheme, kernel_name, polyorder, max_rns)
104+
# Map string values to LoopKey enum
105+
loop_key_mapping = {"part": LoopKey.PART, "rns": LoopKey.RNS, "null": None}
106+
107+
primary_key = loop_key_mapping.get(primary_str)
108+
secondary_key = loop_key_mapping.get(secondary_str)
109+
110+
if debug:
111+
print(
112+
"# Optimal loop order for"
113+
+ f" {scheme}.{kernel_name}: primary={primary_str} ({primary_key}), secondary={secondary_str} ({secondary_key})"
114+
)
115+
116+
return primary_key, secondary_key
117+
118+
except ValueError as e:
119+
if debug:
120+
print(f"# Warning: Could not determine optimal loop order for kernel {kernel}: {e}")
121+
return None, None
77122

78123

79124
def parse_kernels(input_lines, debug=False):
@@ -92,17 +137,24 @@ def parse_kernels(input_lines, debug=False):
92137

93138
def process_kernel_with_reordering(kernel, args):
94139
"""Process a kernel with reordering optimization."""
140+
# Determine loop order
141+
if args.optimal:
142+
primary_key, secondary_key = get_optimal_loop_order(kernel, args.debug)
143+
else:
144+
primary_key = args.primary
145+
secondary_key = args.secondary
146+
95147
groups = split_by_reorderable(kernel.to_pisa())
96148
processed_kernel = []
97149
for group in groups:
98150
if group.is_reorderable:
99-
processed_kernel.append(
100-
loop_interchange(
101-
group.pisa_list,
102-
primary_key=args.primary,
103-
secondary_key=args.secondary,
104-
)
105-
)
151+
interchanged_pisa = loop_interchange(group.pisa_list, primary_key=primary_key, secondary_key=secondary_key)
152+
153+
if ("mod" in args.target) and (primary_key is not None and secondary_key is not None):
154+
for pisa in mixed_to_pisa_ops(interchanged_pisa):
155+
processed_kernel.append(reuse_rns_label(pisa, kernel.context.current_rns))
156+
else:
157+
processed_kernel.append(interchanged_pisa)
106158
else:
107159
processed_kernel.append(group.pisa_list)
108160

@@ -127,7 +179,10 @@ def main(args):
127179
return
128180

129181
if args.debug:
130-
print(f"# Reordered targets {args.target} with primary key {args.primary} and secondary key {args.secondary}")
182+
if args.optimal:
183+
print(f"# Using optimal loop order configuration for targets {args.target}")
184+
else:
185+
print(f"# Reordered targets {args.target} with primary key {args.primary} and secondary key {args.secondary}")
131186

132187
for kernel in valid_kernels:
133188
if should_apply_reordering(kernel, args.target):

p-isa_tools/kerngen/pisa_generators/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def get_pisa_op(num):
247247
ls: list[pisa_op] = []
248248
for digit, op in get_pisa_op(self.input1.digits):
249249
input0_tmp = Polys.from_polys(self.input0)
250-
input0_tmp.name += f"_{digit}"
250+
input0_tmp.name += f"_tmp{digit}"
251251

252252
# mul/mac for 0-current_rns
253253
ls.extend(

p-isa_tools/kerngen/pisa_generators/decomp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import high_parser.pisa_operations as pisa_op
99
from high_parser import HighOp, Immediate, KernelContext, Polys
10-
from high_parser.pisa_operations import PIsaOp
10+
from high_parser.pisa_operations import Comment, PIsaOp
1111

1212
from .basic import Muli, mixed_to_pisa_ops
1313
from .ntt import INTT, NTT
@@ -65,7 +65,7 @@ def to_pisa(self) -> list[PIsaOp]:
6565
)
6666

6767
output_tmp = Polys.from_polys(self.output)
68-
output_tmp.name += f"_{input_rns_index}"
68+
output_tmp.name += f"_tmp{input_rns_index}"
6969
output_split = Polys.from_polys(self.output)
7070
output_split.rns = self.context.current_rns
7171
# ntt for 0-current_rns
@@ -80,5 +80,7 @@ def to_pisa(self) -> list[PIsaOp]:
8080
return mixed_to_pisa_ops(
8181
INTT(self.context, rns_poly, self.input0),
8282
Muli(self.context, rns_poly, rns_poly, one),
83+
Comment("<reorderable>"),
8384
ls,
85+
Comment("</reorderable>"),
8486
)

p-isa_tools/kerngen/pisa_generators/mod.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def generate_mod_stages() -> list[Stage]:
147147
input_remaining_rns,
148148
last_q,
149149
),
150+
Comment("<reorderable>"),
150151
Muli(
151152
self.context,
152153
temp_input_remaining_rns,

p-isa_tools/kerngen/pisa_generators/rescale.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def to_pisa(self) -> list[PIsaOp]:
7979
input_remaining_rns,
8080
last_q,
8181
),
82+
Comment("<reorderable>"),
8283
Muli(self.context, temp_input_remaining_rns, temp_input_remaining_rns, r2),
8384
NTT(self.context, temp_input_remaining_rns, temp_input_remaining_rns),
8485
Sub(
@@ -88,6 +89,6 @@ def to_pisa(self) -> list[PIsaOp]:
8889
temp_input_remaining_rns,
8990
),
9091
Muli(self.context, self.output, temp_input_remaining_rns, iq),
91-
Comment("End of Rescale kernel."),
92+
Comment("End of Rescale kernel </reorderable>"),
9293
]
9394
)

0 commit comments

Comments
 (0)