Skip to content

Commit a029bcd

Browse files
committed
Bit of cleanup
1 parent 361d211 commit a029bcd

File tree

4 files changed

+144
-119
lines changed

4 files changed

+144
-119
lines changed

exllamav2/architecture.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,17 @@ class RopeStyle(IntEnum):
101101

102102
class ExLlamaV2ArchParams:
103103

104-
def __init__(self, arch_string, read_config):
104+
def __init__(self, arch_string: str, read_config: dict):
105+
"""
106+
Get architecture definition from model config. If the architecture isn't recognized, defaults to Llama
107+
architecture.
108+
W
109+
:param arch_string:
110+
Architecture string from config.json
111+
112+
:param read_config:
113+
config.json as Python dict
114+
"""
105115

106116
self.arch_string = arch_string
107117
arch_recognized = False

exllamav2/attn.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,13 @@
99
from exllamav2.linear import ExLlamaV2Linear
1010
from exllamav2.cache import ExLlamaV2CacheBase
1111
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
12-
from exllamav2.compat import safe_move_tensor
1312
from exllamav2.lora import ExLlamaV2Lora
1413
from exllamav2.architecture import RopeStyle
1514
from exllamav2.tensor_p import BROADCAST_KV, BROADCAST_Q
1615
import math
17-
# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak
1816
import torch.nn.functional as F
1917
import inspect
2018
import os
21-
# from line_profiler import profile
2219

2320
from typing import TYPE_CHECKING
2421
if TYPE_CHECKING:
@@ -30,6 +27,9 @@
3027
has_flash_attn_with_paged = False
3128
has_flash_attn_with_window = False
3229
has_flash_attn_with_softcap = False
30+
has_xformers = False
31+
has_lower_right_sdpa = False
32+
3333
if 'EXLLAMA_NO_FLASH_ATTN' not in os.environ:
3434

3535
try:
@@ -63,8 +63,6 @@
6363
except NameError:
6464
pass
6565

66-
67-
has_xformers = False
6866
if 'EXLLAMA_NO_XFORMERS' not in os.environ:
6967

7068
try:
@@ -75,8 +73,6 @@
7573
except ModuleNotFoundError:
7674
pass
7775

78-
79-
has_lower_right_sdpa = False
8076
if 'EXLLAMA_NO_SDPA' not in os.environ:
8177
try:
8278
from torch.nn.attention.bias import causal_lower_right
@@ -86,6 +82,9 @@
8682

8783

8884
def assert_paged_attn():
85+
"""
86+
Raise an exception if paged attention is not available.
87+
"""
8988
global has_flash_attn_with_paged
9089
assert has_flash_attn_with_paged, \
9190
"Paged attention required Flash Attention 2.5.7 or later"
@@ -128,14 +127,15 @@ class ExLlamaV2Attention(ExLlamaV2Module):
128127
from exllamav2.attn_params import Params
129128
from exllamav2.attn_params import PagedParams
130129

131-
def __init__(self,
132-
model: ExLlamaV2,
133-
key: str,
134-
layer_idx: int,
135-
has_norm: bool = True,
136-
has_residual: bool = True,
137-
sliding_window: int = 0):
138-
130+
def __init__(
131+
self,
132+
model: ExLlamaV2,
133+
key: str,
134+
layer_idx: int,
135+
has_norm: bool = True,
136+
has_residual: bool = True,
137+
sliding_window: int = 0
138+
):
139139
super().__init__(model, key)
140140

141141
cfg = self.model.config
@@ -1312,14 +1312,17 @@ def forward_tp_old(
13121312
return hidden_states
13131313

13141314

1315-
def forward_torch(self,
1316-
hidden_states: torch.Tensor,
1317-
cache: ExLlamaV2CacheBase | None = None,
1318-
attn_params: ExLlamaV2Attention.Params | None = None,
1319-
past_len: int | None = None,
1320-
intermediates: bool = False,
1321-
loras: list[ExLlamaV2Lora] | None = None,
1322-
**kwargs) -> torch.Tensor | dict:
1315+
def forward_torch(
1316+
self,
1317+
hidden_states: torch.Tensor,
1318+
cache: ExLlamaV2CacheBase | None = None,
1319+
attn_params: ExLlamaV2Attention.Params | None = None,
1320+
past_len: int | None = None,
1321+
intermediates: bool = False,
1322+
loras: list[ExLlamaV2Lora] | None = None,
1323+
**kwargs
1324+
) -> torch.Tensor | dict:
1325+
13231326
global has_flash_attn
13241327
global has_xformers
13251328

exllamav2/attn_params.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ class Params:
1919
paged: bool
2020

2121
def __init__(
22-
self,
23-
batch_size: int,
24-
seq_len: int | None = None,
25-
past_len: int | list[int] | None = None,
26-
input_mask: torch.Tensor | None = None,
27-
position_offsets: torch.Tensor | None = None,
28-
paged=False
22+
self,
23+
batch_size: int,
24+
seq_len: int | None = None,
25+
past_len: int | list[int] | None = None,
26+
input_mask: torch.Tensor | None = None,
27+
position_offsets: torch.Tensor | None = None,
28+
paged=False
2929
):
3030

3131
self.batch_size = batch_size

0 commit comments

Comments
 (0)