Skip to content

Commit aac4457

Browse files
committed
Add tests
1 parent e221237 commit aac4457

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

pytensor/graph/rewriting/db.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def register(
3535
rewriter: Union["RewriteDatabase", RewritesType],
3636
*tags: str,
3737
use_db_name_as_tag=True,
38+
overwrite_existing=False,
3839
):
3940
"""Register a new rewriter to the database.
4041
@@ -56,7 +57,8 @@ def register(
5657
``local_remove_all_assert``. Setting `use_db_name_as_tag` to
5758
``False`` removes that behavior. This means that only the rewrite's name
5859
and/or its tags will enable it.
59-
60+
overwrite_existing:
61+
Overwrite the existing rewriter with a new one having the same name
6062
"""
6163
if not isinstance(
6264
rewriter,
@@ -72,8 +74,12 @@ def register(
7274

7375
rewriter.name = name
7476

77+
# if tag collides with name
78+
if name in self.__db__ and name not in self._names:
79+
raise ValueError(f"The tag '{name}' is already present in the database.")
80+
7581
if name in self.__db__ or rewriter.name in self.__db__:
76-
if "overwrite_existing" in tags:
82+
if overwrite_existing:
7783
old_rewriter = self.__db__[name].pop()
7884
self._names.remove(name)
7985
self.__db__[old_rewriter.__class__.__name__].remove(old_rewriter)
@@ -82,15 +88,15 @@ def register(
8288
f"The tag '{name}' is already present in the database."
8389
)
8490
else:
85-
if "overwrite_existing" in tags:
91+
if overwrite_existing:
8692
raise ValueError(
8793
f"The tag '{name}' does not exist in the database. Cannot be overwritten."
8894
)
8995

9096
self.__db__[name] = OrderedSet([rewriter])
9197
self._names.add(name)
9298
self.__db__[rewriter.__class__.__name__].add(rewriter)
93-
if "overwrite_existing" not in tags:
99+
if not overwrite_existing:
94100
self.add_tags(name, *tags)
95101

96102
def add_tags(self, name, *tags):

tests/graph/rewriting/test_db.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ def apply(self, fgraph):
1717
pass
1818

1919

20+
class NewTestRewriter(GraphRewriter):
21+
name = "bleh"
22+
23+
def apply(self, fgraph):
24+
pass
25+
26+
2027
class TestDB:
2128
def test_register(self):
2229
db = RewriteDatabase()
@@ -31,7 +38,9 @@ def test_register(self):
3138
assert "c" in db
3239

3340
with pytest.raises(ValueError, match=r"The tag.*"):
34-
db.register("c", TestRewriter()) # name taken
41+
db.register("c", NewTestRewriter()) # name taken
42+
43+
db.register("c", NewTestRewriter(), overwrite_existing=True)
3544

3645
with pytest.raises(ValueError, match=r"The tag.*"):
3746
db.register("z", TestRewriter()) # name collides with tag

0 commit comments

Comments
 (0)