2929#include < iomanip>
3030#include < regex>
3131#include < sstream>
32- #include < sys/ioctl.h>
33- #include < unistd.h>
3432
3533#include < curl/curl.h>
34+ #include < sys/ioctl.h>
3635#include < sys/stat.h>
3736#include < sys/types.h>
37+ #include < unistd.h>
3838
3939#include " huggingface_hub.h"
4040
@@ -136,82 +136,78 @@ size_t write_file_data(void *ptr, size_t size, size_t nmemb, void *stream) {
136136 return size * nmemb;
137137}
138138
139- // Function to extract SHA256 from Git LFS metadata
140- std::string extract_SHA256 (const std::string &response) {
141- std::istringstream stream (response);
142- std::string line;
143-
144- while (std::getline (stream, line)) {
145- if (line.find (" oid sha256:" ) != std::string::npos) {
146- auto result =
147- line.substr (line.find (" sha256:" ) + 7 ); // Extract after "sha256:"
148- stream.str (" " ); // Clear the stream
149- return result;
150- }
151- }
152- return " " ; // Return empty if not found
153- }
139+ // Extract metadata from JSON response
140+ FileMetadata extract_metadata (const std::string &json) {
141+ FileMetadata metadata;
154142
155- uint64_t extract_file_size (const std::string &response) {
156- std::istringstream stream (response);
157- std::string line;
158- uint64_t size = 0 ;
159-
160- while (std::getline (stream, line)) {
161- if (line.find (" size " ) != std::string::npos) {
162- std::string sizeStr = line.substr (5 ); // Extract after "size "
163- size = std::stoull (sizeStr);
164- return size;
165- }
166- }
167- return 0 ; // Return empty if not found
168- }
143+ std::smatch match;
169144
170- std::string extract_commit (const std::string &response) {
171- std::istringstream stream (response);
172- std::string line;
145+ // Extract "type"
146+ if (std::regex_search (json, match,
147+ std::regex (R"( \"type\"\s*:\s*\"([^"]+)\")" )))
148+ metadata.type = match[1 ];
149+
150+ // Extract "oid" (top-level one)
151+ if (std::regex_search (json, match,
152+ std::regex (R"( \"oid\"\s*:\s*\"([a-f0-9]{40})\")" )))
153+ metadata.oid = match[1 ];
154+
155+ // Extract "size"
156+ if (std::regex_search (json, match, std::regex (R"( \"size\"\s*:\s*(\d+))" )))
157+ metadata.size = std::stoull (match[1 ]);
158+
159+ // Extract "lfs" SHA-256 hash
160+ if (std::regex_search (
161+ json, match,
162+ std::regex (
163+ R"( \"lfs\"\s*:\s*\{[^}]*\"oid\"\s*:\s*\"([a-f0-9]{64})\")" )))
164+ metadata.sha256 = match[1 ];
165+
166+ // Extract "commit" ID
167+ if (std::regex_search (
168+ json, match,
169+ std::regex (
170+ R"( \"lastCommit\"\s*:\s*\{[^}]*\"id\"\s*:\s*\"([a-f0-9]{40})\")" )))
171+ metadata.commit = match[1 ];
173172
174- while (std::getline (stream, line)) {
175- if (line.find (" x-repo-commit:" ) != std::string::npos) {
176- auto result = line.substr (line.find (" :" ) + 2 ); // Extract after ": "
177- result.erase (result.find_last_not_of (" \n\r\t " ) +
178- 1 ); // Trim trailing whitespace
179- return result;
180- }
181- }
182- return " " ; // Return empty if not found
173+ return metadata;
183174}
184175
185176std::variant<struct FileMetadata , std::string>
186177get_model_metadata_from_hf (const std::string &repo, const std::string &file) {
187- struct FileMetadata metadata;
188- std::string url = " https://huggingface.co/" + repo + " /raw/main/" + file;
189- std::string response, headers;
190-
191178 CURL *curl = curl_easy_init ();
192179 if (!curl) {
193180 return " Failed to initialize CURL" ;
194181 }
195182
183+ std::string response, headers;
184+
185+ std::string url =
186+ " https://huggingface.co/api/models/" + repo + " /paths-info/main" ;
187+ const std::string body = " {\" paths\" : [\" " + file + " \" ], \" expand\" : true}" ;
188+
189+ struct curl_slist *http_headers = NULL ;
190+ http_headers =
191+ curl_slist_append (http_headers, " Content-Type: application/json" );
192+
196193 curl_easy_setopt (curl, CURLOPT_URL, url.c_str ());
194+ curl_easy_setopt (curl, CURLOPT_HTTPHEADER, http_headers);
195+ curl_easy_setopt (curl, CURLOPT_POSTFIELDS, body.c_str ());
197196 curl_easy_setopt (curl, CURLOPT_WRITEFUNCTION, write_string_data);
198197 curl_easy_setopt (curl, CURLOPT_WRITEDATA, &response);
199198 curl_easy_setopt (curl, CURLOPT_FOLLOWLOCATION, 1L );
200199 curl_easy_setopt (curl, CURLOPT_FAILONERROR, 1L );
201200 curl_easy_setopt (curl, CURLOPT_HEADERDATA, &headers);
202201
203202 CURLcode res = curl_easy_perform (curl);
203+ curl_slist_free_all (http_headers);
204204 curl_easy_cleanup (curl);
205205
206206 if (res != CURLE_OK) {
207207 return " CURL request failed: " + std::string (curl_easy_strerror (res));
208208 }
209209
210- metadata.sha256 = extract_SHA256 (response);
211- metadata.size = extract_file_size (response);
212- metadata.commit = extract_commit (headers);
213-
214- return metadata;
210+ return extract_metadata (response);
215211}
216212
217213int get_terminal_width () {
@@ -348,14 +344,24 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
348344 log_info (" Downloading " + filename + " from " + repo_id);
349345
350346 struct FileMetadata metadata = std::get<struct FileMetadata >(metadata_result);
351- log_debug (" SHA256: " + metadata.sha256 );
352347 log_debug (" Commit: " + metadata.commit );
348+ log_debug (" Blob ID: " + metadata.oid );
353349 log_debug (" Size: " + std::to_string (metadata.size ) + " bytes" );
350+ log_debug (" SHA256: " + metadata.sha256 );
351+
352+ std::filesystem::path blob_file_path;
353+ std::filesystem::path blob_incomplete_file_path;
354+
355+ if (metadata.sha256 .empty ()) {
356+ blob_file_path = cache_model_dir + " blobs/" + metadata.oid ;
357+ blob_incomplete_file_path =
358+ cache_model_dir + " blobs/" + metadata.oid + " .incomplete" ;
359+ } else {
360+ blob_file_path = cache_model_dir + " blobs/" + metadata.sha256 ;
361+ blob_incomplete_file_path =
362+ cache_model_dir + " blobs/" + metadata.sha256 + " .incomplete" ;
363+ }
354364
355- std::filesystem::path blob_file_path (cache_model_dir + " blobs/" +
356- metadata.sha256 );
357- std::filesystem::path blob_incomplete_file_path (
358- cache_model_dir + " blobs/" + metadata.sha256 + " .incomplete" );
359365 std::filesystem::path snapshot_file_path (cache_model_dir + " snapshots/" +
360366 metadata.commit + " /" + filename);
361367 std::filesystem::path refs_file_path (cache_model_dir + " refs/main" );
@@ -417,18 +423,19 @@ struct DownloadResult hf_hub_download_with_shards(const std::string &repo_id,
417423 const std::string &cache_dir,
418424 bool force_download) {
419425
420- std::regex pattern (R"( -([0-9]+)-of-([0-9]+)\.gguf )" );
426+ std::regex pattern (R"( -([0-9]+)-of-([0-9]+)\.(\w+) )" );
421427 std::smatch match;
422428
423429 if (std::regex_search (filename, match, pattern)) {
424430 int total_shards = std::stoi (match[2 ]);
425431 std::string base_name = filename.substr (0 , match.position (0 ));
432+ std::string extension = match[3 ];
426433
427434 // Download shards
428435 for (int i = 1 ; i <= total_shards; ++i) {
429436 char shard_file[512 ];
430- snprintf (shard_file, sizeof (shard_file), " %s-%05d-of-%05d.gguf " ,
431- base_name.c_str (), i, total_shards);
437+ snprintf (shard_file, sizeof (shard_file), " %s-%05d-of-%05d.%s " ,
438+ base_name.c_str (), i, total_shards, extension. c_str () );
432439 auto aux_res =
433440 hf_hub_download (repo_id, shard_file, cache_dir, force_download);
434441
@@ -439,8 +446,8 @@ struct DownloadResult hf_hub_download_with_shards(const std::string &repo_id,
439446
440447 // Return first shard
441448 char first_shard[512 ];
442- snprintf (first_shard, sizeof (first_shard), " %s-00001-of-%05d.gguf " ,
443- base_name.c_str (), total_shards);
449+ snprintf (first_shard, sizeof (first_shard), " %s-00001-of-%05d.%s " ,
450+ base_name.c_str (), total_shards, extension. c_str () );
444451 return hf_hub_download (repo_id, first_shard, cache_dir, false );
445452 }
446453
0 commit comments