diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index e7725be6d23..1b7f947c1f3 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -360,10 +360,6 @@ def __init__( self.output_sizes = output_sizes def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): - weight_need_transpose = getattr(param, "weight_need_transpose", False) - if weight_need_transpose: - loaded_weight = get_tensor(loaded_weight).transpose([1, 0]) - assert loaded_shard_id in ["q_a", "kv_a"] if not param._is_initialized(): param.initialize() @@ -389,7 +385,6 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N else: loaded_weight = loaded_weight.cast(param.dtype) # (bukejiyu) After this fix, the early H2D copy for non-GPU devices is no longer needed and can be safely removed. - loaded_weight = get_tensor(loaded_weight) h2d_copy(param, loaded_weight) @@ -962,7 +957,10 @@ def __init__( self.num_heads_per_partition = divide(num_attention_heads, self.nranks) self.local_rank = fd_config.parallel_config.tensor_parallel_rank self.fd_config = fd_config - self.kv_b_proj = kv_b_proj + if self.fd_config.load_config.load_choices == "default_v1": + self.kv_b_proj = kv_b_proj + else: + self.kv_b_proj = None self.weight_dtype = self._helper.get_default_dtype() diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 3b42e0294e6..5debb0b790d 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -141,7 +141,7 @@ def process_weight_transpose(layer, weight_name): is_bias=False, ) if layer.fd_config.load_config.dynamic_load_weight or layer.fd_config.model_config.enable_cache: - free_tensor(weight) + free_tensor(weight, clear_memory=False) setattr(layer, weight_name, weight_tmp) return @@ -150,7 +150,7 @@ def process_weight_transpose(layer, weight_name): elif len(weight.shape) == 3: weight_transpose = weight.transpose([0, 2, 1]) weight_tmp.copy_(weight_transpose, False) - free_tensor(weight) + free_tensor(weight, clear_memory=False) setattr(layer, weight_name, weight_tmp) @@ -260,11 +260,13 @@ def process_final_after_loading(model, fd_config: FDConfig): sublayer.process_weights_after_loading() -def free_tensor(tensor): +def free_tensor(tensor, clear_memory=True): if hasattr(tensor, "tensor_track"): tensor.tensor_track = None tensor.value().get_tensor()._clear() del tensor + if clear_memory: + paddle.device.cuda.empty_cache() def fd_cast(weight, param):