1818# are both < 2048 tokens.
1919
2020
21- def rotaryembeddings (dim : int , maxseqlen = 2048 , base = 10000 ):
21+ def rotaryembeddings (dim : int , maxseqlen = 2048 , base = 10000 , device = None ):
2222 inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 ).float () / dim ))
2323 tmax = torch .arange (maxseqlen , device = inv_freq .device )
2424 rope = torch .outer (tmax , inv_freq ).float ()
2525 # rope is now matrix [maxseqlen, dim/2]
2626 rope = torch .polar (torch .ones_like (rope ), rope )
2727 rope = torch .cat ((rope , rope ), dim = 1 )
28- return rope
28+ if device is not None :
29+ rope = rope .to (device )
30+ cos = rope [:, : rope .size (1 ) // 2 ].real .contiguous ().half ()
31+ sin = rope [:, : rope .size (1 ) // 2 ].imag .contiguous ().half ()
32+ return rope , cos , sin
2933
3034
3135def rotate_half (x ):
@@ -369,12 +373,8 @@ def __init__(
369373 self .rotary_dim = self .dim_per_head
370374 else :
371375 self .rotary_dim = rotary_dim
372- self .rope = rotaryembeddings (self .rotary_dim , base = rotary_theta )
373- self .cos = (
374- self .rope [:, : self .rope .size (1 ) // 2 ].real .contiguous ().half ()
375- )
376- self .sin = (
377- self .rope [:, : self .rope .size (1 ) // 2 ].imag .contiguous ().half ()
376+ self .rope , self .cos , self .sin = rotaryembeddings (
377+ self .rotary_dim , base = rotary_theta
378378 )
379379 self .rotary_interleave = rotary_interleave
380380 self .rotary_theta = rotary_theta
@@ -465,11 +465,13 @@ def forward(
465465 ):
466466 if self .max_relative_positions == - 1 : # Rotary Embeddings
467467 if seqlen > self .rope .size (0 ):
468- self .rope = rotaryembeddings (
468+
469+ self .rope , _ , _ = rotaryembeddings (
469470 self .rotary_dim ,
470471 maxseqlen = (seqlen + 2048 ),
471472 base = self .rotary_theta ,
472- ).to (self .rope .device )
473+ device = self .rope .device ,
474+ )
473475 rope = self .rope [start_pos : start_pos + seqlen ]
474476 query , key = apply_rotary_emb (
475477 query , key , rope , interleave = self .rotary_interleave
@@ -486,23 +488,6 @@ def forward(
486488 self .layer_cache [1 ]["values" ] = value
487489
488490 else :
489- if self .max_relative_positions == - 1 : # Rotary Embeddings
490- if seqlen > self .rope .size (0 ):
491- self .rope = rotaryembeddings (
492- self .rotary_dim ,
493- maxseqlen = (seqlen + 2048 ),
494- base = self .rotary_theta ,
495- ).to (self .rope .device )
496- self .cos = (
497- self .rope [:, : self .rope .size (1 ) // 2 ]
498- .real .contiguous ()
499- .half ()
500- )
501- self .sin = (
502- self .rope [:, : self .rope .size (1 ) // 2 ]
503- .imag .contiguous ()
504- .half ()
505- )
506491 if start_pos >= self .layer_cache [1 ]["keys" ].size (2 ):
507492 self .layer_cache [1 ]["keys" ] = torch .cat (
508493 [
@@ -528,6 +513,20 @@ def forward(
528513 ],
529514 dim = - 2 ,
530515 )
516+ if (
517+ self .max_relative_positions == - 1
518+ and start_pos + 32 >= self .rope .size (0 )
519+ ):
520+ # Resize rotary embeddings.
521+ # We take a margin of 32 tokens as the kv_cache
522+ # is incremented by 32 tokens every 32 tokens.
523+ self .rope , self .cos , self .sin = rotaryembeddings (
524+ self .rotary_dim ,
525+ maxseqlen = (start_pos + 2048 ),
526+ base = self .rotary_theta ,
527+ device = self .rope .device ,
528+ )
529+
531530 if sliding_window > 0 and key .size (2 ) > sliding_window :
532531 self .layer_cache [1 ]["keys" ] = self .layer_cache [1 ]["keys" ][
533532 :, :, 1 :, :
@@ -593,12 +592,14 @@ def forward(
593592 start_pos = 0
594593 seqlen = query .size (2 )
595594 if seqlen > self .rope .size (0 ):
596- self .rope = rotaryembeddings (
595+ # Resize rotary embeddings.
596+ self .rope , self .cos , self .sin = rotaryembeddings (
597597 self .rotary_dim ,
598598 maxseqlen = (seqlen + 2048 ),
599599 base = self .rotary_theta ,
600- ).to (self .rope .device )
601- rope = self .rope [start_pos : start_pos + seqlen ].to (query .device )
600+ device = query .device ,
601+ )
602+ rope = self .rope [start_pos : start_pos + seqlen ]
602603 query , key = apply_rotary_emb (
603604 query , key , rope , interleave = self .rotary_interleave
604605 )
0 commit comments