11import pytest
22
3+ from pytensor .graph .fg import FunctionGraph
34from pytensor .graph .rewriting .basic import GraphRewriter , SequentialGraphRewriter
45from pytensor .graph .rewriting .db import (
56 EquilibriumDB ,
@@ -17,6 +18,13 @@ def apply(self, fgraph):
1718 pass
1819
1920
21+ class NewTestRewriter (GraphRewriter ):
22+ name = "bleh"
23+
24+ def apply (self , fgraph ):
25+ pass
26+
27+
2028class TestDB :
2129 def test_register (self ):
2230 db = RewriteDatabase ()
@@ -31,7 +39,7 @@ def test_register(self):
3139 assert "c" in db
3240
3341 with pytest .raises (ValueError , match = r"The tag.*" ):
34- db .register ("c" , TestRewriter ()) # name taken
42+ db .register ("c" , NewTestRewriter ()) # name taken
3543
3644 with pytest .raises (ValueError , match = r"The tag.*" ):
3745 db .register ("z" , TestRewriter ()) # name collides with tag
@@ -42,6 +50,40 @@ def test_register(self):
4250 with pytest .raises (TypeError , match = r".* is not a valid.*" ):
4351 db .register ("d" , 1 )
4452
53+ def test_overwrite_existing (self ):
54+ class TestOverwrite1 (GraphRewriter ):
55+ def apply (self , fgraph ):
56+ fgraph .counter [0 ] += 1
57+
58+ class TestOverwrite2 (GraphRewriter ):
59+ def apply (self , fgraph ):
60+ fgraph .counter [1 ] += 1
61+
62+ db = SequenceDB ()
63+ fg = FunctionGraph ([], [])
64+ fg .counter = [0 , 0 ]
65+
66+ db .register ("a" , TestRewriter (), "basic" )
67+ rewriter = db .query ("+basic" )
68+ rewriter .rewrite (fg )
69+ assert fg .counter == [0 , 0 ]
70+
71+ with pytest .raises (ValueError , match = r"The tag.*" ):
72+ db .register ("a" , TestOverwrite1 (), "basic" )
73+ rewriter = db .query ("+basic" )
74+ rewriter .rewrite (fg )
75+ assert fg .counter == [0 , 0 ]
76+
77+ db .register ("a" , TestOverwrite1 (), "basic" , overwrite_existing = True )
78+ rewriter = db .query ("+basic" )
79+ rewriter .rewrite (fg )
80+ assert fg .counter == [1 , 0 ]
81+
82+ db .register ("a" , TestOverwrite2 (), "basic" , overwrite_existing = True )
83+ rewriter = db .query ("+basic" )
84+ rewriter .rewrite (fg )
85+ assert fg .counter == [1 , 1 ]
86+
4587 def test_EquilibriumDB (self ):
4688 eq_db = EquilibriumDB ()
4789
0 commit comments