Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions p-isa_tools/kerngen/kernel_optimization/loop_ordering_lookup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# These contents may have been developed with support from one or more Intel-operated
# generative artificial intelligence solutions.

"""
Loop order lookup functionality for encrypted computing kernels.

This module provides functionality to determine primary and secondary loop orders
based on scheme, kernel type, polynomial order, and RNS parameters.
"""

import json
from pathlib import Path

LOOP_ORDER_CONFIG = str(Path(__file__).parent.parent.absolute() / "kernel_optimization/loop_order_config.json")


def _parse_range(range_str: str) -> tuple[int, int]:
"""
Parse a range string like '1-5' or '3' into min, max values.

Args:
range_str (str): Range string like '1-5' or '3'

Returns:
Tuple[int, int]: (min_value, max_value) inclusive
"""
if "-" in range_str:
min_val, max_val = range_str.split("-")
return int(min_val), int(max_val)
else:
val = int(range_str)
return val, val


def _value_in_range(value: int, range_str: str) -> bool:
"""
Check if a value falls within a range string.

Args:
value (int): Value to check
range_str (str): Range string like '1-5' or '3'

Returns:
bool: True if value is in range
"""
min_val, max_val = _parse_range(range_str)
return min_val <= value <= max_val


def get_loop_order(
scheme: str,
kernel: str,
polyorder: int,
max_rns: int,
) -> tuple[str, str]:
"""
Get primary and secondary loop order based on configuration.

Args:
scheme (str): Encryption scheme ('bgv', 'ckks')
kernel (str): Kernel type ('add', 'mul', 'muli', 'copy', 'sub',
'square', 'ntt', 'intt', 'mod', 'modup', 'relin',
'rotate', 'rescale')
polyorder (int): Polynomial order (16384, 32768, 65536)
max_rns (int): Maximum RNS value
config_file (str, optional): Path to configuration file.
Defaults to loop_order_config.json in same directory.

Returns:
Tuple[str, str]: Primary and secondary loop order

Raises:
FileNotFoundError: If configuration file is not found
KeyError: If the specified parameters are not found in configuration
ValueError: If parameters are invalid
"""
# Validate inputs
valid_schemes = {"bgv", "ckks"}
valid_kernels = {"add", "mul", "muli", "copy", "sub", "square", "ntt", "intt", "mod", "modup", "relin", "rotate", "rescale"}
valid_polyorders = {16384, 32768, 65536}

scheme = scheme.lower()
kernel = kernel.lower()

if scheme not in valid_schemes:
raise ValueError(f"Invalid scheme '{scheme}'. Must be one of {valid_schemes}")

if kernel not in valid_kernels:
raise ValueError(f"Invalid kernel '{kernel}'. Must be one of {valid_kernels}")

if polyorder not in valid_polyorders:
raise ValueError(f"Invalid polyorder '{polyorder}'. Must be one of {valid_polyorders}")

if max_rns < 1:
raise ValueError(f"Invalid RNS value: max_rns={max_rns}")

try:
with open(LOOP_ORDER_CONFIG, encoding="utf-8") as f:
config = json.load(f)
except FileNotFoundError as e:
raise FileNotFoundError(f"Configuration file not found: {LOOP_ORDER_CONFIG}") from e
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in configuration file: {e}") from e

# Lookup configuration with range support
try:
scheme_config = config[scheme]
kernel_config = scheme_config[kernel]
polyorder_config = kernel_config[str(polyorder)]

# Find matching max_rns range
loop_order = None
for max_rns_range, order_config in polyorder_config.items():
if _value_in_range(max_rns, max_rns_range):
loop_order = order_config
break

if loop_order is None:
raise KeyError(f"max_rns={max_rns}")

return tuple(loop_order)

except KeyError as e:
raise KeyError(
f"Configuration not found for scheme='{scheme}', kernel='{kernel}', "
f"polyorder={polyorder}, max_rns={max_rns}. "
f"Missing key: {e}"
) from e


def list_available_configurations(config_file: str | None = None) -> dict:
"""
List all available configurations in the config file.

Args:
config_file (str, optional): Path to configuration file.

Returns:
dict: The complete configuration structure
"""

with open(LOOP_ORDER_CONFIG, encoding="utf-8") as f:
return json.load(f)
51 changes: 49 additions & 2 deletions p-isa_tools/kerngen/kernel_optimization/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re

from const.options import LoopKey
from high_parser.pisa_operations import Comment, PIsaOp
from high_parser.pisa_operations import NTT, BinaryOp, Comment, PIsaOp


class PIsaOpGroup:
Expand Down Expand Up @@ -53,6 +53,54 @@ def remove_comments(pisa_list: list[PIsaOp]) -> list[PIsaOp]:
return [pisa for pisa in pisa_list if not isinstance(pisa, Comment)]


def reuse_rns_label(pisa: PIsaOp, current_rns: int) -> PIsaOp:
"""Helper function to remove RNS terms from modsw output"""
if isinstance(pisa, BinaryOp):
if "x_" in pisa.input0 or "y_" in pisa.input0 or "outtmp_" in pisa.input0:
pisa.input0 = re.sub(
"(x|y|outtmp)_([0-9]+)_[0-9]+_",
r"\1_\2_" + f"{current_rns-1}_",
pisa.input0,
)
if "x_" in pisa.input1 or "y_" in pisa.input1 or "outtmp_" in pisa.input1:
pisa.input1 = re.sub(
"(x|y|outtmp)_([0-9]+)_[0-9]+_",
r"\1_\2_" + f"{current_rns-1}_",
pisa.input1,
)
if "x_" in pisa.output or "y_" in pisa.output or "outtmp_" in pisa.output:
pisa.output = re.sub(
"(x|y|outtmp)_([0-9]+)_[0-9]+_",
r"\1_\2_" + f"{current_rns-1}_",
pisa.output,
)
if isinstance(pisa, NTT):
if "x_" in pisa.input0 and "x_" in pisa.input1 or "outtmp_" in pisa.input0 and "outtmp_" in pisa.input1:
pisa.input0 = re.sub(
"(x|outtmp|y)_([0-9]+)_[0-9]+_",
r"\1_\2_" + f"{current_rns-1}_",
pisa.input0,
)
pisa.input1 = re.sub(
"(x|outtmp|y)_([0-9]+)_[0-9]+_",
r"\1_\2_" + f"{current_rns-1}_",
pisa.input1,
)
if "x_" in pisa.output0 and "x_" in pisa.output1 or "outtmp_" in pisa.output0 and "outtmp_" in pisa.output1:
pisa.output0 = re.sub(
"(x|outtmp|y)_([0-9]+)_[0-9]+_",
r"\1_\2_" + f"{current_rns-1}_",
pisa.output0,
)
pisa.output1 = re.sub(
"(x|outtmp|y)_([0-9]+)_[0-9]+_",
r"\1_\2_" + f"{current_rns-1}_",
pisa.output1,
)

return pisa


def split_by_reorderable(pisa_list: list[PIsaOp]) -> list[PIsaOpGroup]:
"""Split a list of PIsaOp instructions into reorderable and non-reorderable groups.

Expand Down Expand Up @@ -95,7 +143,6 @@ def split_by_reorderable(pisa_list: list[PIsaOp]) -> list[PIsaOpGroup]:
if no_reoderable_group:
for group in groups:
group.is_reorderable = True

return groups


Expand Down
81 changes: 68 additions & 13 deletions p-isa_tools/kerngen/kerngraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@

from const.options import LoopKey
from high_parser.config import Config
from kernel_optimization.loops import loop_interchange, split_by_reorderable
from kernel_optimization.loop_ordering_lookup import get_loop_order
from kernel_optimization.loops import loop_interchange, reuse_rns_label, split_by_reorderable
from kernel_parser.parser import KernelParser
from pisa_generators.basic import mixed_to_pisa_ops

Expand All @@ -49,8 +50,8 @@ def parse_args():
"--target",
nargs="*",
default=[],
# Composition high ops such are ntt, mod, and relin are not currently supported
choices=["add", "sub", "mul", "muli", "copy", "ntt", "intt", "mod", "relin"],
# Composition high ops such are ntt, mod, etc.
choices=["add", "sub", "mul", "muli", "copy", "ntt", "intt", "mod", "relin", "rotate", "rescale"],
help="List of high_op names",
)
parser.add_argument(
Expand All @@ -69,11 +70,55 @@ def parse_args():
choices=list(LoopKey) + [None],
help="Secondary key for loop interchange (default: None, Options: RNS, PART)",
)
parser.add_argument(
"--optimal",
action="store_true",
help="Use optimal primary and secondary loop order based on kernel configuration (overrides -p and -s)",
)
parsed_args = parser.parse_args()
# verify that primary and secondary keys are not the same
if parsed_args.primary == parsed_args.secondary:
if not parsed_args.optimal and parsed_args.primary == parsed_args.secondary:
raise ValueError("Primary and secondary keys cannot be the same.")
return parser.parse_args()
return parsed_args


def get_optimal_loop_order(kernel, debug=False):
"""
Get optimal loop order for a kernel based on its properties.

Args:
kernel: Parsed kernel object
debug (bool): Enable debug output

Returns:
Tuple[LoopKey, LoopKey]: Primary and secondary loop keys, or (None, None) if not found
"""
try:
# Extract kernel properties
scheme = getattr(kernel.context, "scheme", "bgv").lower()
kernel_name = str(kernel).split("(")[0].lower()
polyorder = getattr(kernel.context, "poly_order", 16384)
max_rns = getattr(kernel.context, "max_rns", 3)
# Get optimal loop order from configuration
primary_str, secondary_str = get_loop_order(scheme, kernel_name, polyorder, max_rns)
# Map string values to LoopKey enum
loop_key_mapping = {"part": LoopKey.PART, "rns": LoopKey.RNS, "null": None}

primary_key = loop_key_mapping.get(primary_str)
secondary_key = loop_key_mapping.get(secondary_str)

if debug:
print(
"# Optimal loop order for"
+ f" {scheme}.{kernel_name}: primary={primary_str} ({primary_key}), secondary={secondary_str} ({secondary_key})"
)

return primary_key, secondary_key

except ValueError as e:
if debug:
print(f"# Warning: Could not determine optimal loop order for kernel {kernel}: {e}")
return None, None


def parse_kernels(input_lines, debug=False):
Expand All @@ -92,17 +137,24 @@ def parse_kernels(input_lines, debug=False):

def process_kernel_with_reordering(kernel, args):
"""Process a kernel with reordering optimization."""
# Determine loop order
if args.optimal:
primary_key, secondary_key = get_optimal_loop_order(kernel, args.debug)
else:
primary_key = args.primary
secondary_key = args.secondary

groups = split_by_reorderable(kernel.to_pisa())
processed_kernel = []
for group in groups:
if group.is_reorderable:
processed_kernel.append(
loop_interchange(
group.pisa_list,
primary_key=args.primary,
secondary_key=args.secondary,
)
)
interchanged_pisa = loop_interchange(group.pisa_list, primary_key=primary_key, secondary_key=secondary_key)

if ("mod" in args.target) and (primary_key is not None and secondary_key is not None):
for pisa in mixed_to_pisa_ops(interchanged_pisa):
processed_kernel.append(reuse_rns_label(pisa, kernel.context.current_rns))
else:
processed_kernel.append(interchanged_pisa)
else:
processed_kernel.append(group.pisa_list)

Expand All @@ -127,7 +179,10 @@ def main(args):
return

if args.debug:
print(f"# Reordered targets {args.target} with primary key {args.primary} and secondary key {args.secondary}")
if args.optimal:
print(f"# Using optimal loop order configuration for targets {args.target}")
else:
print(f"# Reordered targets {args.target} with primary key {args.primary} and secondary key {args.secondary}")

for kernel in valid_kernels:
if should_apply_reordering(kernel, args.target):
Expand Down
2 changes: 1 addition & 1 deletion p-isa_tools/kerngen/pisa_generators/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_pisa_op(num):
ls: list[pisa_op] = []
for digit, op in get_pisa_op(self.input1.digits):
input0_tmp = Polys.from_polys(self.input0)
input0_tmp.name += f"_{digit}"
input0_tmp.name += f"_tmp{digit}"

# mul/mac for 0-current_rns
ls.extend(
Expand Down
6 changes: 4 additions & 2 deletions p-isa_tools/kerngen/pisa_generators/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import high_parser.pisa_operations as pisa_op
from high_parser import HighOp, Immediate, KernelContext, Polys
from high_parser.pisa_operations import PIsaOp
from high_parser.pisa_operations import Comment, PIsaOp

from .basic import Muli, mixed_to_pisa_ops
from .ntt import INTT, NTT
Expand Down Expand Up @@ -65,7 +65,7 @@ def to_pisa(self) -> list[PIsaOp]:
)

output_tmp = Polys.from_polys(self.output)
output_tmp.name += f"_{input_rns_index}"
output_tmp.name += f"_tmp{input_rns_index}"
output_split = Polys.from_polys(self.output)
output_split.rns = self.context.current_rns
# ntt for 0-current_rns
Expand All @@ -80,5 +80,7 @@ def to_pisa(self) -> list[PIsaOp]:
return mixed_to_pisa_ops(
INTT(self.context, rns_poly, self.input0),
Muli(self.context, rns_poly, rns_poly, one),
Comment("<reorderable>"),
ls,
Comment("</reorderable>"),
)
1 change: 1 addition & 0 deletions p-isa_tools/kerngen/pisa_generators/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def generate_mod_stages() -> list[Stage]:
input_remaining_rns,
last_q,
),
Comment("<reorderable>"),
Muli(
self.context,
temp_input_remaining_rns,
Expand Down
3 changes: 2 additions & 1 deletion p-isa_tools/kerngen/pisa_generators/rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def to_pisa(self) -> list[PIsaOp]:
input_remaining_rns,
last_q,
),
Comment("<reorderable>"),
Muli(self.context, temp_input_remaining_rns, temp_input_remaining_rns, r2),
NTT(self.context, temp_input_remaining_rns, temp_input_remaining_rns),
Sub(
Expand All @@ -88,6 +89,6 @@ def to_pisa(self) -> list[PIsaOp]:
temp_input_remaining_rns,
),
Muli(self.context, self.output, temp_input_remaining_rns, iq),
Comment("End of Rescale kernel."),
Comment("End of Rescale kernel </reorderable>"),
]
)
Loading