Skip to content

Commit 98c254d

Browse files
Merge pull request #104 from kevinbackhouse/ai-token
Add new environment variable named AI_API_TOKEN
2 parents 9359c8b + 7dfd751 commit 98c254d

File tree

8 files changed

+65
-58
lines changed

8 files changed

+65
-58
lines changed

.devcontainer/post-attach.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ set -e
44
# If running in Codespaces, check for necessary secrets and print error if missing
55
if [ -v CODESPACES ]; then
66
echo "🔐 Running in Codespaces - injecting secrets from Codespaces settings..."
7-
if [ ! -v COPILOT_TOKEN ]; then
8-
echo "⚠️ Running in Codespaces - please add COPILOT_TOKEN to your Codespaces secrets"
7+
if [ ! -v AI_API_TOKEN ]; then
8+
echo "⚠️ Running in Codespaces - please add AI_API_TOKEN to your Codespaces secrets"
99
fi
1010
if [ ! -v GITHUB_PERSONAL_ACCESS_TOKEN ]; then
1111
echo "⚠️ Running in Codespaces - please add GITHUB_PERSONAL_ACCESS_TOKEN to your Codespaces secrets"

.github/workflows/smoketest.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
5353
- name: Run tests
5454
env:
55-
COPILOT_TOKEN: ${{ secrets.COPILOT_TOKEN }}
55+
AI_API_ENDPOINT: ${{ secrets.AI_API_ENDPOINT }}
5656
GITHUB_AUTH_HEADER: "Bearer ${{ secrets.GITHUB_TOKEN }}"
5757

5858
run: |

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ Python >= 3.9 or Docker
3636

3737
## Configuration
3838

39-
Provide a GitHub token for an account that is entitled to use GitHub Copilot via the `COPILOT_TOKEN` environment variable. Further configuration is use case dependent, i.e. pending which MCP servers you'd like to use in your taskflows.
39+
Provide a GitHub token for an account that is entitled to use [GitHub Models](https://models.github.ai) via the `AI_API_ENDPOINT` environment variable. Further configuration is use case dependent, i.e. pending which MCP servers you'd like to use in your taskflows.
4040

4141
You can set persisting environment variables via an `.env` file in the project root.
4242

4343
Example:
4444

4545
```sh
4646
# Tokens
47-
COPILOT_TOKEN=<your_github_token>
47+
AI_API_ENDPOINT=<your_github_token>
4848
# MCP configs
4949
GITHUB_PERSONAL_ACCESS_TOKEN=<your_github_token>
5050
CODEQL_DBS_BASE_PATH="/app/my_data/codeql_databases"

docker/run.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#
1212
# git clone https://github.com/GitHubSecurityLab/seclab-taskflow-agent.git
1313
# cd seclab-taskflow-agent/src
14-
# export COPILOT_TOKEN=<My GitHub PAT>
14+
# export AI_API_TOKEN=<My GitHub PAT>
1515
# export GITHUB_AUTH_HEADER=<My GitHub PAT>
1616
# sudo -E ../docker/run.sh -p seclab_taskflow_agent.personalities.assistant 'explain modems to me please'
1717

@@ -23,5 +23,5 @@ docker run -i \
2323
--mount type=bind,src="$PWD",dst=/app \
2424
-e DATA_DIR=/app/data \
2525
-e GITHUB_PERSONAL_ACCESS_TOKEN="$GITHUB_PERSONAL_ACCESS_TOKEN" \
26-
-e COPILOT_TOKEN="$COPILOT_TOKEN" \
26+
-e AI_API_TOKEN="$AI_API_TOKEN" \
2727
"ghcr.io/githubsecuritylab/seclab-taskflow-agent" "$@"

src/seclab_taskflow_agent/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from .render_utils import render_model_output, flush_async_output
3232
from .env_utils import TmpEnv
3333
from .agent import TaskAgent
34-
from .capi import list_tool_call_models
34+
from .capi import list_tool_call_models, get_AI_token
3535
from .available_tools import AvailableTools
3636

3737
load_dotenv(find_dotenv(usecwd=True))
@@ -686,7 +686,7 @@ async def _deploy_task_agents(resolved_agents, prompt):
686686
p, t, l, cli_globals, user_prompt, help_msg = parse_prompt_args(available_tools)
687687

688688
if l:
689-
tool_models = list_tool_call_models(os.getenv('COPILOT_TOKEN'))
689+
tool_models = list_tool_call_models(get_AI_token())
690690
for model in tool_models:
691691
print(model)
692692
sys.exit(0)

src/seclab_taskflow_agent/agent.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@
1515
from agents.run import RunHooks
1616
from agents import Agent, Runner, AgentHooks, RunHooks, result, function_tool, Tool, RunContextWrapper, TContext, OpenAIChatCompletionsModel, set_default_openai_client, set_default_openai_api, set_tracing_disabled
1717

18-
from .capi import COPILOT_INTEGRATION_ID, AI_API_ENDPOINT, AI_API_ENDPOINT_ENUM
18+
from .capi import COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token, AI_API_ENDPOINT_ENUM
1919

2020
# grab our secrets from .env, this must be in .gitignore
2121
load_dotenv(find_dotenv(usecwd=True))
2222

23-
match urlparse(AI_API_ENDPOINT).netloc:
23+
api_endpoint = get_AI_endpoint()
24+
match urlparse(api_endpoint).netloc:
2425
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
2526
default_model = 'gpt-4o'
2627
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
2728
default_model = 'openai/gpt-4o'
2829
case _:
29-
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
30+
raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}")
3031

3132
DEFAULT_MODEL = os.getenv('COPILOT_DEFAULT_MODEL', default=default_model)
3233

@@ -148,8 +149,8 @@ def __init__(self,
148149
model_settings: ModelSettings | None = None,
149150
run_hooks: TaskRunHooks | None = None,
150151
agent_hooks: TaskAgentHooks | None = None):
151-
client = AsyncOpenAI(base_url=AI_API_ENDPOINT,
152-
api_key=os.getenv('COPILOT_TOKEN'),
152+
client = AsyncOpenAI(base_url=api_endpoint,
153+
api_key=get_AI_token(),
153154
default_headers={'Copilot-Integration-Id': COPILOT_INTEGRATION_ID})
154155
set_default_openai_client(client)
155156
# CAPI does not yet support the Responses API: https://github.com/github/copilot-api/issues/11185

src/seclab_taskflow_agent/capi.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,33 +9,49 @@
99
from strenum import StrEnum
1010
from urllib.parse import urlparse
1111

12-
# you can also set https://api.githubcopilot.com if you prefer
13-
# but beware that your taskflows need to reference the correct model id
14-
# since different APIs use their own id schema, use -l with your desired
15-
# endpoint to retrieve the correct id names to use for your taskflow
16-
AI_API_ENDPOINT = os.getenv('AI_API_ENDPOINT', default='https://models.github.ai/inference')
17-
1812
# Enumeration of currently supported API endpoints.
1913
class AI_API_ENDPOINT_ENUM(StrEnum):
2014
AI_API_MODELS_GITHUB = 'models.github.ai'
2115
AI_API_GITHUBCOPILOT = 'api.githubcopilot.com'
2216

2317
COPILOT_INTEGRATION_ID = 'vscode-chat'
2418

19+
# you can also set https://api.githubcopilot.com if you prefer
20+
# but beware that your taskflows need to reference the correct model id
21+
# since different APIs use their own id schema, use -l with your desired
22+
# endpoint to retrieve the correct id names to use for your taskflow
23+
def get_AI_endpoint():
24+
return os.getenv('AI_API_ENDPOINT', default='https://models.github.ai/inference')
25+
26+
def get_AI_token():
27+
"""
28+
Get the token for the AI API from the environment.
29+
The environment variable can be named either AI_API_TOKEN
30+
or COPILOT_TOKEN.
31+
"""
32+
token = os.getenv('AI_API_TOKEN')
33+
if token:
34+
return token
35+
token = os.getenv('COPILOT_TOKEN')
36+
if token:
37+
return token
38+
raise RuntimeError("AI_API_TOKEN environment variable is not set.")
39+
2540
# assume we are >= python 3.9 for our type hints
2641
def list_capi_models(token: str) -> dict[str, dict]:
2742
"""Retrieve a dictionary of available CAPI models"""
2843
models = {}
2944
try:
30-
netloc = urlparse(AI_API_ENDPOINT).netloc
45+
api_endpoint = get_AI_endpoint()
46+
netloc = urlparse(api_endpoint).netloc
3147
match netloc:
3248
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
3349
models_catalog = 'models'
3450
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
3551
models_catalog = 'catalog/models'
3652
case _:
37-
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
38-
r = httpx.get(httpx.URL(AI_API_ENDPOINT).join(models_catalog),
53+
raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}")
54+
r = httpx.get(httpx.URL(api_endpoint).join(models_catalog),
3955
headers={
4056
'Accept': 'application/json',
4157
'Authorization': f'Bearer {token}',
@@ -49,7 +65,7 @@ def list_capi_models(token: str) -> dict[str, dict]:
4965
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
5066
models_list = r.json()
5167
case _:
52-
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
68+
raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}")
5369
for model in models_list:
5470
models[model.get('id')] = dict(model)
5571
except httpx.RequestError as e:
@@ -61,7 +77,8 @@ def list_capi_models(token: str) -> dict[str, dict]:
6177
return models
6278

6379
def supports_tool_calls(model: str, models: dict) -> bool:
64-
match urlparse(AI_API_ENDPOINT).netloc:
80+
api_endpoint = get_AI_endpoint()
81+
match urlparse(api_endpoint).netloc:
6582
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
6683
return models.get(model, {}).\
6784
get('capabilities', {}).\
@@ -71,7 +88,7 @@ def supports_tool_calls(model: str, models: dict) -> bool:
7188
return 'tool-calling' in models.get(model, {}).\
7289
get('capabilities', [])
7390
case _:
74-
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
91+
raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}")
7592

7693
def list_tool_call_models(token: str) -> dict[str, dict]:
7794
models = list_capi_models(token)

tests/test_api_endpoint_config.py

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,51 +8,40 @@
88
import pytest
99
import os
1010
from urllib.parse import urlparse
11+
from seclab_taskflow_agent.capi import get_AI_endpoint, AI_API_ENDPOINT_ENUM
1112

1213
class TestAPIEndpoint:
1314
"""Test API endpoint configuration."""
14-
15-
@staticmethod
16-
def _reload_capi_module():
17-
"""Helper method to reload the capi module."""
18-
import importlib
19-
import seclab_taskflow_agent.capi
20-
importlib.reload(seclab_taskflow_agent.capi)
21-
15+
2216
def test_default_api_endpoint(self):
2317
"""Test that default API endpoint is set to models.github.ai/inference."""
24-
import seclab_taskflow_agent.capi
2518
# When no env var is set, it should default to models.github.ai/inference
26-
# Note: We can't easily test this without manipulating the environment
27-
# so we'll just import and verify the constant exists
28-
endpoint = seclab_taskflow_agent.capi.AI_API_ENDPOINT
29-
assert endpoint is not None
30-
assert isinstance(endpoint, str)
31-
assert urlparse(endpoint).netloc == seclab_taskflow_agent.capi.AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB
32-
19+
try:
20+
# Save original env
21+
original_env = os.environ.pop('AI_API_ENDPOINT', None)
22+
endpoint = get_AI_endpoint()
23+
assert endpoint is not None
24+
assert isinstance(endpoint, str)
25+
assert urlparse(endpoint).netloc == AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB
26+
finally:
27+
# Restore original env
28+
if original_env:
29+
os.environ['AI_API_ENDPOINT'] = original_env
30+
3331
def test_api_endpoint_env_override(self):
3432
"""Test that AI_API_ENDPOINT can be overridden by environment variable."""
35-
# Save original env
36-
original_env = os.environ.get('AI_API_ENDPOINT')
37-
3833
try:
39-
# Set custom endpoint
40-
test_endpoint = 'https://test.example.com'
34+
# Save original env
35+
original_env = os.environ.pop('AI_API_ENDPOINT', None)
36+
# Set different endpoint
37+
test_endpoint = 'https://api.githubcopilot.com'
4138
os.environ['AI_API_ENDPOINT'] = test_endpoint
42-
43-
# Reload the module to pick up the new env var
44-
self._reload_capi_module()
45-
46-
import seclab_taskflow_agent.capi
47-
assert seclab_taskflow_agent.capi.AI_API_ENDPOINT == test_endpoint
39+
40+
assert get_AI_endpoint() == test_endpoint
4841
finally:
4942
# Restore original env
50-
if original_env is None:
51-
os.environ.pop('AI_API_ENDPOINT', None)
52-
else:
43+
if original_env:
5344
os.environ['AI_API_ENDPOINT'] = original_env
54-
# Reload again to restore original state
55-
self._reload_capi_module()
5645

5746
if __name__ == '__main__':
5847
pytest.main([__file__, '-v'])

0 commit comments

Comments
 (0)