1414from download_model import model_downloaded_signal
1515from constants import TOOLTIPS
1616
17+
1718class CreateDatabaseProcess :
1819 def __init__ (self , database_name , parent = None ):
1920 self .database_name = database_name
2021 self .process = None
22+
2123 def start (self ):
2224 self .process = multiprocessing .Process (target = create_vector_db_in_process , args = (self .database_name ,))
2325 self .process .start ()
26+
2427 def wait (self ):
2528 if self .process :
2629 self .process .join ()
30+
2731 def is_alive (self ):
2832 if self .process :
2933 return self .process .is_alive ()
3034 return False
35+
36+
3137class CreateDatabaseThread (QThread ):
3238 creationComplete = Signal ()
33- def __init__ (self , database_name , model_name , parent = None ):
39+ validationFailed = Signal (str )
40+
41+ def __init__ (self , database_name , model_name , skip_ocr , parent = None ):
3442 super ().__init__ (parent )
3543 self .database_name = database_name
3644 self .model_name = model_name
45+ self .skip_ocr = skip_ocr
3746 self .process = None
47+
3848 def run (self ):
39- self .process = multiprocessing .Process (target = create_vector_db_in_process , args = (self .database_name ,))
49+ script_dir = Path (__file__ ).resolve ().parent
50+ ok , msg = check_preconditions_for_db_creation (script_dir ,
51+ self .database_name ,
52+ skip_ocr = self .skip_ocr )
53+ if not ok :
54+ self .validationFailed .emit (msg )
55+ return
56+
57+ self .process = multiprocessing .Process (
58+ target = create_vector_db_in_process ,
59+ args = (self .database_name ,))
4060 self .process .start ()
4161 self .process .join ()
4262 my_cprint (f"{ self .model_name } removed from memory." , "red" )
4363 self .creationComplete .emit ()
4464 time .sleep (.2 )
4565 self .update_config_with_database_name ()
4666 backup_database_incremental (self .database_name )
67+
4768 def update_config_with_database_name (self ):
4869 config_path = Path (__file__ ).resolve ().parent / "config.yaml"
4970 if config_path .exists ():
@@ -61,10 +82,14 @@ def update_config_with_database_name(self):
6182 }
6283 with open (config_path , 'w' , encoding = 'utf-8' ) as file :
6384 yaml .safe_dump (config , file , allow_unicode = True )
85+
86+
6487class CustomFileSystemModel (QFileSystemModel ):
6588 def __init__ (self , parent = None ):
6689 super ().__init__ (parent )
6790 self .setFilter (QDir .Files )
91+
92+
6893class DatabasesTab (QWidget ):
6994 def __init__ (self ):
7095 super ().__init__ ()
@@ -103,16 +128,23 @@ def __init__(self):
103128 self .layout .addLayout (grid_layout_top_buttons )
104129 self .layout .addLayout (hbox2 )
105130 self .sync_combobox_with_config ()
131+
132+ def _validation_failed (self , message : str ):
133+ QMessageBox .warning (self , "Validation Failed" , message )
134+ self ._reenable_widgets ()
135+
106136 def refresh_model_combobox (self , index ):
107137 current_text = self .model_combobox .currentText ()
108138 self .populate_model_combobox ()
109139 idx = self .model_combobox .findText (current_text )
110140 if idx >= 0 :
111141 self .model_combobox .setCurrentIndex (idx )
142+
112143 def update_model_combobox (self , model_name , model_type ):
113144 if model_type == "vector" :
114145 self .populate_model_combobox ()
115146 self .sync_combobox_with_config ()
147+
116148 def populate_model_combobox (self ):
117149 self .model_combobox .clear ()
118150 self .model_combobox .addItem ("Select a model" , None )
@@ -125,6 +157,7 @@ def populate_model_combobox(self):
125157 display_name = folder .name
126158 full_path = str (folder )
127159 self .model_combobox .addItem (display_name , full_path )
160+
128161 def sync_combobox_with_config (self ):
129162 config_path = Path (__file__ ).resolve ().parent / "config.yaml"
130163 if config_path .exists ():
@@ -141,6 +174,7 @@ def sync_combobox_with_config(self):
141174 self .model_combobox .setCurrentIndex (0 )
142175 else :
143176 self .model_combobox .setCurrentIndex (0 )
177+
144178 def on_model_selected (self , index ):
145179 selected_path = self .model_combobox .itemData (index )
146180 config_path = Path (__file__ ).resolve ().parent / "config.yaml"
@@ -165,6 +199,7 @@ def on_model_selected(self, index):
165199 config_data .pop ("EMBEDDING_MODEL_DIMENSIONS" , None )
166200 with open (config_path , 'w' , encoding = 'utf-8' ) as file :
167201 yaml .safe_dump (config_data , file , allow_unicode = True )
202+
168203 def create_group_box (self , title , directory_name ):
169204 group_box = QGroupBox (title )
170205 layout = QVBoxLayout ()
@@ -174,11 +209,13 @@ def create_group_box(self, title, directory_name):
174209 self .layout .addWidget (group_box )
175210 group_box .toggled .connect (lambda checked , gb = group_box : self .toggle_group_box (gb , checked ))
176211 return group_box
212+
177213 def _refresh_docs_model (self ):
178214 if hasattr (self .docs_model , 'refresh' ):
179215 self .docs_model .refresh ()
180216 elif hasattr (self .docs_model , 'reindex' ):
181217 self .docs_model .reindex ()
218+
182219 def setup_directory_view (self , directory_name ):
183220 tree_view = QTreeView ()
184221 model = CustomFileSystemModel ()
@@ -201,57 +238,71 @@ def setup_directory_view(self, directory_name):
201238 self .docs_refresh .setInterval (500 )
202239 self .docs_refresh .timeout .connect (self ._refresh_docs_model )
203240 return tree_view
241+
204242 def on_double_click (self , index ):
205243 tree_view = self .sender ()
206244 model = tree_view .model ()
207245 file_path = model .filePath (index )
208246 open_file (file_path )
247+
209248 def on_context_menu (self , point ):
210249 tree_view = self .sender ()
211250 context_menu = QMenu (self )
212251 delete_action = QAction ("Delete File" , self )
213252 context_menu .addAction (delete_action )
214253 delete_action .triggered .connect (lambda : self .on_delete_file (tree_view ))
215254 context_menu .exec_ (tree_view .viewport ().mapToGlobal (point ))
255+
216256 def on_delete_file (self , tree_view ):
217257 selected_indexes = tree_view .selectedIndexes ()
218258 model = tree_view .model ()
219259 for index in selected_indexes :
220260 if index .column () == 0 :
221261 file_path = model .filePath (index )
222262 delete_file (file_path )
263+
223264 def on_create_db_clicked (self ):
224265 if self .model_combobox .currentIndex () == 0 :
225266 QMessageBox .warning (self , "No Model Selected" , "Please select a model before creating a database." )
226267 return
268+
227269 self .create_db_button .setDisabled (True )
228270 self .choose_docs_button .setDisabled (True )
229271 self .model_combobox .setDisabled (True )
230272 self .database_name_input .setDisabled (True )
273+
231274 database_name = self .database_name_input .text ().strip ()
232- model_name = self .model_combobox .currentText ()
233- script_dir = Path (__file__ ).resolve ().parent
234- checks_passed , message = check_preconditions_for_db_creation (script_dir , database_name )
235- if not checks_passed :
236- self .create_db_button .setDisabled (False )
237- self .choose_docs_button .setDisabled (False )
238- self .model_combobox .setDisabled (False )
239- self .database_name_input .setDisabled (False )
240- QMessageBox .warning (self , "Validation Failed" , message )
241- return
242- self .create_database_thread = CreateDatabaseThread (database_name = database_name , model_name = model_name , parent = self )
275+ model_name = self .model_combobox .currentText ()
276+
277+ docs_dir = Path (__file__ ).resolve ().parent / "Docs_for_DB"
278+ has_pdfs = any (p .suffix .lower () == ".pdf" for p in docs_dir .iterdir () if p .is_file ())
279+ skip_ocr = False
280+ if has_pdfs :
281+ reply = QMessageBox .question (self , "OCR Check" ,
282+ "PDF files detected. Do you want to check if any of the PDFs need OCR? "
283+ "If there are a lot of PDFs, it is time-consuming but strongly recommended." ,
284+ QMessageBox .Yes | QMessageBox .No ,
285+ QMessageBox .Yes )
286+ skip_ocr = (reply == QMessageBox .No )
287+
288+ self .create_database_thread = CreateDatabaseThread (database_name , model_name , skip_ocr , parent = self )
289+
243290 self .create_database_thread .creationComplete .connect (self .reenable_create_db_button )
291+ self .create_database_thread .validationFailed .connect (self ._validation_failed )
244292 self .create_database_thread .start ()
293+
245294 def reenable_create_db_button (self ):
246295 self .create_db_button .setDisabled (False )
247296 self .choose_docs_button .setDisabled (False )
248297 self .model_combobox .setDisabled (False )
249298 self .database_name_input .setDisabled (False )
250299 self .create_database_thread = None
251300 gc .collect ()
301+
252302 def toggle_group_box (self , group_box , checked ):
253303 self .groups [group_box ] = 1 if checked else 0
254304 self .adjust_stretch ()
305+
255306 def adjust_stretch (self ):
256307 for group , stretch in self .groups .items ():
257308 self .layout .setStretchFactor (group , stretch if group .isChecked () else 0 )
0 commit comments