From 93d26de9b696d45c0e62d71927c16a72b1dd0f2e Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Fri, 14 Nov 2025 23:12:12 +0000 Subject: [PATCH 1/2] New unit tests for he_as and kernel splitter --- .../hec-assembler-tools/he_prep.py | 9 +- .../{ => test_common}/test_dinst/__init__.py | 0 .../test_dinst/test_dinstruction.py | 0 .../test_dinst/test_dkeygen.py | 0 .../test_dinst/test_dload.py | 0 .../test_dinst/test_dstore.py | 0 .../{ => test_common}/test_dinst/test_init.py | 0 .../test_prep/test_kernel_splitter.py | 476 ++++++++++++++++++ .../tests/unit_tests/test_he_as.py | 247 +++++++++ .../tests/unit_tests/test_he_prep.py | 361 ++++++++++--- 10 files changed, 1023 insertions(+), 70 deletions(-) rename assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/{ => test_common}/test_dinst/__init__.py (100%) rename assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/{ => test_common}/test_dinst/test_dinstruction.py (100%) rename assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/{ => test_common}/test_dinst/test_dkeygen.py (100%) rename assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/{ => test_common}/test_dinst/test_dload.py (100%) rename assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/{ => test_common}/test_dinst/test_dstore.py (100%) rename assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/{ => test_common}/test_dinst/test_init.py (100%) create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_prep/test_kernel_splitter.py create mode 100644 assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_as.py diff --git a/assembler_tools/hec-assembler-tools/he_prep.py b/assembler_tools/hec-assembler-tools/he_prep.py index 8c2fd548..02c72c5e 100644 --- a/assembler_tools/hec-assembler-tools/he_prep.py +++ b/assembler_tools/hec-assembler-tools/he_prep.py @@ -74,6 +74,9 @@ def main(args): - interchange """ + strategy = getattr(args, "strategy", "largest_first") + interchange = getattr(args, "interchange", False) + GlobalConfig.debugVerbose = args.verbose # used for timings @@ -97,7 +100,11 @@ def main(args): if args.verbose > 0: print("Assigning register banks to variables...") preprocessor.assign_register_banks_to_vars( - hec_mem_model, insts_listing, use_bank0=False, strategy=args.strategy, interchange=args.interchange + hec_mem_model, + insts_listing, + use_bank0=False, + strategy=strategy, + interchange=interchange, ) # Determine output file name diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/__init__.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/__init__.py similarity index 100% rename from assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/__init__.py rename to assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/__init__.py diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/test_dinstruction.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/test_dinstruction.py similarity index 100% rename from assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/test_dinstruction.py rename to assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/test_dinstruction.py diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/test_dkeygen.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/test_dkeygen.py similarity index 100% rename from assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/test_dkeygen.py rename to assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/test_dkeygen.py diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/test_dload.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/test_dload.py similarity index 100% rename from assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/test_dload.py rename to assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/test_dload.py diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/test_dstore.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/test_dstore.py similarity index 100% rename from assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/test_dstore.py rename to assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/test_dstore.py diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/test_init.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/test_init.py similarity index 100% rename from assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_dinst/test_init.py rename to assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_common/test_dinst/test_init.py diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_prep/test_kernel_splitter.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_prep/test_kernel_splitter.py new file mode 100644 index 00000000..f4c26dcf --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_assembler/test_prep/test_kernel_splitter.py @@ -0,0 +1,476 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +""" +@brief Unit tests for kernel_splitter module. +""" + +import json +from unittest import mock + +import networkx as nx +import pytest +from assembler.common.dinst import DLoad, DStore +from assembler.instructions.xinst import NTT +from assembler.memory_model.variable import Variable +from assembler.stages.prep.kernel_splitter import KernelSplitter + + +@pytest.fixture +def splitter(): + """Provide a fresh KernelSplitter instance for each test.""" + return KernelSplitter() + + +@pytest.fixture +def sample_mem_file(tmp_path): + """Create a sample .mem file for testing.""" + mem_file = tmp_path / "test.mem" + mem_file.write_text( + """# Header line 1 +# Header line 2 +dload, poly, 0, var1 +dload, poly, 1, var2 +dload, ones, 2, ones_1 +dload, ntt_auxiliary_table, 0, ntt_auxiliary_table_0 +dload, twid, 4, twid_0 +dload, poly, 5, var3 +dload, poly, 6, var4 +dload, poly, 10, var5 +dload, poly, 11, var6 +dload, poly, 12, var7 +dload, poly, 12, var8 +dload, poly, 12, var9 +# After 12 vars - these are treated as kernel I/O +dload, poly, 7, input_var +dstore, output_var, 8 +dstore, result_var, 9 +""" + ) + return mem_file + + +class TestKernelSplitterInit: + """Tests for KernelSplitter initialization.""" + + def test_init_creates_empty_collections(self, splitter): + assert len(splitter.commons) == 0 + assert len(splitter.inputs) == 0 + assert len(splitter.outputs) == 0 + assert len(splitter._ext_vars) == 0 + + +class TestLoadMemFile: + """Tests for load_mem_file method.""" + + def test_load_valid_mem_file(self, splitter, sample_mem_file): + dinstrs = splitter.load_mem_file(sample_mem_file) + assert len(dinstrs) > 0 + assert any(isinstance(d, DLoad) for d in dinstrs) + assert any(isinstance(d, DStore) for d in dinstrs) + + def test_load_nonexistent_file_raises_error(self, splitter): + with pytest.raises(FileNotFoundError, match=".mem file not found"): + splitter.load_mem_file("nonexistent.mem") + + def test_load_identifies_commons(self, splitter, sample_mem_file): + splitter.load_mem_file(sample_mem_file) + assert "ones_1" in splitter.commons + assert "ntt_auxiliary_table_0" in splitter.commons + assert "twid_0" in splitter.commons + + def test_load_identifies_inputs_outputs(self, splitter, sample_mem_file): + splitter.load_mem_file(sample_mem_file) + assert "input_var" in splitter.inputs + assert "output_var" in splitter.outputs + assert "result_var" in splitter.outputs + + def test_load_invalid_line_raises_error(self, splitter, tmp_path): + bad_file = tmp_path / "bad.mem" + bad_file.write_text("invalid line format\n") + with pytest.raises(RuntimeError, match="No valid instruction found"): + splitter.load_mem_file(bad_file) + + +class TestCommonVariableDetection: + """Tests for common variable identification.""" + + def test_is_common_var_detects_common_prefix(self, splitter): + assert splitter._is_common_var("common_var") + assert splitter._is_common_var("ntt_table") + assert splitter._is_common_var("intt_data") + assert splitter._is_common_var("twid_factor") + assert splitter._is_common_var("ones_vector") + + def test_is_common_var_rejects_non_common(self, splitter): + assert not splitter._is_common_var("var1") + assert not splitter._is_common_var("input_data") + assert not splitter._is_common_var("") + + def test_register_common_var_adds_to_commons(self, splitter): + splitter._register_common_var("common_test") + assert "common_test" in splitter.commons + + def test_register_common_var_ignores_non_common(self, splitter): + splitter._register_common_var("regular_var") + assert "regular_var" not in splitter.commons + + +class TestBuildDependencyGraph: + """Tests for build_instrs_dependency_graph method.""" + + def test_build_empty_graph(self, splitter): + graph = splitter.build_instrs_dependency_graph([]) + assert graph.number_of_nodes() == 0 + assert graph.number_of_edges() == 0 + + def test_build_graph_creates_nodes(self, splitter): + mock_inst1 = mock.Mock(sources=[], dests=[]) + mock_inst2 = mock.Mock(sources=[], dests=[]) + graph = splitter.build_instrs_dependency_graph([mock_inst1, mock_inst2]) + assert graph.number_of_nodes() == 2 + assert 0 in graph.nodes + assert 1 in graph.nodes + + def test_build_graph_creates_edges_for_dependencies(self, splitter): + var1 = mock.Mock(name="var1") + var2 = mock.Mock(name="var2") + + inst1 = mock.Mock(sources=[], dests=[var1]) + inst2 = mock.Mock(sources=[var1], dests=[var2]) + inst3 = mock.Mock(sources=[var2], dests=[]) + + graph = splitter.build_instrs_dependency_graph([inst1, inst2, inst3]) + + assert graph.has_edge(0, 1) # inst1 -> inst2 + assert graph.has_edge(1, 2) # inst2 -> inst3 + + def test_build_graph_handles_ntt_special_case(self, splitter): + # NTT requires actual Variable objects for dest/src + dest1 = Variable("dest1", 0) + dest2 = Variable("dest2", 1) + src1 = Variable("src1", 0) + src2 = Variable("src2", 1) + src3 = Variable("src3", 2) + stage = 0 + res = 1 + + ntt1 = NTT(0, 2, [dest1, dest2], [src1, src2, src3], stage, res, 6, 6, "") + ntt2 = NTT(1, 2, [dest1, dest2], [src1, src2, src3], stage, res, 6, 6, "") + + graph = splitter.build_instrs_dependency_graph([ntt1, ntt2]) + + assert graph.has_edge(0, 1) + assert graph[0][1]["weight"] == 5 # Special weight for NTT chains + + def test_build_graph_tracks_external_vars(self, splitter): + splitter._inputs.add("input1") + var = mock.Mock() + var.name = "input1" # Set as attribute, not Mock + inst = mock.Mock(sources=[var], dests=[]) + + splitter.build_instrs_dependency_graph([inst]) + + assert 0 in splitter._ext_vars + assert "input1" in splitter._ext_vars[0] + + +class TestGetExternalMemUsage: + """Tests for _get_external_mem_usage method.""" + + def test_empty_instruction_set_returns_zero(self, splitter): + mem_usage, ext_vars = splitter._get_external_mem_usage(set()) + assert mem_usage == 0 + assert len(ext_vars) == 0 + + def test_calculates_external_var_usage(self, splitter): + splitter._ext_vars[0] = {"var1", "var2"} + splitter._ext_vars[1] = {"var2", "var3"} + + mem_usage, ext_vars = splitter._get_external_mem_usage({0, 1}) + + assert mem_usage == 3 # var1, var2, var3 + assert ext_vars == {"var1", "var2", "var3"} + + def test_includes_commons_in_calculation(self, splitter): + splitter._commons = {"common1", "common2"} + splitter._ext_vars[0] = {"var1"} + + mem_usage, ext_vars = splitter._get_external_mem_usage({0}) + + assert mem_usage == 3 # var1 + 2 commons + assert "common1" not in ext_vars # Commons not in external vars + assert "common2" not in ext_vars + + def test_custom_var_size_map(self, splitter): + splitter._ext_vars[0] = {"large_var"} + var_size_map = {"large_var": 10} + + mem_usage, _ = splitter._get_external_mem_usage({0}, var_size_map) + + assert mem_usage == 10 + + +class TestGetIsolatedInstrsSplits: + """Tests for get_isolated_instrs_splits method.""" + + def test_empty_graph_returns_empty_splits(self, splitter): + graph = nx.DiGraph() + splits, externals = splitter.get_isolated_instrs_splits(graph, 100, 100) + assert splits == [] + assert externals == [] + + def test_single_component_under_limits(self, splitter): + graph = nx.DiGraph() + graph.add_nodes_from([0, 1, 2]) + graph.add_edge(0, 1) + graph.add_edge(1, 2) + + splitter._ext_vars = {0: {"var1"}, 1: {"var2"}, 2: {"var3"}} + + splits, externals = splitter.get_isolated_instrs_splits(graph, 10, 10) + + assert len(splits) == 1 + assert splits[0] == {0, 1, 2} + + def test_multiple_isolated_components(self, splitter): + graph = nx.DiGraph() + graph.add_nodes_from([0, 1, 2, 3]) + graph.add_edge(0, 1) # Component 1 + graph.add_edge(2, 3) # Component 2 + + splitter._ext_vars = {i: {f"var{i}"} for i in range(4)} + + splits, externals = splitter.get_isolated_instrs_splits(graph, 10, 10) + + assert len(splits) >= 1 + + def test_exceeds_instruction_limit_returns_none(self, splitter): + graph = nx.DiGraph() + graph.add_nodes_from(range(10)) + for i in range(9): + graph.add_edge(i, i + 1) + + splitter._ext_vars = {i: {f"var{i}"} for i in range(10)} + + splits, externals = splitter.get_isolated_instrs_splits(graph, 5, 100) + + assert splits is None + assert externals is None + + +class TestCommunityDetection: + """Tests for get_community_instrs_splits method.""" + + def test_detects_communities_in_graph(self, splitter): + graph = nx.DiGraph() + # Create two weakly connected components + graph.add_edges_from([(0, 1), (1, 2)]) + graph.add_edges_from([(3, 4), (4, 5)]) + + splitter._ext_vars = {i: {f"var{i}"} for i in range(6)} + + result = splitter.get_community_instrs_splits(graph, 10, 10) + + assert result is not None + splits, externals, out_refs = result + assert len(splits) >= 1 + + +class TestSplitMemInfo: + """Tests for split_mem_info method.""" + + def test_splits_mem_file_into_multiple_files(self, splitter, tmp_path): + mem_file = tmp_path / "test.mem" + mem_file.write_text("dload poly 0 var1\ndstore var1 1\n") + + dinstr1 = DLoad(["dload", "poly", "0", "var1"], "") + dinstr2 = DStore(["dstore", "var1", "1"], "") + + externals = [{"var1"}] + + splitter.split_mem_info(mem_file, [dinstr1, dinstr2], externals) + + expected_file = tmp_path / "test_0.mem" + assert expected_file.exists() + + def test_writes_dependency_maps(self, splitter, tmp_path): + mem_file = tmp_path / "test.mem" + mem_file.write_text("dload poly 0 var1\n") + + dinstr = DLoad(["dload", "poly", "0", "var1"], "") + externals = [{"var1"}, {"var1"}] + + splitter.split_mem_info(mem_file, [dinstr], externals) + + dep_file = tmp_path / "test_deps_0.json" + assert dep_file.exists() + + with dep_file.open() as f: + deps = json.load(f) + assert isinstance(deps, dict) + + +class TestRenameVarsInSplits: + """Tests for rename_vars_in_splits method.""" + + def test_renames_boundary_variables(self, splitter): + var1 = Variable("var1", 1) + var2 = Variable("var2", 0) + + inst1 = mock.Mock(sources=[], dests=[var1]) + inst2 = mock.Mock(sources=[var1], dests=[var2]) + + instr_sets = [{0}, {1}] + out_refs = [{("var1", 0): {1: {1}}}] + + new_outs = splitter.rename_vars_in_splits([inst1, inst2], instr_sets, out_refs) + + assert len(new_outs) == 1 + # Verify variable was renamed + renamed_var = inst1.dests[0] + assert renamed_var.name.startswith("var1_dep_0_0") + + def test_handles_empty_out_refs(self, splitter): + inst1 = mock.Mock(sources=[], dests=[]) + instr_sets = [{0}] + out_refs = [] + + new_outs = splitter.rename_vars_in_splits([inst1], instr_sets, out_refs) + + assert new_outs == [] + + +class TestPrepareInstructionSplits: + """Tests for prepare_instruction_splits method.""" + + def test_full_splitting_pipeline(self, splitter, sample_mem_file, monkeypatch): + # Mock args + args = mock.Mock( + mem_file=str(sample_mem_file), + output_file_name="output.pisa", + split_inst_limit=10, + split_vars_limit=10, + ) + + # Create mock instructions + var1 = Variable("var1", 1) + inst1 = mock.Mock(sources=[], dests=[var1]) + inst1.to_pisa_format = mock.Mock(return_value="inst1") + inst2 = mock.Mock(sources=[var1], dests=[]) + inst2.to_pisa_format = mock.Mock(return_value="inst2") + + insts_listing = [inst1, inst2] + + # Mock verbose output + monkeypatch.setattr("assembler.common.config.GlobalConfig.debugVerbose", 0) + + result = splitter.prepare_instruction_splits(args, insts_listing) + + assert isinstance(result, list) + assert len(result) > 0 + for split_insts, output_path in result: + assert isinstance(split_insts, list) + assert isinstance(output_path, str) + + +class TestCommunityGraphBuilding: + """Tests for community graph building helpers.""" + + def test_map_instr_to_community(self, splitter): + communities = [{0, 1}, {2, 3}, {4}] + mapping = splitter._map_instr_to_community(communities) + + assert mapping[0] == 0 + assert mapping[1] == 0 + assert mapping[2] == 1 + assert mapping[3] == 1 + assert mapping[4] == 2 + + def test_build_community_graph_creates_edges(self, splitter): + base_graph = nx.DiGraph() + # Add all nodes that will be referenced + base_graph.add_nodes_from([0, 1, 2, 3]) + base_graph.add_edge(0, 2, vars={"shared_var"}) + + communities = [{0, 1}, {2, 3}] + instr_to_set = {0: 0, 1: 0, 2: 1, 3: 1} + + comm_graph = splitter._build_community_graph(base_graph, communities, instr_to_set) + + assert comm_graph.has_edge(0, 1) + assert "var_refs" in comm_graph[0][1] + + def test_condense_removes_cycles(self, splitter): + graph = nx.DiGraph() + graph.add_edges_from([(0, 1), (1, 2), (2, 0)]) # Cycle + + communities = [{0}, {1}, {2}] + + new_sets, condensed = splitter._condense_community_graph(graph, communities) + + assert nx.is_directed_acyclic_graph(condensed) + + +class TestMergingStrategies: + """Tests for in-generation and cross-generation merging.""" + + def test_rank_pairs_by_common_io(self, splitter): + graph = nx.DiGraph() + graph.add_edges_from( + [ + (10, 0, {"weight": 5}), + (10, 1, {"weight": 3}), + (0, 20, {"weight": 2}), + (1, 20, {"weight": 4}), + ] + ) + + generation = [0, 1] + ranked = splitter._rank_generation_pairs_by_common_io(graph, generation) + + assert (0, 1) in ranked + assert ranked[(0, 1)] > 0 # Should have common pred/succ + + def test_rebuild_generation_pairs_after_merge(self, splitter): + cluster_sets = {0: {0, 1}, 2: {2, 3}} + ranked_pairs = {(0, 2): 5.0, (1, 2): 3.0, (0, 3): 2.0} + + rebuilt = splitter._rebuild_generation_pairs(cluster_sets, ranked_pairs) + + assert (0, 2) in rebuilt + # Score should aggregate individual member scores + assert rebuilt[(0, 2)] == 5.0 + 3.0 + 2.0 + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_handles_instructions_without_sources_or_dests(self, splitter): + inst = mock.Mock(sources=[], dests=[]) + graph = splitter.build_instrs_dependency_graph([inst]) + assert graph.number_of_nodes() == 1 + + def test_handles_variables_without_names(self, splitter): + var_no_name = mock.Mock(name=None) + inst = mock.Mock(sources=[var_no_name], dests=[]) + graph = splitter.build_instrs_dependency_graph([inst]) + assert graph.number_of_edges() == 0 + + def test_split_with_no_available_splits_raises_error(self, splitter, sample_mem_file): + args = mock.Mock( + mem_file=str(sample_mem_file), + output_file_name="output.pisa", + split_inst_limit=1, # Impossible limit + split_vars_limit=1, + ) + + var1 = Variable("var1", 0) + inst1 = mock.Mock(sources=[], dests=[var1]) + inst2 = mock.Mock(sources=[var1], dests=[]) + + with pytest.raises(RuntimeError, match="Final instruction splits do not cover all instructions."): + splitter.prepare_instruction_splits(args, [inst1, inst2]) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_as.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_as.py new file mode 100644 index 00000000..9bc6cfe2 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_as.py @@ -0,0 +1,247 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# These contents may have been developed with support from one or more Intel-operated +# generative artificial intelligence solutions + +import os +import sys +from unittest import mock + +import he_as +import pytest + + +def test_parse_args_parses_all_flags(monkeypatch): + test_args = [ + "prog", + "kernel.tw", + "--isa_spec", + "isa.json", + "--mem_spec", + "mem.json", + "--input_mem_file", + "custom.mem", + "--output_dir", + "out", + "--output_prefix", + "pref", + "--spad_size", + "32", + "--hbm_size", + "64", + "--no_hbm", + "--repl_policy", + he_as.constants.Constants.REPLACEMENT_POLICIES[0], + "--use_xinstfetch", + "--suppress_comments", + "-vv", + ] + monkeypatch.setattr(sys, "argv", test_args) + args = he_as.parse_args() + + assert args.input_file == "kernel.tw" + assert args.isa_spec_file == "isa.json" + assert args.mem_spec_file == "mem.json" + assert args.input_mem_file == "custom.mem" + assert args.output_dir == "out" + assert args.output_prefix == "pref" + assert args.spad_size == 32 + assert args.hbm_size == 64 + assert args.has_hbm is False + assert args.repl_policy == he_as.constants.Constants.REPLACEMENT_POLICIES[0] + assert args.use_xinstfetch is True + assert args.suppress_comments is True + assert args.debug_verbose == 2 + + +def test_run_config_derives_mem_file(tmp_path): + input_file = tmp_path / "kernel.tw" + input_file.write_text("") + config = he_as.AssemblerRunConfig(input_file=str(input_file)) + + assert config.input_file == str(input_file) + assert config.input_mem_file == str(input_file.with_suffix(".mem")) + assert config.output_dir == os.path.dirname(os.path.realpath(str(input_file))) + assert config.input_prefix == "kernel" + + +def test_main_invokes_assembler_and_creates_outputs(tmp_path, monkeypatch): + input_file = tmp_path / "kernel.tw" + input_file.write_text("dummy") + mem_file = tmp_path / "kernel.mem" + mem_file.write_text("mem") + output_dir = tmp_path / "outputs" + config = he_as.AssemblerRunConfig( + input_file=str(input_file), + input_mem_file=str(mem_file), + output_dir=str(output_dir), + output_prefix="result", + ) + + asm_mock = mock.Mock(return_value=(1, 2, 3, 4, 5)) + monkeypatch.setattr(he_as, "asmisaAssemble", asm_mock) + he_as.main(config, verbose=False) + + assert asm_mock.called + copied_config = asm_mock.call_args.args[0] + assert copied_config is not config + for ext in ("minst", "cinst", "xinst"): + assert (tmp_path / f"outputs/result.{ext}").is_file() + + +def test_run_config_requires_input_file(): + """Test that AssemblerRunConfig raises TypeError when input_file is missing.""" + with pytest.raises(TypeError, match="Expected value for configuration `input_file`"): + he_as.AssemblerRunConfig() + + +def test_run_config_defaults(tmp_path): + """Test that AssemblerRunConfig sets sensible defaults.""" + input_file = tmp_path / "kernel.tw" + input_file.write_text("") + config = he_as.AssemblerRunConfig(input_file=str(input_file)) + + assert config.has_hbm is True + assert config.hbm_size == he_as.AssemblerRunConfig.DEFAULT_HBM_SIZE_KB + assert config.spad_size == he_as.AssemblerRunConfig.DEFAULT_SPAD_SIZE_KB + assert config.repl_policy == he_as.AssemblerRunConfig.DEFAULT_REPL_POLICY + + +def test_run_config_custom_output_dir(tmp_path): + """Test that custom output_dir is respected.""" + input_file = tmp_path / "kernel.tw" + input_file.write_text("") + custom_dir = tmp_path / "custom_output" + + config = he_as.AssemblerRunConfig( + input_file=str(input_file), + output_dir=str(custom_dir), + ) + + assert config.output_dir == str(custom_dir) + + +def test_main_creates_output_directory(tmp_path, monkeypatch): + """Test that main creates output directory if it doesn't exist.""" + input_file = tmp_path / "kernel.tw" + input_file.write_text("dummy") + mem_file = tmp_path / "kernel.mem" + mem_file.write_text("mem") + output_dir = tmp_path / "new_outputs" + + config = he_as.AssemblerRunConfig( + input_file=str(input_file), + input_mem_file=str(mem_file), + output_dir=str(output_dir), + ) + + asm_mock = mock.Mock(return_value=(1, 2, 3, 4, 5)) + monkeypatch.setattr(he_as, "asmisaAssemble", asm_mock) + + assert not output_dir.exists() + he_as.main(config, verbose=False) + assert output_dir.exists() + + +def test_main_output_not_writable_raises_exception(tmp_path, monkeypatch): + """Test that main raises exception when output location is not writable.""" + input_file = tmp_path / "kernel.tw" + input_file.write_text("dummy") + mem_file = tmp_path / "kernel.mem" + mem_file.write_text("mem") + output_dir = tmp_path / "outputs" + output_dir.mkdir() + os.chmod(output_dir, 0o444) # Read-only + + config = he_as.AssemblerRunConfig( + input_file=str(input_file), + input_mem_file=str(mem_file), + output_dir=str(output_dir), + ) + + with pytest.raises(Exception, match="Failed to write to output location"): + he_as.main(config, verbose=False) + + +def test_main_sets_global_config(tmp_path, monkeypatch): + """Test that main correctly sets GlobalConfig values.""" + input_file = tmp_path / "kernel.tw" + input_file.write_text("dummy") + mem_file = tmp_path / "kernel.mem" + mem_file.write_text("mem") + + config = he_as.AssemblerRunConfig( + input_file=str(input_file), + input_mem_file=str(mem_file), + has_hbm=False, + use_xinstfetch=True, + suppress_comments=True, + debug_verbose=2, + ) + + asm_mock = mock.Mock(return_value=(1, 2, 3, 4, 5)) + monkeypatch.setattr(he_as, "asmisaAssemble", asm_mock) + + he_as.main(config, verbose=False) + + assert he_as.GlobalConfig.hasHBM is False + assert he_as.GlobalConfig.useXInstFetch is True + assert he_as.GlobalConfig.suppress_comments is True + assert he_as.GlobalConfig.debugVerbose == 2 + + +def test_main_verbose_output(tmp_path, monkeypatch, capsys): + """Test that main prints verbose output when enabled.""" + input_file = tmp_path / "kernel.tw" + input_file.write_text("dummy") + mem_file = tmp_path / "kernel.mem" + mem_file.write_text("mem") + + config = he_as.AssemblerRunConfig( + input_file=str(input_file), + input_mem_file=str(mem_file), + ) + + asm_mock = mock.Mock(return_value=(10, 2, 5, 1.5, 2.5)) + monkeypatch.setattr(he_as, "asmisaAssemble", asm_mock) + + he_as.main(config, verbose=True) + + captured = capsys.readouterr() + assert "Output:" in captured.out + assert "Total XInstructions: 10" in captured.out + assert "Deps time: 1.5" in captured.out + assert "Scheduling time: 2.5" in captured.out + assert "Minimum idle cycles: 5" in captured.out + assert "Minimum nops required: 2" in captured.out + + +def test_config_as_dict(tmp_path): + """Test that AssemblerRunConfig.as_dict returns all config values.""" + input_file = tmp_path / "kernel.tw" + input_file.write_text("") + + config = he_as.AssemblerRunConfig( + input_file=str(input_file), + has_hbm=False, + hbm_size=128, + ) + + config_dict = config.as_dict() + assert "input_file" in config_dict + assert "has_hbm" in config_dict + assert config_dict["has_hbm"] is False + assert config_dict["hbm_size"] == 128 + + +def test_config_str_representation(tmp_path): + """Test that AssemblerRunConfig has a string representation.""" + input_file = tmp_path / "kernel.tw" + input_file.write_text("") + + config = he_as.AssemblerRunConfig(input_file=str(input_file)) + config_str = str(config) + + assert "input_file" in config_str + assert str(input_file) in config_str diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py index a0387c30..501ef44f 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py @@ -8,6 +8,7 @@ @brief Unit tests for he_prep module. """ +import io import os import pathlib import sys @@ -17,6 +18,22 @@ import pytest +def _make_args(**overrides): + defaults = { + "input_file_name": "", + "output_file_name": "", + "mem_file": "", + "verbose": 0, + "split_on": False, + "split_inst_limit": float("inf"), + "split_vars_limit": float("inf"), + "strategy": "largest_first", + "interchange": False, + } + defaults.update(overrides) + return mock.Mock(**defaults) + + def test_main_assigns_and_saves(monkeypatch, tmp_path): """ @brief Test that the main function assigns register banks, processes instructions, and saves the output. @@ -41,16 +58,9 @@ def test_main_assigns_and_saves(monkeypatch, tmp_path): monkeypatch.setattr(he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock()) he_prep.main( - mock.Mock( - **{ - "input_file_name": str(input_file), - "output_file_name": str(output_file), - "mem_file": "", - "verbose": 0, - "split_on": False, - "split_inst_limit": float("inf"), - "split_vars_limit": float("inf"), - } + _make_args( + input_file_name=str(input_file), + output_file_name=str(output_file), ) ) # Output file should contain the instruction @@ -62,19 +72,7 @@ def test_main_no_input_file(): @brief Test that main raises an error when no input file is provided. """ with pytest.raises(FileNotFoundError): - he_prep.main( - mock.Mock( - **{ - "input_file_name": "", - "output_file_name": "", - "mem_file": "", - "verbose": 0, - "split_on": False, - "split_inst_limit": float("inf"), - "split_vars_limit": float("inf"), - } - ) - ) # Should raise an error due to missing input file + he_prep.main(_make_args()) # Should raise an error due to missing input file def test_main_no_output_file(): @@ -82,19 +80,7 @@ def test_main_no_output_file(): @brief Test that main raises an error when no output file is provided. """ with pytest.raises(FileNotFoundError): - he_prep.main( - mock.Mock( - **{ - "input_file_name": "input.csv", - "output_file_name": "", - "mem_file": "", - "verbose": 0, - "split_on": False, - "split_inst_limit": float("inf"), - "split_vars_limit": float("inf"), - } - ) - ) # Should raise an error due to missing output file + he_prep.main(_make_args(input_file_name="input.csv")) # Should raise an error due to missing output file def test_main_no_instructions(monkeypatch): @@ -119,16 +105,9 @@ def test_main_no_instructions(monkeypatch): monkeypatch.setattr(he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock()) he_prep.main( - mock.Mock( - **{ - "input_file_name": input_file, - "output_file_name": output_file, - "mem_file": "", - "verbose": 0, - "split_on": False, - "split_inst_limit": float("inf"), - "split_vars_limit": float("inf"), - } + _make_args( + input_file_name=input_file, + output_file_name=output_file, ) ) @@ -146,16 +125,9 @@ def test_main_invalid_input_file(tmp_path): with pytest.raises(FileNotFoundError): he_prep.main( - mock.Mock( - **{ - "input_file_name": str(input_file), - "output_file_name": str(output_file), - "mem_file": "", - "verbose": 0, - "split_on": False, - "split_inst_limit": float("inf"), - "split_vars_limit": float("inf"), - } + _make_args( + input_file_name=str(input_file), + output_file_name=str(output_file), ) ) # Should raise an error due to missing input file @@ -176,20 +148,49 @@ def test_main_invalid_output_file(tmp_path): with pytest.raises(PermissionError): he_prep.main( - mock.Mock( - **{ - "input_file_name": str(input_file), - "output_file_name": str(output_file), - "mem_file": "", - "verbose": 0, - "split_on": False, - "split_inst_limit": float("inf"), - "split_vars_limit": float("inf"), - } + _make_args( + input_file_name=str(input_file), + output_file_name=str(output_file), ) ) # Should raise an error due to permission issues +def test_main_respects_coloring_strategy(monkeypatch, tmp_path): + """ + @brief Test that main respects the coloring strategy and interchange options. + + @details This test verifies that the assigned register banks strategy and interchange options + are correctly passed to the assign_register_banks_to_vars function. + """ + input_file = tmp_path / "input.csv" + input_file.write_text("dummy") + output_file = tmp_path / "output.csv" + dummy_model = object() + dummy_insts = [mock.Mock(to_pisa_format=mock.Mock(return_value="inst"))] + monkeypatch.setattr(he_prep, "MemoryModel", mock.Mock(return_value=dummy_model)) + monkeypatch.setattr( + he_prep.preprocessor, + "preprocess_pisa_kernel_listing", + mock.Mock(return_value=dummy_insts), + ) + assign_mock = mock.Mock() + monkeypatch.setattr(he_prep.preprocessor, "assign_register_banks_to_vars", assign_mock) + + he_prep.main( + _make_args( + input_file_name=str(input_file), + output_file_name=str(output_file), + strategy="smallest_last", + interchange=True, + ) + ) + + assign_mock.assert_called_once() + _, kwargs = assign_mock.call_args + assert kwargs["strategy"] == "smallest_last" + assert kwargs["interchange"] is True + + def test_parse_args(): """ @brief Test that parse_args returns the expected arguments. @@ -202,13 +203,235 @@ def test_parse_args(): "isa.json", "--mem_spec", "mem.json", - "--verbose", + "--mem_file", + "kernel.mem", + "--split_vars_limit", + "10", + "--split_inst_limit", + "5", + "--strategy", + "smallest_last", + "--interchange", + "-vv", ] with mock.patch.object(sys, "argv", test_args): args = he_prep.parse_args() - assert args.output_file_name == "output.csv" assert args.input_file_name == "input.csv" + assert args.output_file_name == "output.csv" assert args.isa_spec_file == "isa.json" assert args.mem_spec_file == "mem.json" - assert args.verbose == 1 + assert args.mem_file == "kernel.mem" + assert args.split_vars_limit == 10.0 + assert args.split_inst_limit == 5.0 + assert args.split_on is True + assert args.strategy == "smallest_last" + assert args.interchange is True + assert args.verbose == 2 + + +def test_save_pisa_listing(): + """ + @brief Test that save_pisa_listing writes instructions in correct format. + """ + mock_inst1 = mock.Mock(to_pisa_format=mock.Mock(return_value="instruction1")) + mock_inst2 = mock.Mock(to_pisa_format=mock.Mock(return_value="instruction2")) + mock_inst3 = mock.Mock(to_pisa_format=mock.Mock(return_value="")) # Empty line should be skipped + + output = io.StringIO() + he_prep.save_pisa_listing(output, [mock_inst1, mock_inst2, mock_inst3]) + + result = output.getvalue() + assert "instruction1\n" in result + assert "instruction2\n" in result + assert result.count("\n") == 2 # Only 2 instructions written + + +def test_main_derives_default_output_filename(monkeypatch, tmp_path): + """ + @brief Test that main derives output filename from input when not provided. + """ + input_file = tmp_path / "kernel.pisa" + input_file.write_text("dummy") + + dummy_model = object() + dummy_insts = [mock.Mock(to_pisa_format=mock.Mock(return_value="inst"))] + + monkeypatch.setattr(he_prep, "MemoryModel", mock.Mock(return_value=dummy_model)) + monkeypatch.setattr( + he_prep.preprocessor, + "preprocess_pisa_kernel_listing", + mock.Mock(return_value=dummy_insts), + ) + monkeypatch.setattr(he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock()) + + he_prep.main( + _make_args( + input_file_name=str(input_file), + output_file_name="", # Not provided + ) + ) + + # Should create kernel.tw.pisa + expected_output = tmp_path / "kernel.tw.pisa" + assert expected_output.exists() + + +def test_main_with_kernel_splitting(monkeypatch, tmp_path): + """ + @brief Test that main handles kernel splitting when split_on is True. + """ + input_file = tmp_path / "kernel.pisa" + input_file.write_text("dummy") + output_file = tmp_path / "output.pisa" + + dummy_model = object() + dummy_insts = [mock.Mock(to_pisa_format=mock.Mock(return_value="inst"))] + + monkeypatch.setattr(he_prep, "MemoryModel", mock.Mock(return_value=dummy_model)) + monkeypatch.setattr( + he_prep.preprocessor, + "preprocess_pisa_kernel_listing", + mock.Mock(return_value=dummy_insts), + ) + monkeypatch.setattr(he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock()) + + # Mock KernelSplitter + mock_splitter = mock.Mock() + split_file1 = tmp_path / "split1.pisa" + split_file2 = tmp_path / "split2.pisa" + mock_splitter.prepare_instruction_splits.return_value = [ + (dummy_insts, str(split_file1)), + (dummy_insts, str(split_file2)), + ] + monkeypatch.setattr(he_prep, "KernelSplitter", mock.Mock(return_value=mock_splitter)) + + he_prep.main( + _make_args( + input_file_name=str(input_file), + output_file_name=str(output_file), + split_on=True, + split_inst_limit=10.0, + split_vars_limit=5.0, + ) + ) + + # Verify splitter was called + mock_splitter.prepare_instruction_splits.assert_called_once() + # Verify both split files were created + assert split_file1.exists() + assert split_file2.exists() + + +def test_main_verbose_output(monkeypatch, tmp_path, capsys): + """ + @brief Test that main prints verbose output when verbose flag is set. + """ + input_file = tmp_path / "kernel.pisa" + input_file.write_text("dummy") + output_file = tmp_path / "output.pisa" + + dummy_model = object() + dummy_insts = [mock.Mock(to_pisa_format=mock.Mock(return_value="inst"))] * 5 + + monkeypatch.setattr(he_prep, "MemoryModel", mock.Mock(return_value=dummy_model)) + monkeypatch.setattr( + he_prep.preprocessor, + "preprocess_pisa_kernel_listing", + mock.Mock(return_value=dummy_insts), + ) + monkeypatch.setattr(he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock()) + + he_prep.main( + _make_args( + input_file_name=str(input_file), + output_file_name=str(output_file), + verbose=1, + ) + ) + + captured = capsys.readouterr() + assert "Assigning register banks to variables..." in captured.out + assert "Instructions in input: 5" in captured.out + assert "Saving..." in captured.out + assert "Output:" in captured.out + assert "Instructions in output: 5" in captured.out + assert "Generation time:" in captured.out + + +def test_main_defaults_strategy_and_interchange(monkeypatch, tmp_path): + """ + @brief Test that main uses default values for strategy and interchange when not provided. + """ + input_file = tmp_path / "kernel.pisa" + input_file.write_text("dummy") + output_file = tmp_path / "output.pisa" + + dummy_model = object() + dummy_insts = [mock.Mock(to_pisa_format=mock.Mock(return_value="inst"))] + + monkeypatch.setattr(he_prep, "MemoryModel", mock.Mock(return_value=dummy_model)) + monkeypatch.setattr( + he_prep.preprocessor, + "preprocess_pisa_kernel_listing", + mock.Mock(return_value=dummy_insts), + ) + assign_mock = mock.Mock() + monkeypatch.setattr(he_prep.preprocessor, "assign_register_banks_to_vars", assign_mock) + + # Create args without strategy/interchange using argparse.Namespace + # This properly raises AttributeError for missing attributes + args = type( + "Args", + (), + { + "input_file_name": str(input_file), + "output_file_name": str(output_file), + "mem_file": "", + "verbose": 0, + "split_on": False, + }, + )() + + he_prep.main(args) + + assign_mock.assert_called_once() + _, kwargs = assign_mock.call_args + assert kwargs["strategy"] == "largest_first" # Default + assert kwargs["interchange"] is False # Default + + +def test_parse_args_defaults(): + """ + @brief Test that parse_args sets correct defaults when optional args not provided. + """ + test_args = ["prog", "input.csv"] + with mock.patch.object(sys, "argv", test_args): + args = he_prep.parse_args() + + assert args.input_file_name == "input.csv" + assert args.output_file_name is None + assert args.isa_spec_file == "" + assert args.mem_spec_file == "" + assert args.mem_file == "" + assert args.split_vars_limit == float("inf") + assert args.split_inst_limit == float("inf") + assert args.split_on is False + assert args.strategy == "largest_first" + assert args.interchange is False + assert args.verbose == 0 + + +def test_parse_args_split_on_without_mem_file_fails(): + """ + @brief Test that parse_args raises assertion error when split_on but no mem_file. + """ + test_args = [ + "prog", + "input.csv", + "--split_inst_limit", + "10", + ] + with mock.patch.object(sys, "argv", test_args): + with pytest.raises(AssertionError, match="--mem_file must be specified"): + he_prep.parse_args() From 1c47251c8c2696ca9b1647f4236623dcc65463ce Mon Sep 17 00:00:00 2001 From: "Rojas Chaves, Jose" Date: Fri, 14 Nov 2025 23:37:27 +0000 Subject: [PATCH 2/2] Reertig changes in he_prep --- .../hec-assembler-tools/he_prep.py | 9 +--- .../tests/unit_tests/test_he_prep.py | 42 ------------------- 2 files changed, 1 insertion(+), 50 deletions(-) diff --git a/assembler_tools/hec-assembler-tools/he_prep.py b/assembler_tools/hec-assembler-tools/he_prep.py index 02c72c5e..8c2fd548 100644 --- a/assembler_tools/hec-assembler-tools/he_prep.py +++ b/assembler_tools/hec-assembler-tools/he_prep.py @@ -74,9 +74,6 @@ def main(args): - interchange """ - strategy = getattr(args, "strategy", "largest_first") - interchange = getattr(args, "interchange", False) - GlobalConfig.debugVerbose = args.verbose # used for timings @@ -100,11 +97,7 @@ def main(args): if args.verbose > 0: print("Assigning register banks to variables...") preprocessor.assign_register_banks_to_vars( - hec_mem_model, - insts_listing, - use_bank0=False, - strategy=strategy, - interchange=interchange, + hec_mem_model, insts_listing, use_bank0=False, strategy=args.strategy, interchange=args.interchange ) # Determine output file name diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py index 501ef44f..55ced0d4 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py @@ -359,48 +359,6 @@ def test_main_verbose_output(monkeypatch, tmp_path, capsys): assert "Generation time:" in captured.out -def test_main_defaults_strategy_and_interchange(monkeypatch, tmp_path): - """ - @brief Test that main uses default values for strategy and interchange when not provided. - """ - input_file = tmp_path / "kernel.pisa" - input_file.write_text("dummy") - output_file = tmp_path / "output.pisa" - - dummy_model = object() - dummy_insts = [mock.Mock(to_pisa_format=mock.Mock(return_value="inst"))] - - monkeypatch.setattr(he_prep, "MemoryModel", mock.Mock(return_value=dummy_model)) - monkeypatch.setattr( - he_prep.preprocessor, - "preprocess_pisa_kernel_listing", - mock.Mock(return_value=dummy_insts), - ) - assign_mock = mock.Mock() - monkeypatch.setattr(he_prep.preprocessor, "assign_register_banks_to_vars", assign_mock) - - # Create args without strategy/interchange using argparse.Namespace - # This properly raises AttributeError for missing attributes - args = type( - "Args", - (), - { - "input_file_name": str(input_file), - "output_file_name": str(output_file), - "mem_file": "", - "verbose": 0, - "split_on": False, - }, - )() - - he_prep.main(args) - - assign_mock.assert_called_once() - _, kwargs = assign_mock.call_args - assert kwargs["strategy"] == "largest_first" # Default - assert kwargs["interchange"] is False # Default - - def test_parse_args_defaults(): """ @brief Test that parse_args sets correct defaults when optional args not provided.