Skip to content
Merged
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ dependencies = [
"SQLAlchemy==2.0.41",
"sse-starlette==2.4.1",
"starlette==0.49.1",
"strenum==0.4.15",
"tqdm==4.67.1",
"typer==0.16.0",
"types-requests==2.32.4.20250611",
Expand Down
12 changes: 6 additions & 6 deletions src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
from agents.run import RunHooks
from agents import Agent, Runner, AgentHooks, RunHooks, result, function_tool, Tool, RunContextWrapper, TContext, OpenAIChatCompletionsModel, set_default_openai_client, set_default_openai_api, set_tracing_disabled

from .capi import COPILOT_INTEGRATION_ID, COPILOT_API_ENDPOINT
from .capi import COPILOT_INTEGRATION_ID, AI_API_ENDPOINT, AI_API_ENDPOINT_ENUM

# grab our secrets from .env, this must be in .gitignore
load_dotenv(find_dotenv(usecwd=True))

match urlparse(COPILOT_API_ENDPOINT).netloc:
case 'api.githubcopilot.com':
match urlparse(AI_API_ENDPOINT).netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
default_model = 'gpt-4o'
case 'models.github.ai':
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
default_model = 'openai/gpt-4o'
case _:
raise ValueError(f"Unsupported Model Endpoint: {COPILOT_API_ENDPOINT}")
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")

DEFAULT_MODEL = os.getenv('COPILOT_DEFAULT_MODEL', default=default_model)

Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(self,
model_settings: ModelSettings | None = None,
run_hooks: TaskRunHooks | None = None,
agent_hooks: TaskAgentHooks | None = None):
client = AsyncOpenAI(base_url=COPILOT_API_ENDPOINT,
client = AsyncOpenAI(base_url=AI_API_ENDPOINT,
api_key=os.getenv('COPILOT_TOKEN'),
default_headers={'Copilot-Integration-Id': COPILOT_INTEGRATION_ID})
set_default_openai_client(client)
Expand Down
40 changes: 25 additions & 15 deletions src/seclab_taskflow_agent/capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,50 @@
import json
import logging
import os
from strenum import StrEnum
from urllib.parse import urlparse

# you can also set https://models.github.ai/inference if you prefer
# you can also set https://api.githubcopilot.com if you prefer
# but beware that your taskflows need to reference the correct model id
# since the Modeld API uses it's own id schema, use -l with your desired
# since different APIs use their own id schema, use -l with your desired
# endpoint to retrieve the correct id names to use for your taskflow
COPILOT_API_ENDPOINT = os.getenv('COPILOT_API_ENDPOINT', default='https://api.githubcopilot.com')
AI_API_ENDPOINT = os.getenv('AI_API_ENDPOINT', default='https://models.github.ai/inference')

# Enumeration of currently support API endpoints.
class AI_API_ENDPOINT_ENUM(StrEnum):
AI_API_MODELS_GITHUB = 'models.github.ai'
AI_API_GITHUBCOPILOT = 'api.githubcopilot.com'

COPILOT_INTEGRATION_ID = 'vscode-chat'

# assume we are >= python 3.9 for our type hints
def list_capi_models(token: str) -> dict[str, dict]:
"""Retrieve a dictionary of available CAPI models"""
models = {}
try:
match urlparse(COPILOT_API_ENDPOINT).netloc:
case 'api.githubcopilot.com':
netloc = urlparse(AI_API_ENDPOINT).netloc
match netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
models_catalog = 'models'
case 'models.github.ai':
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
models_catalog = 'catalog/models'
case _:
raise ValueError(f"Unsupported Model Endpoint: {COPILOT_API_ENDPOINT}")
r = httpx.get(httpx.URL(COPILOT_API_ENDPOINT).join(models_catalog),
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
r = httpx.get(httpx.URL(AI_API_ENDPOINT).join(models_catalog),
headers={
'Accept': 'application/json',
'Authorization': f'Bearer {token}',
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID
})
r.raise_for_status()
# CAPI vs Models API
match urlparse(COPILOT_API_ENDPOINT).netloc:
case 'api.githubcopilot.com':
match netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
models_list = r.json().get('data', [])
case 'models.github.ai':
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
models_list = r.json()
case _:
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
for model in models_list:
models[model.get('id')] = dict(model)
except httpx.RequestError as e:
Expand All @@ -51,17 +61,17 @@ def list_capi_models(token: str) -> dict[str, dict]:
return models

def supports_tool_calls(model: str, models: dict) -> bool:
match urlparse(COPILOT_API_ENDPOINT).netloc:
case 'api.githubcopilot.com':
match urlparse(AI_API_ENDPOINT).netloc:
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
return models.get(model, {}).\
get('capabilities', {}).\
get('supports', {}).\
get('tool_calls', False)
case 'models.github.ai':
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
return 'tool-calling' in models.get(model, {}).\
get('capabilities', [])
case _:
raise ValueError(f"Unsupported Model Endpoint: {COPILOT_API_ENDPOINT}")
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")

def list_tool_call_models(token: str) -> dict[str, dict]:
models = list_capi_models(token)
Expand Down
58 changes: 58 additions & 0 deletions tests/test_api_endpoint_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-FileCopyrightText: 2025 GitHub
# SPDX-License-Identifier: MIT

"""
Test API endpoint configuration.
"""

import pytest
import os
from urllib.parse import urlparse

class TestAPIEndpoint:
"""Test API endpoint configuration."""

@staticmethod
def _reload_capi_module():
"""Helper method to reload the capi module."""
import importlib
import seclab_taskflow_agent.capi
importlib.reload(seclab_taskflow_agent.capi)

def test_default_api_endpoint(self):
"""Test that default API endpoint is set to models.github.ai/inference."""
import seclab_taskflow_agent.capi
# When no env var is set, it should default to models.github.ai/inference
# Note: We can't easily test this without manipulating the environment
# so we'll just import and verify the constant exists
endpoint = seclab_taskflow_agent.capi.AI_API_ENDPOINT
assert endpoint is not None
assert isinstance(endpoint, str)
assert urlparse(endpoint).netloc == seclab_taskflow_agent.capi.AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB

def test_api_endpoint_env_override(self):
"""Test that AI_API_ENDPOINT can be overridden by environment variable."""
# Save original env
original_env = os.environ.get('AI_API_ENDPOINT')

try:
# Set custom endpoint
test_endpoint = 'https://test.example.com'
os.environ['AI_API_ENDPOINT'] = test_endpoint

# Reload the module to pick up the new env var
self._reload_capi_module()

import seclab_taskflow_agent.capi
assert seclab_taskflow_agent.capi.AI_API_ENDPOINT == test_endpoint
finally:
# Restore original env
if original_env is None:
os.environ.pop('AI_API_ENDPOINT', None)
else:
os.environ['AI_API_ENDPOINT'] = original_env
# Reload again to restore original state
self._reload_capi_module()

if __name__ == '__main__':
pytest.main([__file__, '-v'])
73 changes: 73 additions & 0 deletions tests/test_cli_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: 2025 GitHub
# SPDX-License-Identifier: MIT

"""
Test CLI global variable parsing.
"""

import pytest
from seclab_taskflow_agent.available_tools import AvailableTools

class TestCliGlobals:
"""Test CLI global variable parsing."""

def test_parse_single_global(self):
"""Test parsing a single global variable from command line."""
from seclab_taskflow_agent.__main__ import parse_prompt_args
available_tools = AvailableTools()

p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
available_tools, "-t example -g fruit=apples")

assert t == "example"
assert cli_globals == {"fruit": "apples"}
assert p is None
assert l is False

def test_parse_multiple_globals(self):
"""Test parsing multiple global variables from command line."""
from seclab_taskflow_agent.__main__ import parse_prompt_args
available_tools = AvailableTools()

p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
available_tools, "-t example -g fruit=apples -g color=red")

assert t == "example"
assert cli_globals == {"fruit": "apples", "color": "red"}
assert p is None
assert l is False

def test_parse_global_with_spaces(self):
"""Test parsing global variables with spaces in values."""
from seclab_taskflow_agent.__main__ import parse_prompt_args
available_tools = AvailableTools()

p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
available_tools, "-t example -g message=hello world")

assert t == "example"
# "world" becomes part of the prompt, not the value
assert cli_globals == {"message": "hello"}
assert "world" in user_prompt

def test_parse_global_with_equals_in_value(self):
"""Test parsing global variables with equals sign in value."""
from seclab_taskflow_agent.__main__ import parse_prompt_args
available_tools = AvailableTools()

p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
available_tools, "-t example -g equation=x=5")

assert t == "example"
assert cli_globals == {"equation": "x=5"}

def test_globals_in_taskflow_file(self):
"""Test that globals can be read from taskflow file."""
available_tools = AvailableTools()

taskflow = available_tools.get_taskflow("tests.data.test_globals_taskflow")
assert 'globals' in taskflow
assert taskflow['globals']['test_var'] == 'default_value'

if __name__ == '__main__':
pytest.main([__file__, '-v'])
64 changes: 0 additions & 64 deletions tests/test_yaml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
"""

import pytest
import tempfile
from pathlib import Path
import yaml
from seclab_taskflow_agent.available_tools import AvailableTools

class TestYamlParser:
Expand Down Expand Up @@ -43,66 +40,5 @@ def test_parse_example_taskflows(self):
assert len(example_task_flow['taskflow']) == 4 # 4 tasks in taskflow
assert example_task_flow['taskflow'][0]['task']['max_steps'] == 20

class TestCliGlobals:
"""Test CLI global variable parsing."""

def test_parse_single_global(self):
"""Test parsing a single global variable from command line."""
from seclab_taskflow_agent.__main__ import parse_prompt_args
available_tools = AvailableTools()

p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
available_tools, "-t example -g fruit=apples")

assert t == "example"
assert cli_globals == {"fruit": "apples"}
assert p is None
assert l is False

def test_parse_multiple_globals(self):
"""Test parsing multiple global variables from command line."""
from seclab_taskflow_agent.__main__ import parse_prompt_args
available_tools = AvailableTools()

p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
available_tools, "-t example -g fruit=apples -g color=red")

assert t == "example"
assert cli_globals == {"fruit": "apples", "color": "red"}
assert p is None
assert l is False

def test_parse_global_with_spaces(self):
"""Test parsing global variables with spaces in values."""
from seclab_taskflow_agent.__main__ import parse_prompt_args
available_tools = AvailableTools()

p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
available_tools, "-t example -g message=hello world")

assert t == "example"
# "world" becomes part of the prompt, not the value
assert cli_globals == {"message": "hello"}
assert "world" in user_prompt

def test_parse_global_with_equals_in_value(self):
"""Test parsing global variables with equals sign in value."""
from seclab_taskflow_agent.__main__ import parse_prompt_args
available_tools = AvailableTools()

p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
available_tools, "-t example -g equation=x=5")

assert t == "example"
assert cli_globals == {"equation": "x=5"}

def test_globals_in_taskflow_file(self):
"""Test that globals can be read from taskflow file."""
available_tools = AvailableTools()

taskflow = available_tools.get_taskflow("tests.data.test_globals_taskflow")
assert 'globals' in taskflow
assert taskflow['globals']['test_var'] == 'default_value'

if __name__ == '__main__':
pytest.main([__file__, '-v'])