Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion cgp/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,11 @@ def change_address_gene_of_output_node(self, new_address: int, output_node_idx:
self.dna = dna

def set_expression_for_output(
self, dna_insert: List[int], hidden_start_node: int = 0, output_node_idx: int = 0
self,
dna_insert: List[int],
target_expression: str,
hidden_start_node: int = 0,
output_node_idx: int = 0,
):
"""Set an expression for one output node

Expand All @@ -370,6 +374,9 @@ def set_expression_for_output(
----------
dna_insert: List[int]
dna segment to be inserted at the first hidden nodes.
target_expression: str
Expression the output node should compile to. Numbers must be written as float.
Defaults to None.
hidden_start_node: int
Index of the hidden node, where the insert starts.
Relative to the first hidden node.
Expand All @@ -388,6 +395,21 @@ def set_expression_for_output(
self.change_address_gene_of_output_node(
new_address=last_inserted_node, output_node_idx=output_node_idx
)
try:
import sympy

if target_expression is not None:
if self._n_outputs > 1:
output_as_sympy = CartesianGraph(self).to_sympy()[output_node_idx]
else:
output_as_sympy = CartesianGraph(self).to_sympy()

target_expression_as_sympy = sympy.parse_expr(target_expression)
if not output_as_sympy == target_expression_as_sympy:
raise ValueError("expression of output and target expression do not match")

except ModuleNotFoundError:
raise Warning("Sympy not available, can not compare written output to target")

def reorder(self, rng: np.random.RandomState) -> None:
"""Reorder the genome
Expand Down
28 changes: 23 additions & 5 deletions test/test_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,17 +835,35 @@ def test_set_expression_for_output(genome_params, rng):
genome = cgp.Genome(**genome_params)
genome.randomize(rng)

new_dna = [0, 0, 1]
genome.set_expression_for_output(new_dna)

x_0 = sympy.symbols("x_0")
x_1 = sympy.symbols("x_1")

new_dna = [0, 0, 1]
genome.set_expression_for_output(new_dna, target_expression="x_0 + x_1")
assert CartesianGraph(genome).to_sympy() == x_0 + x_1

new_dna = [1, 0, 1]
genome.set_expression_for_output(new_dna)
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 - x_1")
assert CartesianGraph(genome).to_sympy() == x_0 - x_1

new_dna = [0, 0, 1, 2, 0, 0, 1, 0, 0, 0, 2, 3] # x_0+x_1; 1.0; 0; x_0+x_1 + 1.0
genome.set_expression_for_output(new_dna)
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 + 1.0")
assert CartesianGraph(genome).to_sympy() == x_0 + x_1 + 1.0

with pytest.raises(ValueError):
# setting an int in the str causes an error
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 + 1")
genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 8 1.0")

genome2_params = {
"n_inputs": 2,
"n_outputs": 2,
"primitives": (cgp.Add, cgp.Sub, cgp.ConstantFloat),
}
genome2 = cgp.Genome(**genome2_params)
genome2.randomize(rng)

genome2.set_expression_for_output(
new_dna, output_node_idx=1, target_expression="x_0 + x_1 + 1.0"
)
assert CartesianGraph(genome2).to_sympy()[1] == x_0 + x_1 + 1.0