Skip to content

Commit 1d2a1ab

Browse files
model : support Rnj-1 (ggml-org#17811)
* add support for rnj1 * refactor gemma3 to support rnj-1 * address review comments
1 parent c8554b6 commit 1d2a1ab

File tree

5 files changed

+76
-24
lines changed

5 files changed

+76
-24
lines changed

convert_hf_to_gguf.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5825,9 +5825,11 @@ class Gemma3Model(TextModel):
58255825
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
58265826

58275827
def set_vocab(self):
5828-
self._set_vocab_sentencepiece()
5829-
5830-
self.gguf_writer.add_add_space_prefix(False)
5828+
if (self.dir_model / "tokenizer.model").is_file():
5829+
self._set_vocab_sentencepiece()
5830+
self.gguf_writer.add_add_space_prefix(False)
5831+
else:
5832+
self._set_vocab_gpt2()
58315833

58325834
def set_gguf_parameters(self):
58335835
hparams = self.hparams
@@ -5845,13 +5847,24 @@ def set_gguf_parameters(self):
58455847
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
58465848
# attn_logit_softcapping is removed in Gemma3
58475849
assert hparams.get("attn_logit_softcapping") is None
5848-
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
5850+
if (final_logit_softcap := hparams.get("final_logit_softcapping")):
5851+
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
5852+
if hparams.get("sliding_window_pattern") != 1:
5853+
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
58495854
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
58505855
if hparams.get("rope_scaling") is not None:
5851-
assert hparams["rope_scaling"]["rope_type"] == "linear"
5852-
# important: this rope_scaling is only applied for global layers, and not used by 1B model
5853-
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
5854-
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
5856+
rope_scaling = hparams["rope_scaling"]
5857+
if rope_scaling["rope_type"] == "linear":
5858+
# important: this rope_scaling is only applied for global layers, and not used by 1B model
5859+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
5860+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
5861+
elif rope_scaling["rope_type"] == "yarn":
5862+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
5863+
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
5864+
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
5865+
self.gguf_writer.add_rope_scaling_yarn_ext_factor(rope_scaling["extrapolation_factor"])
5866+
self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_scaling["beta_fast"])
5867+
self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_scaling["beta_slow"])
58555868

58565869
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
58575870
del bid # unused
@@ -5865,8 +5878,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
58655878

58665879
# remove OOV (out-of-vocabulary) rows in token_embd
58675880
if "embed_tokens.weight" in name:
5868-
vocab = self._create_vocab_sentencepiece()
5869-
tokens = vocab[0]
5881+
if (self.dir_model / "tokenizer.model").is_file():
5882+
tokens = self._create_vocab_sentencepiece()[0]
5883+
else:
5884+
tokens = self.get_vocab_base()[0]
58705885
data_torch = data_torch[:len(tokens)]
58715886

58725887
# ref code in Gemma3RMSNorm

src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ add_library(llama
6767
models/gemma-embedding.cpp
6868
models/gemma.cpp
6969
models/gemma2-iswa.cpp
70-
models/gemma3-iswa.cpp
70+
models/gemma3.cpp
7171
models/gemma3n-iswa.cpp
7272
models/glm4-moe.cpp
7373
models/glm4.cpp

src/llama-model.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,18 +1264,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
12641264
} break;
12651265
case LLM_ARCH_GEMMA3:
12661266
{
1267-
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1268-
hparams.set_swa_pattern(6);
1267+
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
1268+
if (found_swa && hparams.n_swa > 0) {
1269+
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1270+
hparams.set_swa_pattern(6);
12691271

1270-
hparams.rope_freq_base_train_swa = 10000.0f;
1271-
hparams.rope_freq_scale_train_swa = 1.0f;
1272+
hparams.rope_freq_base_train_swa = 10000.0f;
1273+
hparams.rope_freq_scale_train_swa = 1.0f;
1274+
} else {
1275+
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
1276+
}
12721277

1273-
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
1278+
hparams.f_final_logit_softcapping = 0.0f;
1279+
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
12741280
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
12751281

12761282
switch (hparams.n_layer) {
12771283
case 18: type = LLM_TYPE_270M; break;
12781284
case 26: type = LLM_TYPE_1B; break;
1285+
case 32: type = LLM_TYPE_8B; break; // Rnj-1
12791286
case 34: type = LLM_TYPE_4B; break;
12801287
case 48: type = LLM_TYPE_12B; break;
12811288
case 62: type = LLM_TYPE_27B; break;
@@ -7304,7 +7311,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
73047311
} break;
73057312
case LLM_ARCH_GEMMA3:
73067313
{
7307-
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params);
7314+
if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
7315+
llm = std::make_unique<llm_build_gemma3<true>>(*this, params);
7316+
} else {
7317+
llm = std::make_unique<llm_build_gemma3<false>>(*this, params);
7318+
}
73087319
} break;
73097320
case LLM_ARCH_GEMMA3N:
73107321
{

src/models/gemma3-iswa.cpp renamed to src/models/gemma3.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "models.h"
22

3-
llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
3+
template <bool iswa>
4+
llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
45
const int64_t n_embd_head = hparams.n_embd_head_k;
56

67
ggml_tensor * cur;
@@ -17,13 +18,28 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
1718
ggml_tensor * inp_pos = build_inp_pos();
1819

1920
// TODO: is causal == true correct? might need some changes
20-
auto * inp_attn = build_attn_inp_kv_iswa();
21+
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
22+
inp_attn_type * inp_attn = nullptr;
23+
24+
if constexpr (iswa) {
25+
inp_attn = build_attn_inp_kv_iswa();
26+
} else {
27+
inp_attn = build_attn_inp_kv();
28+
}
2129

2230
ggml_tensor * inp_out_ids = build_inp_out_ids();
2331

2432
for (int il = 0; il < n_layer; ++il) {
25-
const float freq_base_l = model.get_rope_freq_base (cparams, il);
26-
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
33+
float freq_base_l = 0.0f;
34+
float freq_scale_l = 0.0f;
35+
36+
if constexpr (iswa) {
37+
freq_base_l = model.get_rope_freq_base (cparams, il);
38+
freq_scale_l = model.get_rope_freq_scale(cparams, il);
39+
} else {
40+
freq_base_l = freq_base;
41+
freq_scale_l = freq_scale;
42+
}
2743

2844
// norm
2945
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
@@ -102,7 +118,7 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
102118
cur = build_norm(cur,
103119
model.layers[il].ffn_post_norm, NULL,
104120
LLM_NORM_RMS, -1);
105-
cb(cur, "ffn_post_norm", -1);
121+
cb(cur, "ffn_post_norm", il);
106122

107123
cur = ggml_add(ctx0, cur, sa_out);
108124

@@ -124,8 +140,17 @@ llm_build_gemma3_iswa::llm_build_gemma3_iswa(const llama_model & model, const ll
124140
// lm_head
125141
cur = build_lora_mm(model.output, cur);
126142

143+
if (hparams.f_final_logit_softcapping) {
144+
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
145+
cur = ggml_tanh(ctx0, cur);
146+
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
147+
}
148+
127149
cb(cur, "result_output", -1);
128150
res->t_logits = cur;
129151

130152
ggml_build_forward_expand(gf, cur);
131153
}
154+
155+
template struct llm_build_gemma3<false>;
156+
template struct llm_build_gemma3<true>;

src/models/models.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,9 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
179179
llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params);
180180
};
181181

182-
struct llm_build_gemma3_iswa : public llm_graph_context {
183-
llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params);
182+
template <bool iswa>
183+
struct llm_build_gemma3 : public llm_graph_context {
184+
llm_build_gemma3(const llama_model & model, const llm_graph_params & params);
184185
};
185186

186187
struct llm_build_gemma3n_iswa : public llm_graph_context {

0 commit comments

Comments
 (0)