Skip to content

Commit c6821fa

Browse files
committed
removing download_shards param
1 parent a2a5121 commit c6821fa

File tree

2 files changed

+22
-28
lines changed

2 files changed

+22
-28
lines changed

include/huggingface_hub.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,13 @@ hf_hub_download(const std::string &repo_id, const std::string &filename,
107107
* "~/.cache/huggingface/hub".
108108
* @param force_download If true, forces the download even if the file already
109109
* exists in the cache.
110-
* @param download_shards If true, download model shards using regex to find
111-
* the amount of shards the model is splitted into.
112110
* @return A DownloadResult structure containing the success status and the path
113111
* of the downloaded file.
114112
*/
115113
struct DownloadResult hf_hub_download_with_shards(
116114
const std::string &repo_id, const std::string &filename,
117115
const std::string &cache_dir = "~/.cache/huggingface/hub",
118-
bool force_download = false, bool download_shards = true);
116+
bool force_download = false);
119117

120118
#endif // HUGGINGFACE_HUB_H
121119
} // namespace huggingface_hub

src/huggingface_hub.cpp

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -377,37 +377,33 @@ struct DownloadResult hf_hub_download(const std::string &repo_id,
377377
struct DownloadResult hf_hub_download_with_shards(const std::string &repo_id,
378378
const std::string &filename,
379379
const std::string &cache_dir,
380-
bool force_download,
381-
bool download_shards) {
380+
bool force_download) {
382381

383-
if (download_shards) {
382+
std::regex pattern(R"(-([0-9]+)-of-([0-9]+)\.gguf)");
383+
std::smatch match;
384384

385-
std::regex pattern(R"(-([0-9]+)-of-([0-9]+)\.gguf)");
386-
std::smatch match;
385+
if (std::regex_search(filename, match, pattern)) {
386+
int total_shards = std::stoi(match[2]);
387+
std::string base_name = filename.substr(0, match.position(0));
387388

388-
if (std::regex_search(filename, match, pattern)) {
389-
int total_shards = std::stoi(match[2]);
390-
std::string base_name = filename.substr(0, match.position(0));
389+
// Download shards
390+
for (int i = 1; i <= total_shards; ++i) {
391+
char shard_file[256];
392+
snprintf(shard_file, sizeof(shard_file), "%s-%05d-of-%05d.gguf",
393+
base_name.c_str(), i, total_shards);
394+
auto aux_res =
395+
hf_hub_download(repo_id, shard_file, cache_dir, force_download);
391396

392-
// Download shards
393-
for (int i = 1; i <= total_shards; ++i) {
394-
char shard_file[256];
395-
snprintf(shard_file, sizeof(shard_file), "%s-%05d-of-%05d.gguf",
396-
base_name.c_str(), i, total_shards);
397-
auto aux_res =
398-
hf_hub_download(repo_id, shard_file, cache_dir, force_download);
399-
400-
if (!aux_res.success) {
401-
return aux_res;
402-
}
397+
if (!aux_res.success) {
398+
return aux_res;
403399
}
404-
405-
// Return first shard
406-
char first_shard[256];
407-
snprintf(first_shard, sizeof(first_shard), "%s-00001-of-%05d.gguf",
408-
base_name.c_str(), total_shards);
409-
return hf_hub_download(repo_id, first_shard, cache_dir, false);
410400
}
401+
402+
// Return first shard
403+
char first_shard[256];
404+
snprintf(first_shard, sizeof(first_shard), "%s-00001-of-%05d.gguf",
405+
base_name.c_str(), total_shards);
406+
return hf_hub_download(repo_id, first_shard, cache_dir, false);
411407
}
412408

413409
return hf_hub_download(repo_id, filename, cache_dir, force_download);

0 commit comments

Comments
 (0)