Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 14 additions & 32 deletions amaranth/back/rtlil.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def emit_cell_wires(self):
wire = self.emit_driven_wire(_nir.Value(nets))
self.instance_wires[cell_idx, name] = wire
continue # Instances use one wire per output, not per cell.
elif isinstance(cell, (_nir.PriorityMatch, _nir.Matches)):
elif isinstance(cell, _nir.Match):
continue # Inlined into assignment lists.
elif isinstance(cell, (_nir.SyncPrint, _nir.AsyncPrint, _nir.SyncProperty,
_nir.AsyncProperty, _nir.Memory, _nir.SyncWritePort)):
Expand Down Expand Up @@ -737,41 +737,25 @@ def emit_assignments(case, cond):
search_cond = assign.cond
while True:
if search_cond == cond:
# We have found the PriorityMatch cell that we should enter.
# We have found the Match cell that we should enter.
break
if search_cond == _nir.Net.from_const(1):
# If this isn't nested condition, go back to parent invocation.
return
# Grab the PriorityMatch cell that is on the next level of nesting.
priority_cell_idx = search_cond.cell
priority_cell = self.netlist.cells[priority_cell_idx]
assert isinstance(priority_cell, _nir.PriorityMatch)
search_cond = priority_cell.en
# We assume that:
# 1. PriorityMatch inputs can only be Match cell outputs, or constant 1.
# 2. All Match cells driving a given PriorityMatch cell test the same value.
# Grab the tested value from a random Match cell.
test = _nir.Value()
for net in priority_cell.inputs:
if net != _nir.Net.from_const(1):
matches_cell = self.netlist.cells[net.cell]
assert isinstance(matches_cell, _nir.Matches)
test = matches_cell.value
break
# Grab the Match cell that is on the next level of nesting.
match_cell_idx = search_cond.cell
match_cell = self.netlist.cells[match_cell_idx]
assert isinstance(match_cell, _nir.Match)
search_cond = match_cell.en
# Now emit cases for all PriorityMatch inputs, in sequence. Consume as many
# assignments as possible along the way.
switch = case.switch(self.sigspec(test))
for bit, net in enumerate(priority_cell.inputs):
subcond = _nir.Net.from_cell(priority_cell_idx, bit)
if net == _nir.Net.from_const(1):
switch = case.switch(self.sigspec(match_cell.value))
for bit, pattern_list in enumerate(match_cell.patterns):
subcond = _nir.Net.from_cell(match_cell_idx, bit)
if pattern_list == ("-" * len(match_cell.value),):
emit_assignments(switch.default(), subcond)
else:
# Validate the above assumptions.
matches_cell = self.netlist.cells[net.cell]
assert isinstance(matches_cell, _nir.Matches)
assert test == matches_cell.value
patterns = matches_cell.patterns
emit_assignments(switch.case(patterns), subcond)
emit_assignments(switch.case(pattern_list), subcond)

lhs = _nir.Value(_nir.Net.from_cell(cell_idx, bit) for bit in range(len(cell.default)))
proc = self.builder.process(src_loc=cell.src_loc)
Expand Down Expand Up @@ -1235,10 +1219,8 @@ def emit_cells(self):
cell = self.netlist.cells[cell_idx]
if isinstance(cell, _nir.Top):
pass
elif isinstance(cell, _nir.Matches):
pass # Matches is only referenced from PriorityMatch cells and inlined there
elif isinstance(cell, _nir.PriorityMatch):
pass # PriorityMatch is only referenced from AssignmentList cells and inlined there
elif isinstance(cell, _nir.Match):
pass # Match is only referenced from AssignmentList cells and inlined there
elif isinstance(cell, _nir.AssignmentList):
self.emit_assignment_list(cell_idx, cell)
elif isinstance(cell, _nir.Operator):
Expand Down
85 changes: 31 additions & 54 deletions amaranth/hdl/_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,7 @@ def __init__(self, netlist: _nir.Netlist, design: Design, *, all_undef_to_ff=Fal
self.drivers = _ast.SignalDict()
self.io_ports: dict[_ast.IOPort, int] = {}
self.rhs_cache: dict[int, tuple[_nir.Value, bool, _ast.Value]] = {}
self.matches_cache = {}
self.priority_match_cache = {}
self.match_cache = {}
self.fragment_module_idx: dict[Fragment, int] = {}

# Collected for driver conflict diagnostics only.
Expand Down Expand Up @@ -787,24 +786,14 @@ def emit_operator(self, module_idx: int, operator: str, *inputs: _nir.Value, src
op = _nir.Operator(module_idx, operator=operator, inputs=inputs, src_loc=src_loc)
return self.netlist.add_value_cell(op.width, op)

def emit_matches(self, module_idx: int, value: _nir.Value, patterns, *, src_loc):
key = module_idx, value, patterns, src_loc
def emit_match(self, module_idx: int, en: _nir.Net, value: _nir.Value, patterns, *, src_loc):
key = module_idx, en, value, patterns, src_loc
try:
return self.matches_cache[key]
return self.match_cache[key]
except KeyError:
cell = _nir.Matches(module_idx, value=value, patterns=patterns, src_loc=src_loc)
net, = self.netlist.add_value_cell(1, cell)
self.matches_cache[key] = net
return net

def emit_priority_match(self, module_idx: int, en: _nir.Net, inputs: _nir.Value, *, src_loc):
key = module_idx, en, inputs, src_loc
try:
return self.priority_match_cache[key]
except KeyError:
cell = _nir.PriorityMatch(module_idx, en=en, inputs=inputs, src_loc=src_loc)
res = self.netlist.add_value_cell(len(inputs), cell)
self.priority_match_cache[key] = res
cell = _nir.Match(module_idx, en=en, value=value, patterns=patterns, src_loc=src_loc)
res = self.netlist.add_value_cell(len(patterns), cell)
self.match_cache[key] = res
return res

def unify_shapes_bitwise(self,
Expand Down Expand Up @@ -956,17 +945,16 @@ def emit_rhs(self, module_idx: int, value: _ast.Value) -> tuple[_nir.Value, bool
result = self.emit_operator(module_idx, 'm', test, operand_a, operand_b,
src_loc=value.src_loc)
else:
conds = []
elems = []
for patterns, elem, in value.cases:
if patterns is not None:
net = self.emit_matches(module_idx, test, patterns, src_loc=value.src_loc)
conds.append(net)
patterns = []
for pattern_list, elem, in value.cases:
if pattern_list is not None:
patterns.append(pattern_list)
else:
conds.append(_nir.Net.from_const(1))
patterns.append(("-" * len(test),))
elems.append(self.emit_rhs(module_idx, elem))
conds = self.emit_priority_match(module_idx, _nir.Net.from_const(1),
_nir.Value(conds), src_loc=value.src_loc)
conds = self.emit_match(module_idx, _nir.Net.from_const(1), test, tuple(patterns),
src_loc=value.src_loc)
shape = _ast.Shape._unify(
_ast.Shape(len(value), signed)
for value, signed in elems
Expand Down Expand Up @@ -1056,14 +1044,10 @@ def emit_assign(self, module_idx: int, cd: "_cd.ClockDomain | None", lhs: _ast.V
offset, _signed = self.emit_rhs(module_idx, lhs.offset)
width = len(lhs.value)
num_cases = min((width + lhs.stride - 1) // lhs.stride, 1 << len(offset))
conds = []
patterns = []
for case_index in range(num_cases):
subcond = self.emit_matches(module_idx, offset,
(to_binary(case_index, len(offset)),),
src_loc=lhs.src_loc)
conds.append(subcond)
conds = self.emit_priority_match(module_idx, cond, _nir.Value(conds),
src_loc=lhs.src_loc)
patterns.append((to_binary(case_index, len(offset)),))
conds = self.emit_match(module_idx, cond, offset, tuple(patterns), src_loc=lhs.src_loc)
for idx, subcond in enumerate(conds):
start = lhs_start + idx * lhs.stride
if start >= width:
Expand All @@ -1075,17 +1059,15 @@ def emit_assign(self, module_idx: int, cd: "_cd.ClockDomain | None", lhs: _ast.V
self.emit_assign(module_idx, cd, lhs.value, start, subrhs, subcond, src_loc=src_loc)
elif isinstance(lhs, _ast.SwitchValue):
test, _signed = self.emit_rhs(module_idx, lhs.test)
conds = []
patterns = []
elems = []
for patterns, elem in lhs.cases:
if patterns is not None:
net = self.emit_matches(module_idx, test, patterns, src_loc=lhs.src_loc)
conds.append(net)
for pattern_list, elem in lhs.cases:
if pattern_list is not None:
patterns.append(pattern_list)
else:
conds.append(_nir.Net.from_const(1))
patterns.append(("-" * len(test),))
elems.append(elem)
conds = self.emit_priority_match(module_idx, cond, _nir.Value(conds),
src_loc=lhs.src_loc)
conds = self.emit_match(module_idx, cond, test, tuple(patterns), src_loc=lhs.src_loc)
for subcond, val in zip(conds, elems):
self.emit_assign(module_idx, cd, val, lhs_start, rhs[:len(val)], subcond, src_loc=src_loc)
elif isinstance(lhs, _ast.Operator):
Expand Down Expand Up @@ -1166,17 +1148,15 @@ def emit_stmt(self, module_idx: int, fragment: _ir.Fragment, domain: str,
self.netlist.add_cell(cell)
elif isinstance(stmt, _ast.Switch):
test, _signed = self.emit_rhs(module_idx, stmt.test)
conds = []
patterns = []
case_stmts = []
for patterns, stmts, case_src_loc in stmt.cases:
if patterns is not None:
net = self.emit_matches(module_idx, test, patterns, src_loc=case_src_loc)
conds.append(net)
for pattern_list, stmts, case_src_loc in stmt.cases:
if pattern_list is not None:
patterns.append(pattern_list)
else:
conds.append(_nir.Net.from_const(1))
patterns.append(("-" * len(test),))
case_stmts.append(stmts)
conds = self.emit_priority_match(module_idx, cond, _nir.Value(conds),
src_loc=stmt.src_loc)
conds = self.emit_match(module_idx, cond, test, tuple(patterns), src_loc=stmt.src_loc)
for subcond, substmts in zip(conds, case_stmts):
for substmt in substmts:
self.emit_stmt(module_idx, fragment, domain, substmt, subcond)
Expand Down Expand Up @@ -1430,13 +1410,10 @@ def emit_drivers(self):
driver.domain.rst is not None and
not driver.domain.async_reset and
not driver.signal.reset_less):
cond = self.emit_matches(driver.module_idx,
cond, = self.emit_match(driver.module_idx, _nir.Net.from_const(1),
self.emit_signal(driver.domain.rst),
("1",),
(("1",),),
src_loc=driver.domain.rst.src_loc)
cond, = self.emit_priority_match(driver.module_idx, _nir.Net.from_const(1),
_nir.Value(cond),
src_loc=driver.domain.rst.src_loc)
init = _nir.Value.from_const(driver.signal.init, len(driver.signal))
driver.assignments.append(_nir.Assignment(cond=cond, start=0,
value=init, src_loc=driver.signal.src_loc))
Expand Down
75 changes: 22 additions & 53 deletions amaranth/hdl/_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Computation cells
"Operator", "Part",
# Decision tree cells
"Matches", "PriorityMatch", "Assignment", "AssignmentList",
"Match", "Assignment", "AssignmentList",
# Storage cells
"FlipFlop", "Memory", "SyncWritePort", "AsyncReadPort", "SyncReadPort",
# Print cells
Expand Down Expand Up @@ -768,79 +768,48 @@ def comb_edges_to(self, bit):
yield (net, self.src_loc)


class Matches(Cell):
"""A combinational cell performing a comparison like ``Value.matches``
(or, equivalently, a case condition).

Attributes
----------

value: Value
patterns: tuple of str, each str contains '0', '1', '-'
"""
def __init__(self, module_idx, *, value, patterns, src_loc):
super().__init__(module_idx, src_loc=src_loc)

for pattern in patterns:
assert len(pattern) == len(value)
self.value = Value(value)
self.patterns = tuple(patterns)

def input_nets(self):
return set(self.value)

def output_nets(self, self_idx: int):
return {Net.from_cell(self_idx, 0)}

def resolve_nets(self, netlist: Netlist):
self.value = netlist.resolve_value(self.value)

def __repr__(self):
patterns = " ".join(self.patterns)
return f"(matches {self.value} {patterns})"

def comb_edges_to(self, bit):
for net in self.value:
yield (net, self.src_loc)


class PriorityMatch(Cell):
class Match(Cell):
"""Used to represent a single switch on the control plane of processes.

The output is the same length as ``inputs``. If ``en`` is ``0``, the output
is all-0. Otherwise, output keeps the lowest-numbered ``1`` bit in the input
(if any) and masks all other bits to ``0``.

Note: the RTLIL backend requires all bits of ``inputs`` to be driven
by a ``Match`` cell within the same module.
The output is the same length as ``patterns``. If ``en`` is ``0``, the output
is all-0. Otherwise, the ``value`` is matched against all pattern sets
in ``patterns``. The output has a ``1`` bit for the first pattern set that
matches ``value``, and ``0`` for all other bits. If no pattern set matches
the value, the output is all-``0``.

Attributes
----------
en: Net
inputs: Value
value: Value
patterns: tuple of tuple of str, each str contains '0', '1', '-'
"""
def __init__(self, module_idx, *, en, inputs, src_loc):
def __init__(self, module_idx, *, en, value, patterns, src_loc):
super().__init__(module_idx, src_loc=src_loc)

for pattern_list in patterns:
for pattern in pattern_list:
assert len(pattern) == len(value)
self.en = Net.ensure(en)
self.inputs = Value(inputs)
self.value = Value(value)
self.patterns = patterns

def input_nets(self):
return set(self.inputs) | {self.en}
return set(self.value) | {self.en}

def output_nets(self, self_idx: int):
return {Net.from_cell(self_idx, bit) for bit in range(len(self.inputs))}
return {Net.from_cell(self_idx, bit) for bit in range(len(self.patterns))}

def resolve_nets(self, netlist: Netlist):
self.en = netlist.resolve_net(self.en)
self.inputs = netlist.resolve_value(self.inputs)
self.value = netlist.resolve_value(self.value)

def __repr__(self):
return f"(priority_match {self.en} {self.inputs})"
patterns = " ".join("{" + " ".join(pattern_list) + "}" if len(pattern_list) != 1 else pattern_list[0] for pattern_list in self.patterns)
return f"(match {self.en} {self.value} {patterns})"

def comb_edges_to(self, bit):
yield (self.en, self.src_loc)
for net in self.inputs[:bit + 1]:
for net in self.value:
yield (net, self.src_loc)


Expand Down Expand Up @@ -883,7 +852,7 @@ class AssignmentList(Cell):
then executing each assignment in sequence.

Note: the RTLIL backend requires all ``cond`` inputs of assignments to be driven
by a ``PriorityMatch`` cell within the same module.
by a ``Match`` cell within the same module.

Attributes
----------
Expand Down
4 changes: 2 additions & 2 deletions tests/test_back_rtlil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,7 +1413,7 @@ def test_trivial(self):
wire width 4 output 1 \out
process $1
assign \out [3:0] 4'0000
switch { }
switch \sel [3:0]
case
assign \out [3:0] 4'0001
end
Expand Down Expand Up @@ -1838,7 +1838,7 @@ def test_assert_simple(self):
cell $check $3
parameter \FORMAT ""
parameter \ARGS_WIDTH 0
parameter signed \PRIORITY 32'11111111111111111111111111111100
parameter signed \PRIORITY 32'11111111111111111111111111111101
parameter \TRG_ENABLE 0
parameter \TRG_WIDTH 0
parameter \TRG_POLARITY 0
Expand Down
Loading
Loading