Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 86 additions & 17 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,30 +224,99 @@ static handle_model_result common_params_handle_model(
if (model.hf_file.empty()) {
if (model.path.empty()) {
auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline);
if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) {
if (auto_detected.repo.empty()) {
exit(1); // built without CURL, error message already printed
}

model.hf_repo = auto_detected.repo;
model.hf_file = auto_detected.ggufFile;
if (!auto_detected.mmprojFile.empty()) {
result.found_mmproj = true;
result.mmproj.hf_repo = model.hf_repo;
result.mmproj.hf_file = auto_detected.mmprojFile;

// Handle safetensors format
if (auto_detected.is_safetensors) {
LOG_INF("%s: detected safetensors format for %s\n", __func__, model.hf_repo.c_str());

// Create a directory for the safetensors files
std::string dir_name = model.hf_repo;
string_replace_all(dir_name, "/", "_");
model.path = fs_get_cache_directory() + "/" + dir_name;

// Create directory if it doesn't exist
std::filesystem::create_directories(model.path);

// Download required files: config.json, tokenizer.json, tokenizer_config.json, and .safetensors files
std::string model_endpoint = get_model_endpoint();
std::vector<std::pair<std::string, std::string>> files_to_download;

// Required config files
files_to_download.push_back({
model_endpoint + model.hf_repo + "/resolve/main/config.json",
model.path + "/config.json"
});
files_to_download.push_back({
model_endpoint + model.hf_repo + "/resolve/main/tokenizer.json",
model.path + "/tokenizer.json"
});
files_to_download.push_back({
model_endpoint + model.hf_repo + "/resolve/main/tokenizer_config.json",
model.path + "/tokenizer_config.json"
});

// Safetensors files
for (const auto & st_file : auto_detected.safetensors_files) {
files_to_download.push_back({
model_endpoint + model.hf_repo + "/resolve/main/" + st_file,
model.path + "/" + st_file
});
}

// Download all files
LOG_INF("%s: downloading %zu files for safetensors model...\n", __func__, files_to_download.size());
for (const auto & [url, path] : files_to_download) {
bool ok = common_download_file_single(url, path, bearer_token, offline);
if (!ok) {
LOG_ERR("error: failed to download file from %s\n", url.c_str());
exit(1);
}
}

LOG_INF("%s: safetensors model downloaded to %s\n", __func__, model.path.c_str());
} else {
// Handle GGUF format (existing logic)
if (auto_detected.ggufFile.empty()) {
exit(1); // no GGUF file found
}
model.hf_file = auto_detected.ggufFile;
if (!auto_detected.mmprojFile.empty()) {
result.found_mmproj = true;
result.mmproj.hf_repo = model.hf_repo;
result.mmproj.hf_file = auto_detected.mmprojFile;
}

std::string model_endpoint = get_model_endpoint();
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
// make sure model path is present (for caching purposes)
if (model.path.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = model.hf_repo + "_" + model.hf_file;
// to make sure we don't have any slashes in the filename
string_replace_all(filename, "/", "_");
model.path = fs_get_cache_file(filename);
}
}
} else {
model.hf_file = model.path;
}
}

std::string model_endpoint = get_model_endpoint();
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
// make sure model path is present (for caching purposes)
if (model.path.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = model.hf_repo + "_" + model.hf_file;
// to make sure we don't have any slashes in the filename
string_replace_all(filename, "/", "_");
model.path = fs_get_cache_file(filename);
} else {
// User specified hf_file explicitly - use GGUF download path
std::string model_endpoint = get_model_endpoint();
model.url = model_endpoint + model.hf_repo + "/resolve/main/" + model.hf_file;
// make sure model path is present (for caching purposes)
if (model.path.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = model.hf_repo + "_" + model.hf_file;
// to make sure we don't have any slashes in the filename
string_replace_all(filename, "/", "_");
model.path = fs_get_cache_file(filename);
}
}

} else if (!model.url.empty()) {
Expand Down
89 changes: 83 additions & 6 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,10 +715,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string

#if defined(LLAMA_USE_CURL) || defined(LLAMA_USE_HTTPLIB)

static bool common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline) {
bool common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline) {
if (!offline) {
return common_download_file_single_online(url, path, bearer_token);
}
Expand Down Expand Up @@ -897,16 +897,93 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
}
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else if (res_code == 400) {
// 400 typically means "not a GGUF repo" - we'll check for safetensors below
LOG_INF("%s: manifest endpoint returned 400 (not a GGUF repo), will check for safetensors...\n", __func__);
} else {
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
}

// check response
if (ggufFile.empty()) {
throw std::runtime_error("error: model does not have ggufFile");
// No GGUF found - try to detect safetensors format
LOG_INF("%s: no GGUF file found, checking for safetensors format...\n", __func__);

// Query HF API to list files in the repo
std::string files_url = get_model_endpoint() + "api/models/" + hf_repo + "/tree/main";

common_remote_params files_params;
files_params.headers = headers;

long files_res_code = 0;
std::string files_res_str;

if (!offline) {
try {
auto files_res = common_remote_get_content(files_url, files_params);
files_res_code = files_res.first;
files_res_str = std::string(files_res.second.data(), files_res.second.size());
} catch (const std::exception & e) {
throw std::runtime_error("error: model does not have ggufFile and failed to check for safetensors: " + std::string(e.what()));
}
} else {
throw std::runtime_error("error: model does not have ggufFile (offline mode, cannot check for safetensors)");
}

if (files_res_code != 200) {
throw std::runtime_error("error: model does not have ggufFile");
}

// Parse the files list
std::vector<std::string> safetensors_files;
bool has_config = false;
bool has_tokenizer = false;

try {
auto files_json = json::parse(files_res_str);

for (const auto & file : files_json) {
if (file.contains("path")) {
std::string path = file["path"].get<std::string>();

if (path == "config.json") {
has_config = true;
} else if (path == "tokenizer.json") {
has_tokenizer = true;
} else {
// Check for .safetensors extension
const std::string suffix = ".safetensors";
if (path.size() >= suffix.size() &&
path.compare(path.size() - suffix.size(), suffix.size(), suffix) == 0) {
safetensors_files.push_back(path);
}
}
}
}
} catch (const std::exception & e) {
throw std::runtime_error("error: model does not have ggufFile and failed to parse file list: " + std::string(e.what()));
}

// Check if we have the required safetensors files
if (!has_config || !has_tokenizer || safetensors_files.empty()) {
throw std::runtime_error("error: model does not have ggufFile or valid safetensors format");
}

LOG_INF("%s: detected safetensors format with %zu tensor files\n", __func__, safetensors_files.size());

common_hf_file_res result;
result.repo = hf_repo;
result.is_safetensors = true;
result.safetensors_files = safetensors_files;
return result;
}

return { hf_repo, ggufFile, mmprojFile };
common_hf_file_res result;
result.repo = hf_repo;
result.ggufFile = ggufFile;
result.mmprojFile = mmprojFile;
result.is_safetensors = false;
return result;
}

//
Expand Down
12 changes: 12 additions & 0 deletions common/download.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <string>
#include <vector>

struct common_params_model;

Expand All @@ -23,6 +24,10 @@ struct common_hf_file_res {
std::string repo; // repo name with ":tag" removed
std::string ggufFile;
std::string mmprojFile;

// Safetensors support
bool is_safetensors = false; // true if model is in safetensors format
std::vector<std::string> safetensors_files; // list of .safetensors files to download
};

/**
Expand All @@ -41,6 +46,13 @@ common_hf_file_res common_get_hf_file(
const std::string & bearer_token,
bool offline);

// download a single file (no GGUF validation)
bool common_download_file_single(
const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline);

// returns true if download succeeded
bool common_download_model(
const common_params_model & model,
Expand Down
5 changes: 5 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ add_library(llama
llama-quant.cpp
llama-sampling.cpp
llama-vocab.cpp
llama-safetensors.cpp
llama-hf-config.cpp
llama-safetensors-loader.cpp
llama-safetensors-types.cpp
llama-model-from-safetensors.cpp
unicode-data.cpp
unicode.cpp
unicode.h
Expand Down
Loading
Loading