Skip to content

Commit 0ddc6d2

Browse files
committed
Refactor AI chat model for live streaming and precise responses
This commit introduces several changes to the AI chat model in the `aicodebot/cli.py` file. The main changes include: - The introduction of two new temperature settings: `PRECISE_TEMPERATURE` and `CREATIVE_TEMPERATURE`. These settings allow for more control over the randomness of the AI's responses. - The use of the `rich.live.Live` class for live streaming of the AI's responses. This is implemented in the `alignment`, `debug`, `fun_fact`, `review`, and `sidekick` commands. - The creation of a `RichLiveCallbackHandler` class that updates the live stream with each new token generated by the AI. - Adjustments to the `max_tokens` parameter in the `fun_fact` and `sidekick` commands. These changes aim to improve the user experience by providing real-time feedback from the AI and allowing for more precise responses. 🤖💬
1 parent 172bd5c commit 0ddc6d2

File tree

1 file changed

+69
-31
lines changed

1 file changed

+69
-31
lines changed

aicodebot/cli.py

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,32 @@
22
from aicodebot.agents import get_agent
33
from aicodebot.helpers import exec_and_get_output, get_token_length, git_diff_context
44
from dotenv import load_dotenv
5+
from langchain.callbacks.base import BaseCallbackHandler
56
from langchain.chains import LLMChain
67
from langchain.chat_models import ChatOpenAI
78
from langchain.prompts import load_prompt
89
from openai.api_resources import engine
910
from pathlib import Path
1011
from rich.console import Console
12+
from rich.live import Live
1113
from rich.markdown import Markdown
1214
from rich.style import Style
1315
import click, datetime, openai, os, random, subprocess, sys, tempfile, webbrowser
1416

1517
# ----------------------------- Default settings ----------------------------- #
1618

1719
DEFAULT_MAX_TOKENS = 512
18-
DEFAULT_TEMPERATURE = 0.1
20+
PRECISE_TEMPERATURE = 0
21+
CREATIVE_TEMPERATURE = 0.7
1922
DEFAULT_SPINNER = "point"
2023

2124
# ----------------------- Setup for rich console output ---------------------- #
25+
2226
console = Console()
2327
bot_style = Style(color="#30D5C8")
2428
error_style = Style(color="#FF0000")
2529
warning_style = Style(color="#FFA500")
2630

27-
2831
# -------------------------- Top level command group ------------------------- #
2932

3033

@@ -55,14 +58,21 @@ def alignment(verbose):
5558

5659
# Set up the language model
5760
model = get_llm_model(get_token_length(prompt.template))
58-
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)
5961

60-
# Set up the chain
61-
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
62+
with Live(Markdown(""), auto_refresh=True) as live:
63+
llm = ChatOpenAI(
64+
model=model,
65+
temperature=CREATIVE_TEMPERATURE,
66+
max_tokens=DEFAULT_MAX_TOKENS,
67+
verbose=verbose,
68+
streaming=True,
69+
callbacks=[RichLiveCallbackHandler(live)],
70+
)
6271

63-
with console.status("Generating an inspirational message", spinner=DEFAULT_SPINNER):
64-
response = chain.run({})
65-
console.print(Markdown(response), style=bot_style)
72+
# Set up the chain
73+
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
74+
75+
chain.run({})
6676

6777

6878
@cli.command()
@@ -109,13 +119,13 @@ def commit(verbose, response_token_size, yes, skip_pre_commit):
109119
console.print(f"Diff context token size: {request_token_size}, using model: {model}")
110120

111121
# Set up the language model
112-
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)
122+
llm = ChatOpenAI(model=model, temperature=PRECISE_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)
113123

114124
# Set up the chain
115125
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
116126

117127
console.print("The following files will be committed:\n" + files)
118-
with console.status("Generating the commit message", spinner=DEFAULT_SPINNER):
128+
with console.status("Examining the diff and generating the commit message", spinner=DEFAULT_SPINNER):
119129
response = chain.run(diff_context)
120130

121131
# Write the commit message to a temporary file
@@ -168,14 +178,20 @@ def debug(command, verbose):
168178

169179
# Set up the language model
170180
model = get_llm_model(get_token_length(error_output) + get_token_length(prompt.template))
171-
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)
172181

173-
# Set up the chain
174-
chat_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
182+
with Live(Markdown(""), auto_refresh=True) as live:
183+
llm = ChatOpenAI(
184+
model=model,
185+
temperature=PRECISE_TEMPERATURE,
186+
max_tokens=DEFAULT_MAX_TOKENS,
187+
verbose=verbose,
188+
streaming=True,
189+
callbacks=[RichLiveCallbackHandler(live)],
190+
)
175191

176-
with console.status("Debugging", spinner=DEFAULT_SPINNER):
177-
response = chat_chain.run(error_output)
178-
console.print(Markdown(response), style=bot_style)
192+
# Set up the chain
193+
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
194+
chain.run(error_output)
179195

180196
sys.exit(process.returncode)
181197

@@ -191,16 +207,22 @@ def fun_fact(verbose):
191207

192208
# Set up the language model
193209
model = get_llm_model(get_token_length(prompt.template))
194-
llm = ChatOpenAI(model=model, temperature=0.9, max_tokens=250, verbose=verbose)
195210

196-
# Set up the chain
197-
chat_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
211+
with Live(Markdown(""), auto_refresh=True) as live:
212+
llm = ChatOpenAI(
213+
model=model,
214+
temperature=PRECISE_TEMPERATURE,
215+
max_tokens=DEFAULT_MAX_TOKENS / 2,
216+
verbose=verbose,
217+
streaming=True,
218+
callbacks=[RichLiveCallbackHandler(live)],
219+
)
220+
221+
# Set up the chain
222+
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
198223

199-
with console.status("Fetching a fun fact", spinner=DEFAULT_SPINNER):
200-
# Select a random year so that we get a different answer each time
201224
year = random.randint(1942, datetime.datetime.utcnow().year)
202-
response = chat_chain.run(f"programming and artificial intelligence in the year {year}")
203-
console.print(Markdown(response), style=bot_style)
225+
chain.run(f"programming and artificial intelligence in the year {year}")
204226

205227

206228
@cli.command
@@ -225,15 +247,20 @@ def review(commit, verbose):
225247
if verbose:
226248
console.print(f"Diff context token size: {request_token_size}, using model: {model}")
227249

228-
# Set up the language model
229-
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=response_token_size, verbose=verbose)
250+
with Live(Markdown(""), auto_refresh=True) as live:
251+
llm = ChatOpenAI(
252+
model=model,
253+
temperature=PRECISE_TEMPERATURE,
254+
max_tokens=response_token_size,
255+
verbose=verbose,
256+
streaming=True,
257+
callbacks=[RichLiveCallbackHandler(live)],
258+
)
230259

231-
# Set up the chain
232-
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
260+
# Set up the chain
261+
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
233262

234-
with console.status("Reviewing code", spinner=DEFAULT_SPINNER):
235-
response = chain.run(diff_context)
236-
console.print(Markdown(response), style=bot_style)
263+
chain.run(diff_context)
237264

238265

239266
@cli.command
@@ -254,7 +281,7 @@ def sidekick(task, verbose):
254281
setup_environment()
255282

256283
model = get_llm_model()
257-
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=3500, verbose=verbose)
284+
llm = ChatOpenAI(model=model, temperature=PRECISE_TEMPERATURE, max_tokens=2000, verbose=verbose)
258285

259286
agent = get_agent("sidekick", llm, verbose)
260287

@@ -358,5 +385,16 @@ def get_llm_model(token_size=0):
358385
raise click.ClickException("🛑 The context is too large to for the Model. 😞")
359386

360387

388+
class RichLiveCallbackHandler(BaseCallbackHandler):
389+
buffer = []
390+
391+
def __init__(self, live):
392+
self.live = live
393+
394+
def on_llm_new_token(self, token, **kwargs):
395+
self.buffer.append(token)
396+
self.live.update(Markdown("".join(self.buffer)))
397+
398+
361399
if __name__ == "__main__":
362400
cli()

0 commit comments

Comments
 (0)