|
22 | 22 | from PySide6.QtWidgets import ( |
23 | 23 | QMainWindow, QWidget, QVBoxLayout, QTextEdit, |
24 | 24 | QLineEdit, QMessageBox, QPushButton, QLabel, |
25 | | - QHBoxLayout, QSizePolicy, QComboBox, QApplication |
| 25 | + QHBoxLayout, QSizePolicy, QComboBox, QApplication, QSpinBox |
26 | 26 | ) |
27 | 27 | from PySide6.QtCore import QThread, Signal, Qt, QTimer, QObject |
28 | 28 | from PySide6.QtGui import QTextCursor, QPixmap |
|
38 | 38 | from module_kokoro import KokoroTTS |
39 | 39 | from utilities import normalize_chat_text |
40 | 40 |
|
| 41 | + |
41 | 42 | class GenerationWorker(QThread): |
42 | 43 | token_signal = Signal(str) |
43 | 44 | finished_signal = Signal() |
@@ -130,6 +131,17 @@ def __init__(self, parent=None): |
130 | 131 | self.eject_button.setEnabled(False) |
131 | 132 | model_layout.addWidget(self.eject_button) |
132 | 133 |
|
| 134 | + self.context_label = QLabel("Contexts:") |
| 135 | + self.context_label.setFixedHeight(30) |
| 136 | + |
| 137 | + self.context_spin = QSpinBox() |
| 138 | + self.context_spin.setRange(1, 10) # allow 1-20 |
| 139 | + self.context_spin.setValue(5) # default 5 |
| 140 | + self.context_spin.setFixedHeight(30) |
| 141 | + |
| 142 | + model_layout.addWidget(self.context_label) |
| 143 | + model_layout.addWidget(self.context_spin) |
| 144 | + |
133 | 145 | self.layout.addLayout(model_layout) |
134 | 146 |
|
135 | 147 | self.chat_display = QTextEdit() |
@@ -226,6 +238,16 @@ def __init__(self, parent=None): |
226 | 238 | self.tts_worker = None |
227 | 239 | self.is_speaking = False |
228 | 240 |
|
| 241 | + def _ensure_model(self) -> None: |
| 242 | + """ |
| 243 | + Download or resume-download the model if *model.bin* is missing. |
| 244 | + (Keeps everything else that is already in the cache.) |
| 245 | + """ |
| 246 | + model_dir = Path(self.model_dir) |
| 247 | + if not (model_dir / "model.bin").exists(): |
| 248 | + print("model.bin missing – redownloading just that file …") |
| 249 | + self._download_model() |
| 250 | + |
229 | 251 | def eject_model(self): |
230 | 252 | if self.generator: |
231 | 253 | del self.generator |
@@ -294,6 +316,7 @@ def on_model_downloaded(self, model_name, model_type): |
294 | 316 | self._load_model() |
295 | 317 |
|
296 | 318 | def _load_model(self): |
| 319 | + self._ensure_model() |
297 | 320 | physical_cores = max(1, psutil.cpu_count(logical=False) - 1) |
298 | 321 | device = "cuda" if torch.cuda.is_available() else "cpu" |
299 | 322 |
|
@@ -361,6 +384,7 @@ def send_message(self): |
361 | 384 | self.chat_display.clear() |
362 | 385 |
|
363 | 386 | try: |
| 387 | + k_value = self.context_spin.value() |
364 | 388 | contexts, metadata = self.vector_db.search(user_message, k=5, score_threshold=0.9) |
365 | 389 | if not contexts: |
366 | 390 | QMessageBox.warning(self, "No Contexts Found", "No relevant contexts were found for your query.") |
|
0 commit comments