Skip to content

Commit 6a22614

Browse files
authored
improve doc loading
multiprocess OCR check for pdfs and general streamline checks
1 parent a6af078 commit 6a22614

File tree

3 files changed

+139
-95
lines changed

3 files changed

+139
-95
lines changed

src/database_interactions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def prepare_encode_kwargs(self):
220220
return encode_kwargs
221221

222222

223-
class InflyEmbedding(BaseEmbeddingModel):
223+
class InflyAndAlibabaEmbedding(BaseEmbeddingModel):
224224
def prepare_kwargs(self):
225225
# 1) inherit all kwargs from the base class
226226
infly_kwargs = super().prepare_kwargs()
@@ -330,7 +330,7 @@ def initialize_vector_model(self, embedding_model_name, config_data):
330330
model = SnowflakeEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
331331
elif "alibaba" in embedding_model_name.lower():
332332
logger.debug("Matched Alibaba condition")
333-
model = InflyEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
333+
model = InflyAndAlibabaEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
334334
elif "400m" in embedding_model_name.lower():
335335
logger.debug("Matched Stella 400m condition")
336336
model = Stella400MEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
@@ -342,7 +342,7 @@ def initialize_vector_model(self, embedding_model_name, config_data):
342342
model = BgeCodeEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
343343
elif "infly" in embedding_model_name.lower():
344344
logger.debug("Matches infly condition")
345-
model = InflyEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
345+
model = InflyAndAlibabaEmbedding(embedding_model_name, model_kwargs, encode_kwargs).create()
346346
else:
347347
logger.debug("No conditions matched - using base model")
348348
model = BaseEmbeddingModel(embedding_model_name, model_kwargs, encode_kwargs).create()
@@ -647,13 +647,13 @@ def initialize_vector_model(self):
647647
if "snowflake" in mp_lower:
648648
embeddings = SnowflakeEmbedding(model_path, model_kwargs, encode_kwargs, is_query=True).create()
649649
elif "alibaba" in mp_lower:
650-
embeddings = InflyEmbedding(model_path, model_kwargs, encode_kwargs, is_query=True).create()
650+
embeddings = InflyAndAlibabaEmbedding(model_path, model_kwargs, encode_kwargs, is_query=True).create()
651651
elif "400m" in mp_lower:
652652
embeddings = Stella400MEmbedding(model_path, model_kwargs, encode_kwargs, is_query=True).create()
653653
elif "stella_en_1.5b_v5" in mp_lower:
654654
embeddings = StellaEmbedding(model_path, model_kwargs, encode_kwargs, is_query=True).create()
655655
elif "infly" in mp_lower:
656-
embeddings = InflyEmbedding(model_path, model_kwargs, encode_kwargs, is_query=True).create()
656+
embeddings = InflyAndAlibabaEmbedding(model_path, model_kwargs, encode_kwargs, is_query=True).create()
657657
elif "bge-code" in mp_lower:
658658
embeddings = BgeCodeEmbedding(model_path, model_kwargs, encode_kwargs, is_query=True).create()
659659
else:

src/gui_tabs_databases.py

Lines changed: 64 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,57 @@
1414
from download_model import model_downloaded_signal
1515
from constants import TOOLTIPS
1616

17+
1718
class CreateDatabaseProcess:
1819
def __init__(self, database_name, parent=None):
1920
self.database_name = database_name
2021
self.process = None
22+
2123
def start(self):
2224
self.process = multiprocessing.Process(target=create_vector_db_in_process, args=(self.database_name,))
2325
self.process.start()
26+
2427
def wait(self):
2528
if self.process:
2629
self.process.join()
30+
2731
def is_alive(self):
2832
if self.process:
2933
return self.process.is_alive()
3034
return False
35+
36+
3137
class CreateDatabaseThread(QThread):
3238
creationComplete = Signal()
33-
def __init__(self, database_name, model_name, parent=None):
39+
validationFailed = Signal(str)
40+
41+
def __init__(self, database_name, model_name, skip_ocr, parent=None):
3442
super().__init__(parent)
3543
self.database_name = database_name
3644
self.model_name = model_name
45+
self.skip_ocr = skip_ocr
3746
self.process = None
47+
3848
def run(self):
39-
self.process = multiprocessing.Process(target=create_vector_db_in_process, args=(self.database_name,))
49+
script_dir = Path(__file__).resolve().parent
50+
ok, msg = check_preconditions_for_db_creation(script_dir,
51+
self.database_name,
52+
skip_ocr=self.skip_ocr)
53+
if not ok:
54+
self.validationFailed.emit(msg)
55+
return
56+
57+
self.process = multiprocessing.Process(
58+
target=create_vector_db_in_process,
59+
args=(self.database_name,))
4060
self.process.start()
4161
self.process.join()
4262
my_cprint(f"{self.model_name} removed from memory.", "red")
4363
self.creationComplete.emit()
4464
time.sleep(.2)
4565
self.update_config_with_database_name()
4666
backup_database_incremental(self.database_name)
67+
4768
def update_config_with_database_name(self):
4869
config_path = Path(__file__).resolve().parent / "config.yaml"
4970
if config_path.exists():
@@ -61,10 +82,14 @@ def update_config_with_database_name(self):
6182
}
6283
with open(config_path, 'w', encoding='utf-8') as file:
6384
yaml.safe_dump(config, file, allow_unicode=True)
85+
86+
6487
class CustomFileSystemModel(QFileSystemModel):
6588
def __init__(self, parent=None):
6689
super().__init__(parent)
6790
self.setFilter(QDir.Files)
91+
92+
6893
class DatabasesTab(QWidget):
6994
def __init__(self):
7095
super().__init__()
@@ -103,16 +128,23 @@ def __init__(self):
103128
self.layout.addLayout(grid_layout_top_buttons)
104129
self.layout.addLayout(hbox2)
105130
self.sync_combobox_with_config()
131+
132+
def _validation_failed(self, message: str):
133+
QMessageBox.warning(self, "Validation Failed", message)
134+
self._reenable_widgets()
135+
106136
def refresh_model_combobox(self, index):
107137
current_text = self.model_combobox.currentText()
108138
self.populate_model_combobox()
109139
idx = self.model_combobox.findText(current_text)
110140
if idx >= 0:
111141
self.model_combobox.setCurrentIndex(idx)
142+
112143
def update_model_combobox(self, model_name, model_type):
113144
if model_type == "vector":
114145
self.populate_model_combobox()
115146
self.sync_combobox_with_config()
147+
116148
def populate_model_combobox(self):
117149
self.model_combobox.clear()
118150
self.model_combobox.addItem("Select a model", None)
@@ -125,6 +157,7 @@ def populate_model_combobox(self):
125157
display_name = folder.name
126158
full_path = str(folder)
127159
self.model_combobox.addItem(display_name, full_path)
160+
128161
def sync_combobox_with_config(self):
129162
config_path = Path(__file__).resolve().parent / "config.yaml"
130163
if config_path.exists():
@@ -141,6 +174,7 @@ def sync_combobox_with_config(self):
141174
self.model_combobox.setCurrentIndex(0)
142175
else:
143176
self.model_combobox.setCurrentIndex(0)
177+
144178
def on_model_selected(self, index):
145179
selected_path = self.model_combobox.itemData(index)
146180
config_path = Path(__file__).resolve().parent / "config.yaml"
@@ -165,6 +199,7 @@ def on_model_selected(self, index):
165199
config_data.pop("EMBEDDING_MODEL_DIMENSIONS", None)
166200
with open(config_path, 'w', encoding='utf-8') as file:
167201
yaml.safe_dump(config_data, file, allow_unicode=True)
202+
168203
def create_group_box(self, title, directory_name):
169204
group_box = QGroupBox(title)
170205
layout = QVBoxLayout()
@@ -174,11 +209,13 @@ def create_group_box(self, title, directory_name):
174209
self.layout.addWidget(group_box)
175210
group_box.toggled.connect(lambda checked, gb=group_box: self.toggle_group_box(gb, checked))
176211
return group_box
212+
177213
def _refresh_docs_model(self):
178214
if hasattr(self.docs_model, 'refresh'):
179215
self.docs_model.refresh()
180216
elif hasattr(self.docs_model, 'reindex'):
181217
self.docs_model.reindex()
218+
182219
def setup_directory_view(self, directory_name):
183220
tree_view = QTreeView()
184221
model = CustomFileSystemModel()
@@ -201,57 +238,71 @@ def setup_directory_view(self, directory_name):
201238
self.docs_refresh.setInterval(500)
202239
self.docs_refresh.timeout.connect(self._refresh_docs_model)
203240
return tree_view
241+
204242
def on_double_click(self, index):
205243
tree_view = self.sender()
206244
model = tree_view.model()
207245
file_path = model.filePath(index)
208246
open_file(file_path)
247+
209248
def on_context_menu(self, point):
210249
tree_view = self.sender()
211250
context_menu = QMenu(self)
212251
delete_action = QAction("Delete File", self)
213252
context_menu.addAction(delete_action)
214253
delete_action.triggered.connect(lambda: self.on_delete_file(tree_view))
215254
context_menu.exec_(tree_view.viewport().mapToGlobal(point))
255+
216256
def on_delete_file(self, tree_view):
217257
selected_indexes = tree_view.selectedIndexes()
218258
model = tree_view.model()
219259
for index in selected_indexes:
220260
if index.column() == 0:
221261
file_path = model.filePath(index)
222262
delete_file(file_path)
263+
223264
def on_create_db_clicked(self):
224265
if self.model_combobox.currentIndex() == 0:
225266
QMessageBox.warning(self, "No Model Selected", "Please select a model before creating a database.")
226267
return
268+
227269
self.create_db_button.setDisabled(True)
228270
self.choose_docs_button.setDisabled(True)
229271
self.model_combobox.setDisabled(True)
230272
self.database_name_input.setDisabled(True)
273+
231274
database_name = self.database_name_input.text().strip()
232-
model_name = self.model_combobox.currentText()
233-
script_dir = Path(__file__).resolve().parent
234-
checks_passed, message = check_preconditions_for_db_creation(script_dir, database_name)
235-
if not checks_passed:
236-
self.create_db_button.setDisabled(False)
237-
self.choose_docs_button.setDisabled(False)
238-
self.model_combobox.setDisabled(False)
239-
self.database_name_input.setDisabled(False)
240-
QMessageBox.warning(self, "Validation Failed", message)
241-
return
242-
self.create_database_thread = CreateDatabaseThread(database_name=database_name, model_name=model_name, parent=self)
275+
model_name = self.model_combobox.currentText()
276+
277+
docs_dir = Path(__file__).resolve().parent / "Docs_for_DB"
278+
has_pdfs = any(p.suffix.lower() == ".pdf" for p in docs_dir.iterdir() if p.is_file())
279+
skip_ocr = False
280+
if has_pdfs:
281+
reply = QMessageBox.question(self, "OCR Check",
282+
"PDF files detected. Do you want to check if any of the PDFs need OCR? "
283+
"If there are a lot of PDFs, it is time-consuming but strongly recommended.",
284+
QMessageBox.Yes | QMessageBox.No,
285+
QMessageBox.Yes)
286+
skip_ocr = (reply == QMessageBox.No)
287+
288+
self.create_database_thread = CreateDatabaseThread(database_name, model_name, skip_ocr, parent=self)
289+
243290
self.create_database_thread.creationComplete.connect(self.reenable_create_db_button)
291+
self.create_database_thread.validationFailed.connect(self._validation_failed)
244292
self.create_database_thread.start()
293+
245294
def reenable_create_db_button(self):
246295
self.create_db_button.setDisabled(False)
247296
self.choose_docs_button.setDisabled(False)
248297
self.model_combobox.setDisabled(False)
249298
self.database_name_input.setDisabled(False)
250299
self.create_database_thread = None
251300
gc.collect()
301+
252302
def toggle_group_box(self, group_box, checked):
253303
self.groups[group_box] = 1 if checked else 0
254304
self.adjust_stretch()
305+
255306
def adjust_stretch(self):
256307
for group, stretch in self.groups.items():
257308
self.layout.setStretchFactor(group, stretch if group.isChecked() else 0)

0 commit comments

Comments
 (0)