Skip to content
This repository was archived by the owner on Sep 30, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions dreadnode_cli/agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
format_agent,
format_agent_versions,
format_run,
format_run_groups,
format_runs,
format_strike_models,
format_strikes,
Expand Down Expand Up @@ -68,7 +69,9 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None
):
print()
raise Exception(
"Agent link does not match the current server profile. Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
f"Current agent link ([yellow]{agent_config.active_link.profile}[/]) does not match "
f"the current server profile ([magenta]{user_config.active_profile_name}[/]). "
"Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
)

switch_profile(agent_config.active_link.profile)
Expand Down Expand Up @@ -248,7 +251,14 @@ def push(

if agent_config.links and not agent_config.has_link_to_profile(user_config.active_profile_name):
print(f":link: Linking as a fresh agent to the current profile [magenta]{user_config.active_profile_name}[/]")
print()
new = True
elif agent_config.active and agent_config.active_link.profile != user_config.active_profile_name:
raise Exception(
f"Current agent link ([yellow]{agent_config.active_link.profile}[/]) does not match "
f"the current server profile ([magenta]{user_config.active_profile_name}[/]). "
"Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]."
)

server_config = user_config.get_server_config()

Expand Down Expand Up @@ -372,6 +382,7 @@ def deploy(
] = None,
strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to use for this run")] = None,
watch: t.Annotated[bool, typer.Option("--watch", "-w", help="Watch the run status")] = True,
group: t.Annotated[str | None, typer.Option("--group", "-g", help="Group to associate this run with")] = None,
) -> None:
agent_config = AgentConfig.read(directory)
ensure_profile(agent_config)
Expand Down Expand Up @@ -421,7 +432,7 @@ def deploy(
)

run = client.start_strike_run(
agent.latest_version.id, strike=strike, model=model, user_model=user_model, context=context
agent.latest_version.id, strike=strike, model=model, user_model=user_model, group=group, context=context
)
agent_config.add_run(run.id).write(directory)
formatted = format_run(run, server_url=server_config.url)
Expand Down Expand Up @@ -615,6 +626,14 @@ def switch(
print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]")


@cli.command(help="List strike run groups")
@pretty_cli
def run_groups() -> None:
client = api.create_client()
groups = client.list_strike_run_groups()
print(format_run_groups(groups))


@cli.command(help="Clone a github repository", no_args_is_help=True)
@pretty_cli
def clone(
Expand Down
36 changes: 30 additions & 6 deletions dreadnode_cli/agent/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,17 @@ def format_run(
table.add_column("Property", style="dim")
table.add_column("Value")

table.add_row("key", run.key)
table.add_row("status", Text(run.status, style=get_status_style(run.status)))
table.add_row("strike", f"[magenta]{run.strike_name}[/] ([dim]{run.strike_key}[/])")
table.add_row("type", run.strike_type)
table.add_row("group", Text(run.group_key or "-", style="blue" if run.group_key else ""))

if server_url != "":
table.add_row("", "")
table.add_row(
"url", Text(f"{server_url.rstrip('/')}/strikes/agents/{run.agent_key}/runs/{run.id}", style="cyan")
)

if run.agent_name:
agent_name = f"[bold magenta]{run.agent_name}[/] [[dim]{run.agent_key}[/]]"
Expand All @@ -280,10 +288,6 @@ def format_run(
table.add_row("model", run.model.replace(USER_MODEL_PREFIX, "") if run.model else "<default>")
table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])")
table.add_row("image", Text(run.agent_version.container.image, style="cyan"))
if server_url != "":
table.add_row(
"run url", Text(f"{server_url.rstrip('/')}/strikes/agents/{run.agent_key}/runs/{run.id}", style="cyan")
)
table.add_row("notes", run.agent_version.notes or "-")

table.add_row("", "")
Expand Down Expand Up @@ -314,21 +318,41 @@ def format_run(

def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("id", style="dim")
table.add_column("key", style="dim")
table.add_column("agent")
table.add_column("status")
table.add_column("model")
table.add_column("group")
table.add_column("started")
table.add_column("duration")

for run in runs:
table.add_row(
str(run.id),
run.key,
f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]",
Text(run.status, style="bold " + get_status_style(run.status)),
Text(run.model.replace(USER_MODEL_PREFIX, "") if run.model else "-"),
Text(run.group_key or "-", style="blue" if run.group_key else "dim"),
format_time(run.start),
Text(format_duration(run.start, run.end), style="bold cyan"),
)

return table


def format_run_groups(groups: list[api.Client.StrikeRunGroupResponse]) -> RenderableType:
table = Table(box=box.ROUNDED)
table.add_column("Name", style="bold cyan")
table.add_column("description")
table.add_column("runs", style="yellow")
table.add_column("created", style="dim")

for group in groups:
table.add_row(
group.key,
group.description or "-",
str(group.run_count),
group.created_at.astimezone().strftime("%c"),
)

return table
2 changes: 1 addition & 1 deletion dreadnode_cli/agent/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_ensure_profile() -> None:
agent_config.add_link("test-main", UUID("00000000-0000-0000-0000-000000000000"), "main")
agent_config.active = "test-other"
with patch("rich.prompt.Prompt.ask", return_value="n"):
with pytest.raises(Exception, match="Agent link does not match the current server profile"):
with pytest.raises(Exception, match="Current agent link"):
ensure_profile(agent_config, user_config=user_config)

# We should switch if the user agrees
Expand Down
30 changes: 30 additions & 0 deletions dreadnode_cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ class Container(BaseModel):
env: dict[str, str]
name: str | None

class StrikeMetricPoint(BaseModel):
timestamp: datetime
value: float
metadata: dict[str, t.Any]

class StrikeMetric(BaseModel):
type: str
description: str | None
points: "list[Client.StrikeMetricPoint]"

class StrikeAgentVersion(BaseModel):
id: UUID
created_at: datetime
Expand Down Expand Up @@ -350,6 +360,7 @@ class StrikeRunZone(_StrikeRunZone):
container_logs: dict[str, str]
outputs: list["Client.StrikeRunOutput"]
inferences: list[dict[str, t.Any]]
metrics: dict[str, "Client.StrikeMetric"]

class StrikeRunContext(BaseModel):
environment: dict[str, str] | None = None
Expand All @@ -358,6 +369,7 @@ class StrikeRunContext(BaseModel):

class _StrikeRun(BaseModel):
id: UUID
key: str
strike_id: UUID
strike_key: str
strike_name: str
Expand All @@ -373,6 +385,9 @@ class _StrikeRun(BaseModel):
status: "Client.StrikeRunStatus"
start: datetime | None
end: datetime | None
group_id: UUID | None
group_key: str | None
group_name: str | None

def is_running(self) -> bool:
return self.status in ["pending", "deploying", "running"]
Expand All @@ -388,6 +403,15 @@ class UserModel(BaseModel):
generator_id: str
api_key: str

class StrikeRunGroupResponse(BaseModel):
id: UUID
key: str
name: str
description: str | None
created_at: datetime
updated_at: datetime
run_count: int

def get_strike(self, strike: str) -> StrikeResponse:
response = self.request("GET", f"/api/strikes/{strike}")
return self.StrikeResponse(**response.json())
Expand Down Expand Up @@ -448,6 +472,7 @@ def start_strike_run(
user_model: UserModel | None = None,
context: StrikeRunContext | None = None,
strike: UUID | str | None = None,
group: UUID | str | None = None,
) -> StrikeRunResponse:
response = self.request(
"POST",
Expand All @@ -457,6 +482,7 @@ def start_strike_run(
"model": model,
"user_model": user_model.model_dump(mode="json") if user_model else None,
"strike": str(strike) if strike else None,
"group": str(group) if group else None,
"context": context.model_dump(mode="json") if context else None,
},
)
Expand All @@ -472,6 +498,10 @@ def list_strike_runs(self, *, strike_id: UUID | str | None = None) -> list[Strik
)
return [self.StrikeRunSummaryResponse(**run) for run in response.json()]

def list_strike_run_groups(self) -> list[StrikeRunGroupResponse]:
response = self.request("GET", "/api/strikes/groups")
return [self.StrikeRunGroupResponse(**group) for group in response.json()]


def create_client(*, profile: str | None = None) -> Client:
"""Create an authenticated API client using stored configuration data."""
Expand Down
Loading