@@ -460,14 +460,6 @@ def __init__(
460460 devs = self .model .get_cache_devices () if self .fixed_device is None else [self .fixed_device ]
461461 for device in devs : self .touch_device (device )
462462
463- # Calibration mode
464-
465- self .calibrated = False
466- self .calibrating = False
467- self .calibration_rows = [0 ] * cfg .num_hidden_layers
468- self .calibration_k = {}
469- self .calibration_v = {}
470-
471463
472464 def touch_device (self , device ):
473465
@@ -516,15 +508,9 @@ def get_kv_state(
516508 block_table if block_table is not None else none_tensor ,
517509 # none_tensor,
518510 # none_tensor
519- self .calibration_k [layer_idx ] if self .calibrated else none_tensor ,
520- self .calibration_v [layer_idx ] if self .calibrated else none_tensor ,
521511 self .wbits
522512 )
523513
524- # if self.calibrated:
525- # temp_key_state *= self.calibration_k[layer_idx]
526- # temp_value_state *= self.calibration_v[layer_idx]
527-
528514 return temp_key_state , temp_value_state
529515
530516
@@ -551,10 +537,6 @@ def store_kv_state(
551537 device = self .model .cache_map .get (layer_idx , self .fixed_device )
552538 temp_key_state , temp_value_state = self .temp_tensors [device ]
553539
554- # if self.calibrated:
555- # temp_key_state /= self.calibration_k[layer_idx]
556- # temp_value_state /= self.calibration_v[layer_idx]
557-
558540 ext_c .fp16_to_q_kv (
559541 temp_key_state ,
560542 self .key_states [layer_idx ],
@@ -570,40 +552,9 @@ def store_kv_state(
570552 block_table if block_table is not None else none_tensor ,
571553 # none_tensor,
572554 # none_tensor
573- self .calibration_k [layer_idx ] if self .calibrated else none_tensor ,
574- self .calibration_v [layer_idx ] if self .calibrated else none_tensor ,
575555 self .wbits
576556 )
577557
578- # Collect calibration data
579-
580- if self .calibrating :
581-
582- cfg = self .model .config
583-
584- if layer_idx not in self .calibration_k :
585- self .calibration_k [layer_idx ] = torch .zeros (
586- (cfg .num_key_value_heads , cfg .head_dim ,),
587- dtype = torch .float ,
588- device = temp_key_state .device
589- )
590- self .calibration_v [layer_idx ] = torch .zeros (
591- (cfg .num_key_value_heads , cfg .head_dim ,),
592- dtype = torch .float ,
593- device = temp_key_state .device
594- )
595-
596- b , l , h , d = temp_key_state .shape
597- cal_k = self .calibration_k [layer_idx ]
598- cal_v = self .calibration_v [layer_idx ]
599- cal_k_input = temp_key_state [:, offset :offset + width , :, :].view (b * width , h * d )
600- cal_v_input = temp_value_state [:, offset :offset + width , :, :].view (b * width , h * d )
601- cal_k_sum = torch .norm (cal_k_input , p = 1 , dim = 0 , dtype = torch .float )
602- cal_v_sum = torch .norm (cal_v_input , p = 1 , dim = 0 , dtype = torch .float )
603- cal_k .add_ (cal_k_sum .view (h , d ))
604- cal_v .add_ (cal_v_sum .view (h , d ))
605- self .calibration_rows [layer_idx ] += width
606-
607558
608559 def footprint (self ) -> list [int ]:
609560
@@ -623,57 +574,13 @@ def footprint(self) -> list[int]:
623574
624575
625576 def clone (self ) -> ExLlamaV2Cache_Q4 :
626-
627577 new = ExLlamaV2Cache_Q4 (self .model , self .batch_size , self .max_seq_len , self )
628578 return new
629579
630-
631580 def all_tensors (self ):
632581 return self .key_states + self .value_states + self .key_scales + self .value_scales
633582
634583
635- def calibrate (self ,
636- tokenizer : ExLlamaV2Tokenizer ,
637- num_batches = 8 ,
638- num_samples_per_batch = 256
639- ):
640- """
641- Unfinished
642- """
643-
644- assert self .max_seq_len >= num_samples_per_batch , \
645- f"Cache max_seq_len must be at least { num_samples_per_batch } to calibrate."
646-
647- self .calibrating = True
648- torch .manual_seed (123 )
649-
650- for _ in range (num_batches ):
651-
652- input_ids = torch .randint (
653- low = 0 ,
654- high = tokenizer .get_vocab_size () - 1 ,
655- size = (1 , num_samples_per_batch ),
656- dtype = torch .long
657- )
658-
659- self .reset ()
660- self .model .forward (input_ids , preprocess_only = True , cache = self )
661-
662- self .calibrating = False
663-
664- for i in range (self .model .config .num_hidden_layers ):
665- cal_k = self .calibration_k [i ] / self .calibration_rows [i ] # self.calibration_k[i].mean()
666- cal_v = self .calibration_v [i ] / self .calibration_rows [i ] # self.calibration_v[i].mean()
667- cal_k = cal_k ** (1 / 8 )
668- cal_v = cal_v ** (1 / 8 )
669- cal_k = cal_k .half () * (- 1 )
670- cal_v = cal_v .half () * (- 1 )
671- self .calibration_k [i ] = cal_k
672- self .calibration_v [i ] = cal_v
673- self .calibrating = False
674- # self.calibrated = True
675-
676-
677584class ExLlamaV2Cache_Q4 (ExLlamaV2Cache_Q ):
678585
679586 def __init__ (
0 commit comments