Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 123 additions & 7 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3328,7 +3328,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
@ModelBase.register("InternVisionModel")
class InternVisionModel(MmprojModel):
def set_gguf_parameters(self):
assert self.hparams_vision is not None
if isinstance(self.hparams_vision['image_size'], list):
self.hparams_vision['image_size'] = self.hparams_vision['image_size'][0]
if isinstance(self.hparams_vision['patch_size'], list):
self.hparams_vision['patch_size'] = self.hparams_vision['patch_size'][0]
super().set_gguf_parameters()

hparams = self.hparams
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
Expand All @@ -3352,14 +3358,30 @@ def tensor_force_quant(self, name, new_name, bid, n_dims):
return gguf.GGMLQuantizationType.F32
return False

def _mapping_interns1_name(self, name):
names_map = {
"model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias",
"model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight",
"model.multi_modal_projector.linear_1.bias": "mlp1.1.bias",
"model.multi_modal_projector.linear_1.weight": "mlp1.1.weight",
"model.multi_modal_projector.linear_2.bias": "mlp1.3.bias",
"model.multi_modal_projector.linear_2.weight": "mlp1.3.weight",
}
if name in names_map:
name = names_map[name]
return name

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("vision_model") or name.startswith("mlp"):
vision_prefix = ['vision_model', 'mlp', 'model.vision_tower', 'model.multi_modal_projector']
# deal with intern-s1 special case
name = self._mapping_interns1_name(name)
if any([name.startswith(prefix) for prefix in vision_prefix]):
# process visual tensors
# correct name
if name.startswith("vision_model"):
name = "vision_tower." + name
if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"):
if (".ls" in name or ".lambda_" in name or "position_embedding" in name) and not name.endswith(".weight"):
name += ".weight"
# split QKV tensors if needed
if ".qkv." in name:
Expand Down Expand Up @@ -3445,6 +3467,10 @@ def set_gguf_parameters(self):

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
name = name.replace("language_model.", "") # InternVL
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector"):
# skip visual tensors
return []
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
Expand Down Expand Up @@ -3498,6 +3524,85 @@ class Qwen3Model(Qwen2Model):
class Qwen3MoeModel(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3MOE

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hparams = ModelBase.load_hparams(self.dir_model)
self.origin_hf_arch = hparams.get('architectures', [None])[0]

def set_vocab(self):
# deal with intern-s1
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
self._set_vocab_interns1()
return

try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
self._set_vocab_gpt2()

def _set_vocab_interns1(self):
tokens: list[str] = []
toktypes: list[int] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
vocab = getattr(tokenizer, 'vocab', tokenizer.get_vocab())
vocab_size = self.hparams.get("vocab_size", len(vocab))
assert max(vocab.values()) < vocab_size

tokpre = self.get_vocab_base_pre(tokenizer)

reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab.items()}
added_vocab = tokenizer.get_added_vocab()

added_tokens_decoder = tokenizer.added_tokens_decoder

for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")

if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_tokens_map_file = self.dir_model / 'special_tokens_map.json'
additional_special_tokens = []
if special_tokens_map_file.is_file():
with open(special_tokens_map_file, encoding = 'utf-8') as f:
additional_special_tokens = json.load(f).get('additional_special_tokens', [])
tokenizer_cfg_file = self.dir_model / 'special_tokens_map.json'
if tokenizer_cfg_file.is_file():
with open(tokenizer_cfg_file, encoding = 'utf-8') as f:
added_tokens_decoder = json.load(f).get('added_tokens_decoder', {})
token2ids_map = {data['content'] : int(token) for token, data in added_tokens_decoder.items() if data['special']}
for token in additional_special_tokens:
if token in token2ids_map:
special_vocab._set_special_token(token, token2ids_map[token])
special_vocab._set_special_token('eos', 151645)
special_vocab._set_special_token("bos", 151643)
special_vocab.add_to_gguf(self.gguf_writer)


@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
Expand Down Expand Up @@ -7997,15 +8102,13 @@ def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
blocks0: Tensor = torch.zeros(1)
blocks1: Tensor = torch.zeros(1)
found_mxfp4_tensors = False
# we assume that tensors are loaded in the correct order
for name, data_torch in self.get_tensors():
if "mlp.experts.down_proj_blocks" in name:
blocks0 = data_torch
elif "mlp.experts.down_proj_scales" in name:
new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
self.repack_mxfp4(new_name, blocks0, data_torch)
found_mxfp4_tensors = True
elif "mlp.experts.gate_up_proj_blocks" in name:
blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
elif "mlp.experts.gate_up_proj_scales" in name:
Expand All @@ -8014,9 +8117,6 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
self.repack_mxfp4(new_name_gate, blocks0, scales0)
self.repack_mxfp4(new_name_up, blocks1, scales1)
found_mxfp4_tensors = True
if not found_mxfp4_tensors:
raise ValueError("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.")
return []

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
Expand All @@ -8029,7 +8129,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if "down_proj" in name:
if name.endswith("_bias"):
name = name.replace("down_proj_bias", "down_proj.bias")
elif "_blocks" not in name and "_scales" not in name:
logger.warning(f"{name} is not in MXFP4, performance may be degraded")
name = name.replace("down_proj", "down_proj.weight")
data_torch = data_torch.transpose(-1, -2)
else:
# otherwise, it should already be repacked to ggml MXFP4 format
return []

# split the gate_up into gate and up
Expand All @@ -8042,7 +8147,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
(self.map_tensor_name(name_gate), gate_proj_bias),
(self.map_tensor_name(name_up), up_proj_bias)
]
elif "_blocks" not in name and "_scales" not in name:
logger.warning(f"{name} is not in MXFP4, performance may be degraded")
name_up = name.replace("gate_up_proj", "up_proj.weight")
name_gate = name.replace("gate_up_proj", "gate_proj.weight")
data_torch = data_torch.transpose(-1, -2)
gate_proj_weight, up_proj_weight = data_torch[:, ::2, :], data_torch[:, 1::2, :]
return [
(self.map_tensor_name(name_gate), gate_proj_weight),
(self.map_tensor_name(name_up), up_proj_weight)
]
else:
# otherwise, it should already be repacked to ggml MXFP4 format
return []

return [(self.map_tensor_name(name), data_torch)]
Expand Down
1 change: 1 addition & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM"
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF)
option(GGML_VULKAN "ggml: use Vulkan" OFF)
Expand Down
82 changes: 42 additions & 40 deletions ggml/cmake/ggml-config.cmake.in
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ if(NOT TARGET ggml::ggml)

find_library(GGML_LIBRARY ggml
REQUIRED
HINTS ${GGML_LIB_DIR} ${GGML_BACKEND_DIR}
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)

add_library(ggml::ggml UNKNOWN IMPORTED)
Expand All @@ -125,54 +125,56 @@ if(NOT TARGET ggml::ggml)
IMPORTED_LOCATION "${GGML_BASE_LIBRARY}")

set(_ggml_all_targets "")
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)

find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
REQUIRED
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)

message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")

add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
INTERFACE_COMPILE_FEATURES c_std_90
POSITION_INDEPENDENT_CODE ON)

string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
if(is_cpu_variant)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")
if (NOT GGML_BACKEND_DL)
foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS})
string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}")
string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx)

if(GGML_CPU_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
endif()
find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend}
REQUIRED
HINTS ${GGML_LIB_DIR}
NO_CMAKE_FIND_ROOT_PATH)

message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}")

else()
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")
INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}"
IMPORTED_LINK_INTERFACE_LANGUAGES "CXX"
IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}"
INTERFACE_COMPILE_FEATURES c_std_90
POSITION_INDEPENDENT_CODE ON)

string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}")
if(is_cpu_variant)
list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}")

if(GGML_CPU_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}")
endif()

if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
else()
list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base")
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}")

if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS)
set_target_properties(ggml::${_ggml_backend}
PROPERTIES
INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}")
endif()
endif()
endif()

list(APPEND _ggml_all_targets ggml::${_ggml_backend})
endforeach()
list(APPEND _ggml_all_targets ggml::${_ggml_backend})
endforeach()
endif()

list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}")
set_target_properties(ggml::ggml
Expand Down
12 changes: 10 additions & 2 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,13 @@ typedef float2 dfloat2;
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
#define TURING_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING

#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define AMPERE_MMA_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#define CP_ASYNC_AVAILABLE
#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
Expand Down Expand Up @@ -303,10 +307,14 @@ static bool amd_mfma_available(const int cc) {
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool new_mma_available(const int cc) {
static bool turing_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}

static bool ampere_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}

static bool cp_async_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
}
Expand Down
12 changes: 6 additions & 6 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
float * const __restrict__ KQ_max,
float * const __restrict__ KQ_rowsum,
const int kb0) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
typedef fattn_mma_f16_config<DKQ, DV> c;

#ifdef CP_ASYNC_AVAILABLE
Expand Down Expand Up @@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
GGML_UNUSED(kb0); GGML_UNUSED(tile_Q);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}

template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
Expand All @@ -800,7 +800,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const int jt,
const int kb0_start,
const int kb0_stop) {
#ifdef NEW_MMA_AVAILABLE
#ifdef TURING_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.

typedef fattn_mma_f16_config<DKQ, DV> c;
Expand Down Expand Up @@ -1196,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask);
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
#endif // TURING_MMA_AVAILABLE
}

template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
Expand All @@ -1223,7 +1223,7 @@ static __global__ void flash_attn_ext_f16(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)

// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
Expand Down Expand Up @@ -1354,7 +1354,7 @@ static __global__ void flash_attn_ext_f16(
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
}

template <int DKQ, int DV, int ncols1, int ncols2>
Expand Down
Loading