Skip to content

Commit 4e0a3b5

Browse files
author
Elad Hoffer
committed
duplicates
1 parent 7e82735 commit 4e0a3b5

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@
101101
help='fixed sequence length')
102102
parser.add_argument('--chunk-batch', default=1, type=int,
103103
help='chunk batch size for multiple passes (training) -- used to fit large batches in memory')
104+
parser.add_argument('--duplicates', default=1, type=int,
105+
help='number of duplicates over singel example')
104106
parser.add_argument('--seed', default=123, type=int,
105107
help='random seed (default: 123)')
106108

@@ -205,6 +207,7 @@ def main(args):
205207
keep_checkpoints=args.keep_checkpoints,
206208
max_tokens=args.max_tokens,
207209
chunk_batch=args.chunk_batch,
210+
duplicates=args.duplicates,
208211
distributed=args.distributed,
209212
local_rank=args.local_rank,
210213
device_ids=args.device_ids,

seq2seq/tools/trainer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,14 @@ def forward(self, module_inputs, target):
5151
return loss, nll
5252

5353

54-
def _chunk_tuple(seq_tuple, num_chunks, batch_first=True):
54+
def _chunk_tuple(seq_tuple, num_chunks, duplicates=1, batch_first=True):
5555
if num_chunks == 1:
56-
return [seq_tuple]
56+
return [seq_tuple] * duplicates
5757
seq, length = seq_tuple
5858
batch_dim = 0 if batch_first else 1
5959
chunked_length = [l.tolist()
6060
for l in torch.tensor(length).chunk(num_chunks)]
61-
return zip(seq.chunk(num_chunks, dim=batch_dim), chunked_length)
61+
return list(zip(seq.chunk(num_chunks, dim=batch_dim), chunked_length)) * duplicates
6262

6363

6464
def _batch_max_tokens(src_tuple, target_tuple, max_tokens, batch_first=True, log=True):
@@ -107,6 +107,7 @@ def __init__(self, model, regime=None,
107107
embedding_grad_clip=None,
108108
max_tokens=None,
109109
chunk_batch=1,
110+
duplicates=1,
110111
save_info={},
111112
save_path='.',
112113
checkpoint_filename='checkpoint%s.pth',
@@ -132,6 +133,7 @@ def __init__(self, model, regime=None,
132133
self.dtype = dtype
133134
self.max_tokens = max_tokens
134135
self.chunk_batch = chunk_batch
136+
self.duplicates = duplicates
135137
self.print_freq = print_freq
136138
self.eval_freq = eval_freq
137139
self.perplexity = float('inf')
@@ -171,8 +173,8 @@ def iterate(self, src_tuple_batch, target_tuple_batch, training=True, chunk_batc
171173
self.optimizer.zero_grad()
172174

173175
repacked_inputs = []
174-
for src_tuple, target_tuple in zip(_chunk_tuple(src_tuple_batch, chunk_batch, self.batch_first),
175-
_chunk_tuple(target_tuple_batch, chunk_batch, self.batch_first)):
176+
for src_tuple, target_tuple in zip(_chunk_tuple(src_tuple_batch, chunk_batch, self.duplicates, self.batch_first),
177+
_chunk_tuple(target_tuple_batch, chunk_batch, self.duplicates, self.batch_first)):
176178
# limit number of tokens to avoid gpu overload
177179
if training and self.max_tokens is not None:
178180
src_tuple, target_tuple = _batch_max_tokens(

0 commit comments

Comments
 (0)