From d3430b360b7b841fe862c4b4f76a1decaf80a8d2 Mon Sep 17 00:00:00 2001 From: Marc Nuri Date: Mon, 1 Dec 2025 14:38:30 +0100 Subject: [PATCH] fix(kubernetes)!: Provider updated for each watcher callback The kubernetes.Provider watch-related functionality is restarted for each watch notification that affects its state. Signed-off-by: Marc Nuri --- pkg/kubernetes/provider_kubeconfig.go | 56 ++++--- pkg/kubernetes/provider_single.go | 71 +++++---- pkg/kubernetes/provider_watch_test.go | 180 ++++++++++++++++++++++ pkg/kubernetes/watcher/cluster.go | 9 +- pkg/kubernetes/watcher/cluster_test.go | 45 ++---- pkg/kubernetes/watcher/kubeconfig.go | 11 +- pkg/kubernetes/watcher/kubeconfig_test.go | 34 ++-- pkg/kubernetes/watcher/watcher.go | 2 +- pkg/mcp/mcp_watch_test.go | 28 ++++ 9 files changed, 326 insertions(+), 110 deletions(-) create mode 100644 pkg/kubernetes/provider_watch_test.go diff --git a/pkg/kubernetes/provider_kubeconfig.go b/pkg/kubernetes/provider_kubeconfig.go index 77d0cd24..7f66a4af 100644 --- a/pkg/kubernetes/provider_kubeconfig.go +++ b/pkg/kubernetes/provider_kubeconfig.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "reflect" "github.com/containers/kubernetes-mcp-server/pkg/config" "github.com/containers/kubernetes-mcp-server/pkg/kubernetes/watcher" @@ -18,6 +19,7 @@ const KubeConfigTargetParameterName = "context" // Kubernetes clusters using different contexts from a kubeconfig file. // It lazily initializes managers for each context as they are requested. type kubeConfigClusterProvider struct { + staticConfig *config.StaticConfig defaultContext string managers map[string]*Manager kubeconfigWatcher *watcher.Kubeconfig @@ -35,20 +37,28 @@ func init() { // Internally, it leverages a KubeconfigManager for each context, initializing them // lazily when requested. func newKubeConfigClusterProvider(cfg *config.StaticConfig) (Provider, error) { - m, err := NewKubeconfigManager(cfg, "") + ret := &kubeConfigClusterProvider{staticConfig: cfg} + if err := ret.reset(); err != nil { + return nil, err + } + return ret, nil +} + +func (p *kubeConfigClusterProvider) reset() error { + m, err := NewKubeconfigManager(p.staticConfig, "") if err != nil { if errors.Is(err, ErrorKubeconfigInClusterNotAllowed) { - return nil, fmt.Errorf("kubeconfig ClusterProviderStrategy is invalid for in-cluster deployments: %v", err) + return fmt.Errorf("kubeconfig ClusterProviderStrategy is invalid for in-cluster deployments: %v", err) } - return nil, err + return err } rawConfig, err := m.accessControlClientset.clientCmdConfig.RawConfig() if err != nil { - return nil, err + return err } - allClusterManagers := map[string]*Manager{ + p.managers = map[string]*Manager{ rawConfig.CurrentContext: m, // we already initialized a manager for the default context, let's use it } @@ -56,16 +66,15 @@ func newKubeConfigClusterProvider(cfg *config.StaticConfig) (Provider, error) { if name == rawConfig.CurrentContext { continue // already initialized this, don't want to set it to nil } - - allClusterManagers[name] = nil + p.managers[name] = nil } - return &kubeConfigClusterProvider{ - defaultContext: rawConfig.CurrentContext, - managers: allClusterManagers, - kubeconfigWatcher: watcher.NewKubeconfig(m.accessControlClientset.clientCmdConfig), - clusterStateWatcher: watcher.NewClusterState(m.accessControlClientset.DiscoveryClient()), - }, nil + p.Close() + p.kubeconfigWatcher = watcher.NewKubeconfig(m.accessControlClientset.clientCmdConfig) + p.clusterStateWatcher = watcher.NewClusterState(m.accessControlClientset.DiscoveryClient()) + p.defaultContext = rawConfig.CurrentContext + + return nil } func (p *kubeConfigClusterProvider) managerForContext(context string) (*Manager, error) { @@ -124,20 +133,21 @@ func (p *kubeConfigClusterProvider) GetDefaultTarget() string { } func (p *kubeConfigClusterProvider) WatchTargets(reload McpReload) { - reloadWithCacheInvalidate := func() error { - // Invalidate all cached managers to force reloading on next access - for contextName := range p.managers { - if m := p.managers[contextName]; m != nil { - m.Invalidate() - } + reloadWithReset := func() error { + if err := p.reset(); err != nil { + return err } + p.WatchTargets(reload) return reload() } - p.kubeconfigWatcher.Watch(reloadWithCacheInvalidate) - p.clusterStateWatcher.Watch(reloadWithCacheInvalidate) + p.kubeconfigWatcher.Watch(reloadWithReset) + p.clusterStateWatcher.Watch(reload) } func (p *kubeConfigClusterProvider) Close() { - _ = p.kubeconfigWatcher.Close() - _ = p.clusterStateWatcher.Close() + for _, w := range []watcher.Watcher{p.kubeconfigWatcher, p.clusterStateWatcher} { + if !reflect.ValueOf(w).IsNil() { + w.Close() + } + } } diff --git a/pkg/kubernetes/provider_single.go b/pkg/kubernetes/provider_single.go index 965acacf..0996e487 100644 --- a/pkg/kubernetes/provider_single.go +++ b/pkg/kubernetes/provider_single.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "reflect" "github.com/containers/kubernetes-mcp-server/pkg/config" "github.com/containers/kubernetes-mcp-server/pkg/kubernetes/watcher" @@ -14,6 +15,7 @@ import ( // Kubernetes cluster. Used for in-cluster deployments or when multi-cluster // support is disabled. type singleClusterProvider struct { + staticConfig *config.StaticConfig strategy string manager *Manager kubeconfigWatcher *watcher.Kubeconfig @@ -32,31 +34,41 @@ func init() { // Otherwise, it uses a KubeconfigManager. func newSingleClusterProvider(strategy string) ProviderFactory { return func(cfg *config.StaticConfig) (Provider, error) { - if cfg != nil && cfg.KubeConfig != "" && strategy == config.ClusterProviderInCluster { - return nil, fmt.Errorf("kubeconfig file %s cannot be used with the in-cluster ClusterProviderStrategy", cfg.KubeConfig) + ret := &singleClusterProvider{ + staticConfig: cfg, + strategy: strategy, } - - var m *Manager - var err error - if strategy == config.ClusterProviderInCluster || IsInCluster(cfg) { - m, err = NewInClusterManager(cfg) - } else { - m, err = NewKubeconfigManager(cfg, "") - } - if err != nil { - if errors.Is(err, ErrorInClusterNotInCluster) { - return nil, fmt.Errorf("server must be deployed in cluster for the %s ClusterProviderStrategy: %v", strategy, err) - } + if err := ret.reset(); err != nil { return nil, err } + return ret, nil + } +} + +func (p *singleClusterProvider) reset() error { + if p.staticConfig != nil && p.staticConfig.KubeConfig != "" && p.strategy == config.ClusterProviderInCluster { + return fmt.Errorf("kubeconfig file %s cannot be used with the in-cluster ClusterProviderStrategy", + p.staticConfig.KubeConfig) + } - return &singleClusterProvider{ - manager: m, - strategy: strategy, - kubeconfigWatcher: watcher.NewKubeconfig(m.accessControlClientset.clientCmdConfig), - clusterStateWatcher: watcher.NewClusterState(m.accessControlClientset.DiscoveryClient()), - }, nil + var err error + if p.strategy == config.ClusterProviderInCluster || IsInCluster(p.staticConfig) { + p.manager, err = NewInClusterManager(p.staticConfig) + } else { + p.manager, err = NewKubeconfigManager(p.staticConfig, "") } + if err != nil { + if errors.Is(err, ErrorInClusterNotInCluster) { + return fmt.Errorf("server must be deployed in cluster for the %s ClusterProviderStrategy: %v", + p.strategy, err) + } + return err + } + + p.Close() + p.kubeconfigWatcher = watcher.NewKubeconfig(p.manager.accessControlClientset.clientCmdConfig) + p.clusterStateWatcher = watcher.NewClusterState(p.manager.accessControlClientset.DiscoveryClient()) + return nil } func (p *singleClusterProvider) IsOpenShift(ctx context.Context) bool { @@ -91,16 +103,21 @@ func (p *singleClusterProvider) GetTargetParameterName() string { } func (p *singleClusterProvider) WatchTargets(reload McpReload) { - reloadWithCacheInvalidate := func() error { - // Invalidate all cached managers to force reloading on next access - p.manager.Invalidate() + reloadWithReset := func() error { + if err := p.reset(); err != nil { + return err + } + p.WatchTargets(reload) return reload() } - p.kubeconfigWatcher.Watch(reloadWithCacheInvalidate) - p.clusterStateWatcher.Watch(reloadWithCacheInvalidate) + p.kubeconfigWatcher.Watch(reloadWithReset) + p.clusterStateWatcher.Watch(reload) } func (p *singleClusterProvider) Close() { - _ = p.kubeconfigWatcher.Close() - _ = p.clusterStateWatcher.Close() + for _, w := range []watcher.Watcher{p.kubeconfigWatcher, p.clusterStateWatcher} { + if !reflect.ValueOf(w).IsNil() { + w.Close() + } + } } diff --git a/pkg/kubernetes/provider_watch_test.go b/pkg/kubernetes/provider_watch_test.go new file mode 100644 index 00000000..95e6caa2 --- /dev/null +++ b/pkg/kubernetes/provider_watch_test.go @@ -0,0 +1,180 @@ +package kubernetes + +import ( + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/containers/kubernetes-mcp-server/internal/test" + "github.com/containers/kubernetes-mcp-server/pkg/config" + "github.com/stretchr/testify/suite" + "k8s.io/client-go/tools/clientcmd" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" +) + +type ProviderWatchTargetsTestSuite struct { + suite.Suite + mockServer *test.MockServer + discoveryClientHandler *test.DiscoveryClientHandler + kubeconfig *clientcmdapi.Config + staticConfig *config.StaticConfig +} + +func (s *ProviderWatchTargetsTestSuite) SetupTest() { + s.mockServer = test.NewMockServer() + s.discoveryClientHandler = &test.DiscoveryClientHandler{} + s.mockServer.Handle(s.discoveryClientHandler) + + s.T().Setenv("CLUSTER_STATE_POLL_INTERVAL_MS", "100") + s.T().Setenv("CLUSTER_STATE_DEBOUNCE_WINDOW_MS", "50") + + // Add multiple fake contexts to allow testing of context changes + s.kubeconfig = s.mockServer.Kubeconfig() + for i := 0; i < 10; i++ { + name := fmt.Sprintf("context-%d", i) + s.kubeconfig.Contexts[name] = clientcmdapi.NewContext() + s.kubeconfig.Contexts[name].Cluster = s.kubeconfig.Contexts[s.kubeconfig.CurrentContext].Cluster + s.kubeconfig.Contexts[name].AuthInfo = s.kubeconfig.Contexts[s.kubeconfig.CurrentContext].AuthInfo + } + + s.staticConfig = &config.StaticConfig{KubeConfig: test.KubeconfigFile(s.T(), s.kubeconfig)} +} + +func (s *ProviderWatchTargetsTestSuite) TearDownTest() { + s.mockServer.Close() +} + +func (s *ProviderWatchTargetsTestSuite) TestClusterStateChanges() { + testCases := []func() (Provider, error){ + func() (Provider, error) { return newKubeConfigClusterProvider(s.staticConfig) }, + func() (Provider, error) { + return newSingleClusterProvider(config.ClusterProviderDisabled)(s.staticConfig) + }, + } + for _, tc := range testCases { + provider, err := tc() + s.Require().NoError(err, "Expected no error from provider creation") + + s.Run("With provider "+reflect.TypeOf(provider).String(), func() { + callback, waitForCallback := CallbackWaiter() + provider.WatchTargets(callback) + s.Run("Reloads provider on cluster changes", func() { + s.discoveryClientHandler.Groups = append(s.discoveryClientHandler.Groups, `{"name":"alex.example.com","versions":[{"groupVersion":"alex.example.com/v1","version":"v1"}],"preferredVersion":{"groupVersion":"alex.example.com/v1","version":"v1"}}`) + + s.Require().NoError(waitForCallback(5 * time.Second)) + // Provider-wise the watcher.ClusterState which triggers the callback has no effect. + // We might consider removing it at some point? (20251202) + }) + }) + } +} + +func (s *ProviderWatchTargetsTestSuite) TestKubeConfigClusterProvider() { + provider, err := newKubeConfigClusterProvider(s.staticConfig) + s.Require().NoError(err, "Expected no error from provider creation") + + callback, waitForCallback := CallbackWaiter() + provider.WatchTargets(callback) + + s.Run("KubeConfigClusterProvider updates targets (reset) on kubeconfig change", func() { + s.kubeconfig.CurrentContext = "context-1" + s.Require().NoError(clientcmd.WriteToFile(*s.kubeconfig, s.staticConfig.KubeConfig)) + s.Require().NoError(waitForCallback(5 * time.Second)) + + s.Run("Replaces default target with new context", func() { + s.Equal("context-1", provider.GetDefaultTarget(), "Expected default target context to be updated") + }) + s.Run("Adds new context to targets", func() { + targets, err := provider.GetTargets(s.T().Context()) + s.Require().NoError(err, "Expected no error from GetTargets") + s.Contains(targets, "context-1") + }) + s.Run("Has derived Kubernetes for new context", func() { + k, err := provider.GetDerivedKubernetes(s.T().Context(), "context-1") + s.Require().NoError(err, "Expected no error from GetDerivedKubernetes for context-1") + s.NotNil(k, "Expected Kubernetes from GetDerivedKubernetes for context-1") + s.Run("Derived Kubernetes points to correct context", func() { + cfg, err := k.AccessControlClientset().ToRawKubeConfigLoader().RawConfig() + s.Require().NoError(err, "Expected no error from ToRawKubeConfigLoader") + s.Equal("context-1", cfg.CurrentContext, "Expected Kubernetes to point to changed-context") + }) + }) + + s.Run("Keeps watching for further changes", func() { + s.kubeconfig.CurrentContext = "context-2" + s.Require().NoError(clientcmd.WriteToFile(*s.kubeconfig, s.staticConfig.KubeConfig)) + s.Require().NoError(waitForCallback(5 * time.Second)) + + s.Run("Replaces default target with new context", func() { + s.Equal("context-2", provider.GetDefaultTarget(), "Expected default target context to be updated") + }) + }) + }) +} + +func (s *ProviderWatchTargetsTestSuite) TestSingleClusterProvider() { + provider, err := newSingleClusterProvider(config.ClusterProviderDisabled)(s.staticConfig) + s.Require().NoError(err, "Expected no error from provider creation") + + callback, waitForCallback := CallbackWaiter() + provider.WatchTargets(callback) + + s.Run("SingleClusterProvider reloads/resets on kubeconfig change", func() { + s.kubeconfig.CurrentContext = "context-1" + s.Require().NoError(clientcmd.WriteToFile(*s.kubeconfig, s.staticConfig.KubeConfig)) + s.Require().NoError(waitForCallback(5 * time.Second)) + + s.Run("Derived Kubernetes points to updated context", func() { + k, err := provider.GetDerivedKubernetes(s.T().Context(), "") + s.Require().NoError(err, "Expected no error from GetDerivedKubernetes for context-1") + s.NotNil(k, "Expected Kubernetes from GetDerivedKubernetes for context-1") + s.Run("Derived Kubernetes points to correct context", func() { + cfg, err := k.AccessControlClientset().ToRawKubeConfigLoader().RawConfig() + s.Require().NoError(err, "Expected no error from ToRawKubeConfigLoader") + s.Equal("context-1", cfg.CurrentContext, "Expected Kubernetes to point to changed-context") + }) + }) + + s.Run("Keeps watching for further changes", func() { + s.kubeconfig.CurrentContext = "context-2" + s.Require().NoError(clientcmd.WriteToFile(*s.kubeconfig, s.staticConfig.KubeConfig)) + s.Require().NoError(waitForCallback(5 * time.Second)) + + s.Run("Derived Kubernetes points to updated context", func() { + k, err := provider.GetDerivedKubernetes(s.T().Context(), "") + s.Require().NoError(err, "Expected no error from GetDerivedKubernetes for context-2") + s.NotNil(k, "Expected Kubernetes from GetDerivedKubernetes for context-2") + cfg, err := k.AccessControlClientset().ToRawKubeConfigLoader().RawConfig() + s.Require().NoError(err, "Expected no error from ToRawKubeConfigLoader") + s.Equal("context-2", cfg.CurrentContext, "Expected Kubernetes to point to changed-context") + }) + }) + }) +} + +// CallbackWaiter returns a callback and wait function that can be used multiple times. +func CallbackWaiter() (callback func() error, waitFunc func(timeout time.Duration) error) { + signal := make(chan struct{}, 1) + callback = func() error { + select { + case signal <- struct{}{}: + default: + } + return nil + } + waitFunc = func(timeout time.Duration) error { + select { + case <-signal: + case <-time.After(timeout): + return errors.New("timeout waiting for callback") + } + return nil + } + return +} + +func TestProviderWatchTargetsTestSuite(t *testing.T) { + suite.Run(t, new(ProviderWatchTargetsTestSuite)) +} diff --git a/pkg/kubernetes/watcher/cluster.go b/pkg/kubernetes/watcher/cluster.go index 1f07bb13..09f1db15 100644 --- a/pkg/kubernetes/watcher/cluster.go +++ b/pkg/kubernetes/watcher/cluster.go @@ -137,7 +137,7 @@ func (w *ClusterState) Watch(onChange func() error) { } // Close stops the cluster state watcher -func (w *ClusterState) Close() error { +func (w *ClusterState) Close() { w.mu.Lock() defer w.mu.Unlock() @@ -146,17 +146,17 @@ func (w *ClusterState) Close() error { } if w.stopCh == nil || w.stoppedCh == nil { - return nil + return // Already closed } if !w.started { - return nil + return } select { case <-w.stopCh: // Already closed or stopped - return nil + return default: close(w.stopCh) w.mu.Unlock() @@ -167,7 +167,6 @@ func (w *ClusterState) Close() error { w.stopCh = make(chan struct{}) w.stoppedCh = make(chan struct{}) } - return nil } func (w *ClusterState) captureState() clusterState { diff --git a/pkg/kubernetes/watcher/cluster_test.go b/pkg/kubernetes/watcher/cluster_test.go index 42e234e7..12cec4b1 100644 --- a/pkg/kubernetes/watcher/cluster_test.go +++ b/pkg/kubernetes/watcher/cluster_test.go @@ -173,7 +173,7 @@ func (s *ClusterStateTestSuite) TestWatch() { go func() { watcher.Watch(onChange) }() - defer func() { _ = watcher.Close() }() + s.T().Cleanup(watcher.Close) // Wait for the watcher to capture initial state s.waitForWatcherState(watcher) @@ -210,7 +210,7 @@ func (s *ClusterStateTestSuite) TestWatch() { go func() { watcher.Watch(onChange) }() - s.T().Cleanup(func() { _ = watcher.Close() }) + s.T().Cleanup(watcher.Close) // Wait for initial state capture s.waitForWatcherState(watcher) @@ -244,7 +244,7 @@ func (s *ClusterStateTestSuite) TestWatch() { go func() { watcher.Watch(onChange) }() - s.T().Cleanup(func() { _ = watcher.Close() }) + s.T().Cleanup(watcher.Close) // Wait for the watcher to capture initial state s.waitForWatcherState(watcher) @@ -277,7 +277,7 @@ func (s *ClusterStateTestSuite) TestWatch() { go func() { watcher.Watch(onChange) }() - s.T().Cleanup(func() { _ = watcher.Close() }) + s.T().Cleanup(watcher.Close) // Wait for the watcher to start and capture initial state s.waitForWatcherState(watcher) @@ -317,11 +317,8 @@ func (s *ClusterStateTestSuite) TestClose() { // Wait for the watcher to start s.waitForWatcherState(watcher) - err := watcher.Close() + watcher.Close() - s.Run("returns no error", func() { - s.NoError(err) - }) s.Run("stops polling", func() { beforeCount := callCount.Load() // Wait longer than poll interval to verify no more polling @@ -343,14 +340,9 @@ func (s *ClusterStateTestSuite) TestClose() { onChange := func() error { return nil } watcher.Watch(onChange) - err1 := watcher.Close() - err2 := watcher.Close() - - s.Run("first close succeeds", func() { - s.NoError(err1) - }) - s.Run("second close succeeds", func() { - s.NoError(err2) + s.NotPanics(func() { + watcher.Close() + watcher.Close() }) }) @@ -390,11 +382,12 @@ func (s *ClusterStateTestSuite) TestClose() { }), "timeout waiting for debounce timer to start") // Close the watcher before debounce window expires - err := watcher.Close() - s.NoError(err, "close should succeed") + watcher.Close() - // Verify onChange was not called (debounce timer was stopped) - s.Equal(int32(0), callCount.Load(), "onChange should not be called because debounce timer was stopped") + s.Run("debounce timer is stopped", func() { + // Verify onChange was not called (debounce timer was stopped) + s.Equal(int32(0), callCount.Load(), "onChange should not be called because debounce timer was stopped") + }) }) s.Run("handles close with nil channels", func() { @@ -403,11 +396,7 @@ func (s *ClusterStateTestSuite) TestClose() { stoppedCh: nil, } - err := watcher.Close() - - s.Run("returns no error", func() { - s.NoError(err) - }) + s.NotPanics(watcher.Close) }) s.Run("handles close on unstarted watcher", func() { @@ -420,11 +409,7 @@ func (s *ClusterStateTestSuite) TestClose() { // Close the stoppedCh channel since the goroutine never started close(watcher.stoppedCh) - err := watcher.Close() - - s.Run("returns no error", func() { - s.NoError(err) - }) + s.NotPanics(watcher.Close) }) } diff --git a/pkg/kubernetes/watcher/kubeconfig.go b/pkg/kubernetes/watcher/kubeconfig.go index 25d5803a..c7d06afe 100644 --- a/pkg/kubernetes/watcher/kubeconfig.go +++ b/pkg/kubernetes/watcher/kubeconfig.go @@ -7,7 +7,7 @@ import ( type Kubeconfig struct { clientcmd.ClientConfig - close func() error + close func() } var _ Watcher = (*Kubeconfig)(nil) @@ -46,14 +46,13 @@ func (w *Kubeconfig) Watch(onChange func() error) { } }() if w.close != nil { - _ = w.close() + w.close() } - w.close = watcher.Close + w.close = func() { _ = watcher.Close() } } -func (w *Kubeconfig) Close() error { +func (w *Kubeconfig) Close() { if w.close != nil { - return w.close() + w.close() } - return nil } diff --git a/pkg/kubernetes/watcher/kubeconfig_test.go b/pkg/kubernetes/watcher/kubeconfig_test.go index f0807822..6f3b47dd 100644 --- a/pkg/kubernetes/watcher/kubeconfig_test.go +++ b/pkg/kubernetes/watcher/kubeconfig_test.go @@ -47,7 +47,7 @@ func (s *KubeconfigTestSuite) TestNewKubeconfig() { func (s *KubeconfigTestSuite) TestWatch() { s.Run("triggers onChange callback on file modification", func() { watcher := NewKubeconfig(s.clientConfig) - defer func() { _ = watcher.Close() }() + s.T().Cleanup(watcher.Close) var changeDetected atomic.Bool onChange := func() error { @@ -91,7 +91,7 @@ func (s *KubeconfigTestSuite) TestWatch() { s.Run("handles multiple file changes", func() { watcher := NewKubeconfig(s.clientConfig) - defer func() { _ = watcher.Close() }() + s.T().Cleanup(watcher.Close) var callCount atomic.Int32 onChange := func() error { @@ -120,7 +120,7 @@ func (s *KubeconfigTestSuite) TestWatch() { s.Run("handles onChange callback errors gracefully", func() { watcher := NewKubeconfig(s.clientConfig) - defer func() { _ = watcher.Close() }() + s.T().Cleanup(watcher.Close) var errorReturned atomic.Bool onChange := func() error { @@ -146,7 +146,7 @@ func (s *KubeconfigTestSuite) TestWatch() { s.Run("replaces previous watcher on subsequent Watch calls", func() { watcher := NewKubeconfig(s.clientConfig) - defer func() { _ = watcher.Close() }() + s.T().Cleanup(watcher.Close) var secondWatcherActive atomic.Bool @@ -187,9 +187,9 @@ func (s *KubeconfigTestSuite) TestClose() { close: nil, } - err := watcher.Close() - - s.NoError(err) + s.NotPanics(func() { + watcher.Close() + }) }) s.Run("closes watcher successfully", func() { @@ -202,9 +202,9 @@ func (s *KubeconfigTestSuite) TestClose() { return watcher.close != nil }), "timeout waiting for watcher to be ready") - err := watcher.Close() - - s.NoError(err) + s.NotPanics(func() { + watcher.Close() + }) }) s.Run("stops triggering onChange after close", func() { @@ -221,8 +221,7 @@ func (s *KubeconfigTestSuite) TestClose() { return watcher.close != nil }), "timeout waiting for watcher to be ready") - err := watcher.Close() - s.NoError(err) + watcher.Close() countAfterClose := callCount.Load() @@ -231,7 +230,7 @@ func (s *KubeconfigTestSuite) TestClose() { // Wait a reasonable amount of time to verify no callbacks are triggered // Using WaitForCondition with a condition that should NOT become true - err = test.WaitForCondition(50*time.Millisecond, func() bool { + err := test.WaitForCondition(50*time.Millisecond, func() bool { return callCount.Load() > countAfterClose }) // We expect this to timeout (return error) because no callbacks should be triggered @@ -249,11 +248,10 @@ func (s *KubeconfigTestSuite) TestClose() { return watcher.close != nil }), "timeout waiting for watcher to be ready") - err1 := watcher.Close() - err2 := watcher.Close() - - s.NoError(err1, "first close should succeed") - s.NoError(err2, "second close should succeed") + s.NotPanics(func() { + watcher.Close() + watcher.Close() + }) }) } diff --git a/pkg/kubernetes/watcher/watcher.go b/pkg/kubernetes/watcher/watcher.go index a94071e1..ec135890 100644 --- a/pkg/kubernetes/watcher/watcher.go +++ b/pkg/kubernetes/watcher/watcher.go @@ -2,5 +2,5 @@ package watcher type Watcher interface { Watch(onChange func() error) - Close() error + Close() } diff --git a/pkg/mcp/mcp_watch_test.go b/pkg/mcp/mcp_watch_test.go index b4f092f0..1e977ea8 100644 --- a/pkg/mcp/mcp_watch_test.go +++ b/pkg/mcp/mcp_watch_test.go @@ -1,6 +1,7 @@ package mcp import ( + "fmt" "os" "testing" "time" @@ -45,6 +46,19 @@ func (s *WatchKubeConfigSuite) TestNotifiesToolsChange() { s.Equal("notifications/tools/list_changed", notification.Method, "WatchKubeConfig did not notify tools change") } +func (s *WatchKubeConfigSuite) TestNotifiesToolsChangeMultipleTimes() { + // Given + s.InitMcpClient() + // When + for i := 0; i < 3; i++ { + s.WriteKubeconfig() + notification := s.WaitForNotification(5 * time.Second) + // Then + s.NotNil(notification, "WatchKubeConfig did not notify on iteration %d", i) + s.Equalf("notifications/tools/list_changed", notification.Method, "WatchKubeConfig did not notify tools change on iteration %d", i) + } +} + func (s *WatchKubeConfigSuite) TestClearsNoLongerAvailableTools() { s.mockServer.Handle(&test.InOpenShiftHandler{}) s.InitMcpClient() @@ -124,6 +138,20 @@ func (s *WatchClusterStateSuite) TestNotifiesToolsChangeOnAPIGroupAddition() { s.Equal("notifications/tools/list_changed", notification.Method, "cluster state watcher did not notify tools change") } +func (s *WatchClusterStateSuite) TestNotifiesToolsChangeMultipleTimes() { + // Given - Initialize with basic API groups + s.InitMcpClient() + + // When - Add multiple API groups to simulate cluster state changes + for i := 0; i < 3; i++ { + name := fmt.Sprintf("custom-%d", i) + s.AddAPIGroup(`{"name":"` + name + `.example.com","versions":[{"groupVersion":"` + name + `.example.com/v1","version":"v1"}],"preferredVersion":{"groupVersion":"` + name + `.example.com/v1","version":"v1"}}`) + notification := s.WaitForNotification(10 * time.Second) + s.NotNil(notification, "cluster state watcher did not notify on iteration %d", i) + s.Equalf("notifications/tools/list_changed", notification.Method, "cluster state watcher did not notify tools change on iteration %d", i) + } +} + func (s *WatchClusterStateSuite) TestDetectsOpenShiftClusterStateChange() { s.InitMcpClient()