1616import tiktoken
1717
1818class OpenAIAPI :
19- def __init__ (self , model ):
19+ def __init__ (self , model , api_key ):
2020 self .model = model
21- self .api_key = os .getenv ("OPENAI_API_KEY" )
22- if not self .api_key :
23- raise ValueError ("OPENAI_API_KEY environment variable is not set" )
21+ self .api_key = api_key
2422 self .openai = OpenAI (api_key = self .api_key )
2523 self .token_count = 0
2624 if "4.1" in model :
@@ -97,14 +95,12 @@ def generate(self, prompt):
9795
9896
9997class GeminiAPI :
100- def __init__ (self , model ):
98+ def __init__ (self , model , api_key ):
10199 if model == "gpt-4.1" :
102100 model = "gemini-2.5-pro-preview-05-06"
103101 self .model = model
104- self .api_key = os . getenv ( "GEMINI_API_KEY" )
102+ self .api_key = api_key
105103 self .base_url = "https://generativelanguage.googleapis.com/v1beta/openai/"
106- if not self .api_key :
107- raise ValueError ("GEMINI_API_KEY environment variable is not set" )
108104 self .openai = OpenAI (api_key = self .api_key , base_url = self .base_url )
109105 self .token_count = 0
110106 self .max_tokens = 65536
@@ -164,10 +160,11 @@ class NemoAgent:
164160 WRITE_RETRY_DELAY = 1 # second
165161
166162 def __init__ (
167- self , task : str , model : str = "gpt-4.1" , provider : str = "openai" , tests : bool = True
163+ self , task : str , api_key : str , model : str = "gpt-4.1" , provider : str = "openai" , tests : bool = True
168164 ):
169165 self .task = task
170166 self .model = model
167+ self .api_key = api_key
171168 self .provider = provider
172169 self .setup_logging ()
173170 self .project_name = self .generate_project_name ()
@@ -187,9 +184,9 @@ def count_tokens(self, text):
187184
188185 def setup_llm (self ):
189186 if self .provider == "openai" :
190- return OpenAIAPI (self .model )
187+ return OpenAIAPI (self .model , self . api_key )
191188 elif self .provider == "gemini" :
192- return GeminiAPI (self .model )
189+ return GeminiAPI (self .model , self . api_key )
193190 else :
194191 raise ValueError (f"Unsupported provider: { self .provider } " )
195192
@@ -925,6 +922,11 @@ def cli(
925922 raise ValueError ("OPENAI_API_KEY environment variable is not set" )
926923 elif provider == "gemini" and not os .getenv ("GEMINI_API_KEY" ):
927924 raise ValueError ("GEMINI_API_KEY environment variable is not set" )
925+
926+ if provider == "gemini" :
927+ api_key = os .getenv ("GEMINI_API_KEY" )
928+ else :
929+ api_key = os .getenv ("OPENAI_API_KEY" )
928930
929931 # Read task from file if provided
930932 if file :
@@ -938,7 +940,7 @@ def cli(
938940 elif not task :
939941 task = click .prompt ("Please enter your task" )
940942
941- nemo_agent = NemoAgent (task = task , model = model , provider = provider , tests = tests )
943+ nemo_agent = NemoAgent (task = task , model = model , provider = provider , tests = tests , api_key = api_key )
942944
943945 # Ingest docs if provided
944946 if docs :
0 commit comments