1313from database_interactions import QueryVectorDB
1414from utilities import format_citations , my_cprint , normalize_chat_text
1515from constants import rag_string
16+ from pathlib import Path
1617
1718class MessageType (Enum ):
1819 QUESTION = auto ()
@@ -111,20 +112,25 @@ def eject_model(self):
111112
112113 def _start_listening_thread (self ):
113114 import threading
114- self .listener_thread = threading .Thread (target = self ._listen_for_response , daemon = True )
115+
116+ if hasattr (self , "_stop_listener_event" ):
117+ self ._stop_listener_event .set ()
118+ if getattr (self , "listener_thread" , None ) and self .listener_thread .is_alive ():
119+ self .listener_thread .join ()
120+
121+ self ._stop_listener_event = threading .Event ()
122+ self .listener_thread = threading .Thread (
123+ target = self ._listen_for_response ,
124+ args = (self ._stop_listener_event ,),
125+ daemon = True ,
126+ )
115127 self .listener_thread .start ()
116128
117- def _listen_for_response (self ):
118- """
119- Listens every second for messages coming through the pipe from the child process. When a message is received, the
120- message type determines which signal is emitted.
121- """
122- while True :
129+ def _listen_for_response (self , stop_event ):
130+ while not stop_event .is_set ():
123131 if not self .model_pipe or not isinstance (self .model_pipe , PipeConnection ):
124132 break
125-
126133 try :
127- # checks every second for messages from "_local_model_process" that's being run in the child process
128134 if self .model_pipe .poll (timeout = 1 ):
129135 message = self .model_pipe .recv ()
130136 if message .type in [MessageType .RESPONSE , MessageType .PARTIAL_RESPONSE ]:
@@ -141,12 +147,11 @@ def _listen_for_response(self):
141147 self .signals .token_count_signal .emit (message .payload )
142148 else :
143149 time .sleep (0.1 )
144- except (BrokenPipeError , EOFError , OSError ) as e :
150+ except (BrokenPipeError , EOFError , OSError ):
145151 break
146152 except Exception as e :
147153 logging .warning (f"Unexpected error in _listen_for_response: { str (e )} " )
148154 break
149-
150155 self .cleanup_listener_resources ()
151156
152157 def cleanup_listener_resources (self ):
@@ -197,9 +202,17 @@ def _local_model_process(conn, model_name): # child process for local model's ge
197202 user_question_token_count = len (model_instance .tokenizer .encode (user_question ))
198203
199204 full_response = ""
205+ buffer = ""
200206 for partial_response in module_chat .generate_response (model_instance , augmented_query ):
201207 full_response += partial_response
202- conn .send (PipeMessage (MessageType .PARTIAL_RESPONSE , partial_response ))
208+ buffer += partial_response
209+
210+ if len (buffer ) >= 50 or '\n ' in buffer :
211+ conn .send (PipeMessage (MessageType .PARTIAL_RESPONSE , buffer ))
212+ buffer = ""
213+
214+ if buffer :
215+ conn .send (PipeMessage (MessageType .PARTIAL_RESPONSE , buffer ))
203216
204217 response_token_count = len (model_instance .tokenizer .encode (full_response ))
205218 remaining_tokens = model_instance .max_length - (prepend_token_count + user_question_token_count + context_token_count + response_token_count )
@@ -216,7 +229,9 @@ def _local_model_process(conn, model_name): # child process for local model's ge
216229
217230 conn .send (PipeMessage (MessageType .TOKEN_COUNTS , token_count_string ))
218231
219- with open ('chat_history.txt' , 'w' , encoding = 'utf-8' ) as f :
232+ script_dir = Path (__file__ ).resolve ().parent
233+ with open (script_dir / 'chat_history.txt' , 'w' , encoding = 'utf-8' ) as f :
234+
220235 normalized_response = normalize_chat_text (full_response )
221236 f .write (normalized_response )
222237 citations = format_citations (metadata_list )
0 commit comments