Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,16 +616,20 @@ def eval(
"""
from pytensor.compile.function import function

on_unused_input = kwargs.get("on_unused_input", None)

def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]:
new_input_to_values = {}
for key, value in inputs_to_values.items():
if isinstance(key, str):
matching_vars = get_var_by_name([self], key)
if not matching_vars:
raise ValueError(f"{key} not found in graph")
if on_unused_input in ["raise", None]:
raise ValueError(f"{key} not found in graph")
elif len(matching_vars) > 1:
raise ValueError(f"Found multiple variables with name {key}")
new_input_to_values[matching_vars[0]] = value
else:
new_input_to_values[matching_vars[0]] = value
else:
new_input_to_values[key] = value
return new_input_to_values
Expand Down
9 changes: 9 additions & 0 deletions tests/graph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,17 @@ def test_eval_with_strings_no_match(self):
def test_eval_kwargs(self):
with pytest.raises(UnusedInputError):
self.w.eval({self.z: 3, self.x: 2.5})
with pytest.warns(
UserWarning,
match="pytensor.function was asked to create a function",
):
self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="warn")
assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0

# regression test for https://github.com/pymc-devs/pytensor/issues/1084
q = self.x + 1
assert q.eval({"x": 1, "y": 2}, on_unused_input="ignore") == 2.0

@pytest.mark.filterwarnings("error")
def test_eval_unashable_kwargs(self):
y_repl = constant(2.0, dtype="floatX")
Expand Down
Loading