@@ -35,6 +35,7 @@ def register(
3535 rewriter : Union ["RewriteDatabase" , RewritesType ],
3636 * tags : str ,
3737 use_db_name_as_tag = True ,
38+ overwrite_existing = False ,
3839 ):
3940 """Register a new rewriter to the database.
4041
@@ -56,7 +57,8 @@ def register(
5657 ``local_remove_all_assert``. Setting `use_db_name_as_tag` to
5758 ``False`` removes that behavior. This means that only the rewrite's name
5859 and/or its tags will enable it.
59-
60+ overwrite_existing:
61+ Overwrite the existing rewriter with a new one having the same name
6062 """
6163 if not isinstance (
6264 rewriter ,
@@ -72,8 +74,12 @@ def register(
7274
7375 rewriter .name = name
7476
77+ # if tag collides with name
78+ if name in self .__db__ and name not in self ._names :
79+ raise ValueError (f"The tag '{ name } ' is already present in the database." )
80+
7581 if name in self .__db__ or rewriter .name in self .__db__ :
76- if " overwrite_existing" in tags :
82+ if overwrite_existing :
7783 old_rewriter = self .__db__ [name ].pop ()
7884 self ._names .remove (name )
7985 self .__db__ [old_rewriter .__class__ .__name__ ].remove (old_rewriter )
@@ -82,15 +88,15 @@ def register(
8288 f"The tag '{ name } ' is already present in the database."
8389 )
8490 else :
85- if " overwrite_existing" in tags :
91+ if overwrite_existing :
8692 raise ValueError (
8793 f"The tag '{ name } ' does not exist in the database. Cannot be overwritten."
8894 )
8995
9096 self .__db__ [name ] = OrderedSet ([rewriter ])
9197 self ._names .add (name )
9298 self .__db__ [rewriter .__class__ .__name__ ].add (rewriter )
93- if "overwrite_existing" not in tags :
99+ if not overwrite_existing :
94100 self .add_tags (name , * tags )
95101
96102 def add_tags (self , name , * tags ):
0 commit comments