Skip to content

Commit 590ae1a

Browse files
authored
general improvements
1 parent fd6f71a commit 590ae1a

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

src/chat_local_model.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from database_interactions import QueryVectorDB
1414
from utilities import format_citations, my_cprint, normalize_chat_text
1515
from constants import rag_string
16+
from pathlib import Path
1617

1718
class MessageType(Enum):
1819
QUESTION = auto()
@@ -111,20 +112,25 @@ def eject_model(self):
111112

112113
def _start_listening_thread(self):
113114
import threading
114-
self.listener_thread = threading.Thread(target=self._listen_for_response, daemon=True)
115+
116+
if hasattr(self, "_stop_listener_event"):
117+
self._stop_listener_event.set()
118+
if getattr(self, "listener_thread", None) and self.listener_thread.is_alive():
119+
self.listener_thread.join()
120+
121+
self._stop_listener_event = threading.Event()
122+
self.listener_thread = threading.Thread(
123+
target=self._listen_for_response,
124+
args=(self._stop_listener_event,),
125+
daemon=True,
126+
)
115127
self.listener_thread.start()
116128

117-
def _listen_for_response(self):
118-
"""
119-
Listens every second for messages coming through the pipe from the child process. When a message is received, the
120-
message type determines which signal is emitted.
121-
"""
122-
while True:
129+
def _listen_for_response(self, stop_event):
130+
while not stop_event.is_set():
123131
if not self.model_pipe or not isinstance(self.model_pipe, PipeConnection):
124132
break
125-
126133
try:
127-
# checks every second for messages from "_local_model_process" that's being run in the child process
128134
if self.model_pipe.poll(timeout=1):
129135
message = self.model_pipe.recv()
130136
if message.type in [MessageType.RESPONSE, MessageType.PARTIAL_RESPONSE]:
@@ -141,12 +147,11 @@ def _listen_for_response(self):
141147
self.signals.token_count_signal.emit(message.payload)
142148
else:
143149
time.sleep(0.1)
144-
except (BrokenPipeError, EOFError, OSError) as e:
150+
except (BrokenPipeError, EOFError, OSError):
145151
break
146152
except Exception as e:
147153
logging.warning(f"Unexpected error in _listen_for_response: {str(e)}")
148154
break
149-
150155
self.cleanup_listener_resources()
151156

152157
def cleanup_listener_resources(self):
@@ -197,9 +202,17 @@ def _local_model_process(conn, model_name): # child process for local model's ge
197202
user_question_token_count = len(model_instance.tokenizer.encode(user_question))
198203

199204
full_response = ""
205+
buffer = ""
200206
for partial_response in module_chat.generate_response(model_instance, augmented_query):
201207
full_response += partial_response
202-
conn.send(PipeMessage(MessageType.PARTIAL_RESPONSE, partial_response))
208+
buffer += partial_response
209+
210+
if len(buffer) >= 50 or '\n' in buffer:
211+
conn.send(PipeMessage(MessageType.PARTIAL_RESPONSE, buffer))
212+
buffer = ""
213+
214+
if buffer:
215+
conn.send(PipeMessage(MessageType.PARTIAL_RESPONSE, buffer))
203216

204217
response_token_count = len(model_instance.tokenizer.encode(full_response))
205218
remaining_tokens = model_instance.max_length - (prepend_token_count + user_question_token_count + context_token_count + response_token_count)
@@ -216,7 +229,9 @@ def _local_model_process(conn, model_name): # child process for local model's ge
216229

217230
conn.send(PipeMessage(MessageType.TOKEN_COUNTS, token_count_string))
218231

219-
with open('chat_history.txt', 'w', encoding='utf-8') as f:
232+
script_dir = Path(__file__).resolve().parent
233+
with open(script_dir / 'chat_history.txt', 'w', encoding='utf-8') as f:
234+
220235
normalized_response = normalize_chat_text(full_response)
221236
f.write(normalized_response)
222237
citations = format_citations(metadata_list)

src/module_chat.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
import copy
77
from pathlib import Path
88
import torch
9-
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
9+
from transformers import (
10+
AutoTokenizer,
11+
AutoModelForCausalLM,
12+
TextIteratorStreamer,
13+
BitsAndBytesConfig,
14+
StoppingCriteria,
15+
StoppingCriteriaList
16+
)
1017
import threading
1118
from abc import ABC, abstractmethod
1219
import builtins

0 commit comments

Comments
 (0)