11from __future__ import annotations
22
33import ctypes
4+ import enum
45import os
56import pathlib
67
451452# LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3,
452453# LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN,
453454# };
454- LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = - 1
455- LLAMA_ROPE_SCALING_TYPE_NONE = 0
456- LLAMA_ROPE_SCALING_TYPE_LINEAR = 1
457- LLAMA_ROPE_SCALING_TYPE_YARN = 2
458- LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3
459- LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN
455+ class llama_rope_scaling_type (enum .IntEnum ):
456+ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = - 1
457+ LLAMA_ROPE_SCALING_TYPE_NONE = 0
458+ LLAMA_ROPE_SCALING_TYPE_LINEAR = 1
459+ LLAMA_ROPE_SCALING_TYPE_YARN = 2
460+ LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3
461+ LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN
460462
461463# enum llama_pooling_type {
462464# LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
478480# LLAMA_ATTENTION_TYPE_CAUSAL = 0,
479481# LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1,
480482# };
481- LLAMA_ATTENTION_TYPE_UNSPECIFIED = - 1
482- LLAMA_ATTENTION_TYPE_CAUSAL = 0
483- LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1
483+ class llama_attention_type (enum .IntEnum ):
484+ LLAMA_ATTENTION_TYPE_UNSPECIFIED = - 1
485+ LLAMA_ATTENTION_TYPE_CAUSAL = 0
486+ LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1
487+
488+ # enum llama_flash_attn_type {
489+ # LLAMA_FLASH_ATTN_TYPE_AUTO = -1,
490+ # LLAMA_FLASH_ATTN_TYPE_DISABLED = 0,
491+ # LLAMA_FLASH_ATTN_TYPE_ENABLED = 1,
492+ # };
493+ class llama_flash_attn_type (enum .IntEnum ):
494+ LLAMA_FLASH_ATTN_TYPE_AUTO = - 1
495+ LLAMA_FLASH_ATTN_TYPE_DISABLED = 0
496+ LLAMA_FLASH_ATTN_TYPE_ENABLED = 1
484497
498+ # LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type);
499+ @ctypes_function (
500+ "llama_flash_attn_type_name" ,
501+ [ctypes .c_int ],
502+ ctypes .c_char_p ,
503+ )
504+ def llama_flash_attn_type_name (
505+ flash_attn_type : llama_flash_attn_type , /
506+ ) -> bytes :
507+ """
508+ Gets the name of a llama_flash_attn_type.
509+ """
485510
486511# enum llama_split_mode {
487512# LLAMA_SPLIT_MODE_NONE = 0, // single GPU
@@ -793,6 +818,7 @@ class llama_model_params(ctypes.Structure):
793818# enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
794819# enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
795820# enum llama_attention_type attention_type; // attention type to use for embeddings
821+ # enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention
796822
797823# // ref: https://github.com/ggml-org/llama.cpp/pull/2054
798824# float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -818,7 +844,6 @@ class llama_model_params(ctypes.Structure):
818844# // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
819845# bool embeddings; // if true, extract embeddings (together with logits)
820846# bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
821- # bool flash_attn; // use flash attention [EXPERIMENTAL]
822847# bool no_perf; // measure performance timings
823848# bool op_offload; // offload host tensor operations to device
824849# bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
@@ -841,6 +866,7 @@ class llama_context_params(ctypes.Structure):
841866 rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
842867 pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
843868 attention_type (int): attention type to use for embeddings
869+ flash_attn_type (int): when to enable Flash Attention
844870 rope_freq_base (float): RoPE base frequency, 0 = from model
845871 rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model
846872 yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model
@@ -857,7 +883,6 @@ class llama_context_params(ctypes.Structure):
857883 abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
858884 embeddings (bool): if true, extract embeddings (together with logits)
859885 offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
860- flash_attn (bool): whether to use flash attention
861886 no_perf (bool): whether to measure performance timings
862887 op_offload(bool): whether to offload host tensor operations to device
863888 swa_full(bool): whether to use full-size SWA cache
@@ -874,6 +899,7 @@ class llama_context_params(ctypes.Structure):
874899 rope_scaling_type : int
875900 pooling_type : int
876901 attention_type : int
902+ flash_attn_type : int
877903 rope_freq_base : float
878904 rope_freq_scale : float
879905 yarn_ext_factor : float
@@ -890,7 +916,6 @@ class llama_context_params(ctypes.Structure):
890916 abort_callback_data : ctypes .c_void_p
891917 embeddings : bool
892918 offload_kqv : bool
893- flash_attn : bool
894919 no_perf : bool
895920 op_offload :bool
896921 swa_full :bool
@@ -906,6 +931,7 @@ class llama_context_params(ctypes.Structure):
906931 ("rope_scaling_type" , ctypes .c_int ),
907932 ("pooling_type" , ctypes .c_int ),
908933 ("attention_type" , ctypes .c_int ),
934+ ("flash_attn_type" , ctypes .c_int ),
909935 ("rope_freq_base" , ctypes .c_float ),
910936 ("rope_freq_scale" , ctypes .c_float ),
911937 ("yarn_ext_factor" , ctypes .c_float ),
@@ -922,7 +948,6 @@ class llama_context_params(ctypes.Structure):
922948 ("abort_callback_data" , ctypes .c_void_p ),
923949 ("embeddings" , ctypes .c_bool ),
924950 ("offload_kqv" , ctypes .c_bool ),
925- ("flash_attn" , ctypes .c_bool ),
926951 ("no_perf" , ctypes .c_bool ),
927952 ("op_offload" , ctypes .c_bool ),
928953 ("swa_full" , ctypes .c_bool ),
0 commit comments