@@ -85,16 +85,37 @@ std::filesystem::path expand_user_home(const std::string &path) {
8585 return std::filesystem::path (path);
8686}
8787
88- std::string create_cache_system (const std::string &cache_dir,
89- const std::string &repo_id) {
90- std::string model_folder = std::string (" models/" + repo_id);
88+ std::string get_model_repo_path (const std::string &repo_id) {
89+ std::string model_folder = " models/" + repo_id;
9190
9291 size_t pos = 0 ;
9392 while ((pos = model_folder.find (" /" , pos)) != std::string::npos) {
9493 model_folder.replace (pos, 1 , " --" );
9594 pos += 2 ;
9695 }
9796
97+ return model_folder;
98+ }
99+
100+ std::string find_outdated_file (const std::string &snapshot_dir,
101+ const std::string &filename) {
102+ for (const auto &version :
103+ std::filesystem::directory_iterator (snapshot_dir)) {
104+ for (const auto &file :
105+ std::filesystem::directory_iterator (version.path ())) {
106+ if (file.path ().filename () == filename) {
107+ return file.path ();
108+ break ;
109+ }
110+ }
111+ }
112+ return " " ;
113+ }
114+
115+ std::string create_cache_system (const std::string &cache_dir,
116+ const std::string &repo_id) {
117+ std::string model_folder = get_model_repo_path (repo_id);
118+
98119 std::string expanded_cache_dir = expand_user_home (cache_dir);
99120
100121 std::string model_cache_path = expanded_cache_dir + " /" + model_folder + " /" ;
@@ -112,7 +133,7 @@ std::string create_cache_system(const std::string &cache_dir,
112133
113134size_t write_string_data (void *ptr, size_t size, size_t nmemb, void *stream) {
114135 if (!stream) {
115- log_error (" Error: stream is null!" );
136+ log_error (" Stream is null!" );
116137 return 0 ;
117138 }
118139 std::string *out = static_cast <std::string *>(stream);
@@ -122,12 +143,12 @@ size_t write_string_data(void *ptr, size_t size, size_t nmemb, void *stream) {
122143
123144size_t write_file_data (void *ptr, size_t size, size_t nmemb, void *stream) {
124145 if (!stream) {
125- log_error (" Error: stream is null!" );
146+ log_error (" Stream is null!" );
126147 return 0 ;
127148 }
128149 std::ofstream *out = static_cast <std::ofstream *>(stream);
129150 if (!out->is_open ()) {
130- log_error (" Error: output file stream is not open!" );
151+ log_error (" Output file stream is not open!" );
131152 return 0 ;
132153 }
133154 out->write (static_cast <char *>(ptr), size * nmemb);
@@ -161,16 +182,68 @@ FileMetadata extract_metadata(const std::string &json) {
161182 R"( \"lfs\"\s*:\s*\{[^}]*\"oid\"\s*:\s*\"([a-f0-9]{64})\")" )))
162183 metadata.sha256 = match[1 ];
163184
164- // Extract "commit" ID
165- if (std::regex_search (
166- json, match,
167- std::regex (
168- R"( \"lastCommit\"\s*:\s*\{[^}]*\"id\"\s*:\s*\"([a-f0-9]{40})\")" )))
169- metadata.commit = match[1 ];
170-
171185 return metadata;
172186}
173187
188+ std::string get_file_path (const std::string &cache_dir,
189+ const std::string &repo_id, const std::string &file) {
190+ std::string model_folder = get_model_repo_path (repo_id);
191+
192+ std::filesystem::path expanded_cache_dir = expand_user_home (cache_dir);
193+ std::filesystem::path refs_file_path =
194+ expanded_cache_dir / model_folder / " refs" / " main" ;
195+
196+ if (!std::filesystem::exists (refs_file_path)) {
197+ log_debug (" refs file does not exist" );
198+ return " " ;
199+ }
200+ std::ifstream refs_file (refs_file_path);
201+ std::string commit;
202+ refs_file >> commit;
203+ refs_file.close ();
204+ std::filesystem::path snapshot_file_path =
205+ expanded_cache_dir / model_folder / " snapshots" / commit / file;
206+ if (std::filesystem::exists (snapshot_file_path)) {
207+ return snapshot_file_path.string ();
208+ } else {
209+ return " " ; // File does not exist
210+ }
211+ }
212+
213+ std::variant<std::string, CURLcode> get_model_commit (const std::string &repo) {
214+ CURL *curl = curl_easy_init ();
215+ if (!curl) {
216+ return CURLE_FAILED_INIT;
217+ }
218+
219+ std::string url =
220+ " https://huggingface.co/api/models/" + repo + " /revision/main" ;
221+ std::string response;
222+
223+ curl_easy_setopt (curl, CURLOPT_URL, url.c_str ());
224+ curl_easy_setopt (curl, CURLOPT_HTTPHEADER, NULL );
225+ curl_easy_setopt (curl, CURLOPT_WRITEFUNCTION, write_string_data);
226+ curl_easy_setopt (curl, CURLOPT_WRITEDATA, &response);
227+ curl_easy_setopt (curl, CURLOPT_FOLLOWLOCATION, 1L );
228+ curl_easy_setopt (curl, CURLOPT_FAILONERROR, 1L );
229+ curl_easy_setopt (curl, CURLOPT_HEADER, 0L );
230+
231+ CURLcode res = curl_easy_perform (curl);
232+ curl_easy_cleanup (curl);
233+ if (res != CURLE_OK) {
234+ return res;
235+ }
236+
237+ std::smatch match;
238+ std::regex pattern (" \" sha\"\\ s*:\\ s*\" ([a-fA-F0-9]{40})\" " );
239+
240+ if (std::regex_search (response, match, pattern) && match.size () > 1 ) {
241+ return match[1 ];
242+ } else {
243+ return std::string (); // Return empty string if not found
244+ }
245+ }
246+
174247std::variant<struct FileMetadata , CURLcode>
175248get_model_metadata_from_hf (const std::string &repo, const std::string &file) {
176249 CURL *curl = curl_easy_init ();
@@ -205,6 +278,10 @@ get_model_metadata_from_hf(const std::string &repo, const std::string &file) {
205278 return res;
206279 }
207280
281+ if (response.empty () || response == " []" ) {
282+ return CURLE_REMOTE_FILE_NOT_FOUND;
283+ }
284+
208285 return extract_metadata (response);
209286}
210287
@@ -287,7 +364,6 @@ CURLcode perform_download(std::string url,
287364 std::ios::binary | std::ios::app);
288365
289366 if (!file.is_open ()) {
290- log_error (" Error: failed to open file stream!" );
291367 return CURLE_FAILED_INIT;
292368 }
293369
@@ -328,51 +404,70 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
328404 struct DownloadResult result;
329405 result.success = true ;
330406
331- // 1. Check that model exists on Hugging Face
332- auto metadata_result = get_model_metadata_from_hf (repo_id, filename);
333- if (std::holds_alternative<CURLcode>(metadata_result)) {
334- CURLcode err = std::get<CURLcode>(metadata_result);
407+ log_info (" Downloading " + filename + " from " + repo_id);
335408
336- std::string refs_main_path = " models/" + repo_id;
337- size_t pos = 0 ;
338- while ((pos = refs_main_path.find (" /" , pos)) != std::string::npos) {
339- refs_main_path.replace (pos, 1 , " --" );
340- pos += 2 ;
409+ // Check repo (accessibility and version)
410+ auto commit_result = get_model_commit (repo_id);
411+
412+ if (std::holds_alternative<CURLcode>(commit_result)) {
413+ CURLcode err = std::get<CURLcode>(commit_result);
414+
415+ std::string file_path = get_file_path (cache_dir, repo_id, filename);
416+ if (!file_path.empty ()) {
417+ log_info (" Using cached file." );
418+ result.path = file_path;
419+ result.success = true ;
420+ return result;
341421 }
342422
343- std::filesystem::path cache_model_dir =
344- expand_user_home (" ~/.cache/huggingface/hub/" + refs_main_path + " /" );
345- std::filesystem::path refs_file_path = cache_model_dir / " refs/main" ;
346-
347- if (std::filesystem::exists (refs_file_path)) {
348- std::ifstream refs_file (refs_file_path);
349- std::string commit;
350- refs_file >> commit;
351- refs_file.close ();
352-
353- std::filesystem::path snapshot_file_path =
354- cache_model_dir / " snapshots" / commit / filename;
355- if (std::filesystem::exists (snapshot_file_path)) {
356- log_info (" Snapshot file exists. Skipping download..." );
357- result.success = true ;
358- result.path = snapshot_file_path;
359- return result;
360- }
423+ std::string model_path = get_model_repo_path (repo_id);
424+ std::string snapshot_path =
425+ expand_user_home (cache_dir + " /" + model_path + " /snapshots" );
426+ if (!std::filesystem::exists (snapshot_path)) {
427+ log_info (snapshot_path);
428+ log_error (" Repo not found (locally nor online): " + repo_id);
429+ result.success = false ;
430+ return result;
361431 }
362432
363- log_error (" CURL metadata request failed: " +
433+ std::string outdated_file = find_outdated_file (snapshot_path, filename);
434+ if (!outdated_file.empty ()) {
435+ log_info (" Using outdated cached file " + outdated_file);
436+ result.path = outdated_file;
437+ result.success = true ;
438+ return result;
439+ }
440+
441+ log_error (" Error getting model: " + std::string (curl_easy_strerror (err)));
442+ result.success = false ;
443+ return result;
444+ }
445+
446+ std::string latest_commit = std::get<std::string>(commit_result);
447+ if (latest_commit.empty ()) {
448+ log_error (" Failed to retrieve the latest commit for repository: " +
449+ repo_id);
450+ result.success = false ;
451+ return result;
452+ }
453+
454+ // Check file accessibility
455+ auto metadata_result = get_model_metadata_from_hf (repo_id, filename);
456+
457+ if (std::holds_alternative<CURLcode>(metadata_result)) {
458+ CURLcode err = std::get<CURLcode>(metadata_result);
459+ log_error (" Error getting metadata: " +
364460 std::string (curl_easy_strerror (err)));
365461 result.success = false ;
366462 return result;
367463 }
368464
369- // 2. Create Cache Dir Struct
465+ // Create Cache Dir Struct
370466 std::string cache_model_dir = create_cache_system (cache_dir, repo_id);
371467 log_debug (" Cache directory: " + cache_model_dir);
372- log_info (" Downloading " + filename + " from " + repo_id);
373468
374469 struct FileMetadata metadata = std::get<struct FileMetadata >(metadata_result);
375- log_debug (" Commit: " + metadata. commit );
470+ log_debug (" Commit: " + latest_commit );
376471 log_debug (" Blob ID: " + metadata.oid );
377472 log_debug (" Size: " + std::to_string (metadata.size ) + " bytes" );
378473 log_debug (" SHA256: " + metadata.sha256 );
@@ -391,29 +486,22 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
391486 }
392487
393488 std::filesystem::path snapshot_file_path (cache_model_dir + " snapshots/" +
394- metadata. commit + " /" + filename);
489+ latest_commit + " /" + filename);
395490 std::filesystem::path refs_file_path (cache_model_dir + " refs/main" );
396491
397492 result.path = snapshot_file_path;
398493
494+ std::ofstream refs_file (refs_file_path);
495+ refs_file << latest_commit << std::endl;
496+ refs_file.close ();
497+
399498 if (std::filesystem::exists (snapshot_file_path) &&
400499 std::filesystem::exists (blob_file_path) && !force_download) {
401- log_info (" Snapshot file exists. Skipping download.. ." );
500+ log_info (" Snapshot file exists. Using cached file ." );
402501 return result;
403502 }
404503
405- if (std::filesystem::exists (refs_file_path)) {
406- std::ifstream refs_file (refs_file_path);
407- std::string commit;
408- refs_file >> commit;
409- refs_file.close ();
410- } else {
411- std::ofstream refs_file (refs_file_path);
412- refs_file << metadata.commit ;
413- refs_file.close ();
414- }
415-
416- // 3. Download the file
504+ // 4. Download the file
417505 std::string url =
418506 " https://huggingface.co/" + repo_id + " /resolve/main/" + filename;
419507 std::filesystem::create_directories (snapshot_file_path.parent_path ());
0 commit comments