diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..84ba610 --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,404 @@ +# Repository Architecture + +This document describes the overall architecture and package organization of the StackState Backup CLI. + +## Design Philosophy + +The codebase follows several key principles: + +1. **Layered Architecture**: Dependencies flow from higher layers (commands) to lower layers (foundation utilities) +2. **Self-Documenting Structure**: Directory hierarchy makes dependency rules and module purposes explicit +3. **Clean Separation**: Domain logic, infrastructure, and presentation are clearly separated +4. **Testability**: Lower layers can be tested independently without external dependencies +5. **Reusability**: Shared functionality is extracted into appropriate packages + +## Repository Structure + +``` +stackstate-backup-cli/ +├── cmd/ # Command-line interface (Layer 4) +│ ├── root.go # Root command and global flags +│ ├── version/ # Version information command +│ ├── elasticsearch/ # Elasticsearch backup/restore commands +│ └── stackgraph/ # Stackgraph backup/restore commands +│ +├── internal/ # Internal packages (Layers 0-3) +│ ├── foundation/ # Layer 0: Core utilities +│ │ ├── config/ # Configuration management +│ │ ├── logger/ # Structured logging +│ │ └── output/ # Output formatting +│ │ +│ ├── clients/ # Layer 1: Service clients +│ │ ├── k8s/ # Kubernetes client +│ │ ├── elasticsearch/ # Elasticsearch client +│ │ └── s3/ # S3/Minio client +│ │ +│ ├── orchestration/ # Layer 2: Workflows +│ │ ├── portforward/ # Port-forwarding orchestration +│ │ └── scale/ # Deployment scaling workflows +│ │ +│ ├── app/ # Layer 3: Dependency Container +│ │ └── app.go # Application context and dependency injection +│ │ +│ └── scripts/ # Embedded bash scripts +│ +├── main.go # Application entry point +├── ARCHITECTURE.md # This file +└── README.md # User documentation +``` + +## Architectural Layers + +### Layer 4: Commands (`cmd/`) + +**Purpose**: User-facing CLI commands and application entry points + +**Characteristics**: +- Implements the Cobra command structure +- Handles user input validation and flag parsing +- Delegates to orchestration and client layers via app context +- Minimal business logic (thin command layer) +- Formats output for end users + +**Key Packages**: +- `cmd/elasticsearch/`: Elasticsearch snapshot/restore commands (configure, list-snapshots, list-indices, restore-snapshot) +- `cmd/stackgraph/`: Stackgraph backup/restore commands (list, restore) +- `cmd/version/`: Version information + +**Dependency Rules**: +- ✅ Can import: `internal/app/*` (preferred), all other `internal/` packages +- ❌ Should not: Create clients directly, contain business logic + +### Layer 3: Dependency Container (`internal/app/`) + +**Purpose**: Centralized dependency initialization and injection + +**Characteristics**: +- Creates and wires all application dependencies +- Provides single entry point for dependency creation +- Eliminates boilerplate from commands +- Improves testability through centralized mocking + +**Key Components**: +- `Context`: Struct holding all dependencies (K8s client, S3 client, ES client, config, logger, formatter) +- `NewContext()`: Factory function creating production dependencies from global flags + +**Usage Pattern**: +```go +// In command files +appCtx, err := app.NewContext(globalFlags) +if err != nil { + return err +} + +// All dependencies available via appCtx +appCtx.K8sClient +appCtx.S3Client +appCtx.ESClient +appCtx.Config +appCtx.Logger +appCtx.Formatter +``` + +**Dependency Rules**: +- ✅ Can import: All `internal/` packages +- ✅ Used by: `cmd/` layer only +- ❌ Should not: Contain business logic or orchestration + +### Layer 2: Orchestration (`internal/orchestration/`) + +**Purpose**: High-level workflows that coordinate multiple services + +**Characteristics**: +- Composes multiple clients to implement complex workflows +- Handles sequencing and error recovery +- Provides logging and user feedback +- Stateless operations + +**Key Packages**: +- `portforward/`: Manages Kubernetes port-forwarding lifecycle +- `scale/`: Deployment scaling workflows with detailed logging + +**Dependency Rules**: +- ✅ Can import: `internal/foundation/*`, `internal/clients/*` +- ❌ Cannot import: Other `internal/orchestration/*` (to prevent circular dependencies) + +### Layer 1: Clients (`internal/clients/`) + +**Purpose**: Wrappers for external service APIs + +**Characteristics**: +- Thin abstraction over external APIs +- Handles connection and authentication +- Translates between external formats and internal types +- No business logic or orchestration + +**Key Packages**: +- `k8s/`: Kubernetes API operations (Jobs, Pods, Deployments, ConfigMaps, Secrets, Logs) +- `elasticsearch/`: Elasticsearch HTTP API (snapshots, indices, datastreams) +- `s3/`: S3/Minio operations (client creation, object filtering) + +**Dependency Rules**: +- ✅ Can import: `internal/foundation/*`, standard library, external SDKs +- ❌ Cannot import: `internal/orchestration/*`, other `internal/clients/*` + +### Layer 0: Foundation (`internal/foundation/`) + +**Purpose**: Core utilities with no internal dependencies + +**Characteristics**: +- Pure utility functions +- No external service dependencies +- Broadly reusable across the application +- Well-tested and stable + +**Key Packages**: +- `config/`: Configuration loading from ConfigMaps, Secrets, environment, and flags +- `logger/`: Structured logging with levels (Debug, Info, Warning, Error, Success) +- `output/`: Output formatting (tables, JSON, YAML, messages) + +**Dependency Rules**: +- ✅ Can import: Standard library, external utility libraries +- ❌ Cannot import: Any `internal/` packages + +## Data Flow + +### Typical Command Execution Flow + +``` +1. User invokes CLI command + └─> cmd/elasticsearch/restore-snapshot.go + │ +2. Parse flags and validate input + └─> Cobra command receives global flags + │ +3. Create application context with dependencies + └─> app.NewContext(globalFlags) + ├─> internal/clients/k8s/ (K8s client) + ├─> internal/foundation/config/ (Load from ConfigMap/Secret) + ├─> internal/clients/s3/ (S3/Minio client) + ├─> internal/clients/elasticsearch/ (ES client) + ├─> internal/foundation/logger/ (Logger) + └─> internal/foundation/output/ (Formatter) + │ +4. Execute business logic with injected dependencies + └─> runRestore(appCtx) + ├─> internal/orchestration/scale/ (Scale down) + ├─> internal/orchestration/portforward/ (Port-forward) + ├─> internal/clients/elasticsearch/ (Restore snapshot) + └─> internal/orchestration/scale/ (Scale up) + │ +5. Format and display results + └─> appCtx.Formatter.PrintTable() or PrintJSON() +``` + +## Key Design Patterns + +### 1. Dependency Injection Pattern + +All dependencies are created once and injected via `app.Context`: + +```go +// Before (repeated in every command) +func runList(globalFlags *config.CLIGlobalFlags) error { + k8sClient, _ := k8s.NewClient(...) + cfg, _ := config.LoadConfig(...) + s3Client, _ := s3.NewClient(...) + log := logger.New(...) + formatter := output.NewFormatter(...) + // ... use dependencies +} + +// After (centralized creation) +func runList(appCtx *app.Context) error { + // All dependencies available immediately + appCtx.K8sClient + appCtx.Config + appCtx.S3Client + appCtx.Logger + appCtx.Formatter +} +``` + +**Benefits**: +- Eliminates boilerplate from commands (30-50 lines per command) +- Centralized dependency creation makes testing easier +- Single source of truth for dependency wiring +- Commands are thinner and more focused on business logic + +### 2. Configuration Precedence + +Configuration is loaded with the following precedence (highest to lowest): + +1. **CLI Flags**: Explicit user input +2. **Environment Variables**: Runtime configuration +3. **Kubernetes Secret**: Sensitive credentials (overrides ConfigMap) +4. **Kubernetes ConfigMap**: Base configuration +5. **Defaults**: Fallback values + +Implementation: `internal/foundation/config/config.go` + +### 3. Client Factory Pattern + +Clients are created with a consistent factory pattern: + +```go +// Example from internal/clients/elasticsearch/client.go +func NewClient(endpoint string) (*Client, error) { + // Initialization logic + return &Client{...}, nil +} +``` + +### 4. Port-Forward Lifecycle + +Services running in Kubernetes are accessed via automatic port-forwarding: + +```go +// Example from internal/orchestration/portforward/portforward.go +pf, err := SetupPortForward(k8sClient, namespace, service, localPort, remotePort, log) +defer close(pf.StopChan) // Automatic cleanup +``` + +### 5. Scale Down/Up Pattern + +Deployments are scaled down before restore operations and scaled up afterward: + +```go +// Example usage +scaledDeployments, _ := scale.ScaleDown(k8sClient, namespace, selector, log) +defer scale.ScaleUp(k8sClient, namespace, scaledDeployments, log) +``` + +### 6. Structured Logging + +All operations use structured logging with consistent levels: + +```go +log.Infof("Starting operation...") +log.Debugf("Detail: %v", detail) +log.Warningf("Non-fatal issue: %v", warning) +log.Errorf("Operation failed: %v", err) +log.Successf("Operation completed successfully") +``` + +## Testing Strategy + +### Unit Tests +- **Location**: Same directory as source (e.g., `config_test.go`) +- **Focus**: Business logic, parsing, validation +- **Mocking**: Use interfaces for external dependencies + +### Integration Tests +- **Location**: `cmd/*/` directories +- **Focus**: Command execution with mocked Kubernetes +- **Tools**: `fake.NewSimpleClientset()` from `k8s.io/client-go` + +### End-to-End Tests +- **Status**: Not yet implemented +- **Future**: Use `kind` or `k3s` for local Kubernetes cluster testing + +## Extending the Codebase + +### Adding a New Command + +1. Create command file in `cmd//` +2. Implement Cobra command structure +3. Use existing clients or create new ones in `internal/clients/` +4. Implement workflow in `internal/orchestration/` if needed +5. Add tests following existing patterns + +### Adding a New Client + +1. Create package in `internal/clients//` +2. Implement client factory: `NewClient(...) (*Client, error)` +3. Only import `internal/foundation/*` packages +4. Add methods for each API operation +5. Write unit tests with mocked HTTP/API calls + +### Adding a New Orchestration Workflow + +1. Create package in `internal/orchestration//` +2. Import required clients from `internal/clients/*` +3. Import utilities from `internal/foundation/*` +4. Keep workflows stateless +5. Add comprehensive logging + +## Common Pitfalls to Avoid + +### ❌ Don't: Import Clients from Other Clients + +```go +// BAD: internal/clients/elasticsearch/backup.go +import "github.com/.../internal/clients/k8s" // Violates layer rules +``` + +**Fix**: Move the orchestration logic to `internal/orchestration/` + +### ❌ Don't: Put Business Logic in Commands + +```go +// BAD: cmd/elasticsearch/restore.go +func runRestore() { + // 200 lines of business logic here +} +``` + +**Fix**: Extract logic to orchestration or client packages + +### ❌ Don't: Import Foundation Packages from Each Other + +```go +// BAD: internal/foundation/config/loader.go +import "github.com/.../internal/foundation/output" +``` + +**Fix**: Foundation packages should be independent + +### ❌ Don't: Hard-code Configuration + +```go +// BAD +endpoint := "http://localhost:9200" +``` + +**Fix**: Use configuration management: `config.Elasticsearch.Service.Name` + +### ❌ Don't: Create Clients Directly in Commands + +```go +// BAD: cmd/elasticsearch/list-snapshots.go +func runListSnapshots(globalFlags *config.CLIGlobalFlags) error { + k8sClient, _ := k8s.NewClient(globalFlags.Kubeconfig, globalFlags.Debug) + esClient, _ := elasticsearch.NewClient("http://localhost:9200") + // ... use clients +} +``` + +**Fix**: Use `app.Context` for dependency injection: +```go +// GOOD +func runListSnapshots(appCtx *app.Context) error { + // Dependencies already created + appCtx.K8sClient + appCtx.ESClient +} +``` + +## Automated Enforcement + +Verify architectural rules with these commands: + +```bash +# Verify foundation/ has no internal/ imports +go list -f '{{.ImportPath}}: {{join .Imports "\n"}}' ./internal/foundation/... | \ + grep 'stackvista.*internal' + +# Verify clients/ only imports foundation/ +go list -f '{{.ImportPath}}: {{join .Imports "\n"}}' ./internal/clients/... | \ + grep 'stackvista.*internal' | grep -v foundation + +# Verify orchestration/ doesn't import other orchestration/ +go list -f '{{.ImportPath}}: {{join .Imports "\n"}}' ./internal/orchestration/... | \ + grep 'stackvista.*orchestration' +``` diff --git a/README.md b/README.md index 6430add..7f9a2e9 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,11 @@ A command-line tool for managing backups and restores for SUSE Observability pla This CLI tool replaces the legacy Bash-based backup/restore scripts with a single Go binary that can be run from an operator host. It uses Kubernetes port-forwarding to connect to services and automatically discovers configuration from ConfigMaps and Secrets. -**Current Support:** Elasticsearch snapshots and restores -**Planned:** VictoriaMetrics, ClickHouse, StackGraph, Configuration backups +**Current Support:** +- Elasticsearch snapshots and restores +- Stackgraph backups and restores + +**Planned:** VictoriaMetrics, ClickHouse, Configuration backups ## Installation @@ -75,17 +78,45 @@ sts-backup elasticsearch list-snapshots --namespace #### restore-snapshot -Restore Elasticsearch snapshot. +Restore Elasticsearch snapshot. Automatically scales down affected deployments before restore and scales them back up afterward. ```bash sts-backup elasticsearch restore-snapshot --namespace --snapshot-name [flags] ``` **Flags:** -- `--snapshot-name` - Name of snapshot to restore (required) -- `--drop-all-indices` - Delete all existing indices before restore +- `--snapshot-name, -s` - Name of snapshot to restore (required) +- `--drop-all-indices, -r` - Delete all existing STS indices before restore - `--yes` - Skip confirmation prompt +### stackgraph + +Manage Stackgraph backups and restores. + +#### list + +List available Stackgraph backups from S3/Minio. + +```bash +sts-backup stackgraph list --namespace +``` + +#### restore + +Restore Stackgraph from a backup archive. Automatically scales down affected deployments before restore and scales them back up afterward. + +```bash +sts-backup stackgraph restore --namespace [--archive | --latest] [flags] +``` + +**Flags:** +- `--archive` - Specific archive name to restore (e.g., sts-backup-20210216-0300.graph) +- `--latest` - Restore from the most recent backup +- `--force` - Force delete existing data during restore +- `--background` - Run restore job in background without waiting for completion + +**Note**: Either `--archive` or `--latest` must be specified (mutually exclusive). + ## Configuration The CLI uses configuration from Kubernetes ConfigMaps and Secrets with the following precedence: @@ -149,29 +180,51 @@ kubectl create secret generic suse-observability-backup-config \ -n ``` -See [internal/config/testdata/validConfigMapConfig.yaml](internal/config/testdata/validConfigMapConfig.yaml) for a complete example. +See [internal/foundation/config/testdata/validConfigMapConfig.yaml](internal/foundation/config/testdata/validConfigMapConfig.yaml) for a complete example. ## Project Structure ``` . -├── cmd/ # CLI commands -│ ├── root.go # Root command and flag definitions +├── cmd/ # CLI commands (Layer 4) +│ ├── root.go # Root command and global flags │ ├── version/ # Version command -│ └── elasticsearch/ # Elasticsearch subcommands -│ ├── configure.go # Configure snapshot repository -│ ├── list-indices.go # List indices -│ ├── list-snapshots.go # List snapshots -│ └── restore-snapshot.go # Restore snapshot -├── internal/ # Internal packages -│ ├── config/ # Configuration loading and validation -│ ├── elasticsearch/ # Elasticsearch client -│ ├── k8s/ # Kubernetes client utilities -│ ├── logger/ # Structured logging -│ └── output/ # Output formatting (table, JSON) -└── main.go # Entry point +│ ├── elasticsearch/ # Elasticsearch subcommands +│ │ ├── configure.go # Configure snapshot repository +│ │ ├── list-indices.go # List indices +│ │ ├── list-snapshots.go # List snapshots +│ │ └── restore-snapshot.go # Restore snapshot +│ └── stackgraph/ # Stackgraph subcommands +│ ├── list.go # List backups +│ └── restore.go # Restore backup +├── internal/ # Internal packages (Layers 0-3) +│ ├── foundation/ # Layer 0: Core utilities +│ │ ├── config/ # Configuration management +│ │ ├── logger/ # Structured logging +│ │ └── output/ # Output formatting +│ ├── clients/ # Layer 1: Service clients +│ │ ├── k8s/ # Kubernetes client +│ │ ├── elasticsearch/ # Elasticsearch client +│ │ └── s3/ # S3/Minio client +│ ├── orchestration/ # Layer 2: Workflows +│ │ ├── portforward/ # Port-forwarding lifecycle +│ │ └── scale/ # Deployment scaling +│ ├── app/ # Layer 3: Dependency container +│ │ └── app.go # Application context and DI +│ └── scripts/ # Embedded bash scripts +├── main.go # Entry point +└── ARCHITECTURE.md # Detailed architecture documentation ``` +### Key Architectural Features + +- **Layered Architecture**: Clear separation between commands (Layer 4), dependency injection (Layer 3), workflows (Layer 2), clients (Layer 1), and utilities (Layer 0) +- **Dependency Injection**: Centralized dependency creation via `internal/app/` eliminates boilerplate from commands +- **Testability**: All layers use interfaces for external dependencies, enabling comprehensive unit testing +- **Clean Commands**: Commands are thin (50-100 lines) and focused on business logic + +See [ARCHITECTURE.md](ARCHITECTURE.md) for detailed information about the layered architecture and design patterns. + ## CI/CD This project uses GitHub Actions and GoReleaser for automated releases: diff --git a/cmd/elasticsearch/configure.go b/cmd/elasticsearch/configure.go index bc43e2b..52a81c1 100644 --- a/cmd/elasticsearch/configure.go +++ b/cmd/elasticsearch/configure.go @@ -5,20 +5,23 @@ import ( "os" "github.com/spf13/cobra" - "github.com/stackvista/stackstate-backup-cli/cmd/portforward" - "github.com/stackvista/stackstate-backup-cli/internal/config" - "github.com/stackvista/stackstate-backup-cli/internal/elasticsearch" - "github.com/stackvista/stackstate-backup-cli/internal/k8s" - "github.com/stackvista/stackstate-backup-cli/internal/logger" + "github.com/stackvista/stackstate-backup-cli/internal/app" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" + "github.com/stackvista/stackstate-backup-cli/internal/orchestration/portforward" ) -func configureCmd(cliCtx *config.Context) *cobra.Command { +func configureCmd(globalFlags *config.CLIGlobalFlags) *cobra.Command { return &cobra.Command{ Use: "configure", Short: "Configure Elasticsearch snapshot repository and SLM policy", Long: `Configure Elasticsearch snapshot repository and Snapshot Lifecycle Management (SLM) policy for automated backups.`, Run: func(_ *cobra.Command, _ []string) { - if err := runConfigure(cliCtx); err != nil { + appCtx, err := app.NewContext(globalFlags) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if err := runConfigure(appCtx); err != nil { _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } @@ -26,49 +29,28 @@ func configureCmd(cliCtx *config.Context) *cobra.Command { } } -func runConfigure(cliCtx *config.Context) error { - // Create logger - log := logger.New(cliCtx.Config.Quiet, cliCtx.Config.Debug) - - // Create Kubernetes client - k8sClient, err := k8s.NewClient(cliCtx.Config.Kubeconfig, cliCtx.Config.Debug) - if err != nil { - return fmt.Errorf("failed to create Kubernetes client: %w", err) - } - - // Load configuration - cfg, err := config.LoadConfig(k8sClient.Clientset(), cliCtx.Config.Namespace, cliCtx.Config.ConfigMapName, cliCtx.Config.SecretName) - if err != nil { - return fmt.Errorf("failed to load configuration: %w", err) - } - +func runConfigure(appCtx *app.Context) error { // Validate required configuration - if cfg.Elasticsearch.SnapshotRepository.AccessKey == "" || cfg.Elasticsearch.SnapshotRepository.SecretKey == "" { + if appCtx.Config.Elasticsearch.SnapshotRepository.AccessKey == "" || appCtx.Config.Elasticsearch.SnapshotRepository.SecretKey == "" { return fmt.Errorf("accessKey and secretKey are required in the secret configuration") } // Setup port-forward to Elasticsearch - serviceName := cfg.Elasticsearch.Service.Name - localPort := cfg.Elasticsearch.Service.LocalPortForwardPort - remotePort := cfg.Elasticsearch.Service.Port + serviceName := appCtx.Config.Elasticsearch.Service.Name + localPort := appCtx.Config.Elasticsearch.Service.LocalPortForwardPort + remotePort := appCtx.Config.Elasticsearch.Service.Port - pf, err := portforward.SetupPortForward(k8sClient, cliCtx.Config.Namespace, serviceName, localPort, remotePort, log) + pf, err := portforward.SetupPortForward(appCtx.K8sClient, appCtx.Namespace, serviceName, localPort, remotePort, appCtx.Logger) if err != nil { return err } defer close(pf.StopChan) - // Create Elasticsearch client - esClient, err := elasticsearch.NewClient(fmt.Sprintf("http://localhost:%d", pf.LocalPort)) - if err != nil { - return fmt.Errorf("failed to create Elasticsearch client: %w", err) - } - // Configure snapshot repository - repo := cfg.Elasticsearch.SnapshotRepository - log.Infof("Configuring snapshot repository '%s' (bucket: %s)...", repo.Name, repo.Bucket) + repo := appCtx.Config.Elasticsearch.SnapshotRepository + appCtx.Logger.Infof("Configuring snapshot repository '%s' (bucket: %s)...", repo.Name, repo.Bucket) - err = esClient.ConfigureSnapshotRepository( + err = appCtx.ESClient.ConfigureSnapshotRepository( repo.Name, repo.Bucket, repo.Endpoint, @@ -80,13 +62,13 @@ func runConfigure(cliCtx *config.Context) error { return fmt.Errorf("failed to configure snapshot repository: %w", err) } - log.Successf("Snapshot repository configured successfully") + appCtx.Logger.Successf("Snapshot repository configured successfully") // Configure SLM policy - slm := cfg.Elasticsearch.SLM - log.Infof("Configuring SLM policy '%s'...", slm.Name) + slm := appCtx.Config.Elasticsearch.SLM + appCtx.Logger.Infof("Configuring SLM policy '%s'...", slm.Name) - err = esClient.ConfigureSLMPolicy( + err = appCtx.ESClient.ConfigureSLMPolicy( slm.Name, slm.Schedule, slm.SnapshotTemplateName, @@ -100,9 +82,9 @@ func runConfigure(cliCtx *config.Context) error { return fmt.Errorf("failed to configure SLM policy: %w", err) } - log.Successf("SLM policy configured successfully") - log.Println() - log.Successf("Configuration completed successfully") + appCtx.Logger.Successf("SLM policy configured successfully") + appCtx.Logger.Println() + appCtx.Logger.Successf("Configuration completed successfully") return nil } diff --git a/cmd/elasticsearch/configure_test.go b/cmd/elasticsearch/configure_test.go index b79b303..ddcd4db 100644 --- a/cmd/elasticsearch/configure_test.go +++ b/cmd/elasticsearch/configure_test.go @@ -5,8 +5,8 @@ import ( "fmt" "testing" - "github.com/stackvista/stackstate-backup-cli/internal/config" - "github.com/stackvista/stackstate-backup-cli/internal/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/internal/clients/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" @@ -92,12 +92,12 @@ func (m *mockESClientForConfigure) RolloverDatastream(_ string) error { // TestConfigureCmd_Unit tests the command structure func TestConfigureCmd_Unit(t *testing.T) { - cliCtx := config.NewContext() - cliCtx.Config.Namespace = testNamespace - cliCtx.Config.ConfigMapName = testConfigMapName - cliCtx.Config.SecretName = testSecretName + flags := config.NewCLIGlobalFlags() + flags.Namespace = testNamespace + flags.ConfigMapName = testConfigMapName + flags.SecretName = testSecretName - cmd := configureCmd(cliCtx) + cmd := configureCmd(flags) // Test command metadata assert.Equal(t, "configure", cmd.Use) @@ -152,7 +152,7 @@ elasticsearch: retentionExpireAfter: 30d retentionMinCount: 5 retentionMaxCount: 50 -`, +` + minimalMinioStackgraphConfig, secretData: "", expectError: false, }, @@ -187,12 +187,15 @@ elasticsearch: retentionExpireAfter: 30d retentionMinCount: 5 retentionMaxCount: 50 -`, +` + minimalMinioStackgraphConfig, secretData: ` elasticsearch: snapshotRepository: accessKey: secret-key secretKey: secret-value +minio: + accessKey: secret-minio-key + secretKey: secret-minio-value `, expectError: false, }, @@ -200,7 +203,7 @@ elasticsearch: for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() // Create ConfigMap cm := &corev1.ConfigMap{ diff --git a/cmd/elasticsearch/elasticsearch.go b/cmd/elasticsearch/elasticsearch.go index 9fc8c52..abd609c 100644 --- a/cmd/elasticsearch/elasticsearch.go +++ b/cmd/elasticsearch/elasticsearch.go @@ -2,19 +2,19 @@ package elasticsearch import ( "github.com/spf13/cobra" - "github.com/stackvista/stackstate-backup-cli/internal/config" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" ) -func Cmd(cliCtx *config.Context) *cobra.Command { +func Cmd(globalFlags *config.CLIGlobalFlags) *cobra.Command { cmd := &cobra.Command{ Use: "elasticsearch", Short: "Elasticsearch backup and restore operations", } - cmd.AddCommand(listSnapshotsCmd(cliCtx)) - cmd.AddCommand(listIndicesCmd(cliCtx)) - cmd.AddCommand(restoreCmd(cliCtx)) - cmd.AddCommand(configureCmd(cliCtx)) + cmd.AddCommand(listSnapshotsCmd(globalFlags)) + cmd.AddCommand(listIndicesCmd(globalFlags)) + cmd.AddCommand(restoreCmd(globalFlags)) + cmd.AddCommand(configureCmd(globalFlags)) return cmd } diff --git a/cmd/elasticsearch/list-indices.go b/cmd/elasticsearch/list-indices.go index 4a4f2c2..2367705 100644 --- a/cmd/elasticsearch/list-indices.go +++ b/cmd/elasticsearch/list-indices.go @@ -5,20 +5,23 @@ import ( "os" "github.com/spf13/cobra" - "github.com/stackvista/stackstate-backup-cli/cmd/portforward" - "github.com/stackvista/stackstate-backup-cli/internal/config" - "github.com/stackvista/stackstate-backup-cli/internal/elasticsearch" - "github.com/stackvista/stackstate-backup-cli/internal/k8s" - "github.com/stackvista/stackstate-backup-cli/internal/logger" - "github.com/stackvista/stackstate-backup-cli/internal/output" + "github.com/stackvista/stackstate-backup-cli/internal/app" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/output" + "github.com/stackvista/stackstate-backup-cli/internal/orchestration/portforward" ) -func listIndicesCmd(cliCtx *config.Context) *cobra.Command { +func listIndicesCmd(globalFlags *config.CLIGlobalFlags) *cobra.Command { return &cobra.Command{ Use: "list-indices", Short: "List Elasticsearch indices", Run: func(_ *cobra.Command, _ []string) { - if err := runListIndices(cliCtx); err != nil { + appCtx, err := app.NewContext(globalFlags) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if err := runListIndices(appCtx); err != nil { _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } @@ -26,52 +29,28 @@ func listIndicesCmd(cliCtx *config.Context) *cobra.Command { } } -func runListIndices(cliCtx *config.Context) error { - // Create logger - log := logger.New(cliCtx.Config.Quiet, cliCtx.Config.Debug) - - // Create Kubernetes client - k8sClient, err := k8s.NewClient(cliCtx.Config.Kubeconfig, cliCtx.Config.Debug) - if err != nil { - return fmt.Errorf("failed to create Kubernetes client: %w", err) - } - - // Load configuration - cfg, err := config.LoadConfig(k8sClient.Clientset(), cliCtx.Config.Namespace, cliCtx.Config.ConfigMapName, cliCtx.Config.SecretName) - if err != nil { - return fmt.Errorf("failed to load configuration: %w", err) - } - +func runListIndices(appCtx *app.Context) error { // Setup port-forward to Elasticsearch - serviceName := cfg.Elasticsearch.Service.Name - localPort := cfg.Elasticsearch.Service.LocalPortForwardPort - remotePort := cfg.Elasticsearch.Service.Port + serviceName := appCtx.Config.Elasticsearch.Service.Name + localPort := appCtx.Config.Elasticsearch.Service.LocalPortForwardPort + remotePort := appCtx.Config.Elasticsearch.Service.Port - pf, err := portforward.SetupPortForward(k8sClient, cliCtx.Config.Namespace, serviceName, localPort, remotePort, log) + pf, err := portforward.SetupPortForward(appCtx.K8sClient, appCtx.Namespace, serviceName, localPort, remotePort, appCtx.Logger) if err != nil { return err } defer close(pf.StopChan) - // Create Elasticsearch client - esClient, err := elasticsearch.NewClient(fmt.Sprintf("http://localhost:%d", pf.LocalPort)) - if err != nil { - return fmt.Errorf("failed to create Elasticsearch client: %w", err) - } - // List indices with cat API - log.Infof("Fetching Elasticsearch indices...") + appCtx.Logger.Infof("Fetching Elasticsearch indices...") - indices, err := esClient.ListIndicesDetailed() + indices, err := appCtx.ESClient.ListIndicesDetailed() if err != nil { return fmt.Errorf("failed to list indices: %w", err) } - // Format and print indices - formatter := output.NewFormatter(cliCtx.Config.OutputFormat) - if len(indices) == 0 { - formatter.PrintMessage("No indices found") + appCtx.Formatter.PrintMessage("No indices found") return nil } @@ -97,5 +76,5 @@ func runListIndices(cliCtx *config.Context) error { table.Rows = append(table.Rows, row) } - return formatter.PrintTable(table) + return appCtx.Formatter.PrintTable(table) } diff --git a/cmd/elasticsearch/list-snapshots.go b/cmd/elasticsearch/list-snapshots.go index 7bc932f..5db5c54 100644 --- a/cmd/elasticsearch/list-snapshots.go +++ b/cmd/elasticsearch/list-snapshots.go @@ -5,20 +5,23 @@ import ( "os" "github.com/spf13/cobra" - "github.com/stackvista/stackstate-backup-cli/cmd/portforward" - "github.com/stackvista/stackstate-backup-cli/internal/config" - "github.com/stackvista/stackstate-backup-cli/internal/elasticsearch" - "github.com/stackvista/stackstate-backup-cli/internal/k8s" - "github.com/stackvista/stackstate-backup-cli/internal/logger" - "github.com/stackvista/stackstate-backup-cli/internal/output" + "github.com/stackvista/stackstate-backup-cli/internal/app" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/output" + "github.com/stackvista/stackstate-backup-cli/internal/orchestration/portforward" ) -func listSnapshotsCmd(cliCtx *config.Context) *cobra.Command { +func listSnapshotsCmd(globalFlags *config.CLIGlobalFlags) *cobra.Command { return &cobra.Command{ Use: "list-snapshots", Short: "List available Elasticsearch snapshots", Run: func(_ *cobra.Command, _ []string) { - if err := runListSnapshots(cliCtx); err != nil { + appCtx, err := app.NewContext(globalFlags) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if err := runListSnapshots(appCtx); err != nil { _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } @@ -26,53 +29,29 @@ func listSnapshotsCmd(cliCtx *config.Context) *cobra.Command { } } -func runListSnapshots(cliCtx *config.Context) error { - // Create logger - log := logger.New(cliCtx.Config.Quiet, cliCtx.Config.Debug) - - // Create Kubernetes client - k8sClient, err := k8s.NewClient(cliCtx.Config.Kubeconfig, cliCtx.Config.Debug) - if err != nil { - return fmt.Errorf("failed to create Kubernetes client: %w", err) - } - - // Load configuration - cfg, err := config.LoadConfig(k8sClient.Clientset(), cliCtx.Config.Namespace, cliCtx.Config.ConfigMapName, cliCtx.Config.SecretName) - if err != nil { - return fmt.Errorf("failed to load configuration: %w", err) - } - +func runListSnapshots(appCtx *app.Context) error { // Setup port-forward to Elasticsearch - serviceName := cfg.Elasticsearch.Service.Name - localPort := cfg.Elasticsearch.Service.LocalPortForwardPort - remotePort := cfg.Elasticsearch.Service.Port + serviceName := appCtx.Config.Elasticsearch.Service.Name + localPort := appCtx.Config.Elasticsearch.Service.LocalPortForwardPort + remotePort := appCtx.Config.Elasticsearch.Service.Port - pf, err := portforward.SetupPortForward(k8sClient, cliCtx.Config.Namespace, serviceName, localPort, remotePort, log) + pf, err := portforward.SetupPortForward(appCtx.K8sClient, appCtx.Namespace, serviceName, localPort, remotePort, appCtx.Logger) if err != nil { return err } defer close(pf.StopChan) - // Create Elasticsearch client - esClient, err := elasticsearch.NewClient(fmt.Sprintf("http://localhost:%d", pf.LocalPort)) - if err != nil { - return fmt.Errorf("failed to create Elasticsearch client: %w", err) - } - // List snapshots - repository := cfg.Elasticsearch.Restore.Repository - log.Infof("Fetching snapshots from repository '%s'...", repository) + repository := appCtx.Config.Elasticsearch.Restore.Repository + appCtx.Logger.Infof("Fetching snapshots from repository '%s'...", repository) - snapshots, err := esClient.ListSnapshots(repository) + snapshots, err := appCtx.ESClient.ListSnapshots(repository) if err != nil { return fmt.Errorf("failed to list snapshots: %w", err) } - // Format and print snapshots - formatter := output.NewFormatter(cliCtx.Config.OutputFormat) - if len(snapshots) == 0 { - formatter.PrintMessage("No snapshots found") + appCtx.Formatter.PrintMessage("No snapshots found") return nil } @@ -97,5 +76,5 @@ func runListSnapshots(cliCtx *config.Context) error { table.Rows = append(table.Rows, row) } - return formatter.PrintTable(table) + return appCtx.Formatter.PrintTable(table) } diff --git a/cmd/elasticsearch/list_indices_test.go b/cmd/elasticsearch/list_indices_test.go index 4257553..6b0bd21 100644 --- a/cmd/elasticsearch/list_indices_test.go +++ b/cmd/elasticsearch/list_indices_test.go @@ -5,8 +5,8 @@ import ( "fmt" "testing" - "github.com/stackvista/stackstate-backup-cli/internal/config" - "github.com/stackvista/stackstate-backup-cli/internal/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/internal/clients/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" @@ -71,12 +71,12 @@ func (m *mockESClientForIndices) RolloverDatastream(_ string) error { // TestListIndicesCmd_Unit tests the command structure func TestListIndicesCmd_Unit(t *testing.T) { - cliCtx := config.NewContext() - cliCtx.Config.Namespace = testNamespace - cliCtx.Config.ConfigMapName = testConfigMapName - cliCtx.Config.OutputFormat = "table" + flags := config.NewCLIGlobalFlags() + flags.Namespace = testNamespace + flags.ConfigMapName = testConfigMapName + flags.OutputFormat = "table" - cmd := listIndicesCmd(cliCtx) + cmd := listIndicesCmd(flags) // Test command metadata assert.Equal(t, "list-indices", cmd.Use) @@ -91,7 +91,7 @@ func TestListIndicesCmd_Integration(t *testing.T) { } // Create fake Kubernetes client - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() // Create ConfigMap with valid config cm := &corev1.ConfigMap{ @@ -129,7 +129,7 @@ elasticsearch: retentionExpireAfter: 30d retentionMinCount: 5 retentionMaxCount: 50 -`, +` + minimalMinioStackgraphConfig, }, } _, err := fakeClient.CoreV1().ConfigMaps(testNamespace).Create( diff --git a/cmd/elasticsearch/list_snapshots_test.go b/cmd/elasticsearch/list_snapshots_test.go index e3b3ac4..21ca38e 100644 --- a/cmd/elasticsearch/list_snapshots_test.go +++ b/cmd/elasticsearch/list_snapshots_test.go @@ -6,8 +6,8 @@ import ( "testing" "time" - "github.com/stackvista/stackstate-backup-cli/internal/config" - "github.com/stackvista/stackstate-backup-cli/internal/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/internal/clients/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" corev1 "k8s.io/api/core/v1" @@ -21,6 +21,36 @@ const ( testSecretName = "backup-secret" ) +// minimalMinioStackgraphConfig provides the required Minio and Stackgraph configuration for tests +const minimalMinioStackgraphConfig = ` +minio: + service: + name: minio + port: 9000 + localPortForwardPort: 9000 + accessKey: minioadmin + secretKey: minioadmin +stackgraph: + bucket: stackgraph-bucket + multipartArchive: true + restore: + scaleDownLabelSelector: "app=stackgraph" + loggingConfigConfigMap: logging-config + zookeeperQuorum: "zookeeper:2181" + job: + image: backup:latest + waitImage: wait:latest + resources: + limits: + cpu: "2" + memory: "4Gi" + requests: + cpu: "1" + memory: "2Gi" + pvc: + size: "10Gi" +` + // mockESClient is a simple mock for testing commands type mockESClient struct { snapshots []elasticsearch.Snapshot @@ -79,7 +109,7 @@ func TestListSnapshotsCmd_Integration(t *testing.T) { } // Create fake Kubernetes client - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() // Create ConfigMap with valid config cm := &corev1.ConfigMap{ @@ -117,7 +147,7 @@ elasticsearch: retentionExpireAfter: 30d retentionMinCount: 5 retentionMaxCount: 50 -`, +` + minimalMinioStackgraphConfig, }, } _, err := fakeClient.CoreV1().ConfigMaps(testNamespace).Create( @@ -135,12 +165,12 @@ elasticsearch: // TestListSnapshotsCmd_Unit demonstrates a unit-style test // This test focuses on the command structure and basic behavior func TestListSnapshotsCmd_Unit(t *testing.T) { - cliCtx := config.NewContext() - cliCtx.Config.Namespace = testNamespace - cliCtx.Config.ConfigMapName = testConfigMapName - cliCtx.Config.OutputFormat = "table" + flags := config.NewCLIGlobalFlags() + flags.Namespace = testNamespace + flags.ConfigMapName = testConfigMapName + flags.OutputFormat = "table" - cmd := listSnapshotsCmd(cliCtx) + cmd := listSnapshotsCmd(flags) // Test command metadata assert.Equal(t, "list-snapshots", cmd.Use) diff --git a/cmd/elasticsearch/restore-snapshot.go b/cmd/elasticsearch/restore-snapshot.go index a857616..30b8cb1 100644 --- a/cmd/elasticsearch/restore-snapshot.go +++ b/cmd/elasticsearch/restore-snapshot.go @@ -8,11 +8,12 @@ import ( "time" "github.com/spf13/cobra" - "github.com/stackvista/stackstate-backup-cli/cmd/portforward" - "github.com/stackvista/stackstate-backup-cli/internal/config" - "github.com/stackvista/stackstate-backup-cli/internal/elasticsearch" - "github.com/stackvista/stackstate-backup-cli/internal/k8s" - "github.com/stackvista/stackstate-backup-cli/internal/logger" + "github.com/stackvista/stackstate-backup-cli/internal/app" + "github.com/stackvista/stackstate-backup-cli/internal/clients/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/logger" + "github.com/stackvista/stackstate-backup-cli/internal/orchestration/portforward" + "github.com/stackvista/stackstate-backup-cli/internal/orchestration/scale" ) const ( @@ -29,13 +30,18 @@ var ( skipConfirmation bool ) -func restoreCmd(cliCtx *config.Context) *cobra.Command { +func restoreCmd(globalFlags *config.CLIGlobalFlags) *cobra.Command { cmd := &cobra.Command{ Use: "restore-snapshot", Short: "Restore Elasticsearch from a snapshot", Long: `Restore Elasticsearch indices from a snapshot. Can optionally delete existing indices before restore.`, Run: func(_ *cobra.Command, _ []string) { - if err := runRestore(cliCtx); err != nil { + appCtx, err := app.NewContext(globalFlags) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if err := runRestore(appCtx); err != nil { _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } @@ -48,24 +54,9 @@ func restoreCmd(cliCtx *config.Context) *cobra.Command { return cmd } -func runRestore(cliCtx *config.Context) error { - // Create logger - log := logger.New(cliCtx.Config.Quiet, cliCtx.Config.Debug) - - // Create Kubernetes client - k8sClient, err := k8s.NewClient(cliCtx.Config.Kubeconfig, cliCtx.Config.Debug) - if err != nil { - return fmt.Errorf("failed to create Kubernetes client: %w", err) - } - - // Load configuration - cfg, err := config.LoadConfig(k8sClient.Clientset(), cliCtx.Config.Namespace, cliCtx.Config.ConfigMapName, cliCtx.Config.SecretName) - if err != nil { - return fmt.Errorf("failed to load configuration: %w", err) - } - +func runRestore(appCtx *app.Context) error { // Scale down deployments before restore - scaledDeployments, err := scaleDownDeployments(k8sClient, cliCtx.Config.Namespace, cfg.Elasticsearch.Restore.ScaleDownLabelSelector, log) + scaledDeployments, err := scale.ScaleDown(appCtx.K8sClient, appCtx.Namespace, appCtx.Config.Elasticsearch.Restore.ScaleDownLabelSelector, appCtx.Logger) if err != nil { return err } @@ -73,83 +64,71 @@ func runRestore(cliCtx *config.Context) error { // Ensure deployments are scaled back up on exit (even if restore fails) defer func() { if len(scaledDeployments) > 0 { - log.Println() - log.Infof("Scaling up deployments back to original replica counts...") - if err := k8sClient.ScaleUpDeployments(cliCtx.Config.Namespace, scaledDeployments); err != nil { - log.Warningf("Failed to scale up deployments: %v", err) - } else { - log.Successf("Scaled up %d deployment(s) successfully:", len(scaledDeployments)) - for _, dep := range scaledDeployments { - log.Infof(" - %s (replicas: 0 -> %d)", dep.Name, dep.Replicas) - } + appCtx.Logger.Println() + if err := scale.ScaleUpFromAnnotations(appCtx.K8sClient, appCtx.Namespace, appCtx.Config.Elasticsearch.Restore.ScaleDownLabelSelector, appCtx.Logger); err != nil { + appCtx.Logger.Warningf("Failed to scale up deployments: %v", err) } } }() // Setup port-forward to Elasticsearch - serviceName := cfg.Elasticsearch.Service.Name - localPort := cfg.Elasticsearch.Service.LocalPortForwardPort - remotePort := cfg.Elasticsearch.Service.Port + serviceName := appCtx.Config.Elasticsearch.Service.Name + localPort := appCtx.Config.Elasticsearch.Service.LocalPortForwardPort + remotePort := appCtx.Config.Elasticsearch.Service.Port - pf, err := portforward.SetupPortForward(k8sClient, cliCtx.Config.Namespace, serviceName, localPort, remotePort, log) + pf, err := portforward.SetupPortForward(appCtx.K8sClient, appCtx.Namespace, serviceName, localPort, remotePort, appCtx.Logger) if err != nil { return err } defer close(pf.StopChan) - // Create Elasticsearch client - esClient, err := elasticsearch.NewClient(fmt.Sprintf("http://localhost:%d", pf.LocalPort)) - if err != nil { - return fmt.Errorf("failed to create Elasticsearch client: %w", err) - } - - repository := cfg.Elasticsearch.Restore.Repository + repository := appCtx.Config.Elasticsearch.Restore.Repository // Get all indices and filter for STS indices - log.Infof("Fetching current Elasticsearch indices...") - allIndices, err := esClient.ListIndices("*") + appCtx.Logger.Infof("Fetching current Elasticsearch indices...") + allIndices, err := appCtx.ESClient.ListIndices("*") if err != nil { return fmt.Errorf("failed to list indices: %w", err) } - stsIndices := filterSTSIndices(allIndices, cfg.Elasticsearch.Restore.IndexPrefix, cfg.Elasticsearch.Restore.DatastreamIndexPrefix) + stsIndices := filterSTSIndices(allIndices, appCtx.Config.Elasticsearch.Restore.IndexPrefix, appCtx.Config.Elasticsearch.Restore.DatastreamIndexPrefix) if dropAllIndices { - log.Println() - if err := deleteIndices(esClient, stsIndices, cfg, log, skipConfirmation); err != nil { + appCtx.Logger.Println() + if err := deleteIndices(appCtx.ESClient, stsIndices, appCtx.Config, appCtx.Logger, skipConfirmation); err != nil { return err } } // Restore snapshot - log.Println() - log.Infof("Restoring snapshot '%s' from repository '%s'", snapshotName, repository) + appCtx.Logger.Println() + appCtx.Logger.Infof("Restoring snapshot '%s' from repository '%s'", snapshotName, repository) // Get snapshot details to show indices - snapshot, err := esClient.GetSnapshot(repository, snapshotName) + snapshot, err := appCtx.ESClient.GetSnapshot(repository, snapshotName) if err != nil { return fmt.Errorf("failed to get snapshot details: %w", err) } - log.Debugf("Indices pattern: %s", cfg.Elasticsearch.Restore.IndicesPattern) + appCtx.Logger.Debugf("Indices pattern: %s", appCtx.Config.Elasticsearch.Restore.IndicesPattern) if len(snapshot.Indices) == 0 { - log.Warningf("Snapshot contains no indices") + appCtx.Logger.Warningf("Snapshot contains no indices") } else { - log.Infof("Snapshot contains %d index(es)", len(snapshot.Indices)) + appCtx.Logger.Infof("Snapshot contains %d index(es)", len(snapshot.Indices)) for _, index := range snapshot.Indices { - log.Debugf(" - %s", index) + appCtx.Logger.Debugf(" - %s", index) } } - log.Infof("Starting restore - this may take several minutes...") + appCtx.Logger.Infof("Starting restore - this may take several minutes...") - if err := esClient.RestoreSnapshot(repository, snapshotName, cfg.Elasticsearch.Restore.IndicesPattern, true); err != nil { + if err := appCtx.ESClient.RestoreSnapshot(repository, snapshotName, appCtx.Config.Elasticsearch.Restore.IndicesPattern, true); err != nil { return fmt.Errorf("failed to restore snapshot: %w", err) } - log.Println() - log.Successf("Restore completed successfully") + appCtx.Logger.Println() + appCtx.Logger.Successf("Restore completed successfully") return nil } @@ -190,7 +169,7 @@ func hasDatastreamIndices(indices []string, datastreamPrefix string) bool { } // deleteIndexWithVerification deletes an index and verifies it's gone -func deleteIndexWithVerification(esClient *elasticsearch.Client, index string, log *logger.Logger) error { +func deleteIndexWithVerification(esClient elasticsearch.Interface, index string, log *logger.Logger) error { log.Infof(" Deleting index: %s", index) if err := esClient.DeleteIndex(index); err != nil { return fmt.Errorf("failed to delete index %s: %w", index, err) @@ -214,29 +193,8 @@ func deleteIndexWithVerification(esClient *elasticsearch.Client, index string, l return nil } -// scaleDownDeployments scales down deployments matching the label selector -func scaleDownDeployments(k8sClient *k8s.Client, namespace, labelSelector string, log *logger.Logger) ([]k8s.DeploymentScale, error) { - log.Infof("Scaling down deployments (selector: %s)...", labelSelector) - - scaledDeployments, err := k8sClient.ScaleDownDeployments(namespace, labelSelector) - if err != nil { - return nil, fmt.Errorf("failed to scale down deployments: %w", err) - } - - if len(scaledDeployments) == 0 { - log.Infof("No deployments found to scale down") - } else { - log.Successf("Scaled down %d deployment(s):", len(scaledDeployments)) - for _, dep := range scaledDeployments { - log.Infof(" - %s (replicas: %d -> 0)", dep.Name, dep.Replicas) - } - } - - return scaledDeployments, nil -} - // deleteIndices handles the deletion of all STS indices including datastream rollover -func deleteIndices(esClient *elasticsearch.Client, stsIndices []string, cfg *config.Config, log *logger.Logger, skipConfirm bool) error { +func deleteIndices(esClient elasticsearch.Interface, stsIndices []string, cfg *config.Config, log *logger.Logger, skipConfirm bool) error { if len(stsIndices) == 0 { log.Infof("No STS indices found to delete") return nil diff --git a/cmd/elasticsearch/restore_snapshot_test.go b/cmd/elasticsearch/restore_snapshot_test.go index 761cc86..49b5018 100644 --- a/cmd/elasticsearch/restore_snapshot_test.go +++ b/cmd/elasticsearch/restore_snapshot_test.go @@ -5,8 +5,8 @@ import ( "testing" "time" - "github.com/stackvista/stackstate-backup-cli/internal/config" - "github.com/stackvista/stackstate-backup-cli/internal/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/internal/clients/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -94,8 +94,10 @@ func (m *mockESClientForRestore) ConfigureSLMPolicy(_, _, _, _, _, _ string, _, // TestRestoreCmd_Unit tests the command structure func TestRestoreCmd_Unit(t *testing.T) { - cliCtx := config.NewContext() - cmd := restoreCmd(cliCtx) + flags := config.NewCLIGlobalFlags() + flags.Namespace = testNamespace + flags.ConfigMapName = testConfigMapName + cmd := restoreCmd(flags) // Test command metadata assert.Equal(t, "restore-snapshot", cmd.Use) diff --git a/cmd/root.go b/cmd/root.go index 1893ac5..258e85e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,35 +5,40 @@ import ( "github.com/spf13/cobra" "github.com/stackvista/stackstate-backup-cli/cmd/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/cmd/stackgraph" "github.com/stackvista/stackstate-backup-cli/cmd/version" - "github.com/stackvista/stackstate-backup-cli/internal/config" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" ) var ( - cliCtx *config.Context + flags *config.CLIGlobalFlags ) // addBackupConfigFlags adds configuration flags needed for backup/restore operations // to commands that interact with data services (Elasticsearch, etc.) func addBackupConfigFlags(cmd *cobra.Command) { - cmd.PersistentFlags().StringVar(&cliCtx.Config.Namespace, "namespace", "", "Kubernetes namespace (required)") - cmd.PersistentFlags().StringVar(&cliCtx.Config.Kubeconfig, "kubeconfig", "", "Path to kubeconfig file (default: ~/.kube/config)") - cmd.PersistentFlags().BoolVar(&cliCtx.Config.Debug, "debug", false, "Enable debug output") - cmd.PersistentFlags().BoolVarP(&cliCtx.Config.Quiet, "quiet", "q", false, "Suppress operational messages (only show errors and data output)") - cmd.PersistentFlags().StringVar(&cliCtx.Config.ConfigMapName, "configmap", "suse-observability-backup-config", "ConfigMap name containing backup configuration") - cmd.PersistentFlags().StringVar(&cliCtx.Config.SecretName, "secret", "suse-observability-backup-config", "Secret name containing backup configuration") - cmd.PersistentFlags().StringVarP(&cliCtx.Config.OutputFormat, "output", "o", "table", "Output format (table, json)") + cmd.PersistentFlags().StringVarP(&flags.Namespace, "namespace", "n", "", "Kubernetes namespace (required)") + cmd.PersistentFlags().StringVar(&flags.Kubeconfig, "kubeconfig", "", "Path to kubeconfig file (default: ~/.kube/config)") + cmd.PersistentFlags().BoolVar(&flags.Debug, "debug", false, "Enable debug output") + cmd.PersistentFlags().BoolVarP(&flags.Quiet, "quiet", "q", false, "Suppress operational messages (only show errors and data output)") + cmd.PersistentFlags().StringVar(&flags.ConfigMapName, "configmap", "suse-observability-backup-config", "ConfigMap name containing backup configuration") + cmd.PersistentFlags().StringVar(&flags.SecretName, "secret", "suse-observability-backup-config", "Secret name containing backup configuration") + cmd.PersistentFlags().StringVarP(&flags.OutputFormat, "output", "o", "table", "Output format (table, json)") _ = cmd.MarkPersistentFlagRequired("namespace") } func init() { - cliCtx = config.NewContext() + flags = config.NewCLIGlobalFlags() // Add backup config flags to commands that need them - esCmd := elasticsearch.Cmd(cliCtx) + esCmd := elasticsearch.Cmd(flags) addBackupConfigFlags(esCmd) rootCmd.AddCommand(esCmd) + stackgraphCmd := stackgraph.Cmd(flags) + addBackupConfigFlags(stackgraphCmd) + rootCmd.AddCommand(stackgraphCmd) + // Add commands that don't need backup config flags rootCmd.AddCommand(version.Cmd()) } diff --git a/cmd/stackgraph/check_and_finalize.go b/cmd/stackgraph/check_and_finalize.go new file mode 100644 index 0000000..03a2637 --- /dev/null +++ b/cmd/stackgraph/check_and_finalize.go @@ -0,0 +1,144 @@ +package stackgraph + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" + "github.com/stackvista/stackstate-backup-cli/internal/app" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" + "github.com/stackvista/stackstate-backup-cli/internal/orchestration/scale" + batchv1 "k8s.io/api/batch/v1" +) + +// Check and finalize command flags +var ( + checkJobName string + waitForJob bool +) + +func checkAndFinalizeCmd(globalFlags *config.CLIGlobalFlags) *cobra.Command { + cmd := &cobra.Command{ + Use: "check-and-finalize", + Short: "Check and finalize a Stackgraph restore job", + Long: `Check the status of a background Stackgraph restore job and clean up resources. + +This command is useful when a restore job was started with --background flag or was interrupted (Ctrl+C). +It will check the job status, print logs if it failed, and clean up the job and PVC resources. + +Examples: + # Check job status without waiting + sts-backup stackgraph check-and-finalize --job stackgraph-restore-20250128t143000 -n my-namespace + + # Wait for job completion and cleanup + sts-backup stackgraph check-and-finalize --job stackgraph-restore-20250128t143000 --wait -n my-namespace`, + Run: func(_ *cobra.Command, _ []string) { + appCtx, err := app.NewContext(globalFlags) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if err := runCheckAndFinalize(appCtx); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + }, + } + + cmd.Flags().StringVarP(&checkJobName, "job", "j", "", "Stackgraph restore job name (required)") + cmd.Flags().BoolVarP(&waitForJob, "wait", "w", false, "Wait for job to complete before cleanup") + _ = cmd.MarkFlagRequired("job") + + return cmd +} + +func runCheckAndFinalize(appCtx *app.Context) error { + // Get job + appCtx.Logger.Infof("Checking status of job: %s", checkJobName) + job, err := appCtx.K8sClient.GetJob(appCtx.Namespace, checkJobName) + if err != nil { + return fmt.Errorf("failed to get job '%s': %w (job may not exist or has been deleted)", checkJobName, err) + } + + // Check if job is already complete + completed, succeeded := isJobComplete(job) + + if completed { + // Job already finished - print status and cleanup + return handleCompletedJob(appCtx, checkJobName, succeeded) + } + + // Job still running + if waitForJob { + // Wait for completion, then cleanup + return waitAndFinalize(appCtx, checkJobName) + } + + // Not waiting - just print status + printRunningJobStatus(appCtx.Logger, checkJobName, appCtx.Namespace, job.Status.Active) + return nil +} + +// isJobComplete checks if job is in a terminal state +func isJobComplete(job *batchv1.Job) (completed bool, succeeded bool) { + if job.Status.Succeeded > 0 { + return true, true + } + if job.Status.Failed > 0 { + return true, false + } + return false, false +} + +// handleCompletedJob handles a job that's already complete +func handleCompletedJob(appCtx *app.Context, jobName string, succeeded bool) error { + appCtx.Logger.Println() + if succeeded { + appCtx.Logger.Successf("Job completed successfully: %s", jobName) + appCtx.Logger.Println() + + // Scale up deployments that were scaled down before restore + scaleDownLabelSelector := appCtx.Config.Stackgraph.Restore.ScaleDownLabelSelector + if err := scale.ScaleUpFromAnnotations(appCtx.K8sClient, appCtx.Namespace, scaleDownLabelSelector, appCtx.Logger); err != nil { + appCtx.Logger.Warningf("Failed to scale up deployments: %v", err) + } + } else { + appCtx.Logger.Errorf("Job failed: %s", jobName) + appCtx.Logger.Println() + appCtx.Logger.Infof("Fetching logs...") + appCtx.Logger.Println() + if err := printJobLogs(appCtx.K8sClient, appCtx.Namespace, jobName, appCtx.Logger); err != nil { + appCtx.Logger.Warningf("Failed to fetch logs: %v", err) + } + } + + // Cleanup resources + appCtx.Logger.Println() + return cleanupRestoreResources(appCtx.K8sClient, appCtx.Namespace, jobName, appCtx.Logger) +} + +// waitAndFinalize waits for job completion and then cleans up +func waitAndFinalize(appCtx *app.Context, jobName string) error { + printWaitingMessage(appCtx.Logger, jobName, appCtx.Namespace) + + if err := waitForJobCompletion(appCtx.K8sClient, appCtx.Namespace, jobName, appCtx.Logger); err != nil { + appCtx.Logger.Errorf("Job failed: %v", err) + // Still cleanup even if failed + appCtx.Logger.Println() + _ = cleanupRestoreResources(appCtx.K8sClient, appCtx.Namespace, jobName, appCtx.Logger) + return err + } + + appCtx.Logger.Println() + appCtx.Logger.Successf("Job completed successfully: %s", jobName) + appCtx.Logger.Println() + + // Scale up deployments that were scaled down before restore + scaleDownLabelSelector := appCtx.Config.Stackgraph.Restore.ScaleDownLabelSelector + if err := scale.ScaleUpFromAnnotations(appCtx.K8sClient, appCtx.Namespace, scaleDownLabelSelector, appCtx.Logger); err != nil { + appCtx.Logger.Warningf("Failed to scale up deployments: %v", err) + } + + appCtx.Logger.Println() + return cleanupRestoreResources(appCtx.K8sClient, appCtx.Namespace, jobName, appCtx.Logger) +} diff --git a/cmd/stackgraph/list.go b/cmd/stackgraph/list.go new file mode 100644 index 0000000..79faf7f --- /dev/null +++ b/cmd/stackgraph/list.go @@ -0,0 +1,109 @@ +package stackgraph + +import ( + "context" + "fmt" + "os" + "sort" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/spf13/cobra" + "github.com/stackvista/stackstate-backup-cli/internal/app" + s3client "github.com/stackvista/stackstate-backup-cli/internal/clients/s3" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/output" + "github.com/stackvista/stackstate-backup-cli/internal/orchestration/portforward" +) + +func listCmd(globalFlags *config.CLIGlobalFlags) *cobra.Command { + return &cobra.Command{ + Use: "list", + Short: "List available Stackgraph backups from S3/Minio", + Run: func(_ *cobra.Command, _ []string) { + appCtx, err := app.NewContext(globalFlags) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if err := runList(appCtx); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + }, + } +} + +func runList(appCtx *app.Context) error { + // Setup port-forward to Minio + serviceName := appCtx.Config.Minio.Service.Name + localPort := appCtx.Config.Minio.Service.LocalPortForwardPort + remotePort := appCtx.Config.Minio.Service.Port + + pf, err := portforward.SetupPortForward(appCtx.K8sClient, appCtx.Namespace, serviceName, localPort, remotePort, appCtx.Logger) + if err != nil { + return err + } + defer close(pf.StopChan) + + // List objects in bucket + bucket := appCtx.Config.Stackgraph.Bucket + prefix := appCtx.Config.Stackgraph.S3Prefix + multipartArchive := appCtx.Config.Stackgraph.MultipartArchive + + appCtx.Logger.Infof("Listing Stackgraph backups in bucket '%s'...", bucket) + + input := &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(prefix), + } + + result, err := appCtx.S3Client.ListObjectsV2(context.Background(), input) + if err != nil { + return fmt.Errorf("failed to list S3 objects: %w", err) + } + + // Filter objects based on whether the archive is split or not + filteredObjects := s3client.FilterBackupObjects(result.Contents, multipartArchive) + + // Sort by LastModified time (most recent first) + sort.Slice(filteredObjects, func(i, j int) bool { + return filteredObjects[i].LastModified.After(filteredObjects[j].LastModified) + }) + + if len(filteredObjects) == 0 { + appCtx.Formatter.PrintMessage("No backups found") + return nil + } + + table := output.Table{ + Headers: []string{"NAME", "LAST MODIFIED", "SIZE"}, + Rows: make([][]string, 0, len(filteredObjects)), + } + + for _, obj := range filteredObjects { + row := []string{ + obj.Key, + obj.LastModified.Format("2006-01-02 15:04:05 MST"), + formatBytes(obj.Size), + } + table.Rows = append(table.Rows, row) + } + + return appCtx.Formatter.PrintTable(table) +} + +// formatBytes formats bytes to human-readable format without spaces (e.g., "624MiB") +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%dB", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + units := []string{"KiB", "MiB", "GiB", "TiB", "PiB"} + return fmt.Sprintf("%.0f%s", float64(bytes)/float64(div), units[exp]) +} diff --git a/cmd/stackgraph/restore.go b/cmd/stackgraph/restore.go new file mode 100644 index 0000000..9a301c5 --- /dev/null +++ b/cmd/stackgraph/restore.go @@ -0,0 +1,543 @@ +package stackgraph + +import ( + "bufio" + "context" + "fmt" + "os" + "sort" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/spf13/cobra" + "github.com/stackvista/stackstate-backup-cli/internal/app" + "github.com/stackvista/stackstate-backup-cli/internal/clients/k8s" + s3client "github.com/stackvista/stackstate-backup-cli/internal/clients/s3" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/logger" + "github.com/stackvista/stackstate-backup-cli/internal/orchestration/portforward" + "github.com/stackvista/stackstate-backup-cli/internal/orchestration/scale" + "github.com/stackvista/stackstate-backup-cli/internal/scripts" + corev1 "k8s.io/api/core/v1" +) + +const ( + jobNameTemplate = "stackgraph-restore" + minioKeysSecretName = "suse-observability-backup-cli-minio-keys" //nolint:gosec // This is a Kubernetes secret name, not a credential + restoreScriptsConfigMap = "suse-observability-backup-cli-restore-scripts" + defaultJobCompletionTimeout = 30 * time.Minute + defaultJobStatusCheckInterval = 10 * time.Second + configMapDefaultFileMode = 0755 + purgeStackgraphDataFlag = "-force" +) + +// Restore command flags +var ( + archiveName string + useLatest bool + background bool + skipConfirmation bool +) + +func restoreCmd(globalFlags *config.CLIGlobalFlags) *cobra.Command { + cmd := &cobra.Command{ + Use: "restore", + Short: "Restore Stackgraph from a backup archive", + Long: `Restore Stackgraph data from a backup archive stored in S3/Minio. Can use --latest or --archive to specify which backup to restore.`, + Run: func(_ *cobra.Command, _ []string) { + appCtx, err := app.NewContext(globalFlags) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if err := runRestore(appCtx); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + }, + } + + cmd.Flags().StringVar(&archiveName, "archive", "", "Specific archive name to restore (e.g., sts-backup-20210216-0300.graph)") + cmd.Flags().BoolVar(&useLatest, "latest", false, "Restore from the most recent backup") + cmd.Flags().BoolVar(&background, "background", false, "Run restore job in background without waiting for completion") + cmd.Flags().BoolVarP(&skipConfirmation, "yes", "y", false, "Skip confirmation prompt") + cmd.MarkFlagsMutuallyExclusive("archive", "latest") + cmd.MarkFlagsOneRequired("archive", "latest") + + return cmd +} + +func runRestore(appCtx *app.Context) error { + // Determine which archive to restore + backupFile := archiveName + if useLatest { + appCtx.Logger.Infof("Finding latest backup...") + latest, err := getLatestBackup(appCtx.K8sClient, appCtx.Namespace, appCtx.Config, appCtx.Logger) + if err != nil { + return err + } + backupFile = latest + appCtx.Logger.Infof("Using latest backup: %s", backupFile) + } + + // Warn user and ask for confirmation + if !skipConfirmation { + appCtx.Logger.Println() + appCtx.Logger.Warningf("WARNING: Restoring from backup will PURGE all existing Stackgraph data!") + appCtx.Logger.Warningf("This operation cannot be undone.") + appCtx.Logger.Println() + appCtx.Logger.Infof("Backup to restore: %s", backupFile) + appCtx.Logger.Infof("Namespace: %s", appCtx.Namespace) + appCtx.Logger.Println() + + if !promptForConfirmation() { + return fmt.Errorf("restore operation cancelled by user") + } + } + + // Scale down deployments before restore + appCtx.Logger.Println() + scaleDownLabelSelector := appCtx.Config.Stackgraph.Restore.ScaleDownLabelSelector + scaledDeployments, err := scale.ScaleDown(appCtx.K8sClient, appCtx.Namespace, scaleDownLabelSelector, appCtx.Logger) + if err != nil { + return err + } + + // Ensure deployments are scaled back up on exit (even if restore fails) + defer func() { + if len(scaledDeployments) > 0 && !background { + appCtx.Logger.Println() + if err := scale.ScaleUpFromAnnotations(appCtx.K8sClient, appCtx.Namespace, scaleDownLabelSelector, appCtx.Logger); err != nil { + appCtx.Logger.Warningf("Failed to scale up deployments: %v", err) + } + } + }() + + // Setup Kubernetes resources for restore job + appCtx.Logger.Println() + if err := ensureRestoreResources(appCtx.K8sClient, appCtx.Namespace, appCtx.Config, appCtx.Logger); err != nil { + return err + } + + // Create restore job + appCtx.Logger.Println() + appCtx.Logger.Infof("Creating restore job for backup: %s", backupFile) + + jobName := fmt.Sprintf("%s-%s", jobNameTemplate, time.Now().Format("20060102t150405")) + + if err = createRestoreJob(appCtx.K8sClient, appCtx.Namespace, jobName, backupFile, appCtx.Config); err != nil { + return fmt.Errorf("failed to create restore job: %w", err) + } + + appCtx.Logger.Successf("Restore job created: %s", jobName) + + if background { + printRunningJobStatus(appCtx.Logger, jobName, appCtx.Namespace, 0) + return nil + } + + return waitAndCleanupRestoreJob(appCtx.K8sClient, appCtx.Namespace, jobName, appCtx.Logger) +} + +// ensureRestoreResources ensures that required Kubernetes resources exist for the restore job +func ensureRestoreResources(k8sClient *k8s.Client, namespace string, config *config.Config, log *logger.Logger) error { + // Ensure backup scripts ConfigMap exists + log.Infof("Ensuring backup scripts ConfigMap exists...") + + scriptNames, err := scripts.ListScripts() + if err != nil { + return fmt.Errorf("failed to list embedded scripts: %w", err) + } + + scriptsData := make(map[string]string) + for _, scriptName := range scriptNames { + scriptContent, err := scripts.GetScript(scriptName) + if err != nil { + return fmt.Errorf("failed to get script %s: %w", scriptName, err) + } + scriptsData[scriptName] = string(scriptContent) + } + + configMapLabels := k8s.MergeLabels(config.Kubernetes.CommonLabels, map[string]string{}) + if _, err := k8sClient.EnsureConfigMap(namespace, restoreScriptsConfigMap, scriptsData, configMapLabels); err != nil { + return fmt.Errorf("failed to ensure backup scripts ConfigMap: %w", err) + } + log.Successf("Backup scripts ConfigMap ready") + + // Ensure Minio keys secret exists + log.Infof("Ensuring Minio keys secret exists...") + + secretData := map[string][]byte{ + "accesskey": []byte(config.Minio.AccessKey), + "secretkey": []byte(config.Minio.SecretKey), + } + + secretLabels := k8s.MergeLabels(config.Kubernetes.CommonLabels, map[string]string{}) + if _, err := k8sClient.EnsureSecret(namespace, minioKeysSecretName, secretData, secretLabels); err != nil { + return fmt.Errorf("failed to ensure Minio keys secret: %w", err) + } + log.Successf("Minio keys secret ready") + + return nil +} + +// printWaitingMessage prints waiting message with instructions for interruption +func printWaitingMessage(log *logger.Logger, jobName, namespace string) { + log.Println() + log.Infof("Waiting for restore job to complete (this may take several minutes)...") + log.Println() + log.Infof("You can safely interrupt this command with Ctrl+C.") + log.Infof("To check status, scale up the required deployments and cleanup later, run:") + log.Infof(" sts-backup stackgraph check-and-finalize --job %s --wait -n %s", jobName, namespace) +} + +// printRunningJobStatus prints status and instructions for a running job +func printRunningJobStatus(log *logger.Logger, jobName, namespace string, activePods int32) { + log.Println() + log.Infof("Job is running in background: %s", jobName) + if activePods > 0 { + log.Infof(" Active pods: %d", activePods) + } + log.Println() + log.Infof("Monitoring commands:") + log.Infof(" kubectl logs --follow job/%s -n %s", jobName, namespace) + log.Infof(" kubectl get job %s -n %s", jobName, namespace) + log.Println() + log.Infof("To wait for completion, scaling up the necessary deployments and cleanup, run:") + log.Infof(" sts-backup stackgraph check-and-finalize --job %s --wait -n %s", jobName, namespace) +} + +// cleanupRestoreResources cleans up job and PVC resources +func cleanupRestoreResources(k8sClient *k8s.Client, namespace, jobName string, log *logger.Logger) error { + log.Infof("Cleaning up job and PVC...") + + // Delete job + if err := k8sClient.DeleteJob(namespace, jobName); err != nil { + log.Warningf("Failed to delete job: %v", err) + } else { + log.Successf("Job deleted: %s", jobName) + } + + // Delete PVC (same name as job) + if err := k8sClient.DeletePVC(namespace, jobName); err != nil { + log.Warningf("Failed to delete PVC: %v", err) + } else { + log.Successf("PVC deleted: %s", jobName) + } + + return nil +} + +// waitAndCleanupRestoreJob waits for job completion and cleans up resources +func waitAndCleanupRestoreJob(k8sClient *k8s.Client, namespace, jobName string, log *logger.Logger) error { + printWaitingMessage(log, jobName, namespace) + + if err := waitForJobCompletion(k8sClient, namespace, jobName, log); err != nil { + log.Errorf("Job failed: %v", err) + log.Println() + log.Infof("Cleanup commands:") + log.Infof(" kubectl delete job,pvc %s -n %s", jobName, namespace) + return err + } + + log.Println() + log.Successf("Restore completed successfully") + + // Cleanup job and PVC using shared function + log.Println() + return cleanupRestoreResources(k8sClient, namespace, jobName, log) +} + +// getLatestBackup retrieves the most recent backup from S3 +func getLatestBackup(k8sClient *k8s.Client, namespace string, config *config.Config, log *logger.Logger) (string, error) { + // Setup port-forward to Minio + serviceName := config.Minio.Service.Name + localPort := config.Minio.Service.LocalPortForwardPort + remotePort := config.Minio.Service.Port + + pf, err := portforward.SetupPortForward(k8sClient, namespace, serviceName, localPort, remotePort, log) + if err != nil { + return "", err + } + defer close(pf.StopChan) + + // Create S3 client + endpoint := fmt.Sprintf("http://localhost:%d", pf.LocalPort) + s3Client, err := s3client.NewClient(endpoint, config.Minio.AccessKey, config.Minio.SecretKey) + if err != nil { + return "", err + } + + // List objects in bucket + bucket := config.Stackgraph.Bucket + prefix := config.Stackgraph.S3Prefix + multipartArchive := config.Stackgraph.MultipartArchive + + input := &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(prefix), + } + + result, err := s3Client.ListObjectsV2(context.Background(), input) + if err != nil { + return "", fmt.Errorf("failed to list S3 objects: %w", err) + } + + // Filter objects based on whether the archive is split or not + filteredObjects := s3client.FilterBackupObjects(result.Contents, multipartArchive) + + if len(filteredObjects) == 0 { + return "", fmt.Errorf("no backups found in bucket %s", bucket) + } + + // Sort by LastModified time (most recent first) + sort.Slice(filteredObjects, func(i, j int) bool { + return filteredObjects[i].LastModified.After(filteredObjects[j].LastModified) + }) + + return filteredObjects[0].Key, nil +} + +// buildPVCSpec builds a PVCSpec from configuration +func buildPVCSpec(name string, config *config.Config, labels map[string]string) k8s.PVCSpec { + pvcConfig := config.Stackgraph.Restore.PVC + + // Convert string access modes to k8s types + accessModes := []corev1.PersistentVolumeAccessMode{corev1.ReadWriteOnce} // default + if len(pvcConfig.AccessModes) > 0 { + accessModes = make([]corev1.PersistentVolumeAccessMode, 0, len(pvcConfig.AccessModes)) + for _, mode := range pvcConfig.AccessModes { + accessModes = append(accessModes, corev1.PersistentVolumeAccessMode(mode)) + } + } + + // Handle storage class (nil if not set) + var storageClass *string + if pvcConfig.StorageClassName != "" { + storageClass = &pvcConfig.StorageClassName + } + + return k8s.PVCSpec{ + Name: name, + Labels: labels, + StorageSize: pvcConfig.Size, + AccessModes: accessModes, + StorageClass: storageClass, + } +} + +// createRestoreJob creates a Kubernetes Job and PVC for restoring from backup +func createRestoreJob(k8sClient *k8s.Client, namespace, jobName, backupFile string, config *config.Config) error { + defaultMode := int32(configMapDefaultFileMode) + + // Merge common labels with resource-specific labels + pvcLabels := k8s.MergeLabels(config.Kubernetes.CommonLabels, map[string]string{}) + jobLabels := k8s.MergeLabels(config.Kubernetes.CommonLabels, config.Stackgraph.Restore.Job.Labels) + + // Create PVC first + pvcSpec := buildPVCSpec(jobName, config, pvcLabels) + pvc, err := k8sClient.CreatePVC(namespace, pvcSpec) + if err != nil { + return fmt.Errorf("failed to create PVC: %w", err) + } + + // Build job spec using configuration + spec := k8s.BackupJobSpec{ + Name: jobName, + Labels: jobLabels, + ImagePullSecrets: k8s.ConvertImagePullSecrets(config.Stackgraph.Restore.Job.ImagePullSecrets), + SecurityContext: k8s.ConvertPodSecurityContext(&config.Stackgraph.Restore.Job.SecurityContext), + NodeSelector: config.Stackgraph.Restore.Job.NodeSelector, + Tolerations: k8s.ConvertTolerations(config.Stackgraph.Restore.Job.Tolerations), + Affinity: k8s.ConvertAffinity(config.Stackgraph.Restore.Job.Affinity), + ContainerSecurityContext: k8s.ConvertSecurityContext(config.Stackgraph.Restore.Job.ContainerSecurityContext), + Image: config.Stackgraph.Restore.Job.Image, + Command: []string{"/backup-restore-scripts/restore-stackgraph-backup.sh"}, + Env: buildRestoreEnvVars(backupFile, config), + Resources: k8s.ConvertResources(config.Stackgraph.Restore.Job.Resources), + VolumeMounts: buildRestoreVolumeMounts(), + InitContainers: buildRestoreInitContainers(config), + Volumes: buildRestoreVolumes(jobName, config, defaultMode), + } + + // Create job + _, err = k8sClient.CreateBackupJob(namespace, spec) + if err != nil { + // Cleanup PVC if job creation fails + _ = k8sClient.DeletePVC(namespace, pvc.Name) + return fmt.Errorf("failed to create job: %w", err) + } + + return nil +} + +// buildRestoreEnvVars constructs environment variables for the restore job +func buildRestoreEnvVars(backupFile string, config *config.Config) []corev1.EnvVar { + return []corev1.EnvVar{ + {Name: "BACKUP_FILE", Value: backupFile}, + {Name: "FORCE_DELETE", Value: purgeStackgraphDataFlag}, + {Name: "BACKUP_STACKGRAPH_BUCKET_NAME", Value: config.Stackgraph.Bucket}, + {Name: "BACKUP_STACKGRAPH_S3_PREFIX", Value: config.Stackgraph.S3Prefix}, + {Name: "BACKUP_STACKGRAPH_MULTIPART_ARCHIVE", Value: strconv.FormatBool(config.Stackgraph.MultipartArchive)}, + {Name: "MINIO_ENDPOINT", Value: fmt.Sprintf("%s:%d", config.Minio.Service.Name, config.Minio.Service.Port)}, + {Name: "ZOOKEEPER_QUORUM", Value: config.Stackgraph.Restore.ZookeeperQuorum}, + } +} + +// buildRestoreVolumeMounts constructs volume mounts for the restore job container +func buildRestoreVolumeMounts() []corev1.VolumeMount { + return []corev1.VolumeMount{ + {Name: "backup-log", MountPath: "/opt/docker/etc_log"}, + {Name: "backup-restore-scripts", MountPath: "/backup-restore-scripts"}, + {Name: "minio-keys", MountPath: "/aws-keys"}, + {Name: "tmp-data", MountPath: "/tmp-data"}, + } +} + +// buildRestoreInitContainers constructs init containers for the restore job +func buildRestoreInitContainers(config *config.Config) []corev1.Container { + return []corev1.Container{ + { + Name: "wait", + Image: config.Stackgraph.Restore.Job.WaitImage, + ImagePullPolicy: corev1.PullIfNotPresent, + Command: []string{ + "sh", + "-c", + fmt.Sprintf("/entrypoint -c %s:%d -t 300", config.Minio.Service.Name, config.Minio.Service.Port), + }, + }, + } +} + +// buildRestoreVolumes constructs volumes for the restore job pod +func buildRestoreVolumes(jobName string, config *config.Config, defaultMode int32) []corev1.Volume { + return []corev1.Volume{ + { + Name: "backup-log", + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: config.Stackgraph.Restore.LoggingConfigConfigMapName, + }, + }, + }, + }, + { + Name: "backup-restore-scripts", + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: restoreScriptsConfigMap, + }, + DefaultMode: &defaultMode, + }, + }, + }, + { + Name: "minio-keys", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: minioKeysSecretName, + }, + }, + }, + { + Name: "tmp-data", + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: jobName, + }, + }, + }, + } +} + +// waitForJobCompletion waits for a Kubernetes job to complete +func waitForJobCompletion(k8sClient *k8s.Client, namespace, jobName string, log *logger.Logger) error { + timeout := time.After(defaultJobCompletionTimeout) + ticker := time.NewTicker(defaultJobStatusCheckInterval) + defer ticker.Stop() + + for { + select { + case <-timeout: + return fmt.Errorf("timeout waiting for job to complete") + case <-ticker.C: + job, err := k8sClient.GetJob(namespace, jobName) + if err != nil { + return fmt.Errorf("failed to get job status: %w", err) + } + + if job.Status.Succeeded > 0 { + return nil + } + + if job.Status.Failed > 0 { + // Get and print logs from failed job + log.Println() + log.Errorf("Job failed. Fetching logs...") + log.Println() + if err := printJobLogs(k8sClient, namespace, jobName, log); err != nil { + log.Warningf("Failed to fetch job logs: %v", err) + } + return fmt.Errorf("job failed") + } + + log.Debugf("Job status: Active=%d, Succeeded=%d, Failed=%d", + job.Status.Active, job.Status.Succeeded, job.Status.Failed) + } + } +} + +// printJobLogs retrieves and prints logs from all containers in a job's pods +func printJobLogs(k8sClient *k8s.Client, namespace, jobName string, log *logger.Logger) error { + // Get logs from all pods in the job + allPodLogs, err := k8sClient.GetJobLogs(namespace, jobName) + if err != nil { + return err + } + + // Print logs from each pod + for _, podLogs := range allPodLogs { + log.Infof("=== Logs from pod: %s ===", podLogs.PodName) + log.Println() + + // Print logs from each container + for _, containerLog := range podLogs.ContainerLogs { + containerType := "container" + if containerLog.IsInit { + containerType = "init container" + } + + log.Infof("--- Logs from %s: %s ---", containerType, containerLog.Name) + + // Print the actual logs + if containerLog.Logs != "" { + fmt.Println(containerLog.Logs) + } else { + log.Infof("(no logs)") + } + log.Println() + } + } + + return nil +} + +// promptForConfirmation prompts the user for confirmation and returns true if they confirm +func promptForConfirmation() bool { + reader := bufio.NewReader(os.Stdin) + fmt.Print("Do you want to continue? (yes/no): ") + + response, err := reader.ReadString('\n') + if err != nil { + return false + } + + response = strings.TrimSpace(strings.ToLower(response)) + return response == "yes" || response == "y" +} diff --git a/cmd/stackgraph/stackgraph.go b/cmd/stackgraph/stackgraph.go new file mode 100644 index 0000000..76702f5 --- /dev/null +++ b/cmd/stackgraph/stackgraph.go @@ -0,0 +1,19 @@ +package stackgraph + +import ( + "github.com/spf13/cobra" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" +) + +func Cmd(globalFlags *config.CLIGlobalFlags) *cobra.Command { + cmd := &cobra.Command{ + Use: "stackgraph", + Short: "Stackgraph backup and restore operations", + } + + cmd.AddCommand(listCmd(globalFlags)) + cmd.AddCommand(restoreCmd(globalFlags)) + cmd.AddCommand(checkAndFinalizeCmd(globalFlags)) + + return cmd +} diff --git a/go.mod b/go.mod index f66215c..5f12358 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,10 @@ go 1.25.2 require ( dario.cat/mergo v1.0.2 + github.com/aws/aws-sdk-go-v2 v1.39.3 + github.com/aws/aws-sdk-go-v2/config v1.31.13 + github.com/aws/aws-sdk-go-v2/credentials v1.18.17 + github.com/aws/aws-sdk-go-v2/service/s3 v1.88.5 github.com/elastic/go-elasticsearch/v8 v8.19.0 github.com/go-playground/validator/v10 v10.28.0 github.com/spf13/cobra v1.10.1 @@ -15,6 +19,20 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.10 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.10 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 // indirect + github.com/aws/smithy-go v1.23.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/elastic/elastic-transport-go/v8 v8.7.0 // indirect github.com/emicklei/go-restful/v3 v3.12.2 // indirect diff --git a/go.sum b/go.sum index 6447f57..cfd0103 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,42 @@ dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/aws/aws-sdk-go-v2 v1.39.3 h1:h7xSsanJ4EQJXG5iuW4UqgP7qBopLpj84mpkNx3wPjM= +github.com/aws/aws-sdk-go-v2 v1.39.3/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 h1:t9yYsydLYNBk9cJ73rgPhPWqOh/52fcWDQB5b1JsKSY= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2/go.mod h1:IusfVNTmiSN3t4rhxWFaBAqn+mcNdwKtPcV16eYdgko= +github.com/aws/aws-sdk-go-v2/config v1.31.13 h1:wcqQB3B0PgRPUF5ZE/QL1JVOyB0mbPevHFoAMpemR9k= +github.com/aws/aws-sdk-go-v2/config v1.31.13/go.mod h1:ySB5D5ybwqGbT6c3GszZ+u+3KvrlYCUQNo62+hkKOFk= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17 h1:skpEwzN/+H8cdrrtT8y+rvWJGiWWv0DeNAe+4VTf+Vs= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17/go.mod h1:Ed+nXsaYa5uBINovJhcAWkALvXw2ZLk36opcuiSZfJM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 h1:UuGVOX48oP4vgQ36oiKmW9RuSeT8jlgQgBFQD+HUiHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10/go.mod h1:vM/Ini41PzvudT4YkQyE/+WiQJiQ6jzeDyU8pQKwCac= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 h1:mj/bdWleWEh81DtpdHKkw41IrS+r3uw1J/VQtbwYYp8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10/go.mod h1:7+oEMxAZWP8gZCyjcm9VicI0M61Sx4DJtcGfKYv2yKQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 h1:wh+/mn57yhUrFtLIxyFPh2RgxgQz/u+Yrf7hiHGHqKY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10/go.mod h1:7zirD+ryp5gitJJ2m1BBux56ai8RIRDykXZrJSp540w= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.10 h1:FHw90xCTsofzk6vjU808TSuDtDfOOKPNdz5Weyc3tUI= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.10/go.mod h1:n8jdIE/8F3UYkg8O4IGkQpn2qUmapg/1K1yl29/uf/c= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.1 h1:ne+eepnDB2Wh5lHKzELgEncIqeVlQ1rSF9fEa4r5I+A= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.1/go.mod h1:u0Jkg0L+dcG1ozUq21uFElmpbmjBnhHR5DELHIme4wg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 h1:DRND0dkCKtJzCj4Xl4OpVbXZgfttY5q712H9Zj7qc/0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10/go.mod h1:tGGNmJKOTernmR2+VJ0fCzQRurcPZj9ut60Zu5Fi6us= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.10 h1:DA+Hl5adieRyFvE7pCvBWm3VOZTRexGVkXw33SUqNoY= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.10/go.mod h1:L+A89dH3/gr8L4ecrdzuXUYd1znoko6myzndVGZx/DA= +github.com/aws/aws-sdk-go-v2/service/s3 v1.88.5 h1:FlGScxzCGNzT+2AvHT1ZGMvxTwAMa6gsooFb1pO/AiM= +github.com/aws/aws-sdk-go-v2/service/s3 v1.88.5/go.mod h1:N/iojY+8bW3MYol9NUMuKimpSbPEur75cuI1SmtonFM= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 h1:fspVFg6qMx0svs40YgRmE7LZXh9VRZvTT35PfdQR6FM= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7/go.mod h1:BQTKL3uMECaLaUV3Zc2L4Qybv8C6BIXjuu1dOPyxTQs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 h1:scVnW+NLXasGOhy7HhkdT9AGb6kjgW7fJ5xYkUaqHs0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2/go.mod h1:FRNCY3zTEWZXBKm2h5UBUPvCVDOecTad9KhynDyGBc0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 h1:VEO5dqFkMsl8QZ2yHsFDJAIZLAkEbaYDB+xdKi0Feic= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7/go.mod h1:L1xxV3zAdB+qVrVW/pBIrIAnHFWHo6FBbFe4xOGsG/o= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/internal/app/app.go b/internal/app/app.go new file mode 100644 index 0000000..1d746ce --- /dev/null +++ b/internal/app/app.go @@ -0,0 +1,64 @@ +package app + +import ( + "fmt" + "os" + + "github.com/stackvista/stackstate-backup-cli/internal/clients/elasticsearch" + "github.com/stackvista/stackstate-backup-cli/internal/clients/k8s" + "github.com/stackvista/stackstate-backup-cli/internal/clients/s3" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/logger" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/output" +) + +// Context holds all dependencies for cli commands +type Context struct { + K8sClient *k8s.Client + Namespace string + S3Client s3.Interface + ESClient elasticsearch.Interface + Config *config.Config + Logger *logger.Logger + Formatter *output.Formatter +} + +// NewContext creates production dependencies +func NewContext(flags *config.CLIGlobalFlags) (*Context, error) { + k8sClient, err := k8s.NewClient(flags.Kubeconfig, flags.Debug) + if err != nil { + return nil, fmt.Errorf("failed to create Kubernetes client: %w", err) + } + + // Load configuration + cfg, err := config.LoadConfig(k8sClient.Clientset(), flags.Namespace, flags.ConfigMapName, flags.SecretName) + if err != nil { + return nil, fmt.Errorf("failed to load configuration: %w", err) + } + + // Create S3 client + endpoint := fmt.Sprintf("http://localhost:%d", cfg.Minio.Service.LocalPortForwardPort) + s3Client, err := s3.NewClient(endpoint, cfg.Minio.AccessKey, cfg.Minio.SecretKey) + if err != nil { + return nil, err + } + + // Create Elasticsearch client + esClient, err := elasticsearch.NewClient(fmt.Sprintf("http://localhost:%d", cfg.Elasticsearch.Service.LocalPortForwardPort)) + if err != nil { + return nil, fmt.Errorf("failed to create Elasticsearch client: %w", err) + } + + // Format and print backups + formatter := output.NewFormatter(os.Stdout, flags.OutputFormat) + + return &Context{ + K8sClient: k8sClient, + Namespace: flags.Namespace, + Config: cfg, + S3Client: s3Client, + ESClient: esClient, + Logger: logger.New(flags.Quiet, flags.Debug), + Formatter: formatter, + }, nil +} diff --git a/internal/elasticsearch/client.go b/internal/clients/elasticsearch/client.go similarity index 100% rename from internal/elasticsearch/client.go rename to internal/clients/elasticsearch/client.go diff --git a/internal/elasticsearch/client_test.go b/internal/clients/elasticsearch/client_test.go similarity index 100% rename from internal/elasticsearch/client_test.go rename to internal/clients/elasticsearch/client_test.go diff --git a/internal/elasticsearch/interface.go b/internal/clients/elasticsearch/interface.go similarity index 100% rename from internal/elasticsearch/interface.go rename to internal/clients/elasticsearch/interface.go diff --git a/internal/k8s/client.go b/internal/clients/k8s/client.go similarity index 71% rename from internal/k8s/client.go rename to internal/clients/k8s/client.go index 6e72ebf..7eaf677 100644 --- a/internal/k8s/client.go +++ b/internal/clients/k8s/client.go @@ -104,18 +104,18 @@ func (c *Client) PortForwardService(namespace, serviceName string, localPort, re func (c *Client) PortForwardPod(namespace, podName string, localPort, remotePort int) (chan struct{}, chan struct{}, error) { path := fmt.Sprintf("/api/v1/namespaces/%s/pods/%s/portforward", namespace, podName) hostIP := c.restConfig.Host - url, err := url.Parse(hostIP) + pfURL, err := url.Parse(hostIP) if err != nil { return nil, nil, fmt.Errorf("failed to parse host: %w", err) } - url.Path = path + pfURL.Path = path transport, upgrader, err := spdy.RoundTripperFor(c.restConfig) if err != nil { return nil, nil, fmt.Errorf("failed to create round tripper: %w", err) } - dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, http.MethodPost, url) + dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, http.MethodPost, pfURL) stopChan := make(chan struct{}, 1) readyChan := make(chan struct{}) @@ -138,7 +138,7 @@ func (c *Client) PortForwardPod(namespace, podName string, localPort, remotePort go func() { if err := fw.ForwardPorts(); err != nil { if c.debug { - fmt.Fprintf(os.Stderr, "Port forward error: %v\n", err) + _, _ = fmt.Fprintf(os.Stderr, "Port forward error: %v\n", err) } } }() @@ -146,6 +146,11 @@ func (c *Client) PortForwardPod(namespace, podName string, localPort, remotePort return stopChan, readyChan, nil } +const ( + // PreRestoreReplicasAnnotation is the annotation key used to store original replica counts + PreRestoreReplicasAnnotation = "stackstate.com/pre-restore-replicas" +) + // DeploymentScale holds the name and original replica count of a deployment type DeploymentScale struct { Name string @@ -186,6 +191,12 @@ func (c *Client) ScaleDownDeployments(namespace, labelSelector string) ([]Deploy // Scale to 0 if not already at 0 if originalReplicas > 0 { + // Add annotation with original replica count + if deployment.Annotations == nil { + deployment.Annotations = make(map[string]string) + } + deployment.Annotations[PreRestoreReplicasAnnotation] = fmt.Sprintf("%d", originalReplicas) + replicas := int32(0) deployment.Spec.Replicas = &replicas @@ -199,25 +210,60 @@ func (c *Client) ScaleDownDeployments(namespace, labelSelector string) ([]Deploy return scaledDeployments, nil } -// ScaleUpDeployments restores deployments to their original replica counts -func (c *Client) ScaleUpDeployments(namespace string, deploymentScales []DeploymentScale) error { +// ScaleUpDeploymentsFromAnnotations scales up deployments that have the pre-restore-replicas annotation +// Returns a list of deployments that were scaled up with their replica counts +func (c *Client) ScaleUpDeploymentsFromAnnotations(namespace, labelSelector string) ([]DeploymentScale, error) { ctx := context.Background() - for _, scale := range deploymentScales { - deployment, err := c.clientset.AppsV1().Deployments(namespace).Get(ctx, scale.Name, metav1.GetOptions{}) - if err != nil { - return fmt.Errorf("failed to get deployment %s: %w", scale.Name, err) + // List deployments matching the label selector + deployments, err := c.clientset.AppsV1().Deployments(namespace).List(ctx, metav1.ListOptions{ + LabelSelector: labelSelector, + }) + if err != nil { + return nil, fmt.Errorf("failed to list deployments: %w", err) + } + + if len(deployments.Items) == 0 { + return []DeploymentScale{}, nil + } + + var scaledDeployments []DeploymentScale + + // Scale up each deployment that has the annotation + for _, deployment := range deployments.Items { + if deployment.Annotations == nil { + continue + } + + replicasStr, exists := deployment.Annotations[PreRestoreReplicasAnnotation] + if !exists { + continue + } + + var originalReplicas int32 + if _, err := fmt.Sscanf(replicasStr, "%d", &originalReplicas); err != nil { + return scaledDeployments, fmt.Errorf("failed to parse replicas annotation for deployment %s: %w", deployment.Name, err) } - deployment.Spec.Replicas = &scale.Replicas + // Scale up to original replica count + deployment.Spec.Replicas = &originalReplicas - _, err = c.clientset.AppsV1().Deployments(namespace).Update(ctx, deployment, metav1.UpdateOptions{}) + // Remove the annotation + delete(deployment.Annotations, PreRestoreReplicasAnnotation) + + _, err := c.clientset.AppsV1().Deployments(namespace).Update(ctx, &deployment, metav1.UpdateOptions{}) if err != nil { - return fmt.Errorf("failed to scale up deployment %s: %w", scale.Name, err) + return scaledDeployments, fmt.Errorf("failed to scale up deployment %s: %w", deployment.Name, err) } + + // Record scaled deployment + scaledDeployments = append(scaledDeployments, DeploymentScale{ + Name: deployment.Name, + Replicas: originalReplicas, + }) } - return nil + return scaledDeployments, nil } // NewTestClient creates a k8s Client for testing with a fake clientset. diff --git a/internal/clients/k8s/client_test.go b/internal/clients/k8s/client_test.go new file mode 100644 index 0000000..f9f66f4 --- /dev/null +++ b/internal/clients/k8s/client_test.go @@ -0,0 +1,464 @@ +package k8s + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +func TestClient_ScaleDownDeployments(t *testing.T) { + tests := []struct { + name string + namespace string + labelSelector string + deployments []appsv1.Deployment + expectedScales []DeploymentScale + expectError bool + }{ + { + name: "scale down multiple deployments", + namespace: "test-ns", + labelSelector: "app=test", + deployments: []appsv1.Deployment{ + createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 3), + createDeployment("deploy2", "test-ns", map[string]string{"app": "test"}, 5), + }, + expectedScales: []DeploymentScale{ + {Name: "deploy1", Replicas: 3}, + {Name: "deploy2", Replicas: 5}, + }, + expectError: false, + }, + { + name: "scale down deployment with zero replicas", + namespace: "test-ns", + labelSelector: "app=test", + deployments: []appsv1.Deployment{ + createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 0), + }, + expectedScales: []DeploymentScale{ + {Name: "deploy1", Replicas: 0}, + }, + expectError: false, + }, + { + name: "no deployments matching selector", + namespace: "test-ns", + labelSelector: "app=nonexistent", + deployments: []appsv1.Deployment{}, + expectedScales: []DeploymentScale{}, + expectError: false, + }, + { + name: "deployments with different labels not selected", + namespace: "test-ns", + labelSelector: "app=test", + deployments: []appsv1.Deployment{ + createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 3), + createDeployment("deploy2", "test-ns", map[string]string{"app": "other"}, 2), + }, + expectedScales: []DeploymentScale{ + {Name: "deploy1", Replicas: 3}, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create fake clientset with test deployments + fakeClient := fake.NewSimpleClientset() + for _, deploy := range tt.deployments { + _, err := fakeClient.AppsV1().Deployments(tt.namespace).Create( + context.Background(), &deploy, metav1.CreateOptions{}, + ) + require.NoError(t, err) + } + + // Create our client wrapper + client := &Client{ + clientset: fakeClient, + } + + // Execute scale down + scales, err := client.ScaleDownDeployments(tt.namespace, tt.labelSelector) + + // Assertions + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, len(tt.expectedScales), len(scales)) + + // Verify each scaled deployment + for i, expectedScale := range tt.expectedScales { + assert.Equal(t, expectedScale.Name, scales[i].Name) + assert.Equal(t, expectedScale.Replicas, scales[i].Replicas) + + // Verify the deployment was actually scaled to 0 + deploy, err := fakeClient.AppsV1().Deployments(tt.namespace).Get( + context.Background(), expectedScale.Name, metav1.GetOptions{}, + ) + require.NoError(t, err) + if expectedScale.Replicas > 0 { + assert.Equal(t, int32(0), *deploy.Spec.Replicas, "deployment should be scaled to 0") + // Verify annotation was added with original replica count + assert.Equal(t, fmt.Sprintf("%d", expectedScale.Replicas), deploy.Annotations[PreRestoreReplicasAnnotation], "annotation should be added with original replica count") + } + } + }) + } +} + +//nolint:funlen // Table-driven test with comprehensive test cases +func TestClient_ScaleUpDeploymentsFromAnnotations(t *testing.T) { + tests := []struct { + name string + namespace string + labelSelector string + deployments []appsv1.Deployment + expectedScales []DeploymentScale + expectError bool + errorContains string + }{ + { + name: "scale up multiple deployments from annotations", + namespace: "test-ns", + labelSelector: "app=test", + deployments: []appsv1.Deployment{ + func() appsv1.Deployment { + d := createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 0) + d.Annotations = map[string]string{PreRestoreReplicasAnnotation: "3"} + return d + }(), + func() appsv1.Deployment { + d := createDeployment("deploy2", "test-ns", map[string]string{"app": "test"}, 0) + d.Annotations = map[string]string{PreRestoreReplicasAnnotation: "5"} + return d + }(), + }, + expectedScales: []DeploymentScale{ + {Name: "deploy1", Replicas: 3}, + {Name: "deploy2", Replicas: 5}, + }, + expectError: false, + }, + { + name: "no deployments with annotations", + namespace: "test-ns", + labelSelector: "app=test", + deployments: []appsv1.Deployment{ + createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 0), + }, + expectedScales: []DeploymentScale{}, + expectError: false, + }, + { + name: "mixed deployments with and without annotations", + namespace: "test-ns", + labelSelector: "app=test", + deployments: []appsv1.Deployment{ + func() appsv1.Deployment { + d := createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 0) + d.Annotations = map[string]string{PreRestoreReplicasAnnotation: "3"} + return d + }(), + createDeployment("deploy2", "test-ns", map[string]string{"app": "test"}, 0), + }, + expectedScales: []DeploymentScale{ + {Name: "deploy1", Replicas: 3}, + }, + expectError: false, + }, + { + name: "invalid annotation value", + namespace: "test-ns", + labelSelector: "app=test", + deployments: []appsv1.Deployment{ + func() appsv1.Deployment { + d := createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 0) + d.Annotations = map[string]string{PreRestoreReplicasAnnotation: "invalid"} + return d + }(), + }, + expectedScales: []DeploymentScale{}, + expectError: true, + errorContains: "failed to parse replicas annotation", + }, + { + name: "scale to zero replicas", + namespace: "test-ns", + labelSelector: "app=test", + deployments: []appsv1.Deployment{ + func() appsv1.Deployment { + d := createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 0) + d.Annotations = map[string]string{PreRestoreReplicasAnnotation: "0"} + return d + }(), + }, + expectedScales: []DeploymentScale{ + {Name: "deploy1", Replicas: 0}, + }, + expectError: false, + }, + { + name: "no deployments matching selector", + namespace: "test-ns", + labelSelector: "app=test", + deployments: []appsv1.Deployment{}, + expectedScales: []DeploymentScale{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create fake clientset with test deployments + fakeClient := fake.NewSimpleClientset() + for _, deploy := range tt.deployments { + _, err := fakeClient.AppsV1().Deployments(tt.namespace).Create( + context.Background(), &deploy, metav1.CreateOptions{}, + ) + require.NoError(t, err) + } + + // Create our client wrapper + client := &Client{ + clientset: fakeClient, + } + + // Execute scale up from annotations + scales, err := client.ScaleUpDeploymentsFromAnnotations(tt.namespace, tt.labelSelector) + + // Assertions + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + return + } + + require.NoError(t, err) + assert.Equal(t, len(tt.expectedScales), len(scales)) + + // Verify each scaled deployment + for i, expectedScale := range tt.expectedScales { + assert.Equal(t, expectedScale.Name, scales[i].Name) + assert.Equal(t, expectedScale.Replicas, scales[i].Replicas) + + // Verify the deployment was actually scaled to expected replicas + deploy, err := fakeClient.AppsV1().Deployments(tt.namespace).Get( + context.Background(), expectedScale.Name, metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, expectedScale.Replicas, *deploy.Spec.Replicas, "deployment should be scaled to expected replicas") + + // Verify annotation was removed + _, exists := deploy.Annotations[PreRestoreReplicasAnnotation] + assert.False(t, exists, "annotation should be removed after scale up") + } + }) + } +} + +func TestClient_ScaleDownThenScaleUpFromAnnotations(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + + // Create deployments with different replica counts + deploy1 := createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 3) + deploy2 := createDeployment("deploy2", "test-ns", map[string]string{"app": "test"}, 5) + + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + _, err = fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy2, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + client := &Client{ + clientset: fakeClient, + } + + // Scale down + scaledDown, err := client.ScaleDownDeployments("test-ns", "app=test") + require.NoError(t, err) + assert.Len(t, scaledDown, 2) + + // Verify deployments are scaled to 0 + deploy1After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(0), *deploy1After.Spec.Replicas) + + // Verify annotations were added + assert.Equal(t, "3", deploy1After.Annotations[PreRestoreReplicasAnnotation]) + + deploy2After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy2", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, "5", deploy2After.Annotations[PreRestoreReplicasAnnotation]) + + // Scale up from annotations + scaledUp, err := client.ScaleUpDeploymentsFromAnnotations("test-ns", "app=test") + require.NoError(t, err) + assert.Len(t, scaledUp, 2) + + // Verify deployments are scaled back to original replicas + deploy1Final, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(3), *deploy1Final.Spec.Replicas) + + // Verify annotations were removed + _, exists := deploy1Final.Annotations[PreRestoreReplicasAnnotation] + assert.False(t, exists) + + deploy2Final, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy2", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(5), *deploy2Final.Spec.Replicas) + + _, exists = deploy2Final.Annotations[PreRestoreReplicasAnnotation] + assert.False(t, exists) +} + +func TestClient_Clientset(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{ + clientset: fakeClient, + } + + clientset := client.Clientset() + assert.NotNil(t, clientset) + assert.Equal(t, fakeClient, clientset) +} + +func TestClient_PortForwardService_ServiceNotFound(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{ + clientset: fakeClient, + } + + _, _, err := client.PortForwardService("test-ns", "nonexistent-svc", 8080, 9200) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to get service") +} + +func TestClient_PortForwardService_NoPodsFound(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + + // Create a service without any matching pods + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-svc", + Namespace: "test-ns", + }, + Spec: corev1.ServiceSpec{ + Selector: map[string]string{"app": "test"}, + }, + } + _, err := fakeClient.CoreV1().Services("test-ns").Create( + context.Background(), svc, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + client := &Client{ + clientset: fakeClient, + } + + _, _, err = client.PortForwardService("test-ns", "test-svc", 8080, 9200) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no pods found for service") +} + +func TestClient_PortForwardService_NoRunningPods(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + + // Create a service + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-svc", + Namespace: "test-ns", + }, + Spec: corev1.ServiceSpec{ + Selector: map[string]string{"app": "test"}, + }, + } + _, err := fakeClient.CoreV1().Services("test-ns").Create( + context.Background(), svc, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Create a pod in Pending state + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "test-ns", + Labels: map[string]string{"app": "test"}, + }, + Status: corev1.PodStatus{ + Phase: corev1.PodPending, + }, + } + _, err = fakeClient.CoreV1().Pods("test-ns").Create( + context.Background(), pod, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + client := &Client{ + clientset: fakeClient, + } + + _, _, err = client.PortForwardService("test-ns", "test-svc", 8080, 9200) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no running pods found for service") +} + +// Helper function to create a deployment for testing +// +//nolint:unparam // namespace parameter is always "test-ns" in current tests, but kept for flexibility +func createDeployment(name, namespace string, labels map[string]string, replicas int32) appsv1.Deployment { + return appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Labels: labels, + }, + Spec: appsv1.DeploymentSpec{ + Replicas: &replicas, + Selector: &metav1.LabelSelector{ + MatchLabels: labels, + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: labels, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test:latest", + }, + }, + }, + }, + }, + } +} diff --git a/internal/clients/k8s/configmap.go b/internal/clients/k8s/configmap.go new file mode 100644 index 0000000..5ebda8d --- /dev/null +++ b/internal/clients/k8s/configmap.go @@ -0,0 +1,88 @@ +package k8s + +import ( + "context" + "fmt" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// CreateConfigMap creates a ConfigMap with the provided data +func (c *Client) CreateConfigMap(namespace, name string, data map[string]string, labels map[string]string) (*corev1.ConfigMap, error) { + ctx := context.Background() + + if len(data) == 0 { + return nil, fmt.Errorf("no data provided for ConfigMap") + } + + // Create default labels if none provided + if labels == nil { + labels = make(map[string]string) + } + + configMap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Labels: labels, + }, + Data: data, + } + + created, err := c.clientset.CoreV1().ConfigMaps(namespace).Create(ctx, configMap, metav1.CreateOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to create ConfigMap: %w", err) + } + + return created, nil +} + +// GetConfigMap retrieves a ConfigMap by name +func (c *Client) GetConfigMap(namespace, name string) (*corev1.ConfigMap, error) { + ctx := context.Background() + return c.clientset.CoreV1().ConfigMaps(namespace).Get(ctx, name, metav1.GetOptions{}) +} + +// UpdateConfigMap updates an existing ConfigMap with new data +func (c *Client) UpdateConfigMap(namespace, name string, data map[string]string) (*corev1.ConfigMap, error) { + ctx := context.Background() + + if len(data) == 0 { + return nil, fmt.Errorf("no data provided for ConfigMap update") + } + + // Get existing ConfigMap + existing, err := c.GetConfigMap(namespace, name) + if err != nil { + return nil, fmt.Errorf("failed to get existing ConfigMap: %w", err) + } + + // Update data + existing.Data = data + + updated, err := c.clientset.CoreV1().ConfigMaps(namespace).Update(ctx, existing, metav1.UpdateOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to update ConfigMap: %w", err) + } + + return updated, nil +} + +// DeleteConfigMap deletes a ConfigMap by name +func (c *Client) DeleteConfigMap(namespace, name string) error { + ctx := context.Background() + return c.clientset.CoreV1().ConfigMaps(namespace).Delete(ctx, name, metav1.DeleteOptions{}) +} + +// EnsureConfigMap ensures a ConfigMap exists with the provided data, creating or updating it as needed +func (c *Client) EnsureConfigMap(namespace, name string, data map[string]string, labels map[string]string) (*corev1.ConfigMap, error) { + // Try to get existing ConfigMap + _, err := c.GetConfigMap(namespace, name) + if err == nil { + // ConfigMap exists, update it + return c.UpdateConfigMap(namespace, name, data) + } + + // ConfigMap doesn't exist, create it + return c.CreateConfigMap(namespace, name, data, labels) +} diff --git a/internal/clients/k8s/configmap_test.go b/internal/clients/k8s/configmap_test.go new file mode 100644 index 0000000..7f32a28 --- /dev/null +++ b/internal/clients/k8s/configmap_test.go @@ -0,0 +1,242 @@ +package k8s + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +// TestClient_CreateConfigMap tests ConfigMap creation +func TestClient_CreateConfigMap(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + data := map[string]string{ + "key1": "value1", + "key2": "value2", + } + labels := map[string]string{ + "app": "test", + } + + cm, err := client.CreateConfigMap("test-ns", "test-cm", data, labels) + + require.NoError(t, err) + assert.NotNil(t, cm) + assert.Equal(t, "test-cm", cm.Name) + assert.Equal(t, "test-ns", cm.Namespace) + assert.Equal(t, data, cm.Data) + assert.Equal(t, labels, cm.Labels) + + // Verify it was created in the fake clientset + createdCM, err := fakeClient.CoreV1().ConfigMaps("test-ns").Get( + context.Background(), "test-cm", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, data, createdCM.Data) +} + +// TestClient_GetConfigMap tests ConfigMap retrieval +func TestClient_GetConfigMap(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create a ConfigMap first + cm := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cm", + Namespace: "test-ns", + Labels: map[string]string{"app": "test"}, + }, + Data: map[string]string{"key": "value"}, + } + _, err := fakeClient.CoreV1().ConfigMaps("test-ns").Create( + context.Background(), cm, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Get the ConfigMap + retrievedCM, err := client.GetConfigMap("test-ns", "test-cm") + + require.NoError(t, err) + assert.NotNil(t, retrievedCM) + assert.Equal(t, "test-cm", retrievedCM.Name) + assert.Equal(t, map[string]string{"key": "value"}, retrievedCM.Data) +} + +// TestClient_GetConfigMap_NotFound tests error when ConfigMap doesn't exist +func TestClient_GetConfigMap_NotFound(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + _, err := client.GetConfigMap("test-ns", "nonexistent-cm") + + assert.Error(t, err) +} + +// TestClient_UpdateConfigMap tests ConfigMap update +func TestClient_UpdateConfigMap(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create a ConfigMap first + cm := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cm", + Namespace: "test-ns", + }, + Data: map[string]string{"key": "oldvalue"}, + } + _, err := fakeClient.CoreV1().ConfigMaps("test-ns").Create( + context.Background(), cm, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Update the ConfigMap + newData := map[string]string{"key": "newvalue", "key2": "value2"} + updatedCM, err := client.UpdateConfigMap("test-ns", "test-cm", newData) + + require.NoError(t, err) + assert.NotNil(t, updatedCM) + assert.Equal(t, newData, updatedCM.Data) + + // Verify the update in fake clientset + retrievedCM, err := fakeClient.CoreV1().ConfigMaps("test-ns").Get( + context.Background(), "test-cm", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, newData, retrievedCM.Data) +} + +// TestClient_UpdateConfigMap_NotFound tests error when ConfigMap doesn't exist +func TestClient_UpdateConfigMap_NotFound(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + _, err := client.UpdateConfigMap("test-ns", "nonexistent-cm", map[string]string{"key": "value"}) + + assert.Error(t, err) +} + +// TestClient_DeleteConfigMap tests ConfigMap deletion +func TestClient_DeleteConfigMap(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create a ConfigMap first + cm := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cm", + Namespace: "test-ns", + }, + Data: map[string]string{"key": "value"}, + } + _, err := fakeClient.CoreV1().ConfigMaps("test-ns").Create( + context.Background(), cm, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Delete the ConfigMap + err = client.DeleteConfigMap("test-ns", "test-cm") + + require.NoError(t, err) + + // Verify it was deleted + _, err = fakeClient.CoreV1().ConfigMaps("test-ns").Get( + context.Background(), "test-cm", metav1.GetOptions{}, + ) + assert.Error(t, err) +} + +// TestClient_EnsureConfigMap_Create tests EnsureConfigMap when ConfigMap doesn't exist +func TestClient_EnsureConfigMap_Create(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + data := map[string]string{"key": "value"} + labels := map[string]string{"app": "test"} + + cm, err := client.EnsureConfigMap("test-ns", "test-cm", data, labels) + + require.NoError(t, err) + assert.NotNil(t, cm) + assert.Equal(t, "test-cm", cm.Name) + assert.Equal(t, data, cm.Data) + assert.Equal(t, labels, cm.Labels) + + // Verify it was created + createdCM, err := fakeClient.CoreV1().ConfigMaps("test-ns").Get( + context.Background(), "test-cm", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, data, createdCM.Data) +} + +// TestClient_EnsureConfigMap_Update tests EnsureConfigMap when ConfigMap exists +func TestClient_EnsureConfigMap_Update(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create existing ConfigMap + existingCM := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cm", + Namespace: "test-ns", + }, + Data: map[string]string{"key": "oldvalue"}, + } + _, err := fakeClient.CoreV1().ConfigMaps("test-ns").Create( + context.Background(), existingCM, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Ensure with new data + newData := map[string]string{"key": "newvalue", "key2": "value2"} + labels := map[string]string{"app": "updated"} + + cm, err := client.EnsureConfigMap("test-ns", "test-cm", newData, labels) + + require.NoError(t, err) + assert.NotNil(t, cm) + assert.Equal(t, newData, cm.Data) + + // Verify it was updated + updatedCM, err := fakeClient.CoreV1().ConfigMaps("test-ns").Get( + context.Background(), "test-cm", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, newData, updatedCM.Data) +} + +// TestClient_EnsureConfigMap_NoChange tests EnsureConfigMap when data matches +func TestClient_EnsureConfigMap_NoChange(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + data := map[string]string{"key": "value"} + + // Create existing ConfigMap + existingCM := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-cm", + Namespace: "test-ns", + }, + Data: data, + } + _, err := fakeClient.CoreV1().ConfigMaps("test-ns").Create( + context.Background(), existingCM, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Ensure with same data + cm, err := client.EnsureConfigMap("test-ns", "test-cm", data, nil) + + require.NoError(t, err) + assert.NotNil(t, cm) + assert.Equal(t, data, cm.Data) +} diff --git a/internal/clients/k8s/convert.go b/internal/clients/k8s/convert.go new file mode 100644 index 0000000..f6d22ef --- /dev/null +++ b/internal/clients/k8s/convert.go @@ -0,0 +1,257 @@ +package k8s + +import ( + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// ConvertResources converts config.ResourceRequirements to k8s native ResourceRequirements +func ConvertResources(r config.ResourceRequirements) corev1.ResourceRequirements { + result := corev1.ResourceRequirements{ + Limits: corev1.ResourceList{}, + Requests: corev1.ResourceList{}, + } + + // Parse limits + if r.Limits.CPU != "" { + result.Limits[corev1.ResourceCPU] = resource.MustParse(r.Limits.CPU) + } + if r.Limits.Memory != "" { + result.Limits[corev1.ResourceMemory] = resource.MustParse(r.Limits.Memory) + } + if r.Limits.EphemeralStorage != "" { + result.Limits[corev1.ResourceEphemeralStorage] = resource.MustParse(r.Limits.EphemeralStorage) + } + + // Parse requests + if r.Requests.CPU != "" { + result.Requests[corev1.ResourceCPU] = resource.MustParse(r.Requests.CPU) + } + if r.Requests.Memory != "" { + result.Requests[corev1.ResourceMemory] = resource.MustParse(r.Requests.Memory) + } + if r.Requests.EphemeralStorage != "" { + result.Requests[corev1.ResourceEphemeralStorage] = resource.MustParse(r.Requests.EphemeralStorage) + } + + return result +} + +// ConvertImagePullSecrets converts config.LocalObjectRef slice to k8s native LocalObjectReference slice +func ConvertImagePullSecrets(refs []config.LocalObjectRef) []corev1.LocalObjectReference { + if len(refs) == 0 { + return nil + } + result := make([]corev1.LocalObjectReference, len(refs)) + for i, ref := range refs { + result[i] = corev1.LocalObjectReference{ + Name: ref.Name, + } + } + return result +} + +// ConvertPodSecurityContext converts config.PodSecurityContext to k8s native PodSecurityContext +func ConvertPodSecurityContext(sc *config.PodSecurityContext) *corev1.PodSecurityContext { + if sc == nil { + return nil + } + return &corev1.PodSecurityContext{ + FSGroup: sc.FSGroup, + RunAsGroup: sc.RunAsGroup, + RunAsNonRoot: sc.RunAsNonRoot, + RunAsUser: sc.RunAsUser, + } +} + +// ConvertSecurityContext converts config.SecurityContext to k8s native SecurityContext +func ConvertSecurityContext(sc *config.SecurityContext) *corev1.SecurityContext { + if sc == nil { + return nil + } + return &corev1.SecurityContext{ + AllowPrivilegeEscalation: sc.AllowPrivilegeEscalation, + RunAsNonRoot: sc.RunAsNonRoot, + RunAsUser: sc.RunAsUser, + } +} + +// ConvertTolerations converts config.Toleration slice to k8s native Toleration slice +func ConvertTolerations(tolerations []config.Toleration) []corev1.Toleration { + result := make([]corev1.Toleration, len(tolerations)) + for i, t := range tolerations { + result[i] = corev1.Toleration{ + Key: t.Key, + Operator: corev1.TolerationOperator(t.Operator), + Value: t.Value, + Effect: corev1.TaintEffect(t.Effect), + } + } + return result +} + +// ConvertAffinity converts config.Affinity to k8s native Affinity +func ConvertAffinity(affinity *config.Affinity) *corev1.Affinity { + if affinity == nil { + return nil + } + + result := &corev1.Affinity{} + + if affinity.NodeAffinity != nil { + result.NodeAffinity = ConvertNodeAffinity(affinity.NodeAffinity) + } + + if affinity.PodAffinity != nil { + result.PodAffinity = ConvertPodAffinity(affinity.PodAffinity) + } + + if affinity.PodAntiAffinity != nil { + result.PodAntiAffinity = ConvertPodAntiAffinity(affinity.PodAntiAffinity) + } + + return result +} + +// ConvertNodeAffinity converts config.NodeAffinity to k8s native NodeAffinity +func ConvertNodeAffinity(na *config.NodeAffinity) *corev1.NodeAffinity { + if na == nil { + return nil + } + + result := &corev1.NodeAffinity{} + + if na.RequiredDuringSchedulingIgnoredDuringExecution != nil { + result.RequiredDuringSchedulingIgnoredDuringExecution = &corev1.NodeSelector{ + NodeSelectorTerms: make([]corev1.NodeSelectorTerm, len(na.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms)), + } + for i, term := range na.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms { + result.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[i] = ConvertNodeSelectorTerm(term) + } + } + + if len(na.PreferredDuringSchedulingIgnoredDuringExecution) > 0 { + result.PreferredDuringSchedulingIgnoredDuringExecution = make([]corev1.PreferredSchedulingTerm, len(na.PreferredDuringSchedulingIgnoredDuringExecution)) + for i, term := range na.PreferredDuringSchedulingIgnoredDuringExecution { + result.PreferredDuringSchedulingIgnoredDuringExecution[i] = corev1.PreferredSchedulingTerm{ + Weight: term.Weight, + Preference: ConvertNodeSelectorTerm(term.Preference), + } + } + } + + return result +} + +// ConvertNodeSelectorTerm converts config.NodeSelectorTerm to k8s native NodeSelectorTerm +func ConvertNodeSelectorTerm(term config.NodeSelectorTerm) corev1.NodeSelectorTerm { + result := corev1.NodeSelectorTerm{} + + if len(term.MatchExpressions) > 0 { + result.MatchExpressions = make([]corev1.NodeSelectorRequirement, len(term.MatchExpressions)) + for i, expr := range term.MatchExpressions { + result.MatchExpressions[i] = corev1.NodeSelectorRequirement{ + Key: expr.Key, + Operator: corev1.NodeSelectorOperator(expr.Operator), + Values: expr.Values, + } + } + } + + if len(term.MatchFields) > 0 { + result.MatchFields = make([]corev1.NodeSelectorRequirement, len(term.MatchFields)) + for i, field := range term.MatchFields { + result.MatchFields[i] = corev1.NodeSelectorRequirement{ + Key: field.Key, + Operator: corev1.NodeSelectorOperator(field.Operator), + Values: field.Values, + } + } + } + + return result +} + +// ConvertPodAffinity converts config.PodAffinity to k8s native PodAffinity +func ConvertPodAffinity(pa *config.PodAffinity) *corev1.PodAffinity { + if pa == nil { + return nil + } + + result := &corev1.PodAffinity{} + + if len(pa.RequiredDuringSchedulingIgnoredDuringExecution) > 0 { + result.RequiredDuringSchedulingIgnoredDuringExecution = make([]corev1.PodAffinityTerm, len(pa.RequiredDuringSchedulingIgnoredDuringExecution)) + for i, term := range pa.RequiredDuringSchedulingIgnoredDuringExecution { + result.RequiredDuringSchedulingIgnoredDuringExecution[i] = ConvertPodAffinityTerm(term) + } + } + + if len(pa.PreferredDuringSchedulingIgnoredDuringExecution) > 0 { + result.PreferredDuringSchedulingIgnoredDuringExecution = make([]corev1.WeightedPodAffinityTerm, len(pa.PreferredDuringSchedulingIgnoredDuringExecution)) + for i, term := range pa.PreferredDuringSchedulingIgnoredDuringExecution { + result.PreferredDuringSchedulingIgnoredDuringExecution[i] = corev1.WeightedPodAffinityTerm{ + Weight: term.Weight, + PodAffinityTerm: ConvertPodAffinityTerm(term.PodAffinityTerm), + } + } + } + + return result +} + +// ConvertPodAntiAffinity converts config.PodAntiAffinity to k8s native PodAntiAffinity +func ConvertPodAntiAffinity(paa *config.PodAntiAffinity) *corev1.PodAntiAffinity { + if paa == nil { + return nil + } + + result := &corev1.PodAntiAffinity{} + + if len(paa.RequiredDuringSchedulingIgnoredDuringExecution) > 0 { + result.RequiredDuringSchedulingIgnoredDuringExecution = make([]corev1.PodAffinityTerm, len(paa.RequiredDuringSchedulingIgnoredDuringExecution)) + for i, term := range paa.RequiredDuringSchedulingIgnoredDuringExecution { + result.RequiredDuringSchedulingIgnoredDuringExecution[i] = ConvertPodAffinityTerm(term) + } + } + + if len(paa.PreferredDuringSchedulingIgnoredDuringExecution) > 0 { + result.PreferredDuringSchedulingIgnoredDuringExecution = make([]corev1.WeightedPodAffinityTerm, len(paa.PreferredDuringSchedulingIgnoredDuringExecution)) + for i, term := range paa.PreferredDuringSchedulingIgnoredDuringExecution { + result.PreferredDuringSchedulingIgnoredDuringExecution[i] = corev1.WeightedPodAffinityTerm{ + Weight: term.Weight, + PodAffinityTerm: ConvertPodAffinityTerm(term.PodAffinityTerm), + } + } + } + + return result +} + +// ConvertPodAffinityTerm converts config.PodAffinityTerm to k8s native PodAffinityTerm +func ConvertPodAffinityTerm(term config.PodAffinityTerm) corev1.PodAffinityTerm { + result := corev1.PodAffinityTerm{ + Namespaces: term.Namespaces, + TopologyKey: term.TopologyKey, + } + + if term.LabelSelector != nil { + result.LabelSelector = &metav1.LabelSelector{ + MatchLabels: term.LabelSelector.MatchLabels, + } + if len(term.LabelSelector.MatchExpressions) > 0 { + result.LabelSelector.MatchExpressions = make([]metav1.LabelSelectorRequirement, len(term.LabelSelector.MatchExpressions)) + for i, expr := range term.LabelSelector.MatchExpressions { + result.LabelSelector.MatchExpressions[i] = metav1.LabelSelectorRequirement{ + Key: expr.Key, + Operator: metav1.LabelSelectorOperator(expr.Operator), + Values: expr.Values, + } + } + } + } + + return result +} diff --git a/internal/clients/k8s/convert_test.go b/internal/clients/k8s/convert_test.go new file mode 100644 index 0000000..b972a22 --- /dev/null +++ b/internal/clients/k8s/convert_test.go @@ -0,0 +1,760 @@ +package k8s + +import ( + "testing" + + "github.com/stackvista/stackstate-backup-cli/internal/foundation/config" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// TestConvertResources tests conversion of resource requirements +func TestConvertResources(t *testing.T) { + tests := []struct { + name string + input config.ResourceRequirements + validate func(*testing.T, corev1.ResourceRequirements) + }{ + { + name: "empty resources", + input: config.ResourceRequirements{ + Limits: config.ResourceList{}, + Requests: config.ResourceList{}, + }, + validate: func(t *testing.T, result corev1.ResourceRequirements) { + assert.Empty(t, result.Limits) + assert.Empty(t, result.Requests) + }, + }, + { + name: "only CPU and memory limits", + input: config.ResourceRequirements{ + Limits: config.ResourceList{ + CPU: "1", + Memory: "2Gi", + }, + }, + validate: func(t *testing.T, result corev1.ResourceRequirements) { + assert.Len(t, result.Limits, 2) + assert.Contains(t, result.Limits, corev1.ResourceCPU) + assert.Contains(t, result.Limits, corev1.ResourceMemory) + assert.Empty(t, result.Requests) + }, + }, + { + name: "only requests", + input: config.ResourceRequirements{ + Requests: config.ResourceList{ + CPU: "500m", + Memory: "1Gi", + }, + }, + validate: func(t *testing.T, result corev1.ResourceRequirements) { + assert.Empty(t, result.Limits) + assert.Len(t, result.Requests, 2) + assert.Contains(t, result.Requests, corev1.ResourceCPU) + assert.Contains(t, result.Requests, corev1.ResourceMemory) + }, + }, + { + name: "full resources with ephemeral storage", + input: config.ResourceRequirements{ + Limits: config.ResourceList{ + CPU: "2", + Memory: "4Gi", + EphemeralStorage: "10Gi", + }, + Requests: config.ResourceList{ + CPU: "1", + Memory: "2Gi", + EphemeralStorage: "5Gi", + }, + }, + validate: func(t *testing.T, result corev1.ResourceRequirements) { + assert.Len(t, result.Limits, 3) + assert.Len(t, result.Requests, 3) + assert.Contains(t, result.Limits, corev1.ResourceCPU) + assert.Contains(t, result.Limits, corev1.ResourceMemory) + assert.Contains(t, result.Limits, corev1.ResourceEphemeralStorage) + assert.Contains(t, result.Requests, corev1.ResourceCPU) + assert.Contains(t, result.Requests, corev1.ResourceMemory) + assert.Contains(t, result.Requests, corev1.ResourceEphemeralStorage) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertResources(tt.input) + tt.validate(t, result) + }) + } +} + +// TestConvertImagePullSecrets tests conversion of image pull secrets +func TestConvertImagePullSecrets(t *testing.T) { + tests := []struct { + name string + input []config.LocalObjectRef + expected []corev1.LocalObjectReference + }{ + { + name: "empty slice", + input: []config.LocalObjectRef{}, + expected: nil, + }, + { + name: "nil slice", + input: nil, + expected: nil, + }, + { + name: "single secret", + input: []config.LocalObjectRef{ + {Name: "registry-secret"}, + }, + expected: []corev1.LocalObjectReference{ + {Name: "registry-secret"}, + }, + }, + { + name: "multiple secrets", + input: []config.LocalObjectRef{ + {Name: "docker-secret"}, + {Name: "gcr-secret"}, + {Name: "ecr-secret"}, + }, + expected: []corev1.LocalObjectReference{ + {Name: "docker-secret"}, + {Name: "gcr-secret"}, + {Name: "ecr-secret"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertImagePullSecrets(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestConvertPodSecurityContext tests conversion of pod security context +func TestConvertPodSecurityContext(t *testing.T) { + tests := []struct { + name string + input *config.PodSecurityContext + expected *corev1.PodSecurityContext + }{ + { + name: "nil context", + input: nil, + expected: nil, + }, + { + name: "empty context", + input: &config.PodSecurityContext{}, + expected: &corev1.PodSecurityContext{ + FSGroup: nil, + RunAsGroup: nil, + RunAsNonRoot: nil, + RunAsUser: nil, + }, + }, + { + name: "full context", + input: &config.PodSecurityContext{ + FSGroup: int64Ptr(3000), + RunAsGroup: int64Ptr(2000), + RunAsNonRoot: boolPtr(true), + RunAsUser: int64Ptr(1000), + }, + expected: &corev1.PodSecurityContext{ + FSGroup: int64Ptr(3000), + RunAsGroup: int64Ptr(2000), + RunAsNonRoot: boolPtr(true), + RunAsUser: int64Ptr(1000), + }, + }, + { + name: "partial context", + input: &config.PodSecurityContext{ + RunAsUser: int64Ptr(1000), + RunAsNonRoot: boolPtr(true), + }, + expected: &corev1.PodSecurityContext{ + RunAsUser: int64Ptr(1000), + RunAsNonRoot: boolPtr(true), + FSGroup: nil, + RunAsGroup: nil, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertPodSecurityContext(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestConvertSecurityContext tests conversion of container security context +func TestConvertSecurityContext(t *testing.T) { + tests := []struct { + name string + input *config.SecurityContext + expected *corev1.SecurityContext + }{ + { + name: "nil context", + input: nil, + expected: nil, + }, + { + name: "empty context", + input: &config.SecurityContext{}, + expected: &corev1.SecurityContext{ + AllowPrivilegeEscalation: nil, + RunAsNonRoot: nil, + RunAsUser: nil, + }, + }, + { + name: "full context", + input: &config.SecurityContext{ + AllowPrivilegeEscalation: boolPtr(false), + RunAsNonRoot: boolPtr(true), + RunAsUser: int64Ptr(1000), + }, + expected: &corev1.SecurityContext{ + AllowPrivilegeEscalation: boolPtr(false), + RunAsNonRoot: boolPtr(true), + RunAsUser: int64Ptr(1000), + }, + }, + { + name: "only privilege escalation", + input: &config.SecurityContext{ + AllowPrivilegeEscalation: boolPtr(false), + }, + expected: &corev1.SecurityContext{ + AllowPrivilegeEscalation: boolPtr(false), + RunAsNonRoot: nil, + RunAsUser: nil, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertSecurityContext(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestConvertTolerations tests conversion of tolerations +func TestConvertTolerations(t *testing.T) { + tests := []struct { + name string + input []config.Toleration + expected []corev1.Toleration + }{ + { + name: "empty slice", + input: []config.Toleration{}, + expected: []corev1.Toleration{}, + }, + { + name: "single toleration", + input: []config.Toleration{ + { + Key: "key1", + Operator: "Equal", + Value: "value1", + Effect: "NoSchedule", + }, + }, + expected: []corev1.Toleration{ + { + Key: "key1", + Operator: corev1.TolerationOpEqual, + Value: "value1", + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + { + name: "multiple tolerations", + input: []config.Toleration{ + { + Key: "node.kubernetes.io/not-ready", + Operator: "Exists", + Effect: "NoExecute", + }, + { + Key: "node.kubernetes.io/unreachable", + Operator: "Exists", + Effect: "NoExecute", + }, + { + Key: "disktype", + Operator: "Equal", + Value: "ssd", + Effect: "PreferNoSchedule", + }, + }, + expected: []corev1.Toleration{ + { + Key: "node.kubernetes.io/not-ready", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoExecute, + }, + { + Key: "node.kubernetes.io/unreachable", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoExecute, + }, + { + Key: "disktype", + Operator: corev1.TolerationOpEqual, + Value: "ssd", + Effect: corev1.TaintEffectPreferNoSchedule, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertTolerations(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestConvertAffinity tests conversion of affinity +func TestConvertAffinity(t *testing.T) { + tests := []struct { + name string + input *config.Affinity + validate func(*testing.T, *corev1.Affinity) + }{ + { + name: "nil affinity", + input: nil, + validate: func(t *testing.T, result *corev1.Affinity) { + assert.Nil(t, result) + }, + }, + { + name: "empty affinity", + input: &config.Affinity{}, + validate: func(t *testing.T, result *corev1.Affinity) { + assert.NotNil(t, result) + assert.Nil(t, result.NodeAffinity) + assert.Nil(t, result.PodAffinity) + assert.Nil(t, result.PodAntiAffinity) + }, + }, + { + name: "only node affinity", + input: &config.Affinity{ + NodeAffinity: &config.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &config.NodeSelector{ + NodeSelectorTerms: []config.NodeSelectorTerm{ + { + MatchExpressions: []config.NodeSelectorRequirement{ + { + Key: "disktype", + Operator: "In", + Values: []string{"ssd"}, + }, + }, + }, + }, + }, + }, + }, + validate: func(t *testing.T, result *corev1.Affinity) { + assert.NotNil(t, result) + assert.NotNil(t, result.NodeAffinity) + assert.Nil(t, result.PodAffinity) + assert.Nil(t, result.PodAntiAffinity) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertAffinity(tt.input) + tt.validate(t, result) + }) + } +} + +// TestConvertNodeAffinity tests conversion of node affinity +func TestConvertNodeAffinity(t *testing.T) { + tests := []struct { + name string + input *config.NodeAffinity + validate func(*testing.T, *corev1.NodeAffinity) + }{ + { + name: "nil node affinity", + input: nil, + validate: func(t *testing.T, result *corev1.NodeAffinity) { + assert.Nil(t, result) + }, + }, + { + name: "empty node affinity", + input: &config.NodeAffinity{}, + validate: func(t *testing.T, result *corev1.NodeAffinity) { + assert.NotNil(t, result) + assert.Nil(t, result.RequiredDuringSchedulingIgnoredDuringExecution) + assert.Nil(t, result.PreferredDuringSchedulingIgnoredDuringExecution) + }, + }, + { + name: "required node selector", + input: &config.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &config.NodeSelector{ + NodeSelectorTerms: []config.NodeSelectorTerm{ + { + MatchExpressions: []config.NodeSelectorRequirement{ + { + Key: "kubernetes.io/os", + Operator: "In", + Values: []string{"linux"}, + }, + }, + }, + }, + }, + }, + validate: func(t *testing.T, result *corev1.NodeAffinity) { + assert.NotNil(t, result) + assert.NotNil(t, result.RequiredDuringSchedulingIgnoredDuringExecution) + assert.Len(t, result.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms, 1) + assert.Len(t, result.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions, 1) + assert.Equal(t, "kubernetes.io/os", result.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0].Key) + }, + }, + { + name: "preferred node selector", + input: &config.NodeAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: []config.PreferredSchedulingTerm{ + { + Weight: 1, + Preference: config.NodeSelectorTerm{ + MatchExpressions: []config.NodeSelectorRequirement{ + { + Key: "zone", + Operator: "In", + Values: []string{"us-west-1a"}, + }, + }, + }, + }, + }, + }, + validate: func(t *testing.T, result *corev1.NodeAffinity) { + assert.NotNil(t, result) + assert.Len(t, result.PreferredDuringSchedulingIgnoredDuringExecution, 1) + assert.Equal(t, int32(1), result.PreferredDuringSchedulingIgnoredDuringExecution[0].Weight) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertNodeAffinity(tt.input) + tt.validate(t, result) + }) + } +} + +// TestConvertNodeSelectorTerm tests conversion of node selector term +func TestConvertNodeSelectorTerm(t *testing.T) { + tests := []struct { + name string + input config.NodeSelectorTerm + validate func(*testing.T, corev1.NodeSelectorTerm) + }{ + { + name: "empty term", + input: config.NodeSelectorTerm{}, + validate: func(t *testing.T, result corev1.NodeSelectorTerm) { + assert.Empty(t, result.MatchExpressions) + assert.Empty(t, result.MatchFields) + }, + }, + { + name: "match expressions", + input: config.NodeSelectorTerm{ + MatchExpressions: []config.NodeSelectorRequirement{ + { + Key: "disktype", + Operator: "In", + Values: []string{"ssd", "nvme"}, + }, + { + Key: "kubernetes.io/arch", + Operator: "NotIn", + Values: []string{"arm"}, + }, + }, + }, + validate: func(t *testing.T, result corev1.NodeSelectorTerm) { + assert.Len(t, result.MatchExpressions, 2) + assert.Equal(t, "disktype", result.MatchExpressions[0].Key) + assert.Equal(t, corev1.NodeSelectorOpIn, result.MatchExpressions[0].Operator) + assert.Equal(t, []string{"ssd", "nvme"}, result.MatchExpressions[0].Values) + }, + }, + { + name: "match fields", + input: config.NodeSelectorTerm{ + MatchFields: []config.NodeSelectorRequirement{ + { + Key: "metadata.name", + Operator: "In", + Values: []string{"node-1", "node-2"}, + }, + }, + }, + validate: func(t *testing.T, result corev1.NodeSelectorTerm) { + assert.Len(t, result.MatchFields, 1) + assert.Equal(t, "metadata.name", result.MatchFields[0].Key) + assert.Equal(t, corev1.NodeSelectorOpIn, result.MatchFields[0].Operator) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertNodeSelectorTerm(tt.input) + tt.validate(t, result) + }) + } +} + +// TestConvertPodAffinity tests conversion of pod affinity +func TestConvertPodAffinity(t *testing.T) { + tests := []struct { + name string + input *config.PodAffinity + validate func(*testing.T, *corev1.PodAffinity) + }{ + { + name: "nil pod affinity", + input: nil, + validate: func(t *testing.T, result *corev1.PodAffinity) { + assert.Nil(t, result) + }, + }, + { + name: "empty pod affinity", + input: &config.PodAffinity{}, + validate: func(t *testing.T, result *corev1.PodAffinity) { + assert.NotNil(t, result) + assert.Empty(t, result.RequiredDuringSchedulingIgnoredDuringExecution) + assert.Empty(t, result.PreferredDuringSchedulingIgnoredDuringExecution) + }, + }, + { + name: "required pod affinity", + input: &config.PodAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: []config.PodAffinityTerm{ + { + TopologyKey: "kubernetes.io/hostname", + LabelSelector: &config.LabelSelector{ + MatchLabels: map[string]string{"app": "cache"}, + }, + }, + }, + }, + validate: func(t *testing.T, result *corev1.PodAffinity) { + assert.NotNil(t, result) + assert.Len(t, result.RequiredDuringSchedulingIgnoredDuringExecution, 1) + assert.Equal(t, "kubernetes.io/hostname", result.RequiredDuringSchedulingIgnoredDuringExecution[0].TopologyKey) + }, + }, + { + name: "preferred pod affinity", + input: &config.PodAffinity{ + PreferredDuringSchedulingIgnoredDuringExecution: []config.WeightedPodAffinityTerm{ + { + Weight: 100, + PodAffinityTerm: config.PodAffinityTerm{ + TopologyKey: "zone", + LabelSelector: &config.LabelSelector{ + MatchLabels: map[string]string{"app": "web"}, + }, + }, + }, + }, + }, + validate: func(t *testing.T, result *corev1.PodAffinity) { + assert.NotNil(t, result) + assert.Len(t, result.PreferredDuringSchedulingIgnoredDuringExecution, 1) + assert.Equal(t, int32(100), result.PreferredDuringSchedulingIgnoredDuringExecution[0].Weight) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertPodAffinity(tt.input) + tt.validate(t, result) + }) + } +} + +// TestConvertPodAntiAffinity tests conversion of pod anti-affinity +func TestConvertPodAntiAffinity(t *testing.T) { + tests := []struct { + name string + input *config.PodAntiAffinity + validate func(*testing.T, *corev1.PodAntiAffinity) + }{ + { + name: "nil pod anti-affinity", + input: nil, + validate: func(t *testing.T, result *corev1.PodAntiAffinity) { + assert.Nil(t, result) + }, + }, + { + name: "empty pod anti-affinity", + input: &config.PodAntiAffinity{}, + validate: func(t *testing.T, result *corev1.PodAntiAffinity) { + assert.NotNil(t, result) + assert.Empty(t, result.RequiredDuringSchedulingIgnoredDuringExecution) + assert.Empty(t, result.PreferredDuringSchedulingIgnoredDuringExecution) + }, + }, + { + name: "required pod anti-affinity", + input: &config.PodAntiAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: []config.PodAffinityTerm{ + { + TopologyKey: "kubernetes.io/hostname", + LabelSelector: &config.LabelSelector{ + MatchLabels: map[string]string{"app": "backup"}, + }, + }, + }, + }, + validate: func(t *testing.T, result *corev1.PodAntiAffinity) { + assert.NotNil(t, result) + assert.Len(t, result.RequiredDuringSchedulingIgnoredDuringExecution, 1) + assert.Equal(t, "kubernetes.io/hostname", result.RequiredDuringSchedulingIgnoredDuringExecution[0].TopologyKey) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertPodAntiAffinity(tt.input) + tt.validate(t, result) + }) + } +} + +// TestConvertPodAffinityTerm tests conversion of pod affinity term +func TestConvertPodAffinityTerm(t *testing.T) { + tests := []struct { + name string + input config.PodAffinityTerm + validate func(*testing.T, corev1.PodAffinityTerm) + }{ + { + name: "minimal term", + input: config.PodAffinityTerm{ + TopologyKey: "kubernetes.io/hostname", + }, + validate: func(t *testing.T, result corev1.PodAffinityTerm) { + assert.Equal(t, "kubernetes.io/hostname", result.TopologyKey) + assert.Nil(t, result.LabelSelector) + assert.Empty(t, result.Namespaces) + }, + }, + { + name: "with label selector", + input: config.PodAffinityTerm{ + TopologyKey: "zone", + LabelSelector: &config.LabelSelector{ + MatchLabels: map[string]string{ + "app": "database", + "env": "prod", + }, + }, + }, + validate: func(t *testing.T, result corev1.PodAffinityTerm) { + assert.Equal(t, "zone", result.TopologyKey) + assert.NotNil(t, result.LabelSelector) + assert.Equal(t, map[string]string{"app": "database", "env": "prod"}, result.LabelSelector.MatchLabels) + }, + }, + { + name: "with label selector and match expressions", + input: config.PodAffinityTerm{ + TopologyKey: "region", + LabelSelector: &config.LabelSelector{ + MatchLabels: map[string]string{"app": "web"}, + MatchExpressions: []config.LabelSelectorRequirement{ + { + Key: "tier", + Operator: "In", + Values: []string{"frontend", "backend"}, + }, + }, + }, + }, + validate: func(t *testing.T, result corev1.PodAffinityTerm) { + assert.Equal(t, "region", result.TopologyKey) + assert.NotNil(t, result.LabelSelector) + assert.Len(t, result.LabelSelector.MatchExpressions, 1) + assert.Equal(t, "tier", result.LabelSelector.MatchExpressions[0].Key) + assert.Equal(t, metav1.LabelSelectorOpIn, result.LabelSelector.MatchExpressions[0].Operator) + assert.Equal(t, []string{"frontend", "backend"}, result.LabelSelector.MatchExpressions[0].Values) + }, + }, + { + name: "with namespaces", + input: config.PodAffinityTerm{ + TopologyKey: "kubernetes.io/hostname", + Namespaces: []string{"default", "kube-system"}, + LabelSelector: &config.LabelSelector{ + MatchLabels: map[string]string{"app": "monitoring"}, + }, + }, + validate: func(t *testing.T, result corev1.PodAffinityTerm) { + assert.Equal(t, "kubernetes.io/hostname", result.TopologyKey) + assert.Equal(t, []string{"default", "kube-system"}, result.Namespaces) + assert.NotNil(t, result.LabelSelector) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertPodAffinityTerm(tt.input) + tt.validate(t, result) + }) + } +} + +// Helper functions for test pointers +func boolPtr(b bool) *bool { + return &b +} + +func int64Ptr(i int64) *int64 { + return &i +} diff --git a/internal/k8s/interface.go b/internal/clients/k8s/interface.go similarity index 87% rename from internal/k8s/interface.go rename to internal/clients/k8s/interface.go index 679d9af..13cfee8 100644 --- a/internal/k8s/interface.go +++ b/internal/clients/k8s/interface.go @@ -14,7 +14,7 @@ type Interface interface { // Deployment scaling operations ScaleDownDeployments(namespace, labelSelector string) ([]DeploymentScale, error) - ScaleUpDeployments(namespace string, deployments []DeploymentScale) error + ScaleUpDeploymentsFromAnnotations(namespace, labelSelector string) ([]DeploymentScale, error) } // Ensure *Client implements Interface diff --git a/internal/clients/k8s/job.go b/internal/clients/k8s/job.go new file mode 100644 index 0000000..e976f1a --- /dev/null +++ b/internal/clients/k8s/job.go @@ -0,0 +1,194 @@ +package k8s + +import ( + "context" + "fmt" + + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + // defaultJobTTLSeconds is the time-to-live for completed/failed jobs (10 minutes) + defaultJobTTLSeconds = 600 +) + +// BackupJobSpec contains all parameters needed to create a backup/restore job +type BackupJobSpec struct { + // Job metadata + Name string + Labels map[string]string + + // Pod spec parameters (using native k8s types) + ImagePullSecrets []corev1.LocalObjectReference + SecurityContext *corev1.PodSecurityContext + NodeSelector map[string]string + Tolerations []corev1.Toleration + Affinity *corev1.Affinity + ContainerSecurityContext *corev1.SecurityContext + + // Container spec + Image string + Command []string + Env []corev1.EnvVar + Resources corev1.ResourceRequirements + VolumeMounts []corev1.VolumeMount + InitContainers []corev1.Container + + // Volumes + Volumes []corev1.Volume + + // PVC spec (optional - only for jobs that need persistent storage) + // If nil, no PVC will be created + PVCSpec *PVCSpec +} + +// PVCSpec contains parameters for creating a PersistentVolumeClaim +// NOTE: Some backup types (e.g., Stackgraph) require a PVC for temporary storage, +// while others (e.g., Elasticsearch, VictoriaMetrics, Clickhouse) may not. +// Set this to nil if the job doesn't require persistent storage. +type PVCSpec struct { + Name string + Labels map[string]string + StorageSize string // e.g., "10Gi" + AccessModes []corev1.PersistentVolumeAccessMode + StorageClass *string // nil for default storage class +} + +// CreatePVC creates a PersistentVolumeClaim +func (c *Client) CreatePVC(namespace string, spec PVCSpec) (*corev1.PersistentVolumeClaim, error) { + ctx := context.Background() + + pvc := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: spec.Name, + Labels: spec.Labels, + }, + Spec: corev1.PersistentVolumeClaimSpec{ + AccessModes: spec.AccessModes, + Resources: corev1.VolumeResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceStorage: resource.MustParse(spec.StorageSize), + }, + }, + }, + } + + if spec.StorageClass != nil { + pvc.Spec.StorageClassName = spec.StorageClass + } + + createdPVC, err := c.clientset.CoreV1().PersistentVolumeClaims(namespace).Create(ctx, pvc, metav1.CreateOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to create PVC: %w", err) + } + + return createdPVC, nil +} + +// CreateBackupJob creates a Kubernetes Job for backup/restore operations +// Note: PVC must be created separately if needed using CreatePVC +// Returns the created Job and any error +func (c *Client) CreateBackupJob(namespace string, spec BackupJobSpec) (*batchv1.Job, error) { + ctx := context.Background() + + // Build Job spec + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: spec.Name, + Labels: spec.Labels, + }, + Spec: batchv1.JobSpec{ + BackoffLimit: ptr(int32(1)), + TTLSecondsAfterFinished: ptr(int32(defaultJobTTLSeconds)), + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: spec.Labels, + }, + Spec: corev1.PodSpec{ + RestartPolicy: corev1.RestartPolicyNever, + InitContainers: spec.InitContainers, + Containers: []corev1.Container{ + { + Name: "restore", + Image: spec.Image, + ImagePullPolicy: corev1.PullIfNotPresent, + Command: spec.Command, + Env: spec.Env, + Resources: spec.Resources, + VolumeMounts: spec.VolumeMounts, + }, + }, + Volumes: spec.Volumes, + }, + }, + }, + } + + // Apply security context if provided + if spec.SecurityContext != nil { + job.Spec.Template.Spec.SecurityContext = spec.SecurityContext + } + + // Apply container security context if provided + if spec.ContainerSecurityContext != nil { + job.Spec.Template.Spec.Containers[0].SecurityContext = spec.ContainerSecurityContext + } + + // Apply node selector if provided + if len(spec.NodeSelector) > 0 { + job.Spec.Template.Spec.NodeSelector = spec.NodeSelector + } + + // Apply tolerations if provided + if len(spec.Tolerations) > 0 { + job.Spec.Template.Spec.Tolerations = spec.Tolerations + } + + // Apply affinity if provided + if spec.Affinity != nil { + job.Spec.Template.Spec.Affinity = spec.Affinity + } + + // Apply image pull secrets if provided + if len(spec.ImagePullSecrets) > 0 { + job.Spec.Template.Spec.ImagePullSecrets = spec.ImagePullSecrets + } + + // Create Job + createdJob, err := c.clientset.BatchV1().Jobs(namespace).Create(ctx, job, metav1.CreateOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to create job: %w", err) + } + + return createdJob, nil +} + +// DeleteJob deletes a Kubernetes Job with Background propagation policy +// This ensures child pods are automatically cleaned up when the job is deleted +func (c *Client) DeleteJob(namespace, name string) error { + ctx := context.Background() + propagationPolicy := metav1.DeletePropagationBackground + return c.clientset.BatchV1().Jobs(namespace).Delete(ctx, name, metav1.DeleteOptions{ + PropagationPolicy: &propagationPolicy, + }) +} + +// DeletePVC deletes a PersistentVolumeClaim +func (c *Client) DeletePVC(namespace, name string) error { + ctx := context.Background() + return c.clientset.CoreV1().PersistentVolumeClaims(namespace).Delete(ctx, name, metav1.DeleteOptions{}) +} + +// GetJob retrieves a Kubernetes Job +func (c *Client) GetJob(namespace, name string) (*batchv1.Job, error) { + ctx := context.Background() + return c.clientset.BatchV1().Jobs(namespace).Get(ctx, name, metav1.GetOptions{}) +} + +// ptr returns a pointer to the provided value +func ptr[T any](v T) *T { + return &v +} diff --git a/internal/clients/k8s/job_test.go b/internal/clients/k8s/job_test.go new file mode 100644 index 0000000..7c9c1bf --- /dev/null +++ b/internal/clients/k8s/job_test.go @@ -0,0 +1,432 @@ +package k8s + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + batchv1 "k8s.io/api/batch/v1" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +// TestClient_CreatePVC tests PersistentVolumeClaim creation +func TestClient_CreatePVC(t *testing.T) { + tests := []struct { + name string + namespace string + spec PVCSpec + expectError bool + validateFunc func(*testing.T, *corev1.PersistentVolumeClaim) + }{ + { + name: "create PVC with default storage class", + namespace: "test-ns", + spec: PVCSpec{ + Name: "test-pvc", + Labels: map[string]string{"app": "backup"}, + StorageSize: "10Gi", + AccessModes: []corev1.PersistentVolumeAccessMode{corev1.ReadWriteOnce}, + }, + expectError: false, + validateFunc: func(t *testing.T, pvc *corev1.PersistentVolumeClaim) { + assert.Equal(t, "test-pvc", pvc.Name) + assert.Equal(t, map[string]string{"app": "backup"}, pvc.Labels) + assert.Equal(t, []corev1.PersistentVolumeAccessMode{corev1.ReadWriteOnce}, pvc.Spec.AccessModes) + assert.Nil(t, pvc.Spec.StorageClassName) + }, + }, + { + name: "create PVC with custom storage class", + namespace: "test-ns", + spec: PVCSpec{ + Name: "test-pvc-custom", + Labels: map[string]string{"type": "restore"}, + StorageSize: "20Gi", + AccessModes: []corev1.PersistentVolumeAccessMode{corev1.ReadWriteMany}, + StorageClass: ptr("fast-ssd"), + }, + expectError: false, + validateFunc: func(t *testing.T, pvc *corev1.PersistentVolumeClaim) { + assert.Equal(t, "test-pvc-custom", pvc.Name) + assert.NotNil(t, pvc.Spec.StorageClassName) + assert.Equal(t, "fast-ssd", *pvc.Spec.StorageClassName) + }, + }, + { + name: "create PVC with multiple access modes", + namespace: "test-ns", + spec: PVCSpec{ + Name: "multi-mode-pvc", + StorageSize: "5Gi", + AccessModes: []corev1.PersistentVolumeAccessMode{ + corev1.ReadWriteOnce, + corev1.ReadOnlyMany, + }, + }, + expectError: false, + validateFunc: func(t *testing.T, pvc *corev1.PersistentVolumeClaim) { + assert.Len(t, pvc.Spec.AccessModes, 2) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + pvc, err := client.CreatePVC(tt.namespace, tt.spec) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, pvc) + } else { + require.NoError(t, err) + assert.NotNil(t, pvc) + if tt.validateFunc != nil { + tt.validateFunc(t, pvc) + } + + // Verify PVC was actually created in fake client + createdPVC, err := fakeClient.CoreV1().PersistentVolumeClaims(tt.namespace).Get( + context.Background(), tt.spec.Name, metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, tt.spec.Name, createdPVC.Name) + } + }) + } +} + +// TestClient_CreateBackupJob tests Job creation for backup/restore operations +// +//nolint:funlen +func TestClient_CreateBackupJob(t *testing.T) { + tests := []struct { + name string + namespace string + spec BackupJobSpec + expectError bool + validateFunc func(*testing.T, *batchv1.Job) + }{ + { + name: "create minimal job", + namespace: "test-ns", + spec: BackupJobSpec{ + Name: "backup-job", + Labels: map[string]string{"app": "backup"}, + Image: "backup:latest", + Command: []string{ + "/bin/sh", "-c", "echo 'backup complete'", + }, + }, + expectError: false, + validateFunc: func(t *testing.T, job *batchv1.Job) { + assert.Equal(t, "backup-job", job.Name) + assert.Equal(t, map[string]string{"app": "backup"}, job.Labels) + assert.Equal(t, int32(1), *job.Spec.BackoffLimit) + assert.Equal(t, int32(600), *job.Spec.TTLSecondsAfterFinished) + assert.Equal(t, corev1.RestartPolicyNever, job.Spec.Template.Spec.RestartPolicy) + }, + }, + { + name: "create job with environment variables", + namespace: "test-ns", + spec: BackupJobSpec{ + Name: "restore-job", + Image: "restore:v1", + Command: []string{"/restore.sh"}, + Env: []corev1.EnvVar{ + {Name: "BACKUP_NAME", Value: "snapshot-123"}, + {Name: "LOG_LEVEL", Value: "debug"}, + }, + }, + expectError: false, + validateFunc: func(t *testing.T, job *batchv1.Job) { + assert.Len(t, job.Spec.Template.Spec.Containers, 1) + assert.Len(t, job.Spec.Template.Spec.Containers[0].Env, 2) + assert.Equal(t, "BACKUP_NAME", job.Spec.Template.Spec.Containers[0].Env[0].Name) + }, + }, + { + name: "create job with resource requirements", + namespace: "test-ns", + spec: BackupJobSpec{ + Name: "resource-job", + Image: "backup:latest", + Command: []string{"/backup.sh"}, + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("1"), + corev1.ResourceMemory: resource.MustParse("2Gi"), + }, + Limits: corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("2"), + corev1.ResourceMemory: resource.MustParse("4Gi"), + }, + }, + }, + expectError: false, + validateFunc: func(t *testing.T, job *batchv1.Job) { + resources := job.Spec.Template.Spec.Containers[0].Resources + assert.NotNil(t, resources.Requests) + assert.NotNil(t, resources.Limits) + }, + }, + { + name: "create job with init containers", + namespace: "test-ns", + spec: BackupJobSpec{ + Name: "init-job", + Image: "main:latest", + Command: []string{"/main.sh"}, + InitContainers: []corev1.Container{ + { + Name: "wait-for-deps", + Image: "wait:latest", + Command: []string{"/wait.sh"}, + }, + }, + }, + expectError: false, + validateFunc: func(t *testing.T, job *batchv1.Job) { + assert.Len(t, job.Spec.Template.Spec.InitContainers, 1) + assert.Equal(t, "wait-for-deps", job.Spec.Template.Spec.InitContainers[0].Name) + }, + }, + { + name: "create job with volumes and mounts", + namespace: "test-ns", + spec: BackupJobSpec{ + Name: "volume-job", + Image: "backup:latest", + Command: []string{"/backup.sh"}, + VolumeMounts: []corev1.VolumeMount{ + {Name: "data", MountPath: "/data"}, + {Name: "config", MountPath: "/config"}, + }, + Volumes: []corev1.Volume{ + { + Name: "data", + VolumeSource: corev1.VolumeSource{ + PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ + ClaimName: "data-pvc", + }, + }, + }, + { + Name: "config", + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: "backup-config", + }, + }, + }, + }, + }, + }, + expectError: false, + validateFunc: func(t *testing.T, job *batchv1.Job) { + assert.Len(t, job.Spec.Template.Spec.Volumes, 2) + assert.Len(t, job.Spec.Template.Spec.Containers[0].VolumeMounts, 2) + }, + }, + { + name: "create job with security context", + namespace: "test-ns", + spec: BackupJobSpec{ + Name: "secure-job", + Image: "backup:latest", + Command: []string{"/backup.sh"}, + SecurityContext: &corev1.PodSecurityContext{ + RunAsUser: ptr(int64(1000)), + RunAsGroup: ptr(int64(2000)), + FSGroup: ptr(int64(3000)), + }, + ContainerSecurityContext: &corev1.SecurityContext{ + AllowPrivilegeEscalation: ptr(false), + ReadOnlyRootFilesystem: ptr(true), + }, + }, + expectError: false, + validateFunc: func(t *testing.T, job *batchv1.Job) { + assert.NotNil(t, job.Spec.Template.Spec.SecurityContext) + assert.Equal(t, int64(1000), *job.Spec.Template.Spec.SecurityContext.RunAsUser) + assert.NotNil(t, job.Spec.Template.Spec.Containers[0].SecurityContext) + assert.False(t, *job.Spec.Template.Spec.Containers[0].SecurityContext.AllowPrivilegeEscalation) + }, + }, + { + name: "create job with node selector and tolerations", + namespace: "test-ns", + spec: BackupJobSpec{ + Name: "scheduled-job", + Image: "backup:latest", + Command: []string{"/backup.sh"}, + NodeSelector: map[string]string{ + "disktype": "ssd", + "zone": "us-west-1a", + }, + Tolerations: []corev1.Toleration{ + { + Key: "key1", + Operator: corev1.TolerationOpEqual, + Value: "value1", + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, + expectError: false, + validateFunc: func(t *testing.T, job *batchv1.Job) { + assert.Len(t, job.Spec.Template.Spec.NodeSelector, 2) + assert.Len(t, job.Spec.Template.Spec.Tolerations, 1) + assert.Equal(t, "key1", job.Spec.Template.Spec.Tolerations[0].Key) + }, + }, + { + name: "create job with image pull secrets", + namespace: "test-ns", + spec: BackupJobSpec{ + Name: "private-image-job", + Image: "private-registry.com/backup:latest", + Command: []string{"/backup.sh"}, + ImagePullSecrets: []corev1.LocalObjectReference{ + {Name: "registry-secret"}, + }, + }, + expectError: false, + validateFunc: func(t *testing.T, job *batchv1.Job) { + assert.Len(t, job.Spec.Template.Spec.ImagePullSecrets, 1) + assert.Equal(t, "registry-secret", job.Spec.Template.Spec.ImagePullSecrets[0].Name) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + job, err := client.CreateBackupJob(tt.namespace, tt.spec) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, job) + } else { + require.NoError(t, err) + assert.NotNil(t, job) + if tt.validateFunc != nil { + tt.validateFunc(t, job) + } + + // Verify job was actually created in fake client + createdJob, err := fakeClient.BatchV1().Jobs(tt.namespace).Get( + context.Background(), tt.spec.Name, metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, tt.spec.Name, createdJob.Name) + } + }) + } +} + +// TestClient_DeleteJob tests Job deletion +func TestClient_DeleteJob(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create a job first + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-job", + Namespace: "test-ns", + }, + } + _, err := fakeClient.BatchV1().Jobs("test-ns").Create( + context.Background(), job, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Delete the job + err = client.DeleteJob("test-ns", "test-job") + assert.NoError(t, err) + + // Verify job was deleted + _, err = fakeClient.BatchV1().Jobs("test-ns").Get( + context.Background(), "test-job", metav1.GetOptions{}, + ) + assert.Error(t, err) +} + +// TestClient_DeletePVC tests PVC deletion +func TestClient_DeletePVC(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create a PVC first + pvc := &corev1.PersistentVolumeClaim{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pvc", + Namespace: "test-ns", + }, + } + _, err := fakeClient.CoreV1().PersistentVolumeClaims("test-ns").Create( + context.Background(), pvc, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Delete the PVC + err = client.DeletePVC("test-ns", "test-pvc") + assert.NoError(t, err) + + // Verify PVC was deleted + _, err = fakeClient.CoreV1().PersistentVolumeClaims("test-ns").Get( + context.Background(), "test-pvc", metav1.GetOptions{}, + ) + assert.Error(t, err) +} + +// TestClient_GetJob tests Job retrieval +func TestClient_GetJob(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create a job + job := &batchv1.Job{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-job", + Namespace: "test-ns", + Labels: map[string]string{"app": "backup"}, + }, + } + _, err := fakeClient.BatchV1().Jobs("test-ns").Create( + context.Background(), job, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Get the job + retrievedJob, err := client.GetJob("test-ns", "test-job") + assert.NoError(t, err) + assert.NotNil(t, retrievedJob) + assert.Equal(t, "test-job", retrievedJob.Name) + assert.Equal(t, map[string]string{"app": "backup"}, retrievedJob.Labels) +} + +// TestClient_GetJob_NotFound tests error case when job doesn't exist +func TestClient_GetJob_NotFound(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + _, err := client.GetJob("test-ns", "nonexistent-job") + assert.Error(t, err) +} + +// TestClient_DefaultJobTTL tests the default TTL constant +func TestClient_DefaultJobTTL(t *testing.T) { + assert.Equal(t, int32(600), int32(defaultJobTTLSeconds)) +} diff --git a/internal/clients/k8s/logs.go b/internal/clients/k8s/logs.go new file mode 100644 index 0000000..610bd10 --- /dev/null +++ b/internal/clients/k8s/logs.go @@ -0,0 +1,123 @@ +package k8s + +import ( + "bytes" + "context" + "fmt" + "io" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// ContainerLog holds logs from a single container +type ContainerLog struct { + Name string // Container name + IsInit bool // True if this is an init container + Logs string // Container logs +} + +// PodLogs holds logs from all containers in a pod +type PodLogs struct { + PodName string + ContainerLogs []ContainerLog +} + +// GetJobLogs retrieves logs from all pods belonging to a job +// Returns a slice of PodLogs, one for each pod in the job +func (c *Client) GetJobLogs(namespace, jobName string) ([]PodLogs, error) { + ctx := context.Background() + + // Find pods for this job + podList, err := c.clientset.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{ + LabelSelector: fmt.Sprintf("job-name=%s", jobName), + }) + if err != nil { + return nil, fmt.Errorf("failed to list pods: %w", err) + } + + if len(podList.Items) == 0 { + return nil, fmt.Errorf("no pods found for job %s", jobName) + } + + // Collect logs from each pod + var allPodLogs []PodLogs + for _, pod := range podList.Items { + podLogs, err := c.GetPodLogs(namespace, pod.Name) + if err != nil { + return nil, fmt.Errorf("failed to get logs for pod %s: %w", pod.Name, err) + } + allPodLogs = append(allPodLogs, podLogs) + } + + return allPodLogs, nil +} + +// GetPodLogs retrieves logs from all containers in a specific pod +func (c *Client) GetPodLogs(namespace, podName string) (PodLogs, error) { + ctx := context.Background() + + // Get pod to access its container list + pod, err := c.clientset.CoreV1().Pods(namespace).Get(ctx, podName, metav1.GetOptions{}) + if err != nil { + return PodLogs{}, fmt.Errorf("failed to get pod: %w", err) + } + + result := PodLogs{ + PodName: podName, + ContainerLogs: make([]ContainerLog, 0), + } + + // Get logs from init containers + for _, container := range pod.Spec.InitContainers { + logs, err := c.getContainerLogs(namespace, podName, container.Name) + if err != nil { + // Don't fail entirely if one container's logs are unavailable + logs = fmt.Sprintf("Error fetching logs: %v", err) + } + result.ContainerLogs = append(result.ContainerLogs, ContainerLog{ + Name: container.Name, + IsInit: true, + Logs: logs, + }) + } + + // Get logs from main containers + for _, container := range pod.Spec.Containers { + logs, err := c.getContainerLogs(namespace, podName, container.Name) + if err != nil { + // Don't fail entirely if one container's logs are unavailable + logs = fmt.Sprintf("Error fetching logs: %v", err) + } + result.ContainerLogs = append(result.ContainerLogs, ContainerLog{ + Name: container.Name, + IsInit: false, + Logs: logs, + }) + } + + return result, nil +} + +// getContainerLogs retrieves logs from a specific container in a pod +func (c *Client) getContainerLogs(namespace, podName, containerName string) (string, error) { + ctx := context.Background() + + req := c.clientset.CoreV1().Pods(namespace).GetLogs(podName, &corev1.PodLogOptions{ + Container: containerName, + }) + + podLogs, err := req.Stream(ctx) + if err != nil { + return "", err + } + defer podLogs.Close() + + buf := new(bytes.Buffer) + _, err = io.Copy(buf, podLogs) + if err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/internal/clients/k8s/secret.go b/internal/clients/k8s/secret.go new file mode 100644 index 0000000..0428e24 --- /dev/null +++ b/internal/clients/k8s/secret.go @@ -0,0 +1,89 @@ +package k8s + +import ( + "context" + "fmt" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// CreateSecret creates a Secret with the provided data +func (c *Client) CreateSecret(namespace, name string, data map[string][]byte, labels map[string]string) (*corev1.Secret, error) { + ctx := context.Background() + + if len(data) == 0 { + return nil, fmt.Errorf("no data provided for Secret") + } + + // Create default labels if none provided + if labels == nil { + labels = make(map[string]string) + } + + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Labels: labels, + }, + Data: data, + Type: corev1.SecretTypeOpaque, + } + + created, err := c.clientset.CoreV1().Secrets(namespace).Create(ctx, secret, metav1.CreateOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to create Secret: %w", err) + } + + return created, nil +} + +// GetSecret retrieves a Secret by name +func (c *Client) GetSecret(namespace, name string) (*corev1.Secret, error) { + ctx := context.Background() + return c.clientset.CoreV1().Secrets(namespace).Get(ctx, name, metav1.GetOptions{}) +} + +// UpdateSecret updates an existing Secret with new data +func (c *Client) UpdateSecret(namespace, name string, data map[string][]byte) (*corev1.Secret, error) { + ctx := context.Background() + + if len(data) == 0 { + return nil, fmt.Errorf("no data provided for Secret update") + } + + // Get existing Secret + existing, err := c.GetSecret(namespace, name) + if err != nil { + return nil, fmt.Errorf("failed to get existing Secret: %w", err) + } + + // Update data + existing.Data = data + + updated, err := c.clientset.CoreV1().Secrets(namespace).Update(ctx, existing, metav1.UpdateOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to update Secret: %w", err) + } + + return updated, nil +} + +// DeleteSecret deletes a Secret by name +func (c *Client) DeleteSecret(namespace, name string) error { + ctx := context.Background() + return c.clientset.CoreV1().Secrets(namespace).Delete(ctx, name, metav1.DeleteOptions{}) +} + +// EnsureSecret ensures a Secret exists with the provided data, creating or updating it as needed +func (c *Client) EnsureSecret(namespace, name string, data map[string][]byte, labels map[string]string) (*corev1.Secret, error) { + // Try to get existing Secret + _, err := c.GetSecret(namespace, name) + if err == nil { + // Secret exists, update it + return c.UpdateSecret(namespace, name, data) + } + + // Secret doesn't exist, create it + return c.CreateSecret(namespace, name, data, labels) +} diff --git a/internal/clients/k8s/secret_test.go b/internal/clients/k8s/secret_test.go new file mode 100644 index 0000000..732cba6 --- /dev/null +++ b/internal/clients/k8s/secret_test.go @@ -0,0 +1,296 @@ +package k8s + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +// TestClient_CreateSecret tests Secret creation +func TestClient_CreateSecret(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + data := map[string][]byte{ + "accesskey": []byte("test-access-key"), + "secretkey": []byte("test-secret-key"), + } + labels := map[string]string{ + "app": "backup", + } + + secret, err := client.CreateSecret("test-ns", "test-secret", data, labels) + + require.NoError(t, err) + assert.NotNil(t, secret) + assert.Equal(t, "test-secret", secret.Name) + assert.Equal(t, "test-ns", secret.Namespace) + assert.Equal(t, data, secret.Data) + assert.Equal(t, labels, secret.Labels) + assert.Equal(t, corev1.SecretTypeOpaque, secret.Type) + + // Verify it was created in the fake clientset + createdSecret, err := fakeClient.CoreV1().Secrets("test-ns").Get( + context.Background(), "test-secret", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, data, createdSecret.Data) +} + +// TestClient_GetSecret tests Secret retrieval +func TestClient_GetSecret(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create a Secret first + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-secret", + Namespace: "test-ns", + Labels: map[string]string{"app": "test"}, + }, + Data: map[string][]byte{ + "key": []byte("value"), + }, + Type: corev1.SecretTypeOpaque, + } + _, err := fakeClient.CoreV1().Secrets("test-ns").Create( + context.Background(), secret, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Get the Secret + retrievedSecret, err := client.GetSecret("test-ns", "test-secret") + + require.NoError(t, err) + assert.NotNil(t, retrievedSecret) + assert.Equal(t, "test-secret", retrievedSecret.Name) + assert.Equal(t, map[string][]byte{"key": []byte("value")}, retrievedSecret.Data) +} + +// TestClient_GetSecret_NotFound tests error when Secret doesn't exist +func TestClient_GetSecret_NotFound(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + _, err := client.GetSecret("test-ns", "nonexistent-secret") + + assert.Error(t, err) +} + +// TestClient_UpdateSecret tests Secret update +func TestClient_UpdateSecret(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create a Secret first + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-secret", + Namespace: "test-ns", + }, + Data: map[string][]byte{ + "key": []byte("oldvalue"), + }, + Type: corev1.SecretTypeOpaque, + } + _, err := fakeClient.CoreV1().Secrets("test-ns").Create( + context.Background(), secret, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Update the Secret + newData := map[string][]byte{ + "key": []byte("newvalue"), + "key2": []byte("value2"), + } + updatedSecret, err := client.UpdateSecret("test-ns", "test-secret", newData) + + require.NoError(t, err) + assert.NotNil(t, updatedSecret) + assert.Equal(t, newData, updatedSecret.Data) + + // Verify the update in fake clientset + retrievedSecret, err := fakeClient.CoreV1().Secrets("test-ns").Get( + context.Background(), "test-secret", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, newData, retrievedSecret.Data) +} + +// TestClient_UpdateSecret_NotFound tests error when Secret doesn't exist +func TestClient_UpdateSecret_NotFound(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + _, err := client.UpdateSecret("test-ns", "nonexistent-secret", map[string][]byte{"key": []byte("value")}) + + assert.Error(t, err) +} + +// TestClient_DeleteSecret tests Secret deletion +func TestClient_DeleteSecret(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create a Secret first + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-secret", + Namespace: "test-ns", + }, + Data: map[string][]byte{ + "key": []byte("value"), + }, + Type: corev1.SecretTypeOpaque, + } + _, err := fakeClient.CoreV1().Secrets("test-ns").Create( + context.Background(), secret, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Delete the Secret + err = client.DeleteSecret("test-ns", "test-secret") + + require.NoError(t, err) + + // Verify it was deleted + _, err = fakeClient.CoreV1().Secrets("test-ns").Get( + context.Background(), "test-secret", metav1.GetOptions{}, + ) + assert.Error(t, err) +} + +// TestClient_EnsureSecret_Create tests EnsureSecret when Secret doesn't exist +func TestClient_EnsureSecret_Create(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + data := map[string][]byte{ + "accesskey": []byte("access"), + "secretkey": []byte("secret"), + } + labels := map[string]string{"app": "test"} + + secret, err := client.EnsureSecret("test-ns", "test-secret", data, labels) + + require.NoError(t, err) + assert.NotNil(t, secret) + assert.Equal(t, "test-secret", secret.Name) + assert.Equal(t, data, secret.Data) + assert.Equal(t, labels, secret.Labels) + + // Verify it was created + createdSecret, err := fakeClient.CoreV1().Secrets("test-ns").Get( + context.Background(), "test-secret", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, data, createdSecret.Data) +} + +// TestClient_EnsureSecret_Update tests EnsureSecret when Secret exists +func TestClient_EnsureSecret_Update(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Create existing Secret + existingSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-secret", + Namespace: "test-ns", + }, + Data: map[string][]byte{ + "key": []byte("oldvalue"), + }, + Type: corev1.SecretTypeOpaque, + } + _, err := fakeClient.CoreV1().Secrets("test-ns").Create( + context.Background(), existingSecret, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Ensure with new data + newData := map[string][]byte{ + "key": []byte("newvalue"), + "key2": []byte("value2"), + } + labels := map[string]string{"app": "updated"} + + secret, err := client.EnsureSecret("test-ns", "test-secret", newData, labels) + + require.NoError(t, err) + assert.NotNil(t, secret) + assert.Equal(t, newData, secret.Data) + + // Verify it was updated + updatedSecret, err := fakeClient.CoreV1().Secrets("test-ns").Get( + context.Background(), "test-secret", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, newData, updatedSecret.Data) +} + +// TestClient_EnsureSecret_NoChange tests EnsureSecret when data matches +func TestClient_EnsureSecret_NoChange(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + data := map[string][]byte{ + "key": []byte("value"), + } + + // Create existing Secret + existingSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-secret", + Namespace: "test-ns", + }, + Data: map[string][]byte{ + "key": []byte("value"), + }, + Type: corev1.SecretTypeOpaque, + } + _, err := fakeClient.CoreV1().Secrets("test-ns").Create( + context.Background(), existingSecret, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Ensure with same data + secret, err := client.EnsureSecret("test-ns", "test-secret", data, nil) + + require.NoError(t, err) + assert.NotNil(t, secret) + assert.Equal(t, data, secret.Data) +} + +// TestClient_Secret_SensitiveData tests that secret data is handled correctly +func TestClient_Secret_SensitiveData(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := &Client{clientset: fakeClient} + + // Test with various sensitive data types + data := map[string][]byte{ + "password": []byte("super-secret-password"), + "api-key": []byte("api-key-12345"), + "certificate": []byte("-----BEGIN CERTIFICATE-----\nMIIC..."), + "private-key": []byte("-----BEGIN PRIVATE KEY-----\nMIIE..."), + "token": []byte("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."), + "empty-value": []byte(""), + "special-chars": []byte("!@#$%^&*(){}[]|\\:;\"'<>,.?/~`"), + } + + secret, err := client.CreateSecret("test-ns", "sensitive-secret", data, nil) + + require.NoError(t, err) + assert.NotNil(t, secret) + + // Verify all data was stored correctly + for key, value := range data { + assert.Equal(t, value, secret.Data[key], "Data for key %s should match", key) + } +} diff --git a/internal/clients/k8s/utils.go b/internal/clients/k8s/utils.go new file mode 100644 index 0000000..9230c62 --- /dev/null +++ b/internal/clients/k8s/utils.go @@ -0,0 +1,20 @@ +package k8s + +// MergeLabels merges commonLabels with resource-specific labels. +// Resource-specific labels take precedence over commonLabels. +// Returns a new map with the merged labels. +func MergeLabels(commonLabels, resourceLabels map[string]string) map[string]string { + merged := make(map[string]string) + + // Add common labels first + for k, v := range commonLabels { + merged[k] = v + } + + // Override with resource-specific labels + for k, v := range resourceLabels { + merged[k] = v + } + + return merged +} diff --git a/internal/clients/k8s/utils_test.go b/internal/clients/k8s/utils_test.go new file mode 100644 index 0000000..eeba9c3 --- /dev/null +++ b/internal/clients/k8s/utils_test.go @@ -0,0 +1,236 @@ +package k8s + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestMergeLabels_BothEmpty tests merging with both maps empty +func TestMergeLabels_BothEmpty(t *testing.T) { + result := MergeLabels(map[string]string{}, map[string]string{}) + assert.Empty(t, result) +} + +// TestMergeLabels_BothNil tests merging with both maps nil +func TestMergeLabels_BothNil(t *testing.T) { + result := MergeLabels(nil, nil) + assert.NotNil(t, result) + assert.Empty(t, result) +} + +// TestMergeLabels_OnlyCommonLabels tests merging with only common labels +func TestMergeLabels_OnlyCommonLabels(t *testing.T) { + common := map[string]string{ + "app": "backup", + "version": "1.0", + "env": "prod", + } + + result := MergeLabels(common, nil) + + assert.Len(t, result, 3) + assert.Equal(t, "backup", result["app"]) + assert.Equal(t, "1.0", result["version"]) + assert.Equal(t, "prod", result["env"]) +} + +// TestMergeLabels_OnlyResourceLabels tests merging with only resource labels +func TestMergeLabels_OnlyResourceLabels(t *testing.T) { + resource := map[string]string{ + "component": "elasticsearch", + "tier": "backend", + } + + result := MergeLabels(nil, resource) + + assert.Len(t, result, 2) + assert.Equal(t, "elasticsearch", result["component"]) + assert.Equal(t, "backend", result["tier"]) +} + +// TestMergeLabels_NoOverlap tests merging with no overlapping keys +func TestMergeLabels_NoOverlap(t *testing.T) { + common := map[string]string{ + "app": "backup", + "version": "1.0", + } + resource := map[string]string{ + "component": "elasticsearch", + "tier": "backend", + } + + result := MergeLabels(common, resource) + + assert.Len(t, result, 4) + assert.Equal(t, "backup", result["app"]) + assert.Equal(t, "1.0", result["version"]) + assert.Equal(t, "elasticsearch", result["component"]) + assert.Equal(t, "backend", result["tier"]) +} + +// TestMergeLabels_WithOverlap tests that resource labels override common labels +func TestMergeLabels_WithOverlap(t *testing.T) { + common := map[string]string{ + "app": "backup", + "version": "1.0", + "env": "prod", + } + resource := map[string]string{ + "version": "2.0", // Override + "env": "staging", // Override + "tier": "backend", // New label + } + + result := MergeLabels(common, resource) + + assert.Len(t, result, 4) + assert.Equal(t, "backup", result["app"]) // From common + assert.Equal(t, "2.0", result["version"]) // Overridden by resource + assert.Equal(t, "staging", result["env"]) // Overridden by resource + assert.Equal(t, "backend", result["tier"]) // From resource +} + +// TestMergeLabels_OverrideAllCommon tests resource labels completely override common labels +func TestMergeLabels_OverrideAllCommon(t *testing.T) { + common := map[string]string{ + "app": "backup", + "env": "prod", + } + resource := map[string]string{ + "app": "restore", + "env": "dev", + } + + result := MergeLabels(common, resource) + + assert.Len(t, result, 2) + assert.Equal(t, "restore", result["app"]) // Overridden + assert.Equal(t, "dev", result["env"]) // Overridden +} + +// TestMergeLabels_EmptyStrings tests handling of empty string values +func TestMergeLabels_EmptyStrings(t *testing.T) { + common := map[string]string{ + "app": "backup", + "version": "", + } + resource := map[string]string{ + "env": "", + } + + result := MergeLabels(common, resource) + + assert.Len(t, result, 3) + assert.Equal(t, "backup", result["app"]) + assert.Equal(t, "", result["version"]) + assert.Equal(t, "", result["env"]) +} + +// TestMergeLabels_SpecialCharacters tests handling of special characters in keys and values +func TestMergeLabels_SpecialCharacters(t *testing.T) { + common := map[string]string{ + "app.kubernetes.io/name": "backup", + "app.kubernetes.io/version": "1.0", + "app.kubernetes.io/managed-by": "helm", + } + resource := map[string]string{ + "app.kubernetes.io/version": "2.0", // Override + "app.kubernetes.io/component": "elasticsearch", + } + + result := MergeLabels(common, resource) + + assert.Len(t, result, 4) + assert.Equal(t, "backup", result["app.kubernetes.io/name"]) + assert.Equal(t, "2.0", result["app.kubernetes.io/version"]) // Overridden + assert.Equal(t, "helm", result["app.kubernetes.io/managed-by"]) + assert.Equal(t, "elasticsearch", result["app.kubernetes.io/component"]) +} + +// TestMergeLabels_DoesNotModifyInput tests that input maps are not modified +func TestMergeLabels_DoesNotModifyInput(t *testing.T) { + common := map[string]string{ + "app": "backup", + "env": "prod", + } + resource := map[string]string{ + "version": "1.0", + } + + // Keep copies for comparison + commonCopy := make(map[string]string) + for k, v := range common { + commonCopy[k] = v + } + resourceCopy := make(map[string]string) + for k, v := range resource { + resourceCopy[k] = v + } + + result := MergeLabels(common, resource) + + // Verify input maps weren't modified + assert.Equal(t, commonCopy, common) + assert.Equal(t, resourceCopy, resource) + + // Verify result is independent + result["new-key"] = "new-value" + assert.NotContains(t, common, "new-key") + assert.NotContains(t, resource, "new-key") +} + +// TestMergeLabels_KubernetesLabels tests realistic Kubernetes label scenarios +func TestMergeLabels_KubernetesLabels(t *testing.T) { + tests := []struct { + name string + commonLabels map[string]string + resourceLabels map[string]string + expected map[string]string + }{ + { + name: "backup job labels", + commonLabels: map[string]string{ + "app.kubernetes.io/name": "suse-observability", + "app.kubernetes.io/managed-by": "helm", + "helm.sh/chart": "suse-observability-1.0.0", + }, + resourceLabels: map[string]string{ + "app.kubernetes.io/component": "backup", + "backup.stackstate.com/type": "elasticsearch", + }, + expected: map[string]string{ + "app.kubernetes.io/name": "suse-observability", + "app.kubernetes.io/managed-by": "helm", + "helm.sh/chart": "suse-observability-1.0.0", + "app.kubernetes.io/component": "backup", + "backup.stackstate.com/type": "elasticsearch", + }, + }, + { + name: "restore job labels with override", + commonLabels: map[string]string{ + "app": "backup-tool", + "app.kubernetes.io/version": "1.0", + "app.kubernetes.io/managed-by": "operator", + }, + resourceLabels: map[string]string{ + "app.kubernetes.io/version": "2.0", // Override version + "app.kubernetes.io/component": "restore", + }, + expected: map[string]string{ + "app": "backup-tool", + "app.kubernetes.io/version": "2.0", // Overridden + "app.kubernetes.io/managed-by": "operator", + "app.kubernetes.io/component": "restore", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := MergeLabels(tt.commonLabels, tt.resourceLabels) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/clients/s3/client.go b/internal/clients/s3/client.go new file mode 100644 index 0000000..14ec89c --- /dev/null +++ b/internal/clients/s3/client.go @@ -0,0 +1,39 @@ +package s3 + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +const ( + // DefaultRegion is the default region for Minio S3 clients + DefaultRegion = "minio" +) + +// NewClient creates a new S3 client configured for Minio +// Returns an Interface that wraps the underlying AWS S3 client +func NewClient(endpoint, accessKey, secretKey string) (Interface, error) { + awsCfg, err := config.LoadDefaultConfig(context.Background(), + config.WithRegion(DefaultRegion), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider( + accessKey, + secretKey, + "", + )), + ) + if err != nil { + return nil, fmt.Errorf("failed to create AWS config: %w", err) + } + + s3Client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { + o.BaseEndpoint = aws.String(endpoint) + o.UsePathStyle = true + }) + + return &Client{client: s3Client}, nil +} diff --git a/internal/clients/s3/client_test.go b/internal/clients/s3/client_test.go new file mode 100644 index 0000000..2535699 --- /dev/null +++ b/internal/clients/s3/client_test.go @@ -0,0 +1,243 @@ +package s3 + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestNewClient tests S3 client creation with various configurations +func TestNewClient(t *testing.T) { + tests := []struct { + name string + endpoint string + accessKey string + secretKey string + expectError bool + }{ + { + name: "valid minio configuration", + endpoint: "http://minio:9000", + accessKey: "minioadmin", + secretKey: "minioadmin", + expectError: false, + }, + { + name: "valid configuration with IP", + endpoint: "http://192.168.1.100:9000", + accessKey: "test-access-key", + secretKey: "test-secret-key", + expectError: false, + }, + { + name: "valid configuration with https", + endpoint: "https://s3.example.com", + accessKey: "access123", + secretKey: "secret123", + expectError: false, + }, + { + name: "valid configuration with localhost", + endpoint: "http://localhost:9000", + accessKey: "local-access", + secretKey: "local-secret", + expectError: false, + }, + { + name: "empty endpoint", + endpoint: "", + accessKey: "access", + secretKey: "secret", + expectError: false, // AWS SDK allows empty endpoint (uses default) + }, + { + name: "empty credentials", + endpoint: "http://minio:9000", + accessKey: "", + secretKey: "", + expectError: false, // Client creation succeeds, but operations will fail + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient(tt.endpoint, tt.accessKey, tt.secretKey) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, client) + } else { + require.NoError(t, err) + assert.NotNil(t, client) + } + }) + } +} + +// TestNewClient_ClientConfiguration tests that the client is configured correctly +func TestNewClient_ClientConfiguration(t *testing.T) { + endpoint := "http://test-minio:9000" + accessKey := "test-access" + secretKey := "test-secret" + + client, err := NewClient(endpoint, accessKey, secretKey) + + require.NoError(t, err) + require.NotNil(t, client) + + // Verify client was created (we can't easily inspect internal config without integration tests) + // But we can verify the client is not nil and has the expected type + assert.IsType(t, client, client) +} + +// TestDefaultRegion tests the default region constant +func TestDefaultRegion(t *testing.T) { + assert.Equal(t, "minio", DefaultRegion, "Default region should be 'minio' for Minio compatibility") +} + +// TestNewClient_Integration demonstrates integration test pattern +// This test is skipped by default and requires a real Minio instance +func TestNewClient_Integration(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + // This would require a real Minio instance running + // Uncomment and configure when integration testing is needed + t.Skip("integration test requires real Minio instance") + + // Example integration test pattern: + // client, err := NewClient("http://localhost:9000", "minioadmin", "minioadmin") + // require.NoError(t, err) + // + // ctx := context.Background() + // _, err = client.ListBuckets(ctx, &s3.ListBucketsInput{}) + // assert.NoError(t, err, "Should be able to list buckets with valid client") +} + +// TestNewClient_CredentialFormats tests various credential format edge cases +func TestNewClient_CredentialFormats(t *testing.T) { + tests := []struct { + name string + accessKey string + secretKey string + }{ + { + name: "alphanumeric credentials", + accessKey: "abc123XYZ", + secretKey: "xyz789ABC", + }, + { + name: "credentials with special characters", + accessKey: "access+key/with=chars", + secretKey: "secret-key_with.chars", + }, + { + name: "very long credentials", + accessKey: "very-long-access-key-with-many-characters-0123456789", + secretKey: "very-long-secret-key-with-many-characters-9876543210", + }, + { + name: "credentials with spaces (edge case)", + accessKey: "access key with spaces", + secretKey: "secret key with spaces", + }, + { + name: "single character credentials", + accessKey: "a", + secretKey: "s", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient("http://minio:9000", tt.accessKey, tt.secretKey) + + // Client creation should succeed regardless of credential format + // The credentials will be validated when actual S3 operations are performed + assert.NoError(t, err) + assert.NotNil(t, client) + }) + } +} + +// TestNewClient_EndpointFormats tests various endpoint format edge cases +func TestNewClient_EndpointFormats(t *testing.T) { + tests := []struct { + name string + endpoint string + }{ + { + name: "endpoint with http scheme", + endpoint: "http://minio.example.com:9000", + }, + { + name: "endpoint with https scheme", + endpoint: "https://s3.amazonaws.com", + }, + { + name: "endpoint without port", + endpoint: "http://minio.local", + }, + { + name: "endpoint with non-standard port", + endpoint: "http://minio:8080", + }, + { + name: "endpoint with path", + endpoint: "http://minio:9000/path/to/s3", + }, + { + name: "endpoint as IP address", + endpoint: "http://127.0.0.1:9000", + }, + { + name: "endpoint as IPv6 address", + endpoint: "http://[::1]:9000", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient(tt.endpoint, "access", "secret") + + // Client creation should succeed for all valid endpoint formats + assert.NoError(t, err) + assert.NotNil(t, client) + }) + } +} + +// TestNewClient_ConcurrentCreation tests that client creation is safe for concurrent use +func TestNewClient_ConcurrentCreation(t *testing.T) { + const numGoroutines = 10 + + done := make(chan bool, numGoroutines) + errors := make(chan error, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + _, err := NewClient("http://minio:9000", "access", "secret") + if err != nil { + errors <- err + } + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < numGoroutines; i++ { + <-done + } + close(errors) + + // Check that no errors occurred + errorCount := 0 + for err := range errors { + t.Errorf("Concurrent client creation failed: %v", err) + errorCount++ + } + + assert.Equal(t, 0, errorCount, "All concurrent client creations should succeed") +} diff --git a/internal/clients/s3/filter.go b/internal/clients/s3/filter.go new file mode 100644 index 0000000..32cafac --- /dev/null +++ b/internal/clients/s3/filter.go @@ -0,0 +1,142 @@ +package s3 + +import ( + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" +) + +const ( + multipartArchiveSuffixLength = 2 +) + +// Object represents a simplified S3 object with key metadata +type Object struct { + Key string + LastModified time.Time + Size int64 +} + +// FilterBackupObjects filters S3 objects based on whether the archive is split or not +// If it is not multipartArchive, it filters out multipart archives (files ending with .digits) +// Otherwise, it groups multipart archives by base name and sums their sizes +func FilterBackupObjects(objects []s3types.Object, multipartArchive bool) []Object { + if !multipartArchive { + return filterNonMultipart(objects) + } + return aggregateMultipart(objects) +} + +// filterNonMultipart filters out multipart archives (files ending with .digits) +func filterNonMultipart(objects []s3types.Object) []Object { + var filteredObjects []Object + + for _, obj := range objects { + key := aws.ToString(obj.Key) + + // Skip if it ends with .digits (multipart archive) + if strings.Contains(key, ".") { + parts := strings.Split(key, ".") + lastPart := parts[len(parts)-1] + isDigits := true + for _, c := range lastPart { + if c < '0' || c > '9' { + isDigits = false + break + } + } + if isDigits && len(lastPart) > 0 { + continue + } + } + + filteredObjects = append(filteredObjects, Object{ + Key: key, + LastModified: aws.ToTime(obj.LastModified), + Size: aws.ToInt64(obj.Size), + }) + } + + return filteredObjects +} + +// aggregateMultipart groups multipart archives by base name and sums their sizes +func aggregateMultipart(objects []s3types.Object) []Object { + // Map to group objects by base name + archiveMap := make(map[string]*Object) + + for _, obj := range objects { + key := aws.ToString(obj.Key) + + // Check if this is a multipart file (ends with .NN where NN are digits) + baseName, isMultipart := getBaseName(key) + if !isMultipart { + // Not a multipart file, include as-is + archiveMap[key] = &Object{ + Key: key, + LastModified: aws.ToTime(obj.LastModified), + Size: aws.ToInt64(obj.Size), + } + continue + } + + // Group multipart files by base name + if existing, exists := archiveMap[baseName]; exists { + // Add size to existing entry + existing.Size += aws.ToInt64(obj.Size) + // Keep the most recent LastModified time + if aws.ToTime(obj.LastModified).After(existing.LastModified) { + existing.LastModified = aws.ToTime(obj.LastModified) + } + } else { + // Create new entry + archiveMap[baseName] = &Object{ + Key: baseName, + LastModified: aws.ToTime(obj.LastModified), + Size: aws.ToInt64(obj.Size), + } + } + } + + // Convert map to slice + var filteredObjects []Object + for _, obj := range archiveMap { + filteredObjects = append(filteredObjects, *obj) + } + + return filteredObjects +} + +// getBaseName extracts the base name from a multipart archive filename +// Returns (baseName, isMultipart) +// Example: "backup.graph.00" -> ("backup.graph", true) +// +// "backup.graph" -> ("backup.graph", false) +func getBaseName(key string) (string, bool) { + if !strings.Contains(key, ".") { + return key, false + } + + parts := strings.Split(key, ".") + lastPart := parts[len(parts)-1] + + // Check if last part is all digits (2 digits for part numbers like .00, .01, etc.) + if len(lastPart) == multipartArchiveSuffixLength { + isDigits := true + for _, c := range lastPart { + if c < '0' || c > '9' { + isDigits = false + break + } + } + if isDigits { + // Remove the .NN suffix to get base name + baseName := strings.Join(parts[:len(parts)-1], ".") + return baseName, true + } + } + + return key, false +} diff --git a/internal/clients/s3/filter_test.go b/internal/clients/s3/filter_test.go new file mode 100644 index 0000000..2cf9b86 --- /dev/null +++ b/internal/clients/s3/filter_test.go @@ -0,0 +1,366 @@ +package s3 + +import ( + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/stretchr/testify/assert" +) + +// TestFilterBackupObjects_SingleFileMode tests filtering when it is not multipartArchive +func TestFilterBackupObjects_SingleFileMode(t *testing.T) { + tests := []struct { + name string + objects []s3types.Object + expectedCount int + expectedKeys []string + }{ + { + name: "filters out multipart archives with numeric extensions", + objects: []s3types.Object{ + {Key: aws.String("backup-2024-01-01.tar.gz"), Size: aws.Int64(1000)}, + {Key: aws.String("backup-2024-01-02.00"), Size: aws.Int64(1000)}, + {Key: aws.String("backup-2024-01-02.01"), Size: aws.Int64(1000)}, + {Key: aws.String("backup-2024-01-02.02"), Size: aws.Int64(1000)}, + {Key: aws.String("backup-2024-01-03.tar"), Size: aws.Int64(1000)}, + }, + expectedCount: 2, + expectedKeys: []string{"backup-2024-01-01.tar.gz", "backup-2024-01-03.tar"}, + }, + { + name: "includes files without extensions", + objects: []s3types.Object{ + {Key: aws.String("backup-no-extension"), Size: aws.Int64(1000)}, + {Key: aws.String("backup-with-ext.tar"), Size: aws.Int64(2000)}, + }, + expectedCount: 2, + expectedKeys: []string{"backup-no-extension", "backup-with-ext.tar"}, + }, + { + name: "includes files with non-numeric extensions", + objects: []s3types.Object{ + {Key: aws.String("backup.tar.gz"), Size: aws.Int64(1000)}, + {Key: aws.String("backup.log"), Size: aws.Int64(100)}, + {Key: aws.String("backup.00"), Size: aws.Int64(1000)}, + {Key: aws.String("backup.abc"), Size: aws.Int64(1000)}, + }, + expectedCount: 3, + expectedKeys: []string{"backup.tar.gz", "backup.log", "backup.abc"}, + }, + { + name: "handles edge case with single digit extension", + objects: []s3types.Object{ + {Key: aws.String("backup.0"), Size: aws.Int64(1000)}, + {Key: aws.String("backup.1"), Size: aws.Int64(1000)}, + {Key: aws.String("backup.tar"), Size: aws.Int64(2000)}, + }, + expectedCount: 1, + expectedKeys: []string{"backup.tar"}, + }, + { + name: "handles multiple dots in filename", + objects: []s3types.Object{ + {Key: aws.String("backup.2024.01.01.tar.gz"), Size: aws.Int64(1000)}, + {Key: aws.String("backup.2024.01.02.00"), Size: aws.Int64(1000)}, + {Key: aws.String("backup.2024.01.03.999"), Size: aws.Int64(1000)}, + }, + expectedCount: 1, + expectedKeys: []string{"backup.2024.01.01.tar.gz"}, + }, + { + name: "handles empty object list", + objects: []s3types.Object{}, + expectedCount: 0, + expectedKeys: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FilterBackupObjects(tt.objects, false) + + assert.Equal(t, tt.expectedCount, len(result)) + + resultKeys := make([]string, len(result)) + for i, obj := range result { + resultKeys[i] = obj.Key + } + + assert.Equal(t, tt.expectedKeys, resultKeys) + }) + } +} + +// TestFilterBackupObjects_MultipartMode tests filtering when it is multipartArchive +func TestFilterBackupObjects_MultipartMode(t *testing.T) { + tests := []struct { + name string + objects []s3types.Object + multipartArchive bool + expectedCount int + expectedKeys []string + }{ + { + name: "groups multipart archives and sums their sizes", + objects: []s3types.Object{ + {Key: aws.String("backup-2024-01-01.00"), Size: aws.Int64(500000000)}, + {Key: aws.String("backup-2024-01-01.01"), Size: aws.Int64(500000000)}, + {Key: aws.String("backup-2024-01-01.02"), Size: aws.Int64(300000000)}, + {Key: aws.String("backup-2024-01-02.00"), Size: aws.Int64(500000000)}, + {Key: aws.String("backup-2024-01-02.01"), Size: aws.Int64(400000000)}, + }, + multipartArchive: true, + expectedCount: 2, + expectedKeys: []string{"backup-2024-01-01", "backup-2024-01-02"}, + }, + { + name: "includes both multipart and single files", + objects: []s3types.Object{ + {Key: aws.String("backup.tar.gz"), Size: aws.Int64(1000)}, + {Key: aws.String("backup-split.00"), Size: aws.Int64(500000000)}, + {Key: aws.String("backup-split.01"), Size: aws.Int64(500000000)}, + {Key: aws.String("backup-single"), Size: aws.Int64(1000000)}, + }, + multipartArchive: true, + expectedCount: 3, + expectedKeys: []string{"backup.tar.gz", "backup-split", "backup-single"}, + }, + { + name: "handles different split size values", + objects: []s3types.Object{ + {Key: aws.String("backup-1G.00"), Size: aws.Int64(1000000000)}, + {Key: aws.String("backup-1G.01"), Size: aws.Int64(500000000)}, + }, + multipartArchive: true, + expectedCount: 1, + expectedKeys: []string{"backup-1G"}, + }, + { + name: "handles empty object list", + objects: []s3types.Object{}, + multipartArchive: true, + expectedCount: 0, + expectedKeys: []string{}, + }, + { + name: "handles objects with .00 in middle of filename", + objects: []s3types.Object{ + {Key: aws.String("backup.00.tar"), Size: aws.Int64(1000)}, + {Key: aws.String("backup-final.00"), Size: aws.Int64(500000000)}, + }, + multipartArchive: true, + expectedCount: 2, + expectedKeys: []string{"backup.00.tar", "backup-final"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FilterBackupObjects(tt.objects, tt.multipartArchive) + + assert.Equal(t, tt.expectedCount, len(result)) + + resultKeys := make([]string, len(result)) + for i, obj := range result { + resultKeys[i] = obj.Key + } + + assert.ElementsMatch(t, tt.expectedKeys, resultKeys) + }) + } +} + +// TestFilterBackupObjects_ObjectMetadata tests that metadata is preserved correctly +func TestFilterBackupObjects_ObjectMetadata(t *testing.T) { + now := time.Now() + yesterday := now.Add(-24 * time.Hour) + + objects := []s3types.Object{ + { + Key: aws.String("backup-2024-01-01.tar.gz"), + Size: aws.Int64(1234567890), + LastModified: aws.Time(now), + }, + { + Key: aws.String("backup-2024-01-02.00"), + Size: aws.Int64(100000000), + LastModified: aws.Time(yesterday), + }, + { + Key: aws.String("backup-2024-01-02.01"), + Size: aws.Int64(50000000), + LastModified: aws.Time(yesterday.Add(1 * time.Minute)), // Slightly later + }, + } + + // Test single file mode + result := FilterBackupObjects(objects, false) + assert.Equal(t, 1, len(result)) + assert.Equal(t, "backup-2024-01-01.tar.gz", result[0].Key) + assert.Equal(t, int64(1234567890), result[0].Size) + assert.Equal(t, now.Unix(), result[0].LastModified.Unix()) + + // Test multipart mode - should group parts and sum sizes + result = FilterBackupObjects(objects, true) + assert.Equal(t, 2, len(result)) // tar.gz file + grouped multipart + + // Find the multipart archive result + var multipartResult *Object + var singleResult *Object + for i := range result { + switch result[i].Key { + case "backup-2024-01-02": + multipartResult = &result[i] + case "backup-2024-01-01.tar.gz": + singleResult = &result[i] + } + } + + assert.NotNil(t, multipartResult, "Should find grouped multipart archive") + assert.NotNil(t, singleResult, "Should find single file") + + // Verify multipart archive has summed size + assert.Equal(t, "backup-2024-01-02", multipartResult.Key) + assert.Equal(t, int64(150000000), multipartResult.Size) // 100M + 50M + assert.Equal(t, yesterday.Add(1*time.Minute).Unix(), multipartResult.LastModified.Unix()) // Most recent timestamp + + // Verify single file + assert.Equal(t, "backup-2024-01-01.tar.gz", singleResult.Key) + assert.Equal(t, int64(1234567890), singleResult.Size) +} + +// TestFilterBackupObjects_EdgeCases tests edge cases and boundary conditions +func TestFilterBackupObjects_EdgeCases(t *testing.T) { + tests := []struct { + name string + objects []s3types.Object + multipartArchive bool + expectedCount int + }{ + { + name: "file ending with just a dot", + objects: []s3types.Object{ + {Key: aws.String("backup."), Size: aws.Int64(1000)}, + }, + multipartArchive: false, + expectedCount: 1, // Empty extension, should be included + }, + { + name: "file with mixed alphanumeric extension", + objects: []s3types.Object{ + {Key: aws.String("backup.123abc"), Size: aws.Int64(1000)}, + }, + multipartArchive: false, + expectedCount: 1, // Contains non-digits, should be included + }, + { + name: "very long numeric extension", + objects: []s3types.Object{ + {Key: aws.String("backup.00000000000000000001"), Size: aws.Int64(1000)}, + }, + multipartArchive: false, + expectedCount: 0, // All digits, should be filtered + }, + { + name: "multipart with .00 but it is not multipartArchive", + objects: []s3types.Object{ + {Key: aws.String("backup.00"), Size: aws.Int64(1000)}, + }, + multipartArchive: false, + expectedCount: 0, // Numeric extension in single file mode, should be filtered + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FilterBackupObjects(tt.objects, tt.multipartArchive) + assert.Equal(t, tt.expectedCount, len(result)) + }) + } +} + +// TestFilterBackupObjects_RealWorldScenarios tests realistic backup scenarios +func TestFilterBackupObjects_RealWorldScenarios(t *testing.T) { + tests := []struct { + name string + scenario string + objects []s3types.Object + multipartArchive bool + expectedCount int + }{ + { + name: "stackgraph backups without splitting", + scenario: "Single large archive files", + objects: []s3types.Object{ + {Key: aws.String("stackgraph-backup-2024-01-01.tar.gz"), Size: aws.Int64(10000000)}, + {Key: aws.String("stackgraph-backup-2024-01-02.tar.gz"), Size: aws.Int64(12000000)}, + {Key: aws.String("stackgraph-backup-2024-01-03.tar.gz"), Size: aws.Int64(11000000)}, + }, + multipartArchive: false, + expectedCount: 3, + }, + { + name: "stackgraph backups with 500M splitting", + scenario: "Split archives at 500M", + objects: []s3types.Object{ + {Key: aws.String("stackgraph-backup-2024-01-01.00"), Size: aws.Int64(500000000)}, + {Key: aws.String("stackgraph-backup-2024-01-01.01"), Size: aws.Int64(500000000)}, + {Key: aws.String("stackgraph-backup-2024-01-01.02"), Size: aws.Int64(200000000)}, + {Key: aws.String("stackgraph-backup-2024-01-02.00"), Size: aws.Int64(500000000)}, + {Key: aws.String("stackgraph-backup-2024-01-02.01"), Size: aws.Int64(300000000)}, + }, + multipartArchive: true, + expectedCount: 2, // Two grouped multipart archives + }, + { + name: "mixed backup types in same bucket", + scenario: "Single and split backups mixed", + objects: []s3types.Object{ + {Key: aws.String("backup-old.tar.gz"), Size: aws.Int64(1000000)}, + {Key: aws.String("backup-new-split.00"), Size: aws.Int64(500000000)}, + {Key: aws.String("backup-new-split.01"), Size: aws.Int64(400000000)}, + {Key: aws.String("backup-old-split.00"), Size: aws.Int64(1000)}, + {Key: aws.String("backup-old-split.01"), Size: aws.Int64(1000)}, + }, + multipartArchive: true, + expectedCount: 3, // One single file + two grouped multipart archives + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := FilterBackupObjects(tt.objects, tt.multipartArchive) + assert.Equal(t, tt.expectedCount, len(result), "Scenario: %s", tt.scenario) + }) + } +} + +// TestFilterBackupObjects_SizeSummation tests that sizes are correctly summed for multipart archives +func TestFilterBackupObjects_SizeSummation(t *testing.T) { + objects := []s3types.Object{ + {Key: aws.String("sts-backup-20251028-1546.graph.00"), Size: aws.Int64(104857600)}, + {Key: aws.String("sts-backup-20251028-1546.graph.01"), Size: aws.Int64(6885342)}, + {Key: aws.String("sts-backup-20251029-0300.graph.00"), Size: aws.Int64(104857600)}, + {Key: aws.String("sts-backup-20251029-0300.graph.01"), Size: aws.Int64(4348555)}, + {Key: aws.String("sts-backup-20251029-0924.graph.00"), Size: aws.Int64(104857600)}, + {Key: aws.String("sts-backup-20251029-0924.graph.01"), Size: aws.Int64(6567239)}, + } + + result := FilterBackupObjects(objects, true) + + // Should have 3 grouped archives + assert.Equal(t, 3, len(result)) + + // Create a map for easier lookup + sizeMap := make(map[string]int64) + for _, obj := range result { + sizeMap[obj.Key] = obj.Size + } + + // Verify sizes are correctly summed + assert.Equal(t, int64(111742942), sizeMap["sts-backup-20251028-1546.graph"]) // 104857600 + 6885342 + assert.Equal(t, int64(109206155), sizeMap["sts-backup-20251029-0300.graph"]) // 104857600 + 4348555 + assert.Equal(t, int64(111424839), sizeMap["sts-backup-20251029-0924.graph"]) // 104857600 + 6567239 +} diff --git a/internal/clients/s3/interface.go b/internal/clients/s3/interface.go new file mode 100644 index 0000000..5403eef --- /dev/null +++ b/internal/clients/s3/interface.go @@ -0,0 +1,44 @@ +package s3 + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +// Interface defines the contract for S3 operations +// This interface allows mocking S3 operations in tests +type Interface interface { + ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) + GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) + PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) + DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) +} + +// Client wraps the AWS S3 client to implement our interface +type Client struct { + client *s3.Client +} + +// Ensure Client implements Interface at compile time +var _ Interface = (*Client)(nil) + +// ListObjectsV2 lists objects in an S3 bucket +func (c *Client) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { + return c.client.ListObjectsV2(ctx, params, optFns...) +} + +// GetObject retrieves an object from S3 +func (c *Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + return c.client.GetObject(ctx, params, optFns...) +} + +// PutObject uploads an object to S3 +func (c *Client) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { + return c.client.PutObject(ctx, params, optFns...) +} + +// DeleteObject deletes an object from S3 +func (c *Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) { + return c.client.DeleteObject(ctx, params, optFns...) +} diff --git a/internal/clients/s3/interface_test.go b/internal/clients/s3/interface_test.go new file mode 100644 index 0000000..053354d --- /dev/null +++ b/internal/clients/s3/interface_test.go @@ -0,0 +1,171 @@ +package s3 + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/stretchr/testify/assert" +) + +// TestClientImplementsInterface verifies that Client implements Interface at compile time +func TestClientImplementsInterface(_ *testing.T) { + var _ Interface = (*Client)(nil) +} + +// TestInterfaceContract verifies that Client correctly wraps AWS S3 client methods +func TestInterfaceContract(t *testing.T) { + // Create a client + client, err := NewClient("http://test-minio:9000", "test-access", "test-secret") + assert.NoError(t, err) + assert.NotNil(t, client) + + // Verify it's specifically a *Client that implements Interface + _, ok := client.(*Client) + assert.True(t, ok, "NewClient should return a *Client implementing Interface") +} + +// TestClientMethods verifies that all interface methods are implemented +// Note: These tests don't call real S3 - they just verify the methods exist +func TestClientMethods(t *testing.T) { + client, err := NewClient("http://test-minio:9000", "test-access", "test-secret") + assert.NoError(t, err) + require := assert.New(t) + + ctx := context.Background() + + // Test that methods exist and can be called (will fail without real S3, but that's expected) + // We're just verifying the interface contract + + t.Run("ListObjectsV2 method exists", func(_ *testing.T) { + require.NotPanics(func() { + _, _ = client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{}) + }) + }) + + t.Run("GetObject method exists", func(_ *testing.T) { + require.NotPanics(func() { + _, _ = client.GetObject(ctx, &s3.GetObjectInput{}) + }) + }) + + t.Run("PutObject method exists", func(_ *testing.T) { + require.NotPanics(func() { + _, _ = client.PutObject(ctx, &s3.PutObjectInput{}) + }) + }) + + t.Run("DeleteObject method exists", func(_ *testing.T) { + require.NotPanics(func() { + _, _ = client.DeleteObject(ctx, &s3.DeleteObjectInput{}) + }) + }) +} + +// MockS3Client is a mock implementation of the S3 Interface for testing +type MockS3Client struct { + ListObjectsV2Func func(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) + GetObjectFunc func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) + PutObjectFunc func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) + DeleteObjectFunc func(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) +} + +// Ensure MockS3Client implements Interface +var _ Interface = (*MockS3Client)(nil) + +// ListObjectsV2 delegates to the mock function +func (m *MockS3Client) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { + if m.ListObjectsV2Func != nil { + return m.ListObjectsV2Func(ctx, params, optFns...) + } + return &s3.ListObjectsV2Output{}, nil +} + +// GetObject delegates to the mock function +func (m *MockS3Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + if m.GetObjectFunc != nil { + return m.GetObjectFunc(ctx, params, optFns...) + } + return &s3.GetObjectOutput{}, nil +} + +// PutObject delegates to the mock function +func (m *MockS3Client) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { + if m.PutObjectFunc != nil { + return m.PutObjectFunc(ctx, params, optFns...) + } + return &s3.PutObjectOutput{}, nil +} + +// DeleteObject delegates to the mock function +func (m *MockS3Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) { + if m.DeleteObjectFunc != nil { + return m.DeleteObjectFunc(ctx, params, optFns...) + } + return &s3.DeleteObjectOutput{}, nil +} + +// TestMockS3Client verifies that our mock implementation works correctly +func TestMockS3Client(t *testing.T) { + ctx := context.Background() + + t.Run("mock with custom function", func(t *testing.T) { + called := false + mock := &MockS3Client{ + ListObjectsV2Func: func(_ context.Context, _ *s3.ListObjectsV2Input, _ ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { + called = true + return &s3.ListObjectsV2Output{}, nil + }, + } + + _, err := mock.ListObjectsV2(ctx, &s3.ListObjectsV2Input{}) + assert.NoError(t, err) + assert.True(t, called, "Mock function should have been called") + }) + + t.Run("mock with default behavior", func(t *testing.T) { + mock := &MockS3Client{} + + // Should not panic and return empty output + output, err := mock.ListObjectsV2(ctx, &s3.ListObjectsV2Input{}) + assert.NoError(t, err) + assert.NotNil(t, output) + }) + + t.Run("mock implements interface", func(t *testing.T) { + var client Interface = &MockS3Client{} + assert.NotNil(t, client) + }) +} + +// TestInterfaceUsage demonstrates how the interface enables dependency injection +func TestInterfaceUsage(t *testing.T) { + // Example function that accepts the interface + listBucketContents := func(client Interface, bucket string) error { + _, err := client.ListObjectsV2(context.Background(), &s3.ListObjectsV2Input{ + Bucket: &bucket, + }) + return err + } + + t.Run("with real client", func(t *testing.T) { + client, err := NewClient("http://test:9000", "access", "secret") + assert.NoError(t, err) + + // Function can accept real client (will fail without real S3, but that's ok) + _ = listBucketContents(client, "test-bucket") + }) + + t.Run("with mock client", func(t *testing.T) { + mock := &MockS3Client{ + ListObjectsV2Func: func(_ context.Context, _ *s3.ListObjectsV2Input, _ ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { + // Simulate successful response + return &s3.ListObjectsV2Output{}, nil + }, + } + + // Function can accept mock client for testing + err := listBucketContents(mock, "test-bucket") + assert.NoError(t, err) + }) +} diff --git a/internal/config/config.go b/internal/config/config.go deleted file mode 100644 index d196400..0000000 --- a/internal/config/config.go +++ /dev/null @@ -1,139 +0,0 @@ -// Package config provides configuration management for the backup CLI tool. -// It supports loading configuration from Kubernetes ConfigMaps and Secrets -// with a merge strategy that allows ConfigMap to be overridden by Secret. -package config - -import ( - "context" - "fmt" - - "dario.cat/mergo" - "github.com/go-playground/validator/v10" - "gopkg.in/yaml.v3" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes" -) - -// Config represents the merged configuration from ConfigMap and Secret -type Config struct { - Elasticsearch ElasticsearchConfig `yaml:"elasticsearch" validate:"required"` -} - -// ElasticsearchConfig holds Elasticsearch-specific configuration -type ElasticsearchConfig struct { - Service ServiceConfig `yaml:"service" validate:"required"` - Restore RestoreConfig `yaml:"restore" validate:"required"` - SnapshotRepository SnapshotRepositoryConfig `yaml:"snapshotRepository" validate:"required"` - SLM SLMConfig `yaml:"slm" validate:"required"` -} - -// RestoreConfig holds restore-specific configuration -type RestoreConfig struct { - ScaleDownLabelSelector string `yaml:"scaleDownLabelSelector" validate:"required"` - IndexPrefix string `yaml:"indexPrefix" validate:"required"` - DatastreamIndexPrefix string `yaml:"datastreamIndexPrefix" validate:"required"` - DatastreamName string `yaml:"datastreamName" validate:"required"` - IndicesPattern string `yaml:"indicesPattern" validate:"required"` - Repository string `yaml:"repository" validate:"required"` -} - -// SnapshotRepositoryConfig holds snapshot repository configuration -type SnapshotRepositoryConfig struct { - Name string `yaml:"name" validate:"required"` - Bucket string `yaml:"bucket" validate:"required"` - Endpoint string `yaml:"endpoint" validate:"required"` - BasePath string `yaml:"basepath"` - AccessKey string `yaml:"accessKey" validate:"required"` // From secret - SecretKey string `yaml:"secretKey" validate:"required"` // From secret -} - -// SLMConfig holds Snapshot Lifecycle Management configuration -type SLMConfig struct { - Name string `yaml:"name" validate:"required"` - Schedule string `yaml:"schedule" validate:"required"` - SnapshotTemplateName string `yaml:"snapshotTemplateName" validate:"required"` - Repository string `yaml:"repository" validate:"required"` - Indices string `yaml:"indices" validate:"required"` - RetentionExpireAfter string `yaml:"retentionExpireAfter" validate:"required"` - RetentionMinCount int `yaml:"retentionMinCount" validate:"required,min=1"` - RetentionMaxCount int `yaml:"retentionMaxCount" validate:"required,min=1"` -} - -// ServiceConfig holds service connection details -type ServiceConfig struct { - Name string `yaml:"name" validate:"required"` - Port int `yaml:"port" validate:"required,min=1,max=65535"` - LocalPortForwardPort int `yaml:"localPortForwardPort" validate:"required,min=1,max=65535"` -} - -// LoadConfig loads and merges configuration from ConfigMap and Secret -// ConfigMap provides base configuration, Secret overrides it -// All required fields must be present after merging, validated with validator -func LoadConfig(clientset kubernetes.Interface, namespace, configMapName, secretName string) (*Config, error) { - ctx := context.Background() - config := &Config{} - - // Load ConfigMap if it exists - if configMapName != "" { - cm, err := clientset.CoreV1().ConfigMaps(namespace).Get(ctx, configMapName, metav1.GetOptions{}) - if err != nil { - return nil, fmt.Errorf("failed to get ConfigMap '%s': %w", configMapName, err) - } - - if configData, ok := cm.Data["config"]; ok { - if err := yaml.Unmarshal([]byte(configData), config); err != nil { - return nil, fmt.Errorf("failed to parse ConfigMap config: %w", err) - } - } else { - return nil, fmt.Errorf("ConfigMap '%s' does not contain 'config' key", configMapName) - } - } - - // Load Secret if it exists (overrides ConfigMap) - if secretName != "" { - secret, err := clientset.CoreV1().Secrets(namespace).Get(ctx, secretName, metav1.GetOptions{}) - if err != nil { - // Secret is optional - only used for overrides - fmt.Printf("Warningf: Secret '%s' not found, using ConfigMap only\n", secretName) - } else { - if configData, ok := secret.Data["config"]; ok { - var secretConfig Config - if err := yaml.Unmarshal(configData, &secretConfig); err != nil { - return nil, fmt.Errorf("failed to parse Secret config: %w", err) - } - // Merge Secret config into base config (non-zero values override) - if err := mergo.Merge(config, secretConfig, mergo.WithOverride); err != nil { - return nil, fmt.Errorf("failed to merge Secret config: %w", err) - } - } - } - } - - // Validate the merged configuration - validate := validator.New() - if err := validate.Struct(config); err != nil { - return nil, fmt.Errorf("configuration validation failed: %w", err) - } - - return config, nil -} - -type Context struct { - Config *CLIConfig -} - -type CLIConfig struct { - Namespace string - Kubeconfig string - Debug bool - Quiet bool - ConfigMapName string - SecretName string - OutputFormat string // table, json -} - -func NewContext() *Context { - return &Context{ - Config: &CLIConfig{}, - } -} diff --git a/internal/foundation/config/config.go b/internal/foundation/config/config.go new file mode 100644 index 0000000..d083340 --- /dev/null +++ b/internal/foundation/config/config.go @@ -0,0 +1,302 @@ +// Package config provides configuration management for the backup CLI tool. +// It supports loading configuration from Kubernetes ConfigMaps and Secrets +// with a merge strategy that allows ConfigMap to be overridden by Secret. +package config + +import ( + "context" + "fmt" + + "dario.cat/mergo" + "github.com/go-playground/validator/v10" + "gopkg.in/yaml.v3" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes" +) + +// Config represents the merged configuration from ConfigMap and Secret +type Config struct { + Kubernetes KubernetesConfig `yaml:"kubernetes"` + Elasticsearch ElasticsearchConfig `yaml:"elasticsearch" validate:"required"` + Minio MinioConfig `yaml:"minio" validate:"required"` + Stackgraph StackgraphConfig `yaml:"stackgraph" validate:"required"` +} + +// KubernetesConfig holds Kubernetes-wide configuration +type KubernetesConfig struct { + CommonLabels map[string]string `yaml:"commonLabels"` +} + +// ElasticsearchConfig holds Elasticsearch-specific configuration +type ElasticsearchConfig struct { + Service ServiceConfig `yaml:"service" validate:"required"` + Restore RestoreConfig `yaml:"restore" validate:"required"` + SnapshotRepository SnapshotRepositoryConfig `yaml:"snapshotRepository" validate:"required"` + SLM SLMConfig `yaml:"slm" validate:"required"` +} + +// RestoreConfig holds restore-specific configuration +type RestoreConfig struct { + ScaleDownLabelSelector string `yaml:"scaleDownLabelSelector" validate:"required"` + IndexPrefix string `yaml:"indexPrefix" validate:"required"` + DatastreamIndexPrefix string `yaml:"datastreamIndexPrefix" validate:"required"` + DatastreamName string `yaml:"datastreamName" validate:"required"` + IndicesPattern string `yaml:"indicesPattern" validate:"required"` + Repository string `yaml:"repository" validate:"required"` +} + +// SnapshotRepositoryConfig holds snapshot repository configuration +type SnapshotRepositoryConfig struct { + Name string `yaml:"name" validate:"required"` + Bucket string `yaml:"bucket" validate:"required"` + Endpoint string `yaml:"endpoint" validate:"required"` + BasePath string `yaml:"basepath"` + AccessKey string `yaml:"accessKey" validate:"required"` // From secret + SecretKey string `yaml:"secretKey" validate:"required"` // From secret +} + +// SLMConfig holds Snapshot Lifecycle Management configuration +type SLMConfig struct { + Name string `yaml:"name" validate:"required"` + Schedule string `yaml:"schedule" validate:"required"` + SnapshotTemplateName string `yaml:"snapshotTemplateName" validate:"required"` + Repository string `yaml:"repository" validate:"required"` + Indices string `yaml:"indices" validate:"required"` + RetentionExpireAfter string `yaml:"retentionExpireAfter" validate:"required"` + RetentionMinCount int `yaml:"retentionMinCount" validate:"required,min=1"` + RetentionMaxCount int `yaml:"retentionMaxCount" validate:"required,min=1"` +} + +// ServiceConfig holds service connection details +type ServiceConfig struct { + Name string `yaml:"name" validate:"required"` + Port int `yaml:"port" validate:"required,min=1,max=65535"` + LocalPortForwardPort int `yaml:"localPortForwardPort" validate:"required,min=1,max=65535"` +} + +// MinioConfig holds Minio-specific configuration +type MinioConfig struct { + Service ServiceConfig `yaml:"service" validate:"required"` + AccessKey string `yaml:"accessKey" validate:"required"` // From secret + SecretKey string `yaml:"secretKey" validate:"required"` // From secret +} + +// StackgraphConfig holds Stackgraph backup-specific configuration +type StackgraphConfig struct { + Bucket string `yaml:"bucket" validate:"required"` + S3Prefix string `yaml:"s3Prefix"` + MultipartArchive bool `yaml:"multipartArchive" validate:"boolean"` + Restore StackgraphRestoreConfig `yaml:"restore" validate:"required"` +} + +// StackgraphRestoreConfig holds Stackgraph restore-specific configuration +type StackgraphRestoreConfig struct { + ScaleDownLabelSelector string `yaml:"scaleDownLabelSelector" validate:"required"` + LoggingConfigConfigMapName string `yaml:"loggingConfigConfigMap" validate:"required"` + ZookeeperQuorum string `yaml:"zookeeperQuorum" validate:"required"` + Job JobConfig `yaml:"job" validate:"required"` + PVC PVCConfig `yaml:"pvc" validate:"required"` +} + +// PVCConfig holds PersistentVolumeClaim configuration +type PVCConfig struct { + Size string `yaml:"size" validate:"required"` + AccessModes []string `yaml:"accessModes"` + StorageClassName string `yaml:"storageClassName"` +} + +// JobConfig holds Kubernetes Job configuration that can be applied to backup/restore jobs +type JobConfig struct { + Labels map[string]string `yaml:"labels"` + ImagePullSecrets []LocalObjectRef `yaml:"imagePullSecrets"` + SecurityContext PodSecurityContext `yaml:"securityContext"` + NodeSelector map[string]string `yaml:"nodeSelector"` + Tolerations []Toleration `yaml:"tolerations"` + Affinity *Affinity `yaml:"affinity"` + Resources ResourceRequirements `yaml:"resources" validate:"required"` + ContainerSecurityContext *SecurityContext `yaml:"containerSecurityContext"` + Image string `yaml:"image" validate:"required"` + WaitImage string `yaml:"waitImage" validate:"required"` +} + +// LocalObjectRef represents a reference to a local object by name +type LocalObjectRef struct { + Name string `yaml:"name" validate:"required"` +} + +// PodSecurityContext holds pod-level security context settings +type PodSecurityContext struct { + FSGroup *int64 `yaml:"fsGroup"` + RunAsGroup *int64 `yaml:"runAsGroup"` + RunAsNonRoot *bool `yaml:"runAsNonRoot"` + RunAsUser *int64 `yaml:"runAsUser"` +} + +// SecurityContext holds container-level security context settings +type SecurityContext struct { + AllowPrivilegeEscalation *bool `yaml:"allowPrivilegeEscalation"` + RunAsNonRoot *bool `yaml:"runAsNonRoot"` + RunAsUser *int64 `yaml:"runAsUser"` +} + +// Toleration represents a pod toleration +type Toleration struct { + Key string `yaml:"key"` + Operator string `yaml:"operator"` + Value string `yaml:"value"` + Effect string `yaml:"effect"` +} + +// Affinity represents pod affinity and anti-affinity settings +type Affinity struct { + NodeAffinity *NodeAffinity `yaml:"nodeAffinity"` + PodAffinity *PodAffinity `yaml:"podAffinity"` + PodAntiAffinity *PodAntiAffinity `yaml:"podAntiAffinity"` +} + +// NodeAffinity represents node affinity scheduling rules +type NodeAffinity struct { + RequiredDuringSchedulingIgnoredDuringExecution *NodeSelector `yaml:"requiredDuringSchedulingIgnoredDuringExecution"` + PreferredDuringSchedulingIgnoredDuringExecution []PreferredSchedulingTerm `yaml:"preferredDuringSchedulingIgnoredDuringExecution"` +} + +// NodeSelector represents node selector requirements +type NodeSelector struct { + NodeSelectorTerms []NodeSelectorTerm `yaml:"nodeSelectorTerms"` +} + +// NodeSelectorTerm represents node selector term +type NodeSelectorTerm struct { + MatchExpressions []NodeSelectorRequirement `yaml:"matchExpressions"` + MatchFields []NodeSelectorRequirement `yaml:"matchFields"` +} + +// NodeSelectorRequirement represents a node selector requirement +type NodeSelectorRequirement struct { + Key string `yaml:"key"` + Operator string `yaml:"operator"` + Values []string `yaml:"values"` +} + +// PreferredSchedulingTerm represents a preferred scheduling term +type PreferredSchedulingTerm struct { + Weight int32 `yaml:"weight"` + Preference NodeSelectorTerm `yaml:"preference"` +} + +// PodAffinity represents pod affinity scheduling rules +type PodAffinity struct { + RequiredDuringSchedulingIgnoredDuringExecution []PodAffinityTerm `yaml:"requiredDuringSchedulingIgnoredDuringExecution"` + PreferredDuringSchedulingIgnoredDuringExecution []WeightedPodAffinityTerm `yaml:"preferredDuringSchedulingIgnoredDuringExecution"` +} + +// PodAntiAffinity represents pod anti-affinity scheduling rules +type PodAntiAffinity struct { + RequiredDuringSchedulingIgnoredDuringExecution []PodAffinityTerm `yaml:"requiredDuringSchedulingIgnoredDuringExecution"` + PreferredDuringSchedulingIgnoredDuringExecution []WeightedPodAffinityTerm `yaml:"preferredDuringSchedulingIgnoredDuringExecution"` +} + +// PodAffinityTerm represents pod affinity term +type PodAffinityTerm struct { + LabelSelector *LabelSelector `yaml:"labelSelector"` + Namespaces []string `yaml:"namespaces"` + TopologyKey string `yaml:"topologyKey"` +} + +// WeightedPodAffinityTerm represents weighted pod affinity term +type WeightedPodAffinityTerm struct { + Weight int32 `yaml:"weight"` + PodAffinityTerm PodAffinityTerm `yaml:"podAffinityTerm"` +} + +// LabelSelector represents a label selector +type LabelSelector struct { + MatchLabels map[string]string `yaml:"matchLabels"` + MatchExpressions []LabelSelectorRequirement `yaml:"matchExpressions"` +} + +// LabelSelectorRequirement represents a label selector requirement +type LabelSelectorRequirement struct { + Key string `yaml:"key"` + Operator string `yaml:"operator"` + Values []string `yaml:"values"` +} + +// ResourceRequirements holds resource limits and requests +type ResourceRequirements struct { + Limits ResourceList `yaml:"limits" validate:"required"` + Requests ResourceList `yaml:"requests" validate:"required"` +} + +// ResourceList holds resource quantities +type ResourceList struct { + CPU string `yaml:"cpu" validate:"required"` + Memory string `yaml:"memory" validate:"required"` + EphemeralStorage string `yaml:"ephemeralStorage"` +} + +// LoadConfig loads and merges configuration from ConfigMap and Secret +// ConfigMap provides base configuration, Secret overrides it +// All required fields must be present after merging, validated with validator +func LoadConfig(clientset kubernetes.Interface, namespace, configMapName, secretName string) (*Config, error) { + ctx := context.Background() + config := &Config{} + + // Load ConfigMap if it exists + if configMapName != "" { + cm, err := clientset.CoreV1().ConfigMaps(namespace).Get(ctx, configMapName, metav1.GetOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to get ConfigMap '%s': %w", configMapName, err) + } + + if configData, ok := cm.Data["config"]; ok { + if err := yaml.Unmarshal([]byte(configData), config); err != nil { + return nil, fmt.Errorf("failed to parse ConfigMap config: %w", err) + } + } else { + return nil, fmt.Errorf("ConfigMap '%s' does not contain 'config' key", configMapName) + } + } + + // Load Secret if it exists (overrides ConfigMap) + if secretName != "" { + secret, err := clientset.CoreV1().Secrets(namespace).Get(ctx, secretName, metav1.GetOptions{}) + if err != nil { + // Secret is optional - only used for overrides + fmt.Printf("Warningf: Secret '%s' not found, using ConfigMap only\n", secretName) + } else { + if configData, ok := secret.Data["config"]; ok { + var secretConfig Config + if err := yaml.Unmarshal(configData, &secretConfig); err != nil { + return nil, fmt.Errorf("failed to parse Secret config: %w", err) + } + // Merge Secret config into base config (non-zero values override) + if err := mergo.Merge(config, secretConfig, mergo.WithOverride); err != nil { + return nil, fmt.Errorf("failed to merge Secret config: %w", err) + } + } + } + } + + // Validate the merged configuration + validate := validator.New() + if err := validate.Struct(config); err != nil { + return nil, fmt.Errorf("configuration validation failed: %w", err) + } + + return config, nil +} + +type CLIGlobalFlags struct { + Namespace string + Kubeconfig string + Debug bool + Quiet bool + ConfigMapName string + SecretName string + OutputFormat string // table, json +} + +func NewCLIGlobalFlags() *CLIGlobalFlags { + return &CLIGlobalFlags{} +} diff --git a/internal/config/config_test.go b/internal/foundation/config/config_test.go similarity index 91% rename from internal/config/config_test.go rename to internal/foundation/config/config_test.go index 544a1d3..2cd91c8 100644 --- a/internal/config/config_test.go +++ b/internal/foundation/config/config_test.go @@ -30,7 +30,7 @@ func loadTestData(t *testing.T, filename string) string { } func TestLoadConfig_FromConfigMapOnly(t *testing.T) { - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() validConfigYAML := loadTestData(t, "validConfigMapOnly.yaml") // Create ConfigMap @@ -62,7 +62,7 @@ func TestLoadConfig_FromConfigMapOnly(t *testing.T) { } func TestLoadConfig_CompleteConfiguration(t *testing.T) { - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() validConfigYAML := loadTestData(t, "validConfigMapConfig.yaml") secretOverrideYAML := loadTestData(t, "validSecretConfig.yaml") @@ -137,7 +137,7 @@ func TestLoadConfig_CompleteConfiguration(t *testing.T) { } func TestLoadConfig_WithSecretOverride(t *testing.T) { - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() validConfigYAML := loadTestData(t, "validConfigMapOnly.yaml") secretOverrideYAML := loadTestData(t, "validSecretConfig.yaml") @@ -184,7 +184,7 @@ func TestLoadConfig_WithSecretOverride(t *testing.T) { } func TestLoadConfig_ConfigMapNotFound(t *testing.T) { - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() // Try to load non-existent ConfigMap config, err := LoadConfig(fakeClient, "test-ns", "nonexistent", "") @@ -196,7 +196,7 @@ func TestLoadConfig_ConfigMapNotFound(t *testing.T) { } func TestLoadConfig_ConfigMapMissingConfigKey(t *testing.T) { - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() validConfigYAML := loadTestData(t, "validConfigMapOnly.yaml") // Create ConfigMap without 'config' key @@ -224,7 +224,7 @@ func TestLoadConfig_ConfigMapMissingConfigKey(t *testing.T) { } func TestLoadConfig_InvalidYAML(t *testing.T) { - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() // Create ConfigMap with invalid YAML cm := &corev1.ConfigMap{ @@ -251,7 +251,7 @@ func TestLoadConfig_InvalidYAML(t *testing.T) { } func TestLoadConfig_ValidationFails(t *testing.T) { - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() // Create ConfigMap with invalid config (missing required fields) cm := &corev1.ConfigMap{ @@ -278,7 +278,7 @@ func TestLoadConfig_ValidationFails(t *testing.T) { } func TestLoadConfig_SecretNotFoundWarning(t *testing.T) { - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() validConfigYAML := loadTestData(t, "validConfigMapOnly.yaml") // Create only ConfigMap @@ -306,7 +306,7 @@ func TestLoadConfig_SecretNotFoundWarning(t *testing.T) { } func TestLoadConfig_EmptyConfigMapName(t *testing.T) { - fakeClient := fake.NewSimpleClientset() + fakeClient := fake.NewClientset() // Try to load with empty ConfigMap name config, err := LoadConfig(fakeClient, "test-ns", "", "") @@ -316,33 +316,6 @@ func TestLoadConfig_EmptyConfigMapName(t *testing.T) { assert.Nil(t, config) } -func TestNewContext(t *testing.T) { - ctx := NewContext() - - assert.NotNil(t, ctx) - assert.NotNil(t, ctx.Config) - assert.Equal(t, "", ctx.Config.Namespace) - assert.Equal(t, "", ctx.Config.Kubeconfig) - assert.False(t, ctx.Config.Debug) - assert.False(t, ctx.Config.Quiet) - assert.Equal(t, "", ctx.Config.ConfigMapName) - assert.Equal(t, "", ctx.Config.SecretName) - assert.Equal(t, "", ctx.Config.OutputFormat) -} - -func TestCLIConfig_Defaults(t *testing.T) { - config := &CLIConfig{} - - // Verify zero values - assert.Equal(t, "", config.Namespace) - assert.Equal(t, "", config.Kubeconfig) - assert.False(t, config.Debug) - assert.False(t, config.Quiet) - assert.Equal(t, "", config.ConfigMapName) - assert.Equal(t, "", config.SecretName) - assert.Equal(t, "", config.OutputFormat) -} - //nolint:funlen func TestConfig_StructValidation(t *testing.T) { tests := []struct { @@ -385,6 +358,42 @@ func TestConfig_StructValidation(t *testing.T) { RetentionMaxCount: 10, }, }, + Minio: MinioConfig{ + Service: ServiceConfig{ + Name: "minio", + Port: 9000, + LocalPortForwardPort: 9000, + }, + AccessKey: "minioadmin", + SecretKey: "minioadmin", + }, + Stackgraph: StackgraphConfig{ + Bucket: "stackgraph-bucket", + S3Prefix: "", + MultipartArchive: true, + Restore: StackgraphRestoreConfig{ + ScaleDownLabelSelector: "app=stackgraph", + LoggingConfigConfigMapName: "logging-config", + ZookeeperQuorum: "zookeeper:2181", + Job: JobConfig{ + Image: "backup:latest", + WaitImage: "wait:latest", + Resources: ResourceRequirements{ + Limits: ResourceList{ + CPU: "2", + Memory: "4Gi", + }, + Requests: ResourceList{ + CPU: "1", + Memory: "2Gi", + }, + }, + }, + PVC: PVCConfig{ + Size: "10Gi", + }, + }, + }, }, expectError: false, }, diff --git a/internal/config/testdata/validConfigMapConfig.yaml b/internal/foundation/config/testdata/validConfigMapConfig.yaml similarity index 68% rename from internal/config/testdata/validConfigMapConfig.yaml rename to internal/foundation/config/testdata/validConfigMapConfig.yaml index f538995..ebff615 100644 --- a/internal/config/testdata/validConfigMapConfig.yaml +++ b/internal/foundation/config/testdata/validConfigMapConfig.yaml @@ -60,3 +60,49 @@ elasticsearch: datastreamName: sts_k8s_logs # Pattern for indices to restore from snapshot (comma-separated glob patterns) indicesPattern: sts*,.ds-sts_k8s_logs* + +# Minio configuration for S3-compatible storage +minio: + # Minio service connection details + service: + name: suse-observability-minio + port: 9000 + localPortForwardPort: 9000 + # Access credentials (typically from Kubernetes secret) + accessKey: minioadmin + secretKey: minioadmin + +# Stackgraph backup configuration +stackgraph: + # S3 bucket for stackgraph backups + bucket: sts-stackgraph-backup + # S3 prefix path for backups + s3Prefix: "" + # Archive split to multiple parts + multipartArchive: true + # Restore configuration + restore: + # Label selector for deployments to scale down during restore + scaleDownLabelSelector: "observability.suse.com/scalable-during-stackgraph-restore=true" + # ConfigMap containing logging configuration + loggingConfigConfigMap: suse-observability-logging + # Zookeeper quorum connection string + zookeeperQuorum: "suse-observability-zookeeper:2181" + # Job configuration + job: + labels: + app: stackgraph-restore + image: quay.io/stackstate/stackstate-backup:latest + waitImage: quay.io/stackstate/wait:latest + resources: + limits: + cpu: "2" + memory: "4Gi" + requests: + cpu: "1" + memory: "2Gi" + # PVC configuration for restore jobs + pvc: + size: "10Gi" + accessModes: + - ReadWriteOnce diff --git a/internal/config/testdata/validConfigMapOnly.yaml b/internal/foundation/config/testdata/validConfigMapOnly.yaml similarity index 77% rename from internal/config/testdata/validConfigMapOnly.yaml rename to internal/foundation/config/testdata/validConfigMapOnly.yaml index a7b3415..df61a1a 100644 --- a/internal/config/testdata/validConfigMapOnly.yaml +++ b/internal/foundation/config/testdata/validConfigMapOnly.yaml @@ -66,4 +66,39 @@ elasticsearch: # Name of the datastream (used for rollover operations) datastreamName: sts_k8s_logs # Pattern for indices to restore from snapshot (comma-separated glob patterns) - indicesPattern: sts*,.ds-sts_k8s_logs* \ No newline at end of file + indicesPattern: sts*,.ds-sts_k8s_logs* + +# Minio configuration for S3-compatible storage +minio: + service: + name: suse-observability-minio + port: 9000 + localPortForwardPort: 9000 + accessKey: minioadmin + secretKey: minioadmin + +# Stackgraph backup configuration +stackgraph: + bucket: sts-stackgraph-backup + s3Prefix: "" + multipartArchive: true + restore: + scaleDownLabelSelector: "observability.suse.com/scalable-during-stackgraph-restore=true" + loggingConfigConfigMap: suse-observability-logging + zookeeperQuorum: "suse-observability-zookeeper:2181" + job: + labels: + app: stackgraph-restore + image: quay.io/stackstate/stackstate-backup:latest + waitImage: quay.io/stackstate/wait:latest + resources: + limits: + cpu: "2" + memory: "4Gi" + requests: + cpu: "1" + memory: "2Gi" + pvc: + size: "10Gi" + accessModes: + - ReadWriteOnce diff --git a/internal/config/testdata/validSecretConfig.yaml b/internal/foundation/config/testdata/validSecretConfig.yaml similarity index 77% rename from internal/config/testdata/validSecretConfig.yaml rename to internal/foundation/config/testdata/validSecretConfig.yaml index a41c8c0..e004e52 100644 --- a/internal/config/testdata/validSecretConfig.yaml +++ b/internal/foundation/config/testdata/validSecretConfig.yaml @@ -13,3 +13,9 @@ elasticsearch: # S3/Minio secret key (overrides ConfigMap value if present) # Keep this value secure - it should never be committed to ConfigMaps secretKey: secret-secret-key + +minio: + # Minio access key (overrides ConfigMap value if present) + accessKey: secret-minio-access-key + # Minio secret key (overrides ConfigMap value if present) + secretKey: secret-minio-secret-key diff --git a/internal/logger/logger.go b/internal/foundation/logger/logger.go similarity index 100% rename from internal/logger/logger.go rename to internal/foundation/logger/logger.go diff --git a/internal/logger/logger_test.go b/internal/foundation/logger/logger_test.go similarity index 100% rename from internal/logger/logger_test.go rename to internal/foundation/logger/logger_test.go diff --git a/internal/output/formatter.go b/internal/foundation/output/formatter.go similarity index 88% rename from internal/output/formatter.go rename to internal/foundation/output/formatter.go index c8a136c..02ebc10 100644 --- a/internal/output/formatter.go +++ b/internal/foundation/output/formatter.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "io" - "os" "strings" "text/tabwriter" ) @@ -28,13 +27,13 @@ type Formatter struct { // NewFormatter creates a new output formatter // Defaults to table format if invalid format provided -func NewFormatter(format string) *Formatter { +func NewFormatter(wr io.Writer, format string) *Formatter { f := Format(format) if f != FormatTable && f != FormatJSON { f = FormatTable } return &Formatter{ - writer: os.Stdout, + writer: wr, format: f, } } @@ -49,7 +48,7 @@ type Table struct { func (f *Formatter) PrintTable(table Table) error { if len(table.Rows) == 0 { if f.format == FormatTable { - fmt.Fprintln(f.writer, "No data found") + _, _ = fmt.Fprintln(f.writer, "No data found") } else { // For JSON, output empty array return f.printJSON([]map[string]string{}) @@ -72,11 +71,11 @@ func (f *Formatter) printTable(table Table) error { w := tabwriter.NewWriter(f.writer, 0, 0, tabwriterPadding, ' ', 0) // Print header - fmt.Fprintln(w, strings.Join(table.Headers, "\t")) + _, _ = fmt.Fprintln(w, strings.Join(table.Headers, "\t")) // Print rows for _, row := range table.Rows { - fmt.Fprintln(w, strings.Join(row, "\t")) + _, _ = fmt.Fprintln(w, strings.Join(row, "\t")) } return w.Flush() @@ -107,13 +106,13 @@ func tableToMaps(table Table) []map[string]string { // PrintMessage prints a simple message (only in table format, ignored in JSON) func (f *Formatter) PrintMessage(message string) { if f.format == FormatTable { - fmt.Fprintln(f.writer, message) + _, _ = fmt.Fprintln(f.writer, message) } } // PrintError prints an error message (only in table format, ignored in JSON) func (f *Formatter) PrintError(err error) { if f.format == FormatTable { - fmt.Fprintf(f.writer, "Errorf: %v\n", err) + _, _ = fmt.Fprintf(f.writer, "Errorf: %v\n", err) } } diff --git a/internal/output/formatter_test.go b/internal/foundation/output/formatter_test.go similarity index 99% rename from internal/output/formatter_test.go rename to internal/foundation/output/formatter_test.go index ab2e869..da816f3 100644 --- a/internal/output/formatter_test.go +++ b/internal/foundation/output/formatter_test.go @@ -3,6 +3,7 @@ package output import ( "bytes" "encoding/json" + "os" "strings" "testing" @@ -40,7 +41,7 @@ func TestNewFormatter(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - formatter := NewFormatter(tt.format) + formatter := NewFormatter(os.Stdout, tt.format) assert.NotNil(t, formatter) assert.Equal(t, tt.expectedFormat, formatter.format) assert.NotNil(t, formatter.writer) diff --git a/internal/k8s/client_test.go b/internal/k8s/client_test.go deleted file mode 100644 index 6868624..0000000 --- a/internal/k8s/client_test.go +++ /dev/null @@ -1,328 +0,0 @@ -package k8s - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - appsv1 "k8s.io/api/apps/v1" - corev1 "k8s.io/api/core/v1" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/client-go/kubernetes/fake" -) - -func TestClient_ScaleDownDeployments(t *testing.T) { - tests := []struct { - name string - namespace string - labelSelector string - deployments []appsv1.Deployment - expectedScales []DeploymentScale - expectError bool - }{ - { - name: "scale down multiple deployments", - namespace: "test-ns", - labelSelector: "app=test", - deployments: []appsv1.Deployment{ - createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 3), - createDeployment("deploy2", "test-ns", map[string]string{"app": "test"}, 5), - }, - expectedScales: []DeploymentScale{ - {Name: "deploy1", Replicas: 3}, - {Name: "deploy2", Replicas: 5}, - }, - expectError: false, - }, - { - name: "scale down deployment with zero replicas", - namespace: "test-ns", - labelSelector: "app=test", - deployments: []appsv1.Deployment{ - createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 0), - }, - expectedScales: []DeploymentScale{ - {Name: "deploy1", Replicas: 0}, - }, - expectError: false, - }, - { - name: "no deployments matching selector", - namespace: "test-ns", - labelSelector: "app=nonexistent", - deployments: []appsv1.Deployment{}, - expectedScales: []DeploymentScale{}, - expectError: false, - }, - { - name: "deployments with different labels not selected", - namespace: "test-ns", - labelSelector: "app=test", - deployments: []appsv1.Deployment{ - createDeployment("deploy1", "test-ns", map[string]string{"app": "test"}, 3), - createDeployment("deploy2", "test-ns", map[string]string{"app": "other"}, 2), - }, - expectedScales: []DeploymentScale{ - {Name: "deploy1", Replicas: 3}, - }, - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create fake clientset with test deployments - fakeClient := fake.NewSimpleClientset() - for _, deploy := range tt.deployments { - _, err := fakeClient.AppsV1().Deployments(tt.namespace).Create( - context.Background(), &deploy, metav1.CreateOptions{}, - ) - require.NoError(t, err) - } - - // Create our client wrapper - client := &Client{ - clientset: fakeClient, - } - - // Execute scale down - scales, err := client.ScaleDownDeployments(tt.namespace, tt.labelSelector) - - // Assertions - if tt.expectError { - assert.Error(t, err) - return - } - - require.NoError(t, err) - assert.Equal(t, len(tt.expectedScales), len(scales)) - - // Verify each scaled deployment - for i, expectedScale := range tt.expectedScales { - assert.Equal(t, expectedScale.Name, scales[i].Name) - assert.Equal(t, expectedScale.Replicas, scales[i].Replicas) - - // Verify the deployment was actually scaled to 0 - deploy, err := fakeClient.AppsV1().Deployments(tt.namespace).Get( - context.Background(), expectedScale.Name, metav1.GetOptions{}, - ) - require.NoError(t, err) - if expectedScale.Replicas > 0 { - assert.Equal(t, int32(0), *deploy.Spec.Replicas, "deployment should be scaled to 0") - } - } - }) - } -} - -func TestClient_ScaleUpDeployments(t *testing.T) { - tests := []struct { - name string - namespace string - initialReplicas int32 - scaleToReplicas int32 - deploymentName string - expectError bool - }{ - { - name: "scale up from zero to three", - namespace: "test-ns", - initialReplicas: 0, - scaleToReplicas: 3, - deploymentName: "test-deploy", - expectError: false, - }, - { - name: "scale up from two to five", - namespace: "test-ns", - initialReplicas: 2, - scaleToReplicas: 5, - deploymentName: "test-deploy", - expectError: false, - }, - { - name: "restore to zero replicas", - namespace: "test-ns", - initialReplicas: 3, - scaleToReplicas: 0, - deploymentName: "test-deploy", - expectError: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create fake clientset with deployment at initial scale - fakeClient := fake.NewSimpleClientset() - deploy := createDeployment(tt.deploymentName, tt.namespace, map[string]string{"app": "test"}, tt.initialReplicas) - _, err := fakeClient.AppsV1().Deployments(tt.namespace).Create( - context.Background(), &deploy, metav1.CreateOptions{}, - ) - require.NoError(t, err) - - // Create our client wrapper - client := &Client{ - clientset: fakeClient, - } - - // Execute scale up - scales := []DeploymentScale{ - {Name: tt.deploymentName, Replicas: tt.scaleToReplicas}, - } - err = client.ScaleUpDeployments(tt.namespace, scales) - - // Assertions - if tt.expectError { - assert.Error(t, err) - return - } - - require.NoError(t, err) - - // Verify the deployment was scaled to expected replicas - updatedDeploy, err := fakeClient.AppsV1().Deployments(tt.namespace).Get( - context.Background(), tt.deploymentName, metav1.GetOptions{}, - ) - require.NoError(t, err) - assert.Equal(t, tt.scaleToReplicas, *updatedDeploy.Spec.Replicas) - }) - } -} - -func TestClient_ScaleUpDeployments_NonExistent(t *testing.T) { - fakeClient := fake.NewSimpleClientset() - client := &Client{ - clientset: fakeClient, - } - - scales := []DeploymentScale{ - {Name: "nonexistent-deploy", Replicas: 3}, - } - err := client.ScaleUpDeployments("test-ns", scales) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to get deployment") -} - -func TestClient_Clientset(t *testing.T) { - fakeClient := fake.NewSimpleClientset() - client := &Client{ - clientset: fakeClient, - } - - clientset := client.Clientset() - assert.NotNil(t, clientset) - assert.Equal(t, fakeClient, clientset) -} - -func TestClient_PortForwardService_ServiceNotFound(t *testing.T) { - fakeClient := fake.NewSimpleClientset() - client := &Client{ - clientset: fakeClient, - } - - _, _, err := client.PortForwardService("test-ns", "nonexistent-svc", 8080, 9200) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to get service") -} - -func TestClient_PortForwardService_NoPodsFound(t *testing.T) { - fakeClient := fake.NewSimpleClientset() - - // Create a service without any matching pods - svc := &corev1.Service{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-svc", - Namespace: "test-ns", - }, - Spec: corev1.ServiceSpec{ - Selector: map[string]string{"app": "test"}, - }, - } - _, err := fakeClient.CoreV1().Services("test-ns").Create( - context.Background(), svc, metav1.CreateOptions{}, - ) - require.NoError(t, err) - - client := &Client{ - clientset: fakeClient, - } - - _, _, err = client.PortForwardService("test-ns", "test-svc", 8080, 9200) - assert.Error(t, err) - assert.Contains(t, err.Error(), "no pods found for service") -} - -func TestClient_PortForwardService_NoRunningPods(t *testing.T) { - fakeClient := fake.NewSimpleClientset() - - // Create a service - svc := &corev1.Service{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-svc", - Namespace: "test-ns", - }, - Spec: corev1.ServiceSpec{ - Selector: map[string]string{"app": "test"}, - }, - } - _, err := fakeClient.CoreV1().Services("test-ns").Create( - context.Background(), svc, metav1.CreateOptions{}, - ) - require.NoError(t, err) - - // Create a pod in Pending state - pod := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test-pod", - Namespace: "test-ns", - Labels: map[string]string{"app": "test"}, - }, - Status: corev1.PodStatus{ - Phase: corev1.PodPending, - }, - } - _, err = fakeClient.CoreV1().Pods("test-ns").Create( - context.Background(), pod, metav1.CreateOptions{}, - ) - require.NoError(t, err) - - client := &Client{ - clientset: fakeClient, - } - - _, _, err = client.PortForwardService("test-ns", "test-svc", 8080, 9200) - assert.Error(t, err) - assert.Contains(t, err.Error(), "no running pods found for service") -} - -// Helper function to create a deployment for testing -func createDeployment(name, namespace string, labels map[string]string, replicas int32) appsv1.Deployment { - return appsv1.Deployment{ - ObjectMeta: metav1.ObjectMeta{ - Name: name, - Namespace: namespace, - Labels: labels, - }, - Spec: appsv1.DeploymentSpec{ - Replicas: &replicas, - Selector: &metav1.LabelSelector{ - MatchLabels: labels, - }, - Template: corev1.PodTemplateSpec{ - ObjectMeta: metav1.ObjectMeta{ - Labels: labels, - }, - Spec: corev1.PodSpec{ - Containers: []corev1.Container{ - { - Name: "test-container", - Image: "test:latest", - }, - }, - }, - }, - }, - } -} diff --git a/cmd/portforward/portforward.go b/internal/orchestration/portforward/portforward.go similarity index 88% rename from cmd/portforward/portforward.go rename to internal/orchestration/portforward/portforward.go index 19830bb..032361c 100644 --- a/cmd/portforward/portforward.go +++ b/internal/orchestration/portforward/portforward.go @@ -3,8 +3,8 @@ package portforward import ( "fmt" - "github.com/stackvista/stackstate-backup-cli/internal/k8s" - "github.com/stackvista/stackstate-backup-cli/internal/logger" + "github.com/stackvista/stackstate-backup-cli/internal/clients/k8s" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/logger" ) // Conn contains the channels needed to manage a port-forward connection diff --git a/cmd/portforward/portforward_test.go b/internal/orchestration/portforward/portforward_test.go similarity index 95% rename from cmd/portforward/portforward_test.go rename to internal/orchestration/portforward/portforward_test.go index f6f10ae..c710dd1 100644 --- a/cmd/portforward/portforward_test.go +++ b/internal/orchestration/portforward/portforward_test.go @@ -7,8 +7,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes/fake" - "github.com/stackvista/stackstate-backup-cli/internal/k8s" - "github.com/stackvista/stackstate-backup-cli/internal/logger" + "github.com/stackvista/stackstate-backup-cli/internal/clients/k8s" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/logger" ) func TestSetupPortForward_ServiceNotFound(t *testing.T) { diff --git a/internal/orchestration/scale/scale.go b/internal/orchestration/scale/scale.go new file mode 100644 index 0000000..4a07ae4 --- /dev/null +++ b/internal/orchestration/scale/scale.go @@ -0,0 +1,115 @@ +package scale + +import ( + "context" + "fmt" + "time" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/stackvista/stackstate-backup-cli/internal/clients/k8s" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/logger" +) + +const ( + podTerminationCheckInterval = 2 * time.Second + podTerminationTimeout = 120 * time.Second +) + +// ScaleDown scales down deployments matching the label selector and logs the results. +// It waits for all pods to terminate before returning. +// +//nolint:revive // Package name "scale" with function "ScaleDown" is intentionally verbose for clarity +func ScaleDown(k8sClient *k8s.Client, namespace, labelSelector string, log *logger.Logger) ([]k8s.DeploymentScale, error) { + log.Infof("Scaling down deployments (selector: %s)...", labelSelector) + + scaledDeployments, err := k8sClient.ScaleDownDeployments(namespace, labelSelector) + if err != nil { + return nil, fmt.Errorf("failed to scale down deployments: %w", err) + } + + if len(scaledDeployments) == 0 { + log.Infof("No deployments found to scale down") + return scaledDeployments, nil + } + + log.Successf("Scaled down %d deployment(s):", len(scaledDeployments)) + for _, dep := range scaledDeployments { + log.Infof(" - %s (replicas: %d -> 0)", dep.Name, dep.Replicas) + } + + // Wait for pods to terminate + if err := waitForPodsToTerminate(k8sClient, namespace, labelSelector, log); err != nil { + return scaledDeployments, fmt.Errorf("failed waiting for pods to terminate: %w", err) + } + + return scaledDeployments, nil +} + +// waitForPodsToTerminate polls for pod termination until all pods matching the label selector are gone +func waitForPodsToTerminate(k8sClient *k8s.Client, namespace, labelSelector string, log *logger.Logger) error { + ctx := context.Background() + ticker := time.NewTicker(podTerminationCheckInterval) + defer ticker.Stop() + + timeout := time.After(podTerminationTimeout) + + log.Infof("Waiting for pods to terminate...") + + for { + select { + case <-timeout: + return fmt.Errorf("timeout waiting for pods to terminate after %v", podTerminationTimeout) + case <-ticker.C: + podList, err := k8sClient.Clientset().CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{ + LabelSelector: labelSelector, + }) + if err != nil { + return fmt.Errorf("failed to list pods: %w", err) + } + + // Filter out pods in terminal states (Succeeded, Failed) or being deleted + activePods := 0 + for _, pod := range podList.Items { + if pod.Status.Phase != corev1.PodSucceeded && + pod.Status.Phase != corev1.PodFailed && + pod.DeletionTimestamp == nil { + activePods++ + } + } + + if activePods == 0 && len(podList.Items) == 0 { + log.Successf("All pods have terminated") + return nil + } + + log.Infof("Waiting for %d pod(s) to terminate...", activePods+len(podList.Items)-activePods) + } + } +} + +// ScaleUpFromAnnotations scales up deployments that have pre-restore-replicas annotations +// This is used to scale up deployments after a background restore job completes +// +//nolint:revive // Package name "scale" with function "ScaleUpFromAnnotations" is intentionally verbose for clarity +func ScaleUpFromAnnotations(k8sClient *k8s.Client, namespace, labelSelector string, log *logger.Logger) error { + log.Infof("Scaling up deployments from annotations (selector: %s)...", labelSelector) + + scaledDeployments, err := k8sClient.ScaleUpDeploymentsFromAnnotations(namespace, labelSelector) + if err != nil { + return fmt.Errorf("failed to scale up deployments from annotations: %w", err) + } + + if len(scaledDeployments) == 0 { + log.Infof("No deployments found with pre-restore annotations to scale up") + return nil + } + + log.Successf("Scaled up %d deployment(s) successfully:", len(scaledDeployments)) + for _, dep := range scaledDeployments { + log.Infof(" - %s (replicas: 0 -> %d)", dep.Name, dep.Replicas) + } + + return nil +} diff --git a/internal/orchestration/scale/scale_test.go b/internal/orchestration/scale/scale_test.go new file mode 100644 index 0000000..9a76bf1 --- /dev/null +++ b/internal/orchestration/scale/scale_test.go @@ -0,0 +1,587 @@ +package scale + +import ( + "context" + "testing" + "time" + + "github.com/stackvista/stackstate-backup-cli/internal/clients/k8s" + "github.com/stackvista/stackstate-backup-cli/internal/foundation/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" +) + +// TestScaleDown_Success tests successful scale down with immediate pod termination +func TestScaleDown_Success(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) // quiet mode for tests + + // Create test deployments + deploy1 := createDeployment("deploy1", map[string]string{"app": "test"}, 3) + deploy2 := createDeployment("deploy2", map[string]string{"app": "test"}, 5) + + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + _, err = fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy2, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Create pods that match the label selector + pod1 := createPod("pod1", map[string]string{"app": "test"}, corev1.PodRunning) + pod2 := createPod("pod2", map[string]string{"app": "test"}, corev1.PodRunning) + + _, err = fakeClient.CoreV1().Pods("test-ns").Create( + context.Background(), &pod1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + _, err = fakeClient.CoreV1().Pods("test-ns").Create( + context.Background(), &pod2, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Delete pods immediately (simulating fast termination) + go func() { + time.Sleep(100 * time.Millisecond) + _ = fakeClient.CoreV1().Pods("test-ns").Delete( + context.Background(), "pod1", metav1.DeleteOptions{}, + ) + _ = fakeClient.CoreV1().Pods("test-ns").Delete( + context.Background(), "pod2", metav1.DeleteOptions{}, + ) + }() + + // Execute scale down + scaledDeployments, err := ScaleDown(client, "test-ns", "app=test", log) + + // Assertions + require.NoError(t, err) + assert.Len(t, scaledDeployments, 2) + assert.Equal(t, "deploy1", scaledDeployments[0].Name) + assert.Equal(t, int32(3), scaledDeployments[0].Replicas) + assert.Equal(t, "deploy2", scaledDeployments[1].Name) + assert.Equal(t, int32(5), scaledDeployments[1].Replicas) + + // Verify deployments were scaled to 0 + deploy1After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(0), *deploy1After.Spec.Replicas) +} + +// TestScaleDown_NoDeployments tests scale down when no deployments match the selector +func TestScaleDown_NoDeployments(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Create deployment with different labels + deploy := createDeployment("other-deploy", map[string]string{"app": "other"}, 2) + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Execute scale down with non-matching selector + scaledDeployments, err := ScaleDown(client, "test-ns", "app=test", log) + + // Assertions + require.NoError(t, err) + assert.Empty(t, scaledDeployments) + + // Verify the other deployment was not scaled + deployAfter, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "other-deploy", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(2), *deployAfter.Spec.Replicas) +} + +// TestScaleDown_PodsInTerminalState tests that pods in terminal states are ignored +func TestScaleDown_PodsInTerminalState(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Create deployment + deploy := createDeployment("deploy1", map[string]string{"app": "test"}, 3) + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Create pods in terminal states (should be ignored by waitForPodsToTerminate) + pod1 := createPod("pod1", map[string]string{"app": "test"}, corev1.PodSucceeded) + pod2 := createPod("pod2", map[string]string{"app": "test"}, corev1.PodFailed) + + _, err = fakeClient.CoreV1().Pods("test-ns").Create( + context.Background(), &pod1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + _, err = fakeClient.CoreV1().Pods("test-ns").Create( + context.Background(), &pod2, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Delete terminal pods (simulating cleanup) + go func() { + time.Sleep(100 * time.Millisecond) + _ = fakeClient.CoreV1().Pods("test-ns").Delete( + context.Background(), "pod1", metav1.DeleteOptions{}, + ) + _ = fakeClient.CoreV1().Pods("test-ns").Delete( + context.Background(), "pod2", metav1.DeleteOptions{}, + ) + }() + + // Execute scale down + scaledDeployments, err := ScaleDown(client, "test-ns", "app=test", log) + + // Assertions - should succeed quickly since pods are in terminal states + require.NoError(t, err) + assert.Len(t, scaledDeployments, 1) +} + +// TestScaleDown_PodsBeingDeleted tests that pods with DeletionTimestamp are handled correctly +func TestScaleDown_PodsBeingDeleted(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Create deployment + deploy := createDeployment("deploy1", map[string]string{"app": "test"}, 2) + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Create pod with DeletionTimestamp set (being deleted) + now := metav1.Now() + pod := createPod("pod1", map[string]string{"app": "test"}, corev1.PodRunning) + pod.DeletionTimestamp = &now + + _, err = fakeClient.CoreV1().Pods("test-ns").Create( + context.Background(), &pod, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Delete pod after a short delay + go func() { + time.Sleep(100 * time.Millisecond) + _ = fakeClient.CoreV1().Pods("test-ns").Delete( + context.Background(), "pod1", metav1.DeleteOptions{}, + ) + }() + + // Execute scale down + scaledDeployments, err := ScaleDown(client, "test-ns", "app=test", log) + + // Assertions + require.NoError(t, err) + assert.Len(t, scaledDeployments, 1) +} + +// TestScaleDown_K8sError tests error handling when K8s API fails +func TestScaleDown_K8sError(t *testing.T) { + // Note: This test demonstrates the pattern, but fake clientset doesn't easily simulate errors + // In a real scenario with mocks, we would inject errors + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Attempt to scale in non-existent namespace (will return empty list, not error with fake client) + scaledDeployments, err := ScaleDown(client, "nonexistent-ns", "app=test", log) + + // With fake client, this succeeds with empty results + require.NoError(t, err) + assert.Empty(t, scaledDeployments) +} + +// TestScaleDown_IntegrationWithScaleUpFromAnnotations tests the full cycle +func TestScaleDown_IntegrationWithScaleUpFromAnnotations(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Create deployments + deploy1 := createDeployment("deploy1", map[string]string{"app": "test"}, 3) + deploy2 := createDeployment("deploy2", map[string]string{"app": "test"}, 5) + + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + _, err = fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy2, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Create pods (will be removed to simulate termination) + pod1 := createPod("pod1", map[string]string{"app": "test"}, corev1.PodRunning) + _, err = fakeClient.CoreV1().Pods("test-ns").Create( + context.Background(), &pod1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Delete pod after delay + go func() { + time.Sleep(100 * time.Millisecond) + _ = fakeClient.CoreV1().Pods("test-ns").Delete( + context.Background(), "pod1", metav1.DeleteOptions{}, + ) + }() + + // Scale down + scaledDeployments, err := ScaleDown(client, "test-ns", "app=test", log) + require.NoError(t, err) + assert.Len(t, scaledDeployments, 2) + + // Verify scaled to 0 + deploy1After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(0), *deploy1After.Spec.Replicas) + + // Verify annotations were added + assert.Equal(t, "3", deploy1After.Annotations[k8s.PreRestoreReplicasAnnotation]) + deploy2After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy2", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, "5", deploy2After.Annotations[k8s.PreRestoreReplicasAnnotation]) + + // Scale back up using annotations + err = ScaleUpFromAnnotations(client, "test-ns", "app=test", log) + require.NoError(t, err) + + // Verify scaled back to original values + deploy1Final, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(3), *deploy1Final.Spec.Replicas) + + // Verify annotation was removed + _, exists := deploy1Final.Annotations[k8s.PreRestoreReplicasAnnotation] + assert.False(t, exists) + + deploy2Final, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy2", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(5), *deploy2Final.Spec.Replicas) + + // Verify annotation was removed + _, exists = deploy2Final.Annotations[k8s.PreRestoreReplicasAnnotation] + assert.False(t, exists) +} + +// TestScaleUpFromAnnotations_Success tests successful scale up from annotations +func TestScaleUpFromAnnotations_Success(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Create deployments at scale 0 with annotations + deploy1 := createDeployment("deploy1", map[string]string{"app": "test"}, 0) + deploy1.Annotations = map[string]string{ + k8s.PreRestoreReplicasAnnotation: "3", + } + deploy2 := createDeployment("deploy2", map[string]string{"app": "test"}, 0) + deploy2.Annotations = map[string]string{ + k8s.PreRestoreReplicasAnnotation: "5", + } + + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + _, err = fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy2, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Scale up from annotations + err = ScaleUpFromAnnotations(client, "test-ns", "app=test", log) + + // Assertions + require.NoError(t, err) + + // Verify deployments were scaled up + deploy1After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(3), *deploy1After.Spec.Replicas) + + // Verify annotation was removed + _, exists := deploy1After.Annotations[k8s.PreRestoreReplicasAnnotation] + assert.False(t, exists) + + deploy2After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy2", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(5), *deploy2After.Spec.Replicas) + + // Verify annotation was removed + _, exists = deploy2After.Annotations[k8s.PreRestoreReplicasAnnotation] + assert.False(t, exists) +} + +// TestScaleUpFromAnnotations_NoAnnotations tests scale up when deployments have no annotations +func TestScaleUpFromAnnotations_NoAnnotations(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Create deployments without annotations + deploy1 := createDeployment("deploy1", map[string]string{"app": "test"}, 0) + deploy2 := createDeployment("deploy2", map[string]string{"app": "test"}, 2) + + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + _, err = fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy2, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Scale up from annotations + err = ScaleUpFromAnnotations(client, "test-ns", "app=test", log) + + // Assertions - should succeed (no-op) + require.NoError(t, err) + + // Verify deployments remain unchanged + deploy1After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(0), *deploy1After.Spec.Replicas) + + deploy2After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy2", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(2), *deploy2After.Spec.Replicas) +} + +// TestScaleUpFromAnnotations_MixedDeployments tests scale up with some annotated and some not +func TestScaleUpFromAnnotations_MixedDeployments(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Create deployment with annotation + deploy1 := createDeployment("deploy1", map[string]string{"app": "test"}, 0) + deploy1.Annotations = map[string]string{ + k8s.PreRestoreReplicasAnnotation: "3", + } + + // Create deployment without annotation + deploy2 := createDeployment("deploy2", map[string]string{"app": "test"}, 0) + + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + _, err = fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy2, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Scale up from annotations + err = ScaleUpFromAnnotations(client, "test-ns", "app=test", log) + + // Assertions + require.NoError(t, err) + + // Verify only annotated deployment was scaled up + deploy1After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(3), *deploy1After.Spec.Replicas) + + // Verify annotation was removed from deploy1 + _, exists := deploy1After.Annotations[k8s.PreRestoreReplicasAnnotation] + assert.False(t, exists) + + // Verify deploy2 remains unchanged + deploy2After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy2", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(0), *deploy2After.Spec.Replicas) +} + +// TestScaleUpFromAnnotations_InvalidAnnotationValue tests error handling for invalid annotation +func TestScaleUpFromAnnotations_InvalidAnnotationValue(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Create deployment with invalid annotation value + deploy1 := createDeployment("deploy1", map[string]string{"app": "test"}, 0) + deploy1.Annotations = map[string]string{ + k8s.PreRestoreReplicasAnnotation: "invalid", + } + + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Scale up from annotations + err = ScaleUpFromAnnotations(client, "test-ns", "app=test", log) + + // Should return error + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse replicas annotation") + + // Verify deployment was not scaled + deploy1After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(0), *deploy1After.Spec.Replicas) +} + +// TestScaleUpFromAnnotations_EmptySelector tests scale up with selector matching no deployments +func TestScaleUpFromAnnotations_EmptySelector(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Create deployment with different labels + deploy1 := createDeployment("deploy1", map[string]string{"app": "other"}, 0) + deploy1.Annotations = map[string]string{ + k8s.PreRestoreReplicasAnnotation: "3", + } + + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Scale up from annotations with non-matching selector + err = ScaleUpFromAnnotations(client, "test-ns", "app=test", log) + + // Should succeed (no-op) + require.NoError(t, err) + + // Verify deployment was not scaled + deploy1After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(0), *deploy1After.Spec.Replicas) + + // Verify annotation still exists + assert.Equal(t, "3", deploy1After.Annotations[k8s.PreRestoreReplicasAnnotation]) +} + +// TestScaleUpFromAnnotations_ZeroReplicas tests scale up with annotation value "0" +func TestScaleUpFromAnnotations_ZeroReplicas(t *testing.T) { + fakeClient := fake.NewSimpleClientset() + client := k8s.NewTestClient(fakeClient) + log := logger.New(true, false) + + // Create deployment with annotation value "0" + deploy1 := createDeployment("deploy1", map[string]string{"app": "test"}, 0) + deploy1.Annotations = map[string]string{ + k8s.PreRestoreReplicasAnnotation: "0", + } + + _, err := fakeClient.AppsV1().Deployments("test-ns").Create( + context.Background(), &deploy1, metav1.CreateOptions{}, + ) + require.NoError(t, err) + + // Scale up from annotations + err = ScaleUpFromAnnotations(client, "test-ns", "app=test", log) + + // Should succeed + require.NoError(t, err) + + // Verify deployment remains at 0 replicas + deploy1After, err := fakeClient.AppsV1().Deployments("test-ns").Get( + context.Background(), "deploy1", metav1.GetOptions{}, + ) + require.NoError(t, err) + assert.Equal(t, int32(0), *deploy1After.Spec.Replicas) + + // Verify annotation was removed + _, exists := deploy1After.Annotations[k8s.PreRestoreReplicasAnnotation] + assert.False(t, exists) +} + +// Helper function to create a deployment for testing +func createDeployment(name string, labels map[string]string, replicas int32) appsv1.Deployment { + return appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: "test-ns", + Labels: labels, + }, + Spec: appsv1.DeploymentSpec{ + Replicas: &replicas, + Selector: &metav1.LabelSelector{ + MatchLabels: labels, + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: labels, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test:latest", + }, + }, + }, + }, + }, + } +} + +// Helper function to create a pod for testing +func createPod(name string, labels map[string]string, phase corev1.PodPhase) corev1.Pod { + return corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: "test-ns", + Labels: labels, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "test:latest", + }, + }, + }, + Status: corev1.PodStatus{ + Phase: phase, + }, + } +} diff --git a/internal/scripts/scripts.go b/internal/scripts/scripts.go new file mode 100644 index 0000000..1835e79 --- /dev/null +++ b/internal/scripts/scripts.go @@ -0,0 +1,48 @@ +// Package scripts provides embedded backup/restore scripts that are compiled into the binary. +// These scripts are used by Kubernetes Jobs for backup and restore operations. +package scripts + +import ( + "embed" + "fmt" + "io/fs" +) + +// Embed all scripts from the scripts directory at the root of the project +// Note: embed paths are relative to the source file, but we can only go down, not up +// So we need to embed from a different approach - embed the actual files +// +//go:embed all:scripts +var embeddedScripts embed.FS + +// GetScript retrieves an embedded script by filename +func GetScript(filename string) ([]byte, error) { + path := fmt.Sprintf("scripts/%s", filename) + data, err := embeddedScripts.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read embedded script %s: %w", filename, err) + } + return data, nil +} + +// ListScripts returns a list of all embedded script filenames +func ListScripts() ([]string, error) { + entries, err := embeddedScripts.ReadDir("scripts") + if err != nil { + return nil, fmt.Errorf("failed to list embedded scripts: %w", err) + } + + var scripts []string + for _, entry := range entries { + if !entry.IsDir() { + scripts = append(scripts, entry.Name()) + } + } + return scripts, nil +} + +// GetScriptsFS returns the embedded filesystem containing all scripts +// This can be used to access scripts without extracting them individually +func GetScriptsFS() (fs.FS, error) { + return fs.Sub(embeddedScripts, "scripts") +} diff --git a/internal/scripts/scripts/restore-stackgraph-backup.sh b/internal/scripts/scripts/restore-stackgraph-backup.sh new file mode 100644 index 0000000..2e89c30 --- /dev/null +++ b/internal/scripts/scripts/restore-stackgraph-backup.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +set -Eeuo pipefail + +TMP_DIR=/tmp-data + +export AWS_ACCESS_KEY_ID +AWS_ACCESS_KEY_ID="$(cat /aws-keys/accesskey)" +export AWS_SECRET_ACCESS_KEY +AWS_SECRET_ACCESS_KEY="$(cat /aws-keys/secretkey)" + +echo "=== Downloading StackGraph backup \"${BACKUP_FILE}\" from bucket \"${BACKUP_STACKGRAPH_BUCKET_NAME}\"..." + +if [ "${BACKUP_STACKGRAPH_MULTIPART_ARCHIVE:-false}" == "false" ]; then + sts-toolbox aws s3 cp --endpoint "http://${MINIO_ENDPOINT}" --region minio "s3://${BACKUP_STACKGRAPH_BUCKET_NAME}/${BACKUP_STACKGRAPH_S3_PREFIX}${BACKUP_FILE}" "${TMP_DIR}/${BACKUP_FILE}" +else + # Check if the filename of the snapshot is one of the multiparts + # sts-backup-20240222-0730.graph.00 -> sts-backup-20240222-0730.graph + BACKUP_FILE="${BACKUP_FILE/%.[0-9]*/}" + rm -f "${TMP_DIR}/${BACKUP_FILE}.*" + sts-toolbox aws s3 ls --endpoint "http://${MINIO_ENDPOINT}" --region minio --bucket "${BACKUP_STACKGRAPH_BUCKET_NAME}" --prefix "${BACKUP_STACKGRAPH_S3_PREFIX}${BACKUP_FILE}" | while read -r backup_file + do + sts-toolbox aws s3 cp --endpoint "http://${MINIO_ENDPOINT}" --region minio "s3://${BACKUP_STACKGRAPH_BUCKET_NAME}/${BACKUP_STACKGRAPH_S3_PREFIX}${backup_file}" "${TMP_DIR}/${backup_file}" + done + # Concatenate a multipart arhive + find ${TMP_DIR} -name "${BACKUP_FILE}.*" | sort | while read -r multipart + do + cat "${multipart}" >> "${TMP_DIR}/${BACKUP_FILE}" + done +fi + +echo "=== Importing StackGraph data from \"${BACKUP_FILE}\"..." +/opt/docker/bin/stackstate-server -Dlogback.configurationFile=/opt/docker/etc_log/logback.xml -import "${TMP_DIR}/${BACKUP_FILE}" "${FORCE_DELETE}" +echo "===" diff --git a/internal/scripts/scripts_test.go b/internal/scripts/scripts_test.go new file mode 100644 index 0000000..b5375b6 --- /dev/null +++ b/internal/scripts/scripts_test.go @@ -0,0 +1,349 @@ +package scripts + +import ( + "io/fs" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGetScript tests retrieving embedded scripts +func TestGetScript(t *testing.T) { + tests := []struct { + name string + filename string + expectError bool + validate func(*testing.T, []byte) + }{ + { + name: "retrieve existing script", + filename: "restore-stackgraph-backup.sh", + expectError: false, + validate: func(t *testing.T, data []byte) { + assert.NotEmpty(t, data, "Script content should not be empty") + assert.Greater(t, len(data), 100, "Script should have substantial content") + // Verify it's a shell script + assert.Contains(t, string(data), "#!/", "Script should have shebang") + }, + }, + { + name: "nonexistent script", + filename: "nonexistent-script.sh", + expectError: true, + validate: nil, + }, + { + name: "empty filename", + filename: "", + expectError: true, + validate: nil, + }, + { + name: "filename with path traversal attempt", + filename: "../../../etc/passwd", + expectError: true, + validate: nil, + }, + { + name: "filename with absolute path", + filename: "/etc/passwd", + expectError: true, + validate: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := GetScript(tt.filename) + + if tt.expectError { + assert.Error(t, err, "Should return error for invalid filename") + assert.Nil(t, data, "Data should be nil on error") + } else { + require.NoError(t, err, "Should not return error for valid filename") + assert.NotNil(t, data, "Data should not be nil") + if tt.validate != nil { + tt.validate(t, data) + } + } + }) + } +} + +// TestGetScript_ContentValidation tests that scripts have expected content +func TestGetScript_ContentValidation(t *testing.T) { + // Get the restore script + data, err := GetScript("restore-stackgraph-backup.sh") + require.NoError(t, err) + require.NotEmpty(t, data) + + content := string(data) + + // Verify it's a valid shell script + assert.Contains(t, content, "#!/", "Script should have shebang") + + // Verify it contains expected backup/restore commands + // (adjust these assertions based on actual script content) + assert.NotContains(t, content, "<<<<<<", "Script should not contain merge conflict markers") + assert.NotContains(t, content, ">>>>>>", "Script should not contain merge conflict markers") +} + +// TestGetScript_SameContentOnMultipleCalls verifies deterministic behavior +func TestGetScript_SameContentOnMultipleCalls(t *testing.T) { + filename := "restore-stackgraph-backup.sh" + + // Get script multiple times + data1, err1 := GetScript(filename) + require.NoError(t, err1) + + data2, err2 := GetScript(filename) + require.NoError(t, err2) + + data3, err3 := GetScript(filename) + require.NoError(t, err3) + + // All calls should return identical data + assert.Equal(t, data1, data2, "Multiple calls should return same content") + assert.Equal(t, data2, data3, "Multiple calls should return same content") +} + +// TestListScripts tests listing all embedded scripts +func TestListScripts(t *testing.T) { + scripts, err := ListScripts() + + require.NoError(t, err, "ListScripts should not return error") + assert.NotEmpty(t, scripts, "Should have at least one script") + + // Verify expected script is in the list + assert.Contains(t, scripts, "restore-stackgraph-backup.sh", "Should contain restore script") + + // Verify no duplicates + seen := make(map[string]bool) + for _, script := range scripts { + assert.False(t, seen[script], "Script list should not contain duplicates: %s", script) + seen[script] = true + } + + // Verify all entries are files (not directories) + for _, script := range scripts { + assert.NotEmpty(t, script, "Script filename should not be empty") + // File extension check (scripts should be .sh files) + if len(script) > 3 { + // Most should be shell scripts, but we won't enforce it strictly + t.Logf("Found script: %s", script) + } + } +} + +// TestListScripts_Consistency tests that ListScripts is deterministic +func TestListScripts_Consistency(t *testing.T) { + scripts1, err1 := ListScripts() + require.NoError(t, err1) + + scripts2, err2 := ListScripts() + require.NoError(t, err2) + + scripts3, err3 := ListScripts() + require.NoError(t, err3) + + // All calls should return same number of scripts + assert.Equal(t, len(scripts1), len(scripts2), "Multiple calls should return same number of scripts") + assert.Equal(t, len(scripts2), len(scripts3), "Multiple calls should return same number of scripts") + + // Convert to maps for easier comparison + toMap := func(scripts []string) map[string]bool { + m := make(map[string]bool) + for _, s := range scripts { + m[s] = true + } + return m + } + + map1 := toMap(scripts1) + map2 := toMap(scripts2) + map3 := toMap(scripts3) + + assert.Equal(t, map1, map2, "Multiple calls should return same scripts") + assert.Equal(t, map2, map3, "Multiple calls should return same scripts") +} + +// TestListScripts_VerifyScriptsExist tests that listed scripts can be retrieved +func TestListScripts_VerifyScriptsExist(t *testing.T) { + scripts, err := ListScripts() + require.NoError(t, err) + require.NotEmpty(t, scripts) + + // Verify each listed script can be retrieved + for _, script := range scripts { + t.Run(script, func(t *testing.T) { + data, err := GetScript(script) + assert.NoError(t, err, "Should be able to retrieve listed script: %s", script) + assert.NotEmpty(t, data, "Retrieved script should not be empty: %s", script) + }) + } +} + +// TestGetScriptsFS tests getting the embedded filesystem +func TestGetScriptsFS(t *testing.T) { + scriptsFS, err := GetScriptsFS() + + require.NoError(t, err, "GetScriptsFS should not return error") + assert.NotNil(t, scriptsFS, "FS should not be nil") + + // Verify we can read files from the FS + entries, err := fs.ReadDir(scriptsFS, ".") + require.NoError(t, err, "Should be able to read directory from FS") + assert.NotEmpty(t, entries, "FS should contain files") + + // Verify expected file exists + found := false + for _, entry := range entries { + if entry.Name() == "restore-stackgraph-backup.sh" { + found = true + assert.False(t, entry.IsDir(), "restore-stackgraph-backup.sh should be a file") + } + } + assert.True(t, found, "Should find restore-stackgraph-backup.sh in FS") +} + +// TestGetScriptsFS_ReadFile tests reading files from the embedded FS +func TestGetScriptsFS_ReadFile(t *testing.T) { + scriptsFS, err := GetScriptsFS() + require.NoError(t, err) + + // Read a known script file + data, err := fs.ReadFile(scriptsFS, "restore-stackgraph-backup.sh") + require.NoError(t, err, "Should be able to read file from FS") + assert.NotEmpty(t, data, "File content should not be empty") + + // Compare with GetScript result + directData, err := GetScript("restore-stackgraph-backup.sh") + require.NoError(t, err) + + assert.Equal(t, directData, data, "FS read should match GetScript result") +} + +// TestGetScriptsFS_Walk tests walking the embedded FS +func TestGetScriptsFS_Walk(t *testing.T) { + scriptsFS, err := GetScriptsFS() + require.NoError(t, err) + + fileCount := 0 + err = fs.WalkDir(scriptsFS, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() { + fileCount++ + t.Logf("Found file in FS: %s", path) + } + return nil + }) + + require.NoError(t, err, "Walking FS should not error") + assert.Greater(t, fileCount, 0, "Should find at least one file when walking FS") +} + +// TestGetScriptsFS_Consistency tests that GetScriptsFS returns consistent results +func TestGetScriptsFS_Consistency(t *testing.T) { + fs1, err1 := GetScriptsFS() + require.NoError(t, err1) + + fs2, err2 := GetScriptsFS() + require.NoError(t, err2) + + // Read same file from both FS instances + data1, err := fs.ReadFile(fs1, "restore-stackgraph-backup.sh") + require.NoError(t, err) + + data2, err := fs.ReadFile(fs2, "restore-stackgraph-backup.sh") + require.NoError(t, err) + + assert.Equal(t, data1, data2, "Multiple FS instances should provide same file content") +} + +// TestEmbeddedScriptsNotNil verifies the embedded filesystem is initialized +func TestEmbeddedScriptsNotNil(t *testing.T) { + // This test verifies that the embed.FS is properly initialized + // by attempting to read from it + _, err := embeddedScripts.ReadDir("scripts") + assert.NoError(t, err, "Embedded scripts filesystem should be initialized") +} + +// TestErrorMessages tests that error messages are descriptive +func TestErrorMessages(t *testing.T) { + tests := []struct { + name string + filename string + contains string + }{ + { + name: "nonexistent file error message", + filename: "does-not-exist.sh", + contains: "does-not-exist.sh", + }, + { + name: "invalid path error message", + filename: "../../../etc/passwd", + contains: "passwd", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := GetScript(tt.filename) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.contains, "Error message should mention the problematic filename") + }) + } +} + +// TestConcurrentAccess tests that concurrent access to scripts is safe +func TestConcurrentAccess(t *testing.T) { + const numGoroutines = 10 + const numIterations = 5 + + done := make(chan bool, numGoroutines) + errors := make(chan error, numGoroutines*numIterations) + + for i := 0; i < numGoroutines; i++ { + go func() { + for j := 0; j < numIterations; j++ { + // Test GetScript + _, err := GetScript("restore-stackgraph-backup.sh") + if err != nil { + errors <- err + } + + // Test ListScripts + _, err = ListScripts() + if err != nil { + errors <- err + } + + // Test GetScriptsFS + _, err = GetScriptsFS() + if err != nil { + errors <- err + } + } + done <- true + }() + } + + // Wait for all goroutines + for i := 0; i < numGoroutines; i++ { + <-done + } + close(errors) + + // Check for errors + errorCount := 0 + for err := range errors { + t.Errorf("Concurrent access error: %v", err) + errorCount++ + } + + assert.Equal(t, 0, errorCount, "Concurrent access should be safe") +}