Skip to content

Commit 0e72326

Browse files
authored
Fix generation with large sequences (#2561)
* fixed a bug in generation when the sequence is larger than 2048 tokens * fixed the dynamic resizing of rotary embeddings
1 parent 5deb20e commit 0e72326

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

onmt/modules/multi_headed_attn.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@
1818
# are both < 2048 tokens.
1919

2020

21-
def rotaryembeddings(dim: int, maxseqlen=2048, base=10000):
21+
def rotaryembeddings(dim: int, maxseqlen=2048, base=10000, device=None):
2222
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
2323
tmax = torch.arange(maxseqlen, device=inv_freq.device)
2424
rope = torch.outer(tmax, inv_freq).float()
2525
# rope is now matrix [maxseqlen, dim/2]
2626
rope = torch.polar(torch.ones_like(rope), rope)
2727
rope = torch.cat((rope, rope), dim=1)
28-
return rope
28+
if device is not None:
29+
rope = rope.to(device)
30+
cos = rope[:, : rope.size(1) // 2].real.contiguous().half()
31+
sin = rope[:, : rope.size(1) // 2].imag.contiguous().half()
32+
return rope, cos, sin
2933

3034

3135
def rotate_half(x):
@@ -369,12 +373,8 @@ def __init__(
369373
self.rotary_dim = self.dim_per_head
370374
else:
371375
self.rotary_dim = rotary_dim
372-
self.rope = rotaryembeddings(self.rotary_dim, base=rotary_theta)
373-
self.cos = (
374-
self.rope[:, : self.rope.size(1) // 2].real.contiguous().half()
375-
)
376-
self.sin = (
377-
self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half()
376+
self.rope, self.cos, self.sin = rotaryembeddings(
377+
self.rotary_dim, base=rotary_theta
378378
)
379379
self.rotary_interleave = rotary_interleave
380380
self.rotary_theta = rotary_theta
@@ -465,11 +465,13 @@ def forward(
465465
):
466466
if self.max_relative_positions == -1: # Rotary Embeddings
467467
if seqlen > self.rope.size(0):
468-
self.rope = rotaryembeddings(
468+
469+
self.rope, _, _ = rotaryembeddings(
469470
self.rotary_dim,
470471
maxseqlen=(seqlen + 2048),
471472
base=self.rotary_theta,
472-
).to(self.rope.device)
473+
device=self.rope.device,
474+
)
473475
rope = self.rope[start_pos : start_pos + seqlen]
474476
query, key = apply_rotary_emb(
475477
query, key, rope, interleave=self.rotary_interleave
@@ -486,23 +488,6 @@ def forward(
486488
self.layer_cache[1]["values"] = value
487489

488490
else:
489-
if self.max_relative_positions == -1: # Rotary Embeddings
490-
if seqlen > self.rope.size(0):
491-
self.rope = rotaryembeddings(
492-
self.rotary_dim,
493-
maxseqlen=(seqlen + 2048),
494-
base=self.rotary_theta,
495-
).to(self.rope.device)
496-
self.cos = (
497-
self.rope[:, : self.rope.size(1) // 2]
498-
.real.contiguous()
499-
.half()
500-
)
501-
self.sin = (
502-
self.rope[:, : self.rope.size(1) // 2]
503-
.imag.contiguous()
504-
.half()
505-
)
506491
if start_pos >= self.layer_cache[1]["keys"].size(2):
507492
self.layer_cache[1]["keys"] = torch.cat(
508493
[
@@ -528,6 +513,20 @@ def forward(
528513
],
529514
dim=-2,
530515
)
516+
if (
517+
self.max_relative_positions == -1
518+
and start_pos + 32 >= self.rope.size(0)
519+
):
520+
# Resize rotary embeddings.
521+
# We take a margin of 32 tokens as the kv_cache
522+
# is incremented by 32 tokens every 32 tokens.
523+
self.rope, self.cos, self.sin = rotaryembeddings(
524+
self.rotary_dim,
525+
maxseqlen=(start_pos + 2048),
526+
base=self.rotary_theta,
527+
device=self.rope.device,
528+
)
529+
531530
if sliding_window > 0 and key.size(2) > sliding_window:
532531
self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][
533532
:, :, 1:, :
@@ -593,12 +592,14 @@ def forward(
593592
start_pos = 0
594593
seqlen = query.size(2)
595594
if seqlen > self.rope.size(0):
596-
self.rope = rotaryembeddings(
595+
# Resize rotary embeddings.
596+
self.rope, self.cos, self.sin = rotaryembeddings(
597597
self.rotary_dim,
598598
maxseqlen=(seqlen + 2048),
599599
base=self.rotary_theta,
600-
).to(self.rope.device)
601-
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
600+
device=query.device,
601+
)
602+
rope = self.rope[start_pos : start_pos + seqlen]
602603
query, key = apply_rotary_emb(
603604
query, key, rope, interleave=self.rotary_interleave
604605
)

0 commit comments

Comments
 (0)