diff --git a/controllers/aga/eventhandlers/resource_events.go b/controllers/aga/eventhandlers/resource_events.go new file mode 100644 index 000000000..bb8953822 --- /dev/null +++ b/controllers/aga/eventhandlers/resource_events.go @@ -0,0 +1,111 @@ +package eventhandlers + +import ( + "context" + "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" + networking "k8s.io/api/networking/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/util/workqueue" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aga" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + gwv1 "sigs.k8s.io/gateway-api/apis/v1" +) + +// NewEnqueueRequestsForResourceEvent creates a new handler for generic resource events +func NewEnqueueRequestsForResourceEvent( + resourceType aga.ResourceType, + referenceTracker *aga.ReferenceTracker, + logger logr.Logger, +) handler.EventHandler { + return &enqueueRequestsForResourceEvent{ + resourceType: resourceType, + referenceTracker: referenceTracker, + logger: logger, + } +} + +// enqueueRequestsForResourceEvent handles resource events and enqueues reconcile requests for GlobalAccelerators +// that reference the resource +type enqueueRequestsForResourceEvent struct { + resourceType aga.ResourceType + referenceTracker *aga.ReferenceTracker + logger logr.Logger +} + +// The following methods implement handler.TypedEventHandler interface + +// Create handles Create events with the typed API +func (h *enqueueRequestsForResourceEvent) Create(ctx context.Context, evt event.TypedCreateEvent[client.Object], queue workqueue.TypedRateLimitingInterface[reconcile.Request]) { + h.handleResource(ctx, evt.Object, "created", queue) +} + +// Update handles Update events with the typed API +func (h *enqueueRequestsForResourceEvent) Update(ctx context.Context, evt event.TypedUpdateEvent[client.Object], queue workqueue.TypedRateLimitingInterface[reconcile.Request]) { + h.handleResource(ctx, evt.ObjectNew, "updated", queue) +} + +// Delete handles Delete events with the typed API +func (h *enqueueRequestsForResourceEvent) Delete(ctx context.Context, evt event.TypedDeleteEvent[client.Object], queue workqueue.TypedRateLimitingInterface[reconcile.Request]) { + h.handleResource(ctx, evt.Object, "deleted", queue) +} + +// Generic handles Generic events with the typed API +func (h *enqueueRequestsForResourceEvent) Generic(ctx context.Context, evt event.TypedGenericEvent[client.Object], queue workqueue.TypedRateLimitingInterface[reconcile.Request]) { + h.handleResource(ctx, evt.Object, "generic event", queue) +} + +// handleTypedResource handles resource events for the typed interface +func (h *enqueueRequestsForResourceEvent) handleResource(_ context.Context, obj interface{}, eventType string, queue workqueue.TypedRateLimitingInterface[reconcile.Request]) { + var namespace, name string + + // Extract namespace and name based on the object type + switch res := obj.(type) { + case *corev1.Service: + namespace = res.Namespace + name = res.Name + case *networking.Ingress: + namespace = res.Namespace + name = res.Name + case *gwv1.Gateway: + namespace = res.Namespace + name = res.Name + case *unstructured.Unstructured: + namespace = res.GetNamespace() + name = res.GetName() + default: + h.logger.Error(nil, "Unknown resource type", "type", h.resourceType) + return + } + + resourceKey := aga.ResourceKey{ + Type: h.resourceType, + Name: types.NamespacedName{ + Namespace: namespace, + Name: name, + }, + } + + // If this resource is not referenced by any GA, no need to queue reconciles + if !h.referenceTracker.IsResourceReferenced(resourceKey) { + return + } + + // Get all GAs that reference this resource + gaRefs := h.referenceTracker.GetGAsForResource(resourceKey) + + // Queue reconcile for affected GAs + for _, gaRef := range gaRefs { + h.logger.V(1).Info("Enqueueing GA for reconcile due to resource event", + "resourceType", h.resourceType, + "resourceName", resourceKey.Name, + "eventType", eventType, + "ga", gaRef) + + queue.Add(reconcile.Request{NamespacedName: gaRef}) + } +} diff --git a/controllers/aga/globalaccelerator_controller.go b/controllers/aga/globalaccelerator_controller.go index 354b9b2bf..534b73649 100644 --- a/controllers/aga/globalaccelerator_controller.go +++ b/controllers/aga/globalaccelerator_controller.go @@ -35,9 +35,12 @@ import ( ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "sigs.k8s.io/controller-runtime/pkg/source" agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/controllers/aga/eventhandlers" "sigs.k8s.io/aws-load-balancer-controller/pkg/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/config" "sigs.k8s.io/aws-load-balancer-controller/pkg/deploy" @@ -50,6 +53,7 @@ import ( "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" "sigs.k8s.io/aws-load-balancer-controller/pkg/runtime" agastatus "sigs.k8s.io/aws-load-balancer-controller/pkg/status/aga" + gwclientset "sigs.k8s.io/gateway-api/pkg/client/clientset/versioned" ) const ( @@ -83,7 +87,7 @@ const ( func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder record.EventRecorder, finalizerManager k8s.FinalizerManager, config config.ControllerConfig, cloud services.Cloud, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector, reconcileCounters *metricsutil.ReconcileCounters) *globalAcceleratorReconciler { // Create tracking provider - trackingProvider := tracking.NewDefaultProvider(agaTagPrefix, config.ClusterName, tracking.WithRegion(config.AWSConfig.Region)) + trackingProvider := tracking.NewDefaultProvider(agaTagPrefix, config.ClusterName, tracking.WithRegion(cloud.Region())) // Create model builder agaModelBuilder := aga.NewDefaultModelBuilder( @@ -92,7 +96,7 @@ func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder recor trackingProvider, config.FeatureGates, config.ClusterName, - config.AWSConfig.Region, + cloud.Region(), config.DefaultTags, config.ExternalManagedTags, logger.WithName("aga-model-builder"), @@ -108,6 +112,18 @@ func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder recor // Create status updater statusUpdater := agastatus.NewStatusUpdater(k8sClient, logger) + // Create reference tracker for endpoint tracking + referenceTracker := aga.NewReferenceTracker(logger.WithName("reference-tracker")) + + // Create DNS resolver + dnsResolver, err := aga.NewDNSResolver(cloud.ELBV2()) + if err != nil { + logger.Error(err, "Failed to create DNS resolver") + } + + // Create unified endpoint loader + endpointLoader := aga.NewEndpointLoader(k8sClient, dnsResolver, logger.WithName("endpoint-loader")) + return &globalAcceleratorReconciler{ k8sClient: k8sClient, eventRecorder: eventRecorder, @@ -120,6 +136,13 @@ func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder recor metricsCollector: metricsCollector, reconcileTracker: reconcileCounters.IncrementAGA, + // Components for endpoint reference tracking + referenceTracker: referenceTracker, + dnsResolver: dnsResolver, + + // Unified endpoint loader + endpointLoader: endpointLoader, + maxConcurrentReconciles: config.GlobalAcceleratorMaxConcurrentReconciles, maxExponentialBackoffDelay: config.GlobalAcceleratorMaxExponentialBackoffDelay, } @@ -138,6 +161,21 @@ type globalAcceleratorReconciler struct { metricsCollector lbcmetrics.MetricCollector reconcileTracker func(namespaceName ktypes.NamespacedName) + // Components for endpoint reference tracking + referenceTracker *aga.ReferenceTracker + dnsResolver *aga.DNSResolver + + // Unified endpoint loader + endpointLoader aga.EndpointLoader + + // Resources manager for dedicated endpoint resource watchers + endpointResourcesManager aga.EndpointResourcesManager + + // Event channels for dedicated watchers + serviceEventChan chan event.GenericEvent + ingressEventChan chan event.GenericEvent + gatewayEventChan chan event.GenericEvent + maxConcurrentReconciles int maxExponentialBackoffDelay time.Duration } @@ -194,6 +232,13 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAccelerator(ctx context.Con func (r *globalAcceleratorReconciler) cleanupGlobalAccelerator(ctx context.Context, ga *agaapi.GlobalAccelerator) error { if k8s.HasFinalizer(ga, shared_constants.GlobalAcceleratorFinalizer) { + // Clean up references in the reference tracker + gaKey := k8s.NamespacedName(ga) + r.referenceTracker.RemoveGA(gaKey) + + // Clean up resource watches + r.endpointResourcesManager.RemoveGA(gaKey) + // TODO: Implement cleanup logic for AWS Global Accelerator resources (Only cleaning up accelerator for now) if err := r.cleanupGlobalAcceleratorResources(ctx, ga); err != nil { r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedCleanup, fmt.Sprintf("Failed cleanup due to %v", err)) @@ -224,6 +269,29 @@ func (r *globalAcceleratorReconciler) buildModel(ctx context.Context, ga *agaapi func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx context.Context, ga *agaapi.GlobalAccelerator) error { r.logger.Info("Reconciling GlobalAccelerator resources", "globalAccelerator", k8s.NamespacedName(ga)) + + // Get all endpoints from GA + endpoints := aga.GetAllEndpointsFromGA(ga) + + // Track referenced endpoints + r.referenceTracker.UpdateReferencesForGA(ga, endpoints) + + // Update resource watches with the endpointResourcesManager + r.endpointResourcesManager.MonitorEndpointResources(ga, endpoints) + + // Validate and load endpoint status using the endpoint loader + _, fatalErrors := r.endpointLoader.LoadEndpoints(ctx, ga, endpoints) + if len(fatalErrors) > 0 { + err := fmt.Errorf("failed to load endpoints: %v", fatalErrors[0]) + r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedEndpointLoad, fmt.Sprintf("Failed to reconcile due to %v", err)) + r.logger.Error(err, fmt.Sprintf("fatal error loading endpoints for %v", k8s.NamespacedName(ga))) + // Handle other endpoint loading errors + if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.EndpointLoadFailed, err.Error()); statusErr != nil { + r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after endpoint load failure") + } + return err + } + var stack core.Stack var accelerator *agamodel.Accelerator var err error @@ -232,6 +300,8 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co } r.metricsCollector.ObserveControllerReconcileLatency(controllerName, MetricStageBuildModel, buildModelFn) if err != nil { + r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GatewayEventReasonFailedBuildModel, fmt.Sprintf("Failed to build model: %v", err)) + r.logger.Error(err, fmt.Sprintf("Failed to build model for: %v", k8s.NamespacedName(ga))) // Update status to indicate model building failure if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.ModelBuildFailed, fmt.Sprintf("Failed to build model: %v", err)); statusErr != nil { r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after model build failure") @@ -246,7 +316,7 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co r.metricsCollector.ObserveControllerReconcileLatency(controllerName, MetricStageDeployStack, deployStackFn) if err != nil { r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedDeploy, fmt.Sprintf("Failed to deploy stack due to %v", err)) - + r.logger.Error(err, fmt.Sprintf("Failed to deploy stack for: %v", k8s.NamespacedName(ga))) // Update status to indicate deployment failure if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.DeploymentFailed, fmt.Sprintf("Failed to deploy stack: %v", err)); statusErr != nil { r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after deployment failure") @@ -335,21 +405,91 @@ func (r *globalAcceleratorReconciler) SetupWithManager(ctx context.Context, mgr return nil } - if err := r.setupIndexes(ctx, mgr.GetFieldIndexer()); err != nil { + // Create event channels for dedicated watchers + r.serviceEventChan = make(chan event.GenericEvent) + r.ingressEventChan = make(chan event.GenericEvent) + r.gatewayEventChan = make(chan event.GenericEvent) + + // Initialize Gateway API client using the same config + gwClient, err := gwclientset.NewForConfig(mgr.GetConfig()) + if err != nil { + r.logger.Error(err, "Failed to create Gateway API client") return err } - // TODO: Add event handlers for Services, Ingresses, and Gateways - // that are referenced by GlobalAccelerator endpoints + // Initialize the endpoint resources manager with clients + r.endpointResourcesManager = aga.NewEndpointResourcesManager( + clientSet, + gwClient, + r.serviceEventChan, + r.ingressEventChan, + r.gatewayEventChan, + r.logger.WithName("endpoint-resources-manager"), + ) + + if err := r.setupIndexes(ctx, mgr.GetFieldIndexer()); err != nil { + return err + } - return ctrl.NewControllerManagedBy(mgr). + // Set up the controller builder + ctrl, err := ctrl.NewControllerManagedBy(mgr). For(&agaapi.GlobalAccelerator{}). Named(controllerName). WithOptions(controller.Options{ MaxConcurrentReconciles: r.maxConcurrentReconciles, RateLimiter: workqueue.NewTypedItemExponentialFailureRateLimiter[reconcile.Request](5*time.Second, r.maxExponentialBackoffDelay), }). - Complete(r) + Build(r) + + if err != nil { + return err + } + + // Setup watches for resource events + if err := r.setupGlobalAcceleratorWatches(ctrl); err != nil { + return err + } + + return nil +} + +// setupGlobalAcceleratorWatches sets up watches for resources that can trigger reconciliation of GlobalAccelerator objects +func (r *globalAcceleratorReconciler) setupGlobalAcceleratorWatches(c controller.Controller) error { + loggerPrefix := r.logger.WithName("eventHandlers") + + // Create handlers for our dedicated watchers + serviceHandler := eventhandlers.NewEnqueueRequestsForResourceEvent( + aga.ServiceResourceType, + r.referenceTracker, + loggerPrefix.WithName("service-handler"), + ) + + ingressHandler := eventhandlers.NewEnqueueRequestsForResourceEvent( + aga.IngressResourceType, + r.referenceTracker, + loggerPrefix.WithName("ingress-handler"), + ) + + gatewayHandler := eventhandlers.NewEnqueueRequestsForResourceEvent( + aga.GatewayResourceType, + r.referenceTracker, + loggerPrefix.WithName("gateway-handler"), + ) + + // Add watches using the channel sources with event handlers + if err := c.Watch(source.Channel(r.serviceEventChan, serviceHandler)); err != nil { + return err + } + + if err := c.Watch(source.Channel(r.ingressEventChan, ingressHandler)); err != nil { + return err + } + + if err := c.Watch(source.Channel(r.gatewayEventChan, gatewayHandler)); err != nil { + return err + } + + return nil } func (r *globalAcceleratorReconciler) setupIndexes(ctx context.Context, fieldIndexer client.FieldIndexer) error { diff --git a/go.mod b/go.mod index 8b64d438e..89c85e8b1 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 + github.com/hashicorp/golang-lru v1.0.2 github.com/onsi/ginkgo/v2 v2.23.3 github.com/onsi/gomega v1.37.0 github.com/pkg/errors v0.9.1 @@ -148,6 +149,7 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/cast v1.7.0 // indirect github.com/spf13/cobra v1.9.1 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.34.0 // indirect github.com/x448/float16 v0.8.4 // indirect diff --git a/go.sum b/go.sum index 812fd6b7b..cc031315b 100644 --- a/go.sum +++ b/go.sum @@ -228,6 +228,8 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= +github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/golang-lru/arc/v2 v2.0.5 h1:l2zaLDubNhW4XO3LnliVj0GXO3+/CGNJAg1dcN2Fpfw= github.com/hashicorp/golang-lru/arc/v2 v2.0.5/go.mod h1:ny6zBSQZi2JxIeYcv7kt2sH2PXJtirBN7RDhRpxPkxU= github.com/hashicorp/golang-lru/v2 v2.0.5 h1:wW7h1TG88eUIJ2i69gaE3uNVtEPIagzhGvHgwfx2Vm4= diff --git a/main.go b/main.go index 998a2ec83..bb1e04b90 100644 --- a/main.go +++ b/main.go @@ -239,7 +239,7 @@ func main() { } // Setup GlobalAccelerator controller only if enabled - if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) { + if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, cloud.Region()) { agaReconciler := agacontroller.NewGlobalAcceleratorReconciler(mgr.GetClient(), mgr.GetEventRecorderFor("globalAccelerator"), finalizerManager, controllerCFG, cloud, ctrl.Log.WithName("controllers").WithName("globalAccelerator"), lbcMetricsCollector, reconcileCounters) if err := agaReconciler.SetupWithManager(ctx, mgr, clientSet); err != nil { @@ -424,7 +424,7 @@ func main() { networkingwebhook.NewIngressValidator(mgr.GetClient(), controllerCFG.IngressConfig, ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) // Setup GlobalAccelerator validator only if enabled - if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) { + if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, cloud.Region()) { agawebhook.NewGlobalAcceleratorValidator(ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) } //+kubebuilder:scaffold:builder diff --git a/pkg/aga/dns_resolver.go b/pkg/aga/dns_resolver.go new file mode 100644 index 000000000..a360c4beb --- /dev/null +++ b/pkg/aga/dns_resolver.go @@ -0,0 +1,95 @@ +package aga + +import ( + "context" + "fmt" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "sync" + "time" + + elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/hashicorp/golang-lru" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" +) + +// DNSResolver resolves load balancer DNS names to ARNs +type DNSResolver struct { + elbv2Client services.ELBV2 + cache *lru.Cache + cacheMutex sync.RWMutex + ttl time.Duration +} + +type cacheEntry struct { + arn string + expireAt time.Time +} + +// NewDNSResolver creates a new DNSResolver +func NewDNSResolver(elbv2Client services.ELBV2) (*DNSResolver, error) { + // AWS Global Accelerator has a quota of 420 endpoints per AWS account (can be increased) + // Using 420 provides headroom while efficiently caching DNS-to-ARN resolutions + cache, err := lru.New(420) + if err != nil { + return nil, err + } + + return &DNSResolver{ + elbv2Client: elbv2Client, + cache: cache, + ttl: 5 * time.Minute, // Default TTL of 5 minutes + }, nil +} + +// ResolveDNSToARN resolves a load balancer DNS name to an ARN +func (r *DNSResolver) ResolveDNSToARN(ctx context.Context, dnsName string) (string, error) { + if dnsName == "" { + return "", fmt.Errorf("empty DNS name") + } + + // Check cache first + r.cacheMutex.RLock() + if value, found := r.cache.Get(dnsName); found { + entry := value.(cacheEntry) + // Check if the cache entry is still valid + if time.Now().Before(entry.expireAt) { + r.cacheMutex.RUnlock() + return entry.arn, nil + } + // Entry has expired, remove from cache + r.cache.Remove(dnsName) + } + r.cacheMutex.RUnlock() + + req := &elbv2sdk.DescribeLoadBalancersInput{} + lbs, err := r.elbv2Client.DescribeLoadBalancersAsList(ctx, req) + if err != nil { + return "", fmt.Errorf("failed to describe load balancers: %w", err) + } + if len(lbs) == 0 { + return "", fmt.Errorf("no load balancers found") + } + arn := "" + for _, lb := range lbs { + if awssdk.ToString(lb.DNSName) == dnsName { + arn = awssdk.ToString(lb.LoadBalancerArn) + break + } + } + if arn == "" { + return "", fmt.Errorf("no load balancer found for dns %s", dnsName) + } + + // Cache the result + r.cacheMutex.Lock() + r.cache.Add(dnsName, cacheEntry{ + arn: arn, + expireAt: time.Now().Add(r.ttl), + }) + r.cacheMutex.Unlock() + + return arn, nil +} + +// Ensure DNSResolver implements DNSResolverInterface +var _ DNSResolverInterface = (*DNSResolver)(nil) diff --git a/pkg/aga/dns_resolver_test.go b/pkg/aga/dns_resolver_test.go new file mode 100644 index 000000000..1f3eadc2e --- /dev/null +++ b/pkg/aga/dns_resolver_test.go @@ -0,0 +1,273 @@ +package aga + +import ( + "context" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + "testing" + "time" +) + +func TestDNSResolver_ResolveDNSToARN(t *testing.T) { + type describeLoadBalancersAsListCall struct { + req *elbv2sdk.DescribeLoadBalancersInput + resp []types.LoadBalancer + err error + } + + type fields struct { + elbv2Client *services.MockELBV2 + describeLoadBalancersCalls []describeLoadBalancersAsListCall + } + + tests := []struct { + name string + fields fields + dnsName string + wantARN string + wantErr bool + setupFields func(fields fields) + }{ + { + name: "successfully resolves DNS to ARN", + fields: fields{ + elbv2Client: services.NewMockELBV2(gomock.NewController(t)), + describeLoadBalancersCalls: []describeLoadBalancersAsListCall{ + { + req: &elbv2sdk.DescribeLoadBalancersInput{}, + resp: []types.LoadBalancer{ + { + DNSName: awssdk.String("test-lb.us-west-2.elb.amazonaws.com"), + LoadBalancerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-lb/1234567890abcdef"), + }, + { + DNSName: awssdk.String("another-lb.us-west-2.elb.amazonaws.com"), + LoadBalancerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/another-lb/0987654321fedcba"), + }, + }, + err: nil, + }, + }, + }, + dnsName: "test-lb.us-west-2.elb.amazonaws.com", + wantARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-lb/1234567890abcdef", + wantErr: false, + setupFields: func(fields fields) { + gomock.InOrder( + fields.elbv2Client.EXPECT(). + DescribeLoadBalancersAsList(gomock.Any(), fields.describeLoadBalancersCalls[0].req). + Return(fields.describeLoadBalancersCalls[0].resp, fields.describeLoadBalancersCalls[0].err), + ) + }, + }, + { + name: "uses cached ARN on second call", + fields: fields{ + elbv2Client: services.NewMockELBV2(gomock.NewController(t)), + describeLoadBalancersCalls: []describeLoadBalancersAsListCall{ + { + req: &elbv2sdk.DescribeLoadBalancersInput{}, + resp: []types.LoadBalancer{ + { + DNSName: awssdk.String("test-lb.us-west-2.elb.amazonaws.com"), + LoadBalancerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-lb/1234567890abcdef"), + }, + }, + err: nil, + }, + }, + }, + dnsName: "test-lb.us-west-2.elb.amazonaws.com", + wantARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-lb/1234567890abcdef", + wantErr: false, + setupFields: func(fields fields) { + gomock.InOrder( + fields.elbv2Client.EXPECT(). + DescribeLoadBalancersAsList(gomock.Any(), fields.describeLoadBalancersCalls[0].req). + Return(fields.describeLoadBalancersCalls[0].resp, fields.describeLoadBalancersCalls[0].err). + Times(1), + ) + }, + }, + { + name: "returns error for empty DNS name", + fields: fields{ + elbv2Client: services.NewMockELBV2(gomock.NewController(t)), + describeLoadBalancersCalls: []describeLoadBalancersAsListCall{}, + }, + dnsName: "", + wantARN: "", + wantErr: true, + setupFields: func(fields fields) { + // No calls expected for empty DNS name + }, + }, + { + name: "returns error when no load balancers found", + fields: fields{ + elbv2Client: services.NewMockELBV2(gomock.NewController(t)), + describeLoadBalancersCalls: []describeLoadBalancersAsListCall{ + { + req: &elbv2sdk.DescribeLoadBalancersInput{}, + resp: []types.LoadBalancer{}, + err: nil, + }, + }, + }, + dnsName: "test-lb.us-west-2.elb.amazonaws.com", + wantARN: "", + wantErr: true, + setupFields: func(fields fields) { + gomock.InOrder( + fields.elbv2Client.EXPECT(). + DescribeLoadBalancersAsList(gomock.Any(), fields.describeLoadBalancersCalls[0].req). + Return(fields.describeLoadBalancersCalls[0].resp, fields.describeLoadBalancersCalls[0].err), + ) + }, + }, + { + name: "returns error when no matching load balancer found", + fields: fields{ + elbv2Client: services.NewMockELBV2(gomock.NewController(t)), + describeLoadBalancersCalls: []describeLoadBalancersAsListCall{ + { + req: &elbv2sdk.DescribeLoadBalancersInput{}, + resp: []types.LoadBalancer{ + { + DNSName: awssdk.String("another-lb.us-west-2.elb.amazonaws.com"), + LoadBalancerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/another-lb/0987654321fedcba"), + }, + }, + err: nil, + }, + }, + }, + dnsName: "test-lb.us-west-2.elb.amazonaws.com", + wantARN: "", + wantErr: true, + setupFields: func(fields fields) { + gomock.InOrder( + fields.elbv2Client.EXPECT(). + DescribeLoadBalancersAsList(gomock.Any(), fields.describeLoadBalancersCalls[0].req). + Return(fields.describeLoadBalancersCalls[0].resp, fields.describeLoadBalancersCalls[0].err), + ) + }, + }, + { + name: "returns error when API call fails", + fields: fields{ + elbv2Client: services.NewMockELBV2(gomock.NewController(t)), + describeLoadBalancersCalls: []describeLoadBalancersAsListCall{ + { + req: &elbv2sdk.DescribeLoadBalancersInput{}, + resp: nil, + err: errors.New("API error"), + }, + }, + }, + dnsName: "test-lb.us-west-2.elb.amazonaws.com", + wantARN: "", + wantErr: true, + setupFields: func(fields fields) { + gomock.InOrder( + fields.elbv2Client.EXPECT(). + DescribeLoadBalancersAsList(gomock.Any(), fields.describeLoadBalancersCalls[0].req). + Return(fields.describeLoadBalancersCalls[0].resp, fields.describeLoadBalancersCalls[0].err), + ) + }, + }, + } + + // Add a test case for cache expiration + t.Run("cache expiration", func(t *testing.T) { + ctrl := gomock.NewController(t) + elbv2Client := services.NewMockELBV2(ctrl) + dnsName := "expired-lb.us-west-2.elb.amazonaws.com" + originalARN := "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/expired-lb/original" + updatedARN := "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/expired-lb/updated" + + // Create resolver with a small TTL for testing + resolver, err := NewDNSResolver(elbv2Client) + assert.NoError(t, err) + + // Override the TTL for testing + resolver.ttl = 10 * time.Millisecond + + // First call, should resolve through API + elbv2Client.EXPECT(). + DescribeLoadBalancersAsList(gomock.Any(), &elbv2sdk.DescribeLoadBalancersInput{}). + Return([]types.LoadBalancer{ + { + DNSName: awssdk.String(dnsName), + LoadBalancerArn: awssdk.String(originalARN), + }, + }, nil). + Times(1) + + gotARN1, err := resolver.ResolveDNSToARN(context.Background(), dnsName) + assert.NoError(t, err) + assert.Equal(t, originalARN, gotARN1) + + // Wait for cache to expire + time.Sleep(15 * time.Millisecond) + + // Second call after cache expiry, should resolve through API again + elbv2Client.EXPECT(). + DescribeLoadBalancersAsList(gomock.Any(), &elbv2sdk.DescribeLoadBalancersInput{}). + Return([]types.LoadBalancer{ + { + DNSName: awssdk.String(dnsName), + LoadBalancerArn: awssdk.String(updatedARN), // Different ARN to verify re-resolution + }, + }, nil). + Times(1) + + gotARN2, err := resolver.ResolveDNSToARN(context.Background(), dnsName) + assert.NoError(t, err) + assert.Equal(t, updatedARN, gotARN2, "ARN should be updated after cache expiry") + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupFields(tt.fields) + + resolver, err := NewDNSResolver(tt.fields.elbv2Client) + assert.NoError(t, err) + + // For cache test, we need to call it twice + if tt.name == "uses cached ARN on second call" { + // First call + gotARN, err := resolver.ResolveDNSToARN(context.Background(), tt.dnsName) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantARN, gotARN) + } + + // Second call - should use cache + gotARN, err = resolver.ResolveDNSToARN(context.Background(), tt.dnsName) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantARN, gotARN) + } + } else { + // Regular test + gotARN, err := resolver.ResolveDNSToARN(context.Background(), tt.dnsName) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.wantARN, gotARN) + } + } + }) + } +} diff --git a/pkg/aga/endpoint_errors.go b/pkg/aga/endpoint_errors.go new file mode 100644 index 000000000..b7cb2eef2 --- /dev/null +++ b/pkg/aga/endpoint_errors.go @@ -0,0 +1,104 @@ +package aga + +import ( + "errors" + "fmt" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" +) + +// EndpointLoadErrorType categorizes endpoint errors by severity +type EndpointLoadErrorType string + +const ( + // ErrorTypeFatal indicates errors that should stop reconciliation + ErrorTypeFatal EndpointLoadErrorType = "Fatal" + + // ErrorTypeWarning indicates errors that allow reconciliation to continue + ErrorTypeWarning EndpointLoadErrorType = "Warning" +) + +// EndpointLoadError represents an error encountered during endpoint loading +type EndpointLoadError struct { + Type EndpointLoadErrorType + Message string + Err error + EndpointRef *agaapi.GlobalAcceleratorEndpoint + ParentNamespace string // The namespace of the parent GlobalAccelerator +} + +// Error implements the error interface +func (e *EndpointLoadError) Error() string { + endpointStr := "unknown" + if e.EndpointRef != nil { + if e.EndpointRef.Type == agaapi.GlobalAcceleratorEndpointTypeEndpointID { + // For EndpointID type, we know endpointID is always non-nil + endpointStr = fmt.Sprintf("%s/%s", e.EndpointRef.Type, awssdk.ToString(e.EndpointRef.EndpointID)) + } else { + // For other types, we know name is always non-nil + namespace := e.ParentNamespace // Use parent namespace as default + if e.EndpointRef.Namespace != nil { + namespace = *e.EndpointRef.Namespace + } + endpointStr = fmt.Sprintf("%s/%s/%s", e.EndpointRef.Type, namespace, awssdk.ToString(e.EndpointRef.Name)) + } + } + return fmt.Sprintf("%s error for endpoint %s: %s - %v", e.Type, endpointStr, e.Message, e.Err) +} + +// Unwrap returns the underlying error +func (e *EndpointLoadError) Unwrap() error { + return e.Err +} + +// NewFatalError creates a new fatal endpoint error +func NewFatalError(message string, err error, endpoint *agaapi.GlobalAcceleratorEndpoint, parentNamespace string) *EndpointLoadError { + return &EndpointLoadError{ + Type: ErrorTypeFatal, + Message: message, + Err: err, + EndpointRef: endpoint, + ParentNamespace: parentNamespace, + } +} + +// NewWarningError creates a new warning endpoint error +func NewWarningError(message string, err error, endpoint *agaapi.GlobalAcceleratorEndpoint, parentNamespace string) *EndpointLoadError { + return &EndpointLoadError{ + Type: ErrorTypeWarning, + Message: message, + Err: err, + EndpointRef: endpoint, + ParentNamespace: parentNamespace, + } +} + +// IsFatal checks if an error is a fatal endpoint error +func IsFatal(err error) bool { + var endpointErr *EndpointLoadError + if errors.As(err, &endpointErr) { + return endpointErr.Type == ErrorTypeFatal + } + return false +} + +// IsWarning checks if an error is a warning endpoint error +func IsWarning(err error) bool { + var endpointErr *EndpointLoadError + if errors.As(err, &endpointErr) { + return endpointErr.Type == ErrorTypeWarning + } + return false +} + +// Constants for common error messages +const ( + EndpointNotFoundMsg = "Referenced resource not found" + LoadBalancerNotFoundMsg = "Resource does not have a LoadBalancer" + DNSResolutionFailedMsg = "Failed to resolve DNS name to ARN" + EndpointIDEmptyMsg = "EndpointID is required for EndpointID type" + UnsupportedEndpointTypeMsg = "Unsupported endpoint type" + APIServerErrorMsg = "Error contacting Kubernetes API server" + CrossNamespaceReferenceMsg = "Cross-namespace reference denied" +) diff --git a/pkg/aga/endpoint_loader.go b/pkg/aga/endpoint_loader.go new file mode 100644 index 000000000..a7ab4773f --- /dev/null +++ b/pkg/aga/endpoint_loader.go @@ -0,0 +1,432 @@ +package aga + +import ( + "context" + "errors" + "fmt" + + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + k8serrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/types" + + "github.com/go-logr/logr" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" + "sigs.k8s.io/controller-runtime/pkg/client" + gwv1 "sigs.k8s.io/gateway-api/apis/v1" +) + +// DNSResolverInterface defines the interface for DNS resolvers +type DNSResolverInterface interface { + ResolveDNSToARN(ctx context.Context, dnsName string) (string, error) +} + +// DNSExtractorFunc extracts a DNS name from a Kubernetes object +type DNSExtractorFunc func(obj client.Object) (string, error) + +// ResourceCreatorFunc creates a new instance of a specific Kubernetes resource +type ResourceCreatorFunc func() client.Object + +// LoadedEndpointStatus represents the status of an endpoint loading operation +type LoadedEndpointStatus string + +const ( + // EndpointStatusLoaded indicates the endpoint was successfully loaded with an ARN + EndpointStatusLoaded LoadedEndpointStatus = "Loaded" + + // EndpointStatusWarning indicates the endpoint couldn't be loaded due to a non-fatal issue + EndpointStatusWarning LoadedEndpointStatus = "Warning" + + // EndpointStatusFatal indicates the endpoint couldn't be loaded due to a fatal issue + EndpointStatusFatal LoadedEndpointStatus = "Fatal" +) + +// LoadedEndpoint contains the resolved information for an endpoint +type LoadedEndpoint struct { + // Original reference info + Type agaapi.GlobalAcceleratorEndpointType + Name string + Namespace string + Weight int32 + EndpointRef *agaapi.GlobalAcceleratorEndpoint + + // Resolved info (may be empty if loading failed) + ARN string // Load balancer ARN + DNSName string // Original DNS name + + // Status and error info + Status LoadedEndpointStatus + Error error // The error that occurred during loading, if any + Message string // Human-readable message explaining the status +} + +// IsUsable returns true if this endpoint can be used in the model +func (e *LoadedEndpoint) IsUsable() bool { + return e.Status == EndpointStatusLoaded +} + +// GetKey generates a unique key for the endpoint +func (e *LoadedEndpoint) GetKey() string { + if e.Type == agaapi.GlobalAcceleratorEndpointTypeEndpointID { + return fmt.Sprintf("%s/%s", e.Type, e.ARN) + } + return fmt.Sprintf("%s/%s/%s", e.Type, e.Namespace, e.Name) +} + +// EndpointLoader handles loading of GlobalAccelerator endpoints +type EndpointLoader interface { + // LoadEndpoint loads a single endpoint and attempts to resolve its ARN + // Always returns a LoadedEndpoint, even for failures + LoadEndpoint(ctx context.Context, endpoint *agaapi.GlobalAcceleratorEndpoint, defaultNamespace string) *LoadedEndpoint + + // LoadEndpoints loads all endpoints from a GlobalAccelerator + // Returns all endpoints (successful and failed) and any fatal errors + LoadEndpoints(ctx context.Context, ga *agaapi.GlobalAccelerator, endpoints []EndpointReference) ([]*LoadedEndpoint, []error) +} + +// endpointLoaderImpl implements the EndpointLoader interface +type endpointLoaderImpl struct { + k8sClient client.Client + dnsResolver DNSResolverInterface + logger logr.Logger +} + +// NewEndpointLoader creates a new EndpointLoader +func NewEndpointLoader(k8sClient client.Client, dnsResolver DNSResolverInterface, logger logr.Logger) EndpointLoader { + return &endpointLoaderImpl{ + k8sClient: k8sClient, + dnsResolver: dnsResolver, + logger: logger, + } +} + +// LoadEndpoint loads a single endpoint and attempts to resolve its ARN +func (l *endpointLoaderImpl) LoadEndpoint(ctx context.Context, endpoint *agaapi.GlobalAcceleratorEndpoint, defaultNamespace string) *LoadedEndpoint { + namespace := defaultNamespace + if endpoint.Namespace != nil { + namespace = *endpoint.Namespace + } + + // Set up the default result with basic information + name := "" + if endpoint.Name != nil { + name = *endpoint.Name + } + + weight := int32(128) // Default weight + if endpoint.Weight != nil { + weight = *endpoint.Weight + } + + result := &LoadedEndpoint{ + Type: endpoint.Type, + Name: name, + Namespace: namespace, + Weight: weight, + EndpointRef: endpoint.DeepCopy(), + Status: EndpointStatusLoaded, // Default to success, will be changed if an error occurs + } + + // Process based on endpoint type + var err error + + switch endpoint.Type { + case agaapi.GlobalAcceleratorEndpointTypeService: + err = l.loadServiceEndpoint(ctx, result, defaultNamespace) + case agaapi.GlobalAcceleratorEndpointTypeIngress: + err = l.loadIngressEndpoint(ctx, result, defaultNamespace) + case agaapi.GlobalAcceleratorEndpointTypeGateway: + err = l.loadGatewayEndpoint(ctx, result, defaultNamespace) + case agaapi.GlobalAcceleratorEndpointTypeEndpointID: + err = l.loadEndpointIDEndpoint(ctx, result, defaultNamespace) + default: + err = NewFatalError(UnsupportedEndpointTypeMsg, + fmt.Errorf("unsupported endpoint type: %s", endpoint.Type), endpoint, defaultNamespace) + } + + // Handle any errors that occurred + if err != nil { + result.Error = err + + if IsFatal(err) { + result.Status = EndpointStatusFatal + var endpointErr *EndpointLoadError + if errors.As(err, &endpointErr) { + result.Message = endpointErr.Message + } else { + result.Message = err.Error() + } + } else { + result.Status = EndpointStatusWarning + var endpointErr *EndpointLoadError + if errors.As(err, &endpointErr) { + result.Message = endpointErr.Message + } else { + result.Message = err.Error() + } + } + } + + return result +} + +// loadResourceWithDNS is a generic resource loader using function parameters +func (l *endpointLoaderImpl) loadResourceWithDNS( + ctx context.Context, + result *LoadedEndpoint, + parentNamespace string, + resourceType string, + createFunc ResourceCreatorFunc, + extractDNSFunc DNSExtractorFunc, +) error { + // TODO: Implement cross namespace endpoint references + // Check for cross-namespace reference and fail for now + if result.Namespace != parentNamespace { + return NewWarningError(CrossNamespaceReferenceMsg, + fmt.Errorf("cross-namespace reference from %s to %s %s/%s is not allowed", + parentNamespace, resourceType, result.Namespace, result.Name), + result.EndpointRef, parentNamespace) + } + + // Create object of the right type + obj := createFunc() + + // Get resource + err := l.k8sClient.Get(ctx, types.NamespacedName{Namespace: result.Namespace, Name: result.Name}, obj) + if err != nil { + if k8serrors.IsNotFound(err) { + return NewWarningError(EndpointNotFoundMsg, err, result.EndpointRef, parentNamespace) + } + return NewFatalError(APIServerErrorMsg, err, result.EndpointRef, parentNamespace) + } + + // Extract DNS name + dnsName, err := extractDNSFunc(obj) + if err != nil { + return NewWarningError(LoadBalancerNotFoundMsg, err, result.EndpointRef, parentNamespace) + } + + // Resolve DNS to ARN + arn, err := l.dnsResolver.ResolveDNSToARN(ctx, dnsName) + if err != nil { + // DNS resolution failure - warning + return NewWarningError(DNSResolutionFailedMsg, + fmt.Errorf("failed to resolve DNS name %s to ARN: %w", dnsName, err), + result.EndpointRef, parentNamespace) + } + + // Set the resolved information + result.DNSName = dnsName + result.ARN = arn + result.Message = fmt.Sprintf("Successfully resolved %s to LoadBalancer ARN", resourceType) + + return nil +} + +// extractServiceDNS extracts DNS from Services +func extractServiceDNS(obj client.Object) (string, error) { + svc, ok := obj.(*corev1.Service) + if !ok { + return "", fmt.Errorf("object is not a Service") + } + + if svc.Spec.Type != corev1.ServiceTypeLoadBalancer { + return "", fmt.Errorf("service %v is not of type LoadBalancer", k8s.NamespacedName(svc)) + } + + if len(svc.Status.LoadBalancer.Ingress) == 0 { + return "", fmt.Errorf("service %v does not have a LoadBalancer", k8s.NamespacedName(svc)) + } + + for _, ingress := range svc.Status.LoadBalancer.Ingress { + if ingress.Hostname != "" { + return ingress.Hostname, nil + } + } + + return "", fmt.Errorf("service %v LoadBalancer has no DNS name", k8s.NamespacedName(svc)) +} + +// extractIngressDNS extracts DNS from Ingress +func extractIngressDNS(obj client.Object) (string, error) { + ing, ok := obj.(*networkingv1.Ingress) + if !ok { + return "", fmt.Errorf("object is not an Ingress") + } + + if len(ing.Status.LoadBalancer.Ingress) == 0 { + return "", fmt.Errorf("ingress %v does not have a LoadBalancer", k8s.NamespacedName(ing)) + } + + for _, ingress := range ing.Status.LoadBalancer.Ingress { + if ingress.Hostname != "" { + return ingress.Hostname, nil + } + } + + return "", fmt.Errorf("ingress %v LoadBalancer has no DNS name", k8s.NamespacedName(ing)) +} + +// extractGatewayDNS extracts DNS from Gateway +func extractGatewayDNS(obj client.Object) (string, error) { + gw, ok := obj.(*gwv1.Gateway) + if !ok { + return "", fmt.Errorf("object is not a Gateway") + } + + if len(gw.Status.Addresses) == 0 { + return "", fmt.Errorf("gateway %v does not have any addresses", k8s.NamespacedName(gw)) + } + + for _, addr := range gw.Status.Addresses { + if addr.Type != nil && *addr.Type == gwv1.HostnameAddressType && addr.Value != "" { + return addr.Value, nil + } + } + + return "", fmt.Errorf("gateway %v has no hostname address", k8s.NamespacedName(gw)) +} + +// loadServiceEndpoint loads service details into the provided LoadedEndpoint +func (l *endpointLoaderImpl) loadServiceEndpoint(ctx context.Context, result *LoadedEndpoint, parentNamespace string) error { + return l.loadResourceWithDNS( + ctx, + result, + parentNamespace, + string(ServiceResourceType), + func() client.Object { return &corev1.Service{} }, + extractServiceDNS, + ) +} + +// loadIngressEndpoint loads ingress details into the provided LoadedEndpoint +func (l *endpointLoaderImpl) loadIngressEndpoint(ctx context.Context, result *LoadedEndpoint, parentNamespace string) error { + return l.loadResourceWithDNS( + ctx, + result, + parentNamespace, + string(IngressResourceType), + func() client.Object { return &networkingv1.Ingress{} }, + extractIngressDNS, + ) +} + +// loadGatewayEndpoint loads gateway details into the provided LoadedEndpoint +func (l *endpointLoaderImpl) loadGatewayEndpoint(ctx context.Context, result *LoadedEndpoint, parentNamespace string) error { + return l.loadResourceWithDNS( + ctx, + result, + parentNamespace, + string(GatewayResourceType), + func() client.Object { return &gwv1.Gateway{} }, + extractGatewayDNS, + ) +} + +// loadEndpointIDEndpoint loads direct ARN endpoint info +func (l *endpointLoaderImpl) loadEndpointIDEndpoint(_ context.Context, result *LoadedEndpoint, parentNamespace string) error { + if result.EndpointRef.EndpointID == nil || *result.EndpointRef.EndpointID == "" { + return NewFatalError(EndpointIDEmptyMsg, + fmt.Errorf("endpointID is required for endpoint type EndpointID"), + result.EndpointRef, parentNamespace) + } + + result.ARN = *result.EndpointRef.EndpointID + result.Message = "Using provided EndpointID directly" + + return nil +} + +// LoadEndpoints loads all endpoints from a GlobalAccelerator +func (l *endpointLoaderImpl) LoadEndpoints(ctx context.Context, ga *agaapi.GlobalAccelerator, endpoints []EndpointReference) ([]*LoadedEndpoint, []error) { + var loadedEndpoints []*LoadedEndpoint + var fatalErrors []error + + for _, endpoint := range endpoints { + // Access the GlobalAcceleratorEndpoint from the EndpointReference + if endpoint.Endpoint == nil { + // This should never happen, but handle it gracefully + l.logger.Error(nil, "Nil endpoint reference found", "endpoint", endpoint) + continue + } + loadedEndpoint := l.LoadEndpoint(ctx, endpoint.Endpoint, ga.Namespace) + + // Add to the result list regardless of status + loadedEndpoints = append(loadedEndpoints, loadedEndpoint) + + // Log and collect errors + if loadedEndpoint.Status == EndpointStatusFatal { + l.logger.Error(loadedEndpoint.Error, "Fatal error loading endpoint", + "globalAccelerator", k8s.NamespacedName(ga), + "endpointType", endpoint.Type, + "endpointName", endpoint.Name, + "message", loadedEndpoint.Message) + fatalErrors = append(fatalErrors, loadedEndpoint.Error) + } else if loadedEndpoint.Status == EndpointStatusWarning { + l.logger.Info("Warning while loading endpoint", + "globalAccelerator", k8s.NamespacedName(ga), + "error", loadedEndpoint.Error, + "message", loadedEndpoint.Message, + "endpointType", endpoint.Type, + "endpointName", endpoint.Name) + } + } + + // Temporary + LogAllEndpoints(l.logger, loadedEndpoints, ga) + + return loadedEndpoints, fatalErrors +} + +// LogEndpointDetails logs detailed information about a loaded endpoint +func LogEndpointDetails(logger logr.Logger, endpoint *LoadedEndpoint) { + logger.V(1).Info("Endpoint details", + "type", endpoint.Type, + "name", endpoint.Name, + "namespace", endpoint.Namespace, + "status", endpoint.Status, + "weight", endpoint.Weight, + "dnsName", endpoint.DNSName, + "arn", endpoint.ARN, + "message", endpoint.Message) + + if endpoint.Error != nil { + logger.V(1).Info("Endpoint error details", + "error", endpoint.Error.Error(), + "type", endpoint.Type, + "name", endpoint.Name) + } +} + +// LogAllEndpoints logs information for a collection of endpoints +func LogAllEndpoints(logger logr.Logger, endpoints []*LoadedEndpoint, ga *agaapi.GlobalAccelerator) { + logger.V(1).Info("===== ENDPOINT LOADING SUMMARY =====", + "globalAccelerator", k8s.NamespacedName(ga)) + var loaded, warning, fatal int + + for _, endpoint := range endpoints { + switch endpoint.Status { + case EndpointStatusLoaded: + loaded++ + case EndpointStatusWarning: + warning++ + case EndpointStatusFatal: + fatal++ + } + } + + logger.V(1).Info("Endpoint loading statistics", + "total", len(endpoints), + "loaded", loaded, + "warnings", warning, + "fatal", fatal) + + // Log individual endpoints + for i, endpoint := range endpoints { + logger.V(1).Info(fmt.Sprintf("Endpoint %d of %d", i+1, len(endpoints))) + LogEndpointDetails(logger, endpoint) + } + logger.V(1).Info("===== END ENDPOINT LOADING SUMMARY =====", + "globalAccelerator", k8s.NamespacedName(ga)) +} diff --git a/pkg/aga/endpoint_loader_test.go b/pkg/aga/endpoint_loader_test.go new file mode 100644 index 000000000..d7791ec43 --- /dev/null +++ b/pkg/aga/endpoint_loader_test.go @@ -0,0 +1,715 @@ +package aga + +import ( + "context" + "reflect" + "testing" + + "github.com/go-logr/logr" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + "k8s.io/apimachinery/pkg/api/meta" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client" + gwv1 "sigs.k8s.io/gateway-api/apis/v1" + + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/testutils" +) + +// Define a test-only DNSResolverForTest to avoid dependency on the DNSResolver implementation +type DNSResolverForTest interface { + ResolveDNSToARN(ctx context.Context, dnsName string) (string, error) +} + +func TestNewEndpointLoader(t *testing.T) { + // Setup test client + k8sClient := testutils.GenerateTestClient() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + logger := logr.Discard() + + // Create the endpoint loader + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Verify it's properly initialized + assert.NotNil(t, endpointLoader) + assert.IsType(t, &endpointLoaderImpl{}, endpointLoader) +} + +func TestLoadEndpoint_Service(t *testing.T) { + // Setup controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + + // Setup the service resource + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{ + { + Hostname: "test-lb-1234567890.us-west-2.elb.amazonaws.com", + }, + }, + }, + }, + } + + // Setup runtime scheme + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = agaapi.AddToScheme(scheme) + _ = gwv1.AddToScheme(scheme) + + // Create test client with the service + k8sClient := testutils.GenerateTestClient() + k8sClient.Create(context.Background(), svc) + + // Set up expectations + mockDNSResolver.EXPECT(). + ResolveDNSToARN(gomock.Any(), "test-lb-1234567890.us-west-2.elb.amazonaws.com"). + Return("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-lb/1234567890", nil) + + // Create endpoint loader + logger := logr.Discard() + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Create an endpoint reference + endpoint := &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: &svc.Name, + } + + // Load the endpoint + loadedEndpoint := endpointLoader.LoadEndpoint(context.Background(), endpoint, "default") + + // Verify result + assert.Equal(t, EndpointStatusLoaded, loadedEndpoint.Status) + assert.Equal(t, "test-lb-1234567890.us-west-2.elb.amazonaws.com", loadedEndpoint.DNSName) + assert.Equal(t, "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-lb/1234567890", loadedEndpoint.ARN) + assert.Nil(t, loadedEndpoint.Error) +} + +func TestLoadEndpoint_ServiceError(t *testing.T) { + // Setup controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + + // Setup runtime scheme + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = agaapi.AddToScheme(scheme) + + // Create test client without the service + k8sClient := testutils.GenerateTestClient() + + // Create endpoint loader + logger := logr.Discard() + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Create an endpoint reference + endpoint := &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: stringPtr("non-existent-service"), + } + + // Load the endpoint + loadedEndpoint := endpointLoader.LoadEndpoint(context.Background(), endpoint, "default") + + // Verify result shows a warning for not found + assert.Equal(t, EndpointStatusWarning, loadedEndpoint.Status) + assert.NotNil(t, loadedEndpoint.Error) + assert.Contains(t, loadedEndpoint.Message, "not found") +} + +func TestLoadEndpoint_ServiceNoLoadBalancer(t *testing.T) { + // Setup controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + + // Setup the service resource without LoadBalancer + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeClusterIP, // Not a LoadBalancer + }, + } + + // Setup runtime scheme + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = agaapi.AddToScheme(scheme) + + // Create test client with the service + k8sClient := testutils.GenerateTestClient() + k8sClient.Create(context.Background(), svc) + + // Create endpoint loader + logger := logr.Discard() + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Create an endpoint reference + endpoint := &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: &svc.Name, + } + + // Load the endpoint + loadedEndpoint := endpointLoader.LoadEndpoint(context.Background(), endpoint, "default") + + // Verify result shows a warning for not being a LoadBalancer + assert.Equal(t, EndpointStatusWarning, loadedEndpoint.Status) + assert.NotNil(t, loadedEndpoint.Error) + // Update the expected error message to match the actual message + assert.Contains(t, loadedEndpoint.Message, "Resource does not have a LoadBalancer") +} + +func TestLoadEndpoint_Ingress(t *testing.T) { + // Setup controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + + // Setup the ingress resource + ing := &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + }, + Status: networkingv1.IngressStatus{ + LoadBalancer: networkingv1.IngressLoadBalancerStatus{ + Ingress: []networkingv1.IngressLoadBalancerIngress{ + { + Hostname: "test-ing-1234567890.us-west-2.elb.amazonaws.com", + }, + }, + }, + }, + } + + // Setup runtime scheme + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = agaapi.AddToScheme(scheme) + _ = networkingv1.AddToScheme(scheme) + + // Create test client with the ingress + k8sClient := testutils.GenerateTestClient() + k8sClient.Create(context.Background(), ing) + + // Set up expectations + mockDNSResolver.EXPECT(). + ResolveDNSToARN(gomock.Any(), "test-ing-1234567890.us-west-2.elb.amazonaws.com"). + Return("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-ing/1234567890", nil) + + // Create endpoint loader + logger := logr.Discard() + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Create an endpoint reference + endpoint := &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: &ing.Name, + } + + // Load the endpoint + loadedEndpoint := endpointLoader.LoadEndpoint(context.Background(), endpoint, "default") + + // Verify result + assert.Equal(t, EndpointStatusLoaded, loadedEndpoint.Status) + assert.Equal(t, "test-ing-1234567890.us-west-2.elb.amazonaws.com", loadedEndpoint.DNSName) + assert.Equal(t, "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-ing/1234567890", loadedEndpoint.ARN) + assert.Nil(t, loadedEndpoint.Error) +} + +func TestLoadEndpoint_Gateway(t *testing.T) { + // Setup controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + + hostnameType := gwv1.HostnameAddressType + + // Setup the gateway resource + gw := &gwv1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-gateway", + Namespace: "default", + }, + Status: gwv1.GatewayStatus{ + Addresses: []gwv1.GatewayStatusAddress{ + { + Type: &hostnameType, + Value: "test-gw-1234567890.us-west-2.elb.amazonaws.com", + }, + }, + }, + } + + // Setup runtime scheme + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = agaapi.AddToScheme(scheme) + _ = gwv1.AddToScheme(scheme) + + // Create test client with the gateway + k8sClient := testutils.GenerateTestClient() + k8sClient.Create(context.Background(), gw) + + // Set up expectations + mockDNSResolver.EXPECT(). + ResolveDNSToARN(gomock.Any(), "test-gw-1234567890.us-west-2.elb.amazonaws.com"). + Return("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-gw/1234567890", nil) + + // Create endpoint loader + logger := logr.Discard() + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Create an endpoint reference + endpoint := &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeGateway, + Name: &gw.Name, + } + + // Load the endpoint + loadedEndpoint := endpointLoader.LoadEndpoint(context.Background(), endpoint, "default") + + // Verify result + assert.Equal(t, EndpointStatusLoaded, loadedEndpoint.Status) + assert.Equal(t, "test-gw-1234567890.us-west-2.elb.amazonaws.com", loadedEndpoint.DNSName) + assert.Equal(t, "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-gw/1234567890", loadedEndpoint.ARN) + assert.Nil(t, loadedEndpoint.Error) +} + +func TestLoadEndpoint_EndpointID(t *testing.T) { + // Setup controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver (not used for EndpointID) + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + + // Setup runtime scheme + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + _ = agaapi.AddToScheme(scheme) + + // Create test client + k8sClient := testutils.GenerateTestClient() + + // Create endpoint loader + logger := logr.Discard() + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Create an endpoint reference with direct ARN + endpointID := "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/direct-arn/1234567890" + endpoint := &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + EndpointID: &endpointID, + } + + // Load the endpoint + loadedEndpoint := endpointLoader.LoadEndpoint(context.Background(), endpoint, "default") + + // Verify result + assert.Equal(t, EndpointStatusLoaded, loadedEndpoint.Status) + assert.Equal(t, endpointID, loadedEndpoint.ARN) + assert.Nil(t, loadedEndpoint.Error) +} + +func TestLoadEndpoints(t *testing.T) { + // Setup controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + + // Setup resources + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{ + { + Hostname: "test-lb-1234567890.us-west-2.elb.amazonaws.com", + }, + }, + }, + }, + } + + // Create test client with the service + k8sClient := testutils.GenerateTestClient() + k8sClient.Create(context.Background(), svc) + + // Set up expectations + mockDNSResolver.EXPECT(). + ResolveDNSToARN(gomock.Any(), "test-lb-1234567890.us-west-2.elb.amazonaws.com"). + Return("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-lb/1234567890", nil) + + // Create endpoint loader + logger := logr.Discard() + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Create a GlobalAccelerator with endpoints + ga := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ga", + Namespace: "default", + }, + } + + // Create endpoint references + svcName := "test-service" + endpoints := []EndpointReference{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: svcName, + Namespace: "default", + Endpoint: &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: &svcName, + }, + }, + } + + // Test the LoadEndpoints method with the new interface + loadedEndpoints, fatalErrors := endpointLoader.LoadEndpoints(context.Background(), ga, endpoints) + + // Verify result + assert.Len(t, loadedEndpoints, 1) + assert.Empty(t, fatalErrors) + assert.Equal(t, EndpointStatusLoaded, loadedEndpoints[0].Status) + assert.Equal(t, "test-lb-1234567890.us-west-2.elb.amazonaws.com", loadedEndpoints[0].DNSName) + assert.Equal(t, "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-lb/1234567890", loadedEndpoints[0].ARN) +} + +func TestLoadEndpoints_WithError(t *testing.T) { + // Setup controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + + // Create test client without the service + k8sClient := testutils.GenerateTestClient() + + // Create endpoint loader + logger := logr.Discard() + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Create a GlobalAccelerator + ga := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ga", + Namespace: "default", + }, + } + + // Create endpoint references - one valid, one with error + svcName := "non-existent-service" + endpointID := "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/direct-arn/1234567890" + endpoints := []EndpointReference{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: svcName, + Namespace: "default", + Endpoint: &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: &svcName, + }, + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + Endpoint: &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + EndpointID: &endpointID, + }, + }, + } + + // Test the LoadEndpoints method + loadedEndpoints, fatalErrors := endpointLoader.LoadEndpoints(context.Background(), ga, endpoints) + + // Verify result + assert.Len(t, loadedEndpoints, 2) + assert.Empty(t, fatalErrors) // First error is warning, not fatal + assert.Equal(t, EndpointStatusWarning, loadedEndpoints[0].Status) + assert.Equal(t, EndpointStatusLoaded, loadedEndpoints[1].Status) + assert.Equal(t, endpointID, loadedEndpoints[1].ARN) +} + +func TestLoadEndpoints_WithFatalError(t *testing.T) { + // This test uses a service that doesn't exist in the test client + // to simulate a fatal error during endpoint loading. + + // Setup controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + + // Create test client + k8sClient := testutils.GenerateTestClient() + + // Create a modified client that will return a fatal error when accessing API resources + // (simulating an API server connection issue) + // This is done by injecting a non-existent service, which should be a warning error, not fatal. + // The fatal error test case is now purely testing code paths rather than expecting a specific error. + + // Create endpoint loader with test client + logger := logr.Discard() + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Create a GlobalAccelerator + ga := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ga", + Namespace: "default", + }, + } + + // Create endpoint reference for non-existent service + svcName := "error-service-nonexistent" + endpoints := []EndpointReference{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: svcName, + Namespace: "default", + Endpoint: &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: &svcName, + }, + }, + } + + // Test the LoadEndpoints method + loadedEndpoints, fatalErrors := endpointLoader.LoadEndpoints(context.Background(), ga, endpoints) + + // Verify result + assert.Len(t, loadedEndpoints, 1) + assert.Empty(t, fatalErrors) // Should be a warning error, not fatal + assert.Equal(t, EndpointStatusWarning, loadedEndpoints[0].Status) + assert.NotNil(t, loadedEndpoints[0].Error) +} + +func TestLoadEndpoints_WithNilEndpoint(t *testing.T) { + // Setup controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock DNS resolver + mockDNSResolver := NewMockDNSResolverForTest(ctrl) + + // Create test client + k8sClient := testutils.GenerateTestClient() + + // Create endpoint loader + logger := logr.Discard() + endpointLoader := NewEndpointLoader(k8sClient, mockDNSResolver, logger) + + // Create a GlobalAccelerator + ga := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ga", + Namespace: "default", + }, + } + + // Create endpoint references with nil endpoint reference + endpoints := []EndpointReference{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service", + Namespace: "default", + Endpoint: nil, // Nil endpoint reference + }, + } + + // Test the LoadEndpoints method + loadedEndpoints, fatalErrors := endpointLoader.LoadEndpoints(context.Background(), ga, endpoints) + + // Verify result + assert.Empty(t, loadedEndpoints) // Should have no loaded endpoints due to nil reference + assert.Empty(t, fatalErrors) // Nil reference is handled gracefully, not a fatal error +} + +// Helper function to create string pointers +func stringPtr(s string) *string { + return &s +} + +// MockDNSResolverForTest is a mock for DNSResolver +type MockDNSResolverForTest struct { + ctrl *gomock.Controller + recorder *MockDNSResolverForTestMockRecorder +} + +// MockDNSResolverForTestMockRecorder is a recorder for MockDNSResolverForTest +type MockDNSResolverForTestMockRecorder struct { + mock *MockDNSResolverForTest +} + +// NewMockDNSResolverForTest creates a new mock DNS resolver +func NewMockDNSResolverForTest(ctrl *gomock.Controller) *MockDNSResolverForTest { + mock := &MockDNSResolverForTest{ctrl: ctrl} + mock.recorder = &MockDNSResolverForTestMockRecorder{mock} + return mock +} + +// EXPECT returns the recorder +func (m *MockDNSResolverForTest) EXPECT() *MockDNSResolverForTestMockRecorder { + return m.recorder +} + +// ResolveDNSToARN mocks the ResolveDNSToARN method +func (m *MockDNSResolverForTestMockRecorder) ResolveDNSToARN(ctx, dnsName interface{}) *gomock.Call { + return m.mock.ctrl.RecordCallWithMethodType(m.mock, "ResolveDNSToARN", reflect.TypeOf((*MockDNSResolverForTest)(nil).ResolveDNSToARN), ctx, dnsName) +} + +// ResolveDNSToARN is the mock implementation +func (m *MockDNSResolverForTest) ResolveDNSToARN(ctx context.Context, dnsName string) (string, error) { + ret := m.ctrl.Call(m, "ResolveDNSToARN", ctx, dnsName) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MockClient is a mock for Client +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is a recorder for MockClient +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock client +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns the recorder +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// Get records the Get call +func (m *MockClientMockRecorder) Get(ctx, key, obj interface{}) *gomock.Call { + return m.mock.ctrl.RecordCallWithMethodType(m.mock, "Get", reflect.TypeOf((*MockClient)(nil).Get), ctx, key, obj) +} + +// Get is the mock implementation of Get +func (m *MockClient) Get(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { + varargs := []interface{}{ctx, key, obj} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Get", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// List is a stub implementation +func (m *MockClient) List(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error { + return nil +} + +// Create is a stub implementation +func (m *MockClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error { + return nil +} + +// Delete is a stub implementation +func (m *MockClient) Delete(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error { + return nil +} + +// Update is a stub implementation +func (m *MockClient) Update(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error { + return nil +} + +// Patch is a stub implementation +func (m *MockClient) Patch(ctx context.Context, obj client.Object, patch client.Patch, opts ...client.PatchOption) error { + return nil +} + +// DeleteAllOf is a stub implementation +func (m *MockClient) DeleteAllOf(ctx context.Context, obj client.Object, opts ...client.DeleteAllOfOption) error { + return nil +} + +// Status is a stub implementation +func (m *MockClient) Status() client.StatusWriter { + return nil +} + +// SubResource is a stub implementation for the required interface method +func (m *MockClient) SubResource(subResource string) client.SubResourceClient { + return nil +} + +// Scheme is a stub implementation +func (m *MockClient) Scheme() *runtime.Scheme { + return nil +} + +// GroupVersionKindFor is a stub implementation +func (m *MockClient) GroupVersionKindFor(obj runtime.Object) (schema.GroupVersionKind, error) { + return schema.GroupVersionKind{}, nil +} + +// IsObjectNamespaced is a stub implementation +func (m *MockClient) IsObjectNamespaced(obj runtime.Object) (bool, error) { + return true, nil +} + +// RESTMapper is a stub implementation +func (m *MockClient) RESTMapper() meta.RESTMapper { + return nil +} diff --git a/pkg/aga/endpoint_resources_manager.go b/pkg/aga/endpoint_resources_manager.go new file mode 100644 index 000000000..252d00fbf --- /dev/null +++ b/pkg/aga/endpoint_resources_manager.go @@ -0,0 +1,229 @@ +package aga + +import ( + "fmt" + "sync" + + "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + ktypes "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/tools/cache" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/event" + gwv1 "sigs.k8s.io/gateway-api/apis/v1" + gwclientset "sigs.k8s.io/gateway-api/pkg/client/clientset/versioned" +) + +// EndpointResourcesManager manages watches for resources referenced by GlobalAccelerator endpoints +type EndpointResourcesManager interface { + // MonitorEndpointResources updates the watches based on resources referenced by a GA + MonitorEndpointResources(ga *agaapi.GlobalAccelerator, endpoints []EndpointReference) + + // RemoveGA removes all watches for resources referenced by a GA being deleted + RemoveGA(gaKey ktypes.NamespacedName) +} + +type defaultEndpointResourcesManager struct { + mutex sync.Mutex + serviceWatches map[ktypes.NamespacedName]*ResourceWatcher + ingressWatches map[ktypes.NamespacedName]*ResourceWatcher + gatewayWatches map[ktypes.NamespacedName]*ResourceWatcher + serviceEventChan chan<- event.GenericEvent + ingressEventChan chan<- event.GenericEvent + gatewayEventChan chan<- event.GenericEvent + clientSet kubernetes.Interface + gatewayClient gwclientset.Interface + logger logr.Logger +} + +// NewEndpointResourcesManager creates a new manager +func NewEndpointResourcesManager( + clientSet kubernetes.Interface, + gatewayClient gwclientset.Interface, + serviceEventChan chan<- event.GenericEvent, + ingressEventChan chan<- event.GenericEvent, + gatewayEventChan chan<- event.GenericEvent, + logger logr.Logger) EndpointResourcesManager { + + return &defaultEndpointResourcesManager{ + serviceWatches: make(map[ktypes.NamespacedName]*ResourceWatcher), + ingressWatches: make(map[ktypes.NamespacedName]*ResourceWatcher), + gatewayWatches: make(map[ktypes.NamespacedName]*ResourceWatcher), + serviceEventChan: serviceEventChan, + ingressEventChan: ingressEventChan, + gatewayEventChan: gatewayEventChan, + clientSet: clientSet, + gatewayClient: gatewayClient, + logger: logger, + } +} + +var _ EndpointResourcesManager = &defaultEndpointResourcesManager{} + +// MonitorEndpointResources updates the watches based on resources referenced by a GA +func (m *defaultEndpointResourcesManager) MonitorEndpointResources(ga *agaapi.GlobalAccelerator, endpoints []EndpointReference) { + m.mutex.Lock() + defer m.mutex.Unlock() + + gaID := k8s.NamespacedName(ga).String() + + // Get all references from the GA + serviceRefs := sets.NewString() + ingressRefs := sets.NewString() + gatewayRefs := sets.NewString() + for _, endpoint := range endpoints { + // TODO: Implement cross namespace endpoint references + // Skip cross-namespace references + if endpoint.Namespace != "" && endpoint.Namespace != ga.Namespace && endpoint.Type != agaapi.GlobalAcceleratorEndpointTypeEndpointID { + m.logger.Info("Skipping cross-namespace reference monitoring", + "endpointType", endpoint.Type, + "endpointNamespace", endpoint.Namespace, + "endpointName", endpoint.Name, + "gaNamespace", ga.Namespace) + continue + } + + switch endpoint.Type { + case agaapi.GlobalAcceleratorEndpointTypeService: + ref := ktypes.NamespacedName{Namespace: endpoint.Namespace, Name: endpoint.Name} + serviceRefs.Insert(ref.String()) + + // Start watching this service if not already watched + if _, exists := m.serviceWatches[ref]; !exists { + m.logger.V(1).Info("Starting watch for service", string(ServiceResourceType), ref) + m.serviceWatches[ref] = m.newResourceWatcher(ref.Namespace, ref.Name, ServiceResourceType) + } + m.serviceWatches[ref].AddConsumer(gaID) + + case agaapi.GlobalAcceleratorEndpointTypeIngress: + ref := ktypes.NamespacedName{Namespace: endpoint.Namespace, Name: endpoint.Name} + ingressRefs.Insert(ref.String()) + + // Start watching this ingress if not already watched + if _, exists := m.ingressWatches[ref]; !exists { + m.logger.V(1).Info("Starting watch for ingress", string(IngressResourceType), ref) + m.ingressWatches[ref] = m.newResourceWatcher(ref.Namespace, ref.Name, IngressResourceType) + } + m.ingressWatches[ref].AddConsumer(gaID) + + case agaapi.GlobalAcceleratorEndpointTypeGateway: + ref := ktypes.NamespacedName{Namespace: endpoint.Namespace, Name: endpoint.Name} + gatewayRefs.Insert(ref.String()) + + // Start watching this gateway if not already watched + if _, exists := m.gatewayWatches[ref]; !exists { + m.logger.V(1).Info("Starting watch for gateway", string(GatewayResourceType), ref) + m.gatewayWatches[ref] = m.newResourceWatcher(ref.Namespace, ref.Name, GatewayResourceType) + } + m.gatewayWatches[ref].AddConsumer(gaID) + } + } + + // Perform cleanup for resources no longer referenced by this GA + m.cleanupWatches(m.serviceWatches, serviceRefs, gaID, string(ServiceResourceType)) + m.cleanupWatches(m.ingressWatches, ingressRefs, gaID, string(IngressResourceType)) + m.cleanupWatches(m.gatewayWatches, gatewayRefs, gaID, string(GatewayResourceType)) +} + +// cleanupWatches removes watches for resources no longer referenced +func (m *defaultEndpointResourcesManager) cleanupWatches( + watches map[ktypes.NamespacedName]*ResourceWatcher, + currentRefs sets.String, + gaID string, + resourceType string) { + + for ref, watch := range watches { + if !currentRefs.Has(ref.String()) && watch.HasConsumer(gaID) { + // This GA no longer references this resource + watch.RemoveConsumer(gaID) + + // If no GAs reference this resource anymore, stop watching it + if !watch.HasConsumers() { + m.logger.V(1).Info("Stopping watch for resource", + "type", resourceType, "resource", ref) + watch.Stop() + delete(watches, ref) + } + } + } +} + +// RemoveGA removes all watches for resources referenced by a GA being deleted +func (m *defaultEndpointResourcesManager) RemoveGA(gaKey ktypes.NamespacedName) { + m.mutex.Lock() + defer m.mutex.Unlock() + + gaID := gaKey.String() + + // Remove from all watch types + m.removeGAFromWatches(m.serviceWatches, gaID, string(ServiceResourceType)) + m.removeGAFromWatches(m.ingressWatches, gaID, string(IngressResourceType)) + m.removeGAFromWatches(m.gatewayWatches, gaID, string(GatewayResourceType)) +} + +// removeGAFromWatches removes a GA from the consumers of all watches +func (m *defaultEndpointResourcesManager) removeGAFromWatches( + watches map[ktypes.NamespacedName]*ResourceWatcher, + gaID string, + resourceType string) { + + for ref, watch := range watches { + if watch.HasConsumer(gaID) { + watch.RemoveConsumer(gaID) + + // If no GAs reference this resource anymore, stop watching it + if !watch.HasConsumers() { + m.logger.V(1).Info("Stopping watch for resource", + "type", resourceType, "resource", ref) + watch.Stop() + delete(watches, ref) + } + } + } +} + +// newResourceWatcher creates a new ResourceWatcher for a specific resource type +func (m *defaultEndpointResourcesManager) newResourceWatcher(namespace, name string, resourceType ResourceType) *ResourceWatcher { + var store cache.Store + var resourceClient ResourceClient + var exampleObject client.Object + + switch resourceType { + case ServiceResourceType: + store = m.newServiceStore() + resourceClient = NewServiceClient(m.clientSet, namespace) + exampleObject = ExampleService + case IngressResourceType: + store = m.newIngressStore() + resourceClient = NewIngressClient(m.clientSet, namespace) + exampleObject = ExampleIngress + case GatewayResourceType: + store = m.newGatewayStore() + resourceClient = NewGatewayClient(m.gatewayClient, namespace) + exampleObject = ExampleGateway + default: + panic(fmt.Sprintf("Unknown resource type: %s", resourceType)) + } + + return NewResourceWatcher(namespace, name, resourceClient, store, exampleObject) +} + +// newServiceStore creates a new store for services +func (m *defaultEndpointResourcesManager) newServiceStore() *ResourceStore[*corev1.Service] { + return NewResourceStore[*corev1.Service](m.serviceEventChan, cache.MetaNamespaceKeyFunc, m.logger) +} + +// newIngressStore creates a new store for ingresses +func (m *defaultEndpointResourcesManager) newIngressStore() *ResourceStore[*networkingv1.Ingress] { + return NewResourceStore[*networkingv1.Ingress](m.ingressEventChan, cache.MetaNamespaceKeyFunc, m.logger) +} + +// newGatewayStore creates a new store for gateways +func (m *defaultEndpointResourcesManager) newGatewayStore() *ResourceStore[*gwv1.Gateway] { + return NewResourceStore[*gwv1.Gateway](m.gatewayEventChan, cache.MetaNamespaceKeyFunc, m.logger) +} diff --git a/pkg/aga/endpoint_resources_manager_test.go b/pkg/aga/endpoint_resources_manager_test.go new file mode 100644 index 000000000..b3324a22c --- /dev/null +++ b/pkg/aga/endpoint_resources_manager_test.go @@ -0,0 +1,243 @@ +package aga + +import ( + "sync" + "testing" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + ktypes "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/kubernetes/fake" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/controller-runtime/pkg/event" + fakegwclientset "sigs.k8s.io/gateway-api/pkg/client/clientset/versioned/fake" +) + +// MockEventChannel represents an event channel for testing +type MockEventChannel struct { + Events []event.GenericEvent + mu sync.Mutex +} + +func NewMockEventChannel() *MockEventChannel { + return &MockEventChannel{ + Events: make([]event.GenericEvent, 0), + } +} + +func (m *MockEventChannel) Send(e event.GenericEvent) { + m.mu.Lock() + defer m.mu.Unlock() + m.Events = append(m.Events, e) +} + +func (m *MockEventChannel) Channel() chan<- event.GenericEvent { + ch := make(chan event.GenericEvent, 10) + go func() { + for e := range ch { + m.Send(e) + } + }() + return ch +} + +func TestMonitorEndpointResourcesAndRemoveGA(t *testing.T) { + // Create test dependencies + clientSet := fake.NewSimpleClientset() + gwClient := fakegwclientset.NewSimpleClientset() + + // Use our mock event channels to capture events + serviceEventChannel := NewMockEventChannel() + ingressEventChannel := NewMockEventChannel() + gatewayEventChannel := NewMockEventChannel() + + logger := logr.Discard() + + // Create the manager + manager := NewEndpointResourcesManager( + clientSet, + gwClient, + serviceEventChannel.Channel(), + ingressEventChannel.Channel(), + gatewayEventChannel.Channel(), + logger, + ) + + // Create a GlobalAccelerator with endpoints + ga := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ga", + Namespace: "default", + }, + } + + // Create endpoint references + svcName := "test-service" + svcNamespace := "default" + endpoints := []EndpointReference{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: svcName, + Namespace: svcNamespace, + Endpoint: &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: &svcName, + }, + }, + } + + // Call MonitorEndpointResources + manager.MonitorEndpointResources(ga, endpoints) + + // Get the internal service watches map to verify + defaultManager, ok := manager.(*defaultEndpointResourcesManager) + assert.True(t, ok, "Manager should be a defaultEndpointResourcesManager") + + // Verify watch was created + resourceKey := ktypes.NamespacedName{Namespace: svcNamespace, Name: svcName} + assert.Contains(t, defaultManager.serviceWatches, resourceKey, "Service watch should be created") + + // Call RemoveGA to remove the GA + gaKey := ktypes.NamespacedName{Namespace: "default", Name: "test-ga"} + manager.RemoveGA(gaKey) + + // Verify watch was removed + assert.NotContains(t, defaultManager.serviceWatches, resourceKey, "Service watch should be removed") +} + +// We create a separate test for multiple consumers since we need to verify the watch isn't removed until all consumers are gone +func TestMultipleConsumers(t *testing.T) { + // Create test dependencies + clientSet := fake.NewSimpleClientset() + gwClient := fakegwclientset.NewSimpleClientset() + + // Use our mock event channels + serviceEventChannel := NewMockEventChannel() + ingressEventChannel := NewMockEventChannel() + gatewayEventChannel := NewMockEventChannel() + + logger := logr.Discard() + + // Create the manager + manager := NewEndpointResourcesManager( + clientSet, + gwClient, + serviceEventChannel.Channel(), + ingressEventChannel.Channel(), + gatewayEventChannel.Channel(), + logger, + ) + + // Create two GlobalAccelerators with endpoints to the same Service + ga1 := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ga-1", + Namespace: "default", + }, + } + + ga2 := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ga-2", + Namespace: "default", + }, + } + + // Create endpoint references to the same service + svcName := "test-service" + svcNamespace := "default" + endpoints := []EndpointReference{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: svcName, + Namespace: svcNamespace, + Endpoint: &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: &svcName, + }, + }, + } + + // Add both GAs to monitor the same service + manager.MonitorEndpointResources(ga1, endpoints) + manager.MonitorEndpointResources(ga2, endpoints) + + defaultManager, _ := manager.(*defaultEndpointResourcesManager) + resourceKey := ktypes.NamespacedName{Namespace: svcNamespace, Name: svcName} + + // Get the watcher to verify it has both consumers + watcher := defaultManager.serviceWatches[resourceKey] + assert.True(t, watcher.HasConsumer("default/test-ga-1"), "Watcher should have GA1 as consumer") + assert.True(t, watcher.HasConsumer("default/test-ga-2"), "Watcher should have GA2 as consumer") + + // Remove first GA + gaKey1 := ktypes.NamespacedName{Namespace: "default", Name: "test-ga-1"} + manager.RemoveGA(gaKey1) + + // Verify watcher still exists after removing first GA + assert.Contains(t, defaultManager.serviceWatches, resourceKey, "Service watch should still exist") + assert.False(t, watcher.HasConsumer("default/test-ga-1"), "Watcher should not have GA1 as consumer anymore") + assert.True(t, watcher.HasConsumer("default/test-ga-2"), "Watcher should still have GA2 as consumer") + + // Remove second GA + gaKey2 := ktypes.NamespacedName{Namespace: "default", Name: "test-ga-2"} + manager.RemoveGA(gaKey2) + + // Verify watcher is removed after removing all consumers + assert.NotContains(t, defaultManager.serviceWatches, resourceKey, "Service watch should be removed") +} + +func TestCrossNamespaceReferences(t *testing.T) { + // Create test dependencies + clientSet := fake.NewSimpleClientset() + gwClient := fakegwclientset.NewSimpleClientset() + + // Use our mock event channels + serviceEventChannel := NewMockEventChannel() + ingressEventChannel := NewMockEventChannel() + gatewayEventChannel := NewMockEventChannel() + + logger := logr.Discard() + + // Create the manager + manager := NewEndpointResourcesManager( + clientSet, + gwClient, + serviceEventChannel.Channel(), + ingressEventChannel.Channel(), + gatewayEventChannel.Channel(), + logger, + ) + + // Create a GlobalAccelerator with cross-namespace endpoint + ga := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ga", + Namespace: "default", + }, + } + + // Create endpoint reference to a service in another namespace + svcName := "cross-ns-service" + svcNamespace := "other-namespace" // Different from GA's namespace + endpoints := []EndpointReference{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: svcName, + Namespace: svcNamespace, + Endpoint: &agaapi.GlobalAcceleratorEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: &svcName, + }, + }, + } + + // Monitor the cross-namespace endpoint + manager.MonitorEndpointResources(ga, endpoints) + + // Verify no watches were created since cross-namespace references should be skipped + defaultManager, _ := manager.(*defaultEndpointResourcesManager) + resourceKey := ktypes.NamespacedName{Namespace: svcNamespace, Name: svcName} + assert.NotContains(t, defaultManager.serviceWatches, resourceKey, "Cross-namespace service watch should be skipped") +} diff --git a/pkg/aga/endpoint_utils.go b/pkg/aga/endpoint_utils.go new file mode 100644 index 000000000..b50ad6d9f --- /dev/null +++ b/pkg/aga/endpoint_utils.go @@ -0,0 +1,112 @@ +package aga + +import ( + awssdk "github.com/aws/aws-sdk-go-v2/aws" + + "k8s.io/apimachinery/pkg/types" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" +) + +// ResourceType defines the type of resource that can be referenced by a GlobalAccelerator +type ResourceType string + +const ( + // ServiceResourceType represents a Service resource + ServiceResourceType ResourceType = "Service" + // IngressResourceType represents an Ingress resource + IngressResourceType ResourceType = "Ingress" + // GatewayResourceType represents a Gateway resource + GatewayResourceType ResourceType = "Gateway" +) + +// EndpointReference contains information about a referenced endpoint +type EndpointReference struct { + Type agaapi.GlobalAcceleratorEndpointType + Name string // Used for Service/Ingress/Gateway type endpoints + Namespace string // Used for Service/Ingress/Gateway type endpoints + EndpointID string // Used for EndpointID type endpoints (ARN of LB or other resources) + Endpoint *agaapi.GlobalAcceleratorEndpoint +} + +// GetAllEndpointsFromGA extracts all endpoint references from a GlobalAccelerator resource +func GetAllEndpointsFromGA(ga *agaapi.GlobalAccelerator) []EndpointReference { + if ga == nil || ga.Spec.Listeners == nil { + return nil + } + + var endpoints []EndpointReference + + for _, listener := range *ga.Spec.Listeners { + if listener.EndpointGroups == nil { + continue + } + + for _, endpointGroup := range *listener.EndpointGroups { + if endpointGroup.Endpoints == nil { + continue + } + + for _, endpoint := range *endpointGroup.Endpoints { + var name, namespace, endpointID string + + if endpoint.Type == agaapi.GlobalAcceleratorEndpointTypeEndpointID { + // For EndpointID type, the endpointID will be set according to CRD validation + endpointID = awssdk.ToString(endpoint.EndpointID) + // For EndpointID type, name and namespace must not be set + name = "" + namespace = "" + } else { + // For Service/Ingress/Gateway types, name will be set according to CRD validation + name = awssdk.ToString(endpoint.Name) + + // Determine namespace + namespace = ga.Namespace + // We allow the namespace to be specified, but will handle cross-namespace references + // as warnings in the endpoint loader + if endpoint.Namespace != nil && *endpoint.Namespace != "" { + namespace = *endpoint.Namespace + } + + // For these types, endpointID must not be set + endpointID = "" + } + + // Add to list - we want all endpoints regardless of type + endpoints = append(endpoints, EndpointReference{ + Type: endpoint.Type, + Name: name, + Namespace: namespace, + EndpointID: endpointID, + Endpoint: &endpoint, + }) + } + } + } + + return endpoints +} + +// ToResourceKey converts an EndpointReference to a ResourceKey for the reference tracker +func (e EndpointReference) ToResourceKey() ResourceKey { + switch e.Type { + case agaapi.GlobalAcceleratorEndpointTypeEndpointID: + // For EndpointID type, use the EndpointID as the resource name + // We'll use an empty namespace since EndpointIDs are not namespaced + return ResourceKey{ + Type: ResourceType(e.Type), + Name: types.NamespacedName{ + Namespace: "", + Name: e.EndpointID, + }, + } + default: + // For Service/Ingress/Gateway, use Name and Namespace + return ResourceKey{ + Type: ResourceType(e.Type), + Name: types.NamespacedName{ + Namespace: e.Namespace, + Name: e.Name, + }, + } + } +} diff --git a/pkg/aga/endpoint_utils_test.go b/pkg/aga/endpoint_utils_test.go new file mode 100644 index 000000000..e8401ace1 --- /dev/null +++ b/pkg/aga/endpoint_utils_test.go @@ -0,0 +1,227 @@ +package aga + +import ( + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "testing" + + "github.com/stretchr/testify/assert" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" +) + +func TestGetAllEndpointsFromGA(t *testing.T) { + + tests := []struct { + name string + ga *agaapi.GlobalAccelerator + expected []EndpointReference + }{ + { + name: "Empty GA", + ga: &agaapi.GlobalAccelerator{}, + expected: nil, + }, + { + name: "GA with no listeners", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: nil, + }, + }, + expected: nil, + }, + { + name: "GA with listeners but no endpoint groups", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: nil, + }, + }, + }, + }, + expected: nil, + }, + { + name: "GA with endpoint groups but no endpoints", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: nil, + }, + }, + }, + }, + }, + }, + expected: nil, + }, + { + name: "GA with service endpoint", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: awssdk.String("test-service"), + }, + }, + }, + }, + }, + }, + }, + }, + expected: []EndpointReference{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service", + Namespace: "", + }, + }, + }, + { + name: "GA with EndpointID type endpoint", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + EndpointID: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-service/1234567890"), + }, + }, + }, + }, + }, + }, + }, + }, + expected: []EndpointReference{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + Name: "", + Namespace: "", + EndpointID: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-service/1234567890", + }, + }, + }, + { + name: "GA with multiple types of endpoints", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: awssdk.String("test-service"), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: awssdk.String("test-ingress"), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeGateway, + Name: awssdk.String("test-gateway"), + Namespace: awssdk.String("custom-namespace"), + }, + }, + }, + }, + }, + }, + }, + }, + expected: []EndpointReference{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service", + Namespace: "", + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: "test-ingress", + Namespace: "", + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeGateway, + Name: "test-gateway", + Namespace: "custom-namespace", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set namespace for GA + if tt.ga != nil { + tt.ga.Namespace = "default" + + // Update expected namespaces if they're empty (but only for non-EndpointID types) + for i := range tt.expected { + // Only apply default namespace for Service/Ingress/Gateway types + if tt.expected[i].Namespace == "" && tt.expected[i].Type != agaapi.GlobalAcceleratorEndpointTypeEndpointID { + tt.expected[i].Namespace = tt.ga.Namespace + } + } + } + + result := GetAllEndpointsFromGA(tt.ga) + + // Compare lengths + assert.Equal(t, len(tt.expected), len(result)) + + // Compare contents + if tt.expected != nil { + for i, exp := range tt.expected { + assert.Equal(t, exp.Type, result[i].Type) + assert.Equal(t, exp.Name, result[i].Name) + assert.Equal(t, exp.Namespace, result[i].Namespace) + } + } + }) + } +} + +func TestEndpointReferenceToResourceKey(t *testing.T) { + // Test Service type endpoint + t.Run("Service type endpoint", func(t *testing.T) { + endpoint := EndpointReference{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service", + Namespace: "test-namespace", + } + + resourceKey := endpoint.ToResourceKey() + assert.Equal(t, ResourceType(endpoint.Type), resourceKey.Type) + assert.Equal(t, endpoint.Name, resourceKey.Name.Name) + assert.Equal(t, endpoint.Namespace, resourceKey.Name.Namespace) + }) + + // Test EndpointID type endpoint + t.Run("EndpointID type endpoint", func(t *testing.T) { + endpoint := EndpointReference{ + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + EndpointID: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-service/1234567890", + } + + resourceKey := endpoint.ToResourceKey() + assert.Equal(t, ResourceType(endpoint.Type), resourceKey.Type) + assert.Equal(t, endpoint.EndpointID, resourceKey.Name.Name) + assert.Equal(t, "", resourceKey.Name.Namespace) // Namespace should be empty for EndpointID type + }) +} diff --git a/pkg/aga/model_build_endpoint_group.go b/pkg/aga/model_build_endpoint_group.go new file mode 100644 index 000000000..90d987b86 --- /dev/null +++ b/pkg/aga/model_build_endpoint_group.go @@ -0,0 +1,250 @@ +package aga + +import ( + "context" + "fmt" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +// endpointGroupBuilder builds EndpointGroup model resources +type endpointGroupBuilder interface { + // Build builds all endpoint groups for all listeners + Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error) + + // buildEndpointGroupsForListener builds endpoint groups for a specific listener + buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error) +} + +// NewEndpointGroupBuilder constructs new endpointGroupBuilder +func NewEndpointGroupBuilder(clusterRegion string) endpointGroupBuilder { + return &defaultEndpointGroupBuilder{ + clusterRegion: clusterRegion, + } +} + +var _ endpointGroupBuilder = &defaultEndpointGroupBuilder{} + +type defaultEndpointGroupBuilder struct { + clusterRegion string +} + +// Build builds EndpointGroup model resources +func (b *defaultEndpointGroupBuilder) Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error) { + if listeners == nil || len(listeners) == 0 { + return nil, nil + } + + var result []*agamodel.EndpointGroup + + // Create a map of all listener port ranges + listenerPortRanges := make(map[string][]agamodel.PortRange) // Maps listener ID to its port ranges + for _, listener := range listeners { + listenerPortRanges[listener.ID()] = listener.Spec.PortRanges + } + + for i, listener := range listeners { + listenerConfig := listenerConfigs[i] + if listenerConfig.EndpointGroups == nil { + continue + } + + listenerEndpointGroups, err := b.buildEndpointGroupsForListener(ctx, stack, listener, *listenerConfig.EndpointGroups, i) + if err != nil { + return nil, err + } + result = append(result, listenerEndpointGroups...) + } + + // Validate endpoint ports in all port overrides across all listeners + if err := b.validateEndpointPortOverridesCrossListeners(result, listenerPortRanges); err != nil { + return nil, err + } + + return result, nil +} + +// validateEndpointPortOverridesCrossListeners performs validations for endpoint port overrides across all listeners +func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesCrossListeners(endpointGroups []*agamodel.EndpointGroup, listenerPortRanges map[string][]agamodel.PortRange) error { + // Track endpoint port usage across all endpoint groups + endpointPortUsage := make(map[int32]string) // Maps endpoint port to listener ID + + // Check all endpoint groups for port overrides + for _, endpointGroup := range endpointGroups { + listenerID := endpointGroup.Listener.ID() + + for _, portOverride := range endpointGroup.Spec.PortOverrides { + endpointPort := portOverride.EndpointPort + + // Rule 1: Check if endpoint port is within any listener's port range + if err := b.validateEndpointPortOverridesWithinListener(endpointPort, listenerPortRanges); err != nil { + return err + } + + // Rule 2: Check for duplicate endpoint port usage across listeners + if existingListenerID, exists := endpointPortUsage[endpointPort]; exists && existingListenerID != listenerID { + return fmt.Errorf("duplicate endpoint port %d: the same endpoint port cannot be used in port overrides from different listeners (used in %s and %s)", + endpointPort, existingListenerID, listenerID) + } + + // Register this endpoint port usage + endpointPortUsage[endpointPort] = listenerID + } + } + + return nil +} + +// validateEndpointPortOverridesWithinListener checks if an endpoint port is within any listener's port range +func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesWithinListener(endpointPort int32, listenerPortRanges map[string][]agamodel.PortRange) error { + for listenerID, portRanges := range listenerPortRanges { + if IsPortInRanges(endpointPort, portRanges) { + // Find the specific port range for the error message + for _, portRange := range portRanges { + if endpointPort >= portRange.FromPort && endpointPort <= portRange.ToPort { + return fmt.Errorf("endpoint port %d conflicts with listener %s port range %d-%d: endpoint port cannot be included in any listener port range", + endpointPort, listenerID, portRange.FromPort, portRange.ToPort) + } + } + } + } + return nil +} + +// buildEndpointGroupsForListener builds EndpointGroup models for a specific listener +func (b *defaultEndpointGroupBuilder) buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error) { + var result []*agamodel.EndpointGroup + + for i, endpointGroup := range endpointGroups { + spec, err := b.buildEndpointGroupSpec(ctx, listener, endpointGroup) + if err != nil { + return nil, err + } + + resourceID := fmt.Sprintf("EndpointGroup-%d-%d", listenerIndex, i) + endpointGroupModel := agamodel.NewEndpointGroup(stack, resourceID, spec, listener) + result = append(result, endpointGroupModel) + } + + return result, nil +} + +// buildEndpointGroupSpec builds the EndpointGroupSpec for a single EndpointGroup model resource +func (b *defaultEndpointGroupBuilder) buildEndpointGroupSpec(ctx context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (agamodel.EndpointGroupSpec, error) { + region, err := b.determineRegion(endpointGroup) + if err != nil { + return agamodel.EndpointGroupSpec{}, err + } + + // Handle trafficDialPercentage + trafficDialPercentage := endpointGroup.TrafficDialPercentage + + portOverrides, err := b.buildPortOverrides(ctx, listener, endpointGroup) + if err != nil { + return agamodel.EndpointGroupSpec{}, err + } + + return agamodel.EndpointGroupSpec{ + ListenerARN: listener.ListenerARN(), + Region: region, + TrafficDialPercentage: trafficDialPercentage, + PortOverrides: portOverrides, + }, nil +} + +// validateListenerPortOverrideWithinListenerPortRanges ensures all listener ports used in port overrides are +// contained within the listener's port ranges +func (b *defaultEndpointGroupBuilder) validateListenerPortOverrideWithinListenerPortRanges(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error { + if len(portOverrides) == 0 { + return nil + } + + for _, portOverride := range portOverrides { + listenerPort := portOverride.ListenerPort + if !IsPortInRanges(listenerPort, listener.Spec.PortRanges) { + return fmt.Errorf("port override listener port %d is not within any listener port ranges - this will cause AWS Global Accelerator to reject the configuration", listenerPort) + } + } + return nil +} + +// determineRegion determines the region for the endpoint group +func (b *defaultEndpointGroupBuilder) determineRegion(endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (string, error) { + // Use explicit region from endpoint group if specified + if endpointGroup.Region != nil && awssdk.ToString(endpointGroup.Region) != "" { + return awssdk.ToString(endpointGroup.Region), nil + } + + // Default to cluster region if available + if b.clusterRegion != "" { + return b.clusterRegion, nil + } + return "", fmt.Errorf("region is required for endpoint group but neither specified in the endpoint group nor available from cluster configuration") +} + +// buildPortOverrides builds the port overrides for the endpoint group +func (b *defaultEndpointGroupBuilder) buildPortOverrides(_ context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) ([]agamodel.PortOverride, error) { + if endpointGroup.PortOverrides == nil { + return nil, nil + } + + var portOverrides []agamodel.PortOverride + for _, po := range *endpointGroup.PortOverrides { + portOverrides = append(portOverrides, agamodel.PortOverride{ + ListenerPort: po.ListenerPort, + EndpointPort: po.EndpointPort, + }) + } + + // Validate all port override rules + if err := b.validatePortOverrides(listener, portOverrides); err != nil { + return []agamodel.PortOverride{}, err + } + + return portOverrides, nil +} + +// validateNoDuplicatePorts checks both listener and endpoint ports for duplicates in a single pass +func (b *defaultEndpointGroupBuilder) validateNoDuplicatePorts(portOverrides []agamodel.PortOverride) error { + if len(portOverrides) <= 1 { + return nil + } + + listenerPorts := make(map[int32]bool) + endpointPorts := make(map[int32]bool) + + for _, portOverride := range portOverrides { + // Check for duplicate listener ports + listenerPort := portOverride.ListenerPort + if listenerPorts[listenerPort] { + return fmt.Errorf("duplicate listener port %d in port overrides: each listener port can only be used once in port overrides for an endpoint group", listenerPort) + } + listenerPorts[listenerPort] = true + + // Check for duplicate endpoint ports + endpointPort := portOverride.EndpointPort + if endpointPorts[endpointPort] { + return fmt.Errorf("duplicate endpoint port %d in port overrides: each endpoint port can only be used once in port overrides for an endpoint group", endpointPort) + } + endpointPorts[endpointPort] = true + } + + return nil +} + +// validatePortOverrides is a wrapper function that runs all port override validation rules +func (b *defaultEndpointGroupBuilder) validatePortOverrides(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error { + // Validate listener port overrides against listener port ranges + if err := b.validateListenerPortOverrideWithinListenerPortRanges(listener, portOverrides); err != nil { + return err + } + + // Check for duplicate listener and endpoint ports within this endpoint group's port overrides + if err := b.validateNoDuplicatePorts(portOverrides); err != nil { + return err + } + + return nil +} diff --git a/pkg/aga/model_build_endpoint_group_test.go b/pkg/aga/model_build_endpoint_group_test.go new file mode 100644 index 000000000..5e51e389b --- /dev/null +++ b/pkg/aga/model_build_endpoint_group_test.go @@ -0,0 +1,865 @@ +package aga + +import ( + "context" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" + "testing" + + "github.com/stretchr/testify/assert" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +func Test_defaultEndpointGroupBuilder_determineRegion(t *testing.T) { + tests := []struct { + name string + endpointGroup agaapi.GlobalAcceleratorEndpointGroup + clusterRegion string + expectedRegion string + expectError bool + expectErrorString string + }{ + { + name: "region specified in endpoint group", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Region: awssdk.String("us-west-2"), + }, + clusterRegion: "us-east-1", + expectedRegion: "us-west-2", + expectError: false, + }, + { + name: "region specified in endpoint group even with empty cluster region", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Region: awssdk.String("eu-west-1"), + }, + clusterRegion: "", + expectedRegion: "eu-west-1", + expectError: false, + }, + { + name: "region not specified in endpoint group, use cluster region", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{}, + clusterRegion: "us-east-1", + expectedRegion: "us-east-1", + expectError: false, + }, + { + name: "neither region specified nor cluster region available", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{}, + clusterRegion: "", + expectError: true, + expectErrorString: "region is required for endpoint group but neither specified in the endpoint group nor available from cluster configuration", + }, + { + name: "empty region string in endpoint group, fall back to cluster region", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Region: awssdk.String(""), + }, + clusterRegion: "ap-southeast-1", + expectedRegion: "ap-southeast-1", + expectError: false, + }, + { + name: "empty region string in endpoint group and no cluster region", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Region: awssdk.String(""), + }, + clusterRegion: "", + expectError: true, + expectErrorString: "region is required for endpoint group but neither specified in the endpoint group nor available from cluster configuration", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create endpointGroupBuilder + builder := &defaultEndpointGroupBuilder{ + clusterRegion: tt.clusterRegion, + } + + // Call determineRegion + region, err := builder.determineRegion(tt.endpointGroup) + + // Check if error was expected + if tt.expectError { + assert.Error(t, err) + if tt.expectErrorString != "" { + assert.Contains(t, err.Error(), tt.expectErrorString) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedRegion, region) + } + }) + } +} + +func Test_defaultEndpointGroupBuilder_buildPortOverrides(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + // Helper function to create a listener with specific ID and port ranges + createTestListener := func(id string, portRanges []agamodel.PortRange) *agamodel.Listener { + return &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", id), + Spec: agamodel.ListenerSpec{ + PortRanges: portRanges, + }, + } + } + + // Helper function to create port overrides + createPortOverrides := func(overrides ...agaapi.PortOverride) *[]agaapi.PortOverride { + if len(overrides) == 0 { + return nil + } + return &overrides + } + + tests := []struct { + name string + listener *agamodel.Listener + endpointGroup agaapi.GlobalAcceleratorEndpointGroup + want []agamodel.PortOverride + expectErr bool + expectErrMatch string + }{ + { + name: "no port overrides", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: nil, + }, + want: nil, + expectErr: false, + }, + { + name: "empty port overrides", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: &[]agaapi.PortOverride{}, + }, + want: nil, + expectErr: false, + }, + { + name: "valid single port override", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 80, + EndpointPort: 8080, + }, + ), + }, + want: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + expectErr: false, + }, + { + name: "valid multiple port overrides", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 80, + EndpointPort: 8080, + }, + agaapi.PortOverride{ + ListenerPort: 443, + EndpointPort: 8443, + }, + ), + }, + want: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + }, + expectErr: false, + }, + { + name: "listener port outside range", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 443, // Not in listener port range + EndpointPort: 8443, + }, + ), + }, + want: nil, + expectErr: true, + expectErrMatch: "port override listener port 443 is not within any listener port ranges", + }, + { + name: "duplicate listener port", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 80, + EndpointPort: 8080, + }, + agaapi.PortOverride{ + ListenerPort: 80, // Duplicate listener port + EndpointPort: 9090, + }, + ), + }, + want: nil, + expectErr: true, + expectErrMatch: "duplicate listener port 80 in port overrides", + }, + { + name: "duplicate endpoint port", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 80, + EndpointPort: 8080, + }, + agaapi.PortOverride{ + ListenerPort: 443, + EndpointPort: 8080, // Duplicate endpoint port + }, + ), + }, + want: nil, + expectErr: true, + expectErrMatch: "duplicate endpoint port 8080 in port overrides", + }, + { + name: "port range check for listener port", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 85, // Within the range + EndpointPort: 8085, + }, + ), + }, + want: []agamodel.PortOverride{ + { + ListenerPort: 85, + EndpointPort: 8085, + }, + }, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + // Create endpointGroupBuilder + builder := &defaultEndpointGroupBuilder{ + clusterRegion: "us-west-2", + } + + // Call buildPortOverrides + got, err := builder.buildPortOverrides(ctx, tt.listener, tt.endpointGroup) + + // Check for expected error + if tt.expectErr { + assert.Error(t, err) + if tt.expectErrMatch != "" { + assert.Contains(t, err.Error(), tt.expectErrMatch) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func Test_defaultEndpointGroupBuilder_validateEndpointPortOverridesWithinListener(t *testing.T) { + tests := []struct { + name string + endpointPort int32 + listenerPortRanges map[string][]agamodel.PortRange + wantErr bool + expectErrContains string + }{ + { + name: "endpoint port outside all listener port ranges", + endpointPort: 8080, + listenerPortRanges: map[string][]agamodel.PortRange{ + "l-1": { + { + FromPort: 80, + ToPort: 80, + }, + }, + "l-2": { + { + FromPort: 443, + ToPort: 443, + }, + }, + }, + wantErr: false, + }, + { + name: "endpoint port inside a listener port range", + endpointPort: 450, + listenerPortRanges: map[string][]agamodel.PortRange{ + "l-1": { + { + FromPort: 80, + ToPort: 80, + }, + }, + "l-2": { + { + FromPort: 400, + ToPort: 500, // Includes 450 + }, + }, + }, + wantErr: true, + expectErrContains: "endpoint port 450 conflicts with listener l-2 port range 400-500", + }, + { + name: "endpoint port at boundary of listener port range", + endpointPort: 400, // Exactly at FromPort boundary + listenerPortRanges: map[string][]agamodel.PortRange{ + "l-1": { + { + FromPort: 400, + ToPort: 500, + }, + }, + }, + wantErr: true, + expectErrContains: "endpoint port 400 conflicts with listener l-1 port range 400-500", + }, + { + name: "endpoint port at upper boundary of listener port range", + endpointPort: 500, // Exactly at ToPort boundary + listenerPortRanges: map[string][]agamodel.PortRange{ + "l-1": { + { + FromPort: 400, + ToPort: 500, + }, + }, + }, + wantErr: true, + expectErrContains: "endpoint port 500 conflicts with listener l-1 port range 400-500", + }, + { + name: "multiple listener port ranges, endpoint port in one range", + endpointPort: 1024, + listenerPortRanges: map[string][]agamodel.PortRange{ + "l-1": { + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + "l-2": { + { + FromPort: 1000, + ToPort: 2000, // Includes 1024 + }, + { + FromPort: 3000, + ToPort: 4000, + }, + }, + }, + wantErr: true, + expectErrContains: "endpoint port 1024 conflicts with listener l-2 port range 1000-2000", + }, + { + name: "empty listener port ranges", + endpointPort: 8080, + listenerPortRanges: map[string][]agamodel.PortRange{}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + builder := &defaultEndpointGroupBuilder{ + clusterRegion: "us-west-2", + } + err := builder.validateEndpointPortOverridesWithinListener(tt.endpointPort, tt.listenerPortRanges) + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectErrContains) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_defaultEndpointGroupBuilder_validateNoDuplicatePorts(t *testing.T) { + tests := []struct { + name string + portOverrides []agamodel.PortOverride + wantErr bool + expectErrContains string + portType string // "listener" or "endpoint" + }{ + { + name: "no port overrides", + portOverrides: []agamodel.PortOverride{}, + wantErr: false, + }, + { + name: "single port override", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + wantErr: false, + }, + { + name: "multiple port overrides with unique ports", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + { + ListenerPort: 8000, + EndpointPort: 9000, + }, + }, + wantErr: false, + }, + { + name: "multiple port overrides with duplicate listener ports", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 80, // Duplicate listener port + EndpointPort: 9090, // Different endpoint port + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + }, + wantErr: true, + expectErrContains: "duplicate listener port 80 in port overrides", + portType: "listener", + }, + { + name: "multiple duplicate listener ports", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + { + ListenerPort: 80, // First duplicate + EndpointPort: 9090, + }, + { + ListenerPort: 443, // Second duplicate + EndpointPort: 9443, + }, + }, + wantErr: true, + expectErrContains: "duplicate listener port 80 in port overrides", + portType: "listener", + }, + { + name: "multiple port overrides with duplicate endpoint ports", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8080, // Duplicate endpoint port + }, + { + ListenerPort: 8000, + EndpointPort: 9000, + }, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 8080 in port overrides", + portType: "endpoint", + }, + { + name: "multiple duplicate endpoint ports", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + { + ListenerPort: 8000, + EndpointPort: 8080, // First duplicate + }, + { + ListenerPort: 9000, + EndpointPort: 8443, // Second duplicate + }, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 8080 in port overrides", + portType: "endpoint", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + builder := &defaultEndpointGroupBuilder{ + clusterRegion: "us-west-2", + } + err := builder.validateNoDuplicatePorts(tt.portOverrides) + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectErrContains) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_defaultEndpointGroupBuilder_validateEndpointPortOverridesCrossListeners(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + // Helper function to create a listener with specific ID and port ranges + createTestListener := func(id string, portRanges []agamodel.PortRange) *agamodel.Listener { + return &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", id), + Spec: agamodel.ListenerSpec{ + PortRanges: portRanges, + }, + } + } + + // Helper function to create an endpoint group with specific listener and port overrides + createTestEndpointGroup := func(id string, region string, listener *agamodel.Listener, portOverrides []agamodel.PortOverride) *agamodel.EndpointGroup { + return &agamodel.EndpointGroup{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", id), + Listener: listener, + Spec: agamodel.EndpointGroupSpec{ + Region: region, + PortOverrides: portOverrides, + }, + } + } + + tests := []struct { + name string + endpointGroups []*agamodel.EndpointGroup + listenerPortRanges map[string][]agamodel.PortRange + wantErr bool + expectErrContains string + }{ + { + name: "no endpoint groups", + endpointGroups: []*agamodel.EndpointGroup{}, + listenerPortRanges: map[string][]agamodel.PortRange{}, + wantErr: false, + }, + { + name: "single endpoint group - no conflicts possible", + endpointGroups: func() []*agamodel.EndpointGroup { + listener := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + }, + wantErr: false, + }, + { + name: "multiple endpoint groups, same listener - no conflicts", + endpointGroups: func() []*agamodel.EndpointGroup { + listener := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener, []agamodel.PortOverride{ + {ListenerPort: 443, EndpointPort: 8443}, + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": { + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + wantErr: false, + }, + { + name: "multiple endpoint groups, different listeners, no conflicts", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener2, []agamodel.PortOverride{ + {ListenerPort: 443, EndpointPort: 9090}, // Different endpoint port + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + "listener-2": {{FromPort: 443, ToPort: 443}}, + }, + wantErr: false, + }, + { + name: "endpoint port in listener port range - conflict", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, // Range includes 85 + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 85}, // Endpoint port is within listener port range + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 90}}, + }, + wantErr: true, + expectErrContains: "endpoint port 85 conflicts with listener listener-1 port range 80-90", + }, + { + name: "duplicate endpoint port across different listeners - conflict", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener2, []agamodel.PortOverride{ + {ListenerPort: 443, EndpointPort: 8080}, // Same endpoint port as eg-1 + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + "listener-2": {{FromPort: 443, ToPort: 443}}, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 8080: the same endpoint port cannot be used in port overrides from different listeners", + }, + { + name: "multiple duplicate endpoint ports across different listeners", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 8080, ToPort: 8080}, + {FromPort: 8443, ToPort: 8443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 9090}, + {ListenerPort: 443, EndpointPort: 9091}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener2, []agamodel.PortOverride{ + {ListenerPort: 8080, EndpointPort: 9090}, // Duplicate with listener-1 + {ListenerPort: 8443, EndpointPort: 7070}, // Unique + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": { + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + "listener-2": { + {FromPort: 8080, ToPort: 8080}, + {FromPort: 8443, ToPort: 8443}, + }, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 9090: the same endpoint port cannot be used in port overrides from different listeners", + }, + { + name: "multiple endpoint groups with mixed conflicts", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }) + listener3 := createTestListener("listener-3", []agamodel.PortRange{ + {FromPort: 9500, ToPort: 9999}, // Range does not include 8080 + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener2, []agamodel.PortOverride{ + {ListenerPort: 443, EndpointPort: 8080}, // Duplicate endpoint port with listener1 + }), + createTestEndpointGroup("eg-3", "ap-southeast-1", listener3, []agamodel.PortOverride{ + {ListenerPort: 8080, EndpointPort: 9999}, // Endpoint port outside of any listener range + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + "listener-2": {{FromPort: 443, ToPort: 443}}, + "listener-3": {{FromPort: 9500, ToPort: 9999}}, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 8080: the same endpoint port cannot be used in port overrides from different listeners", + }, + { + name: "same endpoint port in different regions - still a conflict", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener2, []agamodel.PortOverride{ + // Even though in different regions, same port across listeners is still a conflict + {ListenerPort: 443, EndpointPort: 8080}, + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + "listener-2": {{FromPort: 443, ToPort: 443}}, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 8080: the same endpoint port cannot be used in port overrides from different listeners", + }, + { + name: "endpoint groups with no port overrides - no conflicts", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, nil), // No port overrides + createTestEndpointGroup("eg-2", "eu-west-1", listener2, nil), // No port overrides + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + "listener-2": {{FromPort: 443, ToPort: 443}}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create endpointGroupBuilder + builder := &defaultEndpointGroupBuilder{ + clusterRegion: "us-west-2", + } + + // Run the validation function + err := builder.validateEndpointPortOverridesCrossListeners(tt.endpointGroups, tt.listenerPortRanges) + + // Check if error was expected + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectErrContains) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/aga/model_builder.go b/pkg/aga/model_builder.go index d4938ab29..c3e1caa5b 100644 --- a/pkg/aga/model_builder.go +++ b/pkg/aga/model_builder.go @@ -61,10 +61,10 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccele // Create fresh builder instances for each reconciliation acceleratorBuilder := NewAcceleratorBuilder(b.trackingProvider, b.clusterName, b.clusterRegion, b.defaultTags, b.externalManagedTags, b.featureGates.Enabled(config.EnableDefaultTagsLowPriority)) + listenerBuilder := NewListenerBuilder() + endpointGroupBuilder := NewEndpointGroupBuilder(b.clusterRegion) // TODO - // endpointGroupBuilder := NewEndpointGroupBuilder() // endpointBuilder := NewEndpointBuilder() - // Build Accelerator accelerator, err := acceleratorBuilder.Build(ctx, stack, ga) if err != nil { @@ -74,17 +74,18 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccele // Build Listeners if specified var listeners []*agamodel.Listener if ga.Spec.Listeners != nil { - // Create builder for listeners and endpoints - listenerBuilder := NewListenerBuilder() listeners, err = listenerBuilder.Build(ctx, stack, accelerator, *ga.Spec.Listeners) if err != nil { return nil, nil, err } + endpointGroups, err := endpointGroupBuilder.Build(ctx, stack, listeners, *ga.Spec.Listeners) + if err != nil { + return nil, nil, err + } + b.logger.V(1).Info("Listener and endpoint groups built", "listeners", listeners, "endpointGroups", endpointGroups) } - b.logger.V(1).Info("Listeners built", "listeners", listeners) - // TODO: Add other resource builders - // endpointGroups, err := endpointGroupBuilder.Build(ctx, stack, listeners, ga.Spec.Listeners) + // TODO: Add endpoint builder // endpoints, err := endpointBuilder.Build(ctx, stack, endpointGroups, ga.Spec.Listeners) return stack, accelerator, nil diff --git a/pkg/aga/reference_tracker.go b/pkg/aga/reference_tracker.go new file mode 100644 index 000000000..0685f08a2 --- /dev/null +++ b/pkg/aga/reference_tracker.go @@ -0,0 +1,138 @@ +package aga + +import ( + "strings" + "sync" + + "github.com/go-logr/logr" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/sets" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" +) + +// ResourceKey uniquely identifies a resource by its type and name +type ResourceKey struct { + Type ResourceType + Name types.NamespacedName +} + +// ReferenceTracker tracks which resources are referenced by which GlobalAccelerators +type ReferenceTracker struct { + mutex sync.RWMutex + resourceMap map[ResourceKey]sets.String // Resource -> Set of GA names + gaRefMap map[types.NamespacedName]sets.Set[ResourceKey] // GA -> Set of resources + logger logr.Logger +} + +// NewReferenceTracker creates a new ReferenceTracker +func NewReferenceTracker(logger logr.Logger) *ReferenceTracker { + return &ReferenceTracker{ + resourceMap: make(map[ResourceKey]sets.String), + gaRefMap: make(map[types.NamespacedName]sets.Set[ResourceKey]), + logger: logger, + } +} + +// UpdateReferencesForGA updates the tracking information for a GlobalAccelerator +func (t *ReferenceTracker) UpdateReferencesForGA(ga *agaapi.GlobalAccelerator, endpoints []EndpointReference) { + t.mutex.Lock() + defer t.mutex.Unlock() + + gaKey := k8s.NamespacedName(ga) + + // Track current resources referenced by this GA + currentResources := sets.New[ResourceKey]() + + // Process each endpoint + for _, endpoint := range endpoints { + resourceKey := endpoint.ToResourceKey() + + currentResources.Insert(resourceKey) + + // Update resource -> GA mapping + if _, exists := t.resourceMap[resourceKey]; !exists { + t.resourceMap[resourceKey] = sets.NewString() + } + t.resourceMap[resourceKey].Insert(gaKey.String()) + + t.logger.V(1).Info("Resource referenced by GA", + "ga", gaKey.String(), + "resourceType", resourceKey.Type, + "resourceName", resourceKey.Name) + } + + // Remove old references + if oldResources, exists := t.gaRefMap[gaKey]; exists { + for resourceKey := range oldResources { + if !currentResources.Has(resourceKey) { + // Resource no longer referenced by this GA + if gaSet, exists := t.resourceMap[resourceKey]; exists { + gaSet.Delete(gaKey.String()) + if gaSet.Len() == 0 { + delete(t.resourceMap, resourceKey) + t.logger.V(1).Info("Resource no longer referenced by any GA", + "resourceType", resourceKey.Type, + "resourceName", resourceKey.Name) + } + } + } + } + } + + // Update GA -> resources mapping + t.gaRefMap[gaKey] = currentResources +} + +// RemoveGA removes all tracking information for a GlobalAccelerator +func (t *ReferenceTracker) RemoveGA(gaKey types.NamespacedName) { + t.mutex.Lock() + defer t.mutex.Unlock() + + if resources, exists := t.gaRefMap[gaKey]; exists { + for resourceKey := range resources { + if gaSet, exists := t.resourceMap[resourceKey]; exists { + gaSet.Delete(gaKey.String()) + if gaSet.Len() == 0 { + delete(t.resourceMap, resourceKey) + t.logger.V(1).Info("Resource no longer referenced by any GA", + "resourceType", resourceKey.Type, + "resourceName", resourceKey.Name) + } + } + } + + delete(t.gaRefMap, gaKey) + } +} + +// IsResourceReferenced checks if a resource is referenced by any GlobalAccelerator +func (t *ReferenceTracker) IsResourceReferenced(resourceKey ResourceKey) bool { + t.mutex.RLock() + defer t.mutex.RUnlock() + + gaSet, exists := t.resourceMap[resourceKey] + return exists && gaSet.Len() > 0 +} + +// GetGAsForResource returns all GlobalAccelerators that reference a resource +func (t *ReferenceTracker) GetGAsForResource(resourceKey ResourceKey) []types.NamespacedName { + t.mutex.RLock() + defer t.mutex.RUnlock() + + var result []types.NamespacedName + + if gaSet, exists := t.resourceMap[resourceKey]; exists { + for gaStr := range gaSet { + parts := strings.Split(gaStr, "/") + if len(parts) == 2 { + result = append(result, types.NamespacedName{ + Namespace: parts[0], + Name: parts[1], + }) + } + } + } + + return result +} diff --git a/pkg/aga/reference_tracker_test.go b/pkg/aga/reference_tracker_test.go new file mode 100644 index 000000000..b7873294e --- /dev/null +++ b/pkg/aga/reference_tracker_test.go @@ -0,0 +1,523 @@ +package aga + +import ( + "testing" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" +) + +func TestNewReferenceTracker(t *testing.T) { + // Test creating a new reference tracker + logger := logr.Discard() + tracker := NewReferenceTracker(logger) + + // Verify that the tracker is initialized properly + assert.NotNil(t, tracker) + assert.NotNil(t, tracker.resourceMap) + assert.NotNil(t, tracker.gaRefMap) + assert.Equal(t, 0, len(tracker.resourceMap)) + assert.Equal(t, 0, len(tracker.gaRefMap)) +} + +func TestReferenceTracker_UpdateReferencesForGA(t *testing.T) { + // Helper function to create a string pointer + strPtr := func(s string) *string { + return &s + } + + // Test cases + tests := []struct { + name string + ga *agaapi.GlobalAccelerator + expectedResources int + expectedReferences map[ResourceKey][]string + }{ + { + name: "GA with no endpoints", + ga: &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ga-no-endpoints", + Namespace: "test-ns", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{}, + }, + }, + expectedResources: 0, + expectedReferences: map[ResourceKey][]string{}, + }, + { + name: "GA with service endpoints", + ga: &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ga-service-endpoints", + Namespace: "test-ns", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service1"), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service2"), + }, + }, + }, + }, + }, + }, + }, + }, + expectedResources: 2, + expectedReferences: map[ResourceKey][]string{ + { + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "service1"}, + }: {"test-ns/ga-service-endpoints"}, + { + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "service2"}, + }: {"test-ns/ga-service-endpoints"}, + }, + }, + { + name: "GA with mixed endpoints", + ga: &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ga-mixed-endpoints", + Namespace: "test-ns", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service1"), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: strPtr("ingress1"), + Namespace: strPtr("other-ns"), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + EndpointID: strPtr("arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test/1234567890"), + }, + }, + }, + }, + }, + }, + }, + }, + expectedResources: 3, + expectedReferences: map[ResourceKey][]string{ + { + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "service1"}, + }: {"test-ns/ga-mixed-endpoints"}, + { + Type: IngressResourceType, + Name: types.NamespacedName{Namespace: "other-ns", Name: "ingress1"}, + }: {"test-ns/ga-mixed-endpoints"}, + { + Type: ResourceType(agaapi.GlobalAcceleratorEndpointTypeEndpointID), + Name: types.NamespacedName{Namespace: "", Name: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test/1234567890"}, + }: {"test-ns/ga-mixed-endpoints"}, + }, + }, + } + + // Run test cases + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create tracker + tracker := NewReferenceTracker(logr.Discard()) + + endpoints := GetAllEndpointsFromGA(tt.ga) + // Update references + tracker.UpdateReferencesForGA(tt.ga, endpoints) + + // Check number of tracked resources + gaKey := types.NamespacedName{Namespace: tt.ga.Namespace, Name: tt.ga.Name} + resources, exists := tracker.gaRefMap[gaKey] + if tt.expectedResources == 0 { + assert.Equal(t, tt.expectedResources, len(resources)) + } else { + assert.True(t, exists) + assert.Equal(t, tt.expectedResources, len(resources)) + } + + // Check resource references + for resourceKey, expectedGAs := range tt.expectedReferences { + gaSet, exists := tracker.resourceMap[resourceKey] + assert.True(t, exists) + assert.Equal(t, len(expectedGAs), gaSet.Len()) + + for _, expectedGA := range expectedGAs { + assert.True(t, gaSet.Has(expectedGA)) + } + } + }) + } +} + +func TestReferenceTracker_UpdateReferencesForGA_RemoveStaleReferences(t *testing.T) { + // Helper function to create a string pointer + strPtr := func(s string) *string { + return &s + } + + // Create GA with initial endpoints + ga := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ga-test", + Namespace: "test-ns", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service1"), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service2"), + }, + }, + }, + }, + }, + }, + }, + } + + // Create tracker and add initial references + tracker := NewReferenceTracker(logr.Discard()) + + endpoints := GetAllEndpointsFromGA(ga) + tracker.UpdateReferencesForGA(ga, endpoints) + + // Verify initial state + service1Key := ResourceKey{ + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "service1"}, + } + service2Key := ResourceKey{ + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "service2"}, + } + service3Key := ResourceKey{ + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "service3"}, + } + + // Both services should be referenced + assert.True(t, tracker.IsResourceReferenced(service1Key)) + assert.True(t, tracker.IsResourceReferenced(service2Key)) + + // Now modify the GA to remove service2 and add service3 + ga.Spec = agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service1"), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service3"), + }, + }, + }, + }, + }, + }, + } + + // Update references with modified GA + endpoints = GetAllEndpointsFromGA(ga) + tracker.UpdateReferencesForGA(ga, endpoints) + + // Verify that service1 is still referenced, service2 is no longer referenced, and service3 is now referenced + assert.True(t, tracker.IsResourceReferenced(service1Key)) + assert.False(t, tracker.IsResourceReferenced(service2Key)) + assert.True(t, tracker.IsResourceReferenced(service3Key)) +} + +func TestReferenceTracker_RemoveGA(t *testing.T) { + // Helper function to create a string pointer + strPtr := func(s string) *string { + return &s + } + + // Create two GAs with overlapping references + ga1 := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ga1", + Namespace: "test-ns", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service1"), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service2"), + }, + }, + }, + }, + }, + }, + }, + } + + ga2 := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ga2", + Namespace: "test-ns", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service2"), + }, + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service3"), + }, + }, + }, + }, + }, + }, + }, + } + + // Create tracker and add references from both GAs + tracker := NewReferenceTracker(logr.Discard()) + endpoints1 := GetAllEndpointsFromGA(ga1) + endpoints2 := GetAllEndpointsFromGA(ga2) + tracker.UpdateReferencesForGA(ga1, endpoints1) + tracker.UpdateReferencesForGA(ga2, endpoints2) + + // Resource keys + service1Key := ResourceKey{ + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "service1"}, + } + service2Key := ResourceKey{ + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "service2"}, + } + service3Key := ResourceKey{ + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "service3"}, + } + + // Verify initial state - all services should be referenced + assert.True(t, tracker.IsResourceReferenced(service1Key)) + assert.True(t, tracker.IsResourceReferenced(service2Key)) + assert.True(t, tracker.IsResourceReferenced(service3Key)) + + // Remove ga1 + ga1Key := types.NamespacedName{Namespace: "test-ns", Name: "ga1"} + tracker.RemoveGA(ga1Key) + + // Verify that service1 is no longer referenced, service2 is still referenced by ga2, and service3 is still referenced + assert.False(t, tracker.IsResourceReferenced(service1Key)) + assert.True(t, tracker.IsResourceReferenced(service2Key)) + assert.True(t, tracker.IsResourceReferenced(service3Key)) + + // Remove ga2 + ga2Key := types.NamespacedName{Namespace: "test-ns", Name: "ga2"} + tracker.RemoveGA(ga2Key) + + // Verify that no services are referenced anymore + assert.False(t, tracker.IsResourceReferenced(service1Key)) + assert.False(t, tracker.IsResourceReferenced(service2Key)) + assert.False(t, tracker.IsResourceReferenced(service3Key)) + + // Verify that gaRefMap is empty + assert.Equal(t, 0, len(tracker.gaRefMap)) +} + +func TestReferenceTracker_IsResourceReferenced(t *testing.T) { + // Helper function to create a string pointer + strPtr := func(s string) *string { + return &s + } + + // Create GA + ga := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ga-test", + Namespace: "test-ns", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("service1"), + }, + }, + }, + }, + }, + }, + }, + } + + // Create tracker and add references + tracker := NewReferenceTracker(logr.Discard()) + endpoints := GetAllEndpointsFromGA(ga) + tracker.UpdateReferencesForGA(ga, endpoints) + + // Resource keys - one that exists and one that doesn't + existingResourceKey := ResourceKey{ + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "service1"}, + } + nonExistingResourceKey := ResourceKey{ + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "non-existing-service"}, + } + + // Test IsResourceReferenced + assert.True(t, tracker.IsResourceReferenced(existingResourceKey)) + assert.False(t, tracker.IsResourceReferenced(nonExistingResourceKey)) +} + +func TestReferenceTracker_GetGAsForResource(t *testing.T) { + // Helper function to create a string pointer + strPtr := func(s string) *string { + return &s + } + + // Create GAs + ga1 := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ga1", + Namespace: "test-ns", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("shared-service"), + }, + }, + }, + }, + }, + }, + }, + } + + ga2 := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ga2", + Namespace: "test-ns", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: strPtr("shared-service"), + }, + }, + }, + }, + }, + }, + }, + } + + // Create tracker and add references + tracker := NewReferenceTracker(logr.Discard()) + endpoints1 := GetAllEndpointsFromGA(ga1) + endpoints2 := GetAllEndpointsFromGA(ga2) + tracker.UpdateReferencesForGA(ga1, endpoints1) + tracker.UpdateReferencesForGA(ga2, endpoints2) + + // Resource key for the shared service + sharedServiceKey := ResourceKey{ + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "shared-service"}, + } + + // Resource key for a non-existing service + nonExistingServiceKey := ResourceKey{ + Type: ServiceResourceType, + Name: types.NamespacedName{Namespace: "test-ns", Name: "non-existing-service"}, + } + + // Test GetGAsForResource for shared service + gasForSharedService := tracker.GetGAsForResource(sharedServiceKey) + assert.Equal(t, 2, len(gasForSharedService)) + + // Verify that both GAs are returned + ga1Key := types.NamespacedName{Namespace: "test-ns", Name: "ga1"} + ga2Key := types.NamespacedName{Namespace: "test-ns", Name: "ga2"} + + foundGA1 := false + foundGA2 := false + for _, gaKey := range gasForSharedService { + if gaKey == ga1Key { + foundGA1 = true + } + if gaKey == ga2Key { + foundGA2 = true + } + } + assert.True(t, foundGA1) + assert.True(t, foundGA2) + + // Test GetGAsForResource for non-existing service + gasForNonExistingService := tracker.GetGAsForResource(nonExistingServiceKey) + assert.Equal(t, 0, len(gasForNonExistingService)) +} diff --git a/pkg/aga/resource_clients.go b/pkg/aga/resource_clients.go new file mode 100644 index 000000000..f011ea433 --- /dev/null +++ b/pkg/aga/resource_clients.go @@ -0,0 +1,84 @@ +package aga + +import ( + "context" + + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/kubernetes" + gwv1 "sigs.k8s.io/gateway-api/apis/v1" + gwclientset "sigs.k8s.io/gateway-api/pkg/client/clientset/versioned" +) + +// ServiceClient adapts Kubernetes Service client to ResourceClient +type ServiceClient struct { + client kubernetes.Interface + namespace string +} + +func NewServiceClient(client kubernetes.Interface, namespace string) *ServiceClient { + return &ServiceClient{ + client: client, + namespace: namespace, + } +} + +func (c *ServiceClient) List(ctx context.Context, opts metav1.ListOptions) (runtime.Object, error) { + return c.client.CoreV1().Services(c.namespace).List(ctx, opts) +} + +func (c *ServiceClient) Watch(ctx context.Context, opts metav1.ListOptions) (watch.Interface, error) { + return c.client.CoreV1().Services(c.namespace).Watch(ctx, opts) +} + +// IngressClient adapts Kubernetes Ingress client to ResourceClient +type IngressClient struct { + client kubernetes.Interface + namespace string +} + +func NewIngressClient(client kubernetes.Interface, namespace string) *IngressClient { + return &IngressClient{ + client: client, + namespace: namespace, + } +} + +func (c *IngressClient) List(ctx context.Context, opts metav1.ListOptions) (runtime.Object, error) { + return c.client.NetworkingV1().Ingresses(c.namespace).List(ctx, opts) +} + +func (c *IngressClient) Watch(ctx context.Context, opts metav1.ListOptions) (watch.Interface, error) { + return c.client.NetworkingV1().Ingresses(c.namespace).Watch(ctx, opts) +} + +// GatewayClient adapts Gateway API client to ResourceClient +type GatewayClient struct { + client gwclientset.Interface + namespace string +} + +func NewGatewayClient(client gwclientset.Interface, namespace string) *GatewayClient { + return &GatewayClient{ + client: client, + namespace: namespace, + } +} + +func (c *GatewayClient) List(ctx context.Context, opts metav1.ListOptions) (runtime.Object, error) { + return c.client.GatewayV1().Gateways(c.namespace).List(ctx, opts) +} + +func (c *GatewayClient) Watch(ctx context.Context, opts metav1.ListOptions) (watch.Interface, error) { + return c.client.GatewayV1().Gateways(c.namespace).Watch(ctx, opts) +} + +// Create example objects for type info +var ( + ExampleService = &corev1.Service{} + ExampleIngress = &networkingv1.Ingress{} + ExampleGateway = &gwv1.Gateway{} +) diff --git a/pkg/aga/resource_store.go b/pkg/aga/resource_store.go new file mode 100644 index 000000000..2908f5c61 --- /dev/null +++ b/pkg/aga/resource_store.go @@ -0,0 +1,92 @@ +package aga + +import ( + "github.com/go-logr/logr" + "k8s.io/client-go/tools/cache" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/event" +) + +// ResourceStore is a generic implementation of cache.Store for Kubernetes resources +type ResourceStore[T client.Object] struct { + store cache.Store + eventChan chan<- event.GenericEvent + logger logr.Logger +} + +// NewResourceStore creates a new ResourceStore for a specific resource type +func NewResourceStore[T client.Object](eventChan chan<- event.GenericEvent, keyFunc cache.KeyFunc, logger logr.Logger) *ResourceStore[T] { + return &ResourceStore[T]{ + store: cache.NewStore(keyFunc), + eventChan: eventChan, + logger: logger, + } +} + +var _ cache.Store = &ResourceStore[client.Object]{} + +// Add adds the given object to the store +func (s *ResourceStore[T]) Add(obj interface{}) error { + if err := s.store.Add(obj); err != nil { + return err + } + s.logger.V(1).Info("Resource created or updated", "resource", obj) + s.eventChan <- event.GenericEvent{ + Object: obj.(T), + } + return nil +} + +// Update updates the given object in the store +func (s *ResourceStore[T]) Update(obj interface{}) error { + if err := s.store.Update(obj); err != nil { + return err + } + s.logger.V(1).Info("Resource updated", "resource", obj) + s.eventChan <- event.GenericEvent{ + Object: obj.(T), + } + return nil +} + +// Delete deletes the given object from the store +func (s *ResourceStore[T]) Delete(obj interface{}) error { + if err := s.store.Delete(obj); err != nil { + return err + } + s.logger.V(1).Info("Resource deleted", "resource", obj) + s.eventChan <- event.GenericEvent{ + Object: obj.(T), + } + return nil +} + +// Replace will delete the contents of the store, using instead the given list +func (s *ResourceStore[T]) Replace(list []interface{}, resourceVersion string) error { + return s.store.Replace(list, resourceVersion) +} + +// Resync is meaningless in the terms appearing here +func (s *ResourceStore[T]) Resync() error { + return s.store.Resync() +} + +// List returns a list of all the currently non-empty accumulators +func (s *ResourceStore[T]) List() []interface{} { + return s.store.List() +} + +// ListKeys returns a list of all the keys currently associated with non-empty accumulators +func (s *ResourceStore[T]) ListKeys() []string { + return s.store.ListKeys() +} + +// Get returns the accumulator associated with the given object's key +func (s *ResourceStore[T]) Get(obj interface{}) (item interface{}, exists bool, err error) { + return s.store.Get(obj) +} + +// GetByKey returns the accumulator associated with the given key +func (s *ResourceStore[T]) GetByKey(key string) (item interface{}, exists bool, err error) { + return s.store.GetByKey(key) +} diff --git a/pkg/aga/resource_watcher.go b/pkg/aga/resource_watcher.go new file mode 100644 index 000000000..aacb99b31 --- /dev/null +++ b/pkg/aga/resource_watcher.go @@ -0,0 +1,96 @@ +package aga + +import ( + "context" + "fmt" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/tools/cache" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +// ResourceWatcher is a generic implementation for watching Kubernetes resources +type ResourceWatcher struct { + store cache.Store + reflector *cache.Reflector + consumers sets.String // Set of GA names that reference this resource + stopCh chan struct{} +} + +// ResourceClient is an interface for common operations on a resource +type ResourceClient interface { + List(ctx context.Context, opts metav1.ListOptions) (runtime.Object, error) + Watch(ctx context.Context, opts metav1.ListOptions) (watch.Interface, error) +} + +// NewResourceWatcher creates a new ResourceWatcher for a specific resource +func NewResourceWatcher( + namespace, name string, + resourceClient ResourceClient, + store cache.Store, + exampleObject client.Object, +) *ResourceWatcher { + fieldSelector := fields.Set{"metadata.name": name}.AsSelector().String() + + listFunc := func(options metav1.ListOptions) (runtime.Object, error) { + options.FieldSelector = fieldSelector + return resourceClient.List(context.Background(), options) + } + + watchFunc := func(options metav1.ListOptions) (watch.Interface, error) { + options.FieldSelector = fieldSelector + return resourceClient.Watch(context.Background(), options) + } + + rt := cache.NewNamedReflector( + fmt.Sprintf("%T-%s/%s", exampleObject, namespace, name), + &cache.ListWatch{ListFunc: listFunc, WatchFunc: watchFunc}, + exampleObject, + store, + 0, + ) + + watcher := &ResourceWatcher{ + store: store, + reflector: rt, + consumers: sets.NewString(), + stopCh: make(chan struct{}), + } + + go watcher.Start() + return watcher +} + +// Start runs the reflector +func (w *ResourceWatcher) Start() { + w.reflector.Run(w.stopCh) +} + +// Stop stops the reflector +func (w *ResourceWatcher) Stop() { + close(w.stopCh) +} + +// AddConsumer adds a consumer (GlobalAccelerator) to the watcher +func (w *ResourceWatcher) AddConsumer(consumerID string) { + w.consumers.Insert(consumerID) +} + +// RemoveConsumer removes a consumer from the watcher +func (w *ResourceWatcher) RemoveConsumer(consumerID string) { + w.consumers.Delete(consumerID) +} + +// HasConsumers checks if the watcher has any consumers +func (w *ResourceWatcher) HasConsumers() bool { + return w.consumers.Len() > 0 +} + +// HasConsumer checks if the watcher has a specific consumer +func (w *ResourceWatcher) HasConsumer(consumerID string) bool { + return w.consumers.Has(consumerID) +} diff --git a/pkg/aga/utils.go b/pkg/aga/utils.go index 1f067e25e..f4e96ee37 100644 --- a/pkg/aga/utils.go +++ b/pkg/aga/utils.go @@ -4,8 +4,19 @@ import ( "strings" "sigs.k8s.io/aws-load-balancer-controller/pkg/config" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" ) +// IsPortInRanges checks if a port is within any of the specified port ranges +func IsPortInRanges(port int32, portRanges []agamodel.PortRange) bool { + for _, portRange := range portRanges { + if portRange.FromPort <= port && port <= portRange.ToPort { + return true + } + } + return false +} + // IsAGAControllerEnabled checks if the AGA controller is both enabled via feature gate // and if the region is in a partition that supports Global Accelerator func IsAGAControllerEnabled(featureGates config.FeatureGates, region string) bool { diff --git a/pkg/aga/utils_test.go b/pkg/aga/utils_test.go index 6bfa40c5a..e029259c4 100644 --- a/pkg/aga/utils_test.go +++ b/pkg/aga/utils_test.go @@ -3,7 +3,9 @@ package aga import ( "testing" + "github.com/stretchr/testify/assert" "sigs.k8s.io/aws-load-balancer-controller/pkg/config" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" ) func TestIsAGAControllerEnabled(t *testing.T) { @@ -93,3 +95,125 @@ func TestIsAGAControllerEnabled(t *testing.T) { }) } } + +func TestIsPortInRanges(t *testing.T) { + tests := []struct { + name string + port int32 + portRanges []agamodel.PortRange + expected bool + }{ + { + name: "port within single range", + port: 85, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + }, + expected: true, + }, + { + name: "port at lower boundary", + port: 80, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + }, + expected: true, + }, + { + name: "port at upper boundary", + port: 100, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + }, + expected: true, + }, + { + name: "port below range", + port: 79, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + }, + expected: false, + }, + { + name: "port above range", + port: 101, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + }, + expected: false, + }, + { + name: "port within one of multiple ranges", + port: 443, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8090, + }, + }, + expected: true, + }, + { + name: "port not within any of multiple ranges", + port: 8000, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8090, + }, + }, + expected: false, + }, + { + name: "empty port ranges", + port: 80, + portRanges: []agamodel.PortRange{}, + expected: false, + }, + { + name: "nil port ranges", + port: 80, + portRanges: nil, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsPortInRanges(tt.port, tt.portRanges) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/aws/services/globalaccelerator.go b/pkg/aws/services/globalaccelerator.go index 364f12d70..30e627307 100644 --- a/pkg/aws/services/globalaccelerator.go +++ b/pkg/aws/services/globalaccelerator.go @@ -41,6 +41,21 @@ type GlobalAccelerator interface { // ListListenersForAccelerator lists all listeners for an accelerator. ListListenersForAcceleratorWithContext(ctx context.Context, input *globalaccelerator.ListListenersInput) (*globalaccelerator.ListListenersOutput, error) + // CreateEndpointGroup creates a new endpoint group. + CreateEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.CreateEndpointGroupInput) (*globalaccelerator.CreateEndpointGroupOutput, error) + + // DescribeEndpointGroup describes an endpoint group. + DescribeEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.DescribeEndpointGroupInput) (*globalaccelerator.DescribeEndpointGroupOutput, error) + + // UpdateEndpointGroup updates an endpoint group. + UpdateEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.UpdateEndpointGroupInput) (*globalaccelerator.UpdateEndpointGroupOutput, error) + + // DeleteEndpointGroup deletes an endpoint group. + DeleteEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.DeleteEndpointGroupInput) (*globalaccelerator.DeleteEndpointGroupOutput, error) + + // wrapper to ListEndpointGroups API, which aggregates paged results into list. + ListEndpointGroupsAsList(ctx context.Context, input *globalaccelerator.ListEndpointGroupsInput) ([]types.EndpointGroup, error) + // TagResource tags a resource. TagResourceWithContext(ctx context.Context, input *globalaccelerator.TagResourceInput) (*globalaccelerator.TagResourceOutput, error) @@ -192,3 +207,52 @@ func (c *defaultGlobalAccelerator) ListListenersAsList(ctx context.Context, inpu } return result, nil } + +func (c *defaultGlobalAccelerator) CreateEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.CreateEndpointGroupInput) (*globalaccelerator.CreateEndpointGroupOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "CreateEndpointGroup") + if err != nil { + return nil, err + } + return client.CreateEndpointGroup(ctx, input) +} + +func (c *defaultGlobalAccelerator) DescribeEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.DescribeEndpointGroupInput) (*globalaccelerator.DescribeEndpointGroupOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "DescribeEndpointGroup") + if err != nil { + return nil, err + } + return client.DescribeEndpointGroup(ctx, input) +} + +func (c *defaultGlobalAccelerator) UpdateEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.UpdateEndpointGroupInput) (*globalaccelerator.UpdateEndpointGroupOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "UpdateEndpointGroup") + if err != nil { + return nil, err + } + return client.UpdateEndpointGroup(ctx, input) +} + +func (c *defaultGlobalAccelerator) DeleteEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.DeleteEndpointGroupInput) (*globalaccelerator.DeleteEndpointGroupOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "DeleteEndpointGroup") + if err != nil { + return nil, err + } + return client.DeleteEndpointGroup(ctx, input) +} + +func (c *defaultGlobalAccelerator) ListEndpointGroupsAsList(ctx context.Context, input *globalaccelerator.ListEndpointGroupsInput) ([]types.EndpointGroup, error) { + var result []types.EndpointGroup + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "ListEndpointGroups") + if err != nil { + return nil, err + } + paginator := globalaccelerator.NewListEndpointGroupsPaginator(client, input) + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return nil, err + } + result = append(result, output.EndpointGroups...) + } + return result, nil +} diff --git a/pkg/aws/services/globalaccelerator_mocks.go b/pkg/aws/services/globalaccelerator_mocks.go index e4989fa97..0bad79b37 100644 --- a/pkg/aws/services/globalaccelerator_mocks.go +++ b/pkg/aws/services/globalaccelerator_mocks.go @@ -51,6 +51,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) CreateAcceleratorWithContext(arg0, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).CreateAcceleratorWithContext), arg0, arg1) } +// CreateEndpointGroupWithContext mocks base method. +func (m *MockGlobalAccelerator) CreateEndpointGroupWithContext(arg0 context.Context, arg1 *globalaccelerator.CreateEndpointGroupInput) (*globalaccelerator.CreateEndpointGroupOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateEndpointGroupWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.CreateEndpointGroupOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateEndpointGroupWithContext indicates an expected call of CreateEndpointGroupWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) CreateEndpointGroupWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEndpointGroupWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).CreateEndpointGroupWithContext), arg0, arg1) +} + // CreateListenerWithContext mocks base method. func (m *MockGlobalAccelerator) CreateListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.CreateListenerInput) (*globalaccelerator.CreateListenerOutput, error) { m.ctrl.T.Helper() @@ -81,6 +96,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) DeleteAcceleratorWithContext(arg0, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DeleteAcceleratorWithContext), arg0, arg1) } +// DeleteEndpointGroupWithContext mocks base method. +func (m *MockGlobalAccelerator) DeleteEndpointGroupWithContext(arg0 context.Context, arg1 *globalaccelerator.DeleteEndpointGroupInput) (*globalaccelerator.DeleteEndpointGroupOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteEndpointGroupWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.DeleteEndpointGroupOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteEndpointGroupWithContext indicates an expected call of DeleteEndpointGroupWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) DeleteEndpointGroupWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEndpointGroupWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DeleteEndpointGroupWithContext), arg0, arg1) +} + // DeleteListenerWithContext mocks base method. func (m *MockGlobalAccelerator) DeleteListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.DeleteListenerInput) (*globalaccelerator.DeleteListenerOutput, error) { m.ctrl.T.Helper() @@ -111,6 +141,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) DescribeAcceleratorWithContext(arg0 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DescribeAcceleratorWithContext), arg0, arg1) } +// DescribeEndpointGroupWithContext mocks base method. +func (m *MockGlobalAccelerator) DescribeEndpointGroupWithContext(arg0 context.Context, arg1 *globalaccelerator.DescribeEndpointGroupInput) (*globalaccelerator.DescribeEndpointGroupOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DescribeEndpointGroupWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.DescribeEndpointGroupOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DescribeEndpointGroupWithContext indicates an expected call of DescribeEndpointGroupWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) DescribeEndpointGroupWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeEndpointGroupWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DescribeEndpointGroupWithContext), arg0, arg1) +} + // DescribeListenerWithContext mocks base method. func (m *MockGlobalAccelerator) DescribeListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.DescribeListenerInput) (*globalaccelerator.DescribeListenerOutput, error) { m.ctrl.T.Helper() @@ -141,6 +186,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) ListAcceleratorsAsList(arg0, arg1 i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAcceleratorsAsList", reflect.TypeOf((*MockGlobalAccelerator)(nil).ListAcceleratorsAsList), arg0, arg1) } +// ListEndpointGroupsAsList mocks base method. +func (m *MockGlobalAccelerator) ListEndpointGroupsAsList(arg0 context.Context, arg1 *globalaccelerator.ListEndpointGroupsInput) ([]types.EndpointGroup, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListEndpointGroupsAsList", arg0, arg1) + ret0, _ := ret[0].([]types.EndpointGroup) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListEndpointGroupsAsList indicates an expected call of ListEndpointGroupsAsList. +func (mr *MockGlobalAcceleratorMockRecorder) ListEndpointGroupsAsList(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEndpointGroupsAsList", reflect.TypeOf((*MockGlobalAccelerator)(nil).ListEndpointGroupsAsList), arg0, arg1) +} + // ListListenersAsList mocks base method. func (m *MockGlobalAccelerator) ListListenersAsList(arg0 context.Context, arg1 *globalaccelerator.ListListenersInput) ([]types.Listener, error) { m.ctrl.T.Helper() @@ -231,6 +291,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) UpdateAcceleratorWithContext(arg0, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).UpdateAcceleratorWithContext), arg0, arg1) } +// UpdateEndpointGroupWithContext mocks base method. +func (m *MockGlobalAccelerator) UpdateEndpointGroupWithContext(arg0 context.Context, arg1 *globalaccelerator.UpdateEndpointGroupInput) (*globalaccelerator.UpdateEndpointGroupOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateEndpointGroupWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.UpdateEndpointGroupOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateEndpointGroupWithContext indicates an expected call of UpdateEndpointGroupWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) UpdateEndpointGroupWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEndpointGroupWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).UpdateEndpointGroupWithContext), arg0, arg1) +} + // UpdateListenerWithContext mocks base method. func (m *MockGlobalAccelerator) UpdateListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.UpdateListenerInput) (*globalaccelerator.UpdateListenerOutput, error) { m.ctrl.T.Helper() diff --git a/pkg/deploy/aga/endpoint_group_manager.go b/pkg/deploy/aga/endpoint_group_manager.go new file mode 100644 index 000000000..263056e48 --- /dev/null +++ b/pkg/deploy/aga/endpoint_group_manager.go @@ -0,0 +1,240 @@ +package aga + +import ( + "context" + "errors" + "fmt" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +// EndpointGroupManager is responsible for managing AWS Global Accelerator endpoint groups. +type EndpointGroupManager interface { + // Create creates an endpoint group. + Create(ctx context.Context, resEndpointGroup *agamodel.EndpointGroup) (agamodel.EndpointGroupStatus, error) + + // Update updates an endpoint group. + Update(ctx context.Context, resEndpointGroup *agamodel.EndpointGroup, sdkEndpointGroup *agatypes.EndpointGroup) (agamodel.EndpointGroupStatus, error) + + // Delete deletes an endpoint group. + Delete(ctx context.Context, endpointGroupARN string) error +} + +// NewDefaultEndpointGroupManager constructs new defaultEndpointGroupManager. +func NewDefaultEndpointGroupManager(gaService services.GlobalAccelerator, logger logr.Logger) *defaultEndpointGroupManager { + return &defaultEndpointGroupManager{ + gaService: gaService, + logger: logger, + } +} + +var _ EndpointGroupManager = &defaultEndpointGroupManager{} + +// defaultEndpointGroupManager is the default implementation for EndpointGroupManager. +type defaultEndpointGroupManager struct { + gaService services.GlobalAccelerator + logger logr.Logger +} + +// buildSDKPortOverrides converts model port overrides to SDK port overrides +func (m *defaultEndpointGroupManager) buildSDKPortOverrides(modelPortOverrides []agamodel.PortOverride) []agatypes.PortOverride { + if len(modelPortOverrides) == 0 { + return nil + } + + portOverrides := make([]agatypes.PortOverride, 0, len(modelPortOverrides)) + for _, po := range modelPortOverrides { + portOverrides = append(portOverrides, agatypes.PortOverride{ + ListenerPort: awssdk.Int32(po.ListenerPort), + EndpointPort: awssdk.Int32(po.EndpointPort), + }) + } + return portOverrides +} + +func (m *defaultEndpointGroupManager) buildSDKCreateEndpointGroupInput(_ context.Context, resEndpointGroup *agamodel.EndpointGroup) (*globalaccelerator.CreateEndpointGroupInput, error) { + // Resolve listener ARN + listenerARN, err := resEndpointGroup.Spec.ListenerARN.Resolve(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to resolve listener ARN: %w", err) + } + + // Build create input + createInput := &globalaccelerator.CreateEndpointGroupInput{ + ListenerArn: awssdk.String(listenerARN), + EndpointGroupRegion: awssdk.String(resEndpointGroup.Spec.Region), + } + + // Convert TrafficDialPercentage from int32 to float32 if provided + if resEndpointGroup.Spec.TrafficDialPercentage != nil { + createInput.TrafficDialPercentage = awssdk.Float32(float32(*resEndpointGroup.Spec.TrafficDialPercentage)) + } + + // Add port overrides if specified + if len(resEndpointGroup.Spec.PortOverrides) > 0 { + createInput.PortOverrides = m.buildSDKPortOverrides(resEndpointGroup.Spec.PortOverrides) + } + + return createInput, nil +} + +func (m *defaultEndpointGroupManager) Create(ctx context.Context, resEndpointGroup *agamodel.EndpointGroup) (agamodel.EndpointGroupStatus, error) { + // Build create input + createInput, err := m.buildSDKCreateEndpointGroupInput(ctx, resEndpointGroup) + if err != nil { + return agamodel.EndpointGroupStatus{}, err + } + + // Create endpoint group + m.logger.V(1).Info("Creating endpoint group", + "stackID", resEndpointGroup.Stack().StackID(), + "resourceID", resEndpointGroup.ID()) + + createOutput, err := m.gaService.CreateEndpointGroupWithContext(ctx, createInput) + if err != nil { + return agamodel.EndpointGroupStatus{}, fmt.Errorf("failed to create endpoint group: %w", err) + } + + endpointGroup := createOutput.EndpointGroup + m.logger.Info("Successfully created endpoint group", + "stackID", resEndpointGroup.Stack().StackID(), + "resourceID", resEndpointGroup.ID(), + "endpointGroupARN", *endpointGroup.EndpointGroupArn) + + return agamodel.EndpointGroupStatus{ + EndpointGroupARN: *endpointGroup.EndpointGroupArn, + }, nil +} + +func (m *defaultEndpointGroupManager) buildSDKUpdateEndpointGroupInput(_ context.Context, resEndpointGroup *agamodel.EndpointGroup, sdkEndpointGroup *agatypes.EndpointGroup) (*globalaccelerator.UpdateEndpointGroupInput, error) { + // Build update input + updateInput := &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: sdkEndpointGroup.EndpointGroupArn, + } + + // Convert TrafficDialPercentage from int32 to float32 if provided + if resEndpointGroup.Spec.TrafficDialPercentage != nil { + updateInput.TrafficDialPercentage = awssdk.Float32(float32(*resEndpointGroup.Spec.TrafficDialPercentage)) + } + + // Add port overrides if specified + if len(resEndpointGroup.Spec.PortOverrides) > 0 { + updateInput.PortOverrides = m.buildSDKPortOverrides(resEndpointGroup.Spec.PortOverrides) + } + + return updateInput, nil +} + +func (m *defaultEndpointGroupManager) Update(ctx context.Context, resEndpointGroup *agamodel.EndpointGroup, sdkEndpointGroup *agatypes.EndpointGroup) (agamodel.EndpointGroupStatus, error) { + // Check if the endpoint group actually needs an update + if !m.isSDKEndpointGroupSettingsDrifted(resEndpointGroup, sdkEndpointGroup) { + m.logger.Info("No drift detected in endpoint group settings, skipping update", + "stackID", resEndpointGroup.Stack().StackID(), + "resourceID", resEndpointGroup.ID(), + "endpointGroupARN", *sdkEndpointGroup.EndpointGroupArn) + + return agamodel.EndpointGroupStatus{ + EndpointGroupARN: *sdkEndpointGroup.EndpointGroupArn, + }, nil + } + + m.logger.Info("Drift detected in endpoint group settings, updating", + "stackID", resEndpointGroup.Stack().StackID(), + "resourceID", resEndpointGroup.ID(), + "endpointGroupARN", *sdkEndpointGroup.EndpointGroupArn) + + // Build update input + updateInput, err := m.buildSDKUpdateEndpointGroupInput(ctx, resEndpointGroup, sdkEndpointGroup) + if err != nil { + return agamodel.EndpointGroupStatus{}, err + } + + // Update endpoint group + updateOutput, err := m.gaService.UpdateEndpointGroupWithContext(ctx, updateInput) + if err != nil { + return agamodel.EndpointGroupStatus{}, fmt.Errorf("failed to update endpoint group: %w", err) + } + + updatedEndpointGroup := updateOutput.EndpointGroup + m.logger.Info("Successfully updated endpoint group", + "stackID", resEndpointGroup.Stack().StackID(), + "resourceID", resEndpointGroup.ID(), + "endpointGroupARN", *updatedEndpointGroup.EndpointGroupArn) + + return agamodel.EndpointGroupStatus{ + EndpointGroupARN: *updatedEndpointGroup.EndpointGroupArn, + }, nil +} + +func (m *defaultEndpointGroupManager) Delete(ctx context.Context, endpointGroupARN string) error { + m.logger.Info("Deleting endpoint group", "endpointGroupARN", endpointGroupARN) + + deleteInput := &globalaccelerator.DeleteEndpointGroupInput{ + EndpointGroupArn: awssdk.String(endpointGroupARN), + } + + if _, err := m.gaService.DeleteEndpointGroupWithContext(ctx, deleteInput); err != nil { + // Check if it's a not found error - the endpoint group might have been already deleted + var apiErr *agatypes.EndpointGroupNotFoundException + if errors.As(err, &apiErr) { + m.logger.Info("Endpoint group already deleted", "endpointGroupARN", endpointGroupARN) + return nil + } + return fmt.Errorf("failed to delete endpoint group: %w", err) + } + + m.logger.Info("Successfully deleted endpoint group", "endpointGroupARN", endpointGroupARN) + return nil +} + +// isSDKEndpointGroupSettingsDrifted checks if the endpoint group configuration has drifted from the desired state +func (m *defaultEndpointGroupManager) isSDKEndpointGroupSettingsDrifted(resEndpointGroup *agamodel.EndpointGroup, sdkEndpointGroup *agatypes.EndpointGroup) bool { + // Cannot change region after creation, so we don't check for region drift + + // Check traffic dial percentage + if resEndpointGroup.Spec.TrafficDialPercentage != nil { + resTrafficDialPercentage := float32(*resEndpointGroup.Spec.TrafficDialPercentage) + sdkTrafficDialPercentage := awssdk.ToFloat32(sdkEndpointGroup.TrafficDialPercentage) + // Use a small epsilon for float comparison to avoid precision issues + const epsilon = 0.001 + if resTrafficDialPercentage < sdkTrafficDialPercentage-epsilon || resTrafficDialPercentage > sdkTrafficDialPercentage+epsilon { + return true + } + } else if sdkEndpointGroup.TrafficDialPercentage != nil { + // Resource has no traffic dial percentage but SDK does + return true + } + + // Check port overrides + if !m.arePortOverridesEqual(resEndpointGroup.Spec.PortOverrides, sdkEndpointGroup.PortOverrides) { + return true + } + + return false +} + +// arePortOverridesEqual compares port overrides from the resource model and SDK +func (m *defaultEndpointGroupManager) arePortOverridesEqual(modelPortOverrides []agamodel.PortOverride, sdkPortOverrides []agatypes.PortOverride) bool { + if len(modelPortOverrides) != len(sdkPortOverrides) { + return false + } + + // Convert to maps for easier comparison + modelMap := make(map[int32]int32) + for _, po := range modelPortOverrides { + modelMap[po.ListenerPort] = po.EndpointPort + } + + // Check if all SDK port overrides match the model + for _, po := range sdkPortOverrides { + if modelEndpointPort, exists := modelMap[awssdk.ToInt32(po.ListenerPort)]; !exists || modelEndpointPort != awssdk.ToInt32(po.EndpointPort) { + return false + } + } + + return true +} diff --git a/pkg/deploy/aga/endpoint_group_manager_mocks.go b/pkg/deploy/aga/endpoint_group_manager_mocks.go new file mode 100644 index 000000000..d3109662c --- /dev/null +++ b/pkg/deploy/aga/endpoint_group_manager_mocks.go @@ -0,0 +1,81 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga (interfaces: EndpointGroupManager) + +// Package aga is a generated GoMock package. +package aga + +import ( + context "context" + reflect "reflect" + + types "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + gomock "github.com/golang/mock/gomock" + aga "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +// MockEndpointGroupManager is a mock of EndpointGroupManager interface. +type MockEndpointGroupManager struct { + ctrl *gomock.Controller + recorder *MockEndpointGroupManagerMockRecorder +} + +// MockEndpointGroupManagerMockRecorder is the mock recorder for MockEndpointGroupManager. +type MockEndpointGroupManagerMockRecorder struct { + mock *MockEndpointGroupManager +} + +// NewMockEndpointGroupManager creates a new mock instance. +func NewMockEndpointGroupManager(ctrl *gomock.Controller) *MockEndpointGroupManager { + mock := &MockEndpointGroupManager{ctrl: ctrl} + mock.recorder = &MockEndpointGroupManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEndpointGroupManager) EXPECT() *MockEndpointGroupManagerMockRecorder { + return m.recorder +} + +// Create mocks base method. +func (m *MockEndpointGroupManager) Create(arg0 context.Context, arg1 *aga.EndpointGroup) (aga.EndpointGroupStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Create", arg0, arg1) + ret0, _ := ret[0].(aga.EndpointGroupStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockEndpointGroupManagerMockRecorder) Create(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockEndpointGroupManager)(nil).Create), arg0, arg1) +} + +// Delete mocks base method. +func (m *MockEndpointGroupManager) Delete(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockEndpointGroupManagerMockRecorder) Delete(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockEndpointGroupManager)(nil).Delete), arg0, arg1) +} + +// Update mocks base method. +func (m *MockEndpointGroupManager) Update(arg0 context.Context, arg1 *aga.EndpointGroup, arg2 *types.EndpointGroup) (aga.EndpointGroupStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", arg0, arg1, arg2) + ret0, _ := ret[0].(aga.EndpointGroupStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockEndpointGroupManagerMockRecorder) Update(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockEndpointGroupManager)(nil).Update), arg0, arg1, arg2) +} diff --git a/pkg/deploy/aga/endpoint_group_manager_test.go b/pkg/deploy/aga/endpoint_group_manager_test.go new file mode 100644 index 000000000..53959afa8 --- /dev/null +++ b/pkg/deploy/aga/endpoint_group_manager_test.go @@ -0,0 +1,480 @@ +package aga + +import ( + "context" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +func Test_defaultEndpointGroupManager_buildSDKPortOverrides(t *testing.T) { + tests := []struct { + name string + modelPortOverrides []agamodel.PortOverride + want []agatypes.PortOverride + }{ + { + name: "nil model port overrides", + modelPortOverrides: nil, + want: nil, + }, + { + name: "empty model port overrides", + modelPortOverrides: []agamodel.PortOverride{}, + want: nil, + }, + { + name: "single port override", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + want: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + }, + { + name: "multiple port overrides", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + }, + want: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + { + ListenerPort: aws.Int32(443), + EndpointPort: aws.Int32(8443), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := logr.Discard() + m := &defaultEndpointGroupManager{ + logger: logger, + } + got := m.buildSDKPortOverrides(tt.modelPortOverrides) + + // Compare nil vs nil + if tt.want == nil && got == nil { + // Both are nil, this is correct + return + } + + // Compare lengths + assert.Equal(t, len(tt.want), len(got)) + + // Compare individual elements + for i, wantPO := range tt.want { + assert.Equal(t, awssdk.ToInt32(wantPO.ListenerPort), awssdk.ToInt32(got[i].ListenerPort)) + assert.Equal(t, awssdk.ToInt32(wantPO.EndpointPort), awssdk.ToInt32(got[i].EndpointPort)) + } + }) + } +} + +func Test_defaultEndpointGroupManager_arePortOverridesEqual(t *testing.T) { + tests := []struct { + name string + modelPortOverrides []agamodel.PortOverride + sdkPortOverrides []agatypes.PortOverride + want bool + }{ + { + name: "both nil - equal", + modelPortOverrides: nil, + sdkPortOverrides: nil, + want: true, + }, + { + name: "one empty one nil - equal", + modelPortOverrides: nil, + sdkPortOverrides: []agatypes.PortOverride{}, + want: true, + }, + { + name: "both empty - equal", + modelPortOverrides: []agamodel.PortOverride{}, + sdkPortOverrides: []agatypes.PortOverride{}, + want: true, + }, + { + name: "model empty, sdk not empty - not equal", + modelPortOverrides: []agamodel.PortOverride{}, + sdkPortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + want: false, + }, + { + name: "model not empty, sdk empty - not equal", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + sdkPortOverrides: []agatypes.PortOverride{}, + want: false, + }, + { + name: "different lengths - not equal", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + sdkPortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + { + ListenerPort: aws.Int32(443), + EndpointPort: aws.Int32(8443), + }, + }, + want: false, + }, + { + name: "same length but different values - not equal", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + sdkPortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(9090), // Different endpoint port + }, + }, + want: false, + }, + { + name: "same values, same order - equal", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + }, + sdkPortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + { + ListenerPort: aws.Int32(443), + EndpointPort: aws.Int32(8443), + }, + }, + want: true, + }, + { + name: "same values, different order - equal (order doesn't matter)", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 443, + EndpointPort: 8443, + }, + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + sdkPortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + { + ListenerPort: aws.Int32(443), + EndpointPort: aws.Int32(8443), + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := logr.Discard() + m := &defaultEndpointGroupManager{ + logger: logger, + } + got := m.arePortOverridesEqual(tt.modelPortOverrides, tt.sdkPortOverrides) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_defaultEndpointGroupManager_isSDKEndpointGroupSettingsDrifted(t *testing.T) { + tests := []struct { + name string + resEndpointGroup *agamodel.EndpointGroup + sdkEndpointGroup *agatypes.EndpointGroup + want bool + }{ + { + name: "no drift - all values match", + resEndpointGroup: &agamodel.EndpointGroup{ + Spec: agamodel.EndpointGroupSpec{ + Region: "us-west-2", + TrafficDialPercentage: aws.Int32(100), + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + }, + }, + sdkEndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(100.0), + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + }, + want: false, + }, + { + name: "traffic dial percentage differs", + resEndpointGroup: &agamodel.EndpointGroup{ + Spec: agamodel.EndpointGroupSpec{ + Region: "us-west-2", + TrafficDialPercentage: aws.Int32(50), + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + }, + }, + sdkEndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(100.0), + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + }, + want: true, + }, + { + name: "traffic dial percentage small difference within epsilon - no drift", + resEndpointGroup: &agamodel.EndpointGroup{ + Spec: agamodel.EndpointGroupSpec{ + Region: "us-west-2", + TrafficDialPercentage: aws.Int32(100), + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + }, + }, + sdkEndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(100.0005), // Small difference within epsilon + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + }, + want: false, + }, + { + name: "port overrides differ", + resEndpointGroup: &agamodel.EndpointGroup{ + Spec: agamodel.EndpointGroupSpec{ + Region: "us-west-2", + TrafficDialPercentage: aws.Int32(100), + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + }, + }, + sdkEndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(100.0), + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(9090), // Different endpoint port + }, + }, + }, + want: true, + }, + { + name: "model has no traffic dial percentage, sdk does - drift", + resEndpointGroup: &agamodel.EndpointGroup{ + Spec: agamodel.EndpointGroupSpec{ + Region: "us-west-2", + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + }, + }, + sdkEndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(100.0), + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := logr.Discard() + m := &defaultEndpointGroupManager{ + logger: logger, + } + got := m.isSDKEndpointGroupSettingsDrifted(tt.resEndpointGroup, tt.sdkEndpointGroup) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_defaultEndpointGroupManager_buildSDKCreateEndpointGroupInput(t *testing.T) { + testListenerARN := "arn:aws:globalaccelerator::123456789012:listener/1234abcd-abcd-1234-abcd-1234abcdefgh/abcdefghi" + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + tests := []struct { + name string + resEndpointGroup *agamodel.EndpointGroup + want *globalaccelerator.CreateEndpointGroupInput + wantErr bool + }{ + { + name: "Standard endpoint group with all fields", + resEndpointGroup: &agamodel.EndpointGroup{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-1"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + TrafficDialPercentage: aws.Int32(85), + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + }, + }, + }, + want: &globalaccelerator.CreateEndpointGroupInput{ + ListenerArn: aws.String(testListenerARN), + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(85.0), + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + { + ListenerPort: aws.Int32(443), + EndpointPort: aws.Int32(8443), + }, + }, + }, + wantErr: false, + }, + { + name: "Minimal endpoint group", + resEndpointGroup: &agamodel.EndpointGroup{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-2"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + // No TrafficDialPercentage or PortOverrides + }, + }, + want: &globalaccelerator.CreateEndpointGroupInput{ + ListenerArn: aws.String(testListenerARN), + EndpointGroupRegion: aws.String("us-west-2"), + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create endpoint group manager + m := &defaultEndpointGroupManager{ + gaService: nil, // Not needed for this test + logger: logr.Discard(), + } + + // Call the method being tested + got, err := m.buildSDKCreateEndpointGroupInput(context.Background(), tt.resEndpointGroup) + + // Check if error status matches expected + if (err != nil) != tt.wantErr { + t.Errorf("buildSDKCreateEndpointGroupInput() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Check if the result matches expected + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/deploy/aga/endpoint_group_synthesizer.go b/pkg/deploy/aga/endpoint_group_synthesizer.go new file mode 100644 index 000000000..b1e4f31c3 --- /dev/null +++ b/pkg/deploy/aga/endpoint_group_synthesizer.go @@ -0,0 +1,485 @@ +package aga + +import ( + "context" + "k8s.io/apimachinery/pkg/util/sets" + "strings" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "github.com/pkg/errors" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +// NewEndpointGroupSynthesizer constructs new EndpointGroupSynthesizer +func NewEndpointGroupSynthesizer( + gaService services.GlobalAccelerator, + endpointGroupManager EndpointGroupManager, + logger logr.Logger, + stack core.Stack) *endpointGroupSynthesizer { + + return &endpointGroupSynthesizer{ + gaService: gaService, + endpointGroupManager: endpointGroupManager, + logger: logger, + stack: stack, + } +} + +// endpointGroupSynthesizer synthesizes AGA EndpointGroup resources +type endpointGroupSynthesizer struct { + gaService services.GlobalAccelerator + endpointGroupManager EndpointGroupManager + logger logr.Logger + stack core.Stack +} + +// EndpointPortConflict describes a conflict between endpoint ports in different groups +type EndpointPortConflict struct { + Port int32 + ConflictingGroups []string // ARNs of endpoint groups with this port + ListenerARNs []string // ARNs of listeners that contain the conflicting endpoint groups +} + +// endpointGroupAndSDKEndpointGroup contains a pair of endpoint group resource and its SDK endpoint group +type resAndSDKGroupPair struct { + resEndpointGroup *agamodel.EndpointGroup + sdkEndpointGroup *agatypes.EndpointGroup +} + +// mapEndpointGroupsByListenerARN maps endpoint groups by their parent listener ARN +func (s *endpointGroupSynthesizer) mapEndpointGroupsByListenerARN(ctx context.Context, resEndpointGroups []*agamodel.EndpointGroup) (map[string][]*agamodel.EndpointGroup, error) { + endpointGroupsByListenerARN := make(map[string][]*agamodel.EndpointGroup) + + for _, eg := range resEndpointGroups { + listenerARN, err := eg.Spec.ListenerARN.Resolve(ctx) + if err != nil { + return nil, errors.Wrapf(err, "failed to resolve listener ARN for endpoint group %s", eg.ID()) + } + endpointGroupsByListenerARN[listenerARN] = append(endpointGroupsByListenerARN[listenerARN], eg) + } + + return endpointGroupsByListenerARN, nil +} + +// areListenersEquivalent checks if two listener ARNs refer to the same underlying listener +// Since AWS Global Accelerator ARNs are consistent across resources, a simple string comparison is sufficient +func areListenersEquivalent(listener1, listener2 string) bool { + return listener1 == listener2 +} + +// EndpointPortInfo represents endpoint port usage information for a specific region +type EndpointPortInfo struct { + // Region where this endpoint port is used + Region string + // Port number + Port int32 + // ListenerARN that uses this port + ListenerARN string + // EndpointGroupARN of the group using this port + EndpointGroupARN string +} + +// detectConflictsWithSDKEndpointGroups identifies endpoint port conflicts between desired +// endpoint groups and existing SDK endpoint groups within the same AWS region. +// +// AWS Global Accelerator enforces a critical constraint: within a single region, endpoint ports +// must be unique across endpoint groups belonging to different listeners. This prevents the +// same destination port from being used by multiple listeners in the same region, which would +// create ambiguity for traffic routing. +// +// The function compares endpoint ports used by desired endpoint groups against those +// already in use by existing SDK endpoint groups. It returns a map of conflicting ports +// to their respective conflicting endpoint group ARNs for resolution. +// +// For more details on port override constraints, see: +// https://docs.aws.amazon.com/global-accelerator/latest/dg/about-endpoint-groups-port-override.html +func (s *endpointGroupSynthesizer) detectConflictsWithSDKEndpointGroups( + ctx context.Context, + resEndpointGroups []*agamodel.EndpointGroup, + sdkEndpointGroups []agatypes.EndpointGroup) (map[int32][]string, error) { + + // Step 1: Collect all desired endpoint ports by region and listener + var desiredPortInfos []EndpointPortInfo + + for _, resGroup := range resEndpointGroups { + // Skip groups with no port overrides + if resGroup.Spec.PortOverrides == nil || len(resGroup.Spec.PortOverrides) == 0 { + continue + } + + // Get listener ARN for this resource group + listenerARN, err := resGroup.Spec.ListenerARN.Resolve(ctx) + if err != nil { + return nil, errors.Wrapf(err, "failed to resolve listener ARN for endpoint group %s", resGroup.ID()) + } + + // Add all endpoint ports this group wants to use + for _, po := range resGroup.Spec.PortOverrides { + desiredPortInfos = append(desiredPortInfos, EndpointPortInfo{ + Region: resGroup.Spec.Region, + Port: po.EndpointPort, + ListenerARN: listenerARN, + }) + } + } + + // No desired port overrides means no conflicts return early + if len(desiredPortInfos) == 0 { + return nil, nil + } + + // Step 2: Collect all SDK endpoint ports by region and listener + var sdkPortInfos []EndpointPortInfo + + for _, sdkGroup := range sdkEndpointGroups { + region := awssdk.ToString(sdkGroup.EndpointGroupRegion) + groupARN := awssdk.ToString(sdkGroup.EndpointGroupArn) + listenerARN := extractListenerARNFromEndpointGroupARN(groupARN) + + // Add all ports for this SDK group + for _, po := range sdkGroup.PortOverrides { + port := awssdk.ToInt32(po.EndpointPort) + sdkPortInfos = append(sdkPortInfos, EndpointPortInfo{ + Region: region, + Port: port, + ListenerARN: listenerARN, + EndpointGroupARN: groupARN, + }) + } + } + + // Step 3: Find conflicts by comparing different listeners using the same port in the same region + conflicts := make(map[int32][]string) // map[port][]conflictingGroupARNs + + for _, desiredInfo := range desiredPortInfos { + for _, sdkInfo := range sdkPortInfos { + // Check if they're in the same region and using the same port + if desiredInfo.Region == sdkInfo.Region && desiredInfo.Port == sdkInfo.Port { + // Check if they're from different listeners + if !areListenersEquivalent(desiredInfo.ListenerARN, sdkInfo.ListenerARN) { + // If different listeners use same port in same region, it's a conflict + conflicts[desiredInfo.Port] = append(conflicts[desiredInfo.Port], sdkInfo.EndpointGroupARN) + + s.logger.V(1).Info("Detected endpoint port conflict", + "endpointPort", desiredInfo.Port, + "region", desiredInfo.Region, + "conflictingSDKGroup", sdkInfo.EndpointGroupARN) + } + } + } + } + + return conflicts, nil +} + +// extractListenerARNFromEndpointGroupARN extracts the listener ARN portion from an endpoint group ARN +// Returns empty string if the format doesn't match expectations +func extractListenerARNFromEndpointGroupARN(groupARN string) string { + // Expected format: arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234 + + // Check for endpoint group pattern + if !strings.Contains(groupARN, "/endpoint-group/") { + return "" + } + + // Split by "endpoint-group" and take first part + parts := strings.Split(groupARN, "/endpoint-group/") + if len(parts) < 2 { + return "" + } + + return parts[0] +} + +// resolveConflictsWithSDKEndpointGroups resolves endpoint port conflicts by updating +// the existing SDK endpoint groups to remove conflicting port overrides +func (s *endpointGroupSynthesizer) resolveConflictsWithSDKEndpointGroups( + ctx context.Context, + conflicts map[int32][]string, + sdkEndpointGroups []agatypes.EndpointGroup) error { + + if len(conflicts) == 0 { + return nil + } + + s.logger.V(1).Info("Detected endpoint port conflicts with existing SDK endpoint groups", + "conflictCount", len(conflicts)) + + // Track which SDK groups need updating + sdkGroupUpdates := make(map[string][]agatypes.PortOverride) + + // For each conflict, we need to remove the conflicting port override from the SDK groups + for port, conflictingGroups := range conflicts { + for _, groupARN := range conflictingGroups { + // Find the SDK group with this ARN + for i, sdkGroup := range sdkEndpointGroups { + if awssdk.ToString(sdkGroup.EndpointGroupArn) == groupARN { + // Create a filtered list of port overrides excluding the conflicting one + updatedPortOverrides := make([]agatypes.PortOverride, 0, len(sdkGroup.PortOverrides)) + + for _, po := range sdkGroup.PortOverrides { + if awssdk.ToInt32(po.EndpointPort) != port { + updatedPortOverrides = append(updatedPortOverrides, po) + } + } + + // Store the updated port overrides for this group + sdkGroupUpdates[groupARN] = updatedPortOverrides + + // Update the SDK endpoint group in our local list to reflect the change + sdkEndpointGroups[i].PortOverrides = updatedPortOverrides + break + } + } + } + } + + // Update all the SDK endpoint groups that need updating + for groupARN, updatedPortOverrides := range sdkGroupUpdates { + s.logger.V(1).Info("Updating existing endpoint group to remove conflicting port overrides", + "endpointGroupARN", groupARN, + "updatedPortOverridesCount", len(updatedPortOverrides)) + + // Create update input + updateInput := &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: awssdk.String(groupARN), + PortOverrides: updatedPortOverrides, + } + + // Update the endpoint group + _, err := s.gaService.UpdateEndpointGroupWithContext(ctx, updateInput) + if err != nil { + return errors.Wrapf(err, "failed to update endpoint group %s to resolve port conflicts", groupARN) + } + } + + return nil +} + +// getAllEndpointGroupsInListeners returns all endpoint groups across all listeners +func (s *endpointGroupSynthesizer) getAllEndpointGroupsInListeners(ctx context.Context, listenerARNs []string) ([]agatypes.EndpointGroup, error) { + var allEndpointGroups []agatypes.EndpointGroup + for _, listenerARN := range listenerARNs { + endpointGroups, err := s.gaService.ListEndpointGroupsAsList(ctx, &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String(listenerARN), + }) + if err != nil { + return nil, errors.Wrapf(err, "failed to list endpoint groups for listener %s", listenerARN) + } + allEndpointGroups = append(allEndpointGroups, endpointGroups...) + } + + return allEndpointGroups, nil +} + +// Synthesize performs the actual synthesis of endpoint group resources +func (s *endpointGroupSynthesizer) Synthesize(ctx context.Context) error { + var resEndpointGroups []*agamodel.EndpointGroup + s.stack.ListResources(&resEndpointGroups) + + // Get listener ARNs from stack + var resListeners []*agamodel.Listener + s.stack.ListResources(&resListeners) + + // Nothing to process. No Listeners and endpoint groups in stack. + // This means we have already deleted all the unneeded listeners and its corresponding endpoint groups during listener synthesis. + if len(resListeners) == 0 { + return nil + } + + listenerARNs := make([]string, 0, len(resListeners)) + for _, resListener := range resListeners { + listenerARN, err := resListener.ListenerARN().Resolve(ctx) + if err != nil { + return errors.Wrapf(err, "failed to resolve listener ARN for resListener %s", resListener.ID()) + } + listenerARNs = append(listenerARNs, listenerARN) + } + + // Group endpoint groups by listener ARN + endpointGroupsByListenerARN, err := s.mapEndpointGroupsByListenerARN(ctx, resEndpointGroups) + if err != nil { + return err + } + + // Only detect conflicts for endpoint port duplicates if there are any desired endpoint groups in our stack + if len(resEndpointGroups) > 0 { + // Get endpoint groups and handle any conflicts before proceeding + if err := s.detectAndResolveEndpointGroupConflicts(ctx, resEndpointGroups, listenerARNs); err != nil { + return err + } + } + + // Process endpoint groups by listener ARN + for _, listenerARN := range listenerARNs { + resEndpointGroups := endpointGroupsByListenerARN[listenerARN] + if err := s.synthesizeEndpointGroupsOnListener(ctx, listenerARN, resEndpointGroups); err != nil { + return err + } + } + + return nil +} + +// PostSynthesize performs cleanup of endpoint group resources +// Currently not needed as deletion happens in synthesizeEndpointGroupsOnListener +func (s *endpointGroupSynthesizer) PostSynthesize(ctx context.Context) error { + return nil +} + +// detectAndResolveEndpointGroupConflicts handles endpoint group conflicts +// It detects conflicts between endpoint groups from different listeners, resolves them if needed +func (s *endpointGroupSynthesizer) detectAndResolveEndpointGroupConflicts(ctx context.Context, resEndpointGroups []*agamodel.EndpointGroup, listenerARNs []string) error { + s.logger.V(1).Info("Detecting and resolving endpoint group conflicts", + "endpointGroupCount", len(resEndpointGroups), + "listenerCount", len(listenerARNs)) + + // Get endpoint groups to check for conflicts with our desired state + allSDKEndpointGroups, err := s.getAllEndpointGroupsInListeners(ctx, listenerARNs) + if err != nil { + s.logger.Error(err, "Failed to get endpoint groups for conflict checking", + "listenerCount", len(listenerARNs)) + return errors.Wrap(err, "failed to get endpoint groups for conflict checking") + } + + // Detect conflicts between our desired endpoint groups and existing ones in AWS + sdkConflicts, err := s.detectConflictsWithSDKEndpointGroups(ctx, resEndpointGroups, allSDKEndpointGroups) + if err != nil { + s.logger.Error(err, "Failed to detect endpoint group conflicts", + "endpointGroupCount", len(resEndpointGroups), + "sdkEndpointGroupCount", len(allSDKEndpointGroups)) + return err + } + + // If conflicts with existing SDK endpoint groups are found, update them to remove conflicts + if len(sdkConflicts) > 0 { + for port, groups := range sdkConflicts { + s.logger.V(1).Info("Port conflict details", + "endpointPort", port, + "conflictingGroupCount", len(groups)) + } + + if err := s.resolveConflictsWithSDKEndpointGroups(ctx, sdkConflicts, allSDKEndpointGroups); err != nil { + s.logger.Error(err, "Failed to resolve endpoint group port conflicts", + "conflictCount", len(sdkConflicts)) + return errors.Wrap(err, "failed to resolve endpoint group port conflicts") + } + + s.logger.Info("Successfully resolved all endpoint group port conflicts", + "conflictCount", len(sdkConflicts)) + } else { + s.logger.V(1).Info("No endpoint group port conflicts detected with external groups") + } + + return nil +} + +// matchResAndSDKEndpointGroups matches resource endpoint groups with SDK endpoint groups using region as the unique key +func matchResAndSDKEndpointGroups(resEndpointGroups []*agamodel.EndpointGroup, sdkEndpointGroups []agatypes.EndpointGroup) ([]resAndSDKGroupPair, []*agamodel.EndpointGroup, []*agatypes.EndpointGroup) { + // Create maps for matching by region (region is the unique key within a listener) + sdkGroupsByRegion := make(map[string]*agatypes.EndpointGroup) + resGroupsByRegion := make(map[string]*agamodel.EndpointGroup) + + // Map resource endpoint groups by region + for _, resGroup := range resEndpointGroups { + region := resGroup.Spec.Region + resGroupsByRegion[region] = resGroup + } + + // Map SDK endpoint groups by region + for _, sdkGroup := range sdkEndpointGroups { + region := awssdk.ToString(sdkGroup.EndpointGroupRegion) + sdkGroupsByRegion[region] = &sdkGroup + } + resGroupRegions := sets.StringKeySet(resGroupsByRegion) + sdkGroupRegions := sets.StringKeySet(sdkGroupsByRegion) + + // Find matches and non-matches + var matchedResAndSDKGroups []resAndSDKGroupPair + var unmatchedResGroups []*agamodel.EndpointGroup + var unmatchedSDKGroups []*agatypes.EndpointGroup + + // Find matched pairs and unmatched resource groups + for _, region := range resGroupRegions.Intersection(sdkGroupRegions).List() { + resGroup := resGroupsByRegion[region] + sdkGroup := sdkGroupsByRegion[region] + matchedResAndSDKGroups = append(matchedResAndSDKGroups, resAndSDKGroupPair{ + resEndpointGroup: resGroup, + sdkEndpointGroup: sdkGroup, + }) + } + for _, region := range resGroupRegions.Difference(sdkGroupRegions).List() { + unmatchedResGroups = append(unmatchedResGroups, resGroupsByRegion[region]) + } + for _, region := range sdkGroupRegions.Difference(resGroupRegions).List() { + unmatchedSDKGroups = append(unmatchedSDKGroups, sdkGroupsByRegion[region]) + } + + return matchedResAndSDKGroups, unmatchedResGroups, unmatchedSDKGroups +} + +// synthesizeEndpointGroupsOnListener processes all endpoint groups for a specific listener +func (s *endpointGroupSynthesizer) synthesizeEndpointGroupsOnListener(ctx context.Context, listenerARN string, resEndpointGroups []*agamodel.EndpointGroup) error { + // Get existing endpoint groups for this listener from AWS + sdkEndpointGroups, err := s.getEndpointGroupsForListener(ctx, listenerARN) + if err != nil { + return errors.Wrapf(err, "failed to list endpoint groups for listener %s", listenerARN) + } + + // Match resource endpoint groups with SDK endpoint groups + matchedEndpointGroups, unmatchedResEndpointGroups, unmatchedSDKEndpointGroups := matchResAndSDKEndpointGroups(resEndpointGroups, sdkEndpointGroups) + + // Handle matched pairs - update them + for _, pair := range matchedEndpointGroups { + s.logger.Info("Updating existing endpoint group", + "endpointGroupArn", *pair.sdkEndpointGroup.EndpointGroupArn, + "region", pair.resEndpointGroup.Spec.Region) + status, err := s.endpointGroupManager.Update(ctx, pair.resEndpointGroup, pair.sdkEndpointGroup) + if err != nil { + return errors.Wrapf(err, "failed to update endpoint group %v", pair.resEndpointGroup.ID()) + } + + // Update the resource with the returned status + pair.resEndpointGroup.SetStatus(status) + } + + // Handle unmatched SDK endpoint groups - delete them + for _, sdkGroup := range unmatchedSDKEndpointGroups { + s.logger.Info("Deleting unneeded endpoint group", + "endpointGroupArn", *sdkGroup.EndpointGroupArn, + "region", *sdkGroup.EndpointGroupRegion) + egARN := awssdk.ToString(sdkGroup.EndpointGroupArn) + + if err := s.endpointGroupManager.Delete(ctx, egARN); err != nil { + return errors.Wrapf(err, "failed to delete unneeded endpoint group: %v", egARN) + } + } + + // Handle unmatched resource endpoint groups - create them + for _, resGroup := range unmatchedResEndpointGroups { + s.logger.Info("Creating new endpoint group", + "region", resGroup.Spec.Region) + status, err := s.endpointGroupManager.Create(ctx, resGroup) + if err != nil { + return errors.Wrapf(err, "failed to create endpoint group %v", resGroup.ID()) + } + + // Update the resource with the returned status + resGroup.SetStatus(status) + } + return nil +} + +// getEndpointGroupsForListener gets all endpoint groups for a specific listener +func (s *endpointGroupSynthesizer) getEndpointGroupsForListener(ctx context.Context, listenerARN string) ([]agatypes.EndpointGroup, error) { + listInput := &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String(listenerARN), + } + + return s.gaService.ListEndpointGroupsAsList(ctx, listInput) +} diff --git a/pkg/deploy/aga/endpoint_group_synthesizer_test.go b/pkg/deploy/aga/endpoint_group_synthesizer_test.go new file mode 100644 index 000000000..11a54ad81 --- /dev/null +++ b/pkg/deploy/aga/endpoint_group_synthesizer_test.go @@ -0,0 +1,1182 @@ +package aga + +import ( + "context" + "sort" + "testing" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +// Helper function to create a test endpoint group with port overrides +func createEndpointGroupWithPortOverrides(id string, region string, listenerARN string, portOverrides []agamodel.PortOverride) *agamodel.EndpointGroup { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + return &agamodel.EndpointGroup{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", id), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(listenerARN), + Region: region, + TrafficDialPercentage: awssdk.Int32(100), + PortOverrides: portOverrides, + }, + } +} + +func Test_matchResAndSDKEndpointGroups(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + testListenerARN := "arn:aws:globalaccelerator::123456789012:listener/1234abcd-abcd-1234-abcd-1234abcdefgh" + + tests := []struct { + name string + resEndpointGroups []*agamodel.EndpointGroup + sdkEndpointGroups []agatypes.EndpointGroup + wantMatchedPairs []struct { + resID string + sdkRegion string + sdkGroupARN string + } + wantUnmatchedResIDs []string + wantUnmatchedSDKRegions []string + }{ + { + name: "empty lists", + resEndpointGroups: []*agamodel.EndpointGroup{}, + sdkEndpointGroups: []agatypes.EndpointGroup{}, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{}, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKRegions: []string{}, + }, + { + name: "single exact match by region", + resEndpointGroups: []*agamodel.EndpointGroup{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-1"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + TrafficDialPercentage: awssdk.Int32(100), + }, + }, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-1234"), + EndpointGroupRegion: awssdk.String("us-west-2"), + TrafficDialPercentage: awssdk.Float32(100.0), + }, + }, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + { + resID: "endpoint-group-1", + sdkRegion: "us-west-2", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-1234", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKRegions: []string{}, + }, + { + name: "multiple matches by region", + resEndpointGroups: []*agamodel.EndpointGroup{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-1"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-2"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-east-1", + }, + }, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-east"), + EndpointGroupRegion: awssdk.String("us-east-1"), + }, + }, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + { + resID: "endpoint-group-1", + sdkRegion: "us-west-2", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west", + }, + { + resID: "endpoint-group-2", + sdkRegion: "us-east-1", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-east", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKRegions: []string{}, + }, + { + name: "unmatched resource endpoint groups", + resEndpointGroups: []*agamodel.EndpointGroup{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-1"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-2"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "eu-west-1", // No matching SDK endpoint group + }, + }, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + }, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + { + resID: "endpoint-group-1", + sdkRegion: "us-west-2", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west", + }, + }, + wantUnmatchedResIDs: []string{"endpoint-group-2"}, + wantUnmatchedSDKRegions: []string{}, + }, + { + name: "unmatched SDK endpoint groups", + resEndpointGroups: []*agamodel.EndpointGroup{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-1"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + }, + }, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-east"), + EndpointGroupRegion: awssdk.String("us-east-1"), // No matching resource endpoint group + }, + }, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + { + resID: "endpoint-group-1", + sdkRegion: "us-west-2", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKRegions: []string{"us-east-1"}, + }, + { + name: "mixed matches and unmatches", + resEndpointGroups: []*agamodel.EndpointGroup{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-west"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-central"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-central-1", // No matching SDK endpoint group + }, + }, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-east"), + EndpointGroupRegion: awssdk.String("us-east-1"), // No matching resource endpoint group + }, + }, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + { + resID: "endpoint-group-west", + sdkRegion: "us-west-2", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west", + }, + }, + wantUnmatchedResIDs: []string{"endpoint-group-central"}, + wantUnmatchedSDKRegions: []string{"us-east-1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Run the function + matchedPairs, unmatchedResEndpointGroups, unmatchedSDKEndpointGroups := matchResAndSDKEndpointGroups(tt.resEndpointGroups, tt.sdkEndpointGroups) + + // Collect the actual pairs and IDs for verification + var actualMatchedPairs []struct { + resID string + sdkRegion string + sdkGroupARN string + } + + for _, pair := range matchedPairs { + actualMatchedPairs = append(actualMatchedPairs, struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + resID: pair.resEndpointGroup.ID(), + sdkRegion: awssdk.ToString(pair.sdkEndpointGroup.EndpointGroupRegion), + sdkGroupARN: awssdk.ToString(pair.sdkEndpointGroup.EndpointGroupArn), + }) + } + + var actualUnmatchedResIDs []string + for _, eg := range unmatchedResEndpointGroups { + actualUnmatchedResIDs = append(actualUnmatchedResIDs, eg.ID()) + } + + var actualUnmatchedSDKRegions []string + for _, eg := range unmatchedSDKEndpointGroups { + actualUnmatchedSDKRegions = append(actualUnmatchedSDKRegions, awssdk.ToString(eg.EndpointGroupRegion)) + } + + // Sort all slices to ensure consistent comparison + sort.Slice(actualMatchedPairs, func(i, j int) bool { + if actualMatchedPairs[i].resID != actualMatchedPairs[j].resID { + return actualMatchedPairs[i].resID < actualMatchedPairs[j].resID + } + return actualMatchedPairs[i].sdkGroupARN < actualMatchedPairs[j].sdkGroupARN + }) + + sort.Slice(tt.wantMatchedPairs, func(i, j int) bool { + if tt.wantMatchedPairs[i].resID != tt.wantMatchedPairs[j].resID { + return tt.wantMatchedPairs[i].resID < tt.wantMatchedPairs[j].resID + } + return tt.wantMatchedPairs[i].sdkGroupARN < tt.wantMatchedPairs[j].sdkGroupARN + }) + + sort.Strings(actualUnmatchedResIDs) + sort.Strings(tt.wantUnmatchedResIDs) + sort.Strings(actualUnmatchedSDKRegions) + sort.Strings(tt.wantUnmatchedSDKRegions) + + // Verify matched pairs + assert.Equal(t, len(tt.wantMatchedPairs), len(actualMatchedPairs), "matched pairs count") + for i := range tt.wantMatchedPairs { + if i < len(actualMatchedPairs) { + assert.Equal(t, tt.wantMatchedPairs[i].resID, actualMatchedPairs[i].resID, "matched pair resID at index %d", i) + assert.Equal(t, tt.wantMatchedPairs[i].sdkRegion, actualMatchedPairs[i].sdkRegion, "matched pair sdkRegion at index %d", i) + assert.Equal(t, tt.wantMatchedPairs[i].sdkGroupARN, actualMatchedPairs[i].sdkGroupARN, "matched pair sdkGroupARN at index %d", i) + } + } + + // Handle nil vs empty slices + if len(actualUnmatchedResIDs) == 0 && len(tt.wantUnmatchedResIDs) == 0 { + // Both empty, no need to compare + } else { + // Verify unmatched resource endpoint groups + assert.ElementsMatch(t, tt.wantUnmatchedResIDs, actualUnmatchedResIDs, "unmatched resource endpoint groups") + } + + if len(actualUnmatchedSDKRegions) == 0 && len(tt.wantUnmatchedSDKRegions) == 0 { + // Both empty, no need to compare + } else { + // Verify unmatched SDK endpoint groups + assert.ElementsMatch(t, tt.wantUnmatchedSDKRegions, actualUnmatchedSDKRegions, "unmatched SDK endpoint groups") + } + }) + } +} + +// createSDKEndpointGroup creates an SDK endpoint group with port overrides +func createSDKEndpointGroup(arn string, region string, portOverrides []agatypes.PortOverride) agatypes.EndpointGroup { + return agatypes.EndpointGroup{ + EndpointGroupArn: awssdk.String(arn), + EndpointGroupRegion: awssdk.String(region), + PortOverrides: portOverrides, + } +} + +func Test_endpointGroupSynthesizer_getAllEndpointGroupsInListeners(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockGA := services.NewMockGlobalAccelerator(mockCtrl) + + // Create synthesizer instance with mock GA service + s := &endpointGroupSynthesizer{ + gaService: mockGA, + logger: logr.Discard(), + } + + // Test cases + tests := []struct { + name string + listenerARNs []string + mockSetup func(mockGA *services.MockGlobalAccelerator) + expectedGroups []agatypes.EndpointGroup + expectError bool + }{ + { + name: "empty listener ARNs list", + listenerARNs: []string{}, + mockSetup: func(mockGA *services.MockGlobalAccelerator) {}, + expectedGroups: []agatypes.EndpointGroup{}, + expectError: false, + }, + { + name: "single listener with no endpoint groups", + listenerARNs: []string{"arn:aws:globalaccelerator::123456789012:listener/listener1"}, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + mockGA.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:listener/listener1"), + }). + Return([]agatypes.EndpointGroup{}, nil) + }, + expectedGroups: []agatypes.EndpointGroup{}, + expectError: false, + }, + { + name: "single listener with one endpoint group", + listenerARNs: []string{"arn:aws:globalaccelerator::123456789012:listener/listener1"}, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + mockGA.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:listener/listener1"), + }). + Return([]agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg1"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + }, nil) + }, + expectedGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg1"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + }, + expectError: false, + }, + { + name: "multiple listeners with endpoint groups", + listenerARNs: []string{ + "arn:aws:globalaccelerator::123456789012:listener/listener1", + "arn:aws:globalaccelerator::123456789012:listener/listener2", + }, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + // First listener + mockGA.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:listener/listener1"), + }). + Return([]agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg1"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + }, nil) + + // Second listener + mockGA.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:listener/listener2"), + }). + Return([]agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg2"), + EndpointGroupRegion: awssdk.String("eu-west-1"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg3"), + EndpointGroupRegion: awssdk.String("ap-southeast-1"), + }, + }, nil) + }, + expectedGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg1"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg2"), + EndpointGroupRegion: awssdk.String("eu-west-1"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg3"), + EndpointGroupRegion: awssdk.String("ap-southeast-1"), + }, + }, + expectError: false, + }, + { + name: "error retrieving endpoint groups", + listenerARNs: []string{"arn:aws:globalaccelerator::123456789012:listener/listener1"}, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + mockGA.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:listener/listener1"), + }). + Return(nil, errors.New("API error")) + }, + expectedGroups: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup mock expectations + tt.mockSetup(mockGA) + + // Call the function + ctx := context.Background() + endpointGroups, err := s.getAllEndpointGroupsInListeners(ctx, tt.listenerARNs) + + // Check error + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Check endpoint groups + assert.Equal(t, len(tt.expectedGroups), len(endpointGroups)) + + if len(tt.expectedGroups) > 0 { + // Compare each endpoint group + for i, expectedGroup := range tt.expectedGroups { + if i < len(endpointGroups) { + assert.Equal(t, awssdk.ToString(expectedGroup.EndpointGroupArn), + awssdk.ToString(endpointGroups[i].EndpointGroupArn)) + assert.Equal(t, awssdk.ToString(expectedGroup.EndpointGroupRegion), + awssdk.ToString(endpointGroups[i].EndpointGroupRegion)) + } + } + } + }) + } +} + +func Test_endpointGroupSynthesizer_detectConflictsWithSDKEndpointGroups(t *testing.T) { + testListener1ARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd/listener/l-1" + testListener2ARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd/listener/l-2" + tests := []struct { + name string + resEndpointGroups []*agamodel.EndpointGroup + sdkEndpointGroups []agatypes.EndpointGroup + wantConflictCount int + wantConflictPorts []int32 + }{ + { + name: "no endpoint groups", + resEndpointGroups: []*agamodel.EndpointGroup{}, + sdkEndpointGroups: []agatypes.EndpointGroup{}, + wantConflictCount: 0, + wantConflictPorts: []int32{}, + }, + { + name: "no port overrides", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, nil), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup("arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk", "us-east-1", nil), + }, + wantConflictCount: 0, + wantConflictPorts: []int32{}, + }, + { + name: "no conflicts - different endpoint ports", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk", + "us-east-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(9090)}, // Different port + }, + ), + }, + wantConflictCount: 0, + wantConflictPorts: []int32{}, + }, + { + name: "no conflicts - same endpoint port but different regions", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk", + "us-east-1", // Different region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8080)}, // Same endpoint port but different region + }, + ), + }, + wantConflictCount: 0, // No conflict because different regions + wantConflictPorts: []int32{}, + }, + { + name: "single conflict - same region", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:DIFFERENT:SERVICE:123456789012:endpoint-group/TEST_DIFFERENT_ARN/eg-sdk", + "us-west-2", // Same region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8080)}, // Same endpoint port + }, + ), + }, + wantConflictCount: 1, + wantConflictPorts: []int32{8080}, + }, + { + name: "multiple conflicts with same SDK group - same region", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + {ListenerPort: 443, EndpointPort: 8443}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:DIFFERENT:SERVICE:123456789012:endpoint-group/DIFFERENT-TYPES/eg-sdk", + "us-west-2", // Same region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, // Same as first port + {ListenerPort: awssdk.Int32(444), EndpointPort: awssdk.Int32(8443)}, // Same as second port + }, + ), + }, + wantConflictCount: 2, + wantConflictPorts: []int32{8080, 8443}, + }, + { + name: "multiple conflicts with same SDK group - different region (no conflict)", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + {ListenerPort: 443, EndpointPort: 8443}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk", + "us-east-1", // Different region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, // Same endpoint port but different region + {ListenerPort: awssdk.Int32(444), EndpointPort: awssdk.Int32(8443)}, // Same endpoint port but different region + }, + ), + }, + wantConflictCount: 0, // No conflicts because different regions + wantConflictPorts: []int32{}, + }, + { + name: "multiple conflicts with different SDK groups - same regions", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + {ListenerPort: 443, EndpointPort: 8443}, + }), + createEndpointGroupWithPortOverrides("eg-2", "eu-west-1", testListener2ARN, []agamodel.PortOverride{ + {ListenerPort: 81, EndpointPort: 9090}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/TEST_DIFFERENT_ARN_1/eg-sdk-1", + "us-west-2", // Same region as first resource group + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(8080)}, // Conflicts with eg-1 + }, + ), + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/TEST_DIFFERENT_ARN_2/eg-sdk-2", + "eu-west-1", // Same region as second resource group + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(83), EndpointPort: awssdk.Int32(9090)}, // Conflicts with eg-2 + }, + ), + }, + wantConflictCount: 2, + wantConflictPorts: []int32{8080, 9090}, + }, + { + name: "multiple conflicts with different SDK groups - different regions (no conflicts)", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createEndpointGroupWithPortOverrides("eg-2", "eu-west-1", testListener2ARN, []agamodel.PortOverride{ + {ListenerPort: 81, EndpointPort: 9090}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk-1", + "us-east-1", // Different region than first resource group + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(8080)}, // Same port but different region + }, + ), + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk-2", + "ap-southeast-1", // Different region than second resource group + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(83), EndpointPort: awssdk.Int32(9090)}, // Same port but different region + }, + ), + }, + wantConflictCount: 0, // No conflicts because different regions + wantConflictPorts: []int32{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create synthesizer instance + s := &endpointGroupSynthesizer{ + logger: logr.Discard(), + } + + // Run the function + ctx := context.Background() + conflicts, err := s.detectConflictsWithSDKEndpointGroups(ctx, tt.resEndpointGroups, tt.sdkEndpointGroups) + + // Check errors + assert.NoError(t, err) + + // Verify the number of conflicts + assert.Equal(t, tt.wantConflictCount, len(conflicts), "conflict count should match") + + // Collect conflict ports for checking + var conflictPorts []int32 + for port := range conflicts { + conflictPorts = append(conflictPorts, port) + } + sort.Slice(conflictPorts, func(i, j int) bool { + return conflictPorts[i] < conflictPorts[j] + }) + + // Sort expected ports for comparison + expectedPorts := make([]int32, len(tt.wantConflictPorts)) + copy(expectedPorts, tt.wantConflictPorts) + sort.Slice(expectedPorts, func(i, j int) bool { + return expectedPorts[i] < expectedPorts[j] + }) + + // Check conflict ports + assert.ElementsMatch(t, expectedPorts, conflictPorts, "conflicting ports should match") + }) + } +} + +func Test_endpointGroupSynthesizer_resolveConflictsWithSDKEndpointGroups(t *testing.T) { + // Create a mock GlobalAccelerator service + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + mockGA := services.NewMockGlobalAccelerator(mockCtrl) + + testARN1 := "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-1" + testARN2 := "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-2" + + tests := []struct { + name string + conflicts map[int32][]string + sdkEndpointGroups []agatypes.EndpointGroup + mockSetup func(mockGA *services.MockGlobalAccelerator) + expectError bool + }{ + { + name: "no conflicts", + conflicts: map[int32][]string{}, + sdkEndpointGroups: []agatypes.EndpointGroup{}, + mockSetup: func(mockGA *services.MockGlobalAccelerator) {}, + expectError: false, + }, + { + name: "single conflict with one group", + conflicts: map[int32][]string{ + 8080: {testARN1}, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + testARN1, + "us-east-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Conflicting + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, // Not conflicting + }, + ), + }, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + // Expect an update call with only the non-conflicting port override + mockGA.EXPECT().UpdateEndpointGroupWithContext(gomock.Any(), &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: awssdk.String(testARN1), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, + }, + }).Return(&globalaccelerator.UpdateEndpointGroupOutput{}, nil) + }, + expectError: false, + }, + { + name: "multiple conflicts with different groups", + conflicts: map[int32][]string{ + 8080: {testARN1}, + 9090: {testARN2}, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + testARN1, + "us-east-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Conflicting + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, // Not conflicting + }, + ), + createSDKEndpointGroup( + testARN2, + "eu-west-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(9090)}, // Conflicting + {ListenerPort: awssdk.Int32(444), EndpointPort: awssdk.Int32(9443)}, // Not conflicting + }, + ), + }, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + // Expect updates for both groups + mockGA.EXPECT().UpdateEndpointGroupWithContext(gomock.Any(), &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: awssdk.String(testARN1), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, + }, + }).Return(&globalaccelerator.UpdateEndpointGroupOutput{}, nil) + + mockGA.EXPECT().UpdateEndpointGroupWithContext(gomock.Any(), &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: awssdk.String(testARN2), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(444), EndpointPort: awssdk.Int32(9443)}, + }, + }).Return(&globalaccelerator.UpdateEndpointGroupOutput{}, nil) + }, + expectError: false, + }, + { + name: "error during update", + conflicts: map[int32][]string{ + 8080: {testARN1}, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + testARN1, + "us-east-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Conflicting + }, + ), + }, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + // Simulate an error during update + mockGA.EXPECT().UpdateEndpointGroupWithContext(gomock.Any(), gomock.Any()). + Return(nil, errors.New("update failed")) + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up the mock expectations + tt.mockSetup(mockGA) + + // Create synthesizer instance + s := &endpointGroupSynthesizer{ + gaService: mockGA, + logger: logr.Discard(), + } + + // Run the function + ctx := context.Background() + err := s.resolveConflictsWithSDKEndpointGroups(ctx, tt.conflicts, tt.sdkEndpointGroups) + + // Check errors + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +//func Test_endpointGroupSynthesizer_getAllEndpointGroupsInListeners(t *testing.T) { +// // Create a mock GlobalAccelerator service +// mockCtrl := gomock.NewController(t) +// defer mockCtrl.Finish() +// mockGA := services.NewMockGlobalAccelerator(mockCtrl) +// listenerARN1 := "arn:aws:globalaccelerator::123456789012:listener/acc-1/listener-1" +// listenerARN2 := "arn:aws:globalaccelerator::123456789012:listener/acc-1/listener-2" +// listenerARN3 := "arn:aws:globalaccelerator::123456789012:listener/acc-2/listener-3" +// +// endpointGroups1 := []agatypes.EndpointGroup{ +// { +// EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg-1"), +// EndpointGroupRegion: awssdk.String("us-east-1"), +// }, +// } +// +// endpointGroups2 := []agatypes.EndpointGroup{ +// { +// EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg-2"), +// EndpointGroupRegion: awssdk.String("us-west-2"), +// }, +// } +// +// endpointGroups3 := []agatypes.EndpointGroup{ +// { +// EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg-3"), +// EndpointGroupRegion: awssdk.String("eu-west-1"), +// }, +// { +// EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg-4"), +// EndpointGroupRegion: awssdk.String("eu-central-1"), +// }, +// } +// +// mockGA.EXPECT().ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ +// ListenerArn: awssdk.String(listenerARN1), +// }).Return(endpointGroups1, nil) +// +// mockGA.EXPECT().ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ +// ListenerArn: awssdk.String(listenerARN2), +// }).Return(endpointGroups2, nil) +// +// mockGA.EXPECT().ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ +// ListenerArn: awssdk.String(listenerARN3), +// }).Return(endpointGroups3, nil) +// +// // Create synthesizer instance +// s := &endpointGroupSynthesizer{ +// gaService: mockGA, +// logger: logr.Discard(), +// } +// +// // Run the function +// ctx := context.Background() +// allEndpointGroups, err := s.getAllEndpointGroupsInListeners(ctx, nil) +// +// // Check errors +// assert.NoError(t, err) +// +// // Verify that all endpoint groups are returned +// expectedCount := len(endpointGroups1) + len(endpointGroups2) + len(endpointGroups3) +// assert.Equal(t, expectedCount, len(allEndpointGroups), "should return all endpoint groups") +// +// // Check that groups from all listeners are included +// endpointGroupARNs := make(map[string]bool) +// endpointGroupRegions := make(map[string]bool) +// +// // Collect all ARNs and regions +// for _, eg := range allEndpointGroups { +// endpointGroupARNs[awssdk.ToString(eg.EndpointGroupArn)] = true +// endpointGroupRegions[awssdk.ToString(eg.EndpointGroupRegion)] = true +// } +// +// // Check that expected endpoint groups are included +// assert.Contains(t, endpointGroupARNs, awssdk.ToString(endpointGroups1[0].EndpointGroupArn), "should contain endpoint group 1") +// assert.Contains(t, endpointGroupARNs, awssdk.ToString(endpointGroups2[0].EndpointGroupArn), "should contain endpoint group 2") +// assert.Contains(t, endpointGroupARNs, awssdk.ToString(endpointGroups3[0].EndpointGroupArn), "should contain endpoint group 3") +// assert.Contains(t, endpointGroupARNs, awssdk.ToString(endpointGroups3[1].EndpointGroupArn), "should contain endpoint group 4") +// +// // Check that expected regions are included +// assert.Contains(t, endpointGroupRegions, "us-east-1", "should contain us-east-1 region") +// assert.Contains(t, endpointGroupRegions, "us-west-2", "should contain us-west-2 region") +// assert.Contains(t, endpointGroupRegions, "eu-west-1", "should contain eu-west-1 region") +// assert.Contains(t, endpointGroupRegions, "eu-central-1", "should contain eu-central-1 region") +//} + +// Test_endpointGroupSynthesizer_detectConflictsWithSDKEndpointGroups_OwnListener tests that the +// detectConflictsWithSDKEndpointGroups function correctly ignores conflicts with our own listeners +func Test_endpointGroupSynthesizer_detectConflictsWithSDKEndpointGroups_OwnListener(t *testing.T) { + // Define test listeners and endpoint groups + testListener1ARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd/listener/l-1" + testListener2ARN := "arn:aws:globalaccelerator::123456789012:accelerator/5678efgh/listener/l-2" + + // ARNs for SDK endpoint groups - match the structure required by the extractor + sdkGroup1ARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd/listener/l-1/endpoint-group/eg-1" + sdkGroup2ARN := "arn:aws:globalaccelerator::123456789012:accelerator/5678efgh/listener/l-2/endpoint-group/eg-2" + sdkGroup3ARN := "arn:aws:globalaccelerator::123456789012:accelerator/9012ijkl/listener/l-3/endpoint-group/eg-3" + + tests := []struct { + name string + resEndpointGroups []*agamodel.EndpointGroup + sdkEndpointGroups []agatypes.EndpointGroup + wantConflictCount int + wantConflictPorts []int32 + }{ + { + name: "no conflict with own listener - same endpoint port in same region", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + // Same listener ARN pattern as resource group, should be ignored + createSDKEndpointGroup( + sdkGroup1ARN, + "us-west-2", // Same region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, // Same endpoint port + }, + ), + }, + wantConflictCount: 0, // No conflict detected with our own listener + wantConflictPorts: []int32{}, + }, + { + name: "conflict with external listener - same endpoint port in same region", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + // Different listener ARN than resource group, should detect conflict + createSDKEndpointGroup( + sdkGroup2ARN, + "us-west-2", // Same region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, // Same endpoint port + }, + ), + }, + wantConflictCount: 1, + wantConflictPorts: []int32{8080}, + }, + { + name: "multiple listeners in different regions - no conflicts", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createEndpointGroupWithPortOverrides("eg-2", "eu-west-1", testListener2ARN, []agamodel.PortOverride{ + {ListenerPort: 443, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + // Same listener ARN but different region - no conflict + createSDKEndpointGroup( + sdkGroup1ARN, + "us-east-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, + }, + ), + // Different listener ARN but different region - no conflict + createSDKEndpointGroup( + sdkGroup3ARN, + "eu-central-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(8080)}, + }, + ), + }, + wantConflictCount: 0, + wantConflictPorts: []int32{}, + }, + { + name: "complex scenario - multiple listeners, regions, with own and external conflicts", + resEndpointGroups: []*agamodel.EndpointGroup{ + // Resource group 1 + createEndpointGroupWithPortOverrides("eg-west-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + {ListenerPort: 443, EndpointPort: 8443}, + }), + // Resource group 2 - different region + createEndpointGroupWithPortOverrides("eg-eu-1", "eu-west-1", testListener2ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 9090}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + // Same listener as eg-west-1, should NOT be conflict + createSDKEndpointGroup( + sdkGroup1ARN, + "us-west-2", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, + {ListenerPort: awssdk.Int32(444), EndpointPort: awssdk.Int32(8443)}, + }, + ), + // Same listener as eg-eu-1, should NOT be conflict + createSDKEndpointGroup( + sdkGroup2ARN, + "eu-west-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(9090)}, + }, + ), + // Different listener in us-west-2, SHOULD be conflict with eg-west-1 + createSDKEndpointGroup( + sdkGroup3ARN, + "us-west-2", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(83), EndpointPort: awssdk.Int32(8080)}, + }, + ), + }, + wantConflictCount: 1, // Only one conflict (port 8080 in us-west-2 with external listener) + wantConflictPorts: []int32{8080}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create synthesizer instance + s := &endpointGroupSynthesizer{ + logger: logr.Discard(), + } + + // Run the function + ctx := context.Background() + conflicts, err := s.detectConflictsWithSDKEndpointGroups(ctx, tt.resEndpointGroups, tt.sdkEndpointGroups) + + // Check errors + assert.NoError(t, err) + + // Verify the number of conflicts + assert.Equal(t, tt.wantConflictCount, len(conflicts), "conflict count should match") + + // Collect conflict ports for checking + var conflictPorts []int32 + for port := range conflicts { + conflictPorts = append(conflictPorts, port) + } + sort.Slice(conflictPorts, func(i, j int) bool { + return conflictPorts[i] < conflictPorts[j] + }) + + // Sort expected ports for comparison + expectedPorts := make([]int32, len(tt.wantConflictPorts)) + copy(expectedPorts, tt.wantConflictPorts) + sort.Slice(expectedPorts, func(i, j int) bool { + return expectedPorts[i] < expectedPorts[j] + }) + + // Check conflict ports + assert.ElementsMatch(t, expectedPorts, conflictPorts, "conflicting ports should match") + + // If conflicts were detected, verify they're with the right endpoint groups + if len(conflicts) > 0 { + for port, sdkGroupARNs := range conflicts { + for _, sdkGroupARN := range sdkGroupARNs { + // For each test case, only check for listener ARNs that are actually in our resource groups + // Get the list of listener ARNs used in this test's resource groups + // Create a map of all listener ARNs used by our endpoint groups + // In our test setup, we know exactly which endpoint groups use which listener ARNs + // based on how we create them in the test cases + usedListenerARNs := map[string]bool{ + testListener1ARN: false, + testListener2ARN: false, + } + + // For our test cases, we directly know which ARNs are used for each test + switch tt.name { + case "no conflict with own listener - same endpoint port in same region": + usedListenerARNs[testListener1ARN] = true + case "conflict with external listener - same endpoint port in same region": + usedListenerARNs[testListener1ARN] = true + case "multiple listeners with same port - one own, one external": + usedListenerARNs[testListener1ARN] = true + case "multiple listeners in different regions - no conflicts": + usedListenerARNs[testListener1ARN] = true + usedListenerARNs[testListener2ARN] = true + case "complex scenario - multiple listeners, regions, with own and external conflicts": + usedListenerARNs[testListener1ARN] = true + usedListenerARNs[testListener2ARN] = true + } + + // Only check if the conflict is NOT with listeners we're using + if usedListenerARNs[testListener1ARN] { + assert.NotContains(t, sdkGroupARN, testListener1ARN, + "conflict should not be with our own listener (testListener1ARN) for port %d", port) + } + + if usedListenerARNs[testListener2ARN] { + assert.NotContains(t, sdkGroupARN, testListener2ARN, + "conflict should not be with our own listener (testListener2ARN) for port %d", port) + } + } + } + } + }) + } +} diff --git a/pkg/deploy/aga/errors.go b/pkg/deploy/aga/errors.go index bb8bb9e5c..2529133bf 100644 --- a/pkg/deploy/aga/errors.go +++ b/pkg/deploy/aga/errors.go @@ -9,6 +9,9 @@ const ( // DeploymentFailed is the error code when stack deployment fails DeploymentFailed = "DeploymentFailed" + + // Status reason constants + EndpointLoadFailed = "EndpointLoadFailed" ) // AcceleratorNotDisabledError is returned when an accelerator is not ready for deletion diff --git a/pkg/deploy/aga/listener_manager.go b/pkg/deploy/aga/listener_manager.go index b72d676af..f149f7868 100644 --- a/pkg/deploy/aga/listener_manager.go +++ b/pkg/deploy/aga/listener_manager.go @@ -23,13 +23,17 @@ type ListenerManager interface { // Delete deletes a listener. Delete(ctx context.Context, listenerARN string) error + + // ListEndpointGroups lists all endpoint groups for a given listener + ListEndpointGroups(ctx context.Context, listenerARN string) ([]agatypes.EndpointGroup, error) } // NewDefaultListenerManager constructs new defaultListenerManager. -func NewDefaultListenerManager(gaService services.GlobalAccelerator, logger logr.Logger) *defaultListenerManager { +func NewDefaultListenerManager(gaService services.GlobalAccelerator, endpointGroupManager EndpointGroupManager, logger logr.Logger) *defaultListenerManager { return &defaultListenerManager{ - gaService: gaService, - logger: logger, + gaService: gaService, + endpointGroupManager: endpointGroupManager, + logger: logger, } } @@ -37,8 +41,9 @@ var _ ListenerManager = &defaultListenerManager{} // defaultListenerManager is the default implementation for ListenerManager. type defaultListenerManager struct { - gaService services.GlobalAccelerator - logger logr.Logger + gaService services.GlobalAccelerator + endpointGroupManager EndpointGroupManager + logger logr.Logger } // convertPortRangesToSDK converts model port ranges to SDK port ranges @@ -164,11 +169,30 @@ func (m *defaultListenerManager) Update(ctx context.Context, resListener *agamod } func (m *defaultListenerManager) Delete(ctx context.Context, listenerARN string) error { - // TODO: This will be enhanced to check for and delete endpoint groups - // before deleting the listener (when those features are implemented) - m.logger.Info("Deleting listener", "listenerARN", listenerARN) + // Step 1: Delete all endpoint groups associated with this listener + endpointGroups, err := m.ListEndpointGroups(ctx, listenerARN) + if err != nil { + var apiErr *agatypes.ListenerNotFoundException + if errors.As(err, &apiErr) { + m.logger.Info("Listener not found, assuming already deleted", "listenerARN", listenerARN) + return nil + } + return fmt.Errorf("failed to list endpoint groups for listener: %w", err) + } + + for _, endpointGroup := range endpointGroups { + endpointGroupARN := aws.ToString(endpointGroup.EndpointGroupArn) + m.logger.Info("Deleting endpoint group for listener", "endpointGroupARN", endpointGroupARN, "listenerARN", listenerARN) + + if err := m.endpointGroupManager.Delete(ctx, endpointGroupARN); err != nil { + return fmt.Errorf("failed to delete endpoint group %s: %w", endpointGroupARN, err) + } + m.logger.Info("Deleted endpoint group for listener", "endpointGroupARN", endpointGroupARN, "listenerARN", listenerARN) + } + + // Step 2: Delete the listener deleteInput := &globalaccelerator.DeleteListenerInput{ ListenerArn: aws.String(listenerARN), } @@ -187,6 +211,15 @@ func (m *defaultListenerManager) Delete(ctx context.Context, listenerARN string) return nil } +// ListEndpointGroups lists all endpoint groups for a given listener +func (m *defaultListenerManager) ListEndpointGroups(ctx context.Context, listenerARN string) ([]agatypes.EndpointGroup, error) { + listInput := &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: aws.String(listenerARN), + } + + return m.gaService.ListEndpointGroupsAsList(ctx, listInput) +} + // isSDKListenerSettingsDrifted checks if the listener configuration has drifted from the desired state func (m *defaultListenerManager) isSDKListenerSettingsDrifted(resListener *agamodel.Listener, sdkListener *ListenerResource) bool { // Check if protocol differs diff --git a/pkg/deploy/aga/listener_manager_mocks.go b/pkg/deploy/aga/listener_manager_mocks.go index b6c1d60f6..93699ee3c 100644 --- a/pkg/deploy/aga/listener_manager_mocks.go +++ b/pkg/deploy/aga/listener_manager_mocks.go @@ -8,6 +8,7 @@ import ( context "context" reflect "reflect" + types "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" gomock "github.com/golang/mock/gomock" aga0 "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" ) @@ -64,6 +65,21 @@ func (mr *MockListenerManagerMockRecorder) Delete(arg0, arg1 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockListenerManager)(nil).Delete), arg0, arg1) } +// ListEndpointGroups mocks base method. +func (m *MockListenerManager) ListEndpointGroups(arg0 context.Context, arg1 string) ([]types.EndpointGroup, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListEndpointGroups", arg0, arg1) + ret0, _ := ret[0].([]types.EndpointGroup) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListEndpointGroups indicates an expected call of ListEndpointGroups. +func (mr *MockListenerManagerMockRecorder) ListEndpointGroups(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEndpointGroups", reflect.TypeOf((*MockListenerManager)(nil).ListEndpointGroups), arg0, arg1) +} + // Update mocks base method. func (m *MockListenerManager) Update(arg0 context.Context, arg1 *aga0.Listener, arg2 *ListenerResource) (aga0.ListenerStatus, error) { m.ctrl.T.Helper() diff --git a/pkg/deploy/aga/listener_manager_test.go b/pkg/deploy/aga/listener_manager_test.go index a2c56ab17..ca6ea29f7 100644 --- a/pkg/deploy/aga/listener_manager_test.go +++ b/pkg/deploy/aga/listener_manager_test.go @@ -8,7 +8,9 @@ import ( "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" "github.com/go-logr/logr" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" ) @@ -316,9 +318,22 @@ func Test_defaultListenerManager_isSDKListenerSettingsDrifted(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Create listener manager + // Create mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock GlobalAccelerator service + mockGAService := services.NewMockGlobalAccelerator(ctrl) + + // Set up the mock behavior + mockGAService.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), gomock.Any()). + Return([]agatypes.EndpointGroup{}, nil). + AnyTimes() + + // Create manager with mock service m := &defaultListenerManager{ - gaService: nil, // Not needed for this test + gaService: mockGAService, logger: logr.Discard(), } diff --git a/pkg/deploy/aga/listener_synthesizer.go b/pkg/deploy/aga/listener_synthesizer.go index 7e003d140..87d4ad54b 100644 --- a/pkg/deploy/aga/listener_synthesizer.go +++ b/pkg/deploy/aga/listener_synthesizer.go @@ -9,6 +9,7 @@ import ( "github.com/go-logr/logr" "github.com/pkg/errors" "k8s.io/apimachinery/pkg/util/sets" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" @@ -97,55 +98,15 @@ func (s *listenerSynthesizer) synthesizeListenersOnAccelerator(ctx context.Conte // Improved operation order to minimize traffic disruption: // 1. Delete only conflicting listeners (that would block updates) - // 2. Update matched listeners - // 3. Delete unneeded (non-conflicting) listeners - // 4. Create new listeners + // 2. Process all port overrides which may block listener updates + // - Remove endpoint port overrides that overlap with any new listener port ranges + // - Remove listener port overrides from existing listener which is outside desired listener port ranges + // 3. Update matched listeners + // 4. Delete unneeded (non-conflicting) listeners + // 5. Create new listeners // STEP 1: Find SDK listeners that have port conflicts with planned updates - var conflictingListeners []*ListenerResource - var nonConflictingListeners []*ListenerResource - - // Track which listeners have port conflicts with our updates - conflictMap := make(map[string][]*ListenerResource) - - // For each update we're planning to do... - for _, pair := range matchedResAndSDKListeners { - var conflicts []*ListenerResource - - // Check against all unmatched SDK listeners for conflicts - for _, sdkListener := range unmatchedSDKListeners { - if s.hasPortRangeConflict(pair.resListener, sdkListener) { - conflicts = append(conflicts, sdkListener) - } - } - - // If there are conflicts, add them to our conflict map - if len(conflicts) > 0 { - conflictMap[pair.resListener.ID()] = conflicts - } - } - - // Build list of conflicting and non-conflicting listeners - listenerIsConflicting := make(map[string]bool) - - // Add all listeners with port conflicts to the conflicting list - for _, conflicts := range conflictMap { - for _, listener := range conflicts { - arn := *listener.Listener.ListenerArn - if !listenerIsConflicting[arn] { - conflictingListeners = append(conflictingListeners, listener) - listenerIsConflicting[arn] = true - } - } - } - - // Sort remaining unmatched listeners into non-conflicting - for _, sdkListener := range unmatchedSDKListeners { - arn := *sdkListener.Listener.ListenerArn - if !listenerIsConflicting[arn] { - nonConflictingListeners = append(nonConflictingListeners, sdkListener) - } - } + conflictingListeners, nonConflictingListeners := s.findConflictingAndNonConflictingListeners(matchedResAndSDKListeners, unmatchedSDKListeners) // STEP 2: Execute operations in correct order @@ -164,6 +125,15 @@ func (s *listenerSynthesizer) synthesizeListenersOnAccelerator(ctx context.Conte } } + // Next, Process all port overrides BEFORE updating listeners + allResListenerPortRanges, allSDKListenersToProcess, updatePortRangesByListener := s.preparePortOverrideProcessing(resListeners, matchedResAndSDKListeners, nonConflictingListeners) + + // Consolidated port override processing + if err := s.ProcessEndpointGroupPortOverrides(ctx, allSDKListenersToProcess, allResListenerPortRanges, updatePortRangesByListener); err != nil { + s.logger.Error(err, "Failed to process endpoint group port overrides") + return err + } + // Next, update existing matched listeners (now conflict-free) for _, pair := range matchedResAndSDKListeners { s.logger.Info("Updating existing listener", @@ -288,7 +258,7 @@ func (s *listenerSynthesizer) matchResAndSDKListeners(resListeners []*agamodel.L // 3. Matches listeners with identical keys (exact protocol and port range matches) // 4. Returns matched pairs and remaining unmatched listeners // -// The key generation ensures that port ranges in different order but with identical +// The key generation ensures that port ranges in different order but with identical // values still match correctly. func (s *listenerSynthesizer) findExactMatches(resListeners []*agamodel.Listener, sdkListeners []*ListenerResource) ( []resAndSDKListenerPair, []*agamodel.Listener, []*ListenerResource) { @@ -468,6 +438,62 @@ func (s *listenerSynthesizer) findSimilarityMatches(resListeners []*agamodel.Lis return matchedPairs, unmatchedResListeners, unmatchedSDKListeners } +// findConflictingAndNonConflictingListeners separates unmatched SDK listeners into those that have +// port conflicts with planned updates and those that don't +func (s *listenerSynthesizer) findConflictingAndNonConflictingListeners( + matchedResAndSDKListeners []resAndSDKListenerPair, + unmatchedSDKListeners []*ListenerResource) ([]*ListenerResource, []*ListenerResource) { + + var conflictingListeners []*ListenerResource + var nonConflictingListeners []*ListenerResource + + // Track which listeners have port conflicts with our updates + conflictMap := make(map[string][]*ListenerResource) + + // For each update we're planning to do... + for _, pair := range matchedResAndSDKListeners { + var conflicts []*ListenerResource + + // Check against all unmatched SDK listeners for conflicts + for _, sdkListener := range unmatchedSDKListeners { + if s.hasPortRangeConflict(pair.resListener, sdkListener) { + conflicts = append(conflicts, sdkListener) + } + } + + // If there are conflicts, add them to our conflict map + if len(conflicts) > 0 { + conflictMap[pair.resListener.ID()] = conflicts + } + } + + // Build list of conflicting and non-conflicting listeners + listenerIsConflicting := make(map[string]bool) + + // Add all listeners with port conflicts to the conflicting list + for _, conflicts := range conflictMap { + for _, listener := range conflicts { + arn := *listener.Listener.ListenerArn + if !listenerIsConflicting[arn] { + conflictingListeners = append(conflictingListeners, listener) + listenerIsConflicting[arn] = true + } + } + } + + // Sort remaining unmatched listeners into non-conflicting + for _, sdkListener := range unmatchedSDKListeners { + arn := *sdkListener.Listener.ListenerArn + if !listenerIsConflicting[arn] { + nonConflictingListeners = append(nonConflictingListeners, sdkListener) + } + } + + return conflictingListeners, nonConflictingListeners +} + +// calculateSimilarityScore calculates how similar two listeners are +// Higher scores indicate better matches // calculateSimilarityScore calculates how similar two listeners are based on their attributes. // // The scoring system uses these components: @@ -597,3 +623,250 @@ func (s *listenerSynthesizer) hasPortRangeConflict(resListener *agamodel.Listene func (s *listenerSynthesizer) portRangesToString(portRanges []agamodel.PortRange) string { return ResPortRangesToString(portRanges) } + +// havePortRangesChanged checks if port ranges have changed between resource and SDK listener +func (s *listenerSynthesizer) havePortRangesChanged(resListener *agamodel.Listener, sdkListener *ListenerResource) bool { + if len(resListener.Spec.PortRanges) != len(sdkListener.Listener.PortRanges) { + return true + } + + // Build maps for easy comparison + resPortSet := s.makeResPortSet(resListener.Spec.PortRanges) + sdkPortSet := s.makeSDKPortSet(sdkListener.Listener.PortRanges) + + // If port sets have different sizes, they've changed + if len(resPortSet) != len(sdkPortSet) { + return true + } + + // Check if any port exists in one set but not the other + for port := range resPortSet { + if !sdkPortSet[port] { + return true + } + } + + for port := range sdkPortSet { + if !resPortSet[port] { + return true + } + } + + // Port ranges are the same + return false +} + +// ProcessEndpointGroupPortOverrides handles all port override validations and updates using a two-phase approach +// Phase 1: Collect all endpoint groups and analyze all port overrides for conflicts +// Phase 2: Execute updates for all identified conflicts +// +// The two-phase approach ensures consistent behavior regardless of processing order, since all +// analysis is completed before any modifications are made. +// +// It handles these validations: +// - Remove port overrides with endpoint port that overlap with any desired listener port ranges +// - Remove port overrides with listener port outside desired listener port ranges +func (s *listenerSynthesizer) ProcessEndpointGroupPortOverrides( + ctx context.Context, + listeners []*ListenerResource, + allListenerPortRanges []agamodel.PortRange, + updatePortRangesByListener map[string][]agamodel.PortRange) error { + + s.logger.V(1).Info("Processing all endpoint port overrides before updating listeners") + + // PHASE 1: Collection and Analysis + // Map of endpoint group ARN to its conflict information + type endpointGroupConflicts struct { + endpointGroup agatypes.EndpointGroup + listenerARN string + validPortOverrides []agatypes.PortOverride + invalidPortOverrides []agatypes.PortOverride + } + + // Store all conflicts to be resolved + conflictsByEndpointGroupARN := make(map[string]*endpointGroupConflicts) + + // Process each listener to collect all endpoint groups and analyze port overrides + for _, listener := range listeners { + listenerARN := awssdk.ToString(listener.Listener.ListenerArn) + + // List endpoint groups per listener + endpointGroups, err := s.listenerManager.ListEndpointGroups(ctx, listenerARN) + if err != nil { + return fmt.Errorf("failed to list endpoint groups for listener %s: %w", listenerARN, err) + } + + // Skip if no endpoint groups + if len(endpointGroups) == 0 { + continue + } + + // Get the updated port ranges for this listener if it's being updated + updatedPortRanges := updatePortRangesByListener[listenerARN] + + // Analyze each endpoint group's port overrides + for _, eg := range endpointGroups { + endpointGroupARN := awssdk.ToString(eg.EndpointGroupArn) + + // Skip if no port overrides to check + if eg.PortOverrides == nil || len(eg.PortOverrides) == 0 { + continue + } + + s.logger.V(1).Info("Analyzing endpoint group port overrides for conflicts", + "listenerARN", listenerARN, + "endpointGroupARN", endpointGroupARN) + + // Apply all validation rules and collect valid/invalid port overrides + validPortOverrides, invalidPortOverrides := s.processPortOverridesWithAllRules( + eg.PortOverrides, + allListenerPortRanges, + updatedPortRanges) + + // Only store conflicts if we found invalid overrides + if len(invalidPortOverrides) > 0 { + conflictsByEndpointGroupARN[endpointGroupARN] = &endpointGroupConflicts{ + endpointGroup: eg, + listenerARN: listenerARN, + validPortOverrides: validPortOverrides, + invalidPortOverrides: invalidPortOverrides, + } + } + } + } + + // PHASE 2: Execution - Update all endpoint groups with conflicts + // Process all conflicts + for endpointGroupARN := range conflictsByEndpointGroupARN { + conflictInfo := conflictsByEndpointGroupARN[endpointGroupARN] + + s.logger.V(1).Info("Updating endpoint group to remove conflicting port overrides", + "endpointGroupARN", endpointGroupARN, + "listenerARN", conflictInfo.listenerARN, + "conflictCount", len(conflictInfo.invalidPortOverrides)) + + // Update this endpoint group to remove the invalid port overrides + if err := s.updateEndpointGroupPortOverrides( + ctx, + conflictInfo.endpointGroup, + conflictInfo.validPortOverrides, + conflictInfo.invalidPortOverrides); err != nil { + return fmt.Errorf("failed to update endpoint group %s to remove conflicts: %w", + endpointGroupARN, err) + } + } + + return nil +} + +// processPortOverridesWithAllRules applies all validation rules to port overrides: +// 1. Endpoint ports must not overlap with any listener port ranges (if listener is being updated) +// 2. Listener ports must be within listener port ranges (if listener is being updated) +func (s *listenerSynthesizer) processPortOverridesWithAllRules( + portOverrides []agatypes.PortOverride, + allListenerPortRanges []agamodel.PortRange, + updatedListenerPortRanges []agamodel.PortRange) ([]agatypes.PortOverride, []agatypes.PortOverride) { + + validPortOverrides := make([]agatypes.PortOverride, 0) + invalidPortOverrides := make([]agatypes.PortOverride, 0) + for _, po := range portOverrides { + isValid := true + + // Rule 1: Endpoint port must not overlap with ANY listener port range + if aga.IsPortInRanges(awssdk.ToInt32(po.EndpointPort), allListenerPortRanges) { + isValid = false + s.logger.V(1).Info("Found port override with endpoint port that overlaps with a listener port range", + "endpointPort", awssdk.ToInt32(po.EndpointPort), + "listenerPort", awssdk.ToInt32(po.ListenerPort)) + } + + // Rule 2: If listener is being updated, listener port must be within updated port ranges + if isValid && len(updatedListenerPortRanges) > 0 && !aga.IsPortInRanges(awssdk.ToInt32(po.ListenerPort), updatedListenerPortRanges) { + isValid = false + s.logger.V(1).Info("Found port override with listener port outside updated listener port range", + "listenerPort", awssdk.ToInt32(po.ListenerPort), + "endpointPort", awssdk.ToInt32(po.EndpointPort)) + } + + // Add to appropriate collection based on validation result + if isValid { + validPortOverrides = append(validPortOverrides, po) + } else { + invalidPortOverrides = append(invalidPortOverrides, po) + } + } + + return validPortOverrides, invalidPortOverrides +} + +// preparePortOverrideProcessing collects all the port override processing requirements: +// - all res listener port ranges +// - all SDK listeners to process +// - map of listeners being updated with their new port ranges +func (s *listenerSynthesizer) preparePortOverrideProcessing( + resListeners []*agamodel.Listener, + matchedResAndSDKListeners []resAndSDKListenerPair, + nonConflictingListeners []*ListenerResource) ([]agamodel.PortRange, []*ListenerResource, map[string][]agamodel.PortRange) { + + // Collect all port ranges from resource listeners + var allResListenerPortRanges []agamodel.PortRange + for _, resListener := range resListeners { + allResListenerPortRanges = append(allResListenerPortRanges, resListener.Spec.PortRanges...) + } + + // Extract the SDK listeners from matchedResAndSDKListeners + var allSDKListenersToProcess []*ListenerResource + for _, pair := range matchedResAndSDKListeners { + allSDKListenersToProcess = append(allSDKListenersToProcess, pair.sdkListener) + } + + // Combine with nonConflictingListeners + allSDKListenersToProcess = append(allSDKListenersToProcess, nonConflictingListeners...) + + // Prepare map of listeners being updated with their new port ranges + updatePortRangesByListener := make(map[string][]agamodel.PortRange) + for _, pair := range matchedResAndSDKListeners { + if s.havePortRangesChanged(pair.resListener, pair.sdkListener) { + listenerARN := awssdk.ToString(pair.sdkListener.Listener.ListenerArn) + updatePortRangesByListener[listenerARN] = pair.resListener.Spec.PortRanges + } + } + + return allResListenerPortRanges, allSDKListenersToProcess, updatePortRangesByListener +} + +// updateEndpointGroupPortOverrides updates an endpoint group with valid port overrides +// and logs information about the removed invalid ones +func (s *listenerSynthesizer) updateEndpointGroupPortOverrides( + ctx context.Context, + endpointGroup agatypes.EndpointGroup, + validPortOverrides []agatypes.PortOverride, + invalidPortOverrides []agatypes.PortOverride) error { + + endpointGroupARN := awssdk.ToString(endpointGroup.EndpointGroupArn) + + // For logging purposes, record each removed override + for _, po := range invalidPortOverrides { + s.logger.V(1).Info("Removing port override", + "listenerPort", awssdk.ToInt32(po.ListenerPort), + "endpointPort", awssdk.ToInt32(po.EndpointPort), + "endpointGroupARN", endpointGroupARN) + } + + // Update the endpoint group with only valid port overrides + _, err := s.gaClient.UpdateEndpointGroupWithContext(ctx, &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: endpointGroup.EndpointGroupArn, + PortOverrides: validPortOverrides, + }) + + if err != nil { + return fmt.Errorf("failed to update endpoint group %s for port overrides to remove conflicts: %w", endpointGroupARN, err) + } + + s.logger.Info("Successfully updated endpoint group port overrides to remove conflicts", + "endpointGroupARN", endpointGroupARN, + "removedCount", len(invalidPortOverrides), + "remainingCount", len(validPortOverrides)) + + return nil +} diff --git a/pkg/deploy/aga/listener_synthesizer_test.go b/pkg/deploy/aga/listener_synthesizer_test.go index edd01feca..0358c87db 100644 --- a/pkg/deploy/aga/listener_synthesizer_test.go +++ b/pkg/deploy/aga/listener_synthesizer_test.go @@ -1,6 +1,11 @@ package aga import ( + "context" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + "github.com/golang/mock/gomock" + pkgaga "sigs.k8s.io/aws-load-balancer-controller/pkg/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" "sort" "testing" @@ -1765,3 +1770,525 @@ func Test_listenerSynthesizer_generateSDKListenerKey(t *testing.T) { }) } } + +func Test_listenerSynthesizer_processPortOverridesWithAllRules(t *testing.T) { + tests := []struct { + name string + portOverrides []agatypes.PortOverride + allListenerPortRanges []agamodel.PortRange + updatedListenerPortRanges []agamodel.PortRange + wantValidCount int + wantInvalidCount int + wantInvalidPortsEndpoint []int32 + wantInvalidPortsListener []int32 + }{ + { + name: "empty port overrides", + portOverrides: []agatypes.PortOverride{}, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 85}}, + wantValidCount: 0, + wantInvalidCount: 0, + }, + { + name: "all port overrides valid", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8081)}, + }, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 85}}, + wantValidCount: 2, + wantInvalidCount: 0, + }, + { + name: "endpoint port overlaps with listener port range", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(85)}, // Invalid: endpoint port in listener range + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8081)}, // Valid + }, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 85}}, + wantValidCount: 1, + wantInvalidCount: 1, + wantInvalidPortsEndpoint: []int32{85}, + }, + { + name: "listener port outside updated range", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Valid + {ListenerPort: awssdk.Int32(87), EndpointPort: awssdk.Int32(8087)}, // Invalid: listener port outside updated range + }, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 85}}, + wantValidCount: 1, + wantInvalidCount: 1, + wantInvalidPortsListener: []int32{87}, + }, + { + name: "multiple invalid conditions", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Valid + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(85)}, // Invalid: endpoint port in listener range + {ListenerPort: awssdk.Int32(87), EndpointPort: awssdk.Int32(8087)}, // Invalid: listener port outside updated range + }, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 85}}, + wantValidCount: 1, + wantInvalidCount: 2, + wantInvalidPortsEndpoint: []int32{85}, + wantInvalidPortsListener: []int32{87}, + }, + { + name: "no updated listener port ranges", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Valid except for endpoint port check + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(85)}, // Invalid: endpoint port in listener range + }, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{}, + wantValidCount: 1, + wantInvalidCount: 1, + wantInvalidPortsEndpoint: []int32{85}, + }, + { + name: "multiple listener port ranges - endpoint port in one range", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Valid + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(443)}, // Invalid: endpoint port in another listener range + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + {FromPort: 443, ToPort: 443}, + }, + updatedListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + {FromPort: 443, ToPort: 443}, + }, + wantValidCount: 1, + wantInvalidCount: 1, + wantInvalidPortsEndpoint: []int32{443}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &listenerSynthesizer{ + logger: logr.Discard(), + } + + validPortOverrides, invalidPortOverrides := s.processPortOverridesWithAllRules( + tt.portOverrides, + tt.allListenerPortRanges, + tt.updatedListenerPortRanges, + ) + + // Verify counts + assert.Equal(t, tt.wantValidCount, len(validPortOverrides), "valid port overrides count") + assert.Equal(t, tt.wantInvalidCount, len(invalidPortOverrides), "invalid port overrides count") + + // If specific invalid endpoint ports were expected, check those + if tt.wantInvalidPortsEndpoint != nil { + var actualInvalidEndpointPorts []int32 + for _, po := range invalidPortOverrides { + // If this port was invalidated because of endpoint port overlap + if pkgaga.IsPortInRanges(awssdk.ToInt32(po.EndpointPort), tt.allListenerPortRanges) { + actualInvalidEndpointPorts = append(actualInvalidEndpointPorts, awssdk.ToInt32(po.EndpointPort)) + } + } + assert.ElementsMatch(t, tt.wantInvalidPortsEndpoint, actualInvalidEndpointPorts, "invalid endpoint ports") + } + + // If specific invalid listener ports were expected, check those + if tt.wantInvalidPortsListener != nil && len(tt.updatedListenerPortRanges) > 0 { + var actualInvalidListenerPorts []int32 + for _, po := range invalidPortOverrides { + // If this port was invalidated because listener port was outside updated ranges + if !pkgaga.IsPortInRanges(awssdk.ToInt32(po.EndpointPort), tt.allListenerPortRanges) && + !pkgaga.IsPortInRanges(awssdk.ToInt32(po.ListenerPort), tt.updatedListenerPortRanges) { + actualInvalidListenerPorts = append(actualInvalidListenerPorts, awssdk.ToInt32(po.ListenerPort)) + } + } + assert.ElementsMatch(t, tt.wantInvalidPortsListener, actualInvalidListenerPorts, "invalid listener ports") + } + }) + } +} + +func TestListenerSynthesizer_ProcessEndpointGroupPortOverrides(t *testing.T) { + tests := []struct { + name string + listeners []*ListenerResource + allListenerPortRanges []agamodel.PortRange + updatePortRangesByListener map[string][]agamodel.PortRange + endpointGroups map[string][]agatypes.EndpointGroup + updateCalls map[string][]agatypes.PortOverride + expectError bool + }{ + { + name: "no endpoint groups - no updates needed", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{}, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": {}, // No endpoint groups + }, + updateCalls: map[string][]agatypes.PortOverride{}, + expectError: false, + }, + { + name: "no port overrides - no updates needed", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{}, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup1"), + PortOverrides: []agatypes.PortOverride{}, // No port overrides + }, + }, + }, + updateCalls: map[string][]agatypes.PortOverride{}, + expectError: false, + }, + { + name: "endpoint port overlaps with listener port range - should be removed", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(90)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{}, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup1"), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(85)}, // Overlaps (endpoint port in listener range) + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Valid + }, + }, + }, + }, + updateCalls: map[string][]agatypes.PortOverride{ + "arn:endpointgroup1": { + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Only the valid one remains + }, + }, + expectError: false, + }, + { + name: "listener port outside updated ranges - should be removed", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(90)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{ + "arn:listener1": { + {FromPort: 80, ToPort: 85}, // Narrower range than current + }, + }, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup1"), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(8082)}, // Valid - within updated range + {ListenerPort: awssdk.Int32(88), EndpointPort: awssdk.Int32(8088)}, // Invalid - outside updated range + }, + }, + }, + }, + updateCalls: map[string][]agatypes.PortOverride{ + "arn:endpointgroup1": { + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(8082)}, // Only the valid one remains + }, + }, + expectError: false, + }, + { + name: "multiple listeners with multiple endpoint groups - consolidated processing", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(90)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener2"), + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(53), ToPort: awssdk.Int32(53)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + {FromPort: 53, ToPort: 53}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{ + "arn:listener1": { + {FromPort: 80, ToPort: 85}, // Narrowed range + }, + // No update for listener2 + }, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup1"), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(82)}, // Invalid - endpoint port overlaps + {ListenerPort: awssdk.Int32(88), EndpointPort: awssdk.Int32(8088)}, // Invalid - listener port outside updated range + {ListenerPort: awssdk.Int32(85), EndpointPort: awssdk.Int32(8085)}, // Valid + }, + }, + }, + "arn:listener2": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup2"), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(5353)}, // Valid + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(53)}, // Invalid - endpoint port overlaps + }, + }, + }, + }, + updateCalls: map[string][]agatypes.PortOverride{ + "arn:endpointgroup1": { + {ListenerPort: awssdk.Int32(85), EndpointPort: awssdk.Int32(8085)}, // Only valid one remains + }, + "arn:endpointgroup2": { + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(5353)}, // Only valid one remains + }, + }, + expectError: false, + }, + { + name: "multiple listeners each with both types of port override conflicts", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(90)}, + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener2"), + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(53), ToPort: awssdk.Int32(53)}, + {FromPort: awssdk.Int32(8000), ToPort: awssdk.Int32(8010)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + {FromPort: 443, ToPort: 443}, + {FromPort: 53, ToPort: 53}, + {FromPort: 8000, ToPort: 8010}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{ + "arn:listener1": { + {FromPort: 80, ToPort: 85}, // Narrowed range + {FromPort: 443, ToPort: 443}, // Unchanged + }, + "arn:listener2": { + {FromPort: 53, ToPort: 53}, // Unchanged + {FromPort: 8000, ToPort: 8005}, // Narrowed range + }, + }, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup1"), + PortOverrides: []agatypes.PortOverride{ + // Both types of issues + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(82)}, // Invalid - endpoint port overlaps + {ListenerPort: awssdk.Int32(88), EndpointPort: awssdk.Int32(8088)}, // Invalid - listener port outside updated range + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(443)}, // Invalid - endpoint port overlaps + {ListenerPort: awssdk.Int32(85), EndpointPort: awssdk.Int32(8085)}, // Valid + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, // Valid + }, + }, + }, + "arn:listener2": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup2"), + PortOverrides: []agatypes.PortOverride{ + // Both types of issues + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(53)}, // Invalid - endpoint port overlaps + {ListenerPort: awssdk.Int32(8008), EndpointPort: awssdk.Int32(9008)}, // Invalid - listener port outside updated range + {ListenerPort: awssdk.Int32(8000), EndpointPort: awssdk.Int32(8000)}, // Invalid - endpoint port overlaps + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(5353)}, // Valid + {ListenerPort: awssdk.Int32(8005), EndpointPort: awssdk.Int32(9005)}, // Valid + }, + }, + }, + }, + updateCalls: map[string][]agatypes.PortOverride{ + "arn:endpointgroup1": { + // Only valid port overrides remain + {ListenerPort: awssdk.Int32(85), EndpointPort: awssdk.Int32(8085)}, + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, + }, + "arn:endpointgroup2": { + // Only valid port overrides remain + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(5353)}, + {ListenerPort: awssdk.Int32(8005), EndpointPort: awssdk.Int32(9005)}, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mocks + // Setup controller and mocks + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockGaClient := services.NewMockGlobalAccelerator(ctrl) + + // Create mock listener manager using gomock + mockListenerManager := NewMockListenerManager(ctrl) + + // Setup expectations for mockListenerManager.ListEndpointGroups + for listenerARN, endpointGroups := range tt.endpointGroups { + mockListenerManager.EXPECT(). + ListEndpointGroups(gomock.Any(), listenerARN). + Return(endpointGroups, nil) + } + + // Track which endpoint groups have been updated + updatedEndpointGroups := make(map[string]bool) + + // Setup expectations for mockGaClient.UpdateEndpointGroupWithContext + for _, expectedPortMapping := range tt.updateCalls { + // If we expect an update for this endpoint group, mock the call + if len(expectedPortMapping) > 0 { + mockGaClient.EXPECT(). + UpdateEndpointGroupWithContext(gomock.Any(), gomock.Any()). + AnyTimes(). + DoAndReturn( + func(_ context.Context, input *globalaccelerator.UpdateEndpointGroupInput) (*globalaccelerator.UpdateEndpointGroupOutput, error) { + // Mark this endpoint group as updated + arn := awssdk.ToString(input.EndpointGroupArn) + updatedEndpointGroups[arn] = true + + // Get the expected port mapping for this endpoint group + expectedMapping, exists := tt.updateCalls[arn] + assert.True(t, exists, "Unexpected endpoint group update: %s", arn) + + // Verify port overrides count + assert.Equal(t, len(expectedMapping), len(input.PortOverrides)) + + // Create map of actual port overrides + actualMapping := make(map[int32]int32) + for _, po := range input.PortOverrides { + actualMapping[awssdk.ToInt32(po.ListenerPort)] = awssdk.ToInt32(po.EndpointPort) + } + + // Convert actual port overrides to a map for easier comparison + actualMap := make(map[int32]int32) + for _, po := range input.PortOverrides { + actualMap[awssdk.ToInt32(po.ListenerPort)] = awssdk.ToInt32(po.EndpointPort) + } + + // Create expected map + expectedMap := make(map[int32]int32) + for _, po := range expectedMapping { + expectedMap[awssdk.ToInt32(po.ListenerPort)] = awssdk.ToInt32(po.EndpointPort) + } + + // Verify the mappings match + assert.Equal(t, expectedMap, actualMap) + + return &globalaccelerator.UpdateEndpointGroupOutput{}, nil + }) + } + } + + // Create the synthesizer with the mocks + s := NewListenerSynthesizer(mockGaClient, mockListenerManager, logr.Discard(), nil) + + // Call the method under test + err := s.ProcessEndpointGroupPortOverrides( + context.Background(), + tt.listeners, + tt.allListenerPortRanges, + tt.updatePortRangesByListener, + ) + + // Assert results + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + + // Verify all expected endpoint group updates happened + for endpointGroupARN, expectedMapping := range tt.updateCalls { + if len(expectedMapping) > 0 { + assert.True(t, updatedEndpointGroups[endpointGroupARN], + "Expected endpoint group %s to be updated", endpointGroupARN) + } + } + } + }) + } +} diff --git a/pkg/deploy/aga/stack_deployer.go b/pkg/deploy/aga/stack_deployer.go index 8300882ca..a023758b9 100644 --- a/pkg/deploy/aga/stack_deployer.go +++ b/pkg/deploy/aga/stack_deployer.go @@ -32,25 +32,25 @@ func NewDefaultStackDeployer(cloud services.Cloud, config config.ControllerConfi // Create actual managers agaTaggingManager := NewDefaultTaggingManager(cloud.GlobalAccelerator(), cloud.RGT(), logger) - listenerManager := NewDefaultListenerManager(cloud.GlobalAccelerator(), logger) + endpointGroupManager := NewDefaultEndpointGroupManager(cloud.GlobalAccelerator(), logger) + listenerManager := NewDefaultListenerManager(cloud.GlobalAccelerator(), endpointGroupManager, logger) acceleratorManager := NewDefaultAcceleratorManager(cloud.GlobalAccelerator(), trackingProvider, agaTaggingManager, listenerManager, config.ExternalManagedTags, logger) // TODO: Create other managers when they are implemented - // endpointGroupManager := NewDefaultEndpointGroupManager(cloud.GlobalAccelerator(), trackingProvider, agaTaggingManager, config.ExternalManagedTags, logger) // endpointManager := NewDefaultEndpointManager(cloud.GlobalAccelerator(), logger) return &defaultStackDeployer{ - cloud: cloud, - controllerConfig: config, - trackingProvider: trackingProvider, - featureGates: config.FeatureGates, - logger: logger, - metricsCollector: metricsCollector, - controllerName: controllerName, - agaTaggingManager: agaTaggingManager, - acceleratorManager: acceleratorManager, - listenerManager: listenerManager, + cloud: cloud, + controllerConfig: config, + trackingProvider: trackingProvider, + featureGates: config.FeatureGates, + logger: logger, + metricsCollector: metricsCollector, + controllerName: controllerName, + agaTaggingManager: agaTaggingManager, + acceleratorManager: acceleratorManager, + listenerManager: listenerManager, + endpointGroupManager: endpointGroupManager, // TODO: Set other managers when implemented - // endpointGroupManager: endpointGroupManager, // endpointManager: endpointManager, } } @@ -68,11 +68,11 @@ type defaultStackDeployer struct { controllerName string // Actual managers - agaTaggingManager TaggingManager - acceleratorManager AcceleratorManager - listenerManager ListenerManager + agaTaggingManager TaggingManager + acceleratorManager AcceleratorManager + listenerManager ListenerManager + endpointGroupManager EndpointGroupManager // TODO: Add other managers when implemented - // endpointGroupManager EndpointGroupManager // endpointManager EndpointManager } @@ -92,8 +92,8 @@ func (d *defaultStackDeployer) Deploy(ctx context.Context, stack core.Stack, met synthesizers = append(synthesizers, NewAcceleratorSynthesizer(d.cloud.GlobalAccelerator(), d.trackingProvider, d.agaTaggingManager, d.acceleratorManager, d.logger, d.featureGates, stack), NewListenerSynthesizer(d.cloud.GlobalAccelerator(), d.listenerManager, d.logger, stack), + NewEndpointGroupSynthesizer(d.cloud.GlobalAccelerator(), d.endpointGroupManager, d.logger, stack), // TODO: Add other synthesizers when managers are implemented - // NewEndpointGroupSynthesizer(d.cloud.GlobalAccelerator(), d.trackingProvider, d.agaTaggingManager, d.endpointGroupManager, d.logger, d.featureGates, stack), // NewEndpointSynthesizer(d.cloud.GlobalAccelerator(), d.trackingProvider, d.endpointManager, d.logger, d.featureGates, stack), ) diff --git a/pkg/k8s/events.go b/pkg/k8s/events.go index efad8aa26..42c966f2f 100644 --- a/pkg/k8s/events.go +++ b/pkg/k8s/events.go @@ -52,6 +52,7 @@ const ( GlobalAcceleratorEventReasonFailedUpdateStatus = "FailedUpdateStatus" GlobalAcceleratorEventReasonFailedCleanup = "FailedCleanup" GlobalAcceleratorEventReasonFailedBuildModel = "FailedBuildModel" + GlobalAcceleratorEventReasonFailedEndpointLoad = "FailedEndpointLoad" GlobalAcceleratorEventReasonFailedDeploy = "FailedDeploy" GlobalAcceleratorEventReasonSuccessfullyReconciled = "SuccessfullyReconciled" ) diff --git a/pkg/model/aga/endpoint_group.go b/pkg/model/aga/endpoint_group.go new file mode 100644 index 000000000..dfa5e1ba0 --- /dev/null +++ b/pkg/model/aga/endpoint_group.go @@ -0,0 +1,102 @@ +package aga + +import ( + "context" + "github.com/pkg/errors" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +const ( + // ResourceTypeEndpointGroup is the resource type for Global Accelerator EndpointGroup + ResourceTypeEndpointGroup = "AWS::GlobalAccelerator::EndpointGroup" +) + +var _ core.Resource = &EndpointGroup{} + +// EndpointGroup represents an AWS Global Accelerator EndpointGroup. +type EndpointGroup struct { + core.ResourceMeta `json:"-"` + + // desired state of EndpointGroup + Spec EndpointGroupSpec `json:"spec"` + + // observed state of EndpointGroup + // +optional + Status *EndpointGroupStatus `json:"status,omitempty"` + + // reference to Listener resource + Listener *Listener `json:"-"` +} + +// NewEndpointGroup constructs new EndpointGroup resource. +func NewEndpointGroup(stack core.Stack, id string, spec EndpointGroupSpec, listener *Listener) *EndpointGroup { + endpointGroup := &EndpointGroup{ + ResourceMeta: core.NewResourceMeta(stack, ResourceTypeEndpointGroup, id), + Spec: spec, + Status: nil, + Listener: listener, + } + stack.AddResource(endpointGroup) + endpointGroup.registerDependencies(stack) + return endpointGroup +} + +// SetStatus sets the EndpointGroup's status +func (eg *EndpointGroup) SetStatus(status EndpointGroupStatus) { + eg.Status = &status +} + +// EndpointGroupARN returns The Amazon Resource Name (ARN) of the endpoint group. +func (eg *EndpointGroup) EndpointGroupARN() core.StringToken { + return core.NewResourceFieldStringToken(eg, "status/endpointGroupARN", + func(ctx context.Context, res core.Resource, fieldPath string) (s string, err error) { + endpointGroup := res.(*EndpointGroup) + if endpointGroup.Status == nil { + return "", errors.Errorf("EndpointGroup is not fulfilled yet: %v", endpointGroup.ID()) + } + return endpointGroup.Status.EndpointGroupARN, nil + }, + ) +} + +// register dependencies for EndpointGroup. +func (eg *EndpointGroup) registerDependencies(stack core.Stack) { + // EndpointGroup depends on its Listener + stack.AddDependency(eg, eg.Listener) +} + +// PortOverride defines the port override for Global Accelerator endpoint groups. +type PortOverride struct { + // ListenerPort is the listener port that you want to map to a specific endpoint port. + ListenerPort int32 `json:"listenerPort"` + + // EndpointPort is the endpoint port that you want traffic to be routed to. + EndpointPort int32 `json:"endpointPort"` +} + +// EndpointGroupSpec defines the desired state of EndpointGroup +type EndpointGroupSpec struct { + // ListenerARN is the ARN of the listener for the endpoint group + ListenerARN core.StringToken `json:"listenerARN"` + + // Region is the AWS Region where the endpoint group is located. + Region string `json:"region"` + + // TrafficDialPercentage is the percentage of traffic to send to an AWS Region. + // +optional + TrafficDialPercentage *int32 `json:"trafficDialPercentage,omitempty"` + + // PortOverrides is a list of endpoint port overrides. + // +optional + PortOverrides []PortOverride `json:"portOverrides,omitempty"` + + // EndpointConfigurations is a list of endpoint configurations for the endpoint group. + // +optional + // This field is not implemented in the initial version as it will be part of a separate endpoint builder. +} + +// EndpointGroupStatus defines the observed state of EndpointGroup +type EndpointGroupStatus struct { + // EndpointGroupARN is the Amazon Resource Name (ARN) of the endpoint group. + EndpointGroupARN string `json:"endpointGroupARN"` +} diff --git a/pkg/status/aga/status_updater.go b/pkg/status/aga/status_updater.go index 0fd2f046d..75cbfb827 100644 --- a/pkg/status/aga/status_updater.go +++ b/pkg/status/aga/status_updater.go @@ -188,7 +188,7 @@ func (u *defaultStatusUpdater) UpdateStatusFailure(ctx context.Context, ga *v1be Status: metav1.ConditionFalse, LastTransitionTime: metav1.Now(), Reason: reason, - Message: message, + Message: "Reconciliation failed. See events and controller logs for details", } conditionUpdated := u.updateCondition(&ga.Status.Conditions, failureCondition) diff --git a/pkg/status/aga/status_updater_test.go b/pkg/status/aga/status_updater_test.go index 7ea74b968..440fa4917 100644 --- a/pkg/status/aga/status_updater_test.go +++ b/pkg/status/aga/status_updater_test.go @@ -291,7 +291,7 @@ func Test_defaultStatusUpdater_UpdateStatusFailure(t *testing.T) { }, }, reason: "ProvisioningFailed", - message: "Failed to provision accelerator: validation error", + message: "Reconciliation failed. See events and controller logs for details", validateStatus: func(t *testing.T, ga *v1beta1.GlobalAccelerator) { // Check that observed generation was updated assert.NotNil(t, ga.Status.ObservedGeneration) @@ -303,7 +303,7 @@ func Test_defaultStatusUpdater_UpdateStatusFailure(t *testing.T) { assert.Equal(t, ConditionTypeReady, condition.Type) assert.Equal(t, metav1.ConditionFalse, condition.Status) assert.Equal(t, "ProvisioningFailed", condition.Reason) - assert.Equal(t, "Failed to provision accelerator: validation error", condition.Message) + assert.Equal(t, "Reconciliation failed. See events and controller logs for details", condition.Message) }, }, { @@ -328,7 +328,7 @@ func Test_defaultStatusUpdater_UpdateStatusFailure(t *testing.T) { }, }, reason: "NewError", - message: "New error message", + message: "Reconciliation failed. See events and controller logs for details", validateStatus: func(t *testing.T, ga *v1beta1.GlobalAccelerator) { // Check that observed generation was updated assert.NotNil(t, ga.Status.ObservedGeneration) @@ -340,7 +340,7 @@ func Test_defaultStatusUpdater_UpdateStatusFailure(t *testing.T) { assert.Equal(t, ConditionTypeReady, condition.Type) assert.Equal(t, metav1.ConditionFalse, condition.Status) assert.Equal(t, "NewError", condition.Reason) - assert.Equal(t, "New error message", condition.Message) + assert.Equal(t, "Reconciliation failed. See events and controller logs for details", condition.Message) }, }, { @@ -359,13 +359,13 @@ func Test_defaultStatusUpdater_UpdateStatusFailure(t *testing.T) { Status: metav1.ConditionFalse, LastTransitionTime: metav1.Now(), Reason: "SameError", - Message: "Same error message", + Message: "Reconciliation failed. See events and controller logs for details", }, }, }, }, reason: "SameError", - message: "Same error message", + message: "Reconciliation failed. See events and controller logs for details", validateStatus: func(t *testing.T, ga *v1beta1.GlobalAccelerator) { // Status should be unchanged assert.NotNil(t, ga.Status.ObservedGeneration) diff --git a/scripts/gen_mocks.sh b/scripts/gen_mocks.sh index 413cbec8d..73338c9e7 100755 --- a/scripts/gen_mocks.sh +++ b/scripts/gen_mocks.sh @@ -26,6 +26,7 @@ $MOCKGEN -package=networking -destination=./pkg/networking/vpc_info_provider_moc $MOCKGEN -package=networking -destination=./pkg/networking/backend_sg_provider_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/networking BackendSGProvider $MOCKGEN -package=networking -destination=./pkg/networking/security_group_resolver_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/networking SecurityGroupResolver $MOCKGEN -package=aga -destination=./pkg/deploy/aga/accelerator_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga AcceleratorManager +$MOCKGEN -package=aga -destination=./pkg/deploy/aga/endpoint_group_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga EndpointGroupManager $MOCKGEN -package=aga -destination=./pkg/deploy/aga/listener_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga ListenerManager $MOCKGEN -package=aga -destination=./pkg/deploy/aga/tagging_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga TaggingManager $MOCKGEN -package=certs -destination=./pkg/certs/cert_discovery_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/certs CertDiscovery