diff --git a/fastdeploy/model_executor/model_loader/default_loader.py b/fastdeploy/model_executor/model_loader/default_loader.py index 7be3dca6a8a..69db999918a 100644 --- a/fastdeploy/model_executor/model_loader/default_loader.py +++ b/fastdeploy/model_executor/model_loader/default_loader.py @@ -94,3 +94,31 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer: # TODO(gongshaotian): Now, only support safetensor self.load_weights(model, fd_config, architectures) return model + + def load_rl_model(self, fd_config: FDConfig) -> nn.Layer: + """use for rl model load""" + # (TODO:gaoziyuan) optimze + original_architectures = fd_config.model_config.architectures[0] + logger.info(f"Starting to load model {original_architectures}") + + import fastdeploy.rl # noqa + + if fd_config.speculative_config.model_type != "mtp": + model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM") + else: + model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM") + + model_architectures += "RL" + context = paddle.LazyGuard() + + with context: + model_cls = ModelRegistry.get_class(model_architectures) + model = model_cls(fd_config) + + model.eval() + + if fd_config.load_config.load_strategy == "normal": + # normal strategy need load weight and architectures need without "RL" + self.load_weights(model, fd_config, original_architectures) + # RL model not need set_state_dict + return model diff --git a/fastdeploy/model_executor/model_loader/default_loader_v1.py b/fastdeploy/model_executor/model_loader/default_loader_v1.py index 9164e61af2e..b0708c7d062 100644 --- a/fastdeploy/model_executor/model_loader/default_loader_v1.py +++ b/fastdeploy/model_executor/model_loader/default_loader_v1.py @@ -77,3 +77,31 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer: return model self.load_weights(model, fd_config) return model + + def load_rl_model(self, fd_config: FDConfig) -> nn.Layer: + """use for rl model load""" + # (TODO:gaoziyuan) optimze + original_architectures = fd_config.model_config.architectures[0] + logger.info(f"Starting to load model {original_architectures}") + + import fastdeploy.rl # noqa + + if fd_config.speculative_config.model_type != "mtp": + model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM") + else: + model_architectures = original_architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM") + + model_architectures += "RL" + context = paddle.LazyGuard() + + with context: + model_cls = ModelRegistry.get_class(model_architectures) + model = model_cls(fd_config) + + model.eval() + + if fd_config.load_config.load_strategy == "normal": + # normal strategy need load weight and architectures need without "RL" + self.load_weights(model, fd_config, original_architectures) + # RL model not need set_state_dict + return model diff --git a/fastdeploy/rl/dynamic_weight_manager.py b/fastdeploy/rl/dynamic_weight_manager.py index e3eea99a603..1cb0d66e96a 100644 --- a/fastdeploy/rl/dynamic_weight_manager.py +++ b/fastdeploy/rl/dynamic_weight_manager.py @@ -84,6 +84,7 @@ def update_parameters(self, pid: int = 0) -> None: strategy_handlers = { "ipc_snapshot": self._update_ipc_snapshot, "ipc": self._update_ipc, + "normal": self._normal_load_weight, } if handler := strategy_handlers.get(self.load_config.load_strategy): @@ -97,6 +98,14 @@ def update_parameters(self, pid: int = 0) -> None: # step 5: recapture CUDAGraph # step 6: update weight status signal + def _normal_load_weight(self): + """use for RL mock""" + from fastdeploy.model_executor.model_loader import get_model_loader + + model_loader = get_model_loader(load_config=self.fd_config.load_config) + state_dict = model_loader.load_rl_model(fd_config=self.fd_config).state_dict() + self._update_model_from_state(state_dict, "raw") + def _update_ipc_snapshot(self): """Update using IPC snapshot strategy for elastic recovery.""" model_path = os.path.join(