Skip to content

Commit 2be720d

Browse files
committed
Sync llama: use FA + max. GPU layers by default
1 parent d4f87d3 commit 2be720d

File tree

5 files changed

+78
-27
lines changed

5 files changed

+78
-27
lines changed

llama_cpp/llama.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,10 @@ def __init__(
7979
n_threads_batch: Optional[int] = None,
8080
rope_scaling_type: Optional[
8181
int
82-
] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
82+
] = llama_cpp.llama_rope_scaling_type.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
8383
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
84+
attention_type: Optional[int] = llama_cpp.llama_attention_type.LLAMA_ATTENTION_TYPE_UNSPECIFIED,
85+
flash_attn_type: Optional[int] = llama_cpp.llama_flash_attn_type.LLAMA_FLASH_ATTN_TYPE_AUTO,
8486
rope_freq_base: float = 0.0,
8587
rope_freq_scale: float = 0.0,
8688
yarn_ext_factor: float = -1.0,
@@ -91,7 +93,6 @@ def __init__(
9193
logits_all: bool = False,
9294
embedding: bool = False,
9395
offload_kqv: bool = True,
94-
flash_attn: bool = False,
9596
op_offload: Optional[bool] = None,
9697
swa_full: Optional[bool] = None,
9798
kv_unified: Optional[bool] = None,
@@ -164,6 +165,8 @@ def __init__(
164165
n_threads_batch: Number of threads to use for batch processing
165166
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggml-org/llama.cpp/pull/2054
166167
pooling_type: Pooling type, from `enum llama_pooling_type`.
168+
attention_type: attention type to use for embeddings
169+
flash_attn_type: when to enable Flash Attention
167170
rope_freq_base: RoPE base frequency, 0 = from model
168171
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
169172
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
@@ -174,7 +177,6 @@ def __init__(
174177
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
175178
embedding: Embedding mode only.
176179
offload_kqv: Offload K, Q, V to GPU.
177-
flash_attn: Use flash attention.
178180
op_offload: whether to offload host tensor operations to device
179181
swa_full: whether to use full-size SWA cache
180182
kv_unified: use single unified KV buffer for the KV cache of all sequences
@@ -318,9 +320,23 @@ def __init__(
318320
self.context_params.rope_scaling_type = (
319321
rope_scaling_type
320322
if rope_scaling_type is not None
321-
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
323+
else llama_cpp.llama_rope_scaling_type.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
324+
)
325+
self.context_params.pooling_type = (
326+
pooling_type
327+
if pooling_type is not None
328+
else llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED
329+
)
330+
self.context_params.attention_type = (
331+
attention_type
332+
if attention_type is not None
333+
else llama_cpp.llama_attention_type.LLAMA_ATTENTION_TYPE_UNSPECIFIED
334+
)
335+
self.context_params.flash_attn_type = (
336+
flash_attn_type
337+
if flash_attn_type is not None
338+
else llama_cpp.llama_flash_attn_type.LLAMA_FLASH_ATTN_TYPE_AUTO
322339
)
323-
self.context_params.pooling_type = pooling_type
324340
self.context_params.rope_freq_base = (
325341
rope_freq_base if rope_freq_base != 0.0 else 0
326342
)
@@ -343,7 +359,6 @@ def __init__(
343359
self._logits_all = logits_all if draft_model is None else True
344360
self.context_params.embeddings = embedding # TODO: Rename to embeddings
345361
self.context_params.offload_kqv = offload_kqv
346-
self.context_params.flash_attn = flash_attn
347362

348363
if op_offload is not None:
349364
self.context_params.op_offload = op_offload
@@ -2201,6 +2216,8 @@ def __getstate__(self):
22012216
n_threads_batch=self.context_params.n_threads_batch,
22022217
rope_scaling_type=self.context_params.rope_scaling_type,
22032218
pooling_type=self.context_params.pooling_type,
2219+
attention_type=self.context_params.attention_type,
2220+
flash_attn_type=self.context_params.flash_attn_type,
22042221
rope_freq_base=self.context_params.rope_freq_base,
22052222
rope_freq_scale=self.context_params.rope_freq_scale,
22062223
yarn_ext_factor=self.context_params.yarn_ext_factor,
@@ -2211,7 +2228,6 @@ def __getstate__(self):
22112228
logits_all=self._logits_all,
22122229
embedding=self.context_params.embeddings,
22132230
offload_kqv=self.context_params.offload_kqv,
2214-
flash_attn=self.context_params.flash_attn,
22152231
op_offload=self.context_params.op_offload,
22162232
swa_full=self.context_params.swa_full,
22172233
kv_unified= self.context_params.kv_unified,

llama_cpp/llama_cpp.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ctypes
4+
import enum
45
import os
56
import pathlib
67

@@ -451,12 +452,13 @@
451452
# LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3,
452453
# LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN,
453454
# };
454-
LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1
455-
LLAMA_ROPE_SCALING_TYPE_NONE = 0
456-
LLAMA_ROPE_SCALING_TYPE_LINEAR = 1
457-
LLAMA_ROPE_SCALING_TYPE_YARN = 2
458-
LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3
459-
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN
455+
class llama_rope_scaling_type(enum.IntEnum):
456+
LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED = -1
457+
LLAMA_ROPE_SCALING_TYPE_NONE = 0
458+
LLAMA_ROPE_SCALING_TYPE_LINEAR = 1
459+
LLAMA_ROPE_SCALING_TYPE_YARN = 2
460+
LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3
461+
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN
460462

461463
# enum llama_pooling_type {
462464
# LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
@@ -478,10 +480,33 @@
478480
# LLAMA_ATTENTION_TYPE_CAUSAL = 0,
479481
# LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1,
480482
# };
481-
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1
482-
LLAMA_ATTENTION_TYPE_CAUSAL = 0
483-
LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1
483+
class llama_attention_type(enum.IntEnum):
484+
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1
485+
LLAMA_ATTENTION_TYPE_CAUSAL = 0
486+
LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1
487+
488+
# enum llama_flash_attn_type {
489+
# LLAMA_FLASH_ATTN_TYPE_AUTO = -1,
490+
# LLAMA_FLASH_ATTN_TYPE_DISABLED = 0,
491+
# LLAMA_FLASH_ATTN_TYPE_ENABLED = 1,
492+
# };
493+
class llama_flash_attn_type(enum.IntEnum):
494+
LLAMA_FLASH_ATTN_TYPE_AUTO = -1
495+
LLAMA_FLASH_ATTN_TYPE_DISABLED = 0
496+
LLAMA_FLASH_ATTN_TYPE_ENABLED = 1
484497

498+
# LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type);
499+
@ctypes_function(
500+
"llama_flash_attn_type_name",
501+
[ctypes.c_int],
502+
ctypes.c_char_p,
503+
)
504+
def llama_flash_attn_type_name(
505+
flash_attn_type: llama_flash_attn_type, /
506+
) -> bytes:
507+
"""
508+
Gets the name of a llama_flash_attn_type.
509+
"""
485510

486511
# enum llama_split_mode {
487512
# LLAMA_SPLIT_MODE_NONE = 0, // single GPU
@@ -793,6 +818,7 @@ class llama_model_params(ctypes.Structure):
793818
# enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
794819
# enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
795820
# enum llama_attention_type attention_type; // attention type to use for embeddings
821+
# enum llama_flash_attn_type flash_attn_type; // when to enable Flash Attention
796822

797823
# // ref: https://github.com/ggml-org/llama.cpp/pull/2054
798824
# float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -818,7 +844,6 @@ class llama_model_params(ctypes.Structure):
818844
# // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
819845
# bool embeddings; // if true, extract embeddings (together with logits)
820846
# bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
821-
# bool flash_attn; // use flash attention [EXPERIMENTAL]
822847
# bool no_perf; // measure performance timings
823848
# bool op_offload; // offload host tensor operations to device
824849
# bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
@@ -841,6 +866,7 @@ class llama_context_params(ctypes.Structure):
841866
rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
842867
pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
843868
attention_type (int): attention type to use for embeddings
869+
flash_attn_type (int): when to enable Flash Attention
844870
rope_freq_base (float): RoPE base frequency, 0 = from model
845871
rope_freq_scale (float): RoPE frequency scaling factor, 0 = from model
846872
yarn_ext_factor (float): YaRN extrapolation mix factor, negative = from model
@@ -857,7 +883,6 @@ class llama_context_params(ctypes.Structure):
857883
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
858884
embeddings (bool): if true, extract embeddings (together with logits)
859885
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
860-
flash_attn (bool): whether to use flash attention
861886
no_perf (bool): whether to measure performance timings
862887
op_offload(bool): whether to offload host tensor operations to device
863888
swa_full(bool): whether to use full-size SWA cache
@@ -874,6 +899,7 @@ class llama_context_params(ctypes.Structure):
874899
rope_scaling_type: int
875900
pooling_type: int
876901
attention_type: int
902+
flash_attn_type: int
877903
rope_freq_base: float
878904
rope_freq_scale: float
879905
yarn_ext_factor: float
@@ -890,7 +916,6 @@ class llama_context_params(ctypes.Structure):
890916
abort_callback_data: ctypes.c_void_p
891917
embeddings: bool
892918
offload_kqv: bool
893-
flash_attn: bool
894919
no_perf: bool
895920
op_offload:bool
896921
swa_full:bool
@@ -906,6 +931,7 @@ class llama_context_params(ctypes.Structure):
906931
("rope_scaling_type", ctypes.c_int),
907932
("pooling_type", ctypes.c_int),
908933
("attention_type", ctypes.c_int),
934+
("flash_attn_type", ctypes.c_int),
909935
("rope_freq_base", ctypes.c_float),
910936
("rope_freq_scale", ctypes.c_float),
911937
("yarn_ext_factor", ctypes.c_float),
@@ -922,7 +948,6 @@ class llama_context_params(ctypes.Structure):
922948
("abort_callback_data", ctypes.c_void_p),
923949
("embeddings", ctypes.c_bool),
924950
("offload_kqv", ctypes.c_bool),
925-
("flash_attn", ctypes.c_bool),
926951
("no_perf", ctypes.c_bool),
927952
("op_offload", ctypes.c_bool),
928953
("swa_full", ctypes.c_bool),

llama_cpp/server/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,9 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
281281
n_threads=settings.n_threads,
282282
n_threads_batch=settings.n_threads_batch,
283283
rope_scaling_type=settings.rope_scaling_type,
284+
pooling_type=settings.pooling_type,
285+
attention_type=settings.attention_type,
286+
flash_attn_type=settings.flash_attn_type,
284287
rope_freq_base=settings.rope_freq_base,
285288
rope_freq_scale=settings.rope_freq_scale,
286289
yarn_ext_factor=settings.yarn_ext_factor,
@@ -292,7 +295,6 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
292295
logits_all=settings.logits_all,
293296
embedding=settings.embedding,
294297
offload_kqv=settings.offload_kqv,
295-
flash_attn=settings.flash_attn,
296298
op_offload=settings.op_offload,
297299
swa_full=settings.swa_full,
298300
kv_unified=settings.kv_unified,

llama_cpp/server/settings.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,20 @@ class ModelSettings(BaseSettings):
8484
description="The number of threads to use when batch processing. Use -1 for max cpu threads",
8585
)
8686
rope_scaling_type: int = Field(
87-
default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
87+
default=llama_cpp.llama_rope_scaling_type.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
88+
description="RoPE scaling type, from `enum llama_rope_scaling_type",
89+
)
90+
pooling_type: int = Field(
91+
default=llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
92+
description="whether to pool (sum) embedding results by sequence id",
93+
)
94+
attention_type: int = Field(
95+
default=llama_cpp.llama_attention_type.LLAMA_ATTENTION_TYPE_UNSPECIFIED,
96+
description="attention type to use for embeddings",
97+
)
98+
flash_attn_type: int = Field(
99+
default=llama_cpp.llama_flash_attn_type.LLAMA_FLASH_ATTN_TYPE_AUTO,
100+
description="when to enable Flash Attention",
88101
)
89102
rope_freq_base: float = Field(default=0.0, description="RoPE base frequency")
90103
rope_freq_scale: float = Field(
@@ -103,9 +116,6 @@ class ModelSettings(BaseSettings):
103116
offload_kqv: bool = Field(
104117
default=True, description="Whether to offload kqv to the GPU."
105118
)
106-
flash_attn: bool = Field(
107-
default=False, description="Whether to use flash attention."
108-
)
109119
op_offload: bool = Field(
110120
default=True, description="Whether to offload host tensor operations to device"
111121
)

tests/test_llama.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def test_real_llama(llama_cpp_model_path):
129129
n_threads=multiprocessing.cpu_count(),
130130
n_threads_batch=multiprocessing.cpu_count(),
131131
logits_all=False,
132-
flash_attn=True,
133132
swa_full=True,
134133
kv_unified=True,
135134
)
@@ -234,7 +233,6 @@ def test_real_llama_embeddings(llama_cpp_model_path):
234233
n_threads=multiprocessing.cpu_count(),
235234
n_threads_batch=multiprocessing.cpu_count(),
236235
logits_all=False,
237-
flash_attn=True,
238236
swa_full=True,
239237
kv_unified=True,
240238
embedding=True

0 commit comments

Comments
 (0)