Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion cgp/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,11 @@ def change_address_gene_of_output_node(self, new_address: int, output_node_idx:
self.dna = dna

def set_expression_for_output(
self, dna_insert: List[int], hidden_start_node: int = 0, output_node_idx: int = 0
self,
dna_insert: List[int],
target_expression: str,
hidden_start_node: int = 0,
output_node_idx: int = 0,
):
"""Set an expression for one output node

Expand All @@ -370,6 +374,8 @@ def set_expression_for_output(
----------
dna_insert: List[int]
dna segment to be inserted at the first hidden nodes.
target_expression: str
Expression the output node should compile to. Numbers must be written as float.
hidden_start_node: int
Index of the hidden node, where the insert starts.
Relative to the first hidden node.
Expand All @@ -388,6 +394,22 @@ def set_expression_for_output(
self.change_address_gene_of_output_node(
new_address=last_inserted_node, output_node_idx=output_node_idx
)
try:
import sympy

except ModuleNotFoundError:
raise ModuleNotFoundError(
"Can not check output expression. No module named 'sympy' (extra requirement)"
)

if self._n_outputs > 1:
output_as_sympy = CartesianGraph(self).to_sympy()[output_node_idx]
else:
output_as_sympy = CartesianGraph(self).to_sympy()

target_expression_as_sympy = sympy.parse_expr(target_expression)
if not output_as_sympy == target_expression_as_sympy:
raise ValueError("expression of output and target expression do not match")

def reorder(self, rng: np.random.RandomState) -> None:
"""Reorder the genome
Expand Down
28 changes: 23 additions & 5 deletions test/test_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,17 +835,35 @@ def test_set_expression_for_output(genome_params, rng):
genome = cgp.Genome(**genome_params)
genome.randomize(rng)

new_dna = [0, 0, 1]
genome.set_expression_for_output(new_dna)

x_0 = sympy.symbols("x_0")
x_1 = sympy.symbols("x_1")

new_dna = [0, 0, 1]
genome.set_expression_for_output(new_dna, target_expression="x_0 + x_1")
assert CartesianGraph(genome).to_sympy() == x_0 + x_1

new_dna = [1, 0, 1]
genome.set_expression_for_output(new_dna)
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 - x_1")
assert CartesianGraph(genome).to_sympy() == x_0 - x_1

new_dna = [0, 0, 1, 2, 0, 0, 1, 0, 0, 0, 2, 3] # x_0+x_1; 1.0; 0; x_0+x_1 + 1.0
genome.set_expression_for_output(new_dna)
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 + 1.0")
assert CartesianGraph(genome).to_sympy() == x_0 + x_1 + 1.0

with pytest.raises(ValueError):
# setting an int in the str causes an error
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 + 1")
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 * 1.0")

genome2_params = {
"n_inputs": 2,
"n_outputs": 2,
"primitives": (cgp.Add, cgp.Sub, cgp.ConstantFloat),
}
genome2 = cgp.Genome(**genome2_params)
genome2.randomize(rng)

genome2.set_expression_for_output(
new_dna, output_node_idx=1, target_expression="x_0 + x_1 + 1.0"
)
assert CartesianGraph(genome2).to_sympy()[1] == x_0 + x_1 + 1.0