@@ -104,6 +104,18 @@ def __init__(self, task, model, **kwargs):
104104 )
105105 self ._pypinyin = pypinyin
106106 self ._max_seq_length = 128
107+ self ._batchify_fn = lambda samples , fn = Tuple (
108+ Pad (axis = 0 , pad_val = self ._tokenizer .pad_token_id ), # input
109+ Pad (axis = 0 , pad_val = self ._tokenizer .pad_token_type_id ), # segment
110+ Pad (axis = 0 , pad_val = self ._pinyin_vocab .token_to_idx [self ._pinyin_vocab .pad_token ]), # pinyin
111+ Stack (axis = 0 , dtype = 'int64' ), # length
112+ ): [data for data in fn (samples )]
113+ self ._num_workers = self .kwargs [
114+ 'num_workers' ] if 'num_workers' in self .kwargs else 0
115+ self ._batch_size = self .kwargs [
116+ 'batch_size' ] if 'batch_size' in self .kwargs else 1
117+ self ._lazy_load = self .kwargs [
118+ 'lazy_load' ] if 'lazy_load' in self .kwargs else False
107119
108120 def _construct_input_spec (self ):
109121 """
@@ -141,61 +153,83 @@ def _construct_tokenizer(self, model):
141153
142154 def _preprocess (self , inputs , padding = True , add_special_tokens = True ):
143155 inputs = self ._check_input_text (inputs )
144- batch_size = self .kwargs [
145- 'batch_size' ] if 'batch_size' in self .kwargs else 1
146- trans_func = self ._convert_example
147-
148- batchify_fn = lambda samples , fn = Tuple (
149- Pad (axis = 0 , pad_val = self ._tokenizer .pad_token_id ), # input
150- Pad (axis = 0 , pad_val = self ._tokenizer .pad_token_type_id ), # segment
151- Pad (axis = 0 , pad_val = self ._pinyin_vocab .token_to_idx [self ._pinyin_vocab .pad_token ]), # pinyin
152- Stack (axis = 0 , dtype = 'int64' ), # length
153- ): [data for data in fn (samples )]
154-
155156 examples = []
156157 texts = []
157158 for text in inputs :
158159 if not (isinstance (text , str ) and len (text ) > 0 ):
159160 continue
160161 example = {"source" : text .strip ()}
161- input_ids , token_type_ids , pinyin_ids , length = trans_func (example )
162+ input_ids , token_type_ids , pinyin_ids , length = self ._convert_example (
163+ example )
162164 examples .append ((input_ids , token_type_ids , pinyin_ids , length ))
163165 texts .append (example ["source" ])
164166
165167 batch_examples = [
166- examples [idx :idx + batch_size ]
167- for idx in range (0 , len (examples ), batch_size )
168+ examples [idx :idx + self . _batch_size ]
169+ for idx in range (0 , len (examples ), self . _batch_size )
168170 ]
169171 batch_texts = [
170- texts [idx :idx + batch_size ]
171- for idx in range (0 , len (examples ), batch_size )
172+ texts [idx :idx + self . _batch_size ]
173+ for idx in range (0 , len (examples ), self . _batch_size )
172174 ]
173175 outputs = {}
174176 outputs ['batch_examples' ] = batch_examples
175177 outputs ['batch_texts' ] = batch_texts
176- self .batchify_fn = batchify_fn
178+ if not self ._static_mode :
179+
180+ def read (inputs ):
181+ for text in inputs :
182+ example = {"source" : text .strip ()}
183+ input_ids , token_type_ids , pinyin_ids , length = self ._convert_example (
184+ example )
185+ yield input_ids , token_type_ids , pinyin_ids , length
186+
187+ infer_ds = load_dataset (read , inputs = inputs , lazy = self ._lazy_load )
188+ outputs ['data_loader' ] = paddle .io .DataLoader (
189+ infer_ds ,
190+ collate_fn = self ._batchify_fn ,
191+ num_workers = self ._num_workers ,
192+ batch_size = self ._batch_size ,
193+ shuffle = False ,
194+ return_list = True )
195+
177196 return outputs
178197
179198 def _run_model (self , inputs ):
180199 """
181200 Run the task model from the outputs of the `_tokenize` function.
182201 """
183202 results = []
184- with static_mode_guard ():
185- for examples in inputs ['batch_examples' ]:
186- token_ids , token_type_ids , pinyin_ids , lengths = self .batchify_fn (
187- examples )
188- self .input_handles [0 ].copy_from_cpu (token_ids )
189- self .input_handles [1 ].copy_from_cpu (pinyin_ids )
190- self .predictor .run ()
191- det_preds = self .output_handle [0 ].copy_to_cpu ()
192- char_preds = self .output_handle [1 ].copy_to_cpu ()
193-
194- batch_result = []
195- for i in range (len (lengths )):
196- batch_result .append (
197- (det_preds [i ], char_preds [i ], lengths [i ]))
198- results .append (batch_result )
203+ if not self ._static_mode :
204+ with dygraph_mode_guard ():
205+ for examples in inputs ['data_loader' ]:
206+ token_ids , token_type_ids , pinyin_ids , lengths = examples
207+ det_preds , char_preds = self ._model (token_ids , pinyin_ids )
208+ det_preds = det_preds .numpy ()
209+ char_preds = char_preds .numpy ()
210+ lengths = lengths .numpy ()
211+
212+ batch_result = []
213+ for i in range (len (lengths )):
214+ batch_result .append (
215+ (det_preds [i ], char_preds [i ], lengths [i ]))
216+ results .append (batch_result )
217+ else :
218+ with static_mode_guard ():
219+ for examples in inputs ['batch_examples' ]:
220+ token_ids , token_type_ids , pinyin_ids , lengths = self ._batchify_fn (
221+ examples )
222+ self .input_handles [0 ].copy_from_cpu (token_ids )
223+ self .input_handles [1 ].copy_from_cpu (pinyin_ids )
224+ self .predictor .run ()
225+ det_preds = self .output_handle [0 ].copy_to_cpu ()
226+ char_preds = self .output_handle [1 ].copy_to_cpu ()
227+
228+ batch_result = []
229+ for i in range (len (lengths )):
230+ batch_result .append (
231+ (det_preds [i ], char_preds [i ], lengths [i ]))
232+ results .append (batch_result )
199233 inputs ['batch_results' ] = results
200234 return inputs
201235
@@ -232,7 +266,7 @@ def _postprocess(self, inputs):
232266
233267 def _convert_example (self , example ):
234268 source = example ["source" ]
235- words = self . _tokenizer . tokenize ( text = source )
269+ words = list ( source )
236270 if len (words ) > self ._max_seq_length - 2 :
237271 words = words [:self ._max_seq_length - 2 ]
238272 length = len (words )
@@ -269,64 +303,22 @@ def _convert_example(self, example):
269303 def _parse_decode (self , words , corr_preds , det_preds , lengths ):
270304 UNK = self ._tokenizer .unk_token
271305 UNK_id = self ._tokenizer .convert_tokens_to_ids (UNK )
272- tokens = self ._tokenizer .tokenize (words )
273- if len (tokens ) > self ._max_seq_length - 2 :
274- tokens = tokens [:self ._max_seq_length - 2 ]
306+
275307 corr_pred = corr_preds [1 :1 + lengths ].tolist ()
276308 det_pred = det_preds [1 :1 + lengths ].tolist ()
277309 words = list (words )
310+ rest_words = []
278311 if len (words ) > self ._max_seq_length - 2 :
312+ rest_words = words [max_seq_length - 2 :]
279313 words = words [:self ._max_seq_length - 2 ]
280314
281- assert len (tokens ) == len (
282- corr_pred
283- ), "The number of tokens should be equal to the number of labels {}: {}: {}" .format (
284- len (tokens ), len (corr_pred ), tokens )
285315 pred_result = ""
286-
287- align_offset = 0
288- # Need to be aligned
289- if len (words ) != len (tokens ):
290- first_unk_flag = True
291- for j , word in enumerate (words ):
292- if word .isspace ():
293- tokens .insert (j + 1 , word )
294- corr_pred .insert (j + 1 , UNK_id )
295- det_pred .insert (j + 1 , 0 ) # No error
296- elif tokens [j ] != word :
297- if self ._tokenizer .convert_tokens_to_ids (word ) == UNK_id :
298- if first_unk_flag :
299- first_unk_flag = False
300- corr_pred [j ] = UNK_id
301- det_pred [j ] = 0
302- else :
303- tokens .insert (j , UNK )
304- corr_pred .insert (j , UNK_id )
305- det_pred .insert (j , 0 ) # No error
306- continue
307- elif tokens [j ] == UNK :
308- # Remove rest unk
309- k = 0
310- while k + j < len (tokens ) and tokens [k + j ] == UNK :
311- k += 1
312- tokens = tokens [:j ] + tokens [j + k :]
313- corr_pred = corr_pred [:j ] + corr_pred [j + k :]
314- det_pred = det_pred [:j ] + det_pred [j + k :]
315- else :
316- # Maybe English, number, or suffix
317- token = tokens [j ].lstrip ("##" )
318- corr_pred = corr_pred [:j ] + [UNK_id ] * len (
319- token ) + corr_pred [j + 1 :]
320- det_pred = det_pred [:j ] + [0 ] * len (token ) + det_pred [
321- j + 1 :]
322- tokens = tokens [:j ] + list (token ) + tokens [j + 1 :]
323- first_unk_flag = True
324-
325316 for j , word in enumerate (words ):
326317 candidates = self ._tokenizer .convert_ids_to_tokens (corr_pred [j ])
327- if det_pred [j ] == 0 or candidates == UNK or candidates == '[PAD]' :
318+ if not is_chinese_char (ord (word )) or det_pred [
319+ j ] == 0 or candidates == UNK or candidates == '[PAD]' :
328320 pred_result += word
329321 else :
330322 pred_result += candidates .lstrip ("##" )
331-
323+ pred_result += '' . join ( rest_words )
332324 return pred_result
0 commit comments