Skip to content

Commit c9c9ee3

Browse files
authored
Allow multiple response patterns in the insert_mask_before_placeholder transform (#2567)
* updated onmt/transforms/insert_mask_before_placeholder.py" * updated onmt/tests/test_transform.py
1 parent 211aeec commit c9c9ee3

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

onmt/tests/test_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ class TestInsertMaskBeforePlaceholder(unittest.TestCase):
788788
@classmethod
789789
def setUpClass(cls):
790790
cls.base_opts = {
791-
"response_pattern": "Response : ⦅newline⦆",
791+
"response_patterns": ["Response : ⦅newline⦆"],
792792
}
793793

794794
def test_insert_mask_before_placeholder(self):

onmt/transforms/insert_mask_before_placeholder.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,28 @@ def add_options(cls, parser):
2222
"Transform/InsertMaskBeforePlaceholdersTransform"
2323
)
2424
group.add(
25-
"--response_pattern",
26-
"-response_pattern",
27-
type=str,
25+
"--response_patterns",
26+
"-response_patterns",
2827
help="Response patten to locate the end of the prompt",
29-
default="Response : ⦅newline⦆",
28+
default=["Response : ⦅newline⦆"],
29+
nargs="+",
3030
)
3131

3232
def _parse_opts(self):
33-
self.response_pattern = self.opts.response_pattern
33+
self.response_patterns = self.opts.response_patterns
3434

3535
def apply(self, example, is_train=False, stats=None, **kwargs):
3636
_src = " ".join(example["src"])
37-
if len(_src.split(self.response_pattern)) != 2:
37+
response = None
38+
for _pattern in self.response_patterns:
39+
if len(_src.split(_pattern)) == 2:
40+
prompt, response = _src.split(_pattern)
41+
response = DefaultTokens.MASK_BEFORE.join([_pattern, response])
42+
if response is not None:
43+
_src = "".join([prompt, response])
44+
example["src"] = _src.split(" ")
45+
example["tgt"] = _src.split(" ")
46+
else:
3847
logger.info("The mask_before could not be inserted")
3948
return example
40-
prompt, response = _src.split(self.response_pattern)
41-
response = DefaultTokens.MASK_BEFORE.join([self.response_pattern, response])
42-
_src = "".join([prompt, response])
43-
example["src"] = _src.split(" ")
44-
example["tgt"] = _src.split(" ")
4549
return example

0 commit comments

Comments
 (0)