@@ -22,24 +22,28 @@ def add_options(cls, parser):
2222 "Transform/InsertMaskBeforePlaceholdersTransform"
2323 )
2424 group .add (
25- "--response_pattern" ,
26- "-response_pattern" ,
27- type = str ,
25+ "--response_patterns" ,
26+ "-response_patterns" ,
2827 help = "Response patten to locate the end of the prompt" ,
29- default = "Response : ⦅newline⦆" ,
28+ default = ["Response : ⦅newline⦆" ],
29+ nargs = "+" ,
3030 )
3131
3232 def _parse_opts (self ):
33- self .response_pattern = self .opts .response_pattern
33+ self .response_patterns = self .opts .response_patterns
3434
3535 def apply (self , example , is_train = False , stats = None , ** kwargs ):
3636 _src = " " .join (example ["src" ])
37- if len (_src .split (self .response_pattern )) != 2 :
37+ response = None
38+ for _pattern in self .response_patterns :
39+ if len (_src .split (_pattern )) == 2 :
40+ prompt , response = _src .split (_pattern )
41+ response = DefaultTokens .MASK_BEFORE .join ([_pattern , response ])
42+ if response is not None :
43+ _src = "" .join ([prompt , response ])
44+ example ["src" ] = _src .split (" " )
45+ example ["tgt" ] = _src .split (" " )
46+ else :
3847 logger .info ("The mask_before could not be inserted" )
3948 return example
40- prompt , response = _src .split (self .response_pattern )
41- response = DefaultTokens .MASK_BEFORE .join ([self .response_pattern , response ])
42- _src = "" .join ([prompt , response ])
43- example ["src" ] = _src .split (" " )
44- example ["tgt" ] = _src .split (" " )
4549 return example
0 commit comments