1414import zipfile
1515import click
1616import requests
17- import openai
17+ from openai import OpenAI
1818from anthropic import Anthropic
1919import tiktoken
2020
@@ -24,7 +24,7 @@ def __init__(self, model):
2424 self .model = model
2525 self .base_url = "http://localhost:11434/api"
2626 self .token_count = 0
27- self .max_tokens = 128000 # Max tokens for Mistral-Nemo
27+ self .max_tokens = 131072
2828
2929 def count_tokens (self , text ):
3030 return len (tiktoken .encoding_for_model ("gpt-4o" ).encode (text ))
@@ -69,13 +69,13 @@ def generate(self, prompt):
6969
7070class OpenAIAPI :
7171 def __init__ (self , model ):
72- if model == "mistral-nemo " :
72+ if model == "qwen2.5-coder:32b " :
7373 model = "gpt-4o"
7474 self .model = model
7575 self .api_key = os .getenv ("OPENAI_API_KEY" )
7676 if not self .api_key :
7777 raise ValueError ("OPENAI_API_KEY environment variable is not set" )
78- openai . api_key = self .api_key
78+ self . openai = OpenAI ( api_key = self .api_key )
7979 self .token_count = 0
8080 self .max_tokens = 128000
8181 self .max_output_tokens = 16384
@@ -98,7 +98,7 @@ def generate(self, prompt):
9898
9999 if self .model in self .special_models :
100100 # Non-streaming approach
101- response = openai .chat .completions .create (
101+ response = self . openai .chat .completions .create (
102102 model = self .model ,
103103 messages = [{"role" : "user" , "content" : prompt }],
104104 max_completion_tokens = max_completion_tokens ,
@@ -115,7 +115,7 @@ def generate(self, prompt):
115115 print () # Print a newline at the end
116116 else :
117117 # Streaming approach
118- response = openai .chat .completions .create (
118+ response = self . openai .chat .completions .create (
119119 model = self .model ,
120120 messages = [{"role" : "user" , "content" : prompt }],
121121 max_tokens = max_completion_tokens ,
@@ -148,9 +148,77 @@ def generate(self, prompt):
148148 raise Exception (f"OpenAI API error: { str (e )} " )
149149
150150
151+ class GeminiAPI :
152+ def __init__ (self , model ):
153+ if model == "qwen2.5-coder:32b" :
154+ model = "gemini-1.5-pro"
155+ self .model = model
156+ self .api_key = os .getenv ("GEMINI_API_KEY" )
157+ self .base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
158+ if not self .api_key :
159+ raise ValueError ("GEMINI_API_KEY environment variable is not set" )
160+ self .openai = OpenAI (api_key = self .api_key , base_url = self .base_url )
161+ self .token_count = 0
162+ if model == "gemini-1.5-pro" :
163+ self .max_tokens = 2097152
164+ else :
165+ self .max_tokens = 1048576
166+ self .max_output_tokens = 8192
167+
168+ print (model )
169+
170+ def count_tokens (self , text ):
171+ return len (tiktoken .encoding_for_model ("gpt-4o" ).encode (text ))
172+
173+ def generate (self , prompt ):
174+ try :
175+ full_response = ""
176+ prompt_tokens = self .count_tokens (prompt )
177+
178+ if prompt_tokens >= self .max_tokens :
179+ print (f"Warning: Prompt exceeds maximum token limit ({ prompt_tokens } /{ self .max_tokens } )" )
180+ return "Error: Prompt too long"
181+
182+ # Use the predefined max output tokens, or adjust if prompt is very long
183+ max_completion_tokens = min (self .max_output_tokens , self .max_tokens - prompt_tokens )
184+
185+ # Streaming approach
186+ response = self .openai .chat .completions .create (
187+ model = self .model ,
188+ messages = [{"role" : "user" , "content" : prompt }],
189+ max_tokens = max_completion_tokens ,
190+ stream = True ,
191+ )
192+
193+ for chunk in response :
194+ if chunk .choices [0 ].delta .content :
195+ chunk_text = chunk .choices [0 ].delta .content
196+ full_response += chunk_text
197+ print (chunk_text , end = "" , flush = True )
198+ if "^^^end^^^" in full_response :
199+ break
200+
201+ print () # Print a newline at the end
202+
203+ # Extract content between markers
204+ start_marker = "^^^start^^^"
205+ end_marker = "^^^end^^^"
206+ start_index = full_response .find (start_marker )
207+ end_index = full_response .find (end_marker )
208+ if start_index != - 1 and end_index != - 1 :
209+ full_response = full_response [start_index + len (start_marker ) : end_index ].strip ()
210+
211+ self .token_count = self .count_tokens (full_response )
212+ print (f"Token count: { self .token_count } " )
213+
214+ return full_response
215+ except Exception as e :
216+ raise Exception (f"OpenAI API error: { str (e )} " )
217+
218+
151219class ClaudeAPI :
152220 def __init__ (self , model ):
153- if model == "mistral-nemo " :
221+ if model == "qwen2.5-coder:32b " :
154222 model = "claude-3-5-sonnet-20241022"
155223 self .model = model
156224 self .api_key = os .getenv ("ANTHROPIC_API_KEY" )
@@ -215,7 +283,7 @@ class NemoAgent:
215283 WRITE_RETRY_DELAY = 1 # second
216284
217285 def __init__ (
218- self , task : str , model : str = "mistral-nemo " , provider : str = "ollama"
286+ self , task : str , model : str = "qwen2.5-coder:32b " , provider : str = "ollama"
219287 ):
220288 self .task = task
221289 self .model = model
@@ -242,6 +310,8 @@ def setup_llm(self):
242310 return OpenAIAPI (self .model )
243311 elif self .provider == "claude" :
244312 return ClaudeAPI (self .model )
313+ elif self .provider == "gemini" :
314+ return GeminiAPI (self .model )
245315 else :
246316 raise ValueError (f"Unsupported provider: { self .provider } " )
247317
@@ -923,11 +993,11 @@ def main
923993 type = click .Path (exists = True ),
924994 help = "Path to a markdown file containing the task" ,
925995)
926- @click .option ("--model" , default = "mistral-nemo " , help = "The model to use for the LLM" )
996+ @click .option ("--model" , default = "qwen2.5-coder:32b " , help = "The model to use for the LLM" )
927997@click .option (
928998 "--provider" ,
929999 default = "ollama" ,
930- type = click .Choice (["ollama" , "openai" , "claude" ]),
1000+ type = click .Choice (["ollama" , "openai" , "claude" , "gemini" ]),
9311001 help = "The LLM provider to use" ,
9321002)
9331003@click .option (
@@ -951,7 +1021,7 @@ def main
9511021def cli (
9521022 task : str = None ,
9531023 file : str = None ,
954- model : str = "mistral-nemo " ,
1024+ model : str = "qwen2.5-coder:32b " ,
9551025 provider : str = "ollama" ,
9561026 zip : str = None ,
9571027 docs : str = None ,
@@ -970,6 +1040,8 @@ def cli(
9701040 raise ValueError ("OPENAI_API_KEY environment variable is not set" )
9711041 elif provider == "claude" and not os .getenv ("ANTHROPIC_API_KEY" ):
9721042 raise ValueError ("ANTHROPIC_API_KEY environment variable is not set" )
1043+ elif provider == "gemini" and not os .getenv ("GEMINI_API_KEY" ):
1044+ raise ValueError ("GEMINI_API_KEY environment variable is not set" )
9731045
9741046 # Read task from file if provided
9751047 if file :
0 commit comments