Skip to content

Commit a52de62

Browse files
committed
Update model selection logic to support GPT-4 🚀
This commit introduces changes to the model selection logic in the AI code bot. Previously, the bot was hardcoded to use the GPT-3.5-turbo model. Now, the bot checks if GPT-4 is supported by the provided OpenAI API key and uses it if available. If GPT-4 is not supported, the bot falls back to using GPT-3.5-turbo. The model selection is also now dynamic based on the token size of the request. The bot selects the largest model that supports the required token size. This ensures that the bot can handle larger contexts while still using the most powerful model available. Additionally, a new environment variable `GPT_4_SUPPORTED` has been added to the `.aicodebot.template` file to store the GPT-4 support status of the API key. This value is set during the setup process when the API key is validated. This update enhances the bot's capabilities and prepares it for future improvements in the OpenAI models.
1 parent 70ac379 commit a52de62

File tree

2 files changed

+82
-52
lines changed

2 files changed

+82
-52
lines changed

aicodebot/.aicodebot.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# Obtain your API key from https://platform.openai.com/account/api-keys
22
# sk-...
33
OPENAI_API_KEY=
4+
GPT_4_SUPPORTED=false

aicodebot/cli.py

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from langchain.chains import LLMChain
55
from langchain.chat_models import ChatOpenAI
66
from langchain.prompts import load_prompt
7+
from openai.api_resources import engine
78
from pathlib import Path
89
from rich.console import Console
910
from rich.style import Style
@@ -13,7 +14,6 @@
1314

1415
DEFAULT_MAX_TOKENS = 1024
1516
DEFAULT_TEMPERATURE = 0.1
16-
DEFAULT_MODEL = "gpt-3.5-turbo" # Can't wait to use GPT-4, as the results are much better. On the waitlist.
1717
DEFAULT_SPINNER = "point"
1818

1919
# ----------------------- Setup for rich console output ---------------------- #
@@ -50,9 +50,8 @@ def alignment(verbose):
5050
prompt = load_prompt(Path(__file__).parent / "prompts" / "alignment.yaml")
5151

5252
# Set up the language model
53-
llm = ChatOpenAI(
54-
model=DEFAULT_MODEL, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose
55-
)
53+
model = get_llm_model(get_token_length(prompt.template))
54+
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)
5655

5756
# Set up the chain
5857
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
@@ -100,18 +99,10 @@ def commit(verbose, response_token_size, yes, skip_pre_commit):
10099
sys.exit(0)
101100

102101
# Check the size of the diff context and adjust accordingly
103-
prompt_token_size = get_token_length(diff_context) + get_token_length(prompt.template)
102+
request_token_size = get_token_length(diff_context) + get_token_length(prompt.template)
103+
model = get_llm_model(request_token_size)
104104
if verbose:
105-
console.print(f"Diff context token size: {prompt_token_size}")
106-
107-
if prompt_token_size + response_token_size > 16_000:
108-
# Bigger models coming soon
109-
console.print("The diff context is too large to review. 😞")
110-
sys.exit(1)
111-
elif prompt_token_size + response_token_size > 3_500: # It's actually 4k, but we want a buffer
112-
model = "gpt-3.5-turbo-16k" # supports 16k tokens but is a bit slower and more expensive
113-
else:
114-
model = DEFAULT_MODEL # gpt-3.5-turbo supports 4k tokens
105+
console.print(f"Diff context token size: {request_token_size}, using model: {model}")
115106

116107
# Set up the language model
117108
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)
@@ -149,39 +140,40 @@ def debug(command, verbose):
149140
"""Run a command and debug the output."""
150141
setup_environment()
151142

152-
# Load the prompt
153-
prompt = load_prompt(Path(__file__).parent / "prompts" / "debug.yaml")
154-
155-
# Set up the language model
156-
llm = ChatOpenAI(
157-
model=DEFAULT_MODEL, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose
158-
)
159-
160-
# Set up the chain
161-
chat_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
162-
command_str = " ".join(command)
163-
164143
# Run the command and capture its output
144+
command_str = " ".join(command)
165145
console.print(f"Executing the command:\n{command_str}")
166146
process = subprocess.run(command_str, shell=True, capture_output=True, text=True) # noqa: S602
167147

168148
# Print the output of the command
169149
output = f"Standard Output:\n{process.stdout}\nStandard Error:\n{process.stderr}"
170150
console.print(output)
171151

172-
# Print a message about the exit status
152+
# If it succeeded, exit
173153
if process.returncode == 0:
174154
console.print("✅ The command completed successfully.")
175-
else:
176-
console.print(f"The command exited with status {process.returncode}.")
155+
sys.exit(0)
177156

178157
# If the command failed, send its output to ChatGPT for analysis
179-
if process.returncode != 0:
180-
error_output = process.stderr
181-
with console.status("Debugging", spinner=DEFAULT_SPINNER):
182-
response = chat_chain.run(error_output)
183-
console.print(response, style=bot_style)
184-
sys.exit(process.returncode)
158+
error_output = process.stderr
159+
160+
console.print(f"The command exited with status {process.returncode}.")
161+
162+
# Load the prompt
163+
prompt = load_prompt(Path(__file__).parent / "prompts" / "debug.yaml")
164+
165+
# Set up the language model
166+
model = get_llm_model(get_token_length(error_output) + get_token_length(prompt.template))
167+
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)
168+
169+
# Set up the chain
170+
chat_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
171+
172+
with console.status("Debugging", spinner=DEFAULT_SPINNER):
173+
response = chat_chain.run(error_output)
174+
console.print(response, style=bot_style)
175+
176+
sys.exit(process.returncode)
185177

186178

187179
@cli.command()
@@ -194,7 +186,8 @@ def fun_fact(verbose):
194186
prompt = load_prompt(Path(__file__).parent / "prompts" / "fun_fact.yaml")
195187

196188
# Set up the language model
197-
llm = ChatOpenAI(model=DEFAULT_MODEL, temperature=0.9, max_tokens=250, verbose=verbose)
189+
model = get_llm_model(get_token_length(prompt.template))
190+
llm = ChatOpenAI(model=model, temperature=0.9, max_tokens=250, verbose=verbose)
198191

199192
# Set up the chain
200193
chat_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
@@ -223,21 +216,13 @@ def review(commit, verbose):
223216

224217
# Check the size of the diff context and adjust accordingly
225218
response_token_size = DEFAULT_MAX_TOKENS / 2
226-
prompt_token_size = get_token_length(diff_context) + get_token_length(prompt.template)
219+
request_token_size = get_token_length(diff_context) + get_token_length(prompt.template)
220+
model = get_llm_model(request_token_size)
227221
if verbose:
228-
console.print(f"Prompt token size: {prompt_token_size}")
229-
230-
if prompt_token_size + response_token_size > 16_000:
231-
# Bigger models coming soon
232-
console.print("The diff context is too large to review. 😞")
233-
sys.exit(1)
234-
elif prompt_token_size + response_token_size > 3_500: # It's actually 4k, but we want a buffer
235-
model = "gpt-3.5-turbo-16k" # supports 16k tokens but is a bit slower and more expensive
236-
else:
237-
model = DEFAULT_MODEL # gpt-3.5-turbo supports 4k tokens
222+
console.print(f"Diff context token size: {request_token_size}, using model: {model}")
238223

239224
# Set up the language model
240-
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=DEFAULT_MAX_TOKENS, verbose=verbose)
225+
llm = ChatOpenAI(model=model, temperature=DEFAULT_TEMPERATURE, max_tokens=response_token_size, verbose=verbose)
241226

242227
# Set up the chain
243228
chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)
@@ -247,7 +232,9 @@ def review(commit, verbose):
247232
console.print(response, style=bot_style)
248233

249234

250-
# ------------------------------- End Commands ------------------------------- #
235+
# ---------------------------------------------------------------------------- #
236+
# Helper functions #
237+
# ---------------------------------------------------------------------------- #
251238

252239

253240
def setup_environment():
@@ -273,17 +260,32 @@ def setup_environment():
273260
if click.confirm(f"Create the {config_file} file for you?"):
274261
api_key = click.prompt("Please enter your OpenAI API key")
275262

276-
# Copy .env.template to .env and insert the API key
263+
# Validate the API key and check if it supports GPT-4
264+
openai.api_key = api_key
265+
try:
266+
click.echo("Validating the API key, and checking if GPT-4 is supported...")
267+
engines = engine.Engine.list()
268+
gpt_4_supported = "true" if "gpt-4" in [engine.id for engine in engines.data] else "false"
269+
if gpt_4_supported == "true":
270+
click.echo("✅ The API key is valid and supports GPT-4.")
271+
else:
272+
click.echo("✅ The API key is valid, but does not support GPT-4. GPT-3.5 will be used instead.")
273+
except Exception as e:
274+
raise click.ClickException(f"Failed to validate the API key: {str(e)}") from e
275+
276+
# Copy .env.template to .env and insert the API key and gpt_4_supported
277277
template_file = Path(__file__).parent / ".aicodebot.template"
278278
with Path.open(template_file, "r") as template, Path.open(config_file, "w") as env:
279279
for line in template:
280280
if line.startswith("OPENAI_API_KEY="):
281281
env.write(f"OPENAI_API_KEY={api_key}\n")
282+
elif line.startswith("GPT_4_SUPPORTED="):
283+
env.write(f"GPT_4_SUPPORTED={gpt_4_supported}\n")
282284
else:
283285
env.write(line)
284286

285287
console.print(
286-
f"[bold green]Created {config_file} with your OpenAI API key.[/bold green] "
288+
f"[bold green]Created {config_file} with your OpenAI API key and GPT-4 support status.[/bold green] "
287289
"Now, please re-run aicodebot and let's get started!"
288290
)
289291
sys.exit(0)
@@ -293,5 +295,32 @@ def setup_environment():
293295
)
294296

295297

298+
def get_llm_model(token_size):
299+
# https://platform.openai.com/docs/models/gpt-3-5
300+
# We want to use GPT-4, if it is available for this OPENAI_API_KEY, otherwise GPT-3.5
301+
# We also want to use the largest model that supports the token size we need
302+
model_options = {
303+
"gpt-4": 8192,
304+
"gpt-4-32k": 32768,
305+
"gpt-3.5-turbo": 4096,
306+
"gpt-3.5-turbo-16k": 16384,
307+
}
308+
gpt_4_supported = os.getenv("GPT_4_SUPPORTED") == "true"
309+
if gpt_4_supported:
310+
if token_size <= model_options["gpt-4"]:
311+
return "gpt-4"
312+
elif token_size <= model_options["gpt-4-32k"]:
313+
return "gpt-4-32k"
314+
else:
315+
raise click.ClickException("🛑 The context is too large to for the Model. 😞")
316+
else:
317+
if token_size <= model_options["gpt-3.5-turbo"]: # noqa: PLR5501
318+
return "gpt-3.5-turbo"
319+
elif token_size <= model_options["gpt-3.5-turbo-16k"]:
320+
return "gpt-3.5-turbo-16k"
321+
else:
322+
raise click.ClickException("🛑 The context is too large to for the Model. 😞")
323+
324+
296325
if __name__ == "__main__":
297326
cli()

0 commit comments

Comments
 (0)