@@ -22,21 +22,22 @@ def __init__(self, parent=None):
2222 self .model_radiobuttons .setExclusive (True )
2323 self .stretch_factors = {
2424 'BAAI' : 4 ,
25- # 'NovaSearch': 3,
2625 'intfloat' : 4 ,
27- # 'Alibaba-NLP': 2,
28- # 'Google': 2,
2926 'IBM' : 3 ,
3027 'infly' : 3 ,
3128 'Snowflake' : 3 ,
32- 'Qwen' : 4
29+ 'Qwen' : 4 ,
30+ 'Google' : 2 , # Added if you use it
3331 }
3432
3533 models_dir = Path ('Models' )
3634 if not models_dir .exists ():
3735 models_dir .mkdir (parents = True )
3836
37+ # Use consistent lowercase 'vector'
3938 vector_models_dir = models_dir / "vector"
39+ if not vector_models_dir .exists ():
40+ vector_models_dir .mkdir (parents = True )
4041
4142 existing_vector_directories = {d .name for d in vector_models_dir .iterdir () if d .is_dir ()}
4243
@@ -63,7 +64,7 @@ def add_centered_widget(grid, widget, row, col):
6364 group_layout .setVerticalSpacing (0 )
6465 group_layout .setHorizontalSpacing (0 )
6566 group_box .setLayout (group_layout )
66- group_layout .setContentsMargins (0 , 10 , 0 , 0 ) # left, top, right, bottom
67+ group_layout .setContentsMargins (0 , 10 , 0 , 0 )
6768
6869 size_policy = group_box .sizePolicy ()
6970 size_policy .setVerticalStretch (self .stretch_factors .get (vendor , 1 ))
@@ -87,7 +88,7 @@ def add_centered_widget(grid, widget, row, col):
8788 row = grid .rowCount ()
8889
8990 radiobutton = QRadioButton ()
90- radiobutton .setToolTip (TOOLTIPS [ "VECTOR_MODEL_SELECT" ] )
91+ radiobutton .setToolTip (TOOLTIPS . get ( "VECTOR_MODEL_SELECT" , "" ) )
9192 self .model_radiobuttons .addButton (radiobutton , row_counter )
9293 add_centered_widget (grid , radiobutton , row , 0 )
9394
@@ -96,45 +97,50 @@ def add_centered_widget(grid, widget, row, col):
9697 model_name_label .setText (f'<a style="color: #00bf9e" href="https://huggingface.co/{ model ["repo_id" ]} ">{ model ["name" ]} </a>' )
9798 model_name_label .setOpenExternalLinks (False )
9899 model_name_label .linkActivated .connect (self .open_link )
99- model_name_label .setToolTip (TOOLTIPS [ "VECTOR_MODEL_NAME" ] )
100+ model_name_label .setToolTip (TOOLTIPS . get ( "VECTOR_MODEL_NAME" , "" ) )
100101 add_centered_widget (grid , model_name_label , row , 1 )
101102
102103 precision_label = QLabel (str (model .get ('precision' , 'N/A' )))
103- precision_label .setToolTip (TOOLTIPS [ "VECTOR_MODEL_PRECISION" ] )
104+ precision_label .setToolTip (TOOLTIPS . get ( "VECTOR_MODEL_PRECISION" , "" ) )
104105 add_centered_widget (grid , precision_label , row , 2 )
105106
106107 parameters_label = QLabel (str (model .get ('parameters' , 'N/A' )))
107108 parameters_label .setToolTip (TOOLTIPS .get ("VECTOR_MODEL_PARAMETERS" , "" ))
108109 add_centered_widget (grid , parameters_label , row , 3 )
109110
110111 dimensions_label = QLabel (str (model ['dimensions' ]))
111- dimensions_label .setToolTip (TOOLTIPS [ "VECTOR_MODEL_DIMENSIONS" ] )
112+ dimensions_label .setToolTip (TOOLTIPS . get ( "VECTOR_MODEL_DIMENSIONS" , "" ) )
112113 add_centered_widget (grid , dimensions_label , row , 4 )
113114
114115 max_sequence_label = QLabel (str (model ['max_sequence' ]))
115- max_sequence_label .setToolTip (TOOLTIPS [ "VECTOR_MODEL_MAX_SEQUENCE" ] )
116+ max_sequence_label .setToolTip (TOOLTIPS . get ( "VECTOR_MODEL_MAX_SEQUENCE" , "" ) )
116117 add_centered_widget (grid , max_sequence_label , row , 5 )
117118
118119 size_label = QLabel (str (model ['size_mb' ]))
119- size_label .setToolTip (TOOLTIPS [ "VECTOR_MODEL_SIZE" ] )
120+ size_label .setToolTip (TOOLTIPS . get ( "VECTOR_MODEL_SIZE" , "" ) )
120121 add_centered_widget (grid , size_label , row , 6 )
121122
122- expected_dir_name = ModelDownloader (model_info , model ['type' ]).get_model_directory_name ()
123+ # Use cache_dir if it exists, otherwise generate from repo_id
124+ if 'cache_dir' in model :
125+ expected_dir_name = model ['cache_dir' ]
126+ else :
127+ expected_dir_name = ModelDownloader (model_info , model ['type' ]).get_model_directory_name ()
128+
123129 is_downloaded = expected_dir_name in existing_vector_directories
124130 downloaded_label = QLabel ('Yes' if is_downloaded else 'No' )
125- downloaded_label .setToolTip (TOOLTIPS [ "VECTOR_MODEL_DOWNLOADED" ] )
131+ downloaded_label .setToolTip (TOOLTIPS . get ( "VECTOR_MODEL_DOWNLOADED" , "" ) )
126132 add_centered_widget (grid , downloaded_label , row , 7 )
127133 radiobutton .setEnabled (not is_downloaded )
128134
129- self .downloaded_labels [f"{ vendor } /{ model ['name' ]} " ] = (downloaded_label , model_info )
135+ self .downloaded_labels [f"{ vendor } /{ model ['name' ]} " ] = (downloaded_label , model_info , radiobutton )
130136
131137 row_counter += 1
132138
133139 for vendor , group_box in self .group_boxes .items ():
134140 self .main_layout .addWidget (group_box )
135141
136142 self .download_button = QPushButton ('Download Selected Model' )
137- self .download_button .setToolTip (TOOLTIPS [ "DOWNLOAD_MODEL" ] )
143+ self .download_button .setToolTip (TOOLTIPS . get ( "DOWNLOAD_MODEL" , "" ) )
138144 self .download_button .clicked .connect (self .initiate_model_download )
139145 self .main_layout .addWidget (self .download_button )
140146
@@ -143,28 +149,31 @@ def add_centered_widget(grid, widget, row, col):
143149 def initiate_model_download (self ):
144150 selected_id = self .model_radiobuttons .checkedId ()
145151 if selected_id != - 1 :
146- _ , ( _ , model_info ) = list (self .downloaded_labels .items ())[selected_id - 1 ]
152+ _ , model_info , _ = list (self .downloaded_labels .values ())[selected_id - 1 ]
147153 model_downloader = ModelDownloader (model_info , model_info ['type' ])
148154
149155 download_thread = threading .Thread (target = lambda : model_downloader .download ())
150156 download_thread .start ()
151157
152158 def update_model_downloaded_status (self , model_name , model_type ):
153159 models_dir = Path ('Models' )
154- vector_models_dir = models_dir / "Vector"
160+ # Use consistent lowercase 'vector'
161+ vector_models_dir = models_dir / "vector"
155162
156163 existing_vector_directories = {d .name for d in vector_models_dir .iterdir () if d .is_dir ()}
157164
158165 for vendor , models in VECTOR_MODELS .items ():
159166 for model in models :
160- if model ['cache_dir' ] == model_name :
161- downloaded_label , _ = self .downloaded_labels .get (f"{ vendor } /{ model ['name' ]} " , (None , None ))
162- if downloaded_label :
167+ # Check both cache_dir and the generated directory name
168+ cache_dir = model .get ('cache_dir' , '' )
169+ generated_dir = model ['repo_id' ].replace ('/' , '--' )
170+
171+ if cache_dir == model_name or generated_dir == model_name :
172+ key = f"{ vendor } /{ model ['name' ]} "
173+ if key in self .downloaded_labels :
174+ downloaded_label , _ , radiobutton = self .downloaded_labels [key ]
163175 downloaded_label .setText ('Yes' )
164- for button in self .model_radiobuttons .buttons ():
165- if button .text () == model ['name' ]:
166- button .setEnabled (False )
167- break
176+ radiobutton .setEnabled (False )
168177 self .refresh_gui ()
169178 return
170179
0 commit comments