diff --git a/dreadnode_cli/agent/cli.py b/dreadnode_cli/agent/cli.py index 637a2b6..1c9c97d 100644 --- a/dreadnode_cli/agent/cli.py +++ b/dreadnode_cli/agent/cli.py @@ -19,6 +19,7 @@ format_agent, format_agent_versions, format_run, + format_run_groups, format_runs, format_strike_models, format_strikes, @@ -68,7 +69,9 @@ def ensure_profile(agent_config: AgentConfig, *, user_config: UserConfig | None ): print() raise Exception( - "Agent link does not match the current server profile. Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]." + f"Current agent link ([yellow]{agent_config.active_link.profile}[/]) does not match " + f"the current server profile ([magenta]{user_config.active_profile_name}[/]). " + "Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]." ) switch_profile(agent_config.active_link.profile) @@ -248,7 +251,14 @@ def push( if agent_config.links and not agent_config.has_link_to_profile(user_config.active_profile_name): print(f":link: Linking as a fresh agent to the current profile [magenta]{user_config.active_profile_name}[/]") + print() new = True + elif agent_config.active and agent_config.active_link.profile != user_config.active_profile_name: + raise Exception( + f"Current agent link ([yellow]{agent_config.active_link.profile}[/]) does not match " + f"the current server profile ([magenta]{user_config.active_profile_name}[/]). " + "Use [bold]dreadnode agent switch[/] or [bold]dreadnode profile switch[/]." + ) server_config = user_config.get_server_config() @@ -372,6 +382,7 @@ def deploy( ] = None, strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to use for this run")] = None, watch: t.Annotated[bool, typer.Option("--watch", "-w", help="Watch the run status")] = True, + group: t.Annotated[str | None, typer.Option("--group", "-g", help="Group to associate this run with")] = None, ) -> None: agent_config = AgentConfig.read(directory) ensure_profile(agent_config) @@ -421,7 +432,7 @@ def deploy( ) run = client.start_strike_run( - agent.latest_version.id, strike=strike, model=model, user_model=user_model, context=context + agent.latest_version.id, strike=strike, model=model, user_model=user_model, group=group, context=context ) agent_config.add_run(run.id).write(directory) formatted = format_run(run, server_url=server_config.url) @@ -615,6 +626,14 @@ def switch( print(f":exclamation: '{agent_or_profile}' not found, use [bold]dreadnode agent links[/]") +@cli.command(help="List strike run groups") +@pretty_cli +def run_groups() -> None: + client = api.create_client() + groups = client.list_strike_run_groups() + print(format_run_groups(groups)) + + @cli.command(help="Clone a github repository", no_args_is_help=True) @pretty_cli def clone( diff --git a/dreadnode_cli/agent/format.py b/dreadnode_cli/agent/format.py index 54ac073..a1a0f2e 100644 --- a/dreadnode_cli/agent/format.py +++ b/dreadnode_cli/agent/format.py @@ -267,9 +267,17 @@ def format_run( table.add_column("Property", style="dim") table.add_column("Value") + table.add_row("key", run.key) table.add_row("status", Text(run.status, style=get_status_style(run.status))) table.add_row("strike", f"[magenta]{run.strike_name}[/] ([dim]{run.strike_key}[/])") table.add_row("type", run.strike_type) + table.add_row("group", Text(run.group_key or "-", style="blue" if run.group_key else "")) + + if server_url != "": + table.add_row("", "") + table.add_row( + "url", Text(f"{server_url.rstrip('/')}/strikes/agents/{run.agent_key}/runs/{run.id}", style="cyan") + ) if run.agent_name: agent_name = f"[bold magenta]{run.agent_name}[/] [[dim]{run.agent_key}[/]]" @@ -280,10 +288,6 @@ def format_run( table.add_row("model", run.model.replace(USER_MODEL_PREFIX, "") if run.model else "") table.add_row("agent", f"{agent_name} ([dim]rev[/] [yellow]{run.agent_revision}[/])") table.add_row("image", Text(run.agent_version.container.image, style="cyan")) - if server_url != "": - table.add_row( - "run url", Text(f"{server_url.rstrip('/')}/strikes/agents/{run.agent_key}/runs/{run.id}", style="cyan") - ) table.add_row("notes", run.agent_version.notes or "-") table.add_row("", "") @@ -314,21 +318,41 @@ def format_run( def format_runs(runs: list[api.Client.StrikeRunSummaryResponse]) -> RenderableType: table = Table(box=box.ROUNDED) - table.add_column("id", style="dim") + table.add_column("key", style="dim") table.add_column("agent") table.add_column("status") table.add_column("model") + table.add_column("group") table.add_column("started") table.add_column("duration") for run in runs: table.add_row( - str(run.id), + run.key, f"[bold magenta]{run.agent_key}[/] [dim]:[/] [yellow]{run.agent_revision}[/]", Text(run.status, style="bold " + get_status_style(run.status)), Text(run.model.replace(USER_MODEL_PREFIX, "") if run.model else "-"), + Text(run.group_key or "-", style="blue" if run.group_key else "dim"), format_time(run.start), Text(format_duration(run.start, run.end), style="bold cyan"), ) return table + + +def format_run_groups(groups: list[api.Client.StrikeRunGroupResponse]) -> RenderableType: + table = Table(box=box.ROUNDED) + table.add_column("Name", style="bold cyan") + table.add_column("description") + table.add_column("runs", style="yellow") + table.add_column("created", style="dim") + + for group in groups: + table.add_row( + group.key, + group.description or "-", + str(group.run_count), + group.created_at.astimezone().strftime("%c"), + ) + + return table diff --git a/dreadnode_cli/agent/tests/test_config.py b/dreadnode_cli/agent/tests/test_config.py index 6cd9e20..0b5d475 100644 --- a/dreadnode_cli/agent/tests/test_config.py +++ b/dreadnode_cli/agent/tests/test_config.py @@ -121,7 +121,7 @@ def test_ensure_profile() -> None: agent_config.add_link("test-main", UUID("00000000-0000-0000-0000-000000000000"), "main") agent_config.active = "test-other" with patch("rich.prompt.Prompt.ask", return_value="n"): - with pytest.raises(Exception, match="Agent link does not match the current server profile"): + with pytest.raises(Exception, match="Current agent link"): ensure_profile(agent_config, user_config=user_config) # We should switch if the user agrees diff --git a/dreadnode_cli/api.py b/dreadnode_cli/api.py index 6be6d6b..484af1b 100644 --- a/dreadnode_cli/api.py +++ b/dreadnode_cli/api.py @@ -293,6 +293,16 @@ class Container(BaseModel): env: dict[str, str] name: str | None + class StrikeMetricPoint(BaseModel): + timestamp: datetime + value: float + metadata: dict[str, t.Any] + + class StrikeMetric(BaseModel): + type: str + description: str | None + points: "list[Client.StrikeMetricPoint]" + class StrikeAgentVersion(BaseModel): id: UUID created_at: datetime @@ -350,6 +360,7 @@ class StrikeRunZone(_StrikeRunZone): container_logs: dict[str, str] outputs: list["Client.StrikeRunOutput"] inferences: list[dict[str, t.Any]] + metrics: dict[str, "Client.StrikeMetric"] class StrikeRunContext(BaseModel): environment: dict[str, str] | None = None @@ -358,6 +369,7 @@ class StrikeRunContext(BaseModel): class _StrikeRun(BaseModel): id: UUID + key: str strike_id: UUID strike_key: str strike_name: str @@ -373,6 +385,9 @@ class _StrikeRun(BaseModel): status: "Client.StrikeRunStatus" start: datetime | None end: datetime | None + group_id: UUID | None + group_key: str | None + group_name: str | None def is_running(self) -> bool: return self.status in ["pending", "deploying", "running"] @@ -388,6 +403,15 @@ class UserModel(BaseModel): generator_id: str api_key: str + class StrikeRunGroupResponse(BaseModel): + id: UUID + key: str + name: str + description: str | None + created_at: datetime + updated_at: datetime + run_count: int + def get_strike(self, strike: str) -> StrikeResponse: response = self.request("GET", f"/api/strikes/{strike}") return self.StrikeResponse(**response.json()) @@ -448,6 +472,7 @@ def start_strike_run( user_model: UserModel | None = None, context: StrikeRunContext | None = None, strike: UUID | str | None = None, + group: UUID | str | None = None, ) -> StrikeRunResponse: response = self.request( "POST", @@ -457,6 +482,7 @@ def start_strike_run( "model": model, "user_model": user_model.model_dump(mode="json") if user_model else None, "strike": str(strike) if strike else None, + "group": str(group) if group else None, "context": context.model_dump(mode="json") if context else None, }, ) @@ -472,6 +498,10 @@ def list_strike_runs(self, *, strike_id: UUID | str | None = None) -> list[Strik ) return [self.StrikeRunSummaryResponse(**run) for run in response.json()] + def list_strike_run_groups(self) -> list[StrikeRunGroupResponse]: + response = self.request("GET", "/api/strikes/groups") + return [self.StrikeRunGroupResponse(**group) for group in response.json()] + def create_client(*, profile: str | None = None) -> Client: """Create an authenticated API client using stored configuration data."""