|
9 | 9 | from exllamav2.linear import ExLlamaV2Linear |
10 | 10 | from exllamav2.cache import ExLlamaV2CacheBase |
11 | 11 | from exllamav2.ext import exllamav2_ext as ext_c, none_tensor |
12 | | -from exllamav2.compat import safe_move_tensor |
13 | 12 | from exllamav2.lora import ExLlamaV2Lora |
14 | 13 | from exllamav2.architecture import RopeStyle |
15 | 14 | from exllamav2.tensor_p import BROADCAST_KV, BROADCAST_Q |
16 | 15 | import math |
17 | | -# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak |
18 | 16 | import torch.nn.functional as F |
19 | 17 | import inspect |
20 | 18 | import os |
21 | | -# from line_profiler import profile |
22 | 19 |
|
23 | 20 | from typing import TYPE_CHECKING |
24 | 21 | if TYPE_CHECKING: |
|
30 | 27 | has_flash_attn_with_paged = False |
31 | 28 | has_flash_attn_with_window = False |
32 | 29 | has_flash_attn_with_softcap = False |
| 30 | +has_xformers = False |
| 31 | +has_lower_right_sdpa = False |
| 32 | + |
33 | 33 | if 'EXLLAMA_NO_FLASH_ATTN' not in os.environ: |
34 | 34 |
|
35 | 35 | try: |
|
63 | 63 | except NameError: |
64 | 64 | pass |
65 | 65 |
|
66 | | - |
67 | | -has_xformers = False |
68 | 66 | if 'EXLLAMA_NO_XFORMERS' not in os.environ: |
69 | 67 |
|
70 | 68 | try: |
|
75 | 73 | except ModuleNotFoundError: |
76 | 74 | pass |
77 | 75 |
|
78 | | - |
79 | | -has_lower_right_sdpa = False |
80 | 76 | if 'EXLLAMA_NO_SDPA' not in os.environ: |
81 | 77 | try: |
82 | 78 | from torch.nn.attention.bias import causal_lower_right |
|
86 | 82 |
|
87 | 83 |
|
88 | 84 | def assert_paged_attn(): |
| 85 | + """ |
| 86 | + Raise an exception if paged attention is not available. |
| 87 | + """ |
89 | 88 | global has_flash_attn_with_paged |
90 | 89 | assert has_flash_attn_with_paged, \ |
91 | 90 | "Paged attention required Flash Attention 2.5.7 or later" |
@@ -128,14 +127,15 @@ class ExLlamaV2Attention(ExLlamaV2Module): |
128 | 127 | from exllamav2.attn_params import Params |
129 | 128 | from exllamav2.attn_params import PagedParams |
130 | 129 |
|
131 | | - def __init__(self, |
132 | | - model: ExLlamaV2, |
133 | | - key: str, |
134 | | - layer_idx: int, |
135 | | - has_norm: bool = True, |
136 | | - has_residual: bool = True, |
137 | | - sliding_window: int = 0): |
138 | | - |
| 130 | + def __init__( |
| 131 | + self, |
| 132 | + model: ExLlamaV2, |
| 133 | + key: str, |
| 134 | + layer_idx: int, |
| 135 | + has_norm: bool = True, |
| 136 | + has_residual: bool = True, |
| 137 | + sliding_window: int = 0 |
| 138 | + ): |
139 | 139 | super().__init__(model, key) |
140 | 140 |
|
141 | 141 | cfg = self.model.config |
@@ -1312,14 +1312,17 @@ def forward_tp_old( |
1312 | 1312 | return hidden_states |
1313 | 1313 |
|
1314 | 1314 |
|
1315 | | - def forward_torch(self, |
1316 | | - hidden_states: torch.Tensor, |
1317 | | - cache: ExLlamaV2CacheBase | None = None, |
1318 | | - attn_params: ExLlamaV2Attention.Params | None = None, |
1319 | | - past_len: int | None = None, |
1320 | | - intermediates: bool = False, |
1321 | | - loras: list[ExLlamaV2Lora] | None = None, |
1322 | | - **kwargs) -> torch.Tensor | dict: |
| 1315 | + def forward_torch( |
| 1316 | + self, |
| 1317 | + hidden_states: torch.Tensor, |
| 1318 | + cache: ExLlamaV2CacheBase | None = None, |
| 1319 | + attn_params: ExLlamaV2Attention.Params | None = None, |
| 1320 | + past_len: int | None = None, |
| 1321 | + intermediates: bool = False, |
| 1322 | + loras: list[ExLlamaV2Lora] | None = None, |
| 1323 | + **kwargs |
| 1324 | + ) -> torch.Tensor | dict: |
| 1325 | + |
1323 | 1326 | global has_flash_attn |
1324 | 1327 | global has_xformers |
1325 | 1328 |
|
|
0 commit comments