|
5 | 5 | from pathlib import Path |
6 | 6 |
|
7 | 7 | import openai |
8 | | -from openai.api_resources.model import Model |
9 | | -from openai.error import APIError |
| 8 | +from openai import APIError |
10 | 9 |
|
11 | 10 | from binaryninja.function import Function |
12 | 11 | from binaryninja.lowlevelil import LowLevelILFunction |
|
18 | 17 |
|
19 | 18 | from . query import Query |
20 | 19 | from . c import Pseudo_C |
| 20 | +from . exceptions import NoAPIKeyException |
21 | 21 |
|
22 | 22 |
|
23 | 23 | class Agent: |
@@ -45,7 +45,7 @@ def __init__(self, |
45 | 45 | path_to_api_key: Optional[Path]=None) -> None: |
46 | 46 |
|
47 | 47 | # Read the API key from the environment variable. |
48 | | - openai.api_key = self.read_api_key(path_to_api_key) |
| 48 | + self.client = openai.OpenAI(api_key=self.read_api_key(filename=path_to_api_key)) |
49 | 49 |
|
50 | 50 | assert bv is not None, 'BinaryView is None. Check how you called this function.' |
51 | 51 | # Set instance attributes. |
@@ -87,12 +87,12 @@ def read_api_key(self, filename: Optional[Path]=None) -> str: |
87 | 87 | except FileNotFoundError: |
88 | 88 | log.log_error(f'Could not find API key file at {filename}.') |
89 | 89 |
|
90 | | - raise APIError('No API key found. Refer to the documentation to add the ' |
| 90 | + raise NoAPIKeyException('No API key found. Refer to the documentation to add the ' |
91 | 91 | 'API key.') |
92 | 92 |
|
93 | 93 | def is_valid_model(self, model: str) -> bool: |
94 | 94 | '''Checks if the model is valid by querying the OpenAI API.''' |
95 | | - models: list[Model] = openai.Model.list().data |
| 95 | + models: list = self.client.models.list().data |
96 | 96 | return model in [m.id for m in models] |
97 | 97 |
|
98 | 98 | def get_model(self) -> str: |
@@ -206,7 +206,8 @@ def rename_variable(self, response: str) -> None: |
206 | 206 |
|
207 | 207 | def send_query(self, query: str, callback: Optional[Callable]=None) -> None: |
208 | 208 | '''Sends a query to the engine and prints the response.''' |
209 | | - query = Query(query_string=query, |
| 209 | + query = Query(client=self.client, |
| 210 | + query_string=query, |
210 | 211 | model=self.model, |
211 | 212 | max_token_count=self.get_token_count(), |
212 | 213 | callback_function=callback) |
|
0 commit comments