@@ -12,7 +12,7 @@ class TransformerAttentionEncoder(nn.Module):
1212
1313 def __init__ (self , vocab_size , hidden_size = 512 , embedding_size = None ,
1414 num_layers = 6 , num_heads = 8 , inner_linear = 2048 , inner_groups = 1 , prenormalized = False ,
15- mask_symbol = PAD , layer_norm = True , weight_norm = False , dropout = 0 , embedder = None ):
15+ mask_symbol = PAD , batch_first = True , layer_norm = True , weight_norm = False , dropout = 0 , embedder = None ):
1616
1717 super (TransformerAttentionEncoder , self ).__init__ ()
1818 embedding_size = embedding_size or hidden_size
@@ -21,7 +21,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None,
2121 torch .empty (embedding_size , hidden_size ))
2222 nn .init .kaiming_uniform_ (self .input_projection , a = math .sqrt (5 ))
2323 self .hidden_size = hidden_size
24- self .batch_first = True
24+ self .batch_first = batch_first
2525 self .mask_symbol = mask_symbol
2626 self .embedder = embedder or nn .Embedding (
2727 vocab_size , embedding_size , padding_idx = PAD )
@@ -37,6 +37,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None,
3737 inner_groups = inner_groups ,
3838 layer_norm = layer_norm ,
3939 weight_norm = weight_norm ,
40+ batch_first = batch_first ,
4041 dropout = dropout )
4142 for _ in range (num_layers )
4243 ])
@@ -51,7 +52,7 @@ def forward(self, inputs, hidden=None):
5152 x = self .embedder (inputs ).mul_ (self .scale_embedding )
5253 if hasattr (self , 'input_projection' ):
5354 x = x @ self .input_projection
54- x .add_ (positional_embedding (x ))
55+ x .add_ (positional_embedding (x , batch_first = self . batch_first ))
5556 x = self .dropout (x )
5657
5758 for block in self .blocks :
@@ -61,13 +62,13 @@ def forward(self, inputs, hidden=None):
6162 if hasattr (self , 'lnorm' ):
6263 x = self .lnorm (x )
6364
64- return State (outputs = x , mask = padding_mask , batch_first = True )
65+ return State (outputs = x , mask = padding_mask , batch_first = self . batch_first )
6566
6667
6768class TransformerAttentionDecoder (nn .Module ):
6869
69- def __init__ (self , vocab_size , hidden_size = 512 , embedding_size = None , num_layers = 6 ,
70- num_heads = 8 , dropout = 0 , inner_linear = 2048 , inner_groups = 1 , prenormalized = False , stateful = None , state_dim = None ,
70+ def __init__ (self , vocab_size , hidden_size = 512 , embedding_size = None , num_layers = 6 , num_heads = 8 ,
71+ batch_first = True , dropout = 0 , inner_linear = 2048 , inner_groups = 1 , prenormalized = False , stateful = None , state_dim = None ,
7172 mask_symbol = PAD , tie_embedding = True , layer_norm = True , weight_norm = False , embedder = None , classifier = True ):
7273
7374 super (TransformerAttentionDecoder , self ).__init__ ()
@@ -76,7 +77,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
7677 self .input_projection = nn .Parameter (
7778 torch .empty (embedding_size , hidden_size ))
7879 nn .init .kaiming_uniform_ (self .input_projection , a = math .sqrt (5 ))
79- self .batch_first = True
80+ self .batch_first = batch_first
8081 self .mask_symbol = mask_symbol
8182 self .embedder = embedder or nn .Embedding (
8283 vocab_size , embedding_size , padding_idx = PAD )
@@ -94,6 +95,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
9495 layer_norm = layer_norm ,
9596 weight_norm = weight_norm ,
9697 dropout = dropout ,
98+ batch_first = batch_first ,
9799 stateful = stateful ,
98100 state_dim = state_dim )
99101 for _ in range (num_layers )
@@ -125,7 +127,9 @@ def forward(self, inputs, state, get_attention=False):
125127 time_step = self .time_step
126128 else :
127129 block_state = state .inputs
128- time_step = 0 if block_state is None else block_state [0 ].size (1 )
130+ time_dim = 1 if self .batch_first else 0
131+ time_step = 0 if block_state is None else \
132+ block_state [0 ][0 ].size (time_dim )
129133
130134 if block_state is None :
131135 block_state = [None ] * len (self .blocks )
@@ -137,7 +141,8 @@ def forward(self, inputs, state, get_attention=False):
137141 x = self .embedder (inputs ).mul_ (self .scale_embedding )
138142 if hasattr (self , 'input_projection' ):
139143 x = x @ self .input_projection
140- x .add_ (positional_embedding (x , offset = time_step ))
144+ x .add_ (positional_embedding (
145+ x , batch_first = self .batch_first , offset = time_step ))
141146 x = self .dropout (x )
142147
143148 attention_scores = []
@@ -173,7 +178,7 @@ class Transformer(Seq2Seq):
173178
174179 def __init__ (self , vocab_size , hidden_size = 512 , embedding_size = None , num_layers = 6 , num_heads = 8 ,
175180 inner_linear = 2048 , inner_groups = 1 , dropout = 0.1 , prenormalized = False , tie_embedding = True ,
176- encoder = None , decoder = None , layer_norm = True , weight_norm = False , stateful = None ):
181+ encoder = None , decoder = None , layer_norm = True , weight_norm = False , batch_first = True , stateful = None ):
177182 super (Transformer , self ).__init__ ()
178183 embedding_size = embedding_size or hidden_size
179184 # keeping encoder, decoder None will result with default configuration
@@ -192,6 +197,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
192197 encoder .setdefault ('inner_linear' , inner_linear )
193198 encoder .setdefault ('inner_groups' , inner_groups )
194199 encoder .setdefault ('prenormalized' , prenormalized )
200+ encoder .setdefault ('batch_first' , batch_first )
195201
196202 decoder .setdefault ('embedding_size' , embedding_size )
197203 decoder .setdefault ('hidden_size' , hidden_size )
@@ -204,6 +210,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
204210 decoder .setdefault ('dropout' , dropout )
205211 decoder .setdefault ('inner_linear' , inner_linear )
206212 decoder .setdefault ('inner_groups' , inner_groups )
213+ decoder .setdefault ('batch_first' , batch_first )
207214 decoder .setdefault ('prenormalized' , prenormalized )
208215 decoder .setdefault ('stateful' , stateful )
209216
@@ -214,7 +221,7 @@ def __init__(self, vocab_size, hidden_size=512, embedding_size=None, num_layers=
214221 decoder .setdefault ('embedder' , embedder )
215222 decoder ['classifier' ] = False
216223
217- self .batch_first = True
224+ self .batch_first = batch_first
218225 self .encoder = TransformerAttentionEncoder (** encoder )
219226 self .decoder = TransformerAttentionDecoder (** decoder )
220227
0 commit comments