diff --git a/codex-rs/apply-patch/src/lib.rs b/codex-rs/apply-patch/src/lib.rs index 28dc14eb02f..fe4fe584dc9 100644 --- a/codex-rs/apply-patch/src/lib.rs +++ b/codex-rs/apply-patch/src/lib.rs @@ -112,7 +112,7 @@ fn classify_shell_name(shell: &str) -> Option { fn classify_shell(shell: &str, flag: &str) -> Option { classify_shell_name(shell).and_then(|name| match name.as_str() { - "bash" | "zsh" | "sh" if flag == "-lc" => Some(ApplyPatchShell::Unix), + "bash" | "zsh" | "sh" if matches!(flag, "-lc" | "-c") => Some(ApplyPatchShell::Unix), "pwsh" | "powershell" if flag.eq_ignore_ascii_case("-command") => { Some(ApplyPatchShell::PowerShell) } @@ -1097,6 +1097,13 @@ mod tests { assert_match(&heredoc_script(""), None); } + #[test] + fn test_heredoc_non_login_shell() { + let script = heredoc_script(""); + let args = strs_to_strings(&["bash", "-c", &script]); + assert_match_args(args, None); + } + #[test] fn test_heredoc_applypatch() { let args = strs_to_strings(&[ diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index c33904e2fde..f988216a6d8 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -110,6 +110,7 @@ use crate::rollout::RolloutRecorder; use crate::rollout::RolloutRecorderParams; use crate::rollout::map_session_init_error; use crate::shell; +use crate::shell_snapshot::ShellSnapshot; use crate::state::ActiveTurn; use crate::state::SessionServices; use crate::state::SessionState; @@ -515,7 +516,6 @@ impl Session { // - load history metadata let rollout_fut = RolloutRecorder::new(&config, rollout_params); - let default_shell = shell::default_user_shell(); let history_meta_fut = crate::message_history::history_metadata(&config); let auth_statuses_fut = compute_auth_statuses( config.mcp_servers.iter(), @@ -577,7 +577,14 @@ impl Session { config.active_profile.clone(), ); + let mut default_shell = shell::default_user_shell(); // Create the mutable state for the Session. + if config.features.enabled(Feature::ShellSnapshot) { + default_shell.shell_snapshot = + ShellSnapshot::try_new(&config.codex_home, &default_shell) + .await + .map(Arc::new); + } let state = SessionState::new(session_configuration.clone()); let services = SessionServices { @@ -586,7 +593,7 @@ impl Session { unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(config.notify.clone()), rollout: Mutex::new(Some(rollout_recorder)), - user_shell: default_shell, + user_shell: Arc::new(default_shell), show_raw_agent_reasoning: config.show_raw_agent_reasoning, auth_manager: Arc::clone(&auth_manager), otel_event_manager, @@ -804,14 +811,16 @@ impl Session { ) -> Option { let prev = previous?; - let prev_context = EnvironmentContext::from(prev.as_ref()); - let next_context = EnvironmentContext::from(next); + let shell = self.user_shell(); + let prev_context = EnvironmentContext::from_turn_context(prev.as_ref(), shell.as_ref()); + let next_context = EnvironmentContext::from_turn_context(next, shell.as_ref()); if prev_context.equals_except_shell(&next_context) { return None; } Some(ResponseItem::from(EnvironmentContext::diff( prev.as_ref(), next, + shell.as_ref(), ))) } @@ -1161,6 +1170,7 @@ impl Session { pub(crate) fn build_initial_context(&self, turn_context: &TurnContext) -> Vec { let mut items = Vec::::with_capacity(3); + let shell = self.user_shell(); if let Some(developer_instructions) = turn_context.developer_instructions.as_deref() { items.push(DeveloperInstructions::new(developer_instructions.to_string()).into()); } @@ -1177,7 +1187,7 @@ impl Session { Some(turn_context.cwd.clone()), Some(turn_context.approval_policy), Some(turn_context.sandbox_policy.clone()), - self.user_shell().clone(), + shell.as_ref().clone(), ))); items } @@ -1452,8 +1462,8 @@ impl Session { &self.services.notifier } - pub(crate) fn user_shell(&self) -> &shell::Shell { - &self.services.user_shell + pub(crate) fn user_shell(&self) -> Arc { + Arc::clone(&self.services.user_shell) } fn show_raw_agent_reasoning(&self) -> bool { @@ -2895,7 +2905,7 @@ mod tests { unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(None), rollout: Mutex::new(None), - user_shell: default_user_shell(), + user_shell: Arc::new(default_user_shell()), show_raw_agent_reasoning: config.show_raw_agent_reasoning, auth_manager: auth_manager.clone(), otel_event_manager: otel_event_manager.clone(), @@ -2977,7 +2987,7 @@ mod tests { unified_exec_manager: UnifiedExecSessionManager::default(), notifier: UserNotifier::new(None), rollout: Mutex::new(None), - user_shell: default_user_shell(), + user_shell: Arc::new(default_user_shell()), show_raw_agent_reasoning: config.show_raw_agent_reasoning, auth_manager: Arc::clone(&auth_manager), otel_event_manager: otel_event_manager.clone(), diff --git a/codex-rs/core/src/environment_context.rs b/codex-rs/core/src/environment_context.rs index 56e7f6cadb0..54756bda2d2 100644 --- a/codex-rs/core/src/environment_context.rs +++ b/codex-rs/core/src/environment_context.rs @@ -6,7 +6,6 @@ use crate::codex::TurnContext; use crate::protocol::AskForApproval; use crate::protocol::SandboxPolicy; use crate::shell::Shell; -use crate::shell::default_user_shell; use codex_protocol::config_types::SandboxMode; use codex_protocol::models::ContentItem; use codex_protocol::models::ResponseItem; @@ -95,7 +94,7 @@ impl EnvironmentContext { && self.writable_roots == *writable_roots } - pub fn diff(before: &TurnContext, after: &TurnContext) -> Self { + pub fn diff(before: &TurnContext, after: &TurnContext, shell: &Shell) -> Self { let cwd = if before.cwd != after.cwd { Some(after.cwd.clone()) } else { @@ -111,18 +110,15 @@ impl EnvironmentContext { } else { None }; - EnvironmentContext::new(cwd, approval_policy, sandbox_policy, default_user_shell()) + EnvironmentContext::new(cwd, approval_policy, sandbox_policy, shell.clone()) } -} -impl From<&TurnContext> for EnvironmentContext { - fn from(turn_context: &TurnContext) -> Self { + pub fn from_turn_context(turn_context: &TurnContext, shell: &Shell) -> Self { Self::new( Some(turn_context.cwd.clone()), Some(turn_context.approval_policy), Some(turn_context.sandbox_policy.clone()), - // Shell is not configurable from turn to turn - default_user_shell(), + shell.clone(), ) } } @@ -201,6 +197,7 @@ mod tests { Shell { shell_type: ShellType::Bash, shell_path: PathBuf::from("/bin/bash"), + shell_snapshot: None, } } @@ -338,6 +335,7 @@ mod tests { Shell { shell_type: ShellType::Bash, shell_path: "/bin/bash".into(), + shell_snapshot: None, }, ); let context2 = EnvironmentContext::new( @@ -347,6 +345,7 @@ mod tests { Shell { shell_type: ShellType::Zsh, shell_path: "/bin/zsh".into(), + shell_snapshot: None, }, ); diff --git a/codex-rs/core/src/features.rs b/codex-rs/core/src/features.rs index 69442815e70..b89d29863ec 100644 --- a/codex-rs/core/src/features.rs +++ b/codex-rs/core/src/features.rs @@ -60,6 +60,8 @@ pub enum Feature { ParallelToolCalls, /// Experimental skills injection (CLI flag-driven). Skills, + /// Experimental shell snapshotting. + ShellSnapshot, } impl Feature { @@ -353,4 +355,10 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::Experimental, default_enabled: false, }, + FeatureSpec { + id: Feature::ShellSnapshot, + key: "shell_snapshot", + stage: Stage::Experimental, + default_enabled: false, + }, ]; diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 721c6bb43ca..f84c7dbf33b 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -73,6 +73,7 @@ mod rollout; pub(crate) mod safety; pub mod seatbelt; pub mod shell; +pub mod shell_snapshot; pub mod skills; pub mod spawn; pub mod terminal; diff --git a/codex-rs/core/src/shell.rs b/codex-rs/core/src/shell.rs index 2338f41cd4f..608d8063239 100644 --- a/codex-rs/core/src/shell.rs +++ b/codex-rs/core/src/shell.rs @@ -1,6 +1,9 @@ use serde::Deserialize; use serde::Serialize; use std::path::PathBuf; +use std::sync::Arc; + +use crate::shell_snapshot::ShellSnapshot; #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub enum ShellType { @@ -15,6 +18,8 @@ pub enum ShellType { pub struct Shell { pub(crate) shell_type: ShellType, pub(crate) shell_path: PathBuf, + #[serde(skip_serializing, skip_deserializing, default)] + pub(crate) shell_snapshot: Option>, } impl Shell { @@ -58,6 +63,33 @@ impl Shell { } } } + + pub(crate) fn wrap_command_with_snapshot(&self, command: &[String]) -> Vec { + let Some(snapshot) = &self.shell_snapshot else { + return command.to_vec(); + }; + + if command.is_empty() { + return command.to_vec(); + } + + match self.shell_type { + ShellType::Zsh | ShellType::Bash | ShellType::Sh => { + let mut args = self.derive_exec_args(". \"$0\" && exec \"$@\"", false); + args.push(snapshot.path.to_string_lossy().to_string()); + args.extend_from_slice(command); + args + } + ShellType::PowerShell => { + let mut args = + self.derive_exec_args("param($snapshot) . $snapshot; & @args", false); + args.push(snapshot.path.to_string_lossy().to_string()); + args.extend_from_slice(command); + args + } + ShellType::Cmd => command.to_vec(), + } + } } #[cfg(unix)] @@ -134,6 +166,7 @@ fn get_zsh_shell(path: Option<&PathBuf>) -> Option { shell_path.map(|shell_path| Shell { shell_type: ShellType::Zsh, shell_path, + shell_snapshot: None, }) } @@ -143,6 +176,7 @@ fn get_bash_shell(path: Option<&PathBuf>) -> Option { shell_path.map(|shell_path| Shell { shell_type: ShellType::Bash, shell_path, + shell_snapshot: None, }) } @@ -152,6 +186,7 @@ fn get_sh_shell(path: Option<&PathBuf>) -> Option { shell_path.map(|shell_path| Shell { shell_type: ShellType::Sh, shell_path, + shell_snapshot: None, }) } @@ -167,6 +202,7 @@ fn get_powershell_shell(path: Option<&PathBuf>) -> Option { shell_path.map(|shell_path| Shell { shell_type: ShellType::PowerShell, shell_path, + shell_snapshot: None, }) } @@ -176,6 +212,7 @@ fn get_cmd_shell(path: Option<&PathBuf>) -> Option { shell_path.map(|shell_path| Shell { shell_type: ShellType::Cmd, shell_path, + shell_snapshot: None, }) } @@ -184,11 +221,13 @@ fn ultimate_fallback_shell() -> Shell { Shell { shell_type: ShellType::Cmd, shell_path: PathBuf::from("cmd.exe"), + shell_snapshot: None, } } else { Shell { shell_type: ShellType::Sh, shell_path: PathBuf::from("/bin/sh"), + shell_snapshot: None, } } } @@ -413,6 +452,7 @@ mod tests { let test_bash_shell = Shell { shell_type: ShellType::Bash, shell_path: PathBuf::from("/bin/bash"), + shell_snapshot: None, }; assert_eq!( test_bash_shell.derive_exec_args("echo hello", false), @@ -426,6 +466,7 @@ mod tests { let test_zsh_shell = Shell { shell_type: ShellType::Zsh, shell_path: PathBuf::from("/bin/zsh"), + shell_snapshot: None, }; assert_eq!( test_zsh_shell.derive_exec_args("echo hello", false), @@ -439,6 +480,7 @@ mod tests { let test_powershell_shell = Shell { shell_type: ShellType::PowerShell, shell_path: PathBuf::from("pwsh.exe"), + shell_snapshot: None, }; assert_eq!( test_powershell_shell.derive_exec_args("echo hello", false), @@ -465,6 +507,7 @@ mod tests { Shell { shell_type: ShellType::Zsh, shell_path: PathBuf::from(shell_path), + shell_snapshot: None, } ); } diff --git a/codex-rs/core/src/shell_snapshot.rs b/codex-rs/core/src/shell_snapshot.rs new file mode 100644 index 00000000000..e7c8abb0667 --- /dev/null +++ b/codex-rs/core/src/shell_snapshot.rs @@ -0,0 +1,380 @@ +use std::path::Path; +use std::path::PathBuf; +use std::time::Duration; + +use anyhow::Context; +use anyhow::Result; +use anyhow::anyhow; +use anyhow::bail; +use tokio::fs; +use tokio::process::Command; +use tokio::time::timeout; +use uuid::Uuid; + +use crate::shell::Shell; +use crate::shell::ShellType; +use crate::shell::get_shell; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ShellSnapshot { + pub path: PathBuf, +} + +impl ShellSnapshot { + pub async fn try_new(codex_home: &Path, shell: &Shell) -> Option { + let extension = match shell.shell_type { + ShellType::PowerShell => "ps1", + _ => "sh", + }; + let path = + codex_home + .join("shell_snapshots") + .join(format!("{}.{}", Uuid::new_v4(), extension)); + match write_shell_snapshot(shell.shell_type.clone(), &path).await { + Ok(path) => { + tracing::info!("Shell snapshot successfully created: {}", path.display()); + Some(Self { path }) + } + Err(err) => { + tracing::warn!( + "Failed to create shell snapshot for {}: {err:?}", + shell.name() + ); + None + } + } + } +} + +impl Drop for ShellSnapshot { + fn drop(&mut self) { + if let Err(err) = std::fs::remove_file(&self.path) { + tracing::warn!( + "Failed to delete shell snapshot at {:?}: {err:?}", + self.path + ); + } + } +} + +pub async fn write_shell_snapshot(shell_type: ShellType, output_path: &Path) -> Result { + let shell = get_shell(shell_type.clone(), None) + .with_context(|| format!("No available shell for {shell_type:?}"))?; + + let raw_snapshot = capture_snapshot(&shell).await?; + let snapshot = strip_snapshot_preamble(&raw_snapshot)?; + + if let Some(parent) = output_path.parent() { + let parent_display = parent.display(); + fs::create_dir_all(parent) + .await + .with_context(|| format!("Failed to create snapshot parent {parent_display}"))?; + } + + let snapshot_path = output_path.display(); + fs::write(output_path, snapshot) + .await + .with_context(|| format!("Failed to write snapshot to {snapshot_path}"))?; + + Ok(output_path.to_path_buf()) +} + +async fn capture_snapshot(shell: &Shell) -> Result { + let shell_type = shell.shell_type.clone(); + match shell_type { + ShellType::Zsh => run_shell_script(shell, zsh_snapshot_script()).await, + ShellType::Bash => run_shell_script(shell, bash_snapshot_script()).await, + ShellType::Sh => run_shell_script(shell, sh_snapshot_script()).await, + ShellType::PowerShell => run_shell_script(shell, powershell_snapshot_script()).await, + ShellType::Cmd => bail!("Shell snapshotting is not yet supported for {shell_type:?}"), + } +} + +fn strip_snapshot_preamble(snapshot: &str) -> Result { + let marker = "# Snapshot file"; + let Some(start) = snapshot.find(marker) else { + bail!("Snapshot output missing marker {marker}"); + }; + + Ok(snapshot[start..].to_string()) +} + +async fn run_shell_script(shell: &Shell, script: &str) -> Result { + let args = shell.derive_exec_args(script, true); + let shell_name = shell.name(); + let output = timeout( + Duration::from_secs(10), + Command::new(&args[0]).args(&args[1..]).output(), + ) + .await + .map_err(|_| anyhow!("Snapshot command timed out for {shell_name}"))? + .with_context(|| format!("Failed to execute {shell_name}"))?; + + if !output.status.success() { + let status = output.status; + let stderr = String::from_utf8_lossy(&output.stderr); + bail!("Snapshot command exited with status {status}: {stderr}"); + } + + Ok(String::from_utf8_lossy(&output.stdout).into_owned()) +} + +fn zsh_snapshot_script() -> &'static str { + r##"print '# Snapshot file' +print '# Unset all aliases to avoid conflicts with functions' +print 'unalias -a 2>/dev/null || true' +print '# Functions' +functions +print '' +setopt_count=$(setopt | wc -l | tr -d ' ') +print "# setopts $setopt_count" +setopt | sed 's/^/setopt /' +print '' +alias_count=$(alias -L | wc -l | tr -d ' ') +print "# aliases $alias_count" +alias -L +print '' +export_count=$(export -p | wc -l | tr -d ' ') +print "# exports $export_count" +export -p +"## +} + +fn bash_snapshot_script() -> &'static str { + r##"echo '# Snapshot file' +echo '# Unset all aliases to avoid conflicts with functions' +unalias -a 2>/dev/null || true +echo '# Functions' +declare -f +echo '' +bash_opts=$(set -o | awk '$2=="on"{print $1}') +bash_opt_count=$(printf '%s\n' "$bash_opts" | sed '/^$/d' | wc -l | tr -d ' ') +echo "# setopts $bash_opt_count" +if [ -n "$bash_opts" ]; then + printf 'set -o %s\n' $bash_opts +fi +echo '' +alias_count=$(alias -p | wc -l | tr -d ' ') +echo "# aliases $alias_count" +alias -p +echo '' +export_count=$(export -p | wc -l | tr -d ' ') +echo "# exports $export_count" +export -p +"## +} + +fn sh_snapshot_script() -> &'static str { + r##"echo '# Snapshot file' +echo '# Unset all aliases to avoid conflicts with functions' +unalias -a 2>/dev/null || true +echo '# Functions' +if command -v typeset >/dev/null 2>&1; then + typeset -f +elif command -v declare >/dev/null 2>&1; then + declare -f +fi +echo '' +if set -o >/dev/null 2>&1; then + sh_opts=$(set -o | awk '$2=="on"{print $1}') + sh_opt_count=$(printf '%s\n' "$sh_opts" | sed '/^$/d' | wc -l | tr -d ' ') + echo "# setopts $sh_opt_count" + if [ -n "$sh_opts" ]; then + printf 'set -o %s\n' $sh_opts + fi +else + echo '# setopts 0' +fi +echo '' +if alias >/dev/null 2>&1; then + alias_count=$(alias | wc -l | tr -d ' ') + echo "# aliases $alias_count" + alias + echo '' +else + echo '# aliases 0' +fi +if export -p >/dev/null 2>&1; then + export_count=$(export -p | wc -l | tr -d ' ') + echo "# exports $export_count" + export -p +else + export_count=$(env | wc -l | tr -d ' ') + echo "# exports $export_count" + env | sort | while IFS='=' read -r key value; do + escaped=$(printf "%s" "$value" | sed "s/'/'\"'\"'/g") + printf "export %s='%s'\n" "$key" "$escaped" + done +fi +"## +} + +fn powershell_snapshot_script() -> &'static str { + r##"$ErrorActionPreference = 'Stop' +Write-Output '# Snapshot file' +Write-Output '# Unset all aliases to avoid conflicts with functions' +Write-Output 'Remove-Item Alias:* -ErrorAction SilentlyContinue' +Write-Output '# Functions' +Get-ChildItem Function: | ForEach-Object { + "function {0} {{`n{1}`n}}" -f $_.Name, $_.Definition +} +Write-Output '' +$aliases = Get-Alias +Write-Output ("# aliases " + $aliases.Count) +$aliases | ForEach-Object { + "Set-Alias -Name {0} -Value {1}" -f $_.Name, $_.Definition +} +Write-Output '' +$envVars = Get-ChildItem Env: +Write-Output ("# exports " + $envVars.Count) +$envVars | ForEach-Object { + $escaped = $_.Value -replace "'", "''" + "`$env:{0}='{1}'" -f $_.Name, $escaped +} +"## +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + use std::sync::Arc; + use tempfile::tempdir; + + #[cfg(not(target_os = "windows"))] + fn assert_posix_snapshot_sections(snapshot: &str) { + assert!(snapshot.contains("# Snapshot file")); + assert!(snapshot.contains("aliases ")); + assert!(snapshot.contains("exports ")); + assert!( + snapshot.contains("PATH"), + "snapshot should capture a PATH export" + ); + assert!(snapshot.contains("setopts ")); + } + + async fn get_snapshot(shell_type: ShellType) -> Result { + let dir = tempdir()?; + let path = dir.path().join("snapshot.sh"); + write_shell_snapshot(shell_type, &path).await?; + let content = fs::read_to_string(&path).await?; + Ok(content) + } + + #[test] + fn strip_snapshot_preamble_removes_leading_output() { + let snapshot = "noise\n# Snapshot file\nexport PATH=/bin\n"; + let cleaned = strip_snapshot_preamble(snapshot).expect("snapshot marker exists"); + assert_eq!(cleaned, "# Snapshot file\nexport PATH=/bin\n"); + } + + #[test] + fn strip_snapshot_preamble_requires_marker() { + let result = strip_snapshot_preamble("missing header"); + assert!(result.is_err()); + } + + #[cfg(unix)] + #[test] + fn wrap_command_with_snapshot_wraps_bash_shell() { + let snapshot_path = PathBuf::from("/tmp/snapshot.sh"); + let shell = Shell { + shell_type: ShellType::Bash, + shell_path: PathBuf::from("/bin/bash"), + shell_snapshot: Some(Arc::new(ShellSnapshot { + path: snapshot_path.clone(), + })), + }; + let original_command = vec![ + "bash".to_string(), + "-lc".to_string(), + "echo hello".to_string(), + ]; + + let wrapped = shell.wrap_command_with_snapshot(&original_command); + + let mut expected = shell.derive_exec_args(". \"$0\" && exec \"$@\"", false); + expected.push(snapshot_path.to_string_lossy().to_string()); + expected.extend_from_slice(&original_command); + + assert_eq!(wrapped, expected); + } + + #[test] + fn wrap_command_with_snapshot_preserves_cmd_shell() { + let snapshot_path = PathBuf::from("C:\\snapshot.cmd"); + let shell = Shell { + shell_type: ShellType::Cmd, + shell_path: PathBuf::from("cmd"), + shell_snapshot: Some(Arc::new(ShellSnapshot { + path: snapshot_path, + })), + }; + let original_command = vec![ + "cmd".to_string(), + "/c".to_string(), + "echo hello".to_string(), + ]; + + let wrapped = shell.wrap_command_with_snapshot(&original_command); + + assert_eq!(wrapped, original_command); + } + + #[cfg(unix)] + #[tokio::test] + async fn try_new_creates_and_deletes_snapshot_file() -> Result<()> { + let dir = tempdir()?; + let shell = Shell { + shell_type: ShellType::Bash, + shell_path: PathBuf::from("/bin/bash"), + shell_snapshot: None, + }; + + let snapshot = ShellSnapshot::try_new(dir.path(), &shell) + .await + .expect("snapshot should be created"); + let path = snapshot.path.clone(); + assert!(path.exists()); + + drop(snapshot); + + assert!(!path.exists()); + + Ok(()) + } + + #[cfg(target_os = "macos")] + #[tokio::test] + async fn macos_zsh_snapshot_includes_sections() -> Result<()> { + let snapshot = get_snapshot(ShellType::Zsh).await?; + assert_posix_snapshot_sections(&snapshot); + Ok(()) + } + + #[cfg(target_os = "linux")] + #[tokio::test] + async fn linux_bash_snapshot_includes_sections() -> Result<()> { + let snapshot = get_snapshot(ShellType::Bash).await?; + assert_posix_snapshot_sections(&snapshot); + Ok(()) + } + + #[cfg(target_os = "linux")] + #[tokio::test] + async fn linux_sh_snapshot_includes_sections() -> Result<()> { + let snapshot = get_snapshot(ShellType::Sh).await?; + assert_posix_snapshot_sections(&snapshot); + Ok(()) + } + + #[cfg(target_os = "windows")] + #[tokio::test] + async fn windows_powershell_snapshot_includes_sections() -> Result<()> { + let snapshot = get_snapshot(ShellType::PowerShell).await?; + assert!(snapshot.contains("# Snapshot file")); + assert!(snapshot.contains("aliases ")); + assert!(snapshot.contains("exports ")); + Ok(()) + } +} diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index a35720a9bf7..7387bcedae0 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -18,7 +18,7 @@ pub(crate) struct SessionServices { pub(crate) unified_exec_manager: UnifiedExecSessionManager, pub(crate) notifier: UserNotifier, pub(crate) rollout: Mutex>, - pub(crate) user_shell: crate::shell::Shell, + pub(crate) user_shell: Arc, pub(crate) show_raw_agent_reasoning: bool, pub(crate) auth_manager: Arc, pub(crate) models_manager: Arc, diff --git a/codex-rs/core/src/tools/handlers/apply_patch.rs b/codex-rs/core/src/tools/handlers/apply_patch.rs index 4a28619c760..5b8a04b388f 100644 --- a/codex-rs/core/src/tools/handlers/apply_patch.rs +++ b/codex-rs/core/src/tools/handlers/apply_patch.rs @@ -46,7 +46,7 @@ impl ToolHandler for ApplyPatchHandler { ) } - fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { + async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { true } diff --git a/codex-rs/core/src/tools/handlers/shell.rs b/codex-rs/core/src/tools/handlers/shell.rs index c3ef590e132..c25413568eb 100644 --- a/codex-rs/core/src/tools/handlers/shell.rs +++ b/codex-rs/core/src/tools/handlers/shell.rs @@ -76,7 +76,7 @@ impl ToolHandler for ShellHandler { ) } - fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { match &invocation.payload { ToolPayload::Function { arguments } => { serde_json::from_str::(arguments) @@ -293,18 +293,21 @@ mod tests { let bash_shell = Shell { shell_type: ShellType::Bash, shell_path: PathBuf::from("/bin/bash"), + shell_snapshot: None, }; assert_safe(&bash_shell, "ls -la"); let zsh_shell = Shell { shell_type: ShellType::Zsh, shell_path: PathBuf::from("/bin/zsh"), + shell_snapshot: None, }; assert_safe(&zsh_shell, "ls -la"); let powershell = Shell { shell_type: ShellType::PowerShell, shell_path: PathBuf::from("pwsh.exe"), + shell_snapshot: None, }; assert_safe(&powershell, "ls -Name"); } diff --git a/codex-rs/core/src/tools/handlers/unified_exec.rs b/codex-rs/core/src/tools/handlers/unified_exec.rs index f2500a413ba..8d34e860892 100644 --- a/codex-rs/core/src/tools/handlers/unified_exec.rs +++ b/codex-rs/core/src/tools/handlers/unified_exec.rs @@ -1,12 +1,10 @@ -use std::path::PathBuf; - use crate::function_tool::FunctionCallError; use crate::is_safe_command::is_known_safe_command; use crate::protocol::EventMsg; use crate::protocol::ExecCommandOutputDeltaEvent; use crate::protocol::ExecCommandSource; use crate::protocol::ExecOutputStream; -use crate::shell::default_user_shell; +use crate::shell::Shell; use crate::shell::get_shell_by_model_provided_path; use crate::tools::context::ToolInvocation; use crate::tools::context::ToolOutput; @@ -24,6 +22,8 @@ use crate::unified_exec::UnifiedExecSessionManager; use crate::unified_exec::WriteStdinRequest; use async_trait::async_trait; use serde::Deserialize; +use std::path::PathBuf; +use std::sync::Arc; pub struct UnifiedExecHandler; @@ -34,8 +34,8 @@ struct ExecCommandArgs { workdir: Option, #[serde(default)] shell: Option, - #[serde(default = "default_login")] - login: bool, + #[serde(default)] + login: Option, #[serde(default = "default_exec_yield_time_ms")] yield_time_ms: u64, #[serde(default)] @@ -66,10 +66,6 @@ fn default_write_stdin_yield_time_ms() -> u64 { 250 } -fn default_login() -> bool { - true -} - #[async_trait] impl ToolHandler for UnifiedExecHandler { fn kind(&self) -> ToolKind { @@ -83,7 +79,7 @@ impl ToolHandler for UnifiedExecHandler { ) } - fn is_mutating(&self, invocation: &ToolInvocation) -> bool { + async fn is_mutating(&self, invocation: &ToolInvocation) -> bool { let (ToolPayload::Function { arguments } | ToolPayload::UnifiedExec { arguments }) = &invocation.payload else { @@ -93,7 +89,7 @@ impl ToolHandler for UnifiedExecHandler { let Ok(params) = serde_json::from_str::(arguments) else { return true; }; - let command = get_command(¶ms); + let command = get_command(¶ms, invocation.session.user_shell()); !is_known_safe_command(&command) } @@ -130,9 +126,10 @@ impl ToolHandler for UnifiedExecHandler { })?; let process_id = manager.allocate_process_id().await; - let command = get_command(&args); + let command_for_intercept = get_command(&args, session.user_shell()); let ExecCommandArgs { workdir, + login, yield_time_ms, max_output_tokens, with_escalated_permissions, @@ -159,7 +156,7 @@ impl ToolHandler for UnifiedExecHandler { let cwd = workdir.clone().unwrap_or_else(|| context.turn.cwd.clone()); if let Some(output) = intercept_apply_patch( - &command, + &command_for_intercept, &cwd, Some(yield_time_ms), context.session.as_ref(), @@ -180,6 +177,14 @@ impl ToolHandler for UnifiedExecHandler { &context.call_id, None, ); + let command = if login.is_none() { + context + .session + .user_shell() + .wrap_command_with_snapshot(&command_for_intercept) + } else { + command_for_intercept + }; let emitter = ToolEmitter::unified_exec( &command, cwd.clone(), @@ -255,14 +260,15 @@ impl ToolHandler for UnifiedExecHandler { } } -fn get_command(args: &ExecCommandArgs) -> Vec { - let shell = if let Some(shell_str) = &args.shell { - get_shell_by_model_provided_path(&PathBuf::from(shell_str)) - } else { - default_user_shell() - }; +fn get_command(args: &ExecCommandArgs, session_shell: Arc) -> Vec { + if let Some(shell_str) = &args.shell { + let mut shell = get_shell_by_model_provided_path(&PathBuf::from(shell_str)); + shell.shell_snapshot = None; + return shell.derive_exec_args(&args.cmd, args.login.unwrap_or(true)); + } - shell.derive_exec_args(&args.cmd, args.login) + let use_login_shell = args.login.unwrap_or(session_shell.shell_snapshot.is_none()); + session_shell.derive_exec_args(&args.cmd, use_login_shell) } fn format_response(response: &UnifiedExecResponse) -> String { @@ -297,6 +303,8 @@ fn format_response(response: &UnifiedExecResponse) -> String { #[cfg(test)] mod tests { use super::*; + use crate::shell::default_user_shell; + use std::sync::Arc; #[test] fn test_get_command_uses_default_shell_when_unspecified() { @@ -307,7 +315,7 @@ mod tests { assert!(args.shell.is_none()); - let command = get_command(&args); + let command = get_command(&args, Arc::new(default_user_shell())); assert_eq!(command.len(), 3); assert_eq!(command[2], "echo hello"); @@ -322,7 +330,7 @@ mod tests { assert_eq!(args.shell.as_deref(), Some("/bin/bash")); - let command = get_command(&args); + let command = get_command(&args, Arc::new(default_user_shell())); assert_eq!(command[2], "echo hello"); } @@ -336,7 +344,7 @@ mod tests { assert_eq!(args.shell.as_deref(), Some("powershell")); - let command = get_command(&args); + let command = get_command(&args, Arc::new(default_user_shell())); assert_eq!(command[2], "echo hello"); } @@ -350,7 +358,7 @@ mod tests { assert_eq!(args.shell.as_deref(), Some("cmd")); - let command = get_command(&args); + let command = get_command(&args, Arc::new(default_user_shell())); assert_eq!(command[2], "echo hello"); } diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index f35ff063155..9b33e84b76b 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -30,7 +30,7 @@ pub trait ToolHandler: Send + Sync { ) } - fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { + async fn is_mutating(&self, _invocation: &ToolInvocation) -> bool { false } @@ -110,7 +110,7 @@ impl ToolRegistry { let output_cell = &output_cell; let invocation = invocation; async move { - if handler.is_mutating(&invocation) { + if handler.is_mutating(&invocation).await { tracing::trace!("waiting for tool gate"); invocation.turn.tool_call_gate.wait_ready().await; tracing::trace!("tool gate released"); diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 2112cbb7aaa..29cc3ffb191 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -49,6 +49,7 @@ mod rollout_list_find; mod seatbelt; mod shell_command; mod shell_serialization; +mod shell_snapshot; mod stream_error_allows_next_turn; mod stream_no_completed; mod text_encoding_fix; diff --git a/codex-rs/core/tests/suite/shell_snapshot.rs b/codex-rs/core/tests/suite/shell_snapshot.rs new file mode 100644 index 00000000000..950b3286070 --- /dev/null +++ b/codex-rs/core/tests/suite/shell_snapshot.rs @@ -0,0 +1,222 @@ +use anyhow::Result; +use codex_core::features::Feature; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::ExecCommandBeginEvent; +use codex_core::protocol::ExecCommandEndEvent; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::user_input::UserInput; +use core_test_support::responses::ev_assistant_message; +use core_test_support::responses::ev_completed; +use core_test_support::responses::ev_function_call; +use core_test_support::responses::ev_response_created; +use core_test_support::responses::mount_sse_sequence; +use core_test_support::responses::sse; +use core_test_support::test_codex::TestCodexHarness; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use core_test_support::wait_for_event_match; +use pretty_assertions::assert_eq; +use serde_json::json; +use std::path::PathBuf; +use tokio::fs; + +#[derive(Debug)] +struct SnapshotRun { + begin: ExecCommandBeginEvent, + end: ExecCommandEndEvent, + snapshot_path: PathBuf, + snapshot_content: String, + codex_home: PathBuf, +} + +#[allow(clippy::expect_used)] +async fn run_snapshot_command(command: &str) -> Result { + let builder = test_codex().with_config(|config| { + config.use_experimental_unified_exec_tool = true; + config.features.enable(Feature::UnifiedExec); + config.features.enable(Feature::ShellSnapshot); + }); + let harness = TestCodexHarness::with_builder(builder).await?; + let args = json!({ + "cmd": command, + "yield_time_ms": 1000, + }); + let call_id = "shell-snapshot-exec"; + let responses = vec![ + sse(vec![ + ev_response_created("resp-1"), + ev_function_call(call_id, "exec_command", &serde_json::to_string(&args)?), + ev_completed("resp-1"), + ]), + sse(vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ]; + mount_sse_sequence(harness.server(), responses).await; + + let test = harness.test(); + let codex = test.codex.clone(); + let codex_home = test.home.path().to_path_buf(); + let session_model = test.session_configured.model.clone(); + let cwd = test.cwd_path().to_path_buf(); + + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: "run unified exec with shell snapshot".into(), + }], + final_output_json_schema: None, + cwd, + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + let begin = wait_for_event_match(&codex, |ev| match ev { + EventMsg::ExecCommandBegin(ev) if ev.call_id == call_id => Some(ev.clone()), + _ => None, + }) + .await; + + let snapshot_arg = begin + .command + .iter() + .find(|arg| arg.contains("shell_snapshots")) + .expect("command includes shell snapshot path") + .to_owned(); + let snapshot_path = PathBuf::from(&snapshot_arg); + let snapshot_content = fs::read_to_string(&snapshot_path).await?; + + let end = wait_for_event_match(&codex, |ev| match ev { + EventMsg::ExecCommandEnd(ev) if ev.call_id == call_id => Some(ev.clone()), + _ => None, + }) + .await; + + wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + + Ok(SnapshotRun { + begin, + end, + snapshot_path, + snapshot_content, + codex_home, + }) +} + +fn normalize_newlines(text: &str) -> String { + text.replace("\r\n", "\n") +} + +#[cfg(any(target_os = "linux", target_os = "macos"))] +fn assert_posix_snapshot_sections(snapshot: &str) { + assert!(snapshot.contains("# Snapshot file")); + assert!(snapshot.contains("aliases ")); + assert!(snapshot.contains("exports ")); + assert!(snapshot.contains("setopts ")); + assert!( + snapshot.contains("PATH"), + "snapshot should include PATH exports; snapshot={snapshot:?}" + ); +} + +#[cfg(target_os = "linux")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn linux_unified_exec_uses_shell_snapshot() -> Result<()> { + let command = "echo snapshot-linux"; + let run = run_snapshot_command(command).await?; + + let shell_path = run + .begin + .command + .first() + .expect("shell path recorded") + .clone(); + assert_eq!(run.begin.command.get(1).map(String::as_str), Some("-c")); + assert_eq!( + run.begin.command.get(2).map(String::as_str), + Some(". \"$0\" && exec \"$@\"") + ); + assert_eq!(run.begin.command.get(4), Some(&shell_path)); + assert_eq!(run.begin.command.get(5).map(String::as_str), Some("-c")); + assert_eq!(run.begin.command.last(), Some(&command.to_string())); + + assert!(run.snapshot_path.starts_with(&run.codex_home)); + assert_posix_snapshot_sections(&run.snapshot_content); + assert_eq!(normalize_newlines(&run.end.stdout).trim(), "snapshot-linux"); + assert_eq!(run.end.exit_code, 0); + + Ok(()) +} + +#[cfg(target_os = "macos")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn macos_unified_exec_uses_shell_snapshot() -> Result<()> { + let command = "echo snapshot-macos"; + let run = run_snapshot_command(command).await?; + + let shell_path = run + .begin + .command + .first() + .expect("shell path recorded") + .clone(); + assert_eq!(run.begin.command.get(1).map(String::as_str), Some("-c")); + assert_eq!( + run.begin.command.get(2).map(String::as_str), + Some(". \"$0\" && exec \"$@\"") + ); + assert_eq!(run.begin.command.get(4), Some(&shell_path)); + assert_eq!(run.begin.command.get(5).map(String::as_str), Some("-c")); + assert_eq!(run.begin.command.last(), Some(&command.to_string())); + + assert!(run.snapshot_path.starts_with(&run.codex_home)); + assert_posix_snapshot_sections(&run.snapshot_content); + assert_eq!(normalize_newlines(&run.end.stdout).trim(), "snapshot-macos"); + assert_eq!(run.end.exit_code, 0); + + Ok(()) +} + +#[cfg(target_os = "windows")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn windows_unified_exec_uses_shell_snapshot() -> Result<()> { + let command = "Write-Output snapshot-windows"; + let run = run_snapshot_command(command).await?; + + let snapshot_index = run + .begin + .command + .iter() + .position(|arg| arg.contains("shell_snapshots")) + .expect("snapshot argument exists"); + assert!(run.begin.command.iter().any(|arg| arg == "-NoProfile")); + assert!( + run.begin + .command + .iter() + .any(|arg| arg == "param($snapshot) . $snapshot; & @args") + ); + assert!(snapshot_index > 0); + assert_eq!(run.begin.command.last(), Some(&command.to_string())); + + assert!(run.snapshot_path.starts_with(&run.codex_home)); + assert!(run.snapshot_content.contains("# Snapshot file")); + assert!(run.snapshot_content.contains("# aliases ")); + assert!(run.snapshot_content.contains("# exports ")); + assert_eq!( + normalize_newlines(&run.end.stdout).trim(), + "snapshot-windows" + ); + assert_eq!(run.end.exit_code, 0); + + Ok(()) +}