Skip to content

Commit 5bc56e0

Browse files
authored
fix model download status
1 parent de19a94 commit 5bc56e0

File tree

2 files changed

+35
-25
lines changed

2 files changed

+35
-25
lines changed

src/download_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def cleanup_incomplete_download(self):
5959
def get_model_directory_name(self):
6060
repo_id = self.get_model_url()
6161
if isinstance(repo_id, str):
62-
return repo_id.replace("/", "_")
62+
# Use double dash to match the cache_dir format in constants
63+
return repo_id.replace("/", "--")
6364
return str(repo_id)
6465

6566
def get_model_directory(self):

src/gui_tabs_models.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,22 @@ def __init__(self, parent=None):
2222
self.model_radiobuttons.setExclusive(True)
2323
self.stretch_factors = {
2424
'BAAI': 4,
25-
# 'NovaSearch': 3,
2625
'intfloat': 4,
27-
# 'Alibaba-NLP': 2,
28-
# 'Google': 2,
2926
'IBM': 3,
3027
'infly': 3,
3128
'Snowflake': 3,
32-
'Qwen': 4
29+
'Qwen': 4,
30+
'Google': 2, # Added if you use it
3331
}
3432

3533
models_dir = Path('Models')
3634
if not models_dir.exists():
3735
models_dir.mkdir(parents=True)
3836

37+
# Use consistent lowercase 'vector'
3938
vector_models_dir = models_dir / "vector"
39+
if not vector_models_dir.exists():
40+
vector_models_dir.mkdir(parents=True)
4041

4142
existing_vector_directories = {d.name for d in vector_models_dir.iterdir() if d.is_dir()}
4243

@@ -63,7 +64,7 @@ def add_centered_widget(grid, widget, row, col):
6364
group_layout.setVerticalSpacing(0)
6465
group_layout.setHorizontalSpacing(0)
6566
group_box.setLayout(group_layout)
66-
group_layout.setContentsMargins(0, 10, 0, 0) # left, top, right, bottom
67+
group_layout.setContentsMargins(0, 10, 0, 0)
6768

6869
size_policy = group_box.sizePolicy()
6970
size_policy.setVerticalStretch(self.stretch_factors.get(vendor, 1))
@@ -87,7 +88,7 @@ def add_centered_widget(grid, widget, row, col):
8788
row = grid.rowCount()
8889

8990
radiobutton = QRadioButton()
90-
radiobutton.setToolTip(TOOLTIPS["VECTOR_MODEL_SELECT"])
91+
radiobutton.setToolTip(TOOLTIPS.get("VECTOR_MODEL_SELECT", ""))
9192
self.model_radiobuttons.addButton(radiobutton, row_counter)
9293
add_centered_widget(grid, radiobutton, row, 0)
9394

@@ -96,45 +97,50 @@ def add_centered_widget(grid, widget, row, col):
9697
model_name_label.setText(f'<a style="color: #00bf9e" href="https://huggingface.co/{model["repo_id"]}">{model["name"]}</a>')
9798
model_name_label.setOpenExternalLinks(False)
9899
model_name_label.linkActivated.connect(self.open_link)
99-
model_name_label.setToolTip(TOOLTIPS["VECTOR_MODEL_NAME"])
100+
model_name_label.setToolTip(TOOLTIPS.get("VECTOR_MODEL_NAME", ""))
100101
add_centered_widget(grid, model_name_label, row, 1)
101102

102103
precision_label = QLabel(str(model.get('precision', 'N/A')))
103-
precision_label.setToolTip(TOOLTIPS["VECTOR_MODEL_PRECISION"])
104+
precision_label.setToolTip(TOOLTIPS.get("VECTOR_MODEL_PRECISION", ""))
104105
add_centered_widget(grid, precision_label, row, 2)
105106

106107
parameters_label = QLabel(str(model.get('parameters', 'N/A')))
107108
parameters_label.setToolTip(TOOLTIPS.get("VECTOR_MODEL_PARAMETERS", ""))
108109
add_centered_widget(grid, parameters_label, row, 3)
109110

110111
dimensions_label = QLabel(str(model['dimensions']))
111-
dimensions_label.setToolTip(TOOLTIPS["VECTOR_MODEL_DIMENSIONS"])
112+
dimensions_label.setToolTip(TOOLTIPS.get("VECTOR_MODEL_DIMENSIONS", ""))
112113
add_centered_widget(grid, dimensions_label, row, 4)
113114

114115
max_sequence_label = QLabel(str(model['max_sequence']))
115-
max_sequence_label.setToolTip(TOOLTIPS["VECTOR_MODEL_MAX_SEQUENCE"])
116+
max_sequence_label.setToolTip(TOOLTIPS.get("VECTOR_MODEL_MAX_SEQUENCE", ""))
116117
add_centered_widget(grid, max_sequence_label, row, 5)
117118

118119
size_label = QLabel(str(model['size_mb']))
119-
size_label.setToolTip(TOOLTIPS["VECTOR_MODEL_SIZE"])
120+
size_label.setToolTip(TOOLTIPS.get("VECTOR_MODEL_SIZE", ""))
120121
add_centered_widget(grid, size_label, row, 6)
121122

122-
expected_dir_name = ModelDownloader(model_info, model['type']).get_model_directory_name()
123+
# Use cache_dir if it exists, otherwise generate from repo_id
124+
if 'cache_dir' in model:
125+
expected_dir_name = model['cache_dir']
126+
else:
127+
expected_dir_name = ModelDownloader(model_info, model['type']).get_model_directory_name()
128+
123129
is_downloaded = expected_dir_name in existing_vector_directories
124130
downloaded_label = QLabel('Yes' if is_downloaded else 'No')
125-
downloaded_label.setToolTip(TOOLTIPS["VECTOR_MODEL_DOWNLOADED"])
131+
downloaded_label.setToolTip(TOOLTIPS.get("VECTOR_MODEL_DOWNLOADED", ""))
126132
add_centered_widget(grid, downloaded_label, row, 7)
127133
radiobutton.setEnabled(not is_downloaded)
128134

129-
self.downloaded_labels[f"{vendor}/{model['name']}"] = (downloaded_label, model_info)
135+
self.downloaded_labels[f"{vendor}/{model['name']}"] = (downloaded_label, model_info, radiobutton)
130136

131137
row_counter += 1
132138

133139
for vendor, group_box in self.group_boxes.items():
134140
self.main_layout.addWidget(group_box)
135141

136142
self.download_button = QPushButton('Download Selected Model')
137-
self.download_button.setToolTip(TOOLTIPS["DOWNLOAD_MODEL"])
143+
self.download_button.setToolTip(TOOLTIPS.get("DOWNLOAD_MODEL", ""))
138144
self.download_button.clicked.connect(self.initiate_model_download)
139145
self.main_layout.addWidget(self.download_button)
140146

@@ -143,28 +149,31 @@ def add_centered_widget(grid, widget, row, col):
143149
def initiate_model_download(self):
144150
selected_id = self.model_radiobuttons.checkedId()
145151
if selected_id != -1:
146-
_, (_, model_info) = list(self.downloaded_labels.items())[selected_id - 1]
152+
_, model_info, _ = list(self.downloaded_labels.values())[selected_id - 1]
147153
model_downloader = ModelDownloader(model_info, model_info['type'])
148154

149155
download_thread = threading.Thread(target=lambda: model_downloader.download())
150156
download_thread.start()
151157

152158
def update_model_downloaded_status(self, model_name, model_type):
153159
models_dir = Path('Models')
154-
vector_models_dir = models_dir / "Vector"
160+
# Use consistent lowercase 'vector'
161+
vector_models_dir = models_dir / "vector"
155162

156163
existing_vector_directories = {d.name for d in vector_models_dir.iterdir() if d.is_dir()}
157164

158165
for vendor, models in VECTOR_MODELS.items():
159166
for model in models:
160-
if model['cache_dir'] == model_name:
161-
downloaded_label, _ = self.downloaded_labels.get(f"{vendor}/{model['name']}", (None, None))
162-
if downloaded_label:
167+
# Check both cache_dir and the generated directory name
168+
cache_dir = model.get('cache_dir', '')
169+
generated_dir = model['repo_id'].replace('/', '--')
170+
171+
if cache_dir == model_name or generated_dir == model_name:
172+
key = f"{vendor}/{model['name']}"
173+
if key in self.downloaded_labels:
174+
downloaded_label, _, radiobutton = self.downloaded_labels[key]
163175
downloaded_label.setText('Yes')
164-
for button in self.model_radiobuttons.buttons():
165-
if button.text() == model['name']:
166-
button.setEnabled(False)
167-
break
176+
radiobutton.setEnabled(False)
168177
self.refresh_gui()
169178
return
170179

0 commit comments

Comments
 (0)