@@ -51,14 +51,14 @@ def forward(self, module_inputs, target):
5151 return loss , nll
5252
5353
54- def _chunk_tuple (seq_tuple , num_chunks , batch_first = True ):
54+ def _chunk_tuple (seq_tuple , num_chunks , duplicates = 1 , batch_first = True ):
5555 if num_chunks == 1 :
56- return [seq_tuple ]
56+ return [seq_tuple ] * duplicates
5757 seq , length = seq_tuple
5858 batch_dim = 0 if batch_first else 1
5959 chunked_length = [l .tolist ()
6060 for l in torch .tensor (length ).chunk (num_chunks )]
61- return zip (seq .chunk (num_chunks , dim = batch_dim ), chunked_length )
61+ return list ( zip (seq .chunk (num_chunks , dim = batch_dim ), chunked_length )) * duplicates
6262
6363
6464def _batch_max_tokens (src_tuple , target_tuple , max_tokens , batch_first = True , log = True ):
@@ -107,6 +107,7 @@ def __init__(self, model, regime=None,
107107 embedding_grad_clip = None ,
108108 max_tokens = None ,
109109 chunk_batch = 1 ,
110+ duplicates = 1 ,
110111 save_info = {},
111112 save_path = '.' ,
112113 checkpoint_filename = 'checkpoint%s.pth' ,
@@ -132,6 +133,7 @@ def __init__(self, model, regime=None,
132133 self .dtype = dtype
133134 self .max_tokens = max_tokens
134135 self .chunk_batch = chunk_batch
136+ self .duplicates = duplicates
135137 self .print_freq = print_freq
136138 self .eval_freq = eval_freq
137139 self .perplexity = float ('inf' )
@@ -171,8 +173,8 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
171173 self .optimizer .zero_grad ()
172174
173175 repacked_inputs = []
174- for src_tuple , target_tuple in zip (_chunk_tuple (src_tuple_batch , chunk_batch , self .batch_first ),
175- _chunk_tuple (target_tuple_batch , chunk_batch , self .batch_first )):
176+ for src_tuple , target_tuple in zip (_chunk_tuple (src_tuple_batch , chunk_batch , self .duplicates , self . batch_first ),
177+ _chunk_tuple (target_tuple_batch , chunk_batch , self .duplicates , self . batch_first )):
176178 # limit number of tokens to avoid gpu overload
177179 if training and self .max_tokens is not None :
178180 src_tuple , target_tuple = _batch_max_tokens (
0 commit comments