Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .devcontainer/post-attach.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ set -e
# If running in Codespaces, check for necessary secrets and print error if missing
if [ -v CODESPACES ]; then
echo "🔐 Running in Codespaces - injecting secrets from Codespaces settings..."
if [ ! -v COPILOT_TOKEN ]; then
echo "⚠️ Running in Codespaces - please add COPILOT_TOKEN to your Codespaces secrets"
if [ ! -v AI_API_TOKEN ]; then
echo "⚠️ Running in Codespaces - please add AI_API_TOKEN to your Codespaces secrets"
fi
if [ ! -v GITHUB_PERSONAL_ACCESS_TOKEN ]; then
echo "⚠️ Running in Codespaces - please add GITHUB_PERSONAL_ACCESS_TOKEN to your Codespaces secrets"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/smoketest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:

- name: Run tests
env:
COPILOT_TOKEN: ${{ secrets.COPILOT_TOKEN }}
AI_API_ENDPOINT: ${{ secrets.AI_API_ENDPOINT }}
GITHUB_AUTH_HEADER: "Bearer ${{ secrets.GITHUB_TOKEN }}"

run: |
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ Python >= 3.9 or Docker

## Configuration

Provide a GitHub token for an account that is entitled to use GitHub Copilot via the `COPILOT_TOKEN` environment variable. Further configuration is use case dependent, i.e. pending which MCP servers you'd like to use in your taskflows.
Provide a GitHub token for an account that is entitled to use [GitHub Models](https://models.github.ai) via the `AI_API_ENDPOINT` environment variable. Further configuration is use case dependent, i.e. pending which MCP servers you'd like to use in your taskflows.

You can set persisting environment variables via an `.env` file in the project root.

Example:

```sh
# Tokens
COPILOT_TOKEN=<your_github_token>
AI_API_ENDPOINT=<your_github_token>
# MCP configs
GITHUB_PERSONAL_ACCESS_TOKEN=<your_github_token>
CODEQL_DBS_BASE_PATH="/app/my_data/codeql_databases"
Expand Down
4 changes: 2 additions & 2 deletions docker/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#
# git clone https://github.com/GitHubSecurityLab/seclab-taskflow-agent.git
# cd seclab-taskflow-agent/src
# export COPILOT_TOKEN=<My GitHub PAT>
# export AI_API_TOKEN=<My GitHub PAT>
# export GITHUB_AUTH_HEADER=<My GitHub PAT>
# sudo -E ../docker/run.sh -p seclab_taskflow_agent.personalities.assistant 'explain modems to me please'

Expand All @@ -23,5 +23,5 @@ docker run -i \
--mount type=bind,src="$PWD",dst=/app \
-e DATA_DIR=/app/data \
-e GITHUB_PERSONAL_ACCESS_TOKEN="$GITHUB_PERSONAL_ACCESS_TOKEN" \
-e COPILOT_TOKEN="$COPILOT_TOKEN" \
-e AI_API_TOKEN="$AI_API_TOKEN" \
"ghcr.io/githubsecuritylab/seclab-taskflow-agent" "$@"
4 changes: 2 additions & 2 deletions src/seclab_taskflow_agent/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .render_utils import render_model_output, flush_async_output
from .env_utils import TmpEnv
from .agent import TaskAgent
from .capi import list_tool_call_models
from .capi import list_tool_call_models, get_AI_token
from .available_tools import AvailableTools

load_dotenv(find_dotenv(usecwd=True))
Expand Down Expand Up @@ -686,7 +686,7 @@ async def _deploy_task_agents(resolved_agents, prompt):
p, t, l, cli_globals, user_prompt, help_msg = parse_prompt_args(available_tools)

if l:
tool_models = list_tool_call_models(os.getenv('COPILOT_TOKEN'))
tool_models = list_tool_call_models(get_AI_token())
for model in tool_models:
print(model)
sys.exit(0)
Expand Down
11 changes: 6 additions & 5 deletions src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@
from agents.run import RunHooks
from agents import Agent, Runner, AgentHooks, RunHooks, result, function_tool, Tool, RunContextWrapper, TContext, OpenAIChatCompletionsModel, set_default_openai_client, set_default_openai_api, set_tracing_disabled

from .capi import COPILOT_INTEGRATION_ID, AI_API_ENDPOINT, AI_API_ENDPOINT_ENUM
from .capi import COPILOT_INTEGRATION_ID, get_AI_endpoint, get_AI_token, AI_API_ENDPOINT_ENUM

# grab our secrets from .env, this must be in .gitignore
load_dotenv(find_dotenv(usecwd=True))

match urlparse(AI_API_ENDPOINT).netloc:
api_endpoint = get_AI_endpoint()
match urlparse(api_endpoint).netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
default_model = 'gpt-4o'
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
default_model = 'openai/gpt-4o'
case _:
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}")

DEFAULT_MODEL = os.getenv('COPILOT_DEFAULT_MODEL', default=default_model)

Expand Down Expand Up @@ -148,8 +149,8 @@ def __init__(self,
model_settings: ModelSettings | None = None,
run_hooks: TaskRunHooks | None = None,
agent_hooks: TaskAgentHooks | None = None):
client = AsyncOpenAI(base_url=AI_API_ENDPOINT,
api_key=os.getenv('COPILOT_TOKEN'),
client = AsyncOpenAI(base_url=api_endpoint,
api_key=get_AI_token(),
default_headers={'Copilot-Integration-Id': COPILOT_INTEGRATION_ID})
set_default_openai_client(client)
# CAPI does not yet support the Responses API: https://github.com/github/copilot-api/issues/11185
Expand Down
41 changes: 29 additions & 12 deletions src/seclab_taskflow_agent/capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,49 @@
from strenum import StrEnum
from urllib.parse import urlparse

# you can also set https://api.githubcopilot.com if you prefer
# but beware that your taskflows need to reference the correct model id
# since different APIs use their own id schema, use -l with your desired
# endpoint to retrieve the correct id names to use for your taskflow
AI_API_ENDPOINT = os.getenv('AI_API_ENDPOINT', default='https://models.github.ai/inference')

# Enumeration of currently supported API endpoints.
class AI_API_ENDPOINT_ENUM(StrEnum):
AI_API_MODELS_GITHUB = 'models.github.ai'
AI_API_GITHUBCOPILOT = 'api.githubcopilot.com'

COPILOT_INTEGRATION_ID = 'vscode-chat'

# you can also set https://api.githubcopilot.com if you prefer
# but beware that your taskflows need to reference the correct model id
# since different APIs use their own id schema, use -l with your desired
# endpoint to retrieve the correct id names to use for your taskflow
def get_AI_endpoint():
return os.getenv('AI_API_ENDPOINT', default='https://models.github.ai/inference')

def get_AI_token():
"""
Get the token for the AI API from the environment.
The environment variable can be named either AI_API_TOKEN
or COPILOT_TOKEN.
"""
token = os.getenv('AI_API_TOKEN')
if token:
return token
token = os.getenv('COPILOT_TOKEN')
if token:
return token
raise RuntimeError("AI_API_TOKEN environment variable is not set.")

# assume we are >= python 3.9 for our type hints
def list_capi_models(token: str) -> dict[str, dict]:
"""Retrieve a dictionary of available CAPI models"""
models = {}
try:
netloc = urlparse(AI_API_ENDPOINT).netloc
api_endpoint = get_AI_endpoint()
netloc = urlparse(api_endpoint).netloc
match netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
models_catalog = 'models'
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
models_catalog = 'catalog/models'
case _:
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
r = httpx.get(httpx.URL(AI_API_ENDPOINT).join(models_catalog),
raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}")
r = httpx.get(httpx.URL(api_endpoint).join(models_catalog),
headers={
'Accept': 'application/json',
'Authorization': f'Bearer {token}',
Expand All @@ -49,7 +65,7 @@ def list_capi_models(token: str) -> dict[str, dict]:
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
models_list = r.json()
case _:
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}")
for model in models_list:
models[model.get('id')] = dict(model)
except httpx.RequestError as e:
Expand All @@ -61,7 +77,8 @@ def list_capi_models(token: str) -> dict[str, dict]:
return models

def supports_tool_calls(model: str, models: dict) -> bool:
match urlparse(AI_API_ENDPOINT).netloc:
api_endpoint = get_AI_endpoint()
match urlparse(api_endpoint).netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
return models.get(model, {}).\
get('capabilities', {}).\
Expand All @@ -71,7 +88,7 @@ def supports_tool_calls(model: str, models: dict) -> bool:
return 'tool-calling' in models.get(model, {}).\
get('capabilities', [])
case _:
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}")

def list_tool_call_models(token: str) -> dict[str, dict]:
models = list_capi_models(token)
Expand Down
53 changes: 21 additions & 32 deletions tests/test_api_endpoint_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,51 +8,40 @@
import pytest
import os
from urllib.parse import urlparse
from seclab_taskflow_agent.capi import get_AI_endpoint, AI_API_ENDPOINT_ENUM

class TestAPIEndpoint:
"""Test API endpoint configuration."""

@staticmethod
def _reload_capi_module():
"""Helper method to reload the capi module."""
import importlib
import seclab_taskflow_agent.capi
importlib.reload(seclab_taskflow_agent.capi)


def test_default_api_endpoint(self):
"""Test that default API endpoint is set to models.github.ai/inference."""
import seclab_taskflow_agent.capi
# When no env var is set, it should default to models.github.ai/inference
# Note: We can't easily test this without manipulating the environment
# so we'll just import and verify the constant exists
endpoint = seclab_taskflow_agent.capi.AI_API_ENDPOINT
assert endpoint is not None
assert isinstance(endpoint, str)
assert urlparse(endpoint).netloc == seclab_taskflow_agent.capi.AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB

try:
# Save original env
original_env = os.environ.pop('AI_API_ENDPOINT', None)
endpoint = get_AI_endpoint()
assert endpoint is not None
assert isinstance(endpoint, str)
assert urlparse(endpoint).netloc == AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB
finally:
# Restore original env
if original_env:
os.environ['AI_API_ENDPOINT'] = original_env

def test_api_endpoint_env_override(self):
"""Test that AI_API_ENDPOINT can be overridden by environment variable."""
# Save original env
original_env = os.environ.get('AI_API_ENDPOINT')

try:
# Set custom endpoint
test_endpoint = 'https://test.example.com'
# Save original env
original_env = os.environ.pop('AI_API_ENDPOINT', None)
# Set different endpoint
test_endpoint = 'https://api.githubcopilot.com'
os.environ['AI_API_ENDPOINT'] = test_endpoint

# Reload the module to pick up the new env var
self._reload_capi_module()

import seclab_taskflow_agent.capi
assert seclab_taskflow_agent.capi.AI_API_ENDPOINT == test_endpoint

assert get_AI_endpoint() == test_endpoint
finally:
# Restore original env
if original_env is None:
os.environ.pop('AI_API_ENDPOINT', None)
else:
if original_env:
os.environ['AI_API_ENDPOINT'] = original_env
# Reload again to restore original state
self._reload_capi_module()

if __name__ == '__main__':
pytest.main([__file__, '-v'])