Skip to content

Commit cfd77a3

Browse files
authored
Bug fix on commit change (#3)
hf_hub_download returns a non existing file path if a file has been previous download and now the repo is on a new commit
1 parent bc92a03 commit cfd77a3

File tree

2 files changed

+79
-66
lines changed

2 files changed

+79
-66
lines changed

include/huggingface_hub.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ get_model_metadata_from_hf(const std::string &repo, const std::string &file);
9393
struct DownloadResult
9494
hf_hub_download(const std::string &repo_id, const std::string &filename,
9595
const std::string &cache_dir = "~/.cache/huggingface/hub",
96-
bool force_download = false);
96+
bool force_download = false, bool verbose = false);
9797

9898
/**
9999
* @brief Download a file from Hugging Face Hub.

src/huggingface_hub.cpp

Lines changed: 78 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ namespace huggingface_hub {
4141
volatile sig_atomic_t stop_download = 0;
4242
void handle_sigint(int) { stop_download = 1; }
4343

44+
bool log_verbose = false;
45+
46+
void log_debug(const std::string &message) {
47+
if (!log_verbose) {
48+
return;
49+
}
50+
fprintf(stderr, "[DEBUG] %s\n", message.c_str());
51+
fflush(stderr);
52+
}
53+
4454
void log_info(const std::string &message) {
4555
fprintf(stderr, "[INFO] %s\n", message.c_str());
4656
fflush(stderr);
@@ -246,11 +256,55 @@ int progress_callback(void *userdata, curl_off_t total, curl_off_t now,
246256
return 0; // Continue downloading
247257
}
248258

259+
CURLcode perform_download(std::string url,
260+
std::string blob_incomplete_file_path,
261+
bool force_download, struct FileMetadata metadata) {
262+
CURL *curl = curl_easy_init();
263+
if (!curl) {
264+
return CURLE_FAILED_INIT;
265+
}
266+
267+
std::ofstream file(blob_incomplete_file_path,
268+
std::ios::binary | std::ios::app);
269+
270+
if (!file.is_open()) {
271+
log_error("Error: failed to open file stream!");
272+
return CURLE_FAILED_INIT;
273+
}
274+
275+
curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); // Set URL
276+
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); // Follow redirects
277+
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION,
278+
write_file_data); // Write data to file
279+
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &file); // File stream
280+
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); // Enable progress callback
281+
curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION,
282+
progress_callback); // Progress callback
283+
curl_easy_setopt(curl, CURLOPT_PROGRESSDATA, &metadata);
284+
285+
// Resume download if file exists
286+
long existing_size = get_file_size(blob_incomplete_file_path);
287+
if (existing_size > 0 && !force_download) {
288+
curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE,
289+
(curl_off_t)existing_size);
290+
log_info("Resuming download from " + std::to_string(existing_size) +
291+
" bytes...");
292+
}
293+
294+
fprintf(stderr, "\n"); // New line after progress bar
295+
CURLcode res = curl_easy_perform(curl);
296+
fprintf(stderr, "\n"); // New line after progress bar
297+
curl_easy_cleanup(curl);
298+
file.close();
299+
return res;
300+
}
301+
249302
struct DownloadResult hf_hub_download(const std::string &repo_id,
250303
const std::string &filename,
251304
const std::string &cache_dir,
252-
bool force_download) {
305+
bool force_download, bool verbose) {
253306
signal(SIGINT, handle_sigint);
307+
log_verbose = verbose;
254308

255309
struct DownloadResult result;
256310
result.success = true;
@@ -265,13 +319,13 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
265319

266320
// 2. Create Cache Dir Struct
267321
std::string cache_model_dir = create_cache_system(cache_dir, repo_id);
268-
log_info("Cache directory: " + cache_model_dir);
322+
log_debug("Cache directory: " + cache_model_dir);
269323
log_info("Downloading " + filename + " from " + repo_id);
270324

271325
struct FileMetadata metadata = std::get<struct FileMetadata>(metadata_result);
272-
log_info("SHA256: " + metadata.sha256);
273-
log_info("Commit: " + metadata.commit);
274-
log_info("Size: " + std::to_string(metadata.size) + " bytes");
326+
log_debug("SHA256: " + metadata.sha256);
327+
log_debug("Commit: " + metadata.commit);
328+
log_debug("Size: " + std::to_string(metadata.size) + " bytes");
275329

276330
std::filesystem::path blob_file_path(cache_model_dir + "blobs/" +
277331
metadata.sha256);
@@ -283,8 +337,9 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
283337

284338
result.path = snapshot_file_path;
285339

286-
if (std::filesystem::exists(blob_file_path) && !force_download) {
287-
log_info("Blob file exists. Skipping download...");
340+
if (std::filesystem::exists(snapshot_file_path) &&
341+
std::filesystem::exists(blob_file_path) && !force_download) {
342+
log_info("Snapshot file exists. Skipping download...");
288343
return result;
289344
}
290345

@@ -300,77 +355,35 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
300355
}
301356

302357
// 3. Download the file
303-
CURL *curl = curl_easy_init();
304-
if (!curl) {
305-
result.success = false;
306-
return result;
307-
}
308-
309358
std::string url =
310359
"https://huggingface.co/" + repo_id + "/resolve/main/" + filename;
311-
312-
std::ofstream file(blob_incomplete_file_path,
313-
std::ios::binary | std::ios::app);
314-
315-
if (!file.is_open()) {
316-
log_error("Failed to open file: " + blob_incomplete_file_path.string());
317-
curl_easy_cleanup(curl);
318-
result.success = false;
319-
return result;
320-
}
321-
322-
curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); // Set URL
323-
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); // Follow redirects
324-
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION,
325-
write_file_data); // Write data to file
326-
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &file); // File stream
327-
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); // Enable progress callback
328-
curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION,
329-
progress_callback); // Progress callback
330-
curl_easy_setopt(curl, CURLOPT_PROGRESSDATA, &metadata);
331-
332-
// Resume download if file exists
333-
long existing_size = get_file_size(blob_incomplete_file_path);
334-
if (existing_size > 0 && !force_download) {
335-
curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE,
336-
(curl_off_t)existing_size);
337-
log_info("Resuming download from " + std::to_string(existing_size) +
338-
" bytes...");
339-
}
340-
341-
fprintf(stderr, "\n");
342-
CURLcode res = curl_easy_perform(curl);
343-
fprintf(stderr, "\n");
344-
345360
std::filesystem::create_directories(snapshot_file_path.parent_path());
346361

347-
if (stop_download) {
348-
log_info("Download interrupted. Exiting...");
349-
file.close();
350-
curl_easy_cleanup(curl);
351-
result.success = false;
352-
return result;
353-
} else if (res != CURLE_OK) {
354-
log_error("CURL request failed: " + std::string(curl_easy_strerror(res)));
355-
file.close();
356-
curl_easy_cleanup(curl);
357-
result.success = false;
358-
return result;
362+
if (!std::filesystem::exists(blob_file_path) || force_download) {
363+
CURLcode res = perform_download(url, blob_incomplete_file_path,
364+
force_download, metadata);
365+
result.success = res == CURLE_OK;
366+
367+
if (stop_download) {
368+
log_info("Download interrupted. Exiting...");
369+
return result;
370+
} else if (!result.success) {
371+
log_error("CURL request failed: " + std::string(curl_easy_strerror(res)));
372+
return result;
373+
} else {
374+
std::filesystem::rename(blob_incomplete_file_path, blob_file_path);
375+
}
359376
}
360377

361378
if (std::filesystem::exists(snapshot_file_path)) {
362-
log_info("Snapshot file exists. Deleting...");
379+
log_debug("Snapshot file exists. Deleting...");
363380
std::filesystem::remove(snapshot_file_path);
364381
}
365-
std::filesystem::rename(blob_incomplete_file_path, blob_file_path);
366382
std::filesystem::create_symlink(blob_file_path, snapshot_file_path);
367383

368-
file.close();
369-
curl_easy_cleanup(curl);
370-
371384
log_info("Downloaded to: " + snapshot_file_path.string());
372385

373-
result.success = res == CURLE_OK;
386+
result.success = true;
374387
return result;
375388
}
376389

0 commit comments

Comments
 (0)