Skip to content

Commit 65b536f

Browse files
committed
Merge branch 'master' into xsn/oai_moe
2 parents 44bdb75 + 03d4698 commit 65b536f

39 files changed

+1321
-321
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: Check Pre-Tokenizer Hashes
2+
3+
on:
4+
push:
5+
paths:
6+
- 'convert_hf_to_gguf.py'
7+
- 'convert_hf_to_gguf_update.py'
8+
pull_request:
9+
paths:
10+
- 'convert_hf_to_gguf.py'
11+
- 'convert_hf_to_gguf_update.py'
12+
13+
jobs:
14+
pre-tokenizer-hashes:
15+
runs-on: ubuntu-latest
16+
17+
steps:
18+
- name: Checkout repository
19+
uses: actions/checkout@v4
20+
21+
- name: Set up Python
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: '3.11'
25+
26+
- name: Install Python dependencies
27+
run: |
28+
python3 -m venv .venv
29+
.venv/bin/pip install -r requirements/requirements-convert_hf_to_gguf_update.txt
30+
31+
- name: Update pre-tokenizer hashes
32+
run: |
33+
cp convert_hf_to_gguf.py /tmp
34+
.venv/bin/python convert_hf_to_gguf_update.py --check-missing
35+
36+
- name: Check if committed pre-tokenizer hashes matches generated version
37+
run: |
38+
if ! diff -q convert_hf_to_gguf.py /tmp/convert_hf_to_gguf.py; then
39+
echo "Model pre-tokenizer hashes (in convert_hf_to_gguf.py) do not match generated hashes (from convert_hf_to_gguf_update.py)."
40+
echo "To fix: run ./convert_hf_to_gguf_update.py and commit the updated convert_hf_to_gguf.py along with your changes"
41+
echo "Differences found:"
42+
diff convert_hf_to_gguf.py /tmp/convert_hf_to_gguf.py || true
43+
exit 1
44+
fi
45+
echo "Model pre-tokenizer hashes are up to date."

common/chat.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,7 +1667,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16671667
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
16681668
);
16691669

1670-
if (auto res = builder.try_find_regex(open_regex)) {
1670+
while (auto res = builder.try_find_regex(open_regex)) {
16711671
const auto & block_start = res->groups[1];
16721672
std::string block_end = block_start.empty() ? "" : "```";
16731673

@@ -1689,7 +1689,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16891689
builder.consume_literal(block_end);
16901690
builder.consume_spaces();
16911691
}
1692-
builder.add_content(builder.consume_rest());
16931692
} else {
16941693
throw common_chat_msg_partial_exception("failed to parse tool call");
16951694
}
@@ -1714,11 +1713,10 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
17141713
builder.consume_spaces();
17151714
}
17161715
}
1717-
builder.add_content(builder.consume_rest());
17181716
}
1719-
} else {
1720-
builder.add_content(builder.consume_rest());
17211717
}
1718+
1719+
builder.add_content(builder.consume_rest());
17221720
}
17231721

17241722
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {

convert_hf_to_gguf.py

Lines changed: 98 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
684684
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
685685
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
686686
res = "hunyuan"
687+
if chkhsh == "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6":
688+
# ref: https://huggingface.co/tencent/Hunyuan-4B-Instruct
689+
res = "hunyuan-dense"
687690
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
688691
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
689692
res = "falcon-h1"
@@ -699,6 +702,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
699702
if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
700703
# ref: https://huggingface.co/moonshotai/Kimi-K2-Base
701704
res = "kimi-k2"
705+
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
706+
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
707+
res = "qwen2"
702708
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
703709
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
704710
res = "llama-bpe"
@@ -7553,11 +7559,6 @@ def set_gguf_parameters(self):
75537559
class HunYuanMoEModel(TextModel):
75547560
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
75557561

7556-
def __init__(self, *args, **kwargs):
7557-
super().__init__(*args, **kwargs)
7558-
# For handling tied embeddings
7559-
self._tok_embd = None
7560-
75617562
def set_vocab(self):
75627563
from transformers import AutoTokenizer
75637564
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
@@ -7651,9 +7652,6 @@ def set_gguf_parameters(self):
76517652
_experts: list[dict[str, Tensor]] | None = None
76527653

76537654
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
7654-
if name == "model.embed_tokens.weight":
7655-
self._tok_embd = data_torch.clone()
7656-
76577655
if name == "lm_head.weight":
76587656
if self.hparams.get("tie_word_embeddings", False):
76597657
logger.info("Skipping tied output layer 'lm_head.weight'")
@@ -7698,6 +7696,98 @@ def prepare_tensors(self):
76987696
raise ValueError(f"Unprocessed experts: {experts}")
76997697

77007698

7699+
@ModelBase.register("HunYuanDenseV1ForCausalLM")
7700+
class HunYuanModel(TextModel):
7701+
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
7702+
7703+
def set_vocab(self):
7704+
if (self.dir_model / "tokenizer.json").is_file():
7705+
self._set_vocab_gpt2()
7706+
else:
7707+
from transformers import AutoTokenizer
7708+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
7709+
7710+
# 1. Get the pre-tokenizer identifier hash
7711+
tokpre = self.get_vocab_base_pre(tokenizer)
7712+
7713+
# 2. Reverse-engineer the merges list from mergeable_ranks
7714+
merges = []
7715+
vocab = {}
7716+
mergeable_ranks = tokenizer.mergeable_ranks
7717+
for token, rank in mergeable_ranks.items():
7718+
vocab[QwenModel.token_bytes_to_string(token)] = rank
7719+
if len(token) == 1:
7720+
continue
7721+
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
7722+
if len(merged) == 2:
7723+
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
7724+
7725+
# 3. Generate the tokens and toktypes lists
7726+
vocab_size = self.hparams["vocab_size"]
7727+
assert tokenizer.vocab_size == vocab_size
7728+
special_tokens = tokenizer.special_tokens
7729+
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
7730+
tokens: list[str] = []
7731+
toktypes: list[int] = []
7732+
for i in range(vocab_size):
7733+
if i not in reverse_vocab:
7734+
tokens.append(f"[PAD{i}]")
7735+
toktypes.append(gguf.TokenType.UNUSED)
7736+
else:
7737+
token = reverse_vocab[i]
7738+
tokens.append(token)
7739+
if i in special_tokens.values():
7740+
toktypes.append(gguf.TokenType.CONTROL)
7741+
else:
7742+
toktypes.append(gguf.TokenType.NORMAL)
7743+
7744+
# 4. Write all vocab-related fields to the GGUF writer
7745+
self.gguf_writer.add_tokenizer_model("gpt2")
7746+
self.gguf_writer.add_tokenizer_pre(tokpre)
7747+
self.gguf_writer.add_token_list(tokens)
7748+
self.gguf_writer.add_token_types(toktypes)
7749+
self.gguf_writer.add_token_merges(merges)
7750+
7751+
# 5. Add special tokens and chat templates
7752+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
7753+
special_vocab.add_to_gguf(self.gguf_writer)
7754+
# FIX for BOS token: Overwrite incorrect id read from config.json
7755+
if self.hparams['hidden_size'] == 4096:
7756+
self.gguf_writer.add_bos_token_id(127958) # only for 7b dense, fix <|bos|> token
7757+
7758+
def set_gguf_parameters(self):
7759+
super().set_gguf_parameters()
7760+
hparams = self.hparams
7761+
7762+
# Rope
7763+
rope_scaling = hparams.get("rope_scaling", {})
7764+
if rope_scaling.get("type") == "dynamic":
7765+
# HunYuan uses NTK Aware Alpha based scaling. Original implementation: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
7766+
# 1000 corresponds to a usable context length of 256k (https://github.com/Tencent-Hunyuan/Hunyuan-A13B/blob/main/report/Hunyuan_A13B_Technical_Report.pdf)
7767+
alpha = rope_scaling.get("alpha", 50)
7768+
base = hparams.get("rope_theta", 10000.0)
7769+
dim = hparams["head_dim"]
7770+
scaled_base = base * (alpha ** (dim / (dim - 2)))
7771+
self.gguf_writer.add_rope_freq_base(scaled_base)
7772+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
7773+
self.gguf_writer.add_rope_scaling_factor(1)
7774+
# There is no consistent way to calculate ctx from alpha, and the config is incorrectly set to 32k
7775+
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024) # 256k context length
7776+
self.gguf_writer.add_context_length(256 * 1024) # 256k context length
7777+
7778+
# if any of our assumptions about the values are wrong, something has changed and this may need to be updated
7779+
assert base == 10000.0 and self.hparams["max_position_embeddings"] in [32 * 1024, 256 * 1024] , \
7780+
"HunYuan dynamic RoPE scaling assumptions changed, please update the logic or context length manually"
7781+
7782+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
7783+
if name == "lm_head.weight":
7784+
if self.hparams.get("tie_word_embeddings", False):
7785+
logger.info("Skipping tied output layer 'lm_head.weight'")
7786+
return []
7787+
7788+
return [(self.map_tensor_name(name), data_torch)]
7789+
7790+
77017791
@ModelBase.register("SmolLM3ForCausalLM")
77027792
class SmolLM3Model(LlamaModel):
77037793
model_arch = gguf.MODEL_ARCH.SMOLLM3

convert_hf_to_gguf_update.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class TOKENIZER_TYPE(IntEnum):
5959
"--full", action="store_true",
6060
help="download full list of models - make sure you have access to all of them",
6161
)
62+
parser.add_argument(
63+
"--check-missing", action="store_true",
64+
help="only check for missing pre-tokenizer hashes",
65+
)
6266
parser.add_argument(
6367
"hf_token",
6468
help="optional HF token",
@@ -70,6 +74,10 @@ class TOKENIZER_TYPE(IntEnum):
7074
if hf_token is None:
7175
logger.warning("HF token not found. You can provide it as an argument or set it in ~/.cache/huggingface/token")
7276

77+
if args.check_missing and args.full:
78+
logger.warning("Downloading full list of models requested, ignoring --check-missing!")
79+
args.check_missing = False
80+
7381
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
7482
# will be updated with time - contributions welcome
7583
CHK_TXT = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL'
@@ -140,12 +148,14 @@ class TOKENIZER_TYPE(IntEnum):
140148
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
141149
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
142150
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
151+
{"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"},
143152
# falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes
144153
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"},
145154
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
146155
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
147156
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
148157
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
158+
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
149159
]
150160

151161

@@ -220,12 +230,13 @@ def get_existing_models(convert_py):
220230
all_models = models.copy()
221231
models = [model for model in all_models if model["name"] not in existing_models]
222232

223-
logging.info(f"Downloading {len(models)} models...")
224-
for model in models:
225-
try:
226-
download_model(model)
227-
except Exception as e:
228-
logger.error(f"Failed to download model {model['name']}. Error: {e}")
233+
if not args.check_missing:
234+
logging.info(f"Downloading {len(models)} models...")
235+
for model in models:
236+
try:
237+
download_model(model)
238+
except Exception as e:
239+
logger.error(f"Failed to download model {model['name']}. Error: {e}")
229240

230241

231242
# generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function:

ggml/src/ggml-cuda/fattn.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
326326

327327
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
328328
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
329-
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies &&
330-
(Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
329+
const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
330+
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
331+
(cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
331332
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
332333
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
333334
if (prec == GGML_PREC_DEFAULT) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,6 +1852,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18521852
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
18531853
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
18541854

1855+
bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
1856+
bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
1857+
18551858
// Handle src0
18561859
src0_ptr = (const cuda_t *) src0->data;
18571860

@@ -1870,6 +1873,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
18701873
s11 = ne10;
18711874
s12 = ne11*s11;
18721875
s13 = ne12*s12;
1876+
1877+
is_src1_cont_2 = true;
18731878
}
18741879

18751880
// Setup destination buffer
@@ -1918,15 +1923,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
19181923
const int64_t r2 = ne12/ne02;
19191924
const int64_t r3 = ne13/ne03;
19201925

1921-
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1926+
if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1927+
// with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1928+
const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1929+
const int64_t smb = ne12 == 1 ? s13 : s12;
1930+
19221931
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
19231932
// use cublasGemmStridedBatchedEx
19241933
CUBLAS_CHECK(
19251934
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
19261935
ne01, ne11, ne10,
1927-
alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
1928-
src1_ptr, cu_data_type_b, s11, s12, // strideB
1929-
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
1936+
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
1937+
src1_ptr, cu_data_type_b, s11, smb, // strideB
1938+
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
19301939
ne12*ne13,
19311940
cu_compute_type,
19321941
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

0 commit comments

Comments
 (0)