Skip to content

Commit 0315a96

Browse files
Merge branch 'main' into remove-hardcoded-models
2 parents 780ef2a + 8c6bc9d commit 0315a96

File tree

6 files changed

+163
-85
lines changed

6 files changed

+163
-85
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ dependencies = [
9797
"SQLAlchemy==2.0.41",
9898
"sse-starlette==2.4.1",
9999
"starlette==0.49.1",
100+
"strenum==0.4.15",
100101
"tqdm==4.67.1",
101102
"typer==0.16.0",
102103
"types-requests==2.32.4.20250611",

src/seclab_taskflow_agent/agent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@
1515
from agents.run import RunHooks
1616
from agents import Agent, Runner, AgentHooks, RunHooks, result, function_tool, Tool, RunContextWrapper, TContext, OpenAIChatCompletionsModel, set_default_openai_client, set_default_openai_api, set_tracing_disabled
1717

18-
from .capi import COPILOT_INTEGRATION_ID, COPILOT_API_ENDPOINT
18+
from .capi import COPILOT_INTEGRATION_ID, AI_API_ENDPOINT, AI_API_ENDPOINT_ENUM
1919

2020
# grab our secrets from .env, this must be in .gitignore
2121
load_dotenv(find_dotenv(usecwd=True))
2222

23-
match urlparse(COPILOT_API_ENDPOINT).netloc:
24-
case 'api.githubcopilot.com':
23+
match urlparse(AI_API_ENDPOINT).netloc:
24+
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
2525
default_model = 'gpt-4o'
26-
case 'models.github.ai':
26+
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
2727
default_model = 'openai/gpt-4o'
2828
case _:
29-
raise ValueError(f"Unsupported Model Endpoint: {COPILOT_API_ENDPOINT}")
29+
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
3030

3131
DEFAULT_MODEL = os.getenv('COPILOT_DEFAULT_MODEL', default=default_model)
3232

@@ -148,7 +148,7 @@ def __init__(self,
148148
model_settings: ModelSettings | None = None,
149149
run_hooks: TaskRunHooks | None = None,
150150
agent_hooks: TaskAgentHooks | None = None):
151-
client = AsyncOpenAI(base_url=COPILOT_API_ENDPOINT,
151+
client = AsyncOpenAI(base_url=AI_API_ENDPOINT,
152152
api_key=os.getenv('COPILOT_TOKEN'),
153153
default_headers={'Copilot-Integration-Id': COPILOT_INTEGRATION_ID})
154154
set_default_openai_client(client)

src/seclab_taskflow_agent/capi.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,50 @@
66
import json
77
import logging
88
import os
9+
from strenum import StrEnum
910
from urllib.parse import urlparse
1011

11-
# you can also set https://models.github.ai/inference if you prefer
12+
# you can also set https://api.githubcopilot.com if you prefer
1213
# but beware that your taskflows need to reference the correct model id
13-
# since the Modeld API uses it's own id schema, use -l with your desired
14+
# since different APIs use their own id schema, use -l with your desired
1415
# endpoint to retrieve the correct id names to use for your taskflow
15-
COPILOT_API_ENDPOINT = os.getenv('COPILOT_API_ENDPOINT', default='https://api.githubcopilot.com')
16+
AI_API_ENDPOINT = os.getenv('AI_API_ENDPOINT', default='https://models.github.ai/inference')
17+
18+
# Enumeration of currently supported API endpoints.
19+
class AI_API_ENDPOINT_ENUM(StrEnum):
20+
AI_API_MODELS_GITHUB = 'models.github.ai'
21+
AI_API_GITHUBCOPILOT = 'api.githubcopilot.com'
22+
1623
COPILOT_INTEGRATION_ID = 'vscode-chat'
1724

1825
# assume we are >= python 3.9 for our type hints
1926
def list_capi_models(token: str) -> dict[str, dict]:
2027
"""Retrieve a dictionary of available CAPI models"""
2128
models = {}
2229
try:
23-
match urlparse(COPILOT_API_ENDPOINT).netloc:
24-
case 'api.githubcopilot.com':
30+
netloc = urlparse(AI_API_ENDPOINT).netloc
31+
match netloc:
32+
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
2533
models_catalog = 'models'
26-
case 'models.github.ai':
34+
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
2735
models_catalog = 'catalog/models'
2836
case _:
29-
raise ValueError(f"Unsupported Model Endpoint: {COPILOT_API_ENDPOINT}")
30-
r = httpx.get(httpx.URL(COPILOT_API_ENDPOINT).join(models_catalog),
37+
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
38+
r = httpx.get(httpx.URL(AI_API_ENDPOINT).join(models_catalog),
3139
headers={
3240
'Accept': 'application/json',
3341
'Authorization': f'Bearer {token}',
3442
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID
3543
})
3644
r.raise_for_status()
3745
# CAPI vs Models API
38-
match urlparse(COPILOT_API_ENDPOINT).netloc:
39-
case 'api.githubcopilot.com':
46+
match netloc:
47+
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
4048
models_list = r.json().get('data', [])
41-
case 'models.github.ai':
49+
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
4250
models_list = r.json()
51+
case _:
52+
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
4353
for model in models_list:
4454
models[model.get('id')] = dict(model)
4555
except httpx.RequestError as e:
@@ -51,17 +61,17 @@ def list_capi_models(token: str) -> dict[str, dict]:
5161
return models
5262

5363
def supports_tool_calls(model: str, models: dict) -> bool:
54-
match urlparse(COPILOT_API_ENDPOINT).netloc:
55-
case 'api.githubcopilot.com':
64+
match urlparse(AI_API_ENDPOINT).netloc:
65+
case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT:
5666
return models.get(model, {}).\
5767
get('capabilities', {}).\
5868
get('supports', {}).\
5969
get('tool_calls', False)
60-
case 'models.github.ai':
70+
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
6171
return 'tool-calling' in models.get(model, {}).\
6272
get('capabilities', [])
6373
case _:
64-
raise ValueError(f"Unsupported Model Endpoint: {COPILOT_API_ENDPOINT}")
74+
raise ValueError(f"Unsupported Model Endpoint: {AI_API_ENDPOINT}")
6575

6676
def list_tool_call_models(token: str) -> dict[str, dict]:
6777
models = list_capi_models(token)

tests/test_api_endpoint_config.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# SPDX-FileCopyrightText: 2025 GitHub
2+
# SPDX-License-Identifier: MIT
3+
4+
"""
5+
Test API endpoint configuration.
6+
"""
7+
8+
import pytest
9+
import os
10+
from urllib.parse import urlparse
11+
12+
class TestAPIEndpoint:
13+
"""Test API endpoint configuration."""
14+
15+
@staticmethod
16+
def _reload_capi_module():
17+
"""Helper method to reload the capi module."""
18+
import importlib
19+
import seclab_taskflow_agent.capi
20+
importlib.reload(seclab_taskflow_agent.capi)
21+
22+
def test_default_api_endpoint(self):
23+
"""Test that default API endpoint is set to models.github.ai/inference."""
24+
import seclab_taskflow_agent.capi
25+
# When no env var is set, it should default to models.github.ai/inference
26+
# Note: We can't easily test this without manipulating the environment
27+
# so we'll just import and verify the constant exists
28+
endpoint = seclab_taskflow_agent.capi.AI_API_ENDPOINT
29+
assert endpoint is not None
30+
assert isinstance(endpoint, str)
31+
assert urlparse(endpoint).netloc == seclab_taskflow_agent.capi.AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB
32+
33+
def test_api_endpoint_env_override(self):
34+
"""Test that AI_API_ENDPOINT can be overridden by environment variable."""
35+
# Save original env
36+
original_env = os.environ.get('AI_API_ENDPOINT')
37+
38+
try:
39+
# Set custom endpoint
40+
test_endpoint = 'https://test.example.com'
41+
os.environ['AI_API_ENDPOINT'] = test_endpoint
42+
43+
# Reload the module to pick up the new env var
44+
self._reload_capi_module()
45+
46+
import seclab_taskflow_agent.capi
47+
assert seclab_taskflow_agent.capi.AI_API_ENDPOINT == test_endpoint
48+
finally:
49+
# Restore original env
50+
if original_env is None:
51+
os.environ.pop('AI_API_ENDPOINT', None)
52+
else:
53+
os.environ['AI_API_ENDPOINT'] = original_env
54+
# Reload again to restore original state
55+
self._reload_capi_module()
56+
57+
if __name__ == '__main__':
58+
pytest.main([__file__, '-v'])

tests/test_cli_parser.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# SPDX-FileCopyrightText: 2025 GitHub
2+
# SPDX-License-Identifier: MIT
3+
4+
"""
5+
Test CLI global variable parsing.
6+
"""
7+
8+
import pytest
9+
from seclab_taskflow_agent.available_tools import AvailableTools
10+
11+
class TestCliGlobals:
12+
"""Test CLI global variable parsing."""
13+
14+
def test_parse_single_global(self):
15+
"""Test parsing a single global variable from command line."""
16+
from seclab_taskflow_agent.__main__ import parse_prompt_args
17+
available_tools = AvailableTools()
18+
19+
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
20+
available_tools, "-t example -g fruit=apples")
21+
22+
assert t == "example"
23+
assert cli_globals == {"fruit": "apples"}
24+
assert p is None
25+
assert l is False
26+
27+
def test_parse_multiple_globals(self):
28+
"""Test parsing multiple global variables from command line."""
29+
from seclab_taskflow_agent.__main__ import parse_prompt_args
30+
available_tools = AvailableTools()
31+
32+
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
33+
available_tools, "-t example -g fruit=apples -g color=red")
34+
35+
assert t == "example"
36+
assert cli_globals == {"fruit": "apples", "color": "red"}
37+
assert p is None
38+
assert l is False
39+
40+
def test_parse_global_with_spaces(self):
41+
"""Test parsing global variables with spaces in values."""
42+
from seclab_taskflow_agent.__main__ import parse_prompt_args
43+
available_tools = AvailableTools()
44+
45+
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
46+
available_tools, "-t example -g message=hello world")
47+
48+
assert t == "example"
49+
# "world" becomes part of the prompt, not the value
50+
assert cli_globals == {"message": "hello"}
51+
assert "world" in user_prompt
52+
53+
def test_parse_global_with_equals_in_value(self):
54+
"""Test parsing global variables with equals sign in value."""
55+
from seclab_taskflow_agent.__main__ import parse_prompt_args
56+
available_tools = AvailableTools()
57+
58+
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
59+
available_tools, "-t example -g equation=x=5")
60+
61+
assert t == "example"
62+
assert cli_globals == {"equation": "x=5"}
63+
64+
def test_globals_in_taskflow_file(self):
65+
"""Test that globals can be read from taskflow file."""
66+
available_tools = AvailableTools()
67+
68+
taskflow = available_tools.get_taskflow("tests.data.test_globals_taskflow")
69+
assert 'globals' in taskflow
70+
assert taskflow['globals']['test_var'] == 'default_value'
71+
72+
if __name__ == '__main__':
73+
pytest.main([__file__, '-v'])

tests/test_yaml_parser.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
"""
99

1010
import pytest
11-
import tempfile
12-
from pathlib import Path
13-
import yaml
1411
from seclab_taskflow_agent.available_tools import AvailableTools
1512

1613
class TestYamlParser:
@@ -43,66 +40,5 @@ def test_parse_example_taskflows(self):
4340
assert len(example_task_flow['taskflow']) == 4 # 4 tasks in taskflow
4441
assert example_task_flow['taskflow'][0]['task']['max_steps'] == 20
4542

46-
class TestCliGlobals:
47-
"""Test CLI global variable parsing."""
48-
49-
def test_parse_single_global(self):
50-
"""Test parsing a single global variable from command line."""
51-
from seclab_taskflow_agent.__main__ import parse_prompt_args
52-
available_tools = AvailableTools()
53-
54-
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
55-
available_tools, "-t example -g fruit=apples")
56-
57-
assert t == "example"
58-
assert cli_globals == {"fruit": "apples"}
59-
assert p is None
60-
assert l is False
61-
62-
def test_parse_multiple_globals(self):
63-
"""Test parsing multiple global variables from command line."""
64-
from seclab_taskflow_agent.__main__ import parse_prompt_args
65-
available_tools = AvailableTools()
66-
67-
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
68-
available_tools, "-t example -g fruit=apples -g color=red")
69-
70-
assert t == "example"
71-
assert cli_globals == {"fruit": "apples", "color": "red"}
72-
assert p is None
73-
assert l is False
74-
75-
def test_parse_global_with_spaces(self):
76-
"""Test parsing global variables with spaces in values."""
77-
from seclab_taskflow_agent.__main__ import parse_prompt_args
78-
available_tools = AvailableTools()
79-
80-
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
81-
available_tools, "-t example -g message=hello world")
82-
83-
assert t == "example"
84-
# "world" becomes part of the prompt, not the value
85-
assert cli_globals == {"message": "hello"}
86-
assert "world" in user_prompt
87-
88-
def test_parse_global_with_equals_in_value(self):
89-
"""Test parsing global variables with equals sign in value."""
90-
from seclab_taskflow_agent.__main__ import parse_prompt_args
91-
available_tools = AvailableTools()
92-
93-
p, t, l, cli_globals, user_prompt, _ = parse_prompt_args(
94-
available_tools, "-t example -g equation=x=5")
95-
96-
assert t == "example"
97-
assert cli_globals == {"equation": "x=5"}
98-
99-
def test_globals_in_taskflow_file(self):
100-
"""Test that globals can be read from taskflow file."""
101-
available_tools = AvailableTools()
102-
103-
taskflow = available_tools.get_taskflow("tests.data.test_globals_taskflow")
104-
assert 'globals' in taskflow
105-
assert taskflow['globals']['test_var'] == 'default_value'
106-
10743
if __name__ == '__main__':
10844
pytest.main([__file__, '-v'])

0 commit comments

Comments
 (0)