From 34c53c1ed8e20d5fdb7407938e8b159f70d8eded Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Fri, 28 Nov 2025 20:31:51 +0000 Subject: [PATCH] Add safetensors support So we can load these natively just like gguf Signed-off-by: Eric Curtin --- common/arg.cpp | 103 ++- common/download.cpp | 89 +- common/download.h | 12 + src/CMakeLists.txt | 5 + src/llama-hf-config.cpp | 220 +++++ src/llama-hf-config.h | 64 ++ src/llama-model-from-safetensors.cpp | 1234 ++++++++++++++++++++++++++ src/llama-model-from-safetensors.h | 76 ++ src/llama-model.cpp | 110 +++ src/llama-model.h | 13 + src/llama-safetensors-loader.cpp | 275 ++++++ src/llama-safetensors-loader.h | 96 ++ src/llama-safetensors-types.cpp | 171 ++++ src/llama-safetensors-types.h | 29 + src/llama-safetensors.cpp | 398 +++++++++ src/llama-safetensors.h | 139 +++ src/llama-vocab.cpp | 224 +++++ src/llama-vocab.h | 7 + src/llama.cpp | 66 ++ tools/run/run.cpp | 15 +- 20 files changed, 3322 insertions(+), 24 deletions(-) create mode 100644 src/llama-hf-config.cpp create mode 100644 src/llama-hf-config.h create mode 100644 src/llama-model-from-safetensors.cpp create mode 100644 src/llama-model-from-safetensors.h create mode 100644 src/llama-safetensors-loader.cpp create mode 100644 src/llama-safetensors-loader.h create mode 100644 src/llama-safetensors-types.cpp create mode 100644 src/llama-safetensors-types.h create mode 100644 src/llama-safetensors.cpp create mode 100644 src/llama-safetensors.h diff --git a/common/arg.cpp b/common/arg.cpp index 9a874c6b1d0..a11f78c0f0f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -224,30 +224,99 @@ static handle_model_result common_params_handle_model( if (model.hf_file.empty()) { if (model.path.empty()) { auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline); - if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { + if (auto_detected.repo.empty()) { exit(1); // built without CURL, error message already printed } + model.hf_repo = auto_detected.repo; - model.hf_file = auto_detected.ggufFile; - if (!auto_detected.mmprojFile.empty()) { - result.found_mmproj = true; - result.mmproj.hf_repo = model.hf_repo; - result.mmproj.hf_file = auto_detected.mmprojFile; + + // Handle safetensors format + if (auto_detected.is_safetensors) { + LOG_INF("%s: detected safetensors format for %s\n", __func__, model.hf_repo.c_str()); + + // Create a directory for the safetensors files + std::string dir_name = model.hf_repo; + string_replace_all(dir_name, "/", "_"); + model.path = fs_get_cache_directory() + "/" + dir_name; + + // Create directory if it doesn't exist + std::filesystem::create_directories(model.path); + + // Download required files: config.json, tokenizer.json, tokenizer_config.json, and .safetensors files + std::string model_endpoint = get_model_endpoint(); + std::vector> files_to_download; + + // Required config files + files_to_download.push_back({ + model_endpoint + model.hf_repo + "/resolve/main/config.json", + model.path + "/config.json" + }); + files_to_download.push_back({ + model_endpoint + model.hf_repo + "/resolve/main/tokenizer.json", + model.path + "/tokenizer.json" + }); + files_to_download.push_back({ + model_endpoint + model.hf_repo + "/resolve/main/tokenizer_config.json", + model.path + "/tokenizer_config.json" + }); + + // Safetensors files + for (const auto & st_file : auto_detected.safetensors_files) { + files_to_download.push_back({ + model_endpoint + model.hf_repo + "/resolve/main/" + st_file, + model.path + "/" + st_file + }); + } + + // Download all files + LOG_INF("%s: downloading %zu files for safetensors model...\n", __func__, files_to_download.size()); + for (const auto & [url, path] : files_to_download) { + bool ok = common_download_file_single(url, path, bearer_token, offline); + if (!ok) { + LOG_ERR("error: failed to download file from %s\n", url.c_str()); + exit(1); + } + } + + LOG_INF("%s: safetensors model downloaded to %s\n", __func__, model.path.c_str()); + } else { + // Handle GGUF format (existing logic) + if (auto_detected.ggufFile.empty()) { + exit(1); // no GGUF file found + } + model.hf_file = auto_detected.ggufFile; + if (!auto_detected.mmprojFile.empty()) { + result.found_mmproj = true; + result.mmproj.hf_repo = model.hf_repo; + result.mmproj.hf_file = auto_detected.mmprojFile; + } + + std::string model_endpoint = get_model_endpoint(); + model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; + // make sure model path is present (for caching purposes) + if (model.path.empty()) { + // this is to avoid different repo having same file name, or same file name in different subdirs + std::string filename = model.hf_repo + "_" + model.hf_file; + // to make sure we don't have any slashes in the filename + string_replace_all(filename, "/", "_"); + model.path = fs_get_cache_file(filename); + } } } else { model.hf_file = model.path; } - } - - std::string model_endpoint = get_model_endpoint(); - model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; - // make sure model path is present (for caching purposes) - if (model.path.empty()) { - // this is to avoid different repo having same file name, or same file name in different subdirs - std::string filename = model.hf_repo + "_" + model.hf_file; - // to make sure we don't have any slashes in the filename - string_replace_all(filename, "/", "_"); - model.path = fs_get_cache_file(filename); + } else { + // User specified hf_file explicitly - use GGUF download path + std::string model_endpoint = get_model_endpoint(); + model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file; + // make sure model path is present (for caching purposes) + if (model.path.empty()) { + // this is to avoid different repo having same file name, or same file name in different subdirs + std::string filename = model.hf_repo + "_" + model.hf_file; + // to make sure we don't have any slashes in the filename + string_replace_all(filename, "/", "_"); + model.path = fs_get_cache_file(filename); + } } } else if (!model.url.empty()) { diff --git a/common/download.cpp b/common/download.cpp index eeb32b6a863..ad9ee6456b9 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -715,10 +715,10 @@ std::pair> common_remote_get_content(const std::string #if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB) -static bool common_download_file_single(const std::string & url, - const std::string & path, - const std::string & bearer_token, - bool offline) { +bool common_download_file_single(const std::string & url, + const std::string & path, + const std::string & bearer_token, + bool offline) { if (!offline) { return common_download_file_single_online(url, path, bearer_token); } @@ -897,16 +897,93 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons } } else if (res_code == 401) { throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else if (res_code == 400) { + // 400 typically means "not a GGUF repo" - we'll check for safetensors below + LOG_INF("%s: manifest endpoint returned 400 (not a GGUF repo), will check for safetensors...\n", __func__); } else { throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); } // check response if (ggufFile.empty()) { - throw std::runtime_error("error: model does not have ggufFile"); + // No GGUF found - try to detect safetensors format + LOG_INF("%s: no GGUF file found, checking for safetensors format...\n", __func__); + + // Query HF API to list files in the repo + std::string files_url = get_model_endpoint() + "api/models/" + hf_repo + "/tree/main"; + + common_remote_params files_params; + files_params.headers = headers; + + long files_res_code = 0; + std::string files_res_str; + + if (!offline) { + try { + auto files_res = common_remote_get_content(files_url, files_params); + files_res_code = files_res.first; + files_res_str = std::string(files_res.second.data(), files_res.second.size()); + } catch (const std::exception & e) { + throw std::runtime_error("error: model does not have ggufFile and failed to check for safetensors: " + std::string(e.what())); + } + } else { + throw std::runtime_error("error: model does not have ggufFile (offline mode, cannot check for safetensors)"); + } + + if (files_res_code != 200) { + throw std::runtime_error("error: model does not have ggufFile"); + } + + // Parse the files list + std::vector safetensors_files; + bool has_config = false; + bool has_tokenizer = false; + + try { + auto files_json = json::parse(files_res_str); + + for (const auto & file : files_json) { + if (file.contains("path")) { + std::string path = file["path"].get(); + + if (path == "config.json") { + has_config = true; + } else if (path == "tokenizer.json") { + has_tokenizer = true; + } else { + // Check for .safetensors extension + const std::string suffix = ".safetensors"; + if (path.size() >= suffix.size() && + path.compare(path.size() - suffix.size(), suffix.size(), suffix) == 0) { + safetensors_files.push_back(path); + } + } + } + } + } catch (const std::exception & e) { + throw std::runtime_error("error: model does not have ggufFile and failed to parse file list: " + std::string(e.what())); + } + + // Check if we have the required safetensors files + if (!has_config || !has_tokenizer || safetensors_files.empty()) { + throw std::runtime_error("error: model does not have ggufFile or valid safetensors format"); + } + + LOG_INF("%s: detected safetensors format with %zu tensor files\n", __func__, safetensors_files.size()); + + common_hf_file_res result; + result.repo = hf_repo; + result.is_safetensors = true; + result.safetensors_files = safetensors_files; + return result; } - return { hf_repo, ggufFile, mmprojFile }; + common_hf_file_res result; + result.repo = hf_repo; + result.ggufFile = ggufFile; + result.mmprojFile = mmprojFile; + result.is_safetensors = false; + return result; } // diff --git a/common/download.h b/common/download.h index 45a6bd6bba8..c7a186665c0 100644 --- a/common/download.h +++ b/common/download.h @@ -1,6 +1,7 @@ #pragma once #include +#include struct common_params_model; @@ -23,6 +24,10 @@ struct common_hf_file_res { std::string repo; // repo name with ":tag" removed std::string ggufFile; std::string mmprojFile; + + // Safetensors support + bool is_safetensors = false; // true if model is in safetensors format + std::vector safetensors_files; // list of .safetensors files to download }; /** @@ -41,6 +46,13 @@ common_hf_file_res common_get_hf_file( const std::string & bearer_token, bool offline); +// download a single file (no GGUF validation) +bool common_download_file_single( + const std::string & url, + const std::string & path, + const std::string & bearer_token, + bool offline); + // returns true if download succeeded bool common_download_model( const common_params_model & model, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 67c7807e092..b5ce330e43f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -32,6 +32,11 @@ add_library(llama llama-quant.cpp llama-sampling.cpp llama-vocab.cpp + llama-safetensors.cpp + llama-hf-config.cpp + llama-safetensors-loader.cpp + llama-safetensors-types.cpp + llama-model-from-safetensors.cpp unicode-data.cpp unicode.cpp unicode.h diff --git a/src/llama-hf-config.cpp b/src/llama-hf-config.cpp new file mode 100644 index 00000000000..8fa58cbaba4 --- /dev/null +++ b/src/llama-hf-config.cpp @@ -0,0 +1,220 @@ +#include "llama-hf-config.h" + +#include +#include "../vendor/nlohmann/json.hpp" + +using json = nlohmann::json; + +bool hf_config::load_from_file(const std::string & config_path) { + std::ifstream f(config_path); + if (!f.is_open()) { + error_msg = "Failed to open config file: " + config_path; + return false; + } + + try { + config = std::make_unique(); + f >> *config; + } catch (const std::exception & e) { + error_msg = std::string("Failed to parse config JSON: ") + e.what(); + return false; + } + + return true; +} + +bool hf_config::load_from_string(const std::string & json_str) { + try { + config = std::make_unique(json::parse(json_str)); + } catch (const std::exception & e) { + error_msg = std::string("Failed to parse config JSON: ") + e.what(); + return false; + } + + return true; +} + +std::string hf_config::get_architecture() const { + if (!config) { + return ""; + } + + // Check for architectures array (most common) + if (config->contains("architectures") && (*config)["architectures"].is_array()) { + const auto & archs = (*config)["architectures"]; + if (!archs.empty() && archs[0].is_string()) { + return archs[0].get(); + } + } + + // Check text_config (for multimodal models) + if (config->contains("text_config") && (*config)["text_config"].is_object()) { + const auto & text_config = (*config)["text_config"]; + if (text_config.contains("architectures") && text_config["architectures"].is_array()) { + const auto & archs = text_config["architectures"]; + if (!archs.empty() && archs[0].is_string()) { + return archs[0].get(); + } + } + } + + // Check for ssm_cfg (Mamba models) + if (config->contains("ssm_cfg") && (*config)["ssm_cfg"].is_object()) { + const auto & ssm_cfg = (*config)["ssm_cfg"]; + if (ssm_cfg.contains("layer") && ssm_cfg["layer"].is_string()) { + return ssm_cfg["layer"].get() + "ForCausalLM"; + } + } + + return ""; +} + +template +bool hf_config::get_value_with_fallback(const std::string & key, T & out) const { + if (!config) { + return false; + } + + // First try root level + if (config->contains(key)) { + try { + out = (*config)[key].get(); + return true; + } catch (const std::exception &) { + return false; + } + } + + // Try text_config (for multimodal models) + if (config->contains("text_config") && (*config)["text_config"].is_object()) { + const auto & text_config = (*config)["text_config"]; + if (text_config.contains(key)) { + try { + out = text_config[key].get(); + return true; + } catch (const std::exception &) { + return false; + } + } + } + + return false; +} + +bool hf_config::get_int(const std::string & key, int64_t & out) const { + return get_value_with_fallback(key, out); +} + +bool hf_config::get_float(const std::string & key, double & out) const { + return get_value_with_fallback(key, out); +} + +bool hf_config::get_string(const std::string & key, std::string & out) const { + return get_value_with_fallback(key, out); +} + +bool hf_config::get_bool(const std::string & key, bool & out) const { + return get_value_with_fallback(key, out); +} + +bool hf_config::has_key(const std::string & key) const { + if (!config) { + return false; + } + + if (config->contains(key)) { + return true; + } + + // Check text_config + if (config->contains("text_config") && (*config)["text_config"].is_object()) { + return (*config)["text_config"].contains(key); + } + + return false; +} + +const nlohmann::json * hf_config::get_json() const { + return config.get(); +} + +// Common configuration getters + +int64_t hf_config::get_hidden_size() const { + int64_t val = 0; + // Try multiple possible keys + if (get_int("hidden_size", val)) return val; + if (get_int("d_model", val)) return val; + if (get_int("n_embd", val)) return val; + return 0; +} + +int64_t hf_config::get_num_hidden_layers() const { + int64_t val = 0; + if (get_int("num_hidden_layers", val)) return val; + if (get_int("n_layers", val)) return val; + if (get_int("n_layer", val)) return val; + if (get_int("num_layers", val)) return val; + return 0; +} + +int64_t hf_config::get_num_attention_heads() const { + int64_t val = 0; + if (get_int("num_attention_heads", val)) return val; + if (get_int("n_heads", val)) return val; + if (get_int("n_head", val)) return val; + return 0; +} + +int64_t hf_config::get_num_key_value_heads() const { + int64_t val = 0; + if (get_int("num_key_value_heads", val)) return val; + // If not specified, defaults to num_attention_heads (MHA) + return get_num_attention_heads(); +} + +int64_t hf_config::get_intermediate_size() const { + int64_t val = 0; + if (get_int("intermediate_size", val)) return val; + if (get_int("n_inner", val)) return val; + return 0; +} + +int64_t hf_config::get_vocab_size() const { + int64_t val = 0; + if (get_int("vocab_size", val)) return val; + if (get_int("padded_vocab_size", val)) return val; + return 0; +} + +int64_t hf_config::get_max_position_embeddings() const { + int64_t val = 0; + if (get_int("max_position_embeddings", val)) return val; + if (get_int("n_positions", val)) return val; + if (get_int("n_ctx", val)) return val; + return 0; +} + +double hf_config::get_rms_norm_eps() const { + double val = 0; + if (get_float("rms_norm_eps", val)) return val; + if (get_float("layer_norm_eps", val)) return val; + if (get_float("layer_norm_epsilon", val)) return val; + return 1e-5; // common default +} + +std::string hf_config::get_rope_scaling_type() const { + if (!config) { + return ""; + } + + // Check for rope_scaling object + if (config->contains("rope_scaling") && (*config)["rope_scaling"].is_object()) { + const auto & rope_scaling = (*config)["rope_scaling"]; + if (rope_scaling.contains("type") && rope_scaling["type"].is_string()) { + return rope_scaling["type"].get(); + } + } + + return ""; +} diff --git a/src/llama-hf-config.h b/src/llama-hf-config.h new file mode 100644 index 00000000000..cacb3f602bd --- /dev/null +++ b/src/llama-hf-config.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include +#include +#include + +#include "../vendor/nlohmann/json.hpp" + +// HuggingFace model configuration +class hf_config { +public: + hf_config() = default; + ~hf_config() = default; + + // Load config from file + bool load_from_file(const std::string & config_path); + + // Load config from JSON string + bool load_from_string(const std::string & json_str); + + // Get architecture name (e.g., "LlamaForCausalLM", "MistralForCausalLM") + std::string get_architecture() const; + + // Get a configuration value as integer + bool get_int(const std::string & key, int64_t & out) const; + + // Get a configuration value as float + bool get_float(const std::string & key, double & out) const; + + // Get a configuration value as string + bool get_string(const std::string & key, std::string & out) const; + + // Get a configuration value as bool + bool get_bool(const std::string & key, bool & out) const; + + // Check if a key exists + bool has_key(const std::string & key) const; + + // Get raw JSON object (for advanced users) + const nlohmann::json * get_json() const; + + // Get last error message + const std::string & get_error() const { return error_msg; } + + // Common configuration getters + int64_t get_hidden_size() const; + int64_t get_num_hidden_layers() const; + int64_t get_num_attention_heads() const; + int64_t get_num_key_value_heads() const; + int64_t get_intermediate_size() const; + int64_t get_vocab_size() const; + int64_t get_max_position_embeddings() const; + double get_rms_norm_eps() const; + std::string get_rope_scaling_type() const; + +private: + std::unique_ptr config; + std::string error_msg; + + // Helper to get value, checking nested configs (text_config, vision_config) + template + bool get_value_with_fallback(const std::string & key, T & out) const; +}; diff --git a/src/llama-model-from-safetensors.cpp b/src/llama-model-from-safetensors.cpp new file mode 100644 index 00000000000..2c951245fd8 --- /dev/null +++ b/src/llama-model-from-safetensors.cpp @@ -0,0 +1,1234 @@ +#include "llama-model-from-safetensors.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-hparams.h" +#include "llama-safetensors-types.h" + +#include "../vendor/nlohmann/json.hpp" + +#include +#include +#include + +// Helper function to apply head permutation for HuggingFace attention weights +// This reverses the HF permutation: reshape(n_head, 2, dim/(n_head*2), *) -> swap(1,2) -> reshape +static void apply_head_permutation( + std::vector & data, + size_t elem_size, + size_t out_dim, + size_t in_dim, + int n_head +) { + // Verify dimensions are compatible + if (out_dim % (n_head * 2) != 0) { + LLAMA_LOG_WARN("%s: out_dim %zu not divisible by n_head*2 (%d), skipping permutation\n", + __func__, out_dim, n_head * 2); + return; + } + + size_t head_dim = out_dim / n_head; // Dimension per head + size_t half_head = head_dim / 2; // Half of head dimension + + std::vector permuted(data.size()); + + // Apply permutation: swap pairs within each head + // Original layout: [n_head, 2, half_head, in_dim] + // Permuted layout: [n_head, half_head, 2, in_dim] + for (size_t h = 0; h < (size_t)n_head; h++) { // For each head + for (size_t hh = 0; hh < half_head; hh++) { // For each half-head element + for (size_t pair = 0; pair < 2; pair++) { // For the pair (0 or 1) + for (size_t i = 0; i < in_dim; i++) { // For each input dimension + // Source: [h, pair, hh, i] in original [n_head, 2, half_head, in_dim] + size_t src_idx = ((h * 2 + pair) * half_head + hh) * in_dim + i; + // Dest: [h, hh, pair, i] in permuted [n_head, half_head, 2, in_dim] + size_t dst_idx = ((h * half_head + hh) * 2 + pair) * in_dim + i; + + memcpy(permuted.data() + dst_idx * elem_size, + data.data() + src_idx * elem_size, + elem_size); + } + } + } + } + + data = std::move(permuted); +} + +// Main entry point +llama_model * llama_model_load_from_safetensors( + const char * model_path, + const llama_model_params & params +) { + if (!model_path) { + LLAMA_LOG_ERROR("%s: model_path is null\n", __func__); + return nullptr; + } + + // Determine if path is directory or file + std::string path_str(model_path); + std::filesystem::path path(path_str); + + std::string model_dir; + if (std::filesystem::is_directory(path)) { + model_dir = path_str; + } else if (std::filesystem::is_regular_file(path)) { + model_dir = path.parent_path().string(); + } else { + LLAMA_LOG_ERROR("%s: invalid path: %s\n", __func__, model_path); + return nullptr; + } + + // Create builder and build model + safetensors_model_builder builder(model_dir, params); + llama_model * model = builder.build(); + + if (!model) { + LLAMA_LOG_ERROR("%s: failed to load model: %s\n", __func__, builder.get_error().c_str()); + } + + return model; +} + +// Implementation +safetensors_model_builder::safetensors_model_builder( + const std::string & model_dir, + const llama_model_params & params +) : model_dir(model_dir), params(params) { +} + +safetensors_model_builder::~safetensors_model_builder() { + // Clean up backend buffers from map + for (auto & pair : buffer_map) { + if (pair.second) { + ggml_backend_buffer_free(pair.second); + } + } + buffer_map.clear(); + + // Clean up legacy backend buffer if allocated + if (backend_buffer) { + ggml_backend_buffer_free(backend_buffer); + backend_buffer = nullptr; + } + + // Clean up GGML contexts from map + for (auto & pair : ctx_map) { + if (pair.second) { + ggml_free(pair.second); + } + } + ctx_map.clear(); + + // Clean up legacy GGML contexts + if (ctx_meta) { + ggml_free(ctx_meta); + ctx_meta = nullptr; + } + + if (ctx_data) { + ggml_free(ctx_data); + ctx_data = nullptr; + } +} + +llama_model * safetensors_model_builder::build() { + LLAMA_LOG_INFO("%s: loading model from safetensors: %s\n", __func__, model_dir.c_str()); + + // Step 1: Load config.json + if (!load_config()) { + return nullptr; + } + + // Step 2: Load safetensors files + if (!load_safetensors_files()) { + return nullptr; + } + + // Step 3: Detect architecture + if (!detect_architecture()) { + return nullptr; + } + + // Step 4: Create model structure + if (!create_model_structure()) { + return nullptr; + } + + // Step 4.5: Initialize backend devices + if (!init_devices()) { + return nullptr; + } + + // Step 5: Allocate tensors + if (!allocate_tensors()) { + return nullptr; + } + + // Step 6: Load tensor data + if (!load_tensor_data()) { + return nullptr; + } + + // Step 7: Link tensors to model structure + if (!link_tensors_to_model()) { + return nullptr; + } + + // Step 8: Register buffers with model (transfer ownership) + if (!register_buffers_with_model()) { + return nullptr; + } + + // Step 9: Initialize vocabulary + if (!init_vocabulary()) { + return nullptr; + } + + // Step 10: Finalize + if (!finalize_model()) { + return nullptr; + } + + LLAMA_LOG_INFO("%s: model loaded successfully\n", __func__); + return model; +} + +bool safetensors_model_builder::load_config() { + std::string config_path = model_dir + "/config.json"; + + config = std::make_unique(); + if (!config->load_from_file(config_path)) { + error_msg = "Failed to load config.json: " + config->get_error(); + return false; + } + + LLAMA_LOG_INFO("%s: loaded config.json\n", __func__); + return true; +} + +bool safetensors_model_builder::load_safetensors_files() { + st_loader = std::make_unique(); + + // Try single file first + std::string single_file = model_dir + "/model.safetensors"; + if (std::filesystem::exists(single_file)) { + if (st_loader->load_single(single_file)) { + LLAMA_LOG_INFO("%s: loaded single safetensors file\n", __func__); + return true; + } + } + + // Try sharded model + std::string index_file = model_dir + "/model.safetensors.index.json"; + if (std::filesystem::exists(index_file)) { + if (st_loader->load_sharded(index_file, model_dir)) { + LLAMA_LOG_INFO("%s: loaded sharded safetensors files\n", __func__); + return true; + } + } + + error_msg = "No safetensors files found in: " + model_dir; + return false; +} + +bool safetensors_model_builder::detect_architecture() { + std::string hf_arch = config->get_architecture(); + if (hf_arch.empty()) { + error_msg = "Could not detect architecture from config.json"; + return false; + } + + mapper = create_tensor_mapper(hf_arch); + if (!mapper) { + error_msg = "Unsupported architecture: " + hf_arch; + return false; + } + + LLAMA_LOG_INFO("%s: detected architecture: %s\n", __func__, hf_arch.c_str()); + return true; +} + +bool safetensors_model_builder::create_model_structure() { + // Step 1: Allocate llama_model + model = new llama_model(params); + if (!model) { + error_msg = "Failed to allocate llama_model"; + return false; + } + + // Step 2: Set architecture + model->arch = mapper->get_arch(); + if (model->arch == LLM_ARCH_UNKNOWN) { + error_msg = "Unknown architecture"; + delete model; + model = nullptr; + return false; + } + + // Step 3: Initialize hparams from HF config + // Get basic hyperparameters + model->hparams.n_embd = config->get_hidden_size(); + model->hparams.n_layer = config->get_num_hidden_layers(); + + // Get context length + int64_t max_pos = config->get_max_position_embeddings(); + model->hparams.n_ctx_train = max_pos > 0 ? max_pos : 2048; + + // Get attention parameters + uint32_t n_head = config->get_num_attention_heads(); + int64_t n_head_kv_val = config->get_num_key_value_heads(); + uint32_t n_head_kv = (n_head_kv_val > 0) ? n_head_kv_val : n_head; // Default to n_head for MHA + + // Fill per-layer arrays with same values (uniform layers) + std::fill(model->hparams.n_head_arr.begin(), model->hparams.n_head_arr.end(), n_head); + std::fill(model->hparams.n_head_kv_arr.begin(), model->hparams.n_head_kv_arr.end(), n_head_kv); + + // Get feed-forward dimension + int64_t n_ff_val = config->get_intermediate_size(); + if (n_ff_val > 0) { + std::fill(model->hparams.n_ff_arr.begin(), model->hparams.n_ff_arr.end(), static_cast(n_ff_val)); + } + + // Calculate head dimensions + if (n_head > 0) { + model->hparams.n_embd_head_k = model->hparams.n_embd / n_head; + model->hparams.n_embd_head_v = model->hparams.n_embd / n_head; + model->hparams.n_rot = model->hparams.n_embd_head_k; // Full rotary + } + + // Get normalization epsilon + double norm_eps = config->get_rms_norm_eps(); + if (norm_eps > 0.0) { + model->hparams.f_norm_rms_eps = static_cast(norm_eps); + } else { + // Try layer_norm_eps as fallback + double layer_norm_eps; + if (config->get_float("layer_norm_eps", layer_norm_eps)) { + model->hparams.f_norm_rms_eps = static_cast(layer_norm_eps); + } else { + model->hparams.f_norm_rms_eps = 1e-5f; // Default + } + } + model->hparams.f_norm_eps = model->hparams.f_norm_rms_eps; + + // Get RoPE parameters + double rope_theta; + if (config->get_float("rope_theta", rope_theta)) { + model->hparams.rope_freq_base_train = static_cast(rope_theta); + } else { + model->hparams.rope_freq_base_train = 10000.0f; // Default + } + + // Check for RoPE scaling + if (config->has_key("rope_scaling")) { + // TODO: Parse rope_scaling dict if present + model->hparams.rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_LINEAR; + } else { + model->hparams.rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + } + + // Default rope parameters + model->hparams.rope_freq_scale_train = 1.0f; + model->hparams.n_ctx_orig_yarn = model->hparams.n_ctx_train; + + // Set rope type based on architecture + model->hparams.rope_type = llama_model_rope_type(model); + + // Initialize SWA (Sliding Window Attention) layers array - default to no SWA + std::fill(model->hparams.swa_layers.begin(), model->hparams.swa_layers.end(), false); + + // Initialize recurrent layer array - default to no recurrent layers + std::fill(model->hparams.recurrent_layer_arr.begin(), model->hparams.recurrent_layer_arr.end(), false); + + // Step 4: Determine model type based on architecture and size + model->type = LLM_TYPE_UNKNOWN; + + switch (model->arch) { + case LLM_ARCH_LLAMA: + // SmolLM2-135M has 30 layers, which maps to 256M type + switch (model->hparams.n_layer) { + case 30: model->type = LLM_TYPE_256M; break; // SmolLM2-135M + case 16: model->type = LLM_TYPE_1B; break; + case 22: model->type = LLM_TYPE_1B; break; + case 26: model->type = LLM_TYPE_3B; break; + case 28: model->type = LLM_TYPE_3B; break; + case 32: model->type = LLM_TYPE_7B; break; + case 40: model->type = LLM_TYPE_13B; break; + case 48: model->type = LLM_TYPE_34B; break; + case 60: model->type = LLM_TYPE_30B; break; + case 80: model->type = LLM_TYPE_70B; break; + default: model->type = LLM_TYPE_UNKNOWN; + } + break; + + case LLM_ARCH_PHI3: + switch (model->hparams.n_layer) { + case 24: model->type = LLM_TYPE_1_3B; break; + case 32: model->type = LLM_TYPE_3B; break; + case 40: model->type = LLM_TYPE_14B; break; + default: model->type = LLM_TYPE_UNKNOWN; + } + break; + + case LLM_ARCH_QWEN2: + switch (model->hparams.n_layer) { + case 24: model->type = LLM_TYPE_0_5B; break; + case 28: model->type = LLM_TYPE_1_5B; break; + case 32: model->type = LLM_TYPE_7B; break; + case 40: model->type = LLM_TYPE_13B; break; + case 80: model->type = LLM_TYPE_70B; break; + default: model->type = LLM_TYPE_UNKNOWN; + } + break; + + case LLM_ARCH_GEMMA: + case LLM_ARCH_GEMMA2: + switch (model->hparams.n_layer) { + case 18: model->type = LLM_TYPE_2B; break; + case 26: model->type = LLM_TYPE_7B; break; + case 42: model->type = LLM_TYPE_9B; break; + case 46: model->type = LLM_TYPE_27B; break; + default: model->type = LLM_TYPE_UNKNOWN; + } + break; + + default: + model->type = LLM_TYPE_UNKNOWN; + } + + // Step 5: Allocate layers vector + model->layers.resize(model->hparams.n_layer); + + // Set model name from config + std::string model_name; + if (config->get_string("_name_or_path", model_name)) { + model->name = model_name; + } else { + model->name = "unknown"; + } + + LLAMA_LOG_INFO("%s: created model structure: arch=%s, layers=%d, type=%s\n", + __func__, + llm_arch_name(model->arch), + model->hparams.n_layer, + model->type_name().c_str()); + + return true; +} + +bool safetensors_model_builder::init_devices() { + LLAMA_LOG_INFO("%s: initializing backend devices\n", __func__); + + const int n_gpu_layers = params.n_gpu_layers; + + // Initialize GPU backends if requested + if (n_gpu_layers > 0) { + LLAMA_LOG_INFO("%s: GPU offloading enabled with %d layers\n", __func__, n_gpu_layers); + + // Get available GPU backends + size_t n_devices = ggml_backend_dev_count(); + for (size_t i = 0; i < n_devices; i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + enum ggml_backend_dev_type type = ggml_backend_dev_type(dev); + + // Add GPU/Metal backends to model->devices + if (type == GGML_BACKEND_DEVICE_TYPE_GPU) { + model->devices.push_back(dev); + LLAMA_LOG_INFO("%s: added GPU device: %s\n", __func__, ggml_backend_dev_name(dev)); + } + } + + if (model->devices.empty()) { + LLAMA_LOG_WARN("%s: no GPU backends found, falling back to CPU\n", __func__); + } + } else { + LLAMA_LOG_INFO("%s: GPU offloading disabled (n_gpu_layers=0)\n", __func__); + } + + // Initialize buffer type lists and layer device mappings + try { + model->init_layer_devices(); + } catch (const std::exception & e) { + error_msg = std::string("Failed to initialize layer devices: ") + e.what(); + LLAMA_LOG_ERROR("%s: %s\n", __func__, error_msg.c_str()); + return false; + } + + return true; +} + +// Helper function to parse layer number from tensor name +// Returns -1 for non-layer tensors (embeddings, output, etc.) +static int parse_layer_number(const std::string & name) { + // Look for pattern like "blk.5." or "layers.5." + size_t pos = name.find("blk."); + if (pos == std::string::npos) { + pos = name.find("layers."); + } + + if (pos != std::string::npos) { + size_t start = pos + (name[pos] == 'b' ? 4 : 7); // Skip "blk." or "layers." + size_t end = name.find('.', start); + if (end != std::string::npos) { + std::string layer_str = name.substr(start, end - start); + try { + return std::stoi(layer_str); + } catch (...) { + return -1; + } + } + } + return -1; +} + +bool safetensors_model_builder::allocate_tensors() { + // Step 1: Get list of all tensors from safetensors + std::vector tensor_names = st_loader->get_tensor_names(); + + if (tensor_names.empty()) { + error_msg = "No tensors found in safetensors files"; + return false; + } + + LLAMA_LOG_INFO("%s: found %zu tensors in safetensors\n", __func__, tensor_names.size()); + + // Step 2: Create GGML contexts for tensor metadata (one per buffer type) + // This follows the pattern from GGUF loader in llama-model.cpp + + size_t ctx_size = tensor_names.size() * ggml_tensor_overhead(); + + // Helper lambda to get or create context for a given buffer type + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + struct ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // Don't allocate data yet, just metadata + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + error_msg = "Failed to initialize GGML context for buffer type"; + return nullptr; + } + + ctx_map.emplace(buft, ctx); + LLAMA_LOG_DEBUG("%s: created GGML context for buffer type: %s\n", + __func__, ggml_backend_buft_name(buft)); + return ctx; + } + return it->second; + }; + + // Helper lambda to determine buffer type for a tensor based on its name + auto get_tensor_buft = [&](const std::string & name) -> ggml_backend_buffer_type_t { + // Parse layer number from tensor name + int layer_idx = parse_layer_number(name); + + // Input layer tensors (token_embd) + if (name.find("token_embd") != std::string::npos) { + return model->get_layer_buft(-1); // -1 = input layer + } + + // Output norm and output tensors + if (name.find("output_norm") != std::string::npos || name == "output.weight") { + return model->get_layer_buft(-2); // -2 = output layer + } + + // Layer tensors - use layer assignment + if (layer_idx >= 0) { + return model->get_layer_buft(layer_idx); + } + + // Default to CPU for other tensors + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev) { + return ggml_backend_dev_buffer_type(cpu_dev); + } + return nullptr; + }; + + // Step 3: Create tensor metadata for each safetensors tensor in appropriate context + int tensors_created = 0; + std::map tensor_counts; + + for (const std::string & hf_name : tensor_names) { + // Get tensor info from safetensors + const safetensors_tensor_info * info = st_loader->get_tensor_info(hf_name); + if (!info) { + LLAMA_LOG_WARN("%s: could not find tensor info for %s, skipping\n", __func__, hf_name.c_str()); + continue; + } + + // Map HuggingFace tensor name to llama.cpp internal name + std::string internal_name = mapper->map_tensor_name(hf_name); + if (internal_name.empty()) { + LLAMA_LOG_DEBUG("%s: no mapping for tensor %s, skipping\n", __func__, hf_name.c_str()); + continue; + } + + // Convert safetensors dtype to GGML type + ggml_type ggml_type = safetensors_dtype_to_ggml_type(info->dtype); + if (ggml_type == GGML_TYPE_COUNT) { + LLAMA_LOG_WARN("%s: unsupported dtype for tensor %s, skipping\n", __func__, hf_name.c_str()); + continue; + } + + // Determine which buffer type (CPU or GPU) this tensor should use + ggml_backend_buffer_type_t buft = get_tensor_buft(internal_name); + if (!buft) { + error_msg = "Failed to determine buffer type for tensor: " + internal_name; + return false; + } + + // Get or create context for this buffer type + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + return false; + } + + // Create tensor in the appropriate context + struct ggml_tensor * tensor = nullptr; + + switch (info->shape.size()) { + case 1: + tensor = ggml_new_tensor_1d(ctx, ggml_type, info->shape[0]); + break; + case 2: + // GGML expects dimensions REVERSED from PyTorch/HuggingFace + tensor = ggml_new_tensor_2d(ctx, ggml_type, info->shape[1], info->shape[0]); + break; + case 3: + tensor = ggml_new_tensor_3d(ctx, ggml_type, info->shape[2], info->shape[1], info->shape[0]); + break; + case 4: + tensor = ggml_new_tensor_4d(ctx, ggml_type, info->shape[3], info->shape[2], info->shape[1], info->shape[0]); + break; + default: + LLAMA_LOG_WARN("%s: tensor %s has unsupported number of dimensions: %zu\n", + __func__, hf_name.c_str(), info->shape.size()); + continue; + } + + if (!tensor) { + error_msg = "Failed to create tensor: " + internal_name; + return false; + } + + // Set tensor name + ggml_set_name(tensor, internal_name.c_str()); + + tensors_created++; + tensor_counts[buft]++; + + // Debug log for key tensors + if (tensors_created <= 10 || tensors_created % 100 == 0) { + int layer_idx = parse_layer_number(internal_name); + LLAMA_LOG_DEBUG("%s: Created %s (layer %d): ne[0]=%" PRId64 ", ne[1]=%" PRId64 " in context for %s\n", + __func__, internal_name.c_str(), layer_idx, + tensor->ne[0], tensor->ne[1], + ggml_backend_buft_name(buft)); + } + + if (tensors_created % 100 == 0) { + LLAMA_LOG_INFO("%s: created %d tensor metadata entries...\n", __func__, tensors_created); + } + } + + LLAMA_LOG_INFO("%s: created %d tensor metadata entries total\n", __func__, tensors_created); + + // Step 4: Allocate backend buffers for each context + LLAMA_LOG_INFO("%s: allocating backend buffers for %zu buffer types\n", __func__, ctx_map.size()); + + for (auto & pair : ctx_map) { + ggml_backend_buffer_type_t buft = pair.first; + ggml_context * ctx = pair.second; + + int count = tensor_counts[buft]; + LLAMA_LOG_INFO("%s: allocating buffer for %s (%d tensors)\n", + __func__, ggml_backend_buft_name(buft), count); + + // Allocate backend buffer for all tensors in this context + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buffer) { + error_msg = std::string("Failed to allocate backend buffer for ") + ggml_backend_buft_name(buft); + return false; + } + + size_t buffer_size = ggml_backend_buffer_get_size(buffer); + LLAMA_LOG_INFO("%s: allocated %zu bytes for %s\n", + __func__, buffer_size, ggml_backend_buft_name(buft)); + + // Mark buffer as containing weights for scheduler optimization + ggml_backend_buffer_set_usage(buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + // Store buffer for use during loading (ownership will be transferred later) + buffer_map.emplace(buft, buffer); + } + + LLAMA_LOG_INFO("%s: tensor allocation complete - %d tensors ready\n", __func__, tensors_created); + return true; +} + +bool safetensors_model_builder::load_tensor_data() { + if (ctx_map.empty()) { + error_msg = "Cannot load tensor data: no contexts initialized"; + return false; + } + + if (buffer_map.empty()) { + error_msg = "Cannot load tensor data: no backend buffers allocated"; + return false; + } + + LLAMA_LOG_INFO("%s: loading tensor data from safetensors\n", __func__); + + int tensors_loaded = 0; + int tensors_skipped = 0; + int tensors_failed = 0; + + // Get all safetensors tensor names + std::vector st_tensor_names = st_loader->get_tensor_names(); + + for (const std::string & hf_name : st_tensor_names) { + // Map HF name to internal name + std::string internal_name = mapper->map_tensor_name(hf_name); + if (internal_name.empty()) { + // This tensor doesn't map to anything (might be optional) + LLAMA_LOG_DEBUG("%s: no mapping for HF tensor %s, skipping\n", __func__, hf_name.c_str()); + tensors_skipped++; + continue; + } + + // Find the tensor across all GGML contexts + struct ggml_tensor * tensor = nullptr; + for (auto & pair : ctx_map) { + tensor = ggml_get_tensor(pair.second, internal_name.c_str()); + if (tensor) { + break; + } + } + + if (!tensor) { + LLAMA_LOG_WARN("%s: tensor %s (HF: %s) not found in any GGML context\n", + __func__, internal_name.c_str(), hf_name.c_str()); + tensors_skipped++; + continue; + } + + // Verify tensor has allocated data + if (!tensor->data) { + LLAMA_LOG_ERROR("%s: tensor %s has no data buffer allocated\n", __func__, internal_name.c_str()); + tensors_failed++; + continue; + } + + // Get tensor info from safetensors + const safetensors_tensor_info * info = st_loader->get_tensor_info(hf_name); + if (!info) { + LLAMA_LOG_ERROR("%s: could not get info for tensor %s\n", __func__, hf_name.c_str()); + tensors_failed++; + continue; + } + + // Don't transpose - match Python converter behavior which only reverses dimensions + bool needs_transpose = false; + + // Read data from safetensors into temporary buffer + size_t st_data_size = info->size(); + std::vector temp_buffer(st_data_size); + + if (!st_loader->read_tensor_data(hf_name, temp_buffer.data(), st_data_size)) { + LLAMA_LOG_ERROR("%s: failed to read tensor data for %s\n", __func__, hf_name.c_str()); + tensors_failed++; + continue; + } + + // Convert types and copy to GGML tensor + size_t ggml_data_size = ggml_nbytes(tensor); + ggml_type tensor_type = tensor->type; + + // If transposition is needed, we need to create a temporary buffer with transposed data + std::vector transposed_buffer; + std::vector transposed_shape; + const char * source_data = temp_buffer.data(); + size_t source_size = st_data_size; + const int64_t * shape_ptr = reinterpret_cast(info->shape.data()); + size_t shape_size = info->shape.size(); + + // For 2D weight tensors: we need to physically transpose the data + // because we create the tensor with swapped dimensions [dim1, dim0] + // but the safetensors data is in [dim0, dim1] layout + if (needs_transpose && info->shape.size() == 2) { + size_t dim0 = info->shape[0]; + size_t dim1 = info->shape[1]; + size_t elem_size = st_data_size / (dim0 * dim1); + + // Physically transpose the data: [dim0, dim1] -> [dim1, dim0] + transposed_buffer.resize(st_data_size); + const char * src = temp_buffer.data(); + char * dst = transposed_buffer.data(); + + for (size_t row = 0; row < dim0; row++) { + for (size_t col = 0; col < dim1; col++) { + size_t src_idx = (row * dim1 + col) * elem_size; + size_t dst_idx = (col * dim0 + row) * elem_size; + memcpy(dst + dst_idx, src + src_idx, elem_size); + } + } + + source_data = transposed_buffer.data(); + + LLAMA_LOG_DEBUG("%s: Transposed %s from [%zu, %zu] to [%zu, %zu]\n", + __func__, internal_name.c_str(), dim0, dim1, dim1, dim0); + } + + // Apply head permutation for attention query/key weights + // This reverses HuggingFace's permutation that was applied during training + // Apply head permutation for attention query/key weights + // This reverses HuggingFace's permutation that was applied during training + if ((internal_name.find("attn_q.weight") != std::string::npos || + internal_name.find("attn_k.weight") != std::string::npos) && + info->shape.size() == 2 && !needs_transpose) { + + // Get n_head from model hparams + int n_head = model->hparams.n_head(); + if (internal_name.find("attn_k.weight") != std::string::npos) { + // For key weights, use n_head_kv if available + n_head = model->hparams.n_head_kv(); + } + + if (n_head > 0) { + size_t out_dim = info->shape[0]; // Output dimension + size_t in_dim = info->shape[1]; // Input dimension + size_t elem_size = st_data_size / (out_dim * in_dim); + + LLAMA_LOG_DEBUG("%s: Applying head permutation to %s (n_head=%d)\n", + __func__, internal_name.c_str(), n_head); + + // Apply permutation in-place on temp_buffer + apply_head_permutation(temp_buffer, elem_size, out_dim, in_dim, n_head); + + // Update source_data pointer to use the permuted data + source_data = temp_buffer.data(); + } + } + + if (!convert_safetensors_to_ggml( + source_data, source_size, info->dtype, + tensor->data, ggml_data_size, tensor_type, + shape_ptr, shape_size)) { + LLAMA_LOG_ERROR("%s: failed to convert tensor data for %s\n", __func__, hf_name.c_str()); + tensors_failed++; + continue; + } + + // DEBUG: Log first few values of key tensors + if (internal_name == "token_embd.weight" || internal_name == "blk.0.attn_q.weight") { + const float * data_f32 = (const float *)tensor->data; + LLAMA_LOG_INFO("%s: [DEBUG] %s first 8 F32 values: %.6f %.6f %.6f %.6f %.6f %.6f %.6f %.6f\n", + __func__, internal_name.c_str(), + data_f32[0], data_f32[1], data_f32[2], data_f32[3], + data_f32[4], data_f32[5], data_f32[6], data_f32[7]); + } + + tensors_loaded++; + + if (tensors_loaded % 50 == 0) { + LLAMA_LOG_INFO("%s: loaded %d tensors...\n", __func__, tensors_loaded); + } + } + + LLAMA_LOG_INFO("%s: loaded %d tensors, skipped %d, failed %d\n", + __func__, tensors_loaded, tensors_skipped, tensors_failed); + + if (tensors_failed > 0) { + error_msg = "Some tensors failed to load"; + return false; + } + + if (tensors_loaded == 0) { + error_msg = "No tensors were loaded"; + return false; + } + + return true; +} + +bool safetensors_model_builder::link_tensors_to_model() { + if (!model) { + error_msg = "Cannot link tensors: model not created"; + return false; + } + + if (ctx_map.empty()) { + error_msg = "Cannot link tensors: no contexts initialized"; + return false; + } + + LLAMA_LOG_INFO("%s: linking tensors to model structure\n", __func__); + + // Helper lambda to get tensor (returns nullptr if not found, which is ok for optional tensors) + // Search across all GGML contexts (CPU, GPU, etc.) + auto get_tensor = [&](const char * name) -> ggml_tensor * { + ggml_tensor * tensor = nullptr; + + // Search all contexts for this tensor + for (auto & pair : ctx_map) { + tensor = ggml_get_tensor(pair.second, name); + if (tensor) { + break; + } + } + + if (tensor) { + // Add to tensors_by_name for n_tensors() to work correctly + model->tensors_by_name.emplace_back(name, tensor); + } + return tensor; + }; + + int tensors_linked = 0; + + // Link input embedding + model->tok_embd = get_tensor("token_embd.weight"); + if (model->tok_embd) { + tensors_linked++; + LLAMA_LOG_INFO("%s: linked token_embd: ne[0]=%" PRId64 ", ne[1]=%" PRId64 ", ne[2]=%" PRId64 ", ne[3]=%" PRId64 "\n", + __func__, model->tok_embd->ne[0], model->tok_embd->ne[1], model->tok_embd->ne[2], model->tok_embd->ne[3]); + } else { + LLAMA_LOG_WARN("%s: token_embd.weight not found\n", __func__); + } + + // Link output norm and output + model->output_norm = get_tensor("output_norm.weight"); + if (model->output_norm) { + tensors_linked++; + } + + model->output = get_tensor("output.weight"); + if (model->output) { + tensors_linked++; + } else { + // output might share with tok_embd + model->output = model->tok_embd; + LLAMA_LOG_DEBUG("%s: output shares with token_embd\n", __func__); + } + + // Link layer tensors based on architecture + switch (model->arch) { + case LLM_ARCH_LLAMA: + { + LLAMA_LOG_INFO("%s: linking Llama layer tensors\n", __func__); + + for (size_t i = 0; i < model->layers.size(); ++i) { + auto & layer = model->layers[i]; + char buf[256]; + + // Attention norm + snprintf(buf, sizeof(buf), "blk.%zu.attn_norm.weight", i); + layer.attn_norm = get_tensor(buf); + if (layer.attn_norm) tensors_linked++; + + // Attention Q, K, V, O + snprintf(buf, sizeof(buf), "blk.%zu.attn_q.weight", i); + layer.wq = get_tensor(buf); + if (layer.wq) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.attn_k.weight", i); + layer.wk = get_tensor(buf); + if (layer.wk) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.attn_v.weight", i); + layer.wv = get_tensor(buf); + if (layer.wv) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.attn_output.weight", i); + layer.wo = get_tensor(buf); + if (layer.wo) tensors_linked++; + + // FFN norm + snprintf(buf, sizeof(buf), "blk.%zu.ffn_norm.weight", i); + layer.ffn_norm = get_tensor(buf); + if (layer.ffn_norm) tensors_linked++; + + // FFN gate, down, up + snprintf(buf, sizeof(buf), "blk.%zu.ffn_gate.weight", i); + layer.ffn_gate = get_tensor(buf); + if (layer.ffn_gate) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.ffn_down.weight", i); + layer.ffn_down = get_tensor(buf); + if (layer.ffn_down) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.ffn_up.weight", i); + layer.ffn_up = get_tensor(buf); + if (layer.ffn_up) tensors_linked++; + + if (i % 10 == 0 && i > 0) { + LLAMA_LOG_INFO("%s: linked layer %zu/%zu\n", __func__, i, model->layers.size()); + } + } + + LLAMA_LOG_INFO("%s: linked all %zu layers\n", __func__, model->layers.size()); + } + break; + + case LLM_ARCH_PHI3: + case LLM_ARCH_QWEN2: + case LLM_ARCH_GEMMA: + case LLM_ARCH_GEMMA2: + { + // These architectures have similar structure to Llama + // For now, use the same linking pattern + LLAMA_LOG_WARN("%s: using Llama-style linking for %s - may need adjustments\n", + __func__, llm_arch_name(model->arch)); + + for (size_t i = 0; i < model->layers.size(); ++i) { + auto & layer = model->layers[i]; + char buf[256]; + + snprintf(buf, sizeof(buf), "blk.%zu.attn_norm.weight", i); + layer.attn_norm = get_tensor(buf); + if (layer.attn_norm) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.attn_q.weight", i); + layer.wq = get_tensor(buf); + if (layer.wq) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.attn_k.weight", i); + layer.wk = get_tensor(buf); + if (layer.wk) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.attn_v.weight", i); + layer.wv = get_tensor(buf); + if (layer.wv) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.attn_output.weight", i); + layer.wo = get_tensor(buf); + if (layer.wo) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.ffn_norm.weight", i); + layer.ffn_norm = get_tensor(buf); + if (layer.ffn_norm) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.ffn_gate.weight", i); + layer.ffn_gate = get_tensor(buf); + if (layer.ffn_gate) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.ffn_down.weight", i); + layer.ffn_down = get_tensor(buf); + if (layer.ffn_down) tensors_linked++; + + snprintf(buf, sizeof(buf), "blk.%zu.ffn_up.weight", i); + layer.ffn_up = get_tensor(buf); + if (layer.ffn_up) tensors_linked++; + } + } + break; + + default: + error_msg = "Tensor linking not implemented for this architecture"; + return false; + } + + LLAMA_LOG_INFO("%s: linked %d tensors to model structure\n", __func__, tensors_linked); + + if (tensors_linked == 0) { + error_msg = "No tensors were linked to model - tensor names may not match"; + return false; + } + + return true; +} + +bool safetensors_model_builder::register_buffers_with_model() { + LLAMA_LOG_INFO("%s: registering buffers with model\n", __func__); + + if (ctx_map.empty()) { + error_msg = "Cannot register buffers: no contexts allocated"; + return false; + } + + if (buffer_map.empty()) { + error_msg = "Cannot register buffers: no backend buffers allocated"; + return false; + } + + // Transfer ownership of contexts and buffers to model + // This follows the pattern from GGUF loader in llama-model.cpp line 6688-6703 + for (auto & it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx = it.second; + + // Get the buffer for this buffer type + auto buf_it = buffer_map.find(buft); + if (buf_it == buffer_map.end()) { + error_msg = std::string("No buffer found for buffer type: ") + ggml_backend_buft_name(buft); + return false; + } + + ggml_backend_buffer_t buf = buf_it->second; + + // Wrap buffer in unique_ptr and move to vector + std::vector bufs; + bufs.emplace_back(buf); + + // Add context and buffers to model (transfers ownership) + model->add_context_with_buffers(ctx, std::move(bufs)); + + LLAMA_LOG_DEBUG("%s: registered context and buffer for %s\n", + __func__, ggml_backend_buft_name(buft)); + } + + // Clear maps to prevent double-free in destructor + // The model now owns these resources + ctx_map.clear(); + buffer_map.clear(); + + LLAMA_LOG_INFO("%s: successfully registered all buffers with model\n", __func__); + return true; +} + +bool safetensors_model_builder::init_vocabulary() { + LLAMA_LOG_INFO("%s: initializing vocabulary\n", __func__); + + // Check if tokenizer.json exists + std::string tokenizer_path = model_dir + "/tokenizer.json"; + std::string tokenizer_config_path = model_dir + "/tokenizer_config.json"; + + bool has_tokenizer = std::filesystem::exists(tokenizer_path); + bool has_config = std::filesystem::exists(tokenizer_config_path); + + if (!has_tokenizer) { + LLAMA_LOG_ERROR("%s: tokenizer.json not found in %s\n", __func__, model_dir.c_str()); + error_msg = "tokenizer.json not found - cannot load vocabulary"; + return false; + } + + LLAMA_LOG_INFO("%s: found tokenizer.json\n", __func__); + if (has_config) { + LLAMA_LOG_INFO("%s: found tokenizer_config.json\n", __func__); + } + + // Load vocabulary from HuggingFace tokenizer format + bool success = model->vocab.load_from_hf_tokenizer( + tokenizer_path, + has_config ? tokenizer_config_path : "" + ); + + if (!success) { + error_msg = "Failed to load vocabulary from tokenizer.json"; + LLAMA_LOG_ERROR("%s: failed to load vocabulary\n", __func__); + return false; + } + + LLAMA_LOG_INFO("%s: vocabulary loaded successfully - %u tokens\n", + __func__, model->vocab.n_tokens()); + + return true; +} + +bool safetensors_model_builder::finalize_model() { + if (!model) { + error_msg = "Cannot finalize: model not created"; + return false; + } + + LLAMA_LOG_INFO("%s: finalizing model\n", __func__); + + // Validate that critical tensors are linked + bool has_tok_embd = (model->tok_embd != nullptr); + bool has_output = (model->output != nullptr); + bool has_output_norm = (model->output_norm != nullptr); + + if (!has_tok_embd) { + LLAMA_LOG_WARN("%s: token embedding tensor not linked\n", __func__); + } + + if (!has_output) { + LLAMA_LOG_WARN("%s: output tensor not linked\n", __func__); + } + + if (!has_output_norm) { + LLAMA_LOG_WARN("%s: output norm tensor not linked\n", __func__); + } + + // Validate layers have critical tensors + int layers_valid = 0; + for (size_t i = 0; i < model->layers.size(); ++i) { + const auto & layer = model->layers[i]; + bool layer_ok = (layer.attn_norm && layer.wq && layer.wk && layer.wv && layer.wo && + layer.ffn_norm && layer.ffn_gate && layer.ffn_down && layer.ffn_up); + if (layer_ok) { + layers_valid++; + } else { + LLAMA_LOG_WARN("%s: layer %zu missing some tensors\n", __func__, i); + } + } + + LLAMA_LOG_INFO("%s: validated %d/%zu layers\n", __func__, layers_valid, model->layers.size()); + + // Log final model info + LLAMA_LOG_INFO("%s: model finalized:\n", __func__); + LLAMA_LOG_INFO("%s: architecture: %s\n", __func__, llm_arch_name(model->arch)); + LLAMA_LOG_INFO("%s: type: %s\n", __func__, model->type_name().c_str()); + LLAMA_LOG_INFO("%s: layers: %zu\n", __func__, model->layers.size()); + LLAMA_LOG_INFO("%s: embedding dim: %d\n", __func__, model->hparams.n_embd); + LLAMA_LOG_INFO("%s: attention heads: %d\n", __func__, model->hparams.n_head()); + LLAMA_LOG_INFO("%s: context length: %d\n", __func__, model->hparams.n_ctx_train); + + // Set model stats (number of elements and bytes) + // These are used for various calculations in the backend + uint64_t n_elements = 0; + size_t n_bytes = 0; + std::vector tensor_names = st_loader->get_tensor_names(); + for (const auto & name : tensor_names) { + const safetensors_tensor_info * info = st_loader->get_tensor_info(name); + if (info) { + n_elements += info->n_elements(); + n_bytes += info->size(); + } + } + model->set_stats(n_elements, n_bytes); + LLAMA_LOG_INFO("%s: model stats: n_elements=%" PRIu64 ", n_bytes=%zu\n", __func__, n_elements, n_bytes); + + // Transfer ownership of context and buffer to the model + // This prevents them from being freed when safetensors_model_builder is destroyed + if (ctx_meta && backend_buffer) { + // Create unique pointer for buffer to manage lifetime + ggml_backend_buffer_ptr buf_ptr(backend_buffer); + + // Add to model's context/buffer list + std::vector bufs; + bufs.push_back(std::move(buf_ptr)); + model->add_context_with_buffers(ctx_meta, std::move(bufs)); + + // Release ownership from builder so destructor doesn't free them + ctx_meta = nullptr; + backend_buffer = nullptr; + + LLAMA_LOG_INFO("%s: transferred context and buffer ownership to model\n", __func__); + } + + // Load chat template from tokenizer_config.json if available + std::string tokenizer_config_path = model_dir + "/tokenizer_config.json"; + std::ifstream config_file(tokenizer_config_path); + if (config_file.is_open()) { + try { + nlohmann::json tokenizer_config = nlohmann::json::parse(config_file); + if (tokenizer_config.contains("chat_template") && tokenizer_config["chat_template"].is_string()) { + std::string chat_template = tokenizer_config["chat_template"]; + std::string key = LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); + model->gguf_kv.emplace(key, chat_template); + LLAMA_LOG_INFO("%s: loaded chat template from tokenizer_config.json\n", __func__); + } + } catch (const std::exception & e) { + LLAMA_LOG_WARN("%s: failed to parse tokenizer_config.json for chat template: %s\n", __func__, e.what()); + } + } + + return true; +} diff --git a/src/llama-model-from-safetensors.h b/src/llama-model-from-safetensors.h new file mode 100644 index 00000000000..7a3b89e4bdb --- /dev/null +++ b/src/llama-model-from-safetensors.h @@ -0,0 +1,76 @@ +#pragma once + +#include "llama.h" +#include "llama-safetensors.h" +#include "llama-hf-config.h" +#include "llama-safetensors-loader.h" + +#include +#include +#include + +// Forward declarations +struct llama_model; +struct llama_model_params; +struct ggml_context; + +// Main entry point for loading a model from safetensors +// model_path can be either: +// - Directory containing model.safetensors + config.json +// - Path to a single .safetensors file (config.json must be in same dir) +llama_model * llama_model_load_from_safetensors( + const char * model_path, + const llama_model_params & params +); + +// Internal implementation class +class safetensors_model_builder { +public: + safetensors_model_builder( + const std::string & model_dir, + const llama_model_params & params + ); + + ~safetensors_model_builder(); + + // Main loading pipeline + llama_model * build(); + + // Get last error message + const std::string & get_error() const { return error_msg; } + +private: + std::string model_dir; + llama_model_params params; + std::string error_msg; + + // Components + std::unique_ptr config; + std::unique_ptr st_loader; + std::unique_ptr mapper; + + // Model being built + llama_model * model = nullptr; + + // GGML contexts and backend buffers (one per buffer type for GPU offloading) + struct ggml_context * ctx_meta = nullptr; // Legacy: now unused, kept for compatibility + struct ggml_context * ctx_data = nullptr; + struct ggml_backend_buffer * backend_buffer = nullptr; // Legacy: now unused + + // Multi-device support: map from buffer type to context and buffer + std::map ctx_map; + std::map buffer_map; + + // Pipeline steps + bool load_config(); + bool load_safetensors_files(); + bool detect_architecture(); + bool create_model_structure(); + bool init_devices(); + bool allocate_tensors(); + bool load_tensor_data(); + bool link_tensors_to_model(); + bool register_buffers_with_model(); + bool init_vocabulary(); + bool finalize_model(); +}; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c2a545531a9..7ba19c069d0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -463,6 +463,106 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi llama_model::~llama_model() {} +void llama_model::init_layer_devices() { + // Initialize buffer type lists and layer device mappings + // This is called by the safetensors loader after devices have been set up + + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error("no CPU backend found"); + } + + // Build CPU buffer type list (simplified version - just use CPU buffer type) + // buft_list_t is std::vector> + pimpl->cpu_buft_list.clear(); + pimpl->cpu_buft_list.push_back({cpu_dev, ggml_backend_dev_buffer_type(cpu_dev)}); + + // Build GPU buffer type lists for each device + for (auto * dev : devices) { + auto buft_list = pimpl->cpu_buft_list; // Start with CPU as fallback + // Add device-specific buffer type at the front + auto * dev_buft = ggml_backend_dev_buffer_type(dev); + if (dev_buft) { + buft_list.insert(buft_list.begin(), {dev, dev_buft}); + } + pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); + } + + // Assign input layer to CPU (always keep input on CPU for better performance) + pimpl->dev_input = {cpu_dev, &pimpl->cpu_buft_list}; + + // Calculate which layers to offload to GPU + const int n_layer = hparams.n_layer; + const int n_gpu_layers = params.n_gpu_layers; + const int i_gpu_start = std::max((int)n_layer - n_gpu_layers, (int)0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); + + pimpl->dev_layer.resize(n_layer); + int n_layers_on_gpu = 0; + int n_layers_on_cpu = 0; + + for (int il = 0; il < n_layer; ++il) { + if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers || devices.empty()) { + // Assign to CPU + pimpl->dev_layer[il] = {cpu_dev, &pimpl->cpu_buft_list}; + n_layers_on_cpu++; + LLAMA_LOG_DEBUG("%s: layer %3d assigned to device %s\n", __func__, il, ggml_backend_dev_name(cpu_dev)); + } else { + // Assign to GPU (use first GPU device for now - could be extended for multi-GPU) + auto * dev = devices[0]; + pimpl->dev_layer[il] = {dev, &pimpl->gpu_buft_list.at(dev)}; + n_layers_on_gpu++; + LLAMA_LOG_DEBUG("%s: layer %3d assigned to device %s\n", __func__, il, ggml_backend_dev_name(dev)); + } + } + + // Assign output layer + if (n_layer < i_gpu_start || (n_layer - i_gpu_start) >= act_gpu_layers || devices.empty()) { + pimpl->dev_output = {cpu_dev, &pimpl->cpu_buft_list}; + LLAMA_LOG_DEBUG("%s: output layer assigned to device %s\n", __func__, ggml_backend_dev_name(cpu_dev)); + } else { + auto * dev = devices[0]; + pimpl->dev_output = {dev, &pimpl->gpu_buft_list.at(dev)}; + LLAMA_LOG_DEBUG("%s: output layer assigned to device %s\n", __func__, ggml_backend_dev_name(dev)); + } + + if (n_layers_on_gpu > 0) { + LLAMA_LOG_INFO("%s: assigned %d layers to GPU (%s), %d layers to CPU\n", + __func__, n_layers_on_gpu, ggml_backend_dev_name(devices[0]), n_layers_on_cpu); + } else { + LLAMA_LOG_INFO("%s: assigned %d layers to CPU\n", __func__, n_layers_on_cpu); + } +} + +ggml_backend_buffer_type_t llama_model::get_layer_buft(int layer_idx) { + // layer_idx: -1 for input/embedding layer, -2 for output layer, >=0 for transformer layers + + if (layer_idx == -1) { + // Input/embedding layer + if (pimpl->dev_input.buft_list && !pimpl->dev_input.buft_list->empty()) { + return pimpl->dev_input.buft_list->at(0).second; + } + } else if (layer_idx == -2) { + // Output layer + if (pimpl->dev_output.buft_list && !pimpl->dev_output.buft_list->empty()) { + return pimpl->dev_output.buft_list->at(0).second; + } + } else if (layer_idx >= 0 && layer_idx < (int)pimpl->dev_layer.size()) { + // Transformer layer + if (pimpl->dev_layer[layer_idx].buft_list && !pimpl->dev_layer[layer_idx].buft_list->empty()) { + return pimpl->dev_layer[layer_idx].buft_list->at(0).second; + } + } + + // Fallback to CPU + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev) { + return ggml_backend_dev_buffer_type(cpu_dev); + } + + return nullptr; +} + void llama_model::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; pimpl->n_bytes = ml.n_bytes; @@ -6647,6 +6747,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { return true; } +void llama_model::add_context_with_buffers(ggml_context * ctx, std::vector buffers) { + ggml_context_ptr ctx_ptr(ctx); + pimpl->ctxs_bufs.push_back({std::move(ctx_ptr), std::move(buffers)}); +} + +void llama_model::set_stats(uint64_t n_elements, size_t n_bytes) { + pimpl->n_elements = n_elements; + pimpl->n_bytes = n_bytes; +} + std::string llama_model::arch_name() const { return llm_arch_name(arch); } diff --git a/src/llama-model.h b/src/llama-model.h index f8342cf2cb1..7effe397985 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -484,6 +484,19 @@ struct llama_model { void load_vocab (llama_model_loader & ml); bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback + // Initialize layer device mappings (called by safetensors loader) + void init_layer_devices(); + + // Get buffer type for a tensor based on layer index (called by safetensors loader) + // layer_idx: -1 for input/embedding layer, -2 for output layer, >=0 for transformer layers + ggml_backend_buffer_type_t get_layer_buft(int layer_idx); + + // Add context and buffers to model (called by safetensors loader) + void add_context_with_buffers(ggml_context * ctx, std::vector buffers); + + // Set model stats (called by safetensors loader) + void set_stats(uint64_t n_elements, size_t n_bytes); + std::string arch_name() const; std::string type_name() const; diff --git a/src/llama-safetensors-loader.cpp b/src/llama-safetensors-loader.cpp new file mode 100644 index 00000000000..5bf52469bce --- /dev/null +++ b/src/llama-safetensors-loader.cpp @@ -0,0 +1,275 @@ +#include "llama-safetensors-loader.h" + +#include "llama.h" +#include "llama-impl.h" + +#include +#include +#include + +// Map HuggingFace architecture names to llama.cpp architectures +llm_arch hf_arch_to_llm_arch(const std::string & hf_arch) { + // Llama family + if (hf_arch == "LlamaForCausalLM" || + hf_arch == "LLaMAForCausalLM") { + return LLM_ARCH_LLAMA; + } + // Mistral (uses Llama architecture) + if (hf_arch == "MistralForCausalLM" || + hf_arch == "MixtralForCausalLM") { + return LLM_ARCH_LLAMA; + } + // Phi + if (hf_arch == "PhiForCausalLM" || + hf_arch == "Phi3ForCausalLM") { + return LLM_ARCH_PHI3; + } + // Qwen2 + if (hf_arch == "Qwen2ForCausalLM") { + return LLM_ARCH_QWEN2; + } + // Gemma + if (hf_arch == "GemmaForCausalLM" || + hf_arch == "Gemma2ForCausalLM") { + return LLM_ARCH_GEMMA; + } + + return LLM_ARCH_UNKNOWN; +} + +// Llama/Mistral tensor name mapper +std::string llama_tensor_mapper::map_tensor_name(const std::string & hf_name) const { + // HuggingFace Llama/Mistral tensor naming: + // model.embed_tokens.weight -> token_embd.weight + // model.layers.{N}.self_attn.q_proj.weight -> blk.{N}.attn_q.weight + // model.layers.{N}.self_attn.k_proj.weight -> blk.{N}.attn_k.weight + // model.layers.{N}.self_attn.v_proj.weight -> blk.{N}.attn_v.weight + // model.layers.{N}.self_attn.o_proj.weight -> blk.{N}.attn_output.weight + // model.layers.{N}.mlp.gate_proj.weight -> blk.{N}.ffn_gate.weight + // model.layers.{N}.mlp.up_proj.weight -> blk.{N}.ffn_up.weight + // model.layers.{N}.mlp.down_proj.weight -> blk.{N}.ffn_down.weight + // model.layers.{N}.input_layernorm.weight -> blk.{N}.attn_norm.weight + // model.layers.{N}.post_attention_layernorm.weight -> blk.{N}.ffn_norm.weight + // model.norm.weight -> output_norm.weight + // lm_head.weight -> output.weight + + if (hf_name == "model.embed_tokens.weight") { + return "token_embd.weight"; + } + + if (hf_name == "lm_head.weight") { + return "output.weight"; + } + + if (hf_name == "model.norm.weight") { + return "output_norm.weight"; + } + + // Handle layer-specific tensors + std::regex layer_regex(R"(model\.layers\.(\d+)\.(.+))"); + std::smatch match; + if (std::regex_match(hf_name, match, layer_regex)) { + std::string layer_idx = match[1].str(); + std::string rest = match[2].str(); + + std::string mapped_name = "blk." + layer_idx + "."; + + if (rest == "self_attn.q_proj.weight") { + mapped_name += "attn_q.weight"; + } else if (rest == "self_attn.k_proj.weight") { + mapped_name += "attn_k.weight"; + } else if (rest == "self_attn.v_proj.weight") { + mapped_name += "attn_v.weight"; + } else if (rest == "self_attn.o_proj.weight") { + mapped_name += "attn_output.weight"; + } else if (rest == "mlp.gate_proj.weight") { + mapped_name += "ffn_gate.weight"; + } else if (rest == "mlp.up_proj.weight") { + mapped_name += "ffn_up.weight"; + } else if (rest == "mlp.down_proj.weight") { + mapped_name += "ffn_down.weight"; + } else if (rest == "input_layernorm.weight") { + mapped_name += "attn_norm.weight"; + } else if (rest == "post_attention_layernorm.weight") { + mapped_name += "ffn_norm.weight"; + } else { + // Unknown tensor + return ""; + } + + return mapped_name; + } + + // Unknown tensor - skip it + return ""; +} + +std::vector llama_tensor_mapper::get_required_tensors(int n_layers) const { + std::vector required; + + required.push_back("model.embed_tokens.weight"); + required.push_back("model.norm.weight"); + required.push_back("lm_head.weight"); + + for (int i = 0; i < n_layers; i++) { + std::string prefix = "model.layers." + std::to_string(i) + "."; + required.push_back(prefix + "self_attn.q_proj.weight"); + required.push_back(prefix + "self_attn.k_proj.weight"); + required.push_back(prefix + "self_attn.v_proj.weight"); + required.push_back(prefix + "self_attn.o_proj.weight"); + required.push_back(prefix + "mlp.gate_proj.weight"); + required.push_back(prefix + "mlp.up_proj.weight"); + required.push_back(prefix + "mlp.down_proj.weight"); + required.push_back(prefix + "input_layernorm.weight"); + required.push_back(prefix + "post_attention_layernorm.weight"); + } + + return required; +} + +// Phi tensor name mapper +std::string phi_tensor_mapper::map_tensor_name(const std::string & hf_name) const { + // Phi-3 uses similar structure to Llama with some differences + // TODO: Implement Phi-specific mappings + // For now, use Llama mappings as they're similar + llama_tensor_mapper llama_mapper; + return llama_mapper.map_tensor_name(hf_name); +} + +std::vector phi_tensor_mapper::get_required_tensors(int n_layers) const { + // TODO: Implement Phi-specific required tensors + llama_tensor_mapper llama_mapper; + return llama_mapper.get_required_tensors(n_layers); +} + +// Qwen2 tensor name mapper +std::string qwen2_tensor_mapper::map_tensor_name(const std::string & hf_name) const { + // Qwen2 uses similar structure to Llama + // TODO: Implement Qwen2-specific mappings + llama_tensor_mapper llama_mapper; + return llama_mapper.map_tensor_name(hf_name); +} + +std::vector qwen2_tensor_mapper::get_required_tensors(int n_layers) const { + // TODO: Implement Qwen2-specific required tensors + llama_tensor_mapper llama_mapper; + return llama_mapper.get_required_tensors(n_layers); +} + +// Gemma tensor name mapper +std::string gemma_tensor_mapper::map_tensor_name(const std::string & hf_name) const { + // Gemma uses similar structure to Llama with some differences + // TODO: Implement Gemma-specific mappings + llama_tensor_mapper llama_mapper; + return llama_mapper.map_tensor_name(hf_name); +} + +std::vector gemma_tensor_mapper::get_required_tensors(int n_layers) const { + // TODO: Implement Gemma-specific required tensors + llama_tensor_mapper llama_mapper; + return llama_mapper.get_required_tensors(n_layers); +} + +// Factory function +std::unique_ptr create_tensor_mapper(const std::string & hf_arch) { + llm_arch arch = hf_arch_to_llm_arch(hf_arch); + + switch (arch) { + case LLM_ARCH_LLAMA: + return std::make_unique(); + case LLM_ARCH_PHI3: + return std::make_unique(); + case LLM_ARCH_QWEN2: + return std::make_unique(); + case LLM_ARCH_GEMMA: + return std::make_unique(); + default: + return nullptr; + } +} + +// Main loader implementation +bool safetensors_model_loader::load_config(const std::string & model_dir) { + std::string config_path = model_dir + "/config.json"; + if (!config.load_from_file(config_path)) { + error_msg = "Failed to load config.json: " + config.get_error(); + return false; + } + return true; +} + +bool safetensors_model_loader::load_safetensors_files(const std::string & model_dir) { + st_loader = std::make_unique(); + + // Try loading single file first + std::string single_file = model_dir + "/model.safetensors"; + std::ifstream test_single(single_file); + if (test_single.good()) { + test_single.close(); + if (st_loader->load_single(single_file)) { + return true; + } + } + + // Try loading sharded model + std::string index_file = model_dir + "/model.safetensors.index.json"; + std::ifstream test_index(index_file); + if (test_index.good()) { + test_index.close(); + if (st_loader->load_sharded(index_file, model_dir)) { + return true; + } + } + + error_msg = "No safetensors files found in directory: " + model_dir; + return false; +} + +bool safetensors_model_loader::create_mapper() { + std::string arch = config.get_architecture(); + if (arch.empty()) { + error_msg = "Failed to detect architecture from config.json"; + return false; + } + + mapper = create_tensor_mapper(arch); + if (!mapper) { + error_msg = "Unsupported architecture: " + arch; + return false; + } + + return true; +} + +llama_model * safetensors_model_loader::load( + const std::string & model_dir, + const llama_model_params & params +) { + (void)params; // Unused for now - reserved for future use + + // Load config.json + if (!load_config(model_dir)) { + return nullptr; + } + + // Load safetensors files + if (!load_safetensors_files(model_dir)) { + return nullptr; + } + + // Create tensor mapper + if (!create_mapper()) { + return nullptr; + } + + // TODO: Actually load the model + // This requires: + // 1. Create llama_model structure + // 2. Allocate memory for tensors + // 3. Load tensor data from safetensors + // 4. Map tensor names and populate model structure + // 5. Initialize model parameters + + error_msg = "Safetensors loading not yet fully implemented - this is a work in progress"; + return nullptr; +} diff --git a/src/llama-safetensors-loader.h b/src/llama-safetensors-loader.h new file mode 100644 index 00000000000..3e92c37027a --- /dev/null +++ b/src/llama-safetensors-loader.h @@ -0,0 +1,96 @@ +#pragma once + +#include "llama-safetensors.h" +#include "llama-hf-config.h" +#include "llama-arch.h" + +#include +#include +#include +#include + +// Forward declarations +struct llama_model; +struct llama_model_params; + +// Maps HuggingFace architecture name to llama.cpp architecture +llm_arch hf_arch_to_llm_arch(const std::string & hf_arch_name); + +// Tensor name mapper - converts HF tensor names to llama.cpp internal names +class safetensors_tensor_mapper { +public: + virtual ~safetensors_tensor_mapper() = default; + + // Map HF tensor name to internal name + // Returns empty string if tensor should be skipped + virtual std::string map_tensor_name(const std::string & hf_name) const = 0; + + // Get the architecture this mapper handles + virtual llm_arch get_arch() const = 0; + + // Get expected tensor names (for validation) + virtual std::vector get_required_tensors(int n_layers) const = 0; +}; + +// Llama/Mistral architecture mapper +class llama_tensor_mapper : public safetensors_tensor_mapper { +public: + std::string map_tensor_name(const std::string & hf_name) const override; + llm_arch get_arch() const override { return LLM_ARCH_LLAMA; } + std::vector get_required_tensors(int n_layers) const override; +}; + +// Phi architecture mapper +class phi_tensor_mapper : public safetensors_tensor_mapper { +public: + std::string map_tensor_name(const std::string & hf_name) const override; + llm_arch get_arch() const override { return LLM_ARCH_PHI3; } + std::vector get_required_tensors(int n_layers) const override; +}; + +// Qwen2 architecture mapper +class qwen2_tensor_mapper : public safetensors_tensor_mapper { +public: + std::string map_tensor_name(const std::string & hf_name) const override; + llm_arch get_arch() const override { return LLM_ARCH_QWEN2; } + std::vector get_required_tensors(int n_layers) const override; +}; + +// Gemma architecture mapper +class gemma_tensor_mapper : public safetensors_tensor_mapper { +public: + std::string map_tensor_name(const std::string & hf_name) const override; + llm_arch get_arch() const override { return LLM_ARCH_GEMMA; } + std::vector get_required_tensors(int n_layers) const override; +}; + +// Factory function to create appropriate mapper +std::unique_ptr create_tensor_mapper(const std::string & hf_arch); + +// Main safetensors model loader +class safetensors_model_loader { +public: + safetensors_model_loader() = default; + ~safetensors_model_loader() = default; + + // Load model from safetensors file(s) + // Returns nullptr on error (check get_error()) + llama_model * load( + const std::string & model_dir, + const llama_model_params & params + ); + + // Get last error message + const std::string & get_error() const { return error_msg; } + +private: + std::string error_msg; + + hf_config config; + std::unique_ptr st_loader; + std::unique_ptr mapper; + + bool load_config(const std::string & model_dir); + bool load_safetensors_files(const std::string & model_dir); + bool create_mapper(); +}; diff --git a/src/llama-safetensors-types.cpp b/src/llama-safetensors-types.cpp new file mode 100644 index 00000000000..b7976c03e8b --- /dev/null +++ b/src/llama-safetensors-types.cpp @@ -0,0 +1,171 @@ +#include "llama-safetensors-types.h" + +#include "ggml.h" + +#include +#include + +ggml_type safetensors_dtype_to_ggml_type(safetensors_dtype dtype) { + switch (dtype) { + case safetensors_dtype::F32: return GGML_TYPE_F32; + case safetensors_dtype::F16: return GGML_TYPE_F32; // Convert to F32 for CPU compatibility + case safetensors_dtype::BF16: return GGML_TYPE_F32; // Convert to F32 for CPU compatibility + case safetensors_dtype::I32: return GGML_TYPE_I32; + case safetensors_dtype::I16: return GGML_TYPE_I16; + case safetensors_dtype::I8: return GGML_TYPE_I8; + case safetensors_dtype::U8: return GGML_TYPE_I8; // Map to I8, handle signedness + // Note: GGML doesn't have direct equivalents for all types + case safetensors_dtype::F64: return GGML_TYPE_F32; // Downcast to F32 + case safetensors_dtype::I64: return GGML_TYPE_I32; // Downcast to I32 + case safetensors_dtype::BOOL: return GGML_TYPE_I8; // Map to I8 + default: return GGML_TYPE_COUNT; // Invalid + } +} + +const char * ggml_type_name_safe(ggml_type type) { + if (type < GGML_TYPE_COUNT) { + return ggml_type_name(type); + } + return "INVALID"; +} + +size_t ggml_tensor_size(ggml_type type, const int64_t * shape, int n_dims) { + if (n_dims == 0 || !shape) { + return 0; + } + + int64_t n_elements = 1; + for (int i = 0; i < n_dims; i++) { + n_elements *= shape[i]; + } + + size_t type_size = ggml_type_size(type); + size_t row_size = ggml_row_size(type, shape[0]); + + // For quantized types, use row_size calculation + if (type >= GGML_TYPE_Q4_0 && type < GGML_TYPE_COUNT) { + if (n_dims == 1) { + return row_size; + } + // Calculate total size for multi-dimensional tensors + int64_t n_rows = n_elements / shape[0]; + return row_size * n_rows; + } + + // For standard types + return type_size * n_elements; +} + +bool convert_safetensors_to_ggml( + const void * src_data, + size_t src_size, + safetensors_dtype src_dtype, + void * dst_data, + size_t dst_size, + ggml_type dst_type, + const int64_t * shape, + int n_dims +) { + if (!src_data || !dst_data || !shape || n_dims == 0) { + return false; + } + + int64_t n_elements = 1; + for (int i = 0; i < n_dims; i++) { + n_elements *= shape[i]; + } + + // Direct copy for matching types + if ((src_dtype == safetensors_dtype::F32 && dst_type == GGML_TYPE_F32) || + (src_dtype == safetensors_dtype::F16 && dst_type == GGML_TYPE_F16) || + (src_dtype == safetensors_dtype::BF16 && dst_type == GGML_TYPE_BF16)) { + + if (src_size != dst_size) { + return false; + } + std::memcpy(dst_data, src_data, src_size); + return true; + } + + // Type conversion required + // F64 -> F32 + if (src_dtype == safetensors_dtype::F64 && dst_type == GGML_TYPE_F32) { + const double * src = (const double *)src_data; + float * dst = (float *)dst_data; + for (int64_t i = 0; i < n_elements; i++) { + dst[i] = (float)src[i]; + } + return true; + } + + // F32 -> F16 + if (src_dtype == safetensors_dtype::F32 && dst_type == GGML_TYPE_F16) { + const float * src = (const float *)src_data; + ggml_fp16_t * dst = (ggml_fp16_t *)dst_data; + for (int64_t i = 0; i < n_elements; i++) { + dst[i] = ggml_fp32_to_fp16(src[i]); + } + return true; + } + + // F16 -> F32 + if (src_dtype == safetensors_dtype::F16 && dst_type == GGML_TYPE_F32) { + const ggml_fp16_t * src = (const ggml_fp16_t *)src_data; + float * dst = (float *)dst_data; + for (int64_t i = 0; i < n_elements; i++) { + dst[i] = ggml_fp16_to_fp32(src[i]); + } + return true; + } + + // BF16 -> F32 + if (src_dtype == safetensors_dtype::BF16 && dst_type == GGML_TYPE_F32) { + const uint16_t * src = (const uint16_t *)src_data; + float * dst = (float *)dst_data; + for (int64_t i = 0; i < n_elements; i++) { + // BF16 to F32: shift left 16 bits + uint32_t f32_bits = ((uint32_t)src[i]) << 16; + float f32_value; + memcpy(&f32_value, &f32_bits, sizeof(float)); + dst[i] = f32_value; + } + return true; + } + + // I64 -> I32 + if (src_dtype == safetensors_dtype::I64 && dst_type == GGML_TYPE_I32) { + const int64_t * src = (const int64_t *)src_data; + int32_t * dst = (int32_t *)dst_data; + for (int64_t i = 0; i < n_elements; i++) { + dst[i] = (int32_t)src[i]; + } + return true; + } + + // I32 -> I32, I16 -> I16, I8 -> I8 (direct copy) + if ((src_dtype == safetensors_dtype::I32 && dst_type == GGML_TYPE_I32) || + (src_dtype == safetensors_dtype::I16 && dst_type == GGML_TYPE_I16) || + (src_dtype == safetensors_dtype::I8 && dst_type == GGML_TYPE_I8) || + (src_dtype == safetensors_dtype::U8 && dst_type == GGML_TYPE_I8)) { + + size_t expected_size = n_elements * ggml_type_size(dst_type); + if (src_size < expected_size || dst_size < expected_size) { + return false; + } + std::memcpy(dst_data, src_data, expected_size); + return true; + } + + // BOOL -> I8 + if (src_dtype == safetensors_dtype::BOOL && dst_type == GGML_TYPE_I8) { + const uint8_t * src = (const uint8_t *)src_data; + int8_t * dst = (int8_t *)dst_data; + for (int64_t i = 0; i < n_elements; i++) { + dst[i] = src[i] ? 1 : 0; + } + return true; + } + + // Unsupported conversion + return false; +} diff --git a/src/llama-safetensors-types.h b/src/llama-safetensors-types.h new file mode 100644 index 00000000000..58c64e2b631 --- /dev/null +++ b/src/llama-safetensors-types.h @@ -0,0 +1,29 @@ +#pragma once + +#include "llama-safetensors.h" +#include "ggml.h" + +#include + +// Convert safetensors dtype to GGML type +ggml_type safetensors_dtype_to_ggml_type(safetensors_dtype dtype); + +// Get GGML type name +const char * ggml_type_name_safe(ggml_type type); + +// Convert safetensors tensor data to GGML format +// dst_data must be pre-allocated with enough space +// Returns true on success +bool convert_safetensors_to_ggml( + const void * src_data, + size_t src_size, + safetensors_dtype src_dtype, + void * dst_data, + size_t dst_size, + ggml_type dst_type, + const int64_t * shape, + int n_dims +); + +// Calculate tensor size in bytes for GGML type +size_t ggml_tensor_size(ggml_type type, const int64_t * shape, int n_dims); diff --git a/src/llama-safetensors.cpp b/src/llama-safetensors.cpp new file mode 100644 index 00000000000..5233a562c25 --- /dev/null +++ b/src/llama-safetensors.cpp @@ -0,0 +1,398 @@ +#include "llama-safetensors.h" + +#include +#include +#include +#include "../vendor/nlohmann/json.hpp" + +using json = nlohmann::json; + +// RAII file handle wrapper +struct file_handle { + FILE * file = nullptr; + + file_handle(const char * filename, const char * mode) { + file = fopen(filename, mode); + } + + ~file_handle() { + if (file) { + fclose(file); + } + } + + operator FILE*() { return file; } + operator bool() const { return file != nullptr; } +}; + +safetensors_dtype safetensors_dtype_from_string(const std::string & dtype_str) { + if (dtype_str == "F64") return safetensors_dtype::F64; + if (dtype_str == "F32") return safetensors_dtype::F32; + if (dtype_str == "F16") return safetensors_dtype::F16; + if (dtype_str == "BF16") return safetensors_dtype::BF16; + if (dtype_str == "I64") return safetensors_dtype::I64; + if (dtype_str == "I32") return safetensors_dtype::I32; + if (dtype_str == "I16") return safetensors_dtype::I16; + if (dtype_str == "I8") return safetensors_dtype::I8; + if (dtype_str == "U8") return safetensors_dtype::U8; + if (dtype_str == "BOOL") return safetensors_dtype::BOOL; + return safetensors_dtype::UNKNOWN; +} + +size_t safetensors_dtype_size(safetensors_dtype dtype) { + switch (dtype) { + case safetensors_dtype::F64: return 8; + case safetensors_dtype::F32: return 4; + case safetensors_dtype::F16: return 2; + case safetensors_dtype::BF16: return 2; + case safetensors_dtype::I64: return 8; + case safetensors_dtype::I32: return 4; + case safetensors_dtype::I16: return 2; + case safetensors_dtype::I8: return 1; + case safetensors_dtype::U8: return 1; + case safetensors_dtype::BOOL: return 1; + default: return 0; + } +} + +const char * safetensors_dtype_name(safetensors_dtype dtype) { + switch (dtype) { + case safetensors_dtype::F64: return "F64"; + case safetensors_dtype::F32: return "F32"; + case safetensors_dtype::F16: return "F16"; + case safetensors_dtype::BF16: return "BF16"; + case safetensors_dtype::I64: return "I64"; + case safetensors_dtype::I32: return "I32"; + case safetensors_dtype::I16: return "I16"; + case safetensors_dtype::I8: return "I8"; + case safetensors_dtype::U8: return "U8"; + case safetensors_dtype::BOOL: return "BOOL"; + default: return "UNKNOWN"; + } +} + +bool safetensors_file::open(const std::string & fname) { + close(); // close any existing file + + filename = fname; + file = fopen(filename.c_str(), "rb"); + if (!file) { + error_msg = "Failed to open file: " + filename; + return false; + } + + // Get file size + fseek(file, 0, SEEK_END); + file_size = ftell(file); + fseek(file, 0, SEEK_SET); + + if (file_size < 8) { + error_msg = "File too small to be a valid safetensors file (< 8 bytes)"; + close(); + return false; + } + + return parse_header(); +} + +void safetensors_file::close() { + if (file) { + fclose(file); + file = nullptr; + } + tensors.clear(); + metadata.reset(); + error_msg.clear(); + file_size = 0; + data_start_offset = 0; +} + +bool safetensors_file::parse_header() { + // Read 8-byte header (u64 little-endian) + uint8_t header_bytes[8]; + if (fread(header_bytes, 1, 8, file) != 8) { + error_msg = "Failed to read header"; + return false; + } + + // Parse as little-endian u64 + uint64_t metadata_size = 0; + for (int i = 0; i < 8; i++) { + metadata_size |= (uint64_t)header_bytes[i] << (i * 8); + } + + // Sanity check + if (metadata_size > file_size - 8) { + error_msg = "Invalid metadata size: " + std::to_string(metadata_size); + return false; + } + + if (metadata_size > 100 * 1024 * 1024) { // 100 MB max for metadata + error_msg = "Metadata size too large: " + std::to_string(metadata_size); + return false; + } + + // Read metadata JSON + std::vector metadata_bytes(metadata_size + 1); // +1 for null terminator + if (fread(metadata_bytes.data(), 1, metadata_size, file) != metadata_size) { + error_msg = "Failed to read metadata"; + return false; + } + metadata_bytes[metadata_size] = '\0'; + + // Calculate data start offset (with alignment) + data_start_offset = 8 + metadata_size; + constexpr size_t ALIGNMENT = 8; + if (data_start_offset % ALIGNMENT != 0) { + data_start_offset += ALIGNMENT - (data_start_offset % ALIGNMENT); + } + + // Parse JSON + json j; + try { + j = json::parse(metadata_bytes.data()); + } catch (const std::exception & e) { + error_msg = std::string("Failed to parse metadata JSON: ") + e.what(); + return false; + } + + // Extract tensors + for (auto & [key, value] : j.items()) { + if (key == "__metadata__") { + // Store optional metadata + metadata = std::make_unique(value); + continue; + } + + // Parse tensor info + if (!value.is_object()) { + continue; + } + + safetensors_tensor_info info; + info.name = key; + + try { + // Get dtype + if (!value.contains("dtype") || !value["dtype"].is_string()) { + error_msg = "Missing or invalid dtype for tensor: " + key; + return false; + } + info.dtype = safetensors_dtype_from_string(value["dtype"].get()); + if (info.dtype == safetensors_dtype::UNKNOWN) { + error_msg = "Unknown dtype for tensor: " + key; + return false; + } + + // Get shape + if (!value.contains("shape") || !value["shape"].is_array()) { + error_msg = "Missing or invalid shape for tensor: " + key; + return false; + } + for (auto & dim : value["shape"]) { + if (!dim.is_number_integer()) { + error_msg = "Invalid shape dimension for tensor: " + key; + return false; + } + info.shape.push_back(dim.get()); + } + + // Get data_offsets + if (!value.contains("data_offsets") || !value["data_offsets"].is_array() || + value["data_offsets"].size() != 2) { + error_msg = "Missing or invalid data_offsets for tensor: " + key; + return false; + } + info.offset_start = value["data_offsets"][0].get(); + info.offset_end = value["data_offsets"][1].get(); + + // Validate offsets + if (info.offset_end < info.offset_start) { + error_msg = "Invalid offsets for tensor: " + key; + return false; + } + + // Validate size matches shape and dtype + size_t expected_size = info.n_elements() * safetensors_dtype_size(info.dtype); + if (info.size() != expected_size) { + error_msg = "Size mismatch for tensor " + key + ": expected " + + std::to_string(expected_size) + ", got " + std::to_string(info.size()); + return false; + } + + // Validate offset is within file bounds + if (data_start_offset + info.offset_end > file_size) { + error_msg = "Tensor data extends beyond file bounds: " + key; + return false; + } + + tensors[key] = info; + + } catch (const std::exception & e) { + error_msg = std::string("Error parsing tensor ") + key + ": " + e.what(); + return false; + } + } + + return true; +} + +std::vector safetensors_file::get_tensor_names() const { + std::vector names; + names.reserve(tensors.size()); + for (const auto & [name, _] : tensors) { + names.push_back(name); + } + // Sort for consistency + std::sort(names.begin(), names.end()); + return names; +} + +const safetensors_tensor_info * safetensors_file::get_tensor_info(const std::string & name) const { + auto it = tensors.find(name); + if (it != tensors.end()) { + return &it->second; + } + return nullptr; +} + +bool safetensors_file::read_tensor_data(const std::string & name, void * buffer, size_t buffer_size) { + auto it = tensors.find(name); + if (it == tensors.end()) { + error_msg = "Tensor not found: " + name; + return false; + } + + const auto & info = it->second; + if (buffer_size < info.size()) { + error_msg = "Buffer too small for tensor " + name + ": need " + + std::to_string(info.size()) + ", got " + std::to_string(buffer_size); + return false; + } + + if (!file) { + error_msg = "File not open"; + return false; + } + + // Seek to tensor data position + size_t file_offset = data_start_offset + info.offset_start; + if (fseek(file, file_offset, SEEK_SET) != 0) { + error_msg = "Failed to seek to tensor data for: " + name; + return false; + } + + // Read tensor data + if (fread(buffer, 1, info.size(), file) != info.size()) { + error_msg = "Failed to read tensor data for: " + name; + return false; + } + + return true; +} + +const nlohmann::json * safetensors_file::get_metadata() const { + return metadata.get(); +} + +// safetensors_loader implementation + +bool safetensors_loader::load_single(const std::string & filename) { + auto file = std::make_unique(); + if (!file->open(filename)) { + error_msg = file->get_error(); + return false; + } + + // Map all tensors to this file + size_t file_idx = files.size(); + for (const auto & name : file->get_tensor_names()) { + tensor_to_file[name] = file_idx; + } + + files.push_back(std::move(file)); + return true; +} + +bool safetensors_loader::load_sharded(const std::string & index_path, const std::string & base_dir) { + // Read index.json + std::ifstream f(index_path); + if (!f.is_open()) { + error_msg = "Failed to open index file: " + index_path; + return false; + } + + json index_json; + try { + f >> index_json; + } catch (const std::exception & e) { + error_msg = std::string("Failed to parse index JSON: ") + e.what(); + return false; + } + + // Get weight_map + if (!index_json.contains("weight_map") || !index_json["weight_map"].is_object()) { + error_msg = "Index file missing or invalid weight_map"; + return false; + } + + const auto & weight_map = index_json["weight_map"]; + + // Collect unique shard files + std::map shard_file_to_idx; + for (auto & [tensor_name, shard_file] : weight_map.items()) { + if (!shard_file.is_string()) { + continue; + } + std::string shard_path = base_dir + "/" + shard_file.get(); + + // Load shard if not already loaded + if (shard_file_to_idx.find(shard_path) == shard_file_to_idx.end()) { + size_t idx = files.size(); + auto file = std::make_unique(); + if (!file->open(shard_path)) { + error_msg = "Failed to load shard " + shard_path + ": " + file->get_error(); + return false; + } + files.push_back(std::move(file)); + shard_file_to_idx[shard_path] = idx; + } + + tensor_to_file[tensor_name] = shard_file_to_idx[shard_path]; + } + + return true; +} + +std::vector safetensors_loader::get_tensor_names() const { + std::vector names; + names.reserve(tensor_to_file.size()); + for (const auto & [name, _] : tensor_to_file) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; +} + +const safetensors_tensor_info * safetensors_loader::get_tensor_info(const std::string & name) const { + auto it = tensor_to_file.find(name); + if (it == tensor_to_file.end()) { + return nullptr; + } + return files[it->second]->get_tensor_info(name); +} + +bool safetensors_loader::read_tensor_data(const std::string & name, void * buffer, size_t buffer_size) { + auto it = tensor_to_file.find(name); + if (it == tensor_to_file.end()) { + error_msg = "Tensor not found: " + name; + return false; + } + + if (!files[it->second]->read_tensor_data(name, buffer, buffer_size)) { + error_msg = files[it->second]->get_error(); + return false; + } + + return true; +} diff --git a/src/llama-safetensors.h b/src/llama-safetensors.h new file mode 100644 index 00000000000..591f6e719b7 --- /dev/null +++ b/src/llama-safetensors.h @@ -0,0 +1,139 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "../vendor/nlohmann/json.hpp" + +// Safetensors data types +enum class safetensors_dtype { + F64, // float64 + F32, // float32 + F16, // float16 + BF16, // bfloat16 + I64, // int64 + I32, // int32 + I16, // int16 + I8, // int8 + U8, // uint8 + BOOL, // bool + UNKNOWN +}; + +// Convert safetensors dtype string to enum +safetensors_dtype safetensors_dtype_from_string(const std::string & dtype_str); + +// Get size in bytes for a given dtype +size_t safetensors_dtype_size(safetensors_dtype dtype); + +// Get dtype name +const char * safetensors_dtype_name(safetensors_dtype dtype); + +// Information about a single tensor in the safetensors file +struct safetensors_tensor_info { + std::string name; + safetensors_dtype dtype; + std::vector shape; + size_t offset_start; // offset in data buffer (not file position) + size_t offset_end; // end offset in data buffer + + size_t size() const { + return offset_end - offset_start; + } + + int64_t n_elements() const { + int64_t n = 1; + for (auto dim : shape) { + n *= dim; + } + return n; + } +}; + +// Represents a safetensors file (single file or one shard) +class safetensors_file { +public: + safetensors_file() = default; + ~safetensors_file() = default; + + // Open and parse a safetensors file + // Returns true on success, false on error (check get_error()) + bool open(const std::string & filename); + + // Close the file + void close(); + + // Get list of all tensor names + std::vector get_tensor_names() const; + + // Get information about a specific tensor + // Returns nullptr if tensor not found + const safetensors_tensor_info * get_tensor_info(const std::string & name) const; + + // Read tensor data into a pre-allocated buffer + // Buffer must be at least tensor_info->size() bytes + // Returns true on success + bool read_tensor_data(const std::string & name, void * buffer, size_t buffer_size); + + // Get metadata (optional __metadata__ field) + const nlohmann::json * get_metadata() const; + + // Get last error message + const std::string & get_error() const { return error_msg; } + + // Get file size + size_t get_file_size() const { return file_size; } + + // Get data buffer offset (where tensor data starts in file) + size_t get_data_offset() const { return data_start_offset; } + +private: + std::string filename; + FILE * file = nullptr; + size_t file_size = 0; + size_t data_start_offset = 0; + std::string error_msg; + + std::map tensors; + std::unique_ptr metadata; + + bool parse_header(); +}; + +// Represents a collection of safetensors files (for sharded models) +class safetensors_loader { +public: + safetensors_loader() = default; + ~safetensors_loader() = default; + + // Load a single safetensors file + bool load_single(const std::string & filename); + + // Load sharded model using index.json + // index_path should be "model.safetensors.index.json" + // base_dir is the directory containing the shard files + bool load_sharded(const std::string & index_path, const std::string & base_dir); + + // Get list of all tensor names across all shards + std::vector get_tensor_names() const; + + // Get information about a specific tensor + const safetensors_tensor_info * get_tensor_info(const std::string & name) const; + + // Read tensor data (handles finding the right shard) + bool read_tensor_data(const std::string & name, void * buffer, size_t buffer_size); + + // Get last error message + const std::string & get_error() const { return error_msg; } + + // Get total number of tensors + size_t get_tensor_count() const { return tensor_to_file.size(); } + +private: + std::vector> files; + std::map tensor_to_file; // maps tensor name to file index + std::string error_msg; +}; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a73c4c448ba..0df898fdb10 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -4,6 +4,7 @@ #include "gguf.h" #include "llama-impl.h" #include "llama-model-loader.h" +#include "llama-hf-config.h" #include "unicode.h" @@ -14,6 +15,8 @@ #include #include #include +#include +#include #include #include #include @@ -1612,6 +1615,11 @@ struct llama_vocab::impl { void load(llama_model_loader & ml, const LLM_KV & kv); + bool load_from_hf_tokenizer( + const std::string & tokenizer_json_path, + const std::string & tokenizer_config_path + ); + enum llama_vocab_type get_type() const; std::string type_name() const; @@ -2608,6 +2616,216 @@ llama_token_attr llama_vocab::impl::token_get_attr(llama_token id) const { return id_to_token.at(id).attr; } +bool llama_vocab::impl::load_from_hf_tokenizer( + const std::string & tokenizer_json_path, + const std::string & tokenizer_config_path) { + + LLAMA_LOG_INFO("%s: loading HuggingFace tokenizer from %s\n", __func__, tokenizer_json_path.c_str()); + + // Read tokenizer.json file + std::ifstream file(tokenizer_json_path); + if (!file.is_open()) { + LLAMA_LOG_ERROR("%s: failed to open %s\n", __func__, tokenizer_json_path.c_str()); + return false; + } + + nlohmann::json root; + try { + file >> root; + } catch (const std::exception & e) { + LLAMA_LOG_ERROR("%s: failed to parse tokenizer.json: %s\n", __func__, e.what()); + return false; + } + file.close(); + + // Get model section + if (!root.contains("model") || !root["model"].is_object()) { + LLAMA_LOG_ERROR("%s: tokenizer.json missing 'model' section\n", __func__); + return false; + } + + auto model = root["model"]; + + // Determine tokenizer type from model.type + std::string model_type_str; + if (model.contains("type") && model["type"].is_string()) { + model_type_str = model["type"].get(); + } + + // Set tokenizer type based on HF type + if (model_type_str == "BPE") { + type = LLAMA_VOCAB_TYPE_BPE; + tokenizer_model = "gpt2"; + } else if (model_type_str == "WordPiece") { + type = LLAMA_VOCAB_TYPE_WPM; + tokenizer_model = "bert"; + } else if (model_type_str == "Unigram") { + type = LLAMA_VOCAB_TYPE_UGM; + tokenizer_model = "t5"; + } else { + // Default to BPE + LLAMA_LOG_WARN("%s: unknown tokenizer type '%s', defaulting to BPE\n", __func__, model_type_str.c_str()); + type = LLAMA_VOCAB_TYPE_BPE; + tokenizer_model = "gpt2"; + } + + // Load vocabulary from model.vocab + if (!model.contains("vocab") || !model["vocab"].is_object()) { + LLAMA_LOG_ERROR("%s: tokenizer.json missing 'model.vocab' section\n", __func__); + return false; + } + + auto vocab_obj = model["vocab"]; + + // Extract vocab - it's a map of token -> id + std::map id_to_token_temp; + for (auto it = vocab_obj.begin(); it != vocab_obj.end(); ++it) { + const std::string & token_str = it.key(); + if (it.value().is_number_integer()) { + llama_token token_id = it.value().get(); + token_to_id[token_str] = token_id; + id_to_token_temp[token_id] = token_str; + } + } + + // Build id_to_token vector + if (!id_to_token_temp.empty()) { + llama_token max_id = id_to_token_temp.rbegin()->first; + id_to_token.resize(max_id + 1); + + for (const auto & pair : id_to_token_temp) { + llama_token id = pair.first; + const std::string & text = pair.second; + + id_to_token[id].text = text; + id_to_token[id].score = 0.0f; // HF doesn't use scores + id_to_token[id].attr = LLAMA_TOKEN_ATTR_NORMAL; + } + } + + LLAMA_LOG_INFO("%s: loaded %zu tokens\n", __func__, token_to_id.size()); + + // Load BPE merges if this is a BPE tokenizer + if (type == LLAMA_VOCAB_TYPE_BPE && model.contains("merges") && model["merges"].is_array()) { + int rank = 0; + for (const auto & merge_item : model["merges"]) { + if (merge_item.is_string()) { + std::string merge_str = merge_item.get(); + size_t pos = merge_str.find(' '); + if (pos != std::string::npos) { + std::string first = merge_str.substr(0, pos); + std::string second = merge_str.substr(pos + 1); + bpe_ranks[std::make_pair(first, second)] = rank++; + } + } + } + LLAMA_LOG_INFO("%s: loaded %zu BPE merges\n", __func__, bpe_ranks.size()); + } + + // Load added_tokens (special tokens) + if (root.contains("added_tokens") && root["added_tokens"].is_array()) { + for (const auto & token_obj : root["added_tokens"]) { + if (!token_obj.is_object()) continue; + + if (token_obj.contains("id") && token_obj.contains("content") && + token_obj["id"].is_number_integer() && token_obj["content"].is_string()) { + + llama_token id = token_obj["id"].get(); + std::string content = token_obj["content"].get(); + bool is_special = token_obj.value("special", false); + + // Update token attributes + if (id < (llama_token)id_to_token.size()) { + if (is_special) { + id_to_token[id].attr = LLAMA_TOKEN_ATTR_CONTROL; + } else { + id_to_token[id].attr = LLAMA_TOKEN_ATTR_USER_DEFINED; + } + } + + // Try to identify common special tokens + if (content == "" || content == "<|begin_of_text|>" || content == "<|startoftext|>") { + special_bos_id = id; + } else if (content == "" || content == "<|end_of_text|>" || content == "<|endoftext|>") { + special_eos_id = id; + } else if (content == "" || content == "<|unknown|>") { + special_unk_id = id; + } else if (content == "" || content == "<|pad|>") { + special_pad_id = id; + } else if (content == "\n") { + linefeed_id = id; + } + } + } + } + + // Try to load special tokens from tokenizer_config.json if provided + if (!tokenizer_config_path.empty() && std::filesystem::exists(tokenizer_config_path)) { + std::ifstream config_file(tokenizer_config_path); + if (config_file.is_open()) { + nlohmann::json config_root; + try { + config_file >> config_root; + + auto get_token_id = [&](const nlohmann::json & val) -> llama_token { + if (val.is_string()) { + std::string token_str = val.get(); + auto it = token_to_id.find(token_str); + if (it != token_to_id.end()) return it->second; + } else if (val.is_object() && val.contains("content") && val["content"].is_string()) { + std::string token_str = val["content"].get(); + auto it = token_to_id.find(token_str); + if (it != token_to_id.end()) return it->second; + } + return LLAMA_TOKEN_NULL; + }; + + if (config_root.contains("bos_token")) { + llama_token bos = get_token_id(config_root["bos_token"]); + if (bos != LLAMA_TOKEN_NULL) special_bos_id = bos; + } + if (config_root.contains("eos_token")) { + llama_token eos = get_token_id(config_root["eos_token"]); + if (eos != LLAMA_TOKEN_NULL) special_eos_id = eos; + } + if (config_root.contains("unk_token")) { + llama_token unk = get_token_id(config_root["unk_token"]); + if (unk != LLAMA_TOKEN_NULL) special_unk_id = unk; + } + if (config_root.contains("pad_token")) { + llama_token pad = get_token_id(config_root["pad_token"]); + if (pad != LLAMA_TOKEN_NULL) special_pad_id = pad; + } + } catch (const std::exception & e) { + LLAMA_LOG_WARN("%s: failed to parse tokenizer_config.json: %s\n", __func__, e.what()); + } + } + } + + // Log special tokens + LLAMA_LOG_INFO("%s: special tokens: bos=%d eos=%d unk=%d pad=%d\n", + __func__, special_bos_id, special_eos_id, special_unk_id, special_pad_id); + + // Initialize tokenizer for this type + init_tokenizer(type); + + // Build cache + // IMPORTANT: Build in a temporary vector first, then swap it in atomically! + // If we push_back directly to cache_token_to_piece, after the first iteration + // the cache becomes non-empty, causing token_to_piece() to try using cache.at() + // for indices that don't exist yet, throwing std::out_of_range("vector"). + std::vector temp_cache; + temp_cache.reserve(id_to_token.size()); + for (llama_token id = 0; id < (llama_token)id_to_token.size(); ++id) { + temp_cache.push_back(token_to_piece_for_cache(id, true)); + } + // Atomically swap in the fully-built cache + cache_token_to_piece.swap(temp_cache); + + LLAMA_LOG_INFO("%s: HuggingFace tokenizer loaded successfully\n", __func__); + return true; +} + void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) { LLAMA_LOG_DEBUG("%s: initializing tokenizer for type %d\n", __func__, type); @@ -3260,6 +3478,12 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) { pimpl->load(ml, kv); } +bool llama_vocab::load_from_hf_tokenizer( + const std::string & tokenizer_json_path, + const std::string & tokenizer_config_path) { + return pimpl->load_from_hf_tokenizer(tokenizer_json_path, tokenizer_config_path); +} + std::string llama_vocab::get_tokenizer_model() const { return pimpl->tokenizer_model; } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 55f8f3923c9..ae8ed0966ea 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -68,6 +68,13 @@ struct llama_vocab { void load(llama_model_loader & ml, const LLM_KV & kv); + // Load vocabulary from HuggingFace tokenizer.json format + // Used by safetensors loader + bool load_from_hf_tokenizer( + const std::string & tokenizer_json_path, + const std::string & tokenizer_config_path + ); + std::string get_tokenizer_model() const; std::string get_tokenizer_pre() const; diff --git a/src/llama.cpp b/src/llama.cpp index ab2e9868af4..a98f99930f9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -6,6 +6,7 @@ #include "llama-model-loader.h" #include "llama-model-saver.h" #include "llama-model.h" +#include "llama-model-from-safetensors.h" #include "ggml.h" #include "ggml-backend.h" @@ -16,6 +17,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -98,6 +100,49 @@ int64_t llama_time_us(void) { return ggml_time_us(); } +// Helper function to detect if a path is a safetensors model +static bool is_safetensors_model(const std::string & path) { + namespace fs = std::filesystem; + + // Check if path is a directory + if (fs::is_directory(path)) { + // Look for config.json and at least one .safetensors file + bool has_config = fs::exists(path + "/config.json"); + bool has_safetensors = false; + + try { + for (const auto & entry : fs::directory_iterator(path)) { + if (entry.is_regular_file()) { + std::string filename = entry.path().filename().string(); + // Check if filename ends with ".safetensors" + const std::string suffix = ".safetensors"; + if (filename.size() >= suffix.size() && + filename.compare(filename.size() - suffix.size(), suffix.size(), suffix) == 0) { + has_safetensors = true; + break; + } + } + } + } catch (...) { + return false; + } + + return has_config && has_safetensors; + } + + // Check if path is a .safetensors file + const std::string suffix = ".safetensors"; + if (path.size() >= suffix.size() && + path.compare(path.size() - suffix.size(), suffix.size(), suffix) == 0) { + // Check if config.json exists in the same directory + fs::path safetensors_path(path); + fs::path config_path = safetensors_path.parent_path() / "config.json"; + return fs::exists(config_path); + } + + return false; +} + // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) { // loading time will be recalculated after the first eval, so @@ -163,6 +208,27 @@ static struct llama_model * llama_model_load_from_file_impl( return nullptr; } + // Check if this is a safetensors model + if (is_safetensors_model(path_model)) { + LLAMA_LOG_INFO("%s: detected safetensors model format\n", __func__); + if (!splits.empty()) { + LLAMA_LOG_WARN("%s: safetensors models do not support splits parameter (will be ignored)\n", __func__); + } + + // Load via safetensors loader + try { + llama_model * model = llama_model_load_from_safetensors(path_model.c_str(), params); + if (model) { + LLAMA_LOG_INFO("%s: safetensors model loaded successfully\n", __func__); + } + return model; + } catch (const std::exception & e) { + LLAMA_LOG_ERROR("%s: failed to load safetensors model: %s\n", __func__, e.what()); + return nullptr; + } + } + + // Continue with GGUF loading for non-safetensors models unsigned cur_percentage = 0; if (params.progress_callback == NULL) { params.progress_callback_user_data = &cur_percentage; diff --git a/tools/run/run.cpp b/tools/run/run.cpp index b90a7253c43..38e099e0be4 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -1113,12 +1113,20 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch & static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) { char buf[256]; int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true); + + // DEBUG: Log token ID and decoded length + fprintf(stderr, "[DEBUG] Token ID: %d, decoded length: %d\n", token_id, n); + if (n < 0) { printe("failed to convert token to piece\n"); return 1; } piece = std::string(buf, n); + + // DEBUG: Log decoded piece + fprintf(stderr, "[DEBUG] Decoded piece (len=%d): '%s'\n", (int)piece.size(), piece.c_str()); + return 0; } @@ -1149,7 +1157,12 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str // sample the next token, check is it an end of generation? new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); - if (llama_vocab_is_eog(vocab, new_token_id)) { + + // DEBUG: Log sampled token and EOG check + bool is_eog = llama_vocab_is_eog(vocab, new_token_id); + fprintf(stderr, "[DEBUG] Sampled token ID: %d, is_eog: %s\n", new_token_id, is_eog ? "TRUE" : "FALSE"); + + if (is_eog) { break; }