Skip to content

Commit 6926ad5

Browse files
committed
Refactor llm generation code to improve modularity and readability
The code has been refactored to improve modularity and readability. The `get_llm` method has been moved to the `Coder` class to centralize the creation of language models. This change simplifies the code and makes it easier to understand. The `RichLiveCallbackHandler` class has also been moved to the `helpers.py` file, which is a more appropriate location for it. Additionally, the `DEFAULT_MAX_TOKENS`, `PRECISE_TEMPERATURE`, and `CREATIVE_TEMPERATURE` constants have been moved to the `coder.py` file for better organization.
1 parent 3907bc3 commit 6926ad5

File tree

3 files changed

+73
-72
lines changed

3 files changed

+73
-72
lines changed

aicodebot/cli.py

Lines changed: 33 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from aicodebot import version as aicodebot_version
2-
from aicodebot.coder import Coder
2+
from aicodebot.coder import CREATIVE_TEMPERATURE, DEFAULT_MAX_TOKENS, Coder
33
from aicodebot.config import get_config_file, read_config
4-
from aicodebot.helpers import exec_and_get_output, logger
4+
from aicodebot.helpers import RichLiveCallbackHandler, exec_and_get_output, logger
55
from aicodebot.prompts import DEFAULT_PERSONALITY, PERSONALITIES, generate_files_context, get_prompt
6-
from langchain.callbacks.base import BaseCallbackHandler
76
from langchain.chains import LLMChain
8-
from langchain.chat_models import ChatOpenAI
97
from langchain.memory import ConversationTokenBufferMemory
108
from openai.api_resources import engine
119
from pathlib import Path
@@ -17,9 +15,6 @@
1715

1816
# ----------------------------- Default settings ----------------------------- #
1917

20-
DEFAULT_MAX_TOKENS = 512
21-
PRECISE_TEMPERATURE = 0.1
22-
CREATIVE_TEMPERATURE = 0.7
2318
DEFAULT_SPINNER = "point"
2419

2520
# ----------------------- Setup for rich console output ---------------------- #
@@ -54,7 +49,7 @@ def cli():
5449
@click.option("-t", "--response-token-size", type=int, default=350)
5550
def alignment(response_token_size, verbose):
5651
"""Get a message about Heart-Centered AI Alignment ❤ + 🤖."""
57-
config = setup_config()
52+
setup_config()
5853

5954
# Load the prompt
6055
prompt = get_prompt("alignment")
@@ -64,14 +59,13 @@ def alignment(response_token_size, verbose):
6459
model_name = Coder.get_llm_model_name(Coder.get_token_length(prompt.template))
6560

6661
with Live(Markdown(""), auto_refresh=True) as live:
67-
llm = ChatOpenAI(
68-
model=model_name,
62+
llm = Coder.get_llm(
63+
model_name,
64+
verbose,
65+
response_token_size,
6966
temperature=CREATIVE_TEMPERATURE,
70-
openai_api_key=config["openai_api_key"],
71-
max_tokens=response_token_size,
72-
verbose=verbose,
7367
streaming=True,
74-
callbacks=[RichLiveCallbackHandler(live)],
68+
callbacks=[RichLiveCallbackHandler(live, bot_style)],
7569
)
7670

7771
# Set up the chain
@@ -87,7 +81,7 @@ def alignment(response_token_size, verbose):
8781
@click.option("--skip-pre-commit", is_flag=True, help="Skip running pre-commit (otherwise run it if it is found).")
8882
def commit(verbose, response_token_size, yes, skip_pre_commit):
8983
"""Generate a commit message based on your changes."""
90-
config = setup_config()
84+
setup_config()
9185

9286
# Check if pre-commit is installed and .pre-commit-config.yaml exists
9387
if not skip_pre_commit and Path(".pre-commit-config.yaml").exists():
@@ -127,14 +121,7 @@ def commit(verbose, response_token_size, yes, skip_pre_commit):
127121
f"The diff is too large to generate a commit message ({request_token_size} tokens). 😢"
128122
)
129123

130-
# Set up the language model
131-
llm = ChatOpenAI(
132-
model=model_name,
133-
openai_api_key=config["openai_api_key"],
134-
temperature=PRECISE_TEMPERATURE,
135-
max_tokens=DEFAULT_MAX_TOKENS,
136-
verbose=verbose,
137-
)
124+
llm = Coder.get_llm(model_name, verbose, 350)
138125

139126
# Set up the chain
140127
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
@@ -264,7 +251,7 @@ def write_config_file(config_data):
264251
@click.option("-v", "--verbose", count=True)
265252
def debug(command, verbose):
266253
"""Run a command and debug the output."""
267-
config = setup_config()
254+
setup_config()
268255

269256
# Run the command and capture its output
270257
command_str = " ".join(command)
@@ -296,14 +283,11 @@ def debug(command, verbose):
296283
raise click.ClickException(f"The output is too large to debug ({request_token_size} tokens). 😢")
297284

298285
with Live(Markdown(""), auto_refresh=True) as live:
299-
llm = ChatOpenAI(
300-
model=model_name,
301-
temperature=PRECISE_TEMPERATURE,
302-
openai_api_key=config["openai_api_key"],
303-
max_tokens=DEFAULT_MAX_TOKENS,
304-
verbose=verbose,
286+
llm = Coder.get_llm(
287+
model_name,
288+
verbose,
305289
streaming=True,
306-
callbacks=[RichLiveCallbackHandler(live)],
290+
callbacks=[RichLiveCallbackHandler(live, bot_style)],
307291
)
308292

309293
# Set up the chain
@@ -315,10 +299,10 @@ def debug(command, verbose):
315299

316300
@cli.command()
317301
@click.option("-v", "--verbose", count=True)
318-
@click.option("-t", "--response-token-size", type=int, default=350)
302+
@click.option("-t", "--response-token-size", type=int, default=250)
319303
def fun_fact(verbose, response_token_size):
320304
"""Get a fun fact about programming and artificial intelligence."""
321-
config = setup_config()
305+
setup_config()
322306

323307
# Load the prompt
324308
prompt = get_prompt("fun_fact")
@@ -328,16 +312,14 @@ def fun_fact(verbose, response_token_size):
328312
model_name = Coder.get_llm_model_name(Coder.get_token_length(prompt.template))
329313

330314
with Live(Markdown(""), auto_refresh=True) as live:
331-
llm = ChatOpenAI(
332-
model=model_name,
333-
temperature=PRECISE_TEMPERATURE,
334-
max_tokens=response_token_size,
335-
openai_api_key=config["openai_api_key"],
336-
verbose=verbose,
315+
llm = Coder.get_llm(
316+
model_name,
317+
verbose,
318+
response_token_size=response_token_size,
319+
temperature=CREATIVE_TEMPERATURE,
337320
streaming=True,
338-
callbacks=[RichLiveCallbackHandler(live)],
321+
callbacks=[RichLiveCallbackHandler(live, bot_style)],
339322
)
340-
341323
# Set up the chain
342324
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
343325

@@ -350,7 +332,7 @@ def fun_fact(verbose, response_token_size):
350332
@click.option("-v", "--verbose", count=True)
351333
def review(commit, verbose):
352334
"""Do a code review, with [un]staged changes, or a specified commit."""
353-
config = setup_config()
335+
setup_config()
354336

355337
diff_context = Coder.git_diff_context(commit)
356338
if not diff_context:
@@ -369,14 +351,12 @@ def review(commit, verbose):
369351
raise click.ClickException(f"The diff is too large to review ({request_token_size} tokens). 😢")
370352

371353
with Live(Markdown(""), auto_refresh=True) as live:
372-
llm = ChatOpenAI(
373-
model=model_name,
374-
temperature=PRECISE_TEMPERATURE,
375-
openai_api_key=config["openai_api_key"],
376-
max_tokens=response_token_size,
377-
verbose=verbose,
354+
llm = Coder.get_llm(
355+
model_name,
356+
verbose,
357+
response_token_size=response_token_size,
378358
streaming=True,
379-
callbacks=[RichLiveCallbackHandler(live)],
359+
callbacks=[RichLiveCallbackHandler(live, bot_style)],
380360
)
381361

382362
# Set up the chain
@@ -398,7 +378,7 @@ def sidekick(request, verbose, response_token_size, files):
398378

399379
console.print("This is an experimental feature. Play with it, but don't count on it.", style=warning_style)
400380

401-
config = setup_config()
381+
setup_config()
402382

403383
# Pull in context. Right now it's just the contents of files that we passed in.
404384
# Soon, we could add vector embeddings of:
@@ -416,14 +396,7 @@ def sidekick(request, verbose, response_token_size, files):
416396
f"The file context you supplied is too large ({request_token_size} tokens). 😢 Try again with less files."
417397
)
418398

419-
llm = ChatOpenAI(
420-
model=model_name,
421-
openai_api_key=config["openai_api_key"],
422-
temperature=PRECISE_TEMPERATURE,
423-
max_tokens=response_token_size,
424-
verbose=verbose,
425-
streaming=True,
426-
)
399+
llm = Coder.get_llm(model_name, verbose, response_token_size, streaming=True)
427400

428401
# Open the temporary file in the user's editor
429402
editor = Path(os.getenv("EDITOR", "/usr/bin/vim")).name
@@ -452,9 +425,8 @@ def sidekick(request, verbose, response_token_size, files):
452425
human_input = click.edit(human_input[:-2])
453426

454427
with Live(Markdown(""), auto_refresh=True) as live:
455-
callback = RichLiveCallbackHandler(live)
456-
callback.buffer = []
457-
llm.callbacks = [callback]
428+
callback = RichLiveCallbackHandler(live, bot_style)
429+
llm.callbacks = [callback] # a fresh callback handler for each question
458430
chain.run({"task": human_input, "context": context})
459431

460432
if request:
@@ -477,16 +449,5 @@ def setup_config():
477449
return existing_config
478450

479451

480-
class RichLiveCallbackHandler(BaseCallbackHandler):
481-
buffer = []
482-
483-
def __init__(self, live):
484-
self.live = live
485-
486-
def on_llm_new_token(self, token, **kwargs):
487-
self.buffer.append(token)
488-
self.live.update(Markdown("".join(self.buffer), style=bot_style))
489-
490-
491452
if __name__ == "__main__":
492453
cli()

aicodebot/coder.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
from aicodebot.config import read_config
22
from aicodebot.helpers import exec_and_get_output, logger
3+
from langchain.chat_models import ChatOpenAI
34
from openai.api_resources import engine
45
from pathlib import Path
56
import fnmatch, openai, tiktoken
67

8+
DEFAULT_MAX_TOKENS = 512
9+
PRECISE_TEMPERATURE = 0.1
10+
CREATIVE_TEMPERATURE = 0.7
11+
712

813
class Coder:
914
"""
@@ -37,6 +42,28 @@ def generate_directory_structure(cls, path, ignore_patterns=None, use_gitignore=
3742

3843
return structure
3944

45+
@staticmethod
46+
def get_llm(
47+
model_name,
48+
verbose=False,
49+
response_token_size=DEFAULT_MAX_TOKENS,
50+
temperature=PRECISE_TEMPERATURE,
51+
live=None,
52+
streaming=False,
53+
callbacks=None,
54+
):
55+
config = read_config()
56+
57+
return ChatOpenAI(
58+
openai_api_key=config["openai_api_key"],
59+
model=model_name,
60+
max_tokens=response_token_size,
61+
verbose=verbose,
62+
temperature=temperature,
63+
streaming=streaming,
64+
callbacks=callbacks,
65+
)
66+
4067
@staticmethod
4168
def get_llm_model_name(token_size=0):
4269
model_options = {

aicodebot/helpers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from langchain.callbacks.base import BaseCallbackHandler
12
from loguru import logger
23
from pathlib import Path
4+
from rich.markdown import Markdown
35
import os, subprocess, sys
46

57
# ---------------------------------------------------------------------------- #
@@ -36,3 +38,14 @@ def exec_and_get_output(command):
3638
if result.returncode != 0:
3739
raise Exception(f"Command '{' '.join(command)}' failed with error:\n{result.stderr}") # noqa: TRY002
3840
return result.stdout
41+
42+
43+
class RichLiveCallbackHandler(BaseCallbackHandler):
44+
def __init__(self, live, style):
45+
self.buffer = []
46+
self.live = live
47+
self.style = style
48+
49+
def on_llm_new_token(self, token, **kwargs):
50+
self.buffer.append(token)
51+
self.live.update(Markdown("".join(self.buffer), style=self.style))

0 commit comments

Comments
 (0)