@@ -707,8 +707,7 @@ def __init__(self, netlist: _nir.Netlist, design: Design, *, all_undef_to_ff=Fal
707707 self .drivers = _ast .SignalDict ()
708708 self .io_ports : dict [_ast .IOPort , int ] = {}
709709 self .rhs_cache : dict [int , tuple [_nir .Value , bool , _ast .Value ]] = {}
710- self .matches_cache = {}
711- self .priority_match_cache = {}
710+ self .match_cache = {}
712711 self .fragment_module_idx : dict [Fragment , int ] = {}
713712
714713 # Collected for driver conflict diagnostics only.
@@ -787,24 +786,14 @@ def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src
787786 op = _nir .Operator (module_idx , operator = operator , inputs = inputs , src_loc = src_loc )
788787 return self .netlist .add_value_cell (op .width , op )
789788
790- def emit_matches (self , module_idx : int , value : _nir .Value , patterns , * , src_loc ):
791- key = module_idx , value , patterns , src_loc
789+ def emit_match (self , module_idx : int , en : _nir . Net , value : _nir .Value , patterns , * , src_loc ):
790+ key = module_idx , en , value , patterns , src_loc
792791 try :
793- return self .matches_cache [key ]
792+ return self .match_cache [key ]
794793 except KeyError :
795- cell = _nir .Matches (module_idx , value = value , patterns = patterns , src_loc = src_loc )
796- net , = self .netlist .add_value_cell (1 , cell )
797- self .matches_cache [key ] = net
798- return net
799-
800- def emit_priority_match (self , module_idx : int , en : _nir .Net , inputs : _nir .Value , * , src_loc ):
801- key = module_idx , en , inputs , src_loc
802- try :
803- return self .priority_match_cache [key ]
804- except KeyError :
805- cell = _nir .PriorityMatch (module_idx , en = en , inputs = inputs , src_loc = src_loc )
806- res = self .netlist .add_value_cell (len (inputs ), cell )
807- self .priority_match_cache [key ] = res
794+ cell = _nir .Match (module_idx , en = en , value = value , patterns = patterns , src_loc = src_loc )
795+ res = self .netlist .add_value_cell (len (patterns ), cell )
796+ self .match_cache [key ] = res
808797 return res
809798
810799 def unify_shapes_bitwise (self ,
@@ -956,17 +945,16 @@ def emit_rhs(self, module_idx: int, value: _ast.Value) -> tuple[_nir.Value, bool
956945 result = self .emit_operator (module_idx , 'm' , test , operand_a , operand_b ,
957946 src_loc = value .src_loc )
958947 else :
959- conds = []
960948 elems = []
961- for patterns , elem , in value . cases :
962- if patterns is not None :
963- net = self . emit_matches ( module_idx , test , patterns , src_loc = value . src_loc )
964- conds .append (net )
949+ patterns = []
950+ for pattern_list , elem , in value . cases :
951+ if pattern_list is not None :
952+ patterns .append (pattern_list )
965953 else :
966- conds .append (_nir . Net . from_const ( 1 ))
954+ patterns .append (( "-" * len ( test ), ))
967955 elems .append (self .emit_rhs (module_idx , elem ))
968- conds = self .emit_priority_match (module_idx , _nir .Net .from_const (1 ),
969- _nir . Value ( conds ), src_loc = value .src_loc )
956+ conds = self .emit_match (module_idx , _nir .Net .from_const (1 ), test , tuple ( patterns ),
957+ src_loc = value .src_loc )
970958 shape = _ast .Shape ._unify (
971959 _ast .Shape (len (value ), signed )
972960 for value , signed in elems
@@ -1056,14 +1044,10 @@ def emit_assign(self, module_idx: int, cd: "_cd.ClockDomain | None", lhs: _ast.V
10561044 offset , _signed = self .emit_rhs (module_idx , lhs .offset )
10571045 width = len (lhs .value )
10581046 num_cases = min ((width + lhs .stride - 1 ) // lhs .stride , 1 << len (offset ))
1059- conds = []
1047+ patterns = []
10601048 for case_index in range (num_cases ):
1061- subcond = self .emit_matches (module_idx , offset ,
1062- (to_binary (case_index , len (offset )),),
1063- src_loc = lhs .src_loc )
1064- conds .append (subcond )
1065- conds = self .emit_priority_match (module_idx , cond , _nir .Value (conds ),
1066- src_loc = lhs .src_loc )
1049+ patterns .append ((to_binary (case_index , len (offset )),))
1050+ conds = self .emit_match (module_idx , cond , offset , tuple (patterns ), src_loc = lhs .src_loc )
10671051 for idx , subcond in enumerate (conds ):
10681052 start = lhs_start + idx * lhs .stride
10691053 if start >= width :
@@ -1075,17 +1059,15 @@ def emit_assign(self, module_idx: int, cd: "_cd.ClockDomain | None", lhs: _ast.V
10751059 self .emit_assign (module_idx , cd , lhs .value , start , subrhs , subcond , src_loc = src_loc )
10761060 elif isinstance (lhs , _ast .SwitchValue ):
10771061 test , _signed = self .emit_rhs (module_idx , lhs .test )
1078- conds = []
1062+ patterns = []
10791063 elems = []
1080- for patterns , elem in lhs .cases :
1081- if patterns is not None :
1082- net = self .emit_matches (module_idx , test , patterns , src_loc = lhs .src_loc )
1083- conds .append (net )
1064+ for pattern_list , elem in lhs .cases :
1065+ if pattern_list is not None :
1066+ patterns .append (pattern_list )
10841067 else :
1085- conds .append (_nir . Net . from_const ( 1 ))
1068+ patterns .append (( "-" * len ( test ), ))
10861069 elems .append (elem )
1087- conds = self .emit_priority_match (module_idx , cond , _nir .Value (conds ),
1088- src_loc = lhs .src_loc )
1070+ conds = self .emit_match (module_idx , cond , test , tuple (patterns ), src_loc = lhs .src_loc )
10891071 for subcond , val in zip (conds , elems ):
10901072 self .emit_assign (module_idx , cd , val , lhs_start , rhs [:len (val )], subcond , src_loc = src_loc )
10911073 elif isinstance (lhs , _ast .Operator ):
@@ -1166,17 +1148,15 @@ def emit_stmt(self, module_idx: int, fragment: _ir.Fragment, domain: str,
11661148 self .netlist .add_cell (cell )
11671149 elif isinstance (stmt , _ast .Switch ):
11681150 test , _signed = self .emit_rhs (module_idx , stmt .test )
1169- conds = []
1151+ patterns = []
11701152 case_stmts = []
1171- for patterns , stmts , case_src_loc in stmt .cases :
1172- if patterns is not None :
1173- net = self .emit_matches (module_idx , test , patterns , src_loc = case_src_loc )
1174- conds .append (net )
1153+ for pattern_list , stmts , case_src_loc in stmt .cases :
1154+ if pattern_list is not None :
1155+ patterns .append (pattern_list )
11751156 else :
1176- conds .append (_nir . Net . from_const ( 1 ))
1157+ patterns .append (( "-" * len ( test ), ))
11771158 case_stmts .append (stmts )
1178- conds = self .emit_priority_match (module_idx , cond , _nir .Value (conds ),
1179- src_loc = stmt .src_loc )
1159+ conds = self .emit_match (module_idx , cond , test , tuple (patterns ), src_loc = stmt .src_loc )
11801160 for subcond , substmts in zip (conds , case_stmts ):
11811161 for substmt in substmts :
11821162 self .emit_stmt (module_idx , fragment , domain , substmt , subcond )
@@ -1430,13 +1410,10 @@ def emit_drivers(self):
14301410 driver .domain .rst is not None and
14311411 not driver .domain .async_reset and
14321412 not driver .signal .reset_less ):
1433- cond = self .emit_matches (driver .module_idx ,
1413+ cond , = self .emit_match (driver .module_idx , _nir . Net . from_const ( 1 ) ,
14341414 self .emit_signal (driver .domain .rst ),
1435- ("1" ,),
1415+ (( "1" ,) ,),
14361416 src_loc = driver .domain .rst .src_loc )
1437- cond , = self .emit_priority_match (driver .module_idx , _nir .Net .from_const (1 ),
1438- _nir .Value (cond ),
1439- src_loc = driver .domain .rst .src_loc )
14401417 init = _nir .Value .from_const (driver .signal .init , len (driver .signal ))
14411418 driver .assignments .append (_nir .Assignment (cond = cond , start = 0 ,
14421419 value = init , src_loc = driver .signal .src_loc ))
0 commit comments