@@ -1010,14 +1010,15 @@ def beam_search_v2(self, src_word, beam_size=4, max_len=None, alpha=0.6):
10101010 """
10111011
10121012 def expand_to_beam_size (tensor , beam_size ):
1013- tensor = paddle .reshape (tensor ,
1014- [tensor .shape [0 ], 1 ] + tensor .shape [1 :])
1013+ tensor = paddle .unsqueeze (tensor , axis = 1 )
10151014 tile_dims = [1 ] * len (tensor .shape )
10161015 tile_dims [1 ] = beam_size
10171016 return paddle .tile (tensor , tile_dims )
10181017
10191018 def merge_beam_dim (tensor ):
1020- return paddle .reshape (tensor , [- 1 ] + tensor .shape [2 :])
1019+ shape = tensor .shape
1020+ return paddle .reshape (tensor ,
1021+ [shape [0 ] * shape [1 ]] + list (shape [2 :]))
10211022
10221023 # run encoder
10231024 src_max_len = paddle .shape (src_word )[- 1 ]
@@ -1045,23 +1046,26 @@ def merge_beam_dim(tensor):
10451046
10461047 ### initialize states of beam search ###
10471048 ## init for the alive ##
1048- initial_log_probs = paddle .to_tensor (
1049+ initial_log_probs = paddle .assign (
10491050 np .array (
10501051 [[0. ] + [- inf ] * (beam_size - 1 )], dtype = "float32" ))
10511052 alive_log_probs = paddle .tile (initial_log_probs , [batch_size , 1 ])
1052- # (batch_size, beam_size, 1)
1053- alive_seq = paddle .to_tensor (
1054- np .tile (np .array ([[[self .bos_id ]]]), (batch_size , beam_size , 1 )),
1055- dtype = src_word .dtype )
1053+
1054+ alive_seq = paddle .tile (
1055+ paddle .cast (
1056+ paddle .assign (np .array ([[[self .bos_id ]]])), src_word .dtype ),
1057+ [batch_size , beam_size , 1 ])
10561058
10571059 ## init for the finished ##
1058- finished_scores = paddle .to_tensor (
1060+ finished_scores = paddle .assign (
10591061 np .array (
10601062 [[- inf ] * beam_size ], dtype = "float32" ))
10611063 finished_scores = paddle .tile (finished_scores , [batch_size , 1 ])
1062- finished_seq = paddle .to_tensor (
1063- np .tile (np .array ([[[self .bos_id ]]]), (batch_size , beam_size , 1 )),
1064- dtype = src_word .dtype )
1064+
1065+ finished_seq = paddle .tile (
1066+ paddle .cast (
1067+ paddle .assign (np .array ([[[self .bos_id ]]])), src_word .dtype ),
1068+ [batch_size , beam_size , 1 ])
10651069 finished_flags = paddle .zeros_like (finished_scores )
10661070
10671071 ### initialize inputs and states of transformer decoder ###
@@ -1076,7 +1080,7 @@ def merge_beam_dim(tensor):
10761080 ## init states (caches) for transformer, need to be updated according to selected beam
10771081 caches = self .transformer .decoder .gen_cache (enc_output , do_zip = False )
10781082
1079- def update_states (caches , topk_coordinates , beam_size ):
1083+ def update_states (caches , topk_coordinates , beam_size , batch_size ):
10801084 new_caches = []
10811085 for cache in caches :
10821086 k = gather_2d (
@@ -1107,9 +1111,11 @@ def gather_2d(tensor_nd,
11071111 beam_size ,
11081112 batch_size ,
11091113 need_unmerge = False ):
1114+
11101115 new_tensor_nd = paddle .reshape (
1111- tensor_nd , shape = [batch_size , beam_size ] +
1112- tensor_nd .shape [1 :]) if need_unmerge else tensor_nd
1116+ tensor_nd ,
1117+ shape = [batch_size , beam_size ] +
1118+ list (tensor_nd .shape [1 :])) if need_unmerge else tensor_nd
11131119 topk_seq = paddle .gather_nd (new_tensor_nd , topk_coordinates )
11141120 return merge_beam_dim (topk_seq ) if need_unmerge else topk_seq
11151121
@@ -1162,11 +1168,15 @@ def grow_topk(i, logits, alive_seq, alive_log_probs, states):
11621168 topk_seq = gather_2d (alive_seq , topk_coordinates , beam_size ,
11631169 batch_size )
11641170 topk_seq = paddle .concat (
1165- [topk_seq , paddle .reshape (topk_ids , topk_ids .shape + [1 ])],
1171+ [
1172+ topk_seq , paddle .reshape (topk_ids ,
1173+ list (topk_ids .shape [:]) + [1 ])
1174+ ],
11661175 axis = 2 )
1167- states = update_states (states , topk_coordinates , beam_size )
1176+ states = update_states (states , topk_coordinates , beam_size ,
1177+ batch_size )
11681178 eos = paddle .full (
1169- shape = topk_ids .shape ,
1179+ shape = paddle .shape ( topk_ids ) ,
11701180 dtype = alive_seq .dtype ,
11711181 fill_value = self .eos_id )
11721182 topk_finished = paddle .cast (paddle .equal (topk_ids , eos ), "float32" )
@@ -1192,7 +1202,8 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
11921202
11931203 alive_log_probs = gather_2d (curr_log_probs , topk_coordinates ,
11941204 beam_size , batch_size )
1195- states = update_states (states , topk_coordinates , beam_size * 2 )
1205+ states = update_states (states , topk_coordinates , beam_size * 2 ,
1206+ batch_size )
11961207
11971208 return alive_seq , alive_log_probs , states
11981209
@@ -1234,7 +1245,9 @@ def grow_finished(finished_seq, finished_scores, finished_flags,
12341245 def inner_loop (i , trg_word , alive_seq , alive_log_probs , finished_seq ,
12351246 finished_scores , finished_flags , caches ):
12361247 trg_pos = paddle .full (
1237- shape = trg_word .shape , dtype = alive_seq .dtype , fill_value = i )
1248+ shape = paddle .shape (trg_word ),
1249+ dtype = alive_seq .dtype ,
1250+ fill_value = i )
12381251 trg_emb = self .trg_word_embedding (trg_word )
12391252 trg_pos_emb = self .trg_pos_embedding (trg_pos )
12401253 trg_emb = trg_emb + trg_pos_emb
@@ -1271,13 +1284,19 @@ def is_not_finish(i, trg_word, alive_seq, alive_log_probs, finished_seq,
12711284 finished_seq , finished_scores , finished_flags , caches
12721285 ])
12731286
1274- finished_flags = paddle .any (paddle .cast (
1275- finished_flags , dtype = 'bool' ),
1276- axis = 1 ,
1277- keepdim = True ).tile ([1 , beam_size ])
1278- finished_seq = paddle .where (
1279- finished_flags .unsqueeze (- 1 ).tile ([1 , 1 , alive_seq .shape [- 1 ]]),
1280- finished_seq , alive_seq )
1281- finished_scores = paddle .where (finished_flags , finished_scores ,
1282- alive_log_probs )
1287+ # (gongenlei) `paddle.where` doesn't support broadcast, so we need to use `paddle.unsqueeze`
1288+ # and `paddle.tile` to make condition.shape same as X.shape. But when converting dygraph
1289+ # to static graph, `paddle.tile` will raise error.
1290+ finished_flags = paddle .cast (finished_flags , dtype = finished_seq .dtype )
1291+ neg_finished_flags = 1 - finished_flags
1292+ finished_seq = paddle .multiply (
1293+ finished_seq , finished_flags .unsqueeze (- 1 )) + paddle .multiply (
1294+ alive_seq , neg_finished_flags .unsqueeze (- 1 ))
1295+ finished_scores = paddle .multiply (
1296+ finished_scores ,
1297+ paddle .cast (
1298+ finished_flags , dtype = finished_scores .dtype )) + paddle .multiply (
1299+ alive_log_probs ,
1300+ paddle .cast (
1301+ neg_finished_flags , dtype = alive_log_probs .dtype ))
12831302 return finished_seq , finished_scores
0 commit comments