55from onmt .opts import translate_opts
66from onmt .constants import CorpusTask
77from onmt .inputters .dynamic_iterator import build_dynamic_dataset_iter
8- from onmt .transforms import get_transforms_cls , make_transforms , TransformPipe
8+ from onmt .transforms import get_transforms_cls
99
1010
1111class ScoringPreparator :
@@ -19,16 +19,12 @@ def __init__(self, vocabs, opt):
1919 if self .opt .dump_preds is not None :
2020 if not os .path .exists (self .opt .dump_preds ):
2121 os .makedirs (self .opt .dump_preds )
22- self .transforms = opt .transforms
23- transforms_cls = get_transforms_cls (self .transforms )
24- transforms = make_transforms (self .opt , transforms_cls , self .vocabs )
25- self .transform = TransformPipe .build_from (transforms .values ())
22+ self .transforms = None
23+ self .transforms_cls = None
2624
2725 def warm_up (self , transforms ):
2826 self .transforms = transforms
29- transforms_cls = get_transforms_cls (self .transforms )
30- transforms = make_transforms (self .opt , transforms_cls , self .vocabs )
31- self .transform = TransformPipe .build_from (transforms .values ())
27+ self .transforms_cls = get_transforms_cls (transforms )
3228
3329 def translate (self , model , gpu_rank , step ):
3430 """Compute and save the sentences predicted by the
@@ -84,7 +80,7 @@ def translate(self, model, gpu_rank, step):
8480
8581 # Reinstantiate the validation iterator
8682
87- transforms_cls = get_transforms_cls (model_opt ._all_transform )
83+ # transforms_cls = get_transforms_cls(model_opt._all_transform)
8884 model_opt .num_workers = 0
8985 model_opt .tgt = None
9086
@@ -100,7 +96,7 @@ def translate(self, model, gpu_rank, step):
10096
10197 valid_iter = build_dynamic_dataset_iter (
10298 model_opt ,
103- transforms_cls ,
99+ self . transforms_cls ,
104100 translator .vocabs ,
105101 task = CorpusTask .VALID ,
106102 tgt = "" , # This force to clear the target side (needed when using tgt_file_prefix)
@@ -125,12 +121,11 @@ def translate(self, model, gpu_rank, step):
125121
126122 # Flatten predictions
127123 preds = [x .lstrip () for sublist in preds for x in sublist ]
128-
129124 # Save results
130125 if len (preds ) > 0 and self .opt .scoring_debug :
131126 path = os .path .join (self .opt .dump_preds , f"preds.valid_step_{ step } .txt" )
132127 with open (path , "a" ) as file :
133- for i in range (len (preds )):
128+ for i in range (len (raw_srcs )):
134129 file .write ("SOURCE: {}\n " .format (raw_srcs [i ]))
135130 file .write ("REF: {}\n " .format (raw_refs [i ]))
136131 file .write ("PRED: {}\n \n " .format (preds [i ]))
0 commit comments