Skip to content

Commit 383cd04

Browse files
authored
add whisper models and fix bug
1 parent b6c6981 commit 383cd04

File tree

3 files changed

+88
-36
lines changed

3 files changed

+88
-36
lines changed

src/gui.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import multiprocessing
2+
multiprocessing.set_start_method('spawn')
3+
14
import sys
25
import os
36
from pathlib import Path
@@ -7,11 +10,21 @@
710
QApplication, QWidget, QVBoxLayout, QTabWidget,
811
QStyleFactory, QMenuBar
912
)
10-
import multiprocessing
1113
from initialize import main as initialize_system
1214
from gui_tabs import create_tabs
1315
from utilities import list_theme_files, make_theme_changer, load_stylesheet
1416

17+
# Print the current working directory
18+
print(f"Current working directory: {os.getcwd()}")
19+
20+
# Check if we can write to the current directory
21+
try:
22+
with open('test_write.txt', 'w') as f:
23+
f.write("Testing write permissions")
24+
os.remove('test_write.txt')
25+
except Exception as e:
26+
print(f"Cannot write to the current directory: {e}")
27+
1528
logging.basicConfig(filename='gui_log.txt', level=logging.DEBUG,
1629
format='%(asctime)s - %(levelname)s - %(message)s')
1730

@@ -83,7 +96,7 @@ def closeEvent(self, event):
8396
def main():
8497
try:
8598
logging.info("Starting application")
86-
multiprocessing.set_start_method('spawn')
99+
# multiprocessing.set_start_method('spawn')
87100
app = QApplication(sys.argv)
88101
app.setStyleSheet(load_stylesheet('custom_stylesheet_steel_ocean.css'))
89102
ex = DocQA_GUI()

src/gui_tabs_tools_transcribe.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
11
import threading
2-
from functools import partial
32
from pathlib import Path
4-
53
import yaml
64
from PySide6.QtCore import Qt
75
from PySide6.QtWidgets import (
86
QWidget, QHBoxLayout, QVBoxLayout, QPushButton, QFileDialog, QLabel, QComboBox, QSlider
97
)
10-
118
from module_transcribe import WhisperTranscriber
129
from utilities import my_cprint
10+
from constants import WHISPER_MODELS
1311

1412
class TranscriberToolSettingsTab(QWidget):
1513
CONFIG_FILE = 'config.yaml'
1614

1715
def __init__(self):
1816
super().__init__()
1917
self.selected_audio_file = None
20-
2118
self.create_layout()
2219

2320
def read_config(self):
@@ -26,40 +23,36 @@ def read_config(self):
2623

2724
def create_layout(self):
2825
main_layout = QVBoxLayout()
29-
3026
model_selection_hbox = QHBoxLayout()
31-
model_selection_hbox.addWidget(QLabel("Whisper Model"))
27+
model_selection_hbox.addWidget(QLabel("Model"))
3228
self.model_combo = QComboBox()
33-
34-
self.model_name_mapping = {f"{size} - {precision}": f"ctranslate2-4you/whisper-{size}-ct2-{precision}"
35-
for size in ["large-v2", "medium.en", "small.en"]
36-
for precision in ["float32", "float16"]}
37-
38-
self.model_combo.addItems(list(self.model_name_mapping.keys()))
39-
29+
30+
# Use the WHISPER_MODELS dictionary to populate the combo box
31+
self.model_combo.addItems(WHISPER_MODELS.keys())
32+
4033
model_selection_hbox.addWidget(self.model_combo)
41-
42-
model_selection_hbox.addWidget(QLabel("Speed:"))
43-
34+
model_selection_hbox.addWidget(QLabel("Batch:"))
4435
self.slider_label = QLabel("8")
4536
self.number_slider = QSlider(Qt.Horizontal)
4637
self.number_slider.setMinimum(1)
4738
self.number_slider.setMaximum(150)
4839
self.number_slider.setValue(8)
4940
self.number_slider.valueChanged.connect(self.update_slider_label)
50-
5141
model_selection_hbox.addWidget(self.number_slider)
5242
model_selection_hbox.addWidget(self.slider_label)
53-
43+
44+
model_selection_hbox.setStretchFactor(self.model_combo, 2)
45+
model_selection_hbox.setStretchFactor(self.number_slider, 2)
46+
5447
main_layout.addLayout(model_selection_hbox)
5548

5649
hbox = QHBoxLayout()
5750
self.select_file_button = QPushButton("Select Audio File")
58-
self.select_file_button.clicked.connect(lambda: self.select_audio_file())
51+
self.select_file_button.clicked.connect(self.select_audio_file)
5952
hbox.addWidget(self.select_file_button)
6053

6154
self.transcribe_button = QPushButton("Transcribe")
62-
self.transcribe_button.clicked.connect(lambda: self.start_transcription())
55+
self.transcribe_button.clicked.connect(self.start_transcription)
6356
hbox.addWidget(self.transcribe_button)
6457

6558
main_layout.addLayout(hbox)
@@ -89,17 +82,16 @@ def start_transcription(self):
8982
if not self.selected_audio_file:
9083
print("Please select an audio file.")
9184
return
92-
93-
selected_model = self.model_combo.currentText()
94-
selected_model_identifier = self.model_name_mapping.get(selected_model, selected_model)
95-
96-
selected_compute_type = selected_model.rsplit(' - ', 1)[-1]
9785

86+
selected_model_key = self.model_combo.currentText()
9887
selected_batch_size = int(self.slider_label.text())
99-
88+
10089
def transcription_thread():
101-
transcriber = WhisperTranscriber(model_identifier=selected_model_identifier, batch_size=selected_batch_size, compute_type=selected_compute_type)
90+
transcriber = WhisperTranscriber(
91+
model_key=selected_model_key,
92+
batch_size=selected_batch_size
93+
)
10294
transcriber.start_transcription_process(self.selected_audio_file)
10395
my_cprint("Transcription created and ready to be input into vector database.", 'green')
104-
96+
10597
threading.Thread(target=transcription_thread, daemon=True).start()

src/module_transcribe.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,28 @@
1010
from langchain_community.docstore.document import Document
1111

1212
import whisper_s2t
13+
from whisper_s2t.backends.ctranslate2.hf_utils import download_model
1314
from extract_metadata import extract_audio_metadata
15+
from constants import WHISPER_MODELS
1416

1517
warnings.filterwarnings("ignore")
1618

19+
current_directory = Path(__file__).parent
20+
CACHE_DIR = current_directory / "Models" / "whisper"
21+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
22+
1723
class WhisperTranscriber:
18-
def __init__(self, model_identifier="ctranslate2-4you/whisper-mediuim.en-ct2-int8", batch_size=16, compute_type='int8'):
19-
self.model_identifier = model_identifier
24+
def __init__(self, model_key, batch_size):
25+
model_info = WHISPER_MODELS[model_key]
26+
self.model_identifier = model_info['repo_id']
27+
self.compute_type = model_info['precision']
2028
self.batch_size = batch_size
21-
self.compute_type = compute_type
29+
self.cache_dir = str(CACHE_DIR)
30+
31+
script_dir = Path(__file__).parent
32+
self.model_dir = script_dir / "Models" / "whisper"
33+
self.model_dir.mkdir(parents=True, exist_ok=True)
34+
2235
self.model_kwargs = {
2336
'compute_type': self.compute_type,
2437
'asr_options': {
@@ -42,25 +55,59 @@ def __init__(self, model_identifier="ctranslate2-4you/whisper-mediuim.en-ct2-int
4255
"return_no_speech_prob": True,
4356
"word_aligner_model": 'tiny',
4457
},
45-
'model_identifier': model_identifier,
58+
'model_identifier': self.model_identifier,
4659
'backend': 'CTranslate2',
4760
}
4861

62+
if 'large-v3' in self.model_identifier:
63+
self.model_kwargs['n_mels'] = 128
64+
4965
def start_transcription_process(self, audio_file):
5066
self.audio_file = audio_file
5167
process = Process(target=self.transcribe_and_create_document)
5268
process.start()
5369
process.join()
5470

71+
5572
@torch.inference_mode()
5673
def transcribe_and_create_document(self):
5774
audio_file_str = str(self.audio_file)
5875
converted_audio_file = self.convert_to_wav(audio_file_str)
59-
self.model_kwargs['model_identifier'] = self.model_identifier
60-
model = whisper_s2t.load_model(**self.model_kwargs)
76+
77+
try:
78+
downloaded_path = download_model(
79+
size_or_id=self.model_identifier,
80+
cache_dir=str(CACHE_DIR)
81+
)
82+
83+
model_kwargs = self.model_kwargs.copy()
84+
model_kwargs.pop('model_identifier', None)
85+
model_kwargs.pop('cache_dir', None)
86+
87+
model = whisper_s2t.load_model(
88+
model_identifier=downloaded_path,
89+
**model_kwargs
90+
)
91+
92+
except Exception as e:
93+
print(f"Error loading model {self.model_identifier}: {e}")
94+
raise
95+
6196
transcription = self.transcribe(model, [str(converted_audio_file)])
6297
self.create_document_object(transcription, audio_file_str)
6398

99+
script_dir = Path(__file__).parent
100+
converted_audio_file_name = f"{Path(audio_file_str).stem}_converted.wav"
101+
converted_audio_file_full_path = script_dir / converted_audio_file_name
102+
103+
if converted_audio_file_full_path.exists():
104+
try:
105+
converted_audio_file_full_path.unlink()
106+
except Exception as e:
107+
print(f"Error deleting file {converted_audio_file_full_path}: {e}")
108+
else:
109+
print(f"File does not exist: {converted_audio_file_full_path}")
110+
64111
def convert_to_wav(self, audio_file):
65112
output_file = f"{Path(audio_file).stem}_converted.wav"
66113
output_path = Path(__file__).parent / output_file

0 commit comments

Comments
 (0)