Skip to content

Commit 348276b

Browse files
author
Elad Hoffer
committed
update multi head attention
1 parent 4e0a3b5 commit 348276b

File tree

6 files changed

+139
-41
lines changed

6 files changed

+139
-41
lines changed

main.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.utils.data.distributed import DistributedSampler
1414
from seq2seq import models, datasets
1515
from seq2seq.tools.utils.log import setup_logging
16-
from seq2seq.tools.utils.misc import set_global_seeds
16+
from seq2seq.tools.utils.misc import set_global_seeds, torch_dtypes
1717
from seq2seq.tools.config import PAD
1818
import seq2seq.tools.trainer as trainers
1919

@@ -49,8 +49,10 @@
4949
help='trainer used: ' +
5050
' | '.join(trainers.__all__) +
5151
' (default: Seq2SeqTrainer)')
52-
parser.add_argument('--dtype', default='torch.float',
53-
help='type of tensor - e.g torch.cuda.HalfTensor')
52+
parser.add_argument('--dtype', default='float',
53+
help='type of tensor: ' +
54+
' | '.join(torch_dtypes.keys()) +
55+
' (default: float)')
5456
parser.add_argument('-j', '--workers', default=8, type=int,
5557
help='number of data loading workers (default: 8)')
5658
parser.add_argument('--epochs', default=100, type=int,
@@ -89,6 +91,8 @@
8991
help='maximum grad norm value. negative for off')
9092
parser.add_argument('--embedding-grad-clip', default=None, type=float,
9193
help='maximum embedding grad norm value')
94+
parser.add_argument('--loss-scale', default=1, type=float,
95+
help='loss scale for mixed precision training.')
9296
parser.add_argument('--label-smoothing', default=0, type=float,
9397
help='label smoothing coefficient - default 0')
9498
parser.add_argument('--uniform-init', default=None, type=float,
@@ -102,7 +106,7 @@
102106
parser.add_argument('--chunk-batch', default=1, type=int,
103107
help='chunk batch size for multiple passes (training) -- used to fit large batches in memory')
104108
parser.add_argument('--duplicates', default=1, type=int,
105-
help='number of duplicates over singel example')
109+
help='number of duplicates over singel example')
106110
parser.add_argument('--seed', default=123, type=int,
107111
help='random seed (default: 123)')
108112

@@ -136,6 +140,7 @@ def main(args):
136140
logging.debug("run arguments: %s", args)
137141

138142
device = args.device
143+
dtype = torch_dtypes.get(args.dtype)
139144
if 'cuda' in args.device:
140145
main_gpu = 0
141146
if isinstance(args.device_ids, tuple):
@@ -168,7 +173,7 @@ def main(args):
168173

169174
model = getattr(models, args.model)(**model_config)
170175

171-
model.to(device)
176+
model.to(device, dtype=dtype)
172177
batch_first = getattr(model, 'batch_first', False)
173178

174179
logging.info(model)
@@ -213,6 +218,7 @@ def main(args):
213218
device_ids=args.device_ids,
214219
device=device,
215220
dtype=args.dtype,
221+
loss_scale=args.loss_scale,
216222
print_freq=args.print_freq,
217223
save_freq=args.save_freq,
218224
eval_freq=args.eval_freq)

seq2seq/models/modules/attention.py

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,28 +171,32 @@ def forward(self, q, k, v):
171171
mask_q = self.mask_q.unsqueeze(2).expand(b, t_q, t_k)
172172
mask = mask_q if mask is None else mask | mask_q
173173
if mask is not None:
174-
qk.masked_fill_(mask, -1e9)
174+
qk.masked_fill_(mask, float('-inf'))
175175

176-
sm_qk = F.softmax(qk, dim=2)
176+
sm_qk = F.softmax(qk, dim=2,
177+
dtype=torch.float32 if qk.dtype == torch.float16 else qk.dtype)
177178
sm_qk = self.dropout(sm_qk)
178179
return torch.bmm(sm_qk, v), sm_qk # b x t_q x dim_v
179180

180181

181-
class MultiHeadAttention(nn.Module):
182+
class MultiHeadAttentionV2(nn.Module):
182183
"""
183184
Scaled Dot-Product Attention
184185
"""
185186

186-
def __init__(self, input_size, output_size, num_heads, weight_norm=False, groups=1, dropout=0, causal=False):
187-
super(MultiHeadAttention, self).__init__()
187+
def __init__(self, input_size, output_size, num_heads, weight_norm=False, groups=1, dropout=0, causal=False, add_bias_kv=False):
188+
super(MultiHeadAttentionV2, self).__init__()
188189
assert(input_size % num_heads == 0)
189190
wn_func = wn if weight_norm else lambda x: x
190191
self.input_size = input_size
191192
self.output_size = output_size
192193
self.num_heads = num_heads
193-
self.linear_q = wn_func(Linear(input_size, input_size, groups=groups))
194-
self.linear_k = wn_func(Linear(input_size, input_size, groups=groups))
195-
self.linear_v = wn_func(Linear(input_size, input_size, groups=groups))
194+
self.linear_q = wn_func(
195+
Linear(input_size, input_size, bias=False, groups=groups))
196+
self.linear_k = wn_func(
197+
Linear(input_size, input_size, bias=add_bias_kv, groups=groups))
198+
self.linear_v = wn_func(
199+
Linear(input_size, input_size, bias=add_bias_kv, groups=groups))
196200
self.linear_out = wn_func(
197201
Linear(input_size, output_size, groups=groups))
198202
self.sdp_attention = SDPAttention(dropout=dropout, causal=causal)
@@ -226,3 +230,58 @@ def forward(self, q, k, v):
226230
output = torch.cat(output, 2)
227231

228232
return self.linear_out(output), attention_scores
233+
234+
235+
class MultiHeadAttention(nn.MultiheadAttention):
236+
"""
237+
Scaled Dot-Product Attention
238+
"""
239+
240+
def __init__(self, input_size, output_size, num_heads, dropout=0, causal=False, bias=True, add_bias_kv=False, add_zero_attn=False, batch_first=True, groups=None, weight_norm=None):
241+
super(MultiHeadAttention, self).__init__(input_size, num_heads, dropout=dropout,
242+
bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn)
243+
assert(input_size % num_heads == 0)
244+
assert(input_size == output_size)
245+
self.causal = causal
246+
self.batch_first = batch_first
247+
248+
def set_mask_q(self, masked_tq):
249+
self.mask_q = masked_tq
250+
251+
def set_mask_k(self, masked_tk):
252+
# applies a mask of b x tk length
253+
self.mask_k = masked_tk
254+
255+
def forward(self, query, key, value, incremental_state=None, need_weights=False, static_kv=False):
256+
key_padding_mask = attn_mask = None
257+
time_dim = 1 if self.batch_first else 0
258+
t_q = query.size(time_dim)
259+
t_k = key.size(time_dim)
260+
with torch.no_grad():
261+
if self.causal and t_q > 1:
262+
attn_mask = torch.full((t_q, t_k), float('-inf'),
263+
device=query.device, dtype=query.dtype).triu_(1)
264+
key_padding_mask = self.mask_k
265+
266+
if self.batch_first:
267+
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
268+
kv_same = key.data_ptr() == value.data_ptr()
269+
key = key.transpose(0, 1)
270+
if kv_same:
271+
value = key
272+
else:
273+
value = value.transpose(0, 1)
274+
if qkv_same:
275+
query = key
276+
else:
277+
query = query.transpose(0, 1)
278+
elif key_padding_mask is not None:
279+
key_padding_mask.t()
280+
281+
282+
attn_output, attn_output_weights = super(
283+
MultiHeadAttention, self).forward(query, key, value, key_padding_mask=key_padding_mask, attn_mask=attn_mask,
284+
incremental_state=incremental_state, need_weights=need_weights, static_kv=static_kv)
285+
if self.batch_first:
286+
attn_output = attn_output.transpose(0, 1)
287+
return attn_output, attn_output_weights

seq2seq/models/modules/transformer_blocks.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88
from .recurrent import Recurrent
99

1010

11-
def positional_embedding(x, min_timescale=1.0, max_timescale=1.0e4, offset=0):
12-
batch, length, channels = list(x.size())
11+
def positional_embedding(x, min_timescale=1.0, max_timescale=1.0e4, offset=0, batch_first=True):
12+
if batch_first:
13+
batch, length, channels = list(x.size())
14+
else:
15+
length, batch, channels = list(x.size())
1316
assert (channels % 2 == 0)
1417
num_timescales = channels // 2
1518
log_timescale_increment = (
@@ -24,8 +27,12 @@ def positional_embedding(x, min_timescale=1.0, max_timescale=1.0e4, offset=0):
2427
scaled_time = position.unsqueeze(1) * inv_timescales.unsqueeze(0)
2528
# scaled time is now length x num_timescales
2629
# length x channels
27-
signal = torch.cat([scaled_time.sin(), scaled_time.cos()], 1)
28-
return signal.unsqueeze(0).expand(batch, length, channels)
30+
signal = torch.cat(
31+
[scaled_time.sin(), scaled_time.cos()], 1).to(dtype=x.dtype)
32+
if batch_first:
33+
return signal.unsqueeze(0).expand(batch, length, channels)
34+
else:
35+
return signal.unsqueeze(1).expand(length, batch, channels)
2936

3037

3138
class AverageNetwork(nn.Module):
@@ -80,16 +87,18 @@ def forward(self, x, state=None):
8087
class EncoderBlock(nn.Module):
8188

8289
def __init__(self, hidden_size=512, num_heads=8, inner_linear=2048, inner_groups=1,
83-
layer_norm=True, weight_norm=False, dropout=0):
90+
batch_first=True, layer_norm=True, weight_norm=False, dropout=0):
8491

8592
super(EncoderBlock, self).__init__()
8693
wn_func = wn if weight_norm else lambda x: x
8794
if layer_norm:
8895
self.lnorm1 = nn.LayerNorm(hidden_size)
8996
self.lnorm2 = nn.LayerNorm(hidden_size)
9097
self.dropout = nn.Dropout(dropout)
91-
self.attention = MultiHeadAttention(
92-
hidden_size, hidden_size, num_heads, dropout=dropout, causal=False, groups=inner_groups, weight_norm=weight_norm)
98+
self.batch_first = batch_first
99+
self.attention = MultiHeadAttention(hidden_size, hidden_size, num_heads,
100+
dropout=dropout, causal=False, batch_first=batch_first,
101+
groups=inner_groups, weight_norm=weight_norm)
93102
self.fc = nn.Sequential(wn_func(Linear(hidden_size, inner_linear, groups=inner_groups)),
94103
nn.ReLU(inplace=True),
95104
nn.Dropout(dropout),
@@ -131,7 +140,7 @@ def forward(self, inputs):
131140

132141
class DecoderBlock(nn.Module):
133142

134-
def __init__(self, hidden_size=512, num_heads=8, inner_linear=2048, inner_groups=1,
143+
def __init__(self, hidden_size=512, num_heads=8, inner_linear=2048, inner_groups=1, batch_first=True,
135144
layer_norm=True, weight_norm=False, dropout=0, stateful=None, state_dim=None):
136145

137146
super(DecoderBlock, self).__init__()
@@ -143,8 +152,10 @@ def __init__(self, hidden_size=512, num_heads=8, inner_linear=2048, inner_groups
143152
self.dropout = nn.Dropout(dropout)
144153
self.weight_norm = weight_norm
145154
self.stateful = stateful
146-
self.attention = MultiHeadAttention(
147-
hidden_size, hidden_size, num_heads, dropout=dropout, causal=False, groups=inner_groups, weight_norm=weight_norm)
155+
self.batch_first = batch_first
156+
self.attention = MultiHeadAttention(hidden_size, hidden_size, num_heads,
157+
batch_first=batch_first, dropout=dropout,
158+
causal=False, groups=inner_groups, weight_norm=weight_norm)
148159
if stateful is not None:
149160
residual = False
150161
stateful_hidden = hidden_size
@@ -157,13 +168,14 @@ def __init__(self, hidden_size=512, num_heads=8, inner_linear=2048, inner_groups
157168
residual = True
158169
if stateful in ['RNN', 'iRNN', 'LSTM', 'GRU']:
159170
self.state_block = Recurrent(stateful, hidden_size, stateful_hidden,
160-
dropout=dropout, residual=residual, batch_first=True)
171+
dropout=dropout, residual=residual, batch_first=batch_first)
161172
else:
162173
self.state_block = AverageNetwork(
163-
hidden_size, hidden_size, layer_norm=layer_norm, weight_norm=weight_norm, batch_first=True)
174+
hidden_size, hidden_size, layer_norm=layer_norm, weight_norm=weight_norm, batch_first=batch_first)
164175
else:
165176
self.masked_attention = MultiHeadAttention(
166-
hidden_size, hidden_size, num_heads, dropout=dropout, causal=True, groups=inner_groups, weight_norm=weight_norm)
177+
hidden_size, hidden_size, num_heads, dropout=dropout,
178+
batch_first=batch_first, causal=True, groups=inner_groups, weight_norm=weight_norm)
167179

168180
self.fc = nn.Sequential(wn_func(Linear(hidden_size, inner_linear, groups=inner_groups)),
169181
nn.ReLU(inplace=True),
@@ -185,10 +197,15 @@ def forward(self, inputs, context, state=None):
185197
else: # block_state are past inputs
186198
if state is None:
187199
x_past = x
200+
mask_past = self.masked_attention.mask_k
188201
else:
189-
x_past = torch.cat((state, x), 1)
202+
time_dim = 1 if self.batch_first else 0
203+
x_past, mask_past = state
204+
x_past = torch.cat((x_past, x), time_dim)
205+
mask_past = torch.cat((mask_past, self.masked_attention.mask_k), time_dim)
206+
self.masked_attention.set_mask_k(mask_past)
190207
x, _ = self.masked_attention(x, x_past, x_past)
191-
state = x_past
208+
state = (x_past, mask_past)
192209
if hasattr(self, 'state_proj'):
193210
x = self.state_proj(x)
194211
x = self.dropout(x).add(res)

seq2seq/models/transformer.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class TransformerAttentionEncoder(nn.Module):
1212

1313
def __init__(self, vocab_size, hidden_size=512, embedding_size=None,
1414
num_layers=6, num_heads=8, inner_linear=2048, inner_groups=1, prenormalized=False,
15-
mask_symbol=PAD, layer_norm=True, weight_norm=False, dropout=0, embedder=None):
15+
mask_symbol=PAD, batch_first=True, layer_norm=True, weight_norm=False, dropout=0, embedder=None):
1616

1717
super(TransformerAttentionEncoder, self).__init__()
1818
embedding_size = embedding_size or hidden_size
@@ -21,7 +21,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None,
2121
torch.empty(embedding_size, hidden_size))
2222
nn.init.kaiming_uniform_(self.input_projection, a=math.sqrt(5))
2323
self.hidden_size = hidden_size
24-
self.batch_first = True
24+
self.batch_first = batch_first
2525
self.mask_symbol = mask_symbol
2626
self.embedder = embedder or nn.Embedding(
2727
vocab_size, embedding_size, padding_idx=PAD)
@@ -37,6 +37,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None,
3737
inner_groups=inner_groups,
3838
layer_norm=layer_norm,
3939
weight_norm=weight_norm,
40+
batch_first=batch_first,
4041
dropout=dropout)
4142
for _ in range(num_layers)
4243
])
@@ -51,7 +52,7 @@ def forward(self, inputs, hidden=None):
5152
x = self.embedder(inputs).mul_(self.scale_embedding)
5253
if hasattr(self, 'input_projection'):
5354
x = x @ self.input_projection
54-
x.add_(positional_embedding(x))
55+
x.add_(positional_embedding(x, batch_first=self.batch_first))
5556
x = self.dropout(x)
5657

5758
for block in self.blocks:
@@ -61,13 +62,13 @@ def forward(self, inputs, hidden=None):
6162
if hasattr(self, 'lnorm'):
6263
x = self.lnorm(x)
6364

64-
return State(outputs=x, mask=padding_mask, batch_first=True)
65+
return State(outputs=x, mask=padding_mask, batch_first=self.batch_first)
6566

6667

6768
class TransformerAttentionDecoder(nn.Module):
6869

69-
def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=6,
70-
num_heads=8, dropout=0, inner_linear=2048, inner_groups=1, prenormalized=False, stateful=None, state_dim=None,
70+
def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=6, num_heads=8,
71+
batch_first=True, dropout=0, inner_linear=2048, inner_groups=1, prenormalized=False, stateful=None, state_dim=None,
7172
mask_symbol=PAD, tie_embedding=True, layer_norm=True, weight_norm=False, embedder=None, classifier=True):
7273

7374
super(TransformerAttentionDecoder, self).__init__()
@@ -76,7 +77,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
7677
self.input_projection = nn.Parameter(
7778
torch.empty(embedding_size, hidden_size))
7879
nn.init.kaiming_uniform_(self.input_projection, a=math.sqrt(5))
79-
self.batch_first = True
80+
self.batch_first = batch_first
8081
self.mask_symbol = mask_symbol
8182
self.embedder = embedder or nn.Embedding(
8283
vocab_size, embedding_size, padding_idx=PAD)
@@ -94,6 +95,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
9495
layer_norm=layer_norm,
9596
weight_norm=weight_norm,
9697
dropout=dropout,
98+
batch_first=batch_first,
9799
stateful=stateful,
98100
state_dim=state_dim)
99101
for _ in range(num_layers)
@@ -125,7 +127,9 @@ def forward(self, inputs, state, get_attention=False):
125127
time_step = self.time_step
126128
else:
127129
block_state = state.inputs
128-
time_step = 0 if block_state is None else block_state[0].size(1)
130+
time_dim = 1 if self.batch_first else 0
131+
time_step = 0 if block_state is None else \
132+
block_state[0][0].size(time_dim)
129133

130134
if block_state is None:
131135
block_state = [None] * len(self.blocks)
@@ -137,7 +141,8 @@ def forward(self, inputs, state, get_attention=False):
137141
x = self.embedder(inputs).mul_(self.scale_embedding)
138142
if hasattr(self, 'input_projection'):
139143
x = x @ self.input_projection
140-
x.add_(positional_embedding(x, offset=time_step))
144+
x.add_(positional_embedding(
145+
x, batch_first=self.batch_first, offset=time_step))
141146
x = self.dropout(x)
142147

143148
attention_scores = []
@@ -173,7 +178,7 @@ class Transformer(Seq2Seq):
173178

174179
def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=6, num_heads=8,
175180
inner_linear=2048, inner_groups=1, dropout=0.1, prenormalized=False, tie_embedding=True,
176-
encoder=None, decoder=None, layer_norm=True, weight_norm=False, stateful=None):
181+
encoder=None, decoder=None, layer_norm=True, weight_norm=False, batch_first=True, stateful=None):
177182
super(Transformer, self).__init__()
178183
embedding_size = embedding_size or hidden_size
179184
# keeping encoder, decoder None will result with default configuration
@@ -192,6 +197,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
192197
encoder.setdefault('inner_linear', inner_linear)
193198
encoder.setdefault('inner_groups', inner_groups)
194199
encoder.setdefault('prenormalized', prenormalized)
200+
encoder.setdefault('batch_first', batch_first)
195201

196202
decoder.setdefault('embedding_size', embedding_size)
197203
decoder.setdefault('hidden_size', hidden_size)
@@ -204,6 +210,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
204210
decoder.setdefault('dropout', dropout)
205211
decoder.setdefault('inner_linear', inner_linear)
206212
decoder.setdefault('inner_groups', inner_groups)
213+
decoder.setdefault('batch_first', batch_first)
207214
decoder.setdefault('prenormalized', prenormalized)
208215
decoder.setdefault('stateful', stateful)
209216

@@ -214,7 +221,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
214221
decoder.setdefault('embedder', embedder)
215222
decoder['classifier'] = False
216223

217-
self.batch_first = True
224+
self.batch_first = batch_first
218225
self.encoder = TransformerAttentionEncoder(**encoder)
219226
self.decoder = TransformerAttentionDecoder(**decoder)
220227

0 commit comments

Comments
 (0)