Skip to content

Commit 68f978e

Browse files
authored
Add files via upload
1 parent 32827e5 commit 68f978e

10 files changed

+718
-446
lines changed

src/chat_kobold.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import json
2+
from pathlib import Path
3+
import requests
4+
import logging
5+
import sseclient
6+
from PySide6.QtCore import QThread, Signal, QObject
7+
from database_interactions import QueryVectorDB
8+
9+
ROOT_DIRECTORY = Path(__file__).resolve().parent
10+
contexts_output_file_path = ROOT_DIRECTORY / "contexts.txt"
11+
metadata_output_file_path = ROOT_DIRECTORY / "metadata.txt"
12+
13+
class KoboldSignals(QObject):
14+
response_signal = Signal(str)
15+
error_signal = Signal(str)
16+
finished_signal = Signal()
17+
citation_signal = Signal(str)
18+
19+
class KoboldAPIWorker(QThread):
20+
def __init__(self, url, payload):
21+
super().__init__()
22+
self.url = url
23+
self.payload = payload
24+
self.signals = KoboldSignals()
25+
26+
def run(self):
27+
try:
28+
response = requests.post(self.url, json=self.payload, stream=True)
29+
response.raise_for_status()
30+
client = sseclient.SSEClient(response)
31+
for event in client.events():
32+
if event.event == "message":
33+
try:
34+
data = json.loads(event.data)
35+
if 'token' in data:
36+
logging.debug(f"Received token: {data['token']}")
37+
self.signals.response_signal.emit(data['token']) # Corrected this line
38+
else:
39+
logging.warning(f"Unexpected data format: {data}")
40+
except json.JSONDecodeError:
41+
logging.error(f"Failed to parse JSON: {event.data}")
42+
self.signals.error_signal.emit(f"Failed to parse: {event.data}") # Corrected this line
43+
else:
44+
logging.info(f"Received non-message event: {event.event}")
45+
except Exception as e:
46+
logging.error(f"Error in API request: {str(e)}")
47+
self.signals.error_signal.emit(str(e)) # Corrected this line
48+
finally:
49+
self.signals.finished_signal.emit()
50+
51+
class KoboldChat:
52+
def __init__(self):
53+
self.signals = KoboldSignals()
54+
self.api_url = "http://localhost:5001/api/extra/generate/stream"
55+
self.query_vector_db = None
56+
57+
def ask_kobold(self, query, chunks_only, selected_database):
58+
logging.debug(f"ask_kobold called with query: {query}, chunks_only: {chunks_only}, selected_database: {selected_database}")
59+
60+
if self.query_vector_db is None or self.query_vector_db.selected_database != selected_database:
61+
logging.debug(f"Initializing QueryVectorDB with database: {selected_database}")
62+
self.query_vector_db = QueryVectorDB(selected_database)
63+
64+
contexts, metadata_list = self.query_vector_db.search(query)
65+
logging.debug(f"Retrieved {len(contexts)} contexts from vector database")
66+
67+
if chunks_only:
68+
logging.debug("Chunks only mode, displaying contexts")
69+
self.display_chunks(contexts, metadata_list)
70+
self.signals.finished_signal.emit()
71+
return
72+
73+
prepend_string = "Only base your answer on the provided context/contexts. If you cannot, please state so."
74+
augmented_query = f"{prepend_string}\n\n---\n\n" + "\n\n---\n\n".join(contexts) + f"\n\n-----\n\n{query}"
75+
logging.debug(f"Augmented query: {augmented_query[:100]}...") # Log first 100 characters of augmented query
76+
77+
payload = {
78+
"prompt": augmented_query,
79+
"max_context_length": 4096,
80+
"max_length": 512,
81+
"temperature": 0.1,
82+
"top_p": 0.9,
83+
"rep_pen": 1.1
84+
}
85+
86+
logging.debug("Creating KoboldAPIWorker")
87+
self.worker = KoboldAPIWorker(self.api_url, payload)
88+
self.worker.signals.response_signal.connect(self.on_response_received)
89+
self.worker.signals.error_signal.connect(self.signals.error_signal.emit)
90+
self.worker.signals.finished_signal.connect(self.on_response_finished)
91+
logging.debug("Starting Kobold API worker")
92+
self.worker.start()
93+
94+
self.metadata_list = metadata_list # Store for citation use later
95+
logging.debug("ask_kobold method completed")
96+
97+
def on_response_received(self, token):
98+
logging.debug(f"Response received in KoboldChat: {token}")
99+
self.signals.response_signal.emit(token)
100+
101+
def display_chunks(self, contexts, metadata_list):
102+
formatted_chunks = self.format_chunks(contexts, metadata_list)
103+
self.signals.response_signal.emit(formatted_chunks)
104+
105+
def format_chunks(self, contexts, metadata_list):
106+
formatted_chunks = ""
107+
for i, (context, metadata) in enumerate(zip(contexts, metadata_list), 1):
108+
formatted_chunks += f"---------- Context {i} | From File: {metadata.get('file_name', 'Unknown')} ----------\n\n{context}\n\n"
109+
return formatted_chunks
110+
111+
def on_response_finished(self):
112+
self.signals.citation_signal.emit(self.format_citations(self.metadata_list))
113+
self.signals.finished_signal.emit()
114+
115+
def format_citations(self, metadata_list):
116+
return "\n".join([Path(metadata['file_path']).name for metadata in metadata_list])
117+
118+
class KoboldChatThread(QThread):
119+
def __init__(self, query, chunks_only, selected_database):
120+
super().__init__()
121+
self.query = query
122+
self.chunks_only = chunks_only
123+
self.selected_database = selected_database
124+
self.kobold_chat = KoboldChat()
125+
126+
def run(self):
127+
logging.debug("KoboldChatThread started running")
128+
try:
129+
self.kobold_chat.ask_kobold(self.query, self.chunks_only, self.selected_database)
130+
except Exception as e:
131+
logging.error(f"Error in KoboldChatThread: {str(e)}")
132+
self.kobold_chat.signals.error_signal.emit(str(e))

src/constants.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
'max_sequence': 512,
77
'size_mb': 134,
88
'repo_id': 'BAAI/bge-small-en-v1.5',
9-
'cache_dir': 'BAAI--bge-small-en-v1.5',
9+
'cache_dir': 'BAAI_bge-small-en-v1.5',
1010
'type': 'vector'
1111
},
1212
{
@@ -15,7 +15,7 @@
1515
'max_sequence': 512,
1616
'size_mb': 438,
1717
'repo_id': 'BAAI/bge-base-en-v1.5',
18-
'cache_dir': 'BAAI--bge-base-en-v1.5',
18+
'cache_dir': 'BAAI-bge-base-en-v1.5',
1919
'type': 'vector'
2020
},
2121
{
@@ -24,7 +24,7 @@
2424
'max_sequence': 512,
2525
'size_mb': 1340,
2626
'repo_id': 'BAAI/bge-large-en-v1.5',
27-
'cache_dir': 'BAAI--bge-large-en-v1.5',
27+
'cache_dir': 'BAAI_bge-large-en-v1.5',
2828
'type': 'vector'
2929
},
3030
],
@@ -35,7 +35,7 @@
3535
'max_sequence': 512,
3636
'size_mb': 439,
3737
'repo_id': 'hkunlp/instructor-base',
38-
'cache_dir': 'hkunlp--instructor-base',
38+
'cache_dir': 'hkunlp_instructor-base',
3939
'type': 'vector'
4040
},
4141
{
@@ -44,7 +44,7 @@
4444
'max_sequence': 512,
4545
'size_mb': 1340,
4646
'repo_id': 'hkunlp/instructor-large',
47-
'cache_dir': 'hkunlp--instructor-large',
47+
'cache_dir': 'hkunlp_instructor-large',
4848
'type': 'vector'
4949
},
5050
{
@@ -53,7 +53,7 @@
5353
'max_sequence': 512,
5454
'size_mb': 4960,
5555
'repo_id': 'hkunlp/instructor-xl',
56-
'cache_dir': 'hkunlp--instructor-xl',
56+
'cache_dir': 'hkunlp_instructor-xl',
5757
'type': 'vector'
5858
},
5959
],
@@ -64,7 +64,7 @@
6464
'max_sequence': 256,
6565
'size_mb': 120,
6666
'repo_id': 'sentence-transformers/all-MiniLM-L12-v2',
67-
'cache_dir': 'sentence-transformers--all-MiniLM-L12-v2',
67+
'cache_dir': 'sentence-transformers_all-MiniLM-L12-v2',
6868
'type': 'vector'
6969
},
7070
{
@@ -73,7 +73,7 @@
7373
'max_sequence': 384,
7474
'size_mb': 438,
7575
'repo_id': 'sentence-transformers/all-mpnet-base-v2',
76-
'cache_dir': 'sentence-transformers--all-mpnet-base-v2',
76+
'cache_dir': 'sentence-transformers_all-mpnet-base-v2',
7777
'type': 'vector'
7878
},
7979
],
@@ -84,7 +84,7 @@
8484
'max_sequence': 512,
8585
'size_mb': 67,
8686
'repo_id': 'thenlper/gte-small',
87-
'cache_dir': 'thenlper--gte-small',
87+
'cache_dir': 'thenlper_gte-small',
8888
'type': 'vector'
8989
},
9090
{
@@ -93,7 +93,7 @@
9393
'max_sequence': 512,
9494
'size_mb': 219,
9595
'repo_id': 'thenlper/gte-base',
96-
'cache_dir': 'thenlper--gte-base',
96+
'cache_dir': 'thenlper_gte-base',
9797
'type': 'vector'
9898
},
9999
{
@@ -102,7 +102,7 @@
102102
'max_sequence': 512,
103103
'size_mb': 670,
104104
'repo_id': 'thenlper/gte-large',
105-
'cache_dir': 'thenlper--gte-large',
105+
'cache_dir': 'thenlper_gte-large',
106106
'type': 'vector'
107107
},
108108
],
@@ -114,21 +114,21 @@
114114
'precision': 'autoselect',
115115
'size': '232m',
116116
'repo_id': 'microsoft/Florence-2-base',
117-
'cache_dir': 'microsoft--Florence-2-base',
117+
'cache_dir': 'vision',
118118
'requires_cuda': False
119119
},
120120
'Florence-2-large': {
121121
'precision': 'autoselect',
122122
'size': '770m',
123123
'repo_id': 'microsoft/Florence-2-large',
124-
'cache_dir': 'microsoft--Florence-2-large',
124+
'cache_dir': 'vision',
125125
'requires_cuda': False
126126
},
127127
'Moondream2': {
128128
'precision': 'float16',
129129
'size': '2b',
130130
'repo_id': 'vikhyatk/moondream2',
131-
'cache_dir': 'vikhyatk--moondream2',
131+
'cache_dir': 'vision',
132132
'requires_cuda': True
133133
}
134134
}

src/database_interactions.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ def load_config(self, root_directory):
4848
return yaml.safe_load(stream)
4949

5050
@torch.inference_mode()
51-
def initialize_vector_model(self, embedding_model_name, config_data):
52-
EMBEDDING_MODEL_NAME = config_data.get("EMBEDDING_MODEL_NAME")
51+
def initialize_vector_model(self, config_data):
52+
model_name = config_data.get("EMBEDDING_MODEL_NAME")
53+
cache_folder = Path.cwd() / "Models" / "vector"
5354
compute_device = config_data['Compute_Device']['database_creation']
5455
model_kwargs = {"device": compute_device, "trust_remote_code": True}
5556
encode_kwargs = {'normalize_embeddings': True, 'batch_size': 8}
@@ -58,55 +59,59 @@ def initialize_vector_model(self, embedding_model_name, config_data):
5859
encode_kwargs['batch_size'] = 2
5960
else:
6061
batch_size_mapping = {
61-
'sentence-t5-xxl': 1,
62-
('instructor-xl', 'sentence-t5-xl'): 2,
63-
'instructor-large': 3,
64-
('jina-embedding-l', 'bge-large', 'gte-large', 'roberta-large'): 4,
65-
'jina-embedding-s': 9,
66-
('bge-small', 'gte-small'): 10,
67-
('MiniLM',): 30,
62+
'instructor-xl': 2,
63+
'bge-large': 4,
64+
'instructor-large': 4,
65+
'gte-large': 4,
66+
'instructor-base': 6,
67+
'mpnet': 8,
68+
'bge-base': 8,
69+
'gte-base': 8,
70+
'bge-small': 10,
71+
'gte-small': 10,
72+
'MiniLM': 30,
6873
}
6974

7075
for key, value in batch_size_mapping.items():
7176
if isinstance(key, tuple):
72-
if any(model_name_part in EMBEDDING_MODEL_NAME for model_name_part in key):
77+
if any(model_name_part in model_name for model_name_part in key):
7378
encode_kwargs['batch_size'] = value
7479
break
7580
else:
76-
if key in EMBEDDING_MODEL_NAME:
81+
if key in model_name:
7782
encode_kwargs['batch_size'] = value
7883
break
7984

80-
if "instructor" in embedding_model_name:
85+
if "instructor" in model_name:
8186
encode_kwargs['show_progress_bar'] = True
8287

8388
model = HuggingFaceInstructEmbeddings(
84-
model_name=embedding_model_name,
89+
model_name=model_name,
8590
model_kwargs=model_kwargs,
86-
encode_kwargs=encode_kwargs,
91+
cache_folder=str(cache_folder)
8792
)
8893

89-
elif "bge" in embedding_model_name:
94+
elif "bge" in model_name:
9095
query_instruction = config_data['embedding-models']['bge'].get('query_instruction')
9196
encode_kwargs['show_progress_bar'] = True
9297

9398
model = HuggingFaceBgeEmbeddings(
94-
model_name=embedding_model_name,
99+
model_name=model_name,
95100
model_kwargs=model_kwargs,
96101
query_instruction=query_instruction,
97-
encode_kwargs=encode_kwargs
102+
cache_folder=str(cache_folder)
98103
)
99104

100105
else:
101106
# model_kwargs["trust_remote_code"] = True
102107
model = HuggingFaceEmbeddings(
103-
model_name=embedding_model_name,
108+
model_name=model_name,
104109
show_progress=True,
105110
model_kwargs=model_kwargs,
106-
encode_kwargs=encode_kwargs
111+
encode_kwargs=encode_kwargs,
112+
cache_folder=str(cache_folder)
107113
)
108114

109-
model_name = Path(EMBEDDING_MODEL_NAME).name
110115
my_cprint(f"{model_name} vector model loaded into memory.", "green")
111116

112117
return model, encode_kwargs
@@ -224,7 +229,6 @@ def save_documents_to_pickle(self, documents):
224229
@torch.inference_mode()
225230
def run(self):
226231
config_data = self.load_config(self.ROOT_DIRECTORY)
227-
EMBEDDING_MODEL_NAME = config_data.get("EMBEDDING_MODEL_NAME")
228232

229233
# create a list to hold langchain "document objects"
230234
# langchain_core.documents.base.Document
@@ -265,7 +269,7 @@ def run(self):
265269
self.save_document_structures(texts) # optional for troubleshooting
266270

267271
# initialize vector model
268-
embeddings, encode_kwargs = self.initialize_vector_model(EMBEDDING_MODEL_NAME, config_data)
272+
embeddings, encode_kwargs = self.initialize_vector_model(config_data)
269273

270274
# create database
271275
if isinstance(texts, list) and texts:
@@ -303,25 +307,30 @@ def initialize_vector_model(self):
303307
compute_device = self.config['Compute_Device']['database_query']
304308
encode_kwargs = {'normalize_embeddings': True, 'batch_size': 1}
305309

310+
cache_folder = str(Path.cwd() / "Models" / "vector")
311+
306312
if "instructor" in model_path:
307313
return HuggingFaceInstructEmbeddings(
308314
model_name=model_path,
309315
model_kwargs={"device": compute_device},
310316
encode_kwargs=encode_kwargs,
317+
cache_folder=cache_folder
311318
)
312319
elif "bge" in model_path:
313320
query_instruction = self.config['embedding-models']['bge']['query_instruction']
314321
return HuggingFaceBgeEmbeddings(
315322
model_name=model_path,
316323
model_kwargs={"device": compute_device},
317324
query_instruction=query_instruction,
318-
encode_kwargs=encode_kwargs
325+
encode_kwargs=encode_kwargs,
326+
cache_folder=cache_folder
319327
)
320328
else:
321329
return HuggingFaceEmbeddings(
322330
model_name=model_path,
323331
model_kwargs={"device": compute_device, "trust_remote_code": True},
324-
encode_kwargs=encode_kwargs
332+
encode_kwargs=encode_kwargs,
333+
cache_folder=cache_folder
325334
)
326335

327336
def initialize_database(self):

0 commit comments

Comments
 (0)