diff --git a/CLI.md b/CLI.md index 508037c..b9dfb60 100644 --- a/CLI.md +++ b/CLI.md @@ -85,6 +85,9 @@ $ dreadnode agent deploy [OPTIONS] * `-m, --model TEXT`: The inference model to use for this run * `-d, --dir DIRECTORY`: The agent directory [default: .] +* `-e, --env-var TEXT`: Environment vars to override for this run (key=value) +* `-p, --param TEXT`: Define custom parameters for this run (key = value in toml syntax or @filename.toml for multiple values) +* `-c, --command TEXT`: Override the container command for this run. * `-s, --strike TEXT`: The strike to use for this run * `-w, --watch`: Watch the run status [default: True] * `--help`: Show this message and exit. diff --git a/README.md b/README.md index e617bc1..b818aa0 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,18 @@ dreadnode agent push # start a new run using the latest agent version. dreadnode agent deploy +# start a new run using the latest agent version with custom environment variables +dreadnode agent deploy --env-var TEST_ENV=test --env-var ANOTHER_ENV=another_value + +# start a new run using the latest agent version with custom parameters (using toml syntax) +dreadnode agent deploy --param "foo = 'bar'" --param "baz = 123.0" + +# start a new run using the latest agent version with custom parameters from a toml file +dreadnode agent deploy --param @parameters.toml + +# start a new run using the latest agent version and override the container command +dreadnode agent deploy --command "echo 'Hello, world!'" + # show the latest run of the currently active agent dreadnode agent latest diff --git a/dreadnode_cli/agent/cli.py b/dreadnode_cli/agent/cli.py index b42c3a4..637a2b6 100644 --- a/dreadnode_cli/agent/cli.py +++ b/dreadnode_cli/agent/cli.py @@ -4,6 +4,7 @@ import time import typing as t +import toml import typer from rich import box, print from rich.live import Live @@ -318,6 +319,31 @@ def push( print(":tada: Agent pushed. use [bold]dreadnode agent deploy[/] to start a new run.") +def prepare_run_context( + env_vars: list[str] | None, parameters: list[str] | None, command: str | None +) -> Client.StrikeRunContext | None: + if not env_vars and not parameters and not command: + return None + + context = Client.StrikeRunContext() + + if env_vars: + context.environment = {env_var.split("=")[0]: env_var.split("=")[1] for env_var in env_vars} + + if parameters: + context.parameters = {} + for param in parameters: + if param.startswith("@"): + context.parameters.update(toml.load(open(param[1:]))) + else: + context.parameters.update(toml.loads(param)) + + if command: + context.command = command + + return context + + @cli.command(help="Start a new run using the latest active agent version") @pretty_cli def deploy( @@ -328,6 +354,22 @@ def deploy( pathlib.Path, typer.Option("--dir", "-d", help="The agent directory", file_okay=False, resolve_path=True), ] = pathlib.Path("."), + env_vars: t.Annotated[ + list[str] | None, + typer.Option("--env-var", "-e", help="Environment vars to override for this run (key=value)"), + ] = None, + parameters: t.Annotated[ + list[str] | None, + typer.Option( + "--param", + "-p", + help="Define custom parameters for this run (key = value in toml syntax or @filename.toml for multiple values)", + ), + ] = None, + command: t.Annotated[ + str | None, + typer.Option("--command", "-c", help="Override the container command for this run."), + ] = None, strike: t.Annotated[str | None, typer.Option("--strike", "-s", help="The strike to use for this run")] = None, watch: t.Annotated[bool, typer.Option("--watch", "-w", help="Watch the run status")] = True, ) -> None: @@ -346,6 +388,8 @@ def deploy( if strike is None: raise Exception("No strike specified, use -s/--strike or set the strike in the agent config") + context = prepare_run_context(env_vars, parameters, command) + user_models = UserModels.read() user_model: Client.UserModel | None = None @@ -376,7 +420,9 @@ def deploy( f"Model '{model}' is not user-defined nor is it available in strike '{strike_response.name}'" ) - run = client.start_strike_run(agent.latest_version.id, strike=strike, model=model, user_model=user_model) + run = client.start_strike_run( + agent.latest_version.id, strike=strike, model=model, user_model=user_model, context=context + ) agent_config.add_run(run.id).write(directory) formatted = format_run(run, server_url=server_config.url) diff --git a/dreadnode_cli/agent/format.py b/dreadnode_cli/agent/format.py index 9ab9524..54ac073 100644 --- a/dreadnode_cli/agent/format.py +++ b/dreadnode_cli/agent/format.py @@ -291,6 +291,19 @@ def format_run( table.add_row("start", format_time(run.start)) table.add_row("end", format_time(run.end)) + if run.context and (run.context.environment or run.context.parameters or run.context.command): + table.add_row("", "") + if run.context.environment: + table.add_row( + "environment", " ".join(f"[magenta]{k}[/]=[yellow]{v}[/]" for k, v in run.context.environment.items()) + ) + if run.context.parameters: + table.add_row( + "parameters", " ".join(f"[magenta]{k}[/]=[yellow]{v}[/]" for k, v in run.context.parameters.items()) + ) + if run.context.command: + table.add_row("command", f"[bold][red]{run.context.command}[/red][/bold]") + components: list[RenderableType] = [ table, format_zones_verbose(run.zones, include_logs=include_logs) if verbose else format_zones_summary(run.zones), diff --git a/dreadnode_cli/api.py b/dreadnode_cli/api.py index abf9335..6be6d6b 100644 --- a/dreadnode_cli/api.py +++ b/dreadnode_cli/api.py @@ -351,6 +351,11 @@ class StrikeRunZone(_StrikeRunZone): outputs: list["Client.StrikeRunOutput"] inferences: list[dict[str, t.Any]] + class StrikeRunContext(BaseModel): + environment: dict[str, str] | None = None + parameters: dict[str, t.Any] | None = None + command: str | None = None + class _StrikeRun(BaseModel): id: UUID strike_id: UUID @@ -364,6 +369,7 @@ class _StrikeRun(BaseModel): agent_name: str | None = None agent_revision: int agent_version: "Client.StrikeAgentVersion" + context: "Client.StrikeRunContext | None" = None status: "Client.StrikeRunStatus" start: datetime | None end: datetime | None @@ -440,6 +446,7 @@ def start_strike_run( *, model: str | None = None, user_model: UserModel | None = None, + context: StrikeRunContext | None = None, strike: UUID | str | None = None, ) -> StrikeRunResponse: response = self.request( @@ -450,6 +457,7 @@ def start_strike_run( "model": model, "user_model": user_model.model_dump(mode="json") if user_model else None, "strike": str(strike) if strike else None, + "context": context.model_dump(mode="json") if context else None, }, ) return self.StrikeRunResponse(**response.json()) diff --git a/poetry.lock b/poetry.lock index f1c3292..e7c53fc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1061,6 +1061,17 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, + {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, +] + [[package]] name = "tomli" version = "2.1.0" @@ -1103,6 +1114,17 @@ files = [ [package.dependencies] urllib3 = ">=2" +[[package]] +name = "types-toml" +version = "0.10.8.20240310" +description = "Typing stubs for toml" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-toml-0.10.8.20240310.tar.gz", hash = "sha256:3d41501302972436a6b8b239c850b26689657e25281b48ff0ec06345b8830331"}, + {file = "types_toml-0.10.8.20240310-py3-none-any.whl", hash = "sha256:627b47775d25fa29977d9c70dc0cbab3f314f32c8d8d0c012f2ef5de7aaec05d"}, +] + [[package]] name = "typing-extensions" version = "4.12.2" @@ -1154,4 +1176,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "513aed5633093d6a74ed27b7117995eadee3e1c511e7835f0362f4de5246494e" +content-hash = "4dd62ed90a1469dd758b514a8e08d1397999ddb6bdd7eb72c15828a6ea0eca79" diff --git a/pyproject.toml b/pyproject.toml index 97c4138..3a0511c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,8 @@ httpx = "^0.27.2" ruamel-yaml = "^0.18.6" docker = "^7.1.0" pydantic-yaml = "^1.4.0" +toml = "^0.10.2" +types-toml = "^0.10.8.20240310" [tool.pytest.ini_options] asyncio_mode = "auto"