diff --git a/controllers/aga/globalaccelerator_controller.go b/controllers/aga/globalaccelerator_controller.go index a3819039c..0f4ed7f1c 100644 --- a/controllers/aga/globalaccelerator_controller.go +++ b/controllers/aga/globalaccelerator_controller.go @@ -68,9 +68,6 @@ const ( requeueMessage = "Monitoring provisioning state" statusUpdateRequeueTime = 1 * time.Minute - // Status reason constants - EndpointLoadFailed = "EndpointLoadFailed" - // Metric stage constants MetricStageFetchGlobalAccelerator = "fetch_globalAccelerator" MetricStageAddFinalizers = "add_finalizers" @@ -90,7 +87,7 @@ const ( func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder record.EventRecorder, finalizerManager k8s.FinalizerManager, config config.ControllerConfig, cloud services.Cloud, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector, reconcileCounters *metricsutil.ReconcileCounters) *globalAcceleratorReconciler { // Create tracking provider - trackingProvider := tracking.NewDefaultProvider(agaTagPrefix, config.ClusterName, tracking.WithRegion(config.AWSConfig.Region)) + trackingProvider := tracking.NewDefaultProvider(agaTagPrefix, config.ClusterName, tracking.WithRegion(cloud.Region())) // Create model builder agaModelBuilder := aga.NewDefaultModelBuilder( @@ -99,7 +96,7 @@ func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder recor trackingProvider, config.FeatureGates, config.ClusterName, - config.AWSConfig.Region, + cloud.Region(), config.DefaultTags, config.ExternalManagedTags, logger.WithName("aga-model-builder"), @@ -272,11 +269,11 @@ func (r *globalAcceleratorReconciler) buildModel(ctx context.Context, ga *agaapi func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx context.Context, ga *agaapi.GlobalAccelerator) error { r.logger.Info("Reconciling GlobalAccelerator resources", "globalAccelerator", k8s.NamespacedName(ga)) - // Get all endpoints from GA - endpoints := aga.GetAllEndpointsFromGA(ga) + // Get all desired endpoints from GA + endpoints := aga.GetAllDesiredEndpointsFromGA(ga) // Track referenced endpoints - r.referenceTracker.UpdateReferencesForGA(ga, endpoints) + r.referenceTracker.UpdateDesiredEndpointReferencesForGA(ga, endpoints) // Update resource watches with the endpointResourcesManager r.endpointResourcesManager.MonitorEndpointResources(ga, endpoints) @@ -285,10 +282,10 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co _, fatalErrors := r.endpointLoader.LoadEndpoints(ctx, ga, endpoints) if len(fatalErrors) > 0 { err := fmt.Errorf("failed to load endpoints: %v", fatalErrors[0]) - r.logger.Error(err, "Fatal error loading endpoints") - + r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedEndpointLoad, fmt.Sprintf("Failed to reconcile due to %v", err)) + r.logger.Error(err, fmt.Sprintf("fatal error loading endpoints for %v", k8s.NamespacedName(ga))) // Handle other endpoint loading errors - if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, EndpointLoadFailed, err.Error()); statusErr != nil { + if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.EndpointLoadFailed, err.Error()); statusErr != nil { r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after endpoint load failure") } return err @@ -302,6 +299,8 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co } r.metricsCollector.ObserveControllerReconcileLatency(controllerName, MetricStageBuildModel, buildModelFn) if err != nil { + r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GatewayEventReasonFailedBuildModel, fmt.Sprintf("Failed to build model: %v", err)) + r.logger.Error(err, fmt.Sprintf("Failed to build model for: %v", k8s.NamespacedName(ga))) // Update status to indicate model building failure if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.ModelBuildFailed, fmt.Sprintf("Failed to build model: %v", err)); statusErr != nil { r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after model build failure") @@ -316,7 +315,7 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co r.metricsCollector.ObserveControllerReconcileLatency(controllerName, MetricStageDeployStack, deployStackFn) if err != nil { r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedDeploy, fmt.Sprintf("Failed to deploy stack due to %v", err)) - + r.logger.Error(err, fmt.Sprintf("Failed to deploy stack for: %v", k8s.NamespacedName(ga))) // Update status to indicate deployment failure if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.DeploymentFailed, fmt.Sprintf("Failed to deploy stack: %v", err)); statusErr != nil { r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after deployment failure") diff --git a/go.mod b/go.mod index c0ea412fd..7ea5607cd 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 + github.com/hashicorp/golang-lru v1.0.2 github.com/onsi/ginkgo/v2 v2.23.3 github.com/onsi/gomega v1.37.0 github.com/pkg/errors v0.9.1 @@ -104,7 +105,6 @@ require ( github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect - github.com/hashicorp/golang-lru v1.0.2 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/imkira/go-interpol v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/main.go b/main.go index 269f978bb..ce0ec948a 100644 --- a/main.go +++ b/main.go @@ -241,7 +241,7 @@ func main() { } // Setup GlobalAccelerator controller only if enabled - if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) { + if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, cloud.Region()) { agaReconciler := agacontroller.NewGlobalAcceleratorReconciler(mgr.GetClient(), mgr.GetEventRecorderFor("globalAccelerator"), finalizerManager, controllerCFG, cloud, ctrl.Log.WithName("controllers").WithName("globalAccelerator"), lbcMetricsCollector, reconcileCounters) if err := agaReconciler.SetupWithManager(ctx, mgr, clientSet); err != nil { @@ -442,7 +442,7 @@ func main() { networkingwebhook.NewIngressValidator(mgr.GetClient(), controllerCFG.IngressConfig, ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) // Setup GlobalAccelerator validator only if enabled - if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) { + if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, cloud.Region()) { agawebhook.NewGlobalAcceleratorValidator(ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr) } //+kubebuilder:scaffold:builder diff --git a/pkg/aga/endpoint_utils.go b/pkg/aga/endpoint_utils.go index b50ad6d9f..d5c75ca3f 100644 --- a/pkg/aga/endpoint_utils.go +++ b/pkg/aga/endpoint_utils.go @@ -28,8 +28,8 @@ type EndpointReference struct { Endpoint *agaapi.GlobalAcceleratorEndpoint } -// GetAllEndpointsFromGA extracts all endpoint references from a GlobalAccelerator resource -func GetAllEndpointsFromGA(ga *agaapi.GlobalAccelerator) []EndpointReference { +// GetAllDesiredEndpointsFromGA extracts all endpoint references from a GlobalAccelerator resource +func GetAllDesiredEndpointsFromGA(ga *agaapi.GlobalAccelerator) []EndpointReference { if ga == nil || ga.Spec.Listeners == nil { return nil } diff --git a/pkg/aga/endpoint_utils_test.go b/pkg/aga/endpoint_utils_test.go index e8401ace1..bbebe799d 100644 --- a/pkg/aga/endpoint_utils_test.go +++ b/pkg/aga/endpoint_utils_test.go @@ -180,7 +180,7 @@ func TestGetAllEndpointsFromGA(t *testing.T) { } } - result := GetAllEndpointsFromGA(tt.ga) + result := GetAllDesiredEndpointsFromGA(tt.ga) // Compare lengths assert.Equal(t, len(tt.expected), len(result)) diff --git a/pkg/aga/model_build_endpoint_group.go b/pkg/aga/model_build_endpoint_group.go new file mode 100644 index 000000000..90d987b86 --- /dev/null +++ b/pkg/aga/model_build_endpoint_group.go @@ -0,0 +1,250 @@ +package aga + +import ( + "context" + "fmt" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +// endpointGroupBuilder builds EndpointGroup model resources +type endpointGroupBuilder interface { + // Build builds all endpoint groups for all listeners + Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error) + + // buildEndpointGroupsForListener builds endpoint groups for a specific listener + buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error) +} + +// NewEndpointGroupBuilder constructs new endpointGroupBuilder +func NewEndpointGroupBuilder(clusterRegion string) endpointGroupBuilder { + return &defaultEndpointGroupBuilder{ + clusterRegion: clusterRegion, + } +} + +var _ endpointGroupBuilder = &defaultEndpointGroupBuilder{} + +type defaultEndpointGroupBuilder struct { + clusterRegion string +} + +// Build builds EndpointGroup model resources +func (b *defaultEndpointGroupBuilder) Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error) { + if listeners == nil || len(listeners) == 0 { + return nil, nil + } + + var result []*agamodel.EndpointGroup + + // Create a map of all listener port ranges + listenerPortRanges := make(map[string][]agamodel.PortRange) // Maps listener ID to its port ranges + for _, listener := range listeners { + listenerPortRanges[listener.ID()] = listener.Spec.PortRanges + } + + for i, listener := range listeners { + listenerConfig := listenerConfigs[i] + if listenerConfig.EndpointGroups == nil { + continue + } + + listenerEndpointGroups, err := b.buildEndpointGroupsForListener(ctx, stack, listener, *listenerConfig.EndpointGroups, i) + if err != nil { + return nil, err + } + result = append(result, listenerEndpointGroups...) + } + + // Validate endpoint ports in all port overrides across all listeners + if err := b.validateEndpointPortOverridesCrossListeners(result, listenerPortRanges); err != nil { + return nil, err + } + + return result, nil +} + +// validateEndpointPortOverridesCrossListeners performs validations for endpoint port overrides across all listeners +func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesCrossListeners(endpointGroups []*agamodel.EndpointGroup, listenerPortRanges map[string][]agamodel.PortRange) error { + // Track endpoint port usage across all endpoint groups + endpointPortUsage := make(map[int32]string) // Maps endpoint port to listener ID + + // Check all endpoint groups for port overrides + for _, endpointGroup := range endpointGroups { + listenerID := endpointGroup.Listener.ID() + + for _, portOverride := range endpointGroup.Spec.PortOverrides { + endpointPort := portOverride.EndpointPort + + // Rule 1: Check if endpoint port is within any listener's port range + if err := b.validateEndpointPortOverridesWithinListener(endpointPort, listenerPortRanges); err != nil { + return err + } + + // Rule 2: Check for duplicate endpoint port usage across listeners + if existingListenerID, exists := endpointPortUsage[endpointPort]; exists && existingListenerID != listenerID { + return fmt.Errorf("duplicate endpoint port %d: the same endpoint port cannot be used in port overrides from different listeners (used in %s and %s)", + endpointPort, existingListenerID, listenerID) + } + + // Register this endpoint port usage + endpointPortUsage[endpointPort] = listenerID + } + } + + return nil +} + +// validateEndpointPortOverridesWithinListener checks if an endpoint port is within any listener's port range +func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesWithinListener(endpointPort int32, listenerPortRanges map[string][]agamodel.PortRange) error { + for listenerID, portRanges := range listenerPortRanges { + if IsPortInRanges(endpointPort, portRanges) { + // Find the specific port range for the error message + for _, portRange := range portRanges { + if endpointPort >= portRange.FromPort && endpointPort <= portRange.ToPort { + return fmt.Errorf("endpoint port %d conflicts with listener %s port range %d-%d: endpoint port cannot be included in any listener port range", + endpointPort, listenerID, portRange.FromPort, portRange.ToPort) + } + } + } + } + return nil +} + +// buildEndpointGroupsForListener builds EndpointGroup models for a specific listener +func (b *defaultEndpointGroupBuilder) buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error) { + var result []*agamodel.EndpointGroup + + for i, endpointGroup := range endpointGroups { + spec, err := b.buildEndpointGroupSpec(ctx, listener, endpointGroup) + if err != nil { + return nil, err + } + + resourceID := fmt.Sprintf("EndpointGroup-%d-%d", listenerIndex, i) + endpointGroupModel := agamodel.NewEndpointGroup(stack, resourceID, spec, listener) + result = append(result, endpointGroupModel) + } + + return result, nil +} + +// buildEndpointGroupSpec builds the EndpointGroupSpec for a single EndpointGroup model resource +func (b *defaultEndpointGroupBuilder) buildEndpointGroupSpec(ctx context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (agamodel.EndpointGroupSpec, error) { + region, err := b.determineRegion(endpointGroup) + if err != nil { + return agamodel.EndpointGroupSpec{}, err + } + + // Handle trafficDialPercentage + trafficDialPercentage := endpointGroup.TrafficDialPercentage + + portOverrides, err := b.buildPortOverrides(ctx, listener, endpointGroup) + if err != nil { + return agamodel.EndpointGroupSpec{}, err + } + + return agamodel.EndpointGroupSpec{ + ListenerARN: listener.ListenerARN(), + Region: region, + TrafficDialPercentage: trafficDialPercentage, + PortOverrides: portOverrides, + }, nil +} + +// validateListenerPortOverrideWithinListenerPortRanges ensures all listener ports used in port overrides are +// contained within the listener's port ranges +func (b *defaultEndpointGroupBuilder) validateListenerPortOverrideWithinListenerPortRanges(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error { + if len(portOverrides) == 0 { + return nil + } + + for _, portOverride := range portOverrides { + listenerPort := portOverride.ListenerPort + if !IsPortInRanges(listenerPort, listener.Spec.PortRanges) { + return fmt.Errorf("port override listener port %d is not within any listener port ranges - this will cause AWS Global Accelerator to reject the configuration", listenerPort) + } + } + return nil +} + +// determineRegion determines the region for the endpoint group +func (b *defaultEndpointGroupBuilder) determineRegion(endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (string, error) { + // Use explicit region from endpoint group if specified + if endpointGroup.Region != nil && awssdk.ToString(endpointGroup.Region) != "" { + return awssdk.ToString(endpointGroup.Region), nil + } + + // Default to cluster region if available + if b.clusterRegion != "" { + return b.clusterRegion, nil + } + return "", fmt.Errorf("region is required for endpoint group but neither specified in the endpoint group nor available from cluster configuration") +} + +// buildPortOverrides builds the port overrides for the endpoint group +func (b *defaultEndpointGroupBuilder) buildPortOverrides(_ context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) ([]agamodel.PortOverride, error) { + if endpointGroup.PortOverrides == nil { + return nil, nil + } + + var portOverrides []agamodel.PortOverride + for _, po := range *endpointGroup.PortOverrides { + portOverrides = append(portOverrides, agamodel.PortOverride{ + ListenerPort: po.ListenerPort, + EndpointPort: po.EndpointPort, + }) + } + + // Validate all port override rules + if err := b.validatePortOverrides(listener, portOverrides); err != nil { + return []agamodel.PortOverride{}, err + } + + return portOverrides, nil +} + +// validateNoDuplicatePorts checks both listener and endpoint ports for duplicates in a single pass +func (b *defaultEndpointGroupBuilder) validateNoDuplicatePorts(portOverrides []agamodel.PortOverride) error { + if len(portOverrides) <= 1 { + return nil + } + + listenerPorts := make(map[int32]bool) + endpointPorts := make(map[int32]bool) + + for _, portOverride := range portOverrides { + // Check for duplicate listener ports + listenerPort := portOverride.ListenerPort + if listenerPorts[listenerPort] { + return fmt.Errorf("duplicate listener port %d in port overrides: each listener port can only be used once in port overrides for an endpoint group", listenerPort) + } + listenerPorts[listenerPort] = true + + // Check for duplicate endpoint ports + endpointPort := portOverride.EndpointPort + if endpointPorts[endpointPort] { + return fmt.Errorf("duplicate endpoint port %d in port overrides: each endpoint port can only be used once in port overrides for an endpoint group", endpointPort) + } + endpointPorts[endpointPort] = true + } + + return nil +} + +// validatePortOverrides is a wrapper function that runs all port override validation rules +func (b *defaultEndpointGroupBuilder) validatePortOverrides(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error { + // Validate listener port overrides against listener port ranges + if err := b.validateListenerPortOverrideWithinListenerPortRanges(listener, portOverrides); err != nil { + return err + } + + // Check for duplicate listener and endpoint ports within this endpoint group's port overrides + if err := b.validateNoDuplicatePorts(portOverrides); err != nil { + return err + } + + return nil +} diff --git a/pkg/aga/model_build_endpoint_group_test.go b/pkg/aga/model_build_endpoint_group_test.go new file mode 100644 index 000000000..5e51e389b --- /dev/null +++ b/pkg/aga/model_build_endpoint_group_test.go @@ -0,0 +1,865 @@ +package aga + +import ( + "context" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" + "testing" + + "github.com/stretchr/testify/assert" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +func Test_defaultEndpointGroupBuilder_determineRegion(t *testing.T) { + tests := []struct { + name string + endpointGroup agaapi.GlobalAcceleratorEndpointGroup + clusterRegion string + expectedRegion string + expectError bool + expectErrorString string + }{ + { + name: "region specified in endpoint group", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Region: awssdk.String("us-west-2"), + }, + clusterRegion: "us-east-1", + expectedRegion: "us-west-2", + expectError: false, + }, + { + name: "region specified in endpoint group even with empty cluster region", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Region: awssdk.String("eu-west-1"), + }, + clusterRegion: "", + expectedRegion: "eu-west-1", + expectError: false, + }, + { + name: "region not specified in endpoint group, use cluster region", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{}, + clusterRegion: "us-east-1", + expectedRegion: "us-east-1", + expectError: false, + }, + { + name: "neither region specified nor cluster region available", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{}, + clusterRegion: "", + expectError: true, + expectErrorString: "region is required for endpoint group but neither specified in the endpoint group nor available from cluster configuration", + }, + { + name: "empty region string in endpoint group, fall back to cluster region", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Region: awssdk.String(""), + }, + clusterRegion: "ap-southeast-1", + expectedRegion: "ap-southeast-1", + expectError: false, + }, + { + name: "empty region string in endpoint group and no cluster region", + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + Region: awssdk.String(""), + }, + clusterRegion: "", + expectError: true, + expectErrorString: "region is required for endpoint group but neither specified in the endpoint group nor available from cluster configuration", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create endpointGroupBuilder + builder := &defaultEndpointGroupBuilder{ + clusterRegion: tt.clusterRegion, + } + + // Call determineRegion + region, err := builder.determineRegion(tt.endpointGroup) + + // Check if error was expected + if tt.expectError { + assert.Error(t, err) + if tt.expectErrorString != "" { + assert.Contains(t, err.Error(), tt.expectErrorString) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedRegion, region) + } + }) + } +} + +func Test_defaultEndpointGroupBuilder_buildPortOverrides(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + // Helper function to create a listener with specific ID and port ranges + createTestListener := func(id string, portRanges []agamodel.PortRange) *agamodel.Listener { + return &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", id), + Spec: agamodel.ListenerSpec{ + PortRanges: portRanges, + }, + } + } + + // Helper function to create port overrides + createPortOverrides := func(overrides ...agaapi.PortOverride) *[]agaapi.PortOverride { + if len(overrides) == 0 { + return nil + } + return &overrides + } + + tests := []struct { + name string + listener *agamodel.Listener + endpointGroup agaapi.GlobalAcceleratorEndpointGroup + want []agamodel.PortOverride + expectErr bool + expectErrMatch string + }{ + { + name: "no port overrides", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: nil, + }, + want: nil, + expectErr: false, + }, + { + name: "empty port overrides", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: &[]agaapi.PortOverride{}, + }, + want: nil, + expectErr: false, + }, + { + name: "valid single port override", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 80, + EndpointPort: 8080, + }, + ), + }, + want: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + expectErr: false, + }, + { + name: "valid multiple port overrides", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 80, + EndpointPort: 8080, + }, + agaapi.PortOverride{ + ListenerPort: 443, + EndpointPort: 8443, + }, + ), + }, + want: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + }, + expectErr: false, + }, + { + name: "listener port outside range", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 443, // Not in listener port range + EndpointPort: 8443, + }, + ), + }, + want: nil, + expectErr: true, + expectErrMatch: "port override listener port 443 is not within any listener port ranges", + }, + { + name: "duplicate listener port", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 80, + EndpointPort: 8080, + }, + agaapi.PortOverride{ + ListenerPort: 80, // Duplicate listener port + EndpointPort: 9090, + }, + ), + }, + want: nil, + expectErr: true, + expectErrMatch: "duplicate listener port 80 in port overrides", + }, + { + name: "duplicate endpoint port", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 80, + EndpointPort: 8080, + }, + agaapi.PortOverride{ + ListenerPort: 443, + EndpointPort: 8080, // Duplicate endpoint port + }, + ), + }, + want: nil, + expectErr: true, + expectErrMatch: "duplicate endpoint port 8080 in port overrides", + }, + { + name: "port range check for listener port", + listener: createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + }), + endpointGroup: agaapi.GlobalAcceleratorEndpointGroup{ + PortOverrides: createPortOverrides( + agaapi.PortOverride{ + ListenerPort: 85, // Within the range + EndpointPort: 8085, + }, + ), + }, + want: []agamodel.PortOverride{ + { + ListenerPort: 85, + EndpointPort: 8085, + }, + }, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + // Create endpointGroupBuilder + builder := &defaultEndpointGroupBuilder{ + clusterRegion: "us-west-2", + } + + // Call buildPortOverrides + got, err := builder.buildPortOverrides(ctx, tt.listener, tt.endpointGroup) + + // Check for expected error + if tt.expectErr { + assert.Error(t, err) + if tt.expectErrMatch != "" { + assert.Contains(t, err.Error(), tt.expectErrMatch) + } + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } +} + +func Test_defaultEndpointGroupBuilder_validateEndpointPortOverridesWithinListener(t *testing.T) { + tests := []struct { + name string + endpointPort int32 + listenerPortRanges map[string][]agamodel.PortRange + wantErr bool + expectErrContains string + }{ + { + name: "endpoint port outside all listener port ranges", + endpointPort: 8080, + listenerPortRanges: map[string][]agamodel.PortRange{ + "l-1": { + { + FromPort: 80, + ToPort: 80, + }, + }, + "l-2": { + { + FromPort: 443, + ToPort: 443, + }, + }, + }, + wantErr: false, + }, + { + name: "endpoint port inside a listener port range", + endpointPort: 450, + listenerPortRanges: map[string][]agamodel.PortRange{ + "l-1": { + { + FromPort: 80, + ToPort: 80, + }, + }, + "l-2": { + { + FromPort: 400, + ToPort: 500, // Includes 450 + }, + }, + }, + wantErr: true, + expectErrContains: "endpoint port 450 conflicts with listener l-2 port range 400-500", + }, + { + name: "endpoint port at boundary of listener port range", + endpointPort: 400, // Exactly at FromPort boundary + listenerPortRanges: map[string][]agamodel.PortRange{ + "l-1": { + { + FromPort: 400, + ToPort: 500, + }, + }, + }, + wantErr: true, + expectErrContains: "endpoint port 400 conflicts with listener l-1 port range 400-500", + }, + { + name: "endpoint port at upper boundary of listener port range", + endpointPort: 500, // Exactly at ToPort boundary + listenerPortRanges: map[string][]agamodel.PortRange{ + "l-1": { + { + FromPort: 400, + ToPort: 500, + }, + }, + }, + wantErr: true, + expectErrContains: "endpoint port 500 conflicts with listener l-1 port range 400-500", + }, + { + name: "multiple listener port ranges, endpoint port in one range", + endpointPort: 1024, + listenerPortRanges: map[string][]agamodel.PortRange{ + "l-1": { + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + "l-2": { + { + FromPort: 1000, + ToPort: 2000, // Includes 1024 + }, + { + FromPort: 3000, + ToPort: 4000, + }, + }, + }, + wantErr: true, + expectErrContains: "endpoint port 1024 conflicts with listener l-2 port range 1000-2000", + }, + { + name: "empty listener port ranges", + endpointPort: 8080, + listenerPortRanges: map[string][]agamodel.PortRange{}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + builder := &defaultEndpointGroupBuilder{ + clusterRegion: "us-west-2", + } + err := builder.validateEndpointPortOverridesWithinListener(tt.endpointPort, tt.listenerPortRanges) + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectErrContains) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_defaultEndpointGroupBuilder_validateNoDuplicatePorts(t *testing.T) { + tests := []struct { + name string + portOverrides []agamodel.PortOverride + wantErr bool + expectErrContains string + portType string // "listener" or "endpoint" + }{ + { + name: "no port overrides", + portOverrides: []agamodel.PortOverride{}, + wantErr: false, + }, + { + name: "single port override", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + wantErr: false, + }, + { + name: "multiple port overrides with unique ports", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + { + ListenerPort: 8000, + EndpointPort: 9000, + }, + }, + wantErr: false, + }, + { + name: "multiple port overrides with duplicate listener ports", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 80, // Duplicate listener port + EndpointPort: 9090, // Different endpoint port + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + }, + wantErr: true, + expectErrContains: "duplicate listener port 80 in port overrides", + portType: "listener", + }, + { + name: "multiple duplicate listener ports", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + { + ListenerPort: 80, // First duplicate + EndpointPort: 9090, + }, + { + ListenerPort: 443, // Second duplicate + EndpointPort: 9443, + }, + }, + wantErr: true, + expectErrContains: "duplicate listener port 80 in port overrides", + portType: "listener", + }, + { + name: "multiple port overrides with duplicate endpoint ports", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8080, // Duplicate endpoint port + }, + { + ListenerPort: 8000, + EndpointPort: 9000, + }, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 8080 in port overrides", + portType: "endpoint", + }, + { + name: "multiple duplicate endpoint ports", + portOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + { + ListenerPort: 8000, + EndpointPort: 8080, // First duplicate + }, + { + ListenerPort: 9000, + EndpointPort: 8443, // Second duplicate + }, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 8080 in port overrides", + portType: "endpoint", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + builder := &defaultEndpointGroupBuilder{ + clusterRegion: "us-west-2", + } + err := builder.validateNoDuplicatePorts(tt.portOverrides) + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectErrContains) + } else { + assert.NoError(t, err) + } + }) + } +} + +func Test_defaultEndpointGroupBuilder_validateEndpointPortOverridesCrossListeners(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + // Helper function to create a listener with specific ID and port ranges + createTestListener := func(id string, portRanges []agamodel.PortRange) *agamodel.Listener { + return &agamodel.Listener{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::Listener", id), + Spec: agamodel.ListenerSpec{ + PortRanges: portRanges, + }, + } + } + + // Helper function to create an endpoint group with specific listener and port overrides + createTestEndpointGroup := func(id string, region string, listener *agamodel.Listener, portOverrides []agamodel.PortOverride) *agamodel.EndpointGroup { + return &agamodel.EndpointGroup{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", id), + Listener: listener, + Spec: agamodel.EndpointGroupSpec{ + Region: region, + PortOverrides: portOverrides, + }, + } + } + + tests := []struct { + name string + endpointGroups []*agamodel.EndpointGroup + listenerPortRanges map[string][]agamodel.PortRange + wantErr bool + expectErrContains string + }{ + { + name: "no endpoint groups", + endpointGroups: []*agamodel.EndpointGroup{}, + listenerPortRanges: map[string][]agamodel.PortRange{}, + wantErr: false, + }, + { + name: "single endpoint group - no conflicts possible", + endpointGroups: func() []*agamodel.EndpointGroup { + listener := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + }, + wantErr: false, + }, + { + name: "multiple endpoint groups, same listener - no conflicts", + endpointGroups: func() []*agamodel.EndpointGroup { + listener := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener, []agamodel.PortOverride{ + {ListenerPort: 443, EndpointPort: 8443}, + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": { + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + }, + wantErr: false, + }, + { + name: "multiple endpoint groups, different listeners, no conflicts", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener2, []agamodel.PortOverride{ + {ListenerPort: 443, EndpointPort: 9090}, // Different endpoint port + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + "listener-2": {{FromPort: 443, ToPort: 443}}, + }, + wantErr: false, + }, + { + name: "endpoint port in listener port range - conflict", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, // Range includes 85 + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 85}, // Endpoint port is within listener port range + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 90}}, + }, + wantErr: true, + expectErrContains: "endpoint port 85 conflicts with listener listener-1 port range 80-90", + }, + { + name: "duplicate endpoint port across different listeners - conflict", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener2, []agamodel.PortOverride{ + {ListenerPort: 443, EndpointPort: 8080}, // Same endpoint port as eg-1 + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + "listener-2": {{FromPort: 443, ToPort: 443}}, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 8080: the same endpoint port cannot be used in port overrides from different listeners", + }, + { + name: "multiple duplicate endpoint ports across different listeners", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 8080, ToPort: 8080}, + {FromPort: 8443, ToPort: 8443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 9090}, + {ListenerPort: 443, EndpointPort: 9091}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener2, []agamodel.PortOverride{ + {ListenerPort: 8080, EndpointPort: 9090}, // Duplicate with listener-1 + {ListenerPort: 8443, EndpointPort: 7070}, // Unique + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": { + {FromPort: 80, ToPort: 80}, + {FromPort: 443, ToPort: 443}, + }, + "listener-2": { + {FromPort: 8080, ToPort: 8080}, + {FromPort: 8443, ToPort: 8443}, + }, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 9090: the same endpoint port cannot be used in port overrides from different listeners", + }, + { + name: "multiple endpoint groups with mixed conflicts", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }) + listener3 := createTestListener("listener-3", []agamodel.PortRange{ + {FromPort: 9500, ToPort: 9999}, // Range does not include 8080 + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener2, []agamodel.PortOverride{ + {ListenerPort: 443, EndpointPort: 8080}, // Duplicate endpoint port with listener1 + }), + createTestEndpointGroup("eg-3", "ap-southeast-1", listener3, []agamodel.PortOverride{ + {ListenerPort: 8080, EndpointPort: 9999}, // Endpoint port outside of any listener range + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + "listener-2": {{FromPort: 443, ToPort: 443}}, + "listener-3": {{FromPort: 9500, ToPort: 9999}}, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 8080: the same endpoint port cannot be used in port overrides from different listeners", + }, + { + name: "same endpoint port in different regions - still a conflict", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createTestEndpointGroup("eg-2", "eu-west-1", listener2, []agamodel.PortOverride{ + // Even though in different regions, same port across listeners is still a conflict + {ListenerPort: 443, EndpointPort: 8080}, + }), + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + "listener-2": {{FromPort: 443, ToPort: 443}}, + }, + wantErr: true, + expectErrContains: "duplicate endpoint port 8080: the same endpoint port cannot be used in port overrides from different listeners", + }, + { + name: "endpoint groups with no port overrides - no conflicts", + endpointGroups: func() []*agamodel.EndpointGroup { + listener1 := createTestListener("listener-1", []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }) + listener2 := createTestListener("listener-2", []agamodel.PortRange{ + {FromPort: 443, ToPort: 443}, + }) + return []*agamodel.EndpointGroup{ + createTestEndpointGroup("eg-1", "us-west-2", listener1, nil), // No port overrides + createTestEndpointGroup("eg-2", "eu-west-1", listener2, nil), // No port overrides + } + }(), + listenerPortRanges: map[string][]agamodel.PortRange{ + "listener-1": {{FromPort: 80, ToPort: 80}}, + "listener-2": {{FromPort: 443, ToPort: 443}}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create endpointGroupBuilder + builder := &defaultEndpointGroupBuilder{ + clusterRegion: "us-west-2", + } + + // Run the validation function + err := builder.validateEndpointPortOverridesCrossListeners(tt.endpointGroups, tt.listenerPortRanges) + + // Check if error was expected + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectErrContains) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/aga/model_builder.go b/pkg/aga/model_builder.go index d4938ab29..c3e1caa5b 100644 --- a/pkg/aga/model_builder.go +++ b/pkg/aga/model_builder.go @@ -61,10 +61,10 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccele // Create fresh builder instances for each reconciliation acceleratorBuilder := NewAcceleratorBuilder(b.trackingProvider, b.clusterName, b.clusterRegion, b.defaultTags, b.externalManagedTags, b.featureGates.Enabled(config.EnableDefaultTagsLowPriority)) + listenerBuilder := NewListenerBuilder() + endpointGroupBuilder := NewEndpointGroupBuilder(b.clusterRegion) // TODO - // endpointGroupBuilder := NewEndpointGroupBuilder() // endpointBuilder := NewEndpointBuilder() - // Build Accelerator accelerator, err := acceleratorBuilder.Build(ctx, stack, ga) if err != nil { @@ -74,17 +74,18 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccele // Build Listeners if specified var listeners []*agamodel.Listener if ga.Spec.Listeners != nil { - // Create builder for listeners and endpoints - listenerBuilder := NewListenerBuilder() listeners, err = listenerBuilder.Build(ctx, stack, accelerator, *ga.Spec.Listeners) if err != nil { return nil, nil, err } + endpointGroups, err := endpointGroupBuilder.Build(ctx, stack, listeners, *ga.Spec.Listeners) + if err != nil { + return nil, nil, err + } + b.logger.V(1).Info("Listener and endpoint groups built", "listeners", listeners, "endpointGroups", endpointGroups) } - b.logger.V(1).Info("Listeners built", "listeners", listeners) - // TODO: Add other resource builders - // endpointGroups, err := endpointGroupBuilder.Build(ctx, stack, listeners, ga.Spec.Listeners) + // TODO: Add endpoint builder // endpoints, err := endpointBuilder.Build(ctx, stack, endpointGroups, ga.Spec.Listeners) return stack, accelerator, nil diff --git a/pkg/aga/reference_tracker.go b/pkg/aga/reference_tracker.go index 0685f08a2..eb870f93d 100644 --- a/pkg/aga/reference_tracker.go +++ b/pkg/aga/reference_tracker.go @@ -34,8 +34,8 @@ func NewReferenceTracker(logger logr.Logger) *ReferenceTracker { } } -// UpdateReferencesForGA updates the tracking information for a GlobalAccelerator -func (t *ReferenceTracker) UpdateReferencesForGA(ga *agaapi.GlobalAccelerator, endpoints []EndpointReference) { +// UpdateDesiredEndpointReferencesForGA updates the tracking information for a GlobalAccelerator +func (t *ReferenceTracker) UpdateDesiredEndpointReferencesForGA(ga *agaapi.GlobalAccelerator, desiredEndpoints []EndpointReference) { t.mutex.Lock() defer t.mutex.Unlock() @@ -45,7 +45,7 @@ func (t *ReferenceTracker) UpdateReferencesForGA(ga *agaapi.GlobalAccelerator, e currentResources := sets.New[ResourceKey]() // Process each endpoint - for _, endpoint := range endpoints { + for _, endpoint := range desiredEndpoints { resourceKey := endpoint.ToResourceKey() currentResources.Insert(resourceKey) diff --git a/pkg/aga/reference_tracker_test.go b/pkg/aga/reference_tracker_test.go index b7873294e..566666d0c 100644 --- a/pkg/aga/reference_tracker_test.go +++ b/pkg/aga/reference_tracker_test.go @@ -147,9 +147,9 @@ func TestReferenceTracker_UpdateReferencesForGA(t *testing.T) { // Create tracker tracker := NewReferenceTracker(logr.Discard()) - endpoints := GetAllEndpointsFromGA(tt.ga) + endpoints := GetAllDesiredEndpointsFromGA(tt.ga) // Update references - tracker.UpdateReferencesForGA(tt.ga, endpoints) + tracker.UpdateDesiredEndpointReferencesForGA(tt.ga, endpoints) // Check number of tracked resources gaKey := types.NamespacedName{Namespace: tt.ga.Namespace, Name: tt.ga.Name} @@ -212,8 +212,8 @@ func TestReferenceTracker_UpdateReferencesForGA_RemoveStaleReferences(t *testing // Create tracker and add initial references tracker := NewReferenceTracker(logr.Discard()) - endpoints := GetAllEndpointsFromGA(ga) - tracker.UpdateReferencesForGA(ga, endpoints) + endpoints := GetAllDesiredEndpointsFromGA(ga) + tracker.UpdateDesiredEndpointReferencesForGA(ga, endpoints) // Verify initial state service1Key := ResourceKey{ @@ -256,8 +256,8 @@ func TestReferenceTracker_UpdateReferencesForGA_RemoveStaleReferences(t *testing } // Update references with modified GA - endpoints = GetAllEndpointsFromGA(ga) - tracker.UpdateReferencesForGA(ga, endpoints) + endpoints = GetAllDesiredEndpointsFromGA(ga) + tracker.UpdateDesiredEndpointReferencesForGA(ga, endpoints) // Verify that service1 is still referenced, service2 is no longer referenced, and service3 is now referenced assert.True(t, tracker.IsResourceReferenced(service1Key)) @@ -328,10 +328,10 @@ func TestReferenceTracker_RemoveGA(t *testing.T) { // Create tracker and add references from both GAs tracker := NewReferenceTracker(logr.Discard()) - endpoints1 := GetAllEndpointsFromGA(ga1) - endpoints2 := GetAllEndpointsFromGA(ga2) - tracker.UpdateReferencesForGA(ga1, endpoints1) - tracker.UpdateReferencesForGA(ga2, endpoints2) + endpoints1 := GetAllDesiredEndpointsFromGA(ga1) + endpoints2 := GetAllDesiredEndpointsFromGA(ga2) + tracker.UpdateDesiredEndpointReferencesForGA(ga1, endpoints1) + tracker.UpdateDesiredEndpointReferencesForGA(ga2, endpoints2) // Resource keys service1Key := ResourceKey{ @@ -406,8 +406,8 @@ func TestReferenceTracker_IsResourceReferenced(t *testing.T) { // Create tracker and add references tracker := NewReferenceTracker(logr.Discard()) - endpoints := GetAllEndpointsFromGA(ga) - tracker.UpdateReferencesForGA(ga, endpoints) + endpoints := GetAllDesiredEndpointsFromGA(ga) + tracker.UpdateDesiredEndpointReferencesForGA(ga, endpoints) // Resource keys - one that exists and one that doesn't existingResourceKey := ResourceKey{ @@ -479,10 +479,10 @@ func TestReferenceTracker_GetGAsForResource(t *testing.T) { // Create tracker and add references tracker := NewReferenceTracker(logr.Discard()) - endpoints1 := GetAllEndpointsFromGA(ga1) - endpoints2 := GetAllEndpointsFromGA(ga2) - tracker.UpdateReferencesForGA(ga1, endpoints1) - tracker.UpdateReferencesForGA(ga2, endpoints2) + endpoints1 := GetAllDesiredEndpointsFromGA(ga1) + endpoints2 := GetAllDesiredEndpointsFromGA(ga2) + tracker.UpdateDesiredEndpointReferencesForGA(ga1, endpoints1) + tracker.UpdateDesiredEndpointReferencesForGA(ga2, endpoints2) // Resource key for the shared service sharedServiceKey := ResourceKey{ diff --git a/pkg/aga/utils.go b/pkg/aga/utils.go index 1f067e25e..f4e96ee37 100644 --- a/pkg/aga/utils.go +++ b/pkg/aga/utils.go @@ -4,8 +4,19 @@ import ( "strings" "sigs.k8s.io/aws-load-balancer-controller/pkg/config" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" ) +// IsPortInRanges checks if a port is within any of the specified port ranges +func IsPortInRanges(port int32, portRanges []agamodel.PortRange) bool { + for _, portRange := range portRanges { + if portRange.FromPort <= port && port <= portRange.ToPort { + return true + } + } + return false +} + // IsAGAControllerEnabled checks if the AGA controller is both enabled via feature gate // and if the region is in a partition that supports Global Accelerator func IsAGAControllerEnabled(featureGates config.FeatureGates, region string) bool { diff --git a/pkg/aga/utils_test.go b/pkg/aga/utils_test.go index 6bfa40c5a..e029259c4 100644 --- a/pkg/aga/utils_test.go +++ b/pkg/aga/utils_test.go @@ -3,7 +3,9 @@ package aga import ( "testing" + "github.com/stretchr/testify/assert" "sigs.k8s.io/aws-load-balancer-controller/pkg/config" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" ) func TestIsAGAControllerEnabled(t *testing.T) { @@ -93,3 +95,125 @@ func TestIsAGAControllerEnabled(t *testing.T) { }) } } + +func TestIsPortInRanges(t *testing.T) { + tests := []struct { + name string + port int32 + portRanges []agamodel.PortRange + expected bool + }{ + { + name: "port within single range", + port: 85, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + }, + expected: true, + }, + { + name: "port at lower boundary", + port: 80, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + }, + expected: true, + }, + { + name: "port at upper boundary", + port: 100, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + }, + expected: true, + }, + { + name: "port below range", + port: 79, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + }, + expected: false, + }, + { + name: "port above range", + port: 101, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + }, + expected: false, + }, + { + name: "port within one of multiple ranges", + port: 443, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8090, + }, + }, + expected: true, + }, + { + name: "port not within any of multiple ranges", + port: 8000, + portRanges: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 100, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8090, + }, + }, + expected: false, + }, + { + name: "empty port ranges", + port: 80, + portRanges: []agamodel.PortRange{}, + expected: false, + }, + { + name: "nil port ranges", + port: 80, + portRanges: nil, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsPortInRanges(tt.port, tt.portRanges) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/aws/services/globalaccelerator.go b/pkg/aws/services/globalaccelerator.go index 364f12d70..30e627307 100644 --- a/pkg/aws/services/globalaccelerator.go +++ b/pkg/aws/services/globalaccelerator.go @@ -41,6 +41,21 @@ type GlobalAccelerator interface { // ListListenersForAccelerator lists all listeners for an accelerator. ListListenersForAcceleratorWithContext(ctx context.Context, input *globalaccelerator.ListListenersInput) (*globalaccelerator.ListListenersOutput, error) + // CreateEndpointGroup creates a new endpoint group. + CreateEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.CreateEndpointGroupInput) (*globalaccelerator.CreateEndpointGroupOutput, error) + + // DescribeEndpointGroup describes an endpoint group. + DescribeEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.DescribeEndpointGroupInput) (*globalaccelerator.DescribeEndpointGroupOutput, error) + + // UpdateEndpointGroup updates an endpoint group. + UpdateEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.UpdateEndpointGroupInput) (*globalaccelerator.UpdateEndpointGroupOutput, error) + + // DeleteEndpointGroup deletes an endpoint group. + DeleteEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.DeleteEndpointGroupInput) (*globalaccelerator.DeleteEndpointGroupOutput, error) + + // wrapper to ListEndpointGroups API, which aggregates paged results into list. + ListEndpointGroupsAsList(ctx context.Context, input *globalaccelerator.ListEndpointGroupsInput) ([]types.EndpointGroup, error) + // TagResource tags a resource. TagResourceWithContext(ctx context.Context, input *globalaccelerator.TagResourceInput) (*globalaccelerator.TagResourceOutput, error) @@ -192,3 +207,52 @@ func (c *defaultGlobalAccelerator) ListListenersAsList(ctx context.Context, inpu } return result, nil } + +func (c *defaultGlobalAccelerator) CreateEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.CreateEndpointGroupInput) (*globalaccelerator.CreateEndpointGroupOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "CreateEndpointGroup") + if err != nil { + return nil, err + } + return client.CreateEndpointGroup(ctx, input) +} + +func (c *defaultGlobalAccelerator) DescribeEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.DescribeEndpointGroupInput) (*globalaccelerator.DescribeEndpointGroupOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "DescribeEndpointGroup") + if err != nil { + return nil, err + } + return client.DescribeEndpointGroup(ctx, input) +} + +func (c *defaultGlobalAccelerator) UpdateEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.UpdateEndpointGroupInput) (*globalaccelerator.UpdateEndpointGroupOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "UpdateEndpointGroup") + if err != nil { + return nil, err + } + return client.UpdateEndpointGroup(ctx, input) +} + +func (c *defaultGlobalAccelerator) DeleteEndpointGroupWithContext(ctx context.Context, input *globalaccelerator.DeleteEndpointGroupInput) (*globalaccelerator.DeleteEndpointGroupOutput, error) { + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "DeleteEndpointGroup") + if err != nil { + return nil, err + } + return client.DeleteEndpointGroup(ctx, input) +} + +func (c *defaultGlobalAccelerator) ListEndpointGroupsAsList(ctx context.Context, input *globalaccelerator.ListEndpointGroupsInput) ([]types.EndpointGroup, error) { + var result []types.EndpointGroup + client, err := c.awsClientsProvider.GetGlobalAcceleratorClient(ctx, "ListEndpointGroups") + if err != nil { + return nil, err + } + paginator := globalaccelerator.NewListEndpointGroupsPaginator(client, input) + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return nil, err + } + result = append(result, output.EndpointGroups...) + } + return result, nil +} diff --git a/pkg/aws/services/globalaccelerator_mocks.go b/pkg/aws/services/globalaccelerator_mocks.go index e4989fa97..0bad79b37 100644 --- a/pkg/aws/services/globalaccelerator_mocks.go +++ b/pkg/aws/services/globalaccelerator_mocks.go @@ -51,6 +51,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) CreateAcceleratorWithContext(arg0, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).CreateAcceleratorWithContext), arg0, arg1) } +// CreateEndpointGroupWithContext mocks base method. +func (m *MockGlobalAccelerator) CreateEndpointGroupWithContext(arg0 context.Context, arg1 *globalaccelerator.CreateEndpointGroupInput) (*globalaccelerator.CreateEndpointGroupOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateEndpointGroupWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.CreateEndpointGroupOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateEndpointGroupWithContext indicates an expected call of CreateEndpointGroupWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) CreateEndpointGroupWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEndpointGroupWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).CreateEndpointGroupWithContext), arg0, arg1) +} + // CreateListenerWithContext mocks base method. func (m *MockGlobalAccelerator) CreateListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.CreateListenerInput) (*globalaccelerator.CreateListenerOutput, error) { m.ctrl.T.Helper() @@ -81,6 +96,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) DeleteAcceleratorWithContext(arg0, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DeleteAcceleratorWithContext), arg0, arg1) } +// DeleteEndpointGroupWithContext mocks base method. +func (m *MockGlobalAccelerator) DeleteEndpointGroupWithContext(arg0 context.Context, arg1 *globalaccelerator.DeleteEndpointGroupInput) (*globalaccelerator.DeleteEndpointGroupOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteEndpointGroupWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.DeleteEndpointGroupOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteEndpointGroupWithContext indicates an expected call of DeleteEndpointGroupWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) DeleteEndpointGroupWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEndpointGroupWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DeleteEndpointGroupWithContext), arg0, arg1) +} + // DeleteListenerWithContext mocks base method. func (m *MockGlobalAccelerator) DeleteListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.DeleteListenerInput) (*globalaccelerator.DeleteListenerOutput, error) { m.ctrl.T.Helper() @@ -111,6 +141,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) DescribeAcceleratorWithContext(arg0 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DescribeAcceleratorWithContext), arg0, arg1) } +// DescribeEndpointGroupWithContext mocks base method. +func (m *MockGlobalAccelerator) DescribeEndpointGroupWithContext(arg0 context.Context, arg1 *globalaccelerator.DescribeEndpointGroupInput) (*globalaccelerator.DescribeEndpointGroupOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DescribeEndpointGroupWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.DescribeEndpointGroupOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DescribeEndpointGroupWithContext indicates an expected call of DescribeEndpointGroupWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) DescribeEndpointGroupWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeEndpointGroupWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).DescribeEndpointGroupWithContext), arg0, arg1) +} + // DescribeListenerWithContext mocks base method. func (m *MockGlobalAccelerator) DescribeListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.DescribeListenerInput) (*globalaccelerator.DescribeListenerOutput, error) { m.ctrl.T.Helper() @@ -141,6 +186,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) ListAcceleratorsAsList(arg0, arg1 i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAcceleratorsAsList", reflect.TypeOf((*MockGlobalAccelerator)(nil).ListAcceleratorsAsList), arg0, arg1) } +// ListEndpointGroupsAsList mocks base method. +func (m *MockGlobalAccelerator) ListEndpointGroupsAsList(arg0 context.Context, arg1 *globalaccelerator.ListEndpointGroupsInput) ([]types.EndpointGroup, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListEndpointGroupsAsList", arg0, arg1) + ret0, _ := ret[0].([]types.EndpointGroup) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListEndpointGroupsAsList indicates an expected call of ListEndpointGroupsAsList. +func (mr *MockGlobalAcceleratorMockRecorder) ListEndpointGroupsAsList(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEndpointGroupsAsList", reflect.TypeOf((*MockGlobalAccelerator)(nil).ListEndpointGroupsAsList), arg0, arg1) +} + // ListListenersAsList mocks base method. func (m *MockGlobalAccelerator) ListListenersAsList(arg0 context.Context, arg1 *globalaccelerator.ListListenersInput) ([]types.Listener, error) { m.ctrl.T.Helper() @@ -231,6 +291,21 @@ func (mr *MockGlobalAcceleratorMockRecorder) UpdateAcceleratorWithContext(arg0, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAcceleratorWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).UpdateAcceleratorWithContext), arg0, arg1) } +// UpdateEndpointGroupWithContext mocks base method. +func (m *MockGlobalAccelerator) UpdateEndpointGroupWithContext(arg0 context.Context, arg1 *globalaccelerator.UpdateEndpointGroupInput) (*globalaccelerator.UpdateEndpointGroupOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateEndpointGroupWithContext", arg0, arg1) + ret0, _ := ret[0].(*globalaccelerator.UpdateEndpointGroupOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateEndpointGroupWithContext indicates an expected call of UpdateEndpointGroupWithContext. +func (mr *MockGlobalAcceleratorMockRecorder) UpdateEndpointGroupWithContext(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEndpointGroupWithContext", reflect.TypeOf((*MockGlobalAccelerator)(nil).UpdateEndpointGroupWithContext), arg0, arg1) +} + // UpdateListenerWithContext mocks base method. func (m *MockGlobalAccelerator) UpdateListenerWithContext(arg0 context.Context, arg1 *globalaccelerator.UpdateListenerInput) (*globalaccelerator.UpdateListenerOutput, error) { m.ctrl.T.Helper() diff --git a/pkg/deploy/aga/endpoint_group_manager.go b/pkg/deploy/aga/endpoint_group_manager.go new file mode 100644 index 000000000..263056e48 --- /dev/null +++ b/pkg/deploy/aga/endpoint_group_manager.go @@ -0,0 +1,240 @@ +package aga + +import ( + "context" + "errors" + "fmt" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +// EndpointGroupManager is responsible for managing AWS Global Accelerator endpoint groups. +type EndpointGroupManager interface { + // Create creates an endpoint group. + Create(ctx context.Context, resEndpointGroup *agamodel.EndpointGroup) (agamodel.EndpointGroupStatus, error) + + // Update updates an endpoint group. + Update(ctx context.Context, resEndpointGroup *agamodel.EndpointGroup, sdkEndpointGroup *agatypes.EndpointGroup) (agamodel.EndpointGroupStatus, error) + + // Delete deletes an endpoint group. + Delete(ctx context.Context, endpointGroupARN string) error +} + +// NewDefaultEndpointGroupManager constructs new defaultEndpointGroupManager. +func NewDefaultEndpointGroupManager(gaService services.GlobalAccelerator, logger logr.Logger) *defaultEndpointGroupManager { + return &defaultEndpointGroupManager{ + gaService: gaService, + logger: logger, + } +} + +var _ EndpointGroupManager = &defaultEndpointGroupManager{} + +// defaultEndpointGroupManager is the default implementation for EndpointGroupManager. +type defaultEndpointGroupManager struct { + gaService services.GlobalAccelerator + logger logr.Logger +} + +// buildSDKPortOverrides converts model port overrides to SDK port overrides +func (m *defaultEndpointGroupManager) buildSDKPortOverrides(modelPortOverrides []agamodel.PortOverride) []agatypes.PortOverride { + if len(modelPortOverrides) == 0 { + return nil + } + + portOverrides := make([]agatypes.PortOverride, 0, len(modelPortOverrides)) + for _, po := range modelPortOverrides { + portOverrides = append(portOverrides, agatypes.PortOverride{ + ListenerPort: awssdk.Int32(po.ListenerPort), + EndpointPort: awssdk.Int32(po.EndpointPort), + }) + } + return portOverrides +} + +func (m *defaultEndpointGroupManager) buildSDKCreateEndpointGroupInput(_ context.Context, resEndpointGroup *agamodel.EndpointGroup) (*globalaccelerator.CreateEndpointGroupInput, error) { + // Resolve listener ARN + listenerARN, err := resEndpointGroup.Spec.ListenerARN.Resolve(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to resolve listener ARN: %w", err) + } + + // Build create input + createInput := &globalaccelerator.CreateEndpointGroupInput{ + ListenerArn: awssdk.String(listenerARN), + EndpointGroupRegion: awssdk.String(resEndpointGroup.Spec.Region), + } + + // Convert TrafficDialPercentage from int32 to float32 if provided + if resEndpointGroup.Spec.TrafficDialPercentage != nil { + createInput.TrafficDialPercentage = awssdk.Float32(float32(*resEndpointGroup.Spec.TrafficDialPercentage)) + } + + // Add port overrides if specified + if len(resEndpointGroup.Spec.PortOverrides) > 0 { + createInput.PortOverrides = m.buildSDKPortOverrides(resEndpointGroup.Spec.PortOverrides) + } + + return createInput, nil +} + +func (m *defaultEndpointGroupManager) Create(ctx context.Context, resEndpointGroup *agamodel.EndpointGroup) (agamodel.EndpointGroupStatus, error) { + // Build create input + createInput, err := m.buildSDKCreateEndpointGroupInput(ctx, resEndpointGroup) + if err != nil { + return agamodel.EndpointGroupStatus{}, err + } + + // Create endpoint group + m.logger.V(1).Info("Creating endpoint group", + "stackID", resEndpointGroup.Stack().StackID(), + "resourceID", resEndpointGroup.ID()) + + createOutput, err := m.gaService.CreateEndpointGroupWithContext(ctx, createInput) + if err != nil { + return agamodel.EndpointGroupStatus{}, fmt.Errorf("failed to create endpoint group: %w", err) + } + + endpointGroup := createOutput.EndpointGroup + m.logger.Info("Successfully created endpoint group", + "stackID", resEndpointGroup.Stack().StackID(), + "resourceID", resEndpointGroup.ID(), + "endpointGroupARN", *endpointGroup.EndpointGroupArn) + + return agamodel.EndpointGroupStatus{ + EndpointGroupARN: *endpointGroup.EndpointGroupArn, + }, nil +} + +func (m *defaultEndpointGroupManager) buildSDKUpdateEndpointGroupInput(_ context.Context, resEndpointGroup *agamodel.EndpointGroup, sdkEndpointGroup *agatypes.EndpointGroup) (*globalaccelerator.UpdateEndpointGroupInput, error) { + // Build update input + updateInput := &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: sdkEndpointGroup.EndpointGroupArn, + } + + // Convert TrafficDialPercentage from int32 to float32 if provided + if resEndpointGroup.Spec.TrafficDialPercentage != nil { + updateInput.TrafficDialPercentage = awssdk.Float32(float32(*resEndpointGroup.Spec.TrafficDialPercentage)) + } + + // Add port overrides if specified + if len(resEndpointGroup.Spec.PortOverrides) > 0 { + updateInput.PortOverrides = m.buildSDKPortOverrides(resEndpointGroup.Spec.PortOverrides) + } + + return updateInput, nil +} + +func (m *defaultEndpointGroupManager) Update(ctx context.Context, resEndpointGroup *agamodel.EndpointGroup, sdkEndpointGroup *agatypes.EndpointGroup) (agamodel.EndpointGroupStatus, error) { + // Check if the endpoint group actually needs an update + if !m.isSDKEndpointGroupSettingsDrifted(resEndpointGroup, sdkEndpointGroup) { + m.logger.Info("No drift detected in endpoint group settings, skipping update", + "stackID", resEndpointGroup.Stack().StackID(), + "resourceID", resEndpointGroup.ID(), + "endpointGroupARN", *sdkEndpointGroup.EndpointGroupArn) + + return agamodel.EndpointGroupStatus{ + EndpointGroupARN: *sdkEndpointGroup.EndpointGroupArn, + }, nil + } + + m.logger.Info("Drift detected in endpoint group settings, updating", + "stackID", resEndpointGroup.Stack().StackID(), + "resourceID", resEndpointGroup.ID(), + "endpointGroupARN", *sdkEndpointGroup.EndpointGroupArn) + + // Build update input + updateInput, err := m.buildSDKUpdateEndpointGroupInput(ctx, resEndpointGroup, sdkEndpointGroup) + if err != nil { + return agamodel.EndpointGroupStatus{}, err + } + + // Update endpoint group + updateOutput, err := m.gaService.UpdateEndpointGroupWithContext(ctx, updateInput) + if err != nil { + return agamodel.EndpointGroupStatus{}, fmt.Errorf("failed to update endpoint group: %w", err) + } + + updatedEndpointGroup := updateOutput.EndpointGroup + m.logger.Info("Successfully updated endpoint group", + "stackID", resEndpointGroup.Stack().StackID(), + "resourceID", resEndpointGroup.ID(), + "endpointGroupARN", *updatedEndpointGroup.EndpointGroupArn) + + return agamodel.EndpointGroupStatus{ + EndpointGroupARN: *updatedEndpointGroup.EndpointGroupArn, + }, nil +} + +func (m *defaultEndpointGroupManager) Delete(ctx context.Context, endpointGroupARN string) error { + m.logger.Info("Deleting endpoint group", "endpointGroupARN", endpointGroupARN) + + deleteInput := &globalaccelerator.DeleteEndpointGroupInput{ + EndpointGroupArn: awssdk.String(endpointGroupARN), + } + + if _, err := m.gaService.DeleteEndpointGroupWithContext(ctx, deleteInput); err != nil { + // Check if it's a not found error - the endpoint group might have been already deleted + var apiErr *agatypes.EndpointGroupNotFoundException + if errors.As(err, &apiErr) { + m.logger.Info("Endpoint group already deleted", "endpointGroupARN", endpointGroupARN) + return nil + } + return fmt.Errorf("failed to delete endpoint group: %w", err) + } + + m.logger.Info("Successfully deleted endpoint group", "endpointGroupARN", endpointGroupARN) + return nil +} + +// isSDKEndpointGroupSettingsDrifted checks if the endpoint group configuration has drifted from the desired state +func (m *defaultEndpointGroupManager) isSDKEndpointGroupSettingsDrifted(resEndpointGroup *agamodel.EndpointGroup, sdkEndpointGroup *agatypes.EndpointGroup) bool { + // Cannot change region after creation, so we don't check for region drift + + // Check traffic dial percentage + if resEndpointGroup.Spec.TrafficDialPercentage != nil { + resTrafficDialPercentage := float32(*resEndpointGroup.Spec.TrafficDialPercentage) + sdkTrafficDialPercentage := awssdk.ToFloat32(sdkEndpointGroup.TrafficDialPercentage) + // Use a small epsilon for float comparison to avoid precision issues + const epsilon = 0.001 + if resTrafficDialPercentage < sdkTrafficDialPercentage-epsilon || resTrafficDialPercentage > sdkTrafficDialPercentage+epsilon { + return true + } + } else if sdkEndpointGroup.TrafficDialPercentage != nil { + // Resource has no traffic dial percentage but SDK does + return true + } + + // Check port overrides + if !m.arePortOverridesEqual(resEndpointGroup.Spec.PortOverrides, sdkEndpointGroup.PortOverrides) { + return true + } + + return false +} + +// arePortOverridesEqual compares port overrides from the resource model and SDK +func (m *defaultEndpointGroupManager) arePortOverridesEqual(modelPortOverrides []agamodel.PortOverride, sdkPortOverrides []agatypes.PortOverride) bool { + if len(modelPortOverrides) != len(sdkPortOverrides) { + return false + } + + // Convert to maps for easier comparison + modelMap := make(map[int32]int32) + for _, po := range modelPortOverrides { + modelMap[po.ListenerPort] = po.EndpointPort + } + + // Check if all SDK port overrides match the model + for _, po := range sdkPortOverrides { + if modelEndpointPort, exists := modelMap[awssdk.ToInt32(po.ListenerPort)]; !exists || modelEndpointPort != awssdk.ToInt32(po.EndpointPort) { + return false + } + } + + return true +} diff --git a/pkg/deploy/aga/endpoint_group_manager_mocks.go b/pkg/deploy/aga/endpoint_group_manager_mocks.go new file mode 100644 index 000000000..d3109662c --- /dev/null +++ b/pkg/deploy/aga/endpoint_group_manager_mocks.go @@ -0,0 +1,81 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga (interfaces: EndpointGroupManager) + +// Package aga is a generated GoMock package. +package aga + +import ( + context "context" + reflect "reflect" + + types "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + gomock "github.com/golang/mock/gomock" + aga "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +// MockEndpointGroupManager is a mock of EndpointGroupManager interface. +type MockEndpointGroupManager struct { + ctrl *gomock.Controller + recorder *MockEndpointGroupManagerMockRecorder +} + +// MockEndpointGroupManagerMockRecorder is the mock recorder for MockEndpointGroupManager. +type MockEndpointGroupManagerMockRecorder struct { + mock *MockEndpointGroupManager +} + +// NewMockEndpointGroupManager creates a new mock instance. +func NewMockEndpointGroupManager(ctrl *gomock.Controller) *MockEndpointGroupManager { + mock := &MockEndpointGroupManager{ctrl: ctrl} + mock.recorder = &MockEndpointGroupManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockEndpointGroupManager) EXPECT() *MockEndpointGroupManagerMockRecorder { + return m.recorder +} + +// Create mocks base method. +func (m *MockEndpointGroupManager) Create(arg0 context.Context, arg1 *aga.EndpointGroup) (aga.EndpointGroupStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Create", arg0, arg1) + ret0, _ := ret[0].(aga.EndpointGroupStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockEndpointGroupManagerMockRecorder) Create(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockEndpointGroupManager)(nil).Create), arg0, arg1) +} + +// Delete mocks base method. +func (m *MockEndpointGroupManager) Delete(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockEndpointGroupManagerMockRecorder) Delete(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockEndpointGroupManager)(nil).Delete), arg0, arg1) +} + +// Update mocks base method. +func (m *MockEndpointGroupManager) Update(arg0 context.Context, arg1 *aga.EndpointGroup, arg2 *types.EndpointGroup) (aga.EndpointGroupStatus, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", arg0, arg1, arg2) + ret0, _ := ret[0].(aga.EndpointGroupStatus) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockEndpointGroupManagerMockRecorder) Update(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockEndpointGroupManager)(nil).Update), arg0, arg1, arg2) +} diff --git a/pkg/deploy/aga/endpoint_group_manager_test.go b/pkg/deploy/aga/endpoint_group_manager_test.go new file mode 100644 index 000000000..53959afa8 --- /dev/null +++ b/pkg/deploy/aga/endpoint_group_manager_test.go @@ -0,0 +1,480 @@ +package aga + +import ( + "context" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" +) + +func Test_defaultEndpointGroupManager_buildSDKPortOverrides(t *testing.T) { + tests := []struct { + name string + modelPortOverrides []agamodel.PortOverride + want []agatypes.PortOverride + }{ + { + name: "nil model port overrides", + modelPortOverrides: nil, + want: nil, + }, + { + name: "empty model port overrides", + modelPortOverrides: []agamodel.PortOverride{}, + want: nil, + }, + { + name: "single port override", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + want: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + }, + { + name: "multiple port overrides", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + }, + want: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + { + ListenerPort: aws.Int32(443), + EndpointPort: aws.Int32(8443), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := logr.Discard() + m := &defaultEndpointGroupManager{ + logger: logger, + } + got := m.buildSDKPortOverrides(tt.modelPortOverrides) + + // Compare nil vs nil + if tt.want == nil && got == nil { + // Both are nil, this is correct + return + } + + // Compare lengths + assert.Equal(t, len(tt.want), len(got)) + + // Compare individual elements + for i, wantPO := range tt.want { + assert.Equal(t, awssdk.ToInt32(wantPO.ListenerPort), awssdk.ToInt32(got[i].ListenerPort)) + assert.Equal(t, awssdk.ToInt32(wantPO.EndpointPort), awssdk.ToInt32(got[i].EndpointPort)) + } + }) + } +} + +func Test_defaultEndpointGroupManager_arePortOverridesEqual(t *testing.T) { + tests := []struct { + name string + modelPortOverrides []agamodel.PortOverride + sdkPortOverrides []agatypes.PortOverride + want bool + }{ + { + name: "both nil - equal", + modelPortOverrides: nil, + sdkPortOverrides: nil, + want: true, + }, + { + name: "one empty one nil - equal", + modelPortOverrides: nil, + sdkPortOverrides: []agatypes.PortOverride{}, + want: true, + }, + { + name: "both empty - equal", + modelPortOverrides: []agamodel.PortOverride{}, + sdkPortOverrides: []agatypes.PortOverride{}, + want: true, + }, + { + name: "model empty, sdk not empty - not equal", + modelPortOverrides: []agamodel.PortOverride{}, + sdkPortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + want: false, + }, + { + name: "model not empty, sdk empty - not equal", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + sdkPortOverrides: []agatypes.PortOverride{}, + want: false, + }, + { + name: "different lengths - not equal", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + sdkPortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + { + ListenerPort: aws.Int32(443), + EndpointPort: aws.Int32(8443), + }, + }, + want: false, + }, + { + name: "same length but different values - not equal", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + sdkPortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(9090), // Different endpoint port + }, + }, + want: false, + }, + { + name: "same values, same order - equal", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + }, + sdkPortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + { + ListenerPort: aws.Int32(443), + EndpointPort: aws.Int32(8443), + }, + }, + want: true, + }, + { + name: "same values, different order - equal (order doesn't matter)", + modelPortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 443, + EndpointPort: 8443, + }, + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + sdkPortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + { + ListenerPort: aws.Int32(443), + EndpointPort: aws.Int32(8443), + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := logr.Discard() + m := &defaultEndpointGroupManager{ + logger: logger, + } + got := m.arePortOverridesEqual(tt.modelPortOverrides, tt.sdkPortOverrides) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_defaultEndpointGroupManager_isSDKEndpointGroupSettingsDrifted(t *testing.T) { + tests := []struct { + name string + resEndpointGroup *agamodel.EndpointGroup + sdkEndpointGroup *agatypes.EndpointGroup + want bool + }{ + { + name: "no drift - all values match", + resEndpointGroup: &agamodel.EndpointGroup{ + Spec: agamodel.EndpointGroupSpec{ + Region: "us-west-2", + TrafficDialPercentage: aws.Int32(100), + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + }, + }, + sdkEndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(100.0), + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + }, + want: false, + }, + { + name: "traffic dial percentage differs", + resEndpointGroup: &agamodel.EndpointGroup{ + Spec: agamodel.EndpointGroupSpec{ + Region: "us-west-2", + TrafficDialPercentage: aws.Int32(50), + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + }, + }, + sdkEndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(100.0), + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + }, + want: true, + }, + { + name: "traffic dial percentage small difference within epsilon - no drift", + resEndpointGroup: &agamodel.EndpointGroup{ + Spec: agamodel.EndpointGroupSpec{ + Region: "us-west-2", + TrafficDialPercentage: aws.Int32(100), + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + }, + }, + sdkEndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(100.0005), // Small difference within epsilon + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + }, + want: false, + }, + { + name: "port overrides differ", + resEndpointGroup: &agamodel.EndpointGroup{ + Spec: agamodel.EndpointGroupSpec{ + Region: "us-west-2", + TrafficDialPercentage: aws.Int32(100), + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + }, + }, + sdkEndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(100.0), + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(9090), // Different endpoint port + }, + }, + }, + want: true, + }, + { + name: "model has no traffic dial percentage, sdk does - drift", + resEndpointGroup: &agamodel.EndpointGroup{ + Spec: agamodel.EndpointGroupSpec{ + Region: "us-west-2", + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + }, + }, + }, + sdkEndpointGroup: &agatypes.EndpointGroup{ + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(100.0), + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logger := logr.Discard() + m := &defaultEndpointGroupManager{ + logger: logger, + } + got := m.isSDKEndpointGroupSettingsDrifted(tt.resEndpointGroup, tt.sdkEndpointGroup) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_defaultEndpointGroupManager_buildSDKCreateEndpointGroupInput(t *testing.T) { + testListenerARN := "arn:aws:globalaccelerator::123456789012:listener/1234abcd-abcd-1234-abcd-1234abcdefgh/abcdefghi" + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + tests := []struct { + name string + resEndpointGroup *agamodel.EndpointGroup + want *globalaccelerator.CreateEndpointGroupInput + wantErr bool + }{ + { + name: "Standard endpoint group with all fields", + resEndpointGroup: &agamodel.EndpointGroup{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-1"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + TrafficDialPercentage: aws.Int32(85), + PortOverrides: []agamodel.PortOverride{ + { + ListenerPort: 80, + EndpointPort: 8080, + }, + { + ListenerPort: 443, + EndpointPort: 8443, + }, + }, + }, + }, + want: &globalaccelerator.CreateEndpointGroupInput{ + ListenerArn: aws.String(testListenerARN), + EndpointGroupRegion: aws.String("us-west-2"), + TrafficDialPercentage: aws.Float32(85.0), + PortOverrides: []agatypes.PortOverride{ + { + ListenerPort: aws.Int32(80), + EndpointPort: aws.Int32(8080), + }, + { + ListenerPort: aws.Int32(443), + EndpointPort: aws.Int32(8443), + }, + }, + }, + wantErr: false, + }, + { + name: "Minimal endpoint group", + resEndpointGroup: &agamodel.EndpointGroup{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-2"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + // No TrafficDialPercentage or PortOverrides + }, + }, + want: &globalaccelerator.CreateEndpointGroupInput{ + ListenerArn: aws.String(testListenerARN), + EndpointGroupRegion: aws.String("us-west-2"), + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create endpoint group manager + m := &defaultEndpointGroupManager{ + gaService: nil, // Not needed for this test + logger: logr.Discard(), + } + + // Call the method being tested + got, err := m.buildSDKCreateEndpointGroupInput(context.Background(), tt.resEndpointGroup) + + // Check if error status matches expected + if (err != nil) != tt.wantErr { + t.Errorf("buildSDKCreateEndpointGroupInput() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Check if the result matches expected + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/pkg/deploy/aga/endpoint_group_synthesizer.go b/pkg/deploy/aga/endpoint_group_synthesizer.go new file mode 100644 index 000000000..b1e4f31c3 --- /dev/null +++ b/pkg/deploy/aga/endpoint_group_synthesizer.go @@ -0,0 +1,485 @@ +package aga + +import ( + "context" + "k8s.io/apimachinery/pkg/util/sets" + "strings" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "github.com/pkg/errors" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +// NewEndpointGroupSynthesizer constructs new EndpointGroupSynthesizer +func NewEndpointGroupSynthesizer( + gaService services.GlobalAccelerator, + endpointGroupManager EndpointGroupManager, + logger logr.Logger, + stack core.Stack) *endpointGroupSynthesizer { + + return &endpointGroupSynthesizer{ + gaService: gaService, + endpointGroupManager: endpointGroupManager, + logger: logger, + stack: stack, + } +} + +// endpointGroupSynthesizer synthesizes AGA EndpointGroup resources +type endpointGroupSynthesizer struct { + gaService services.GlobalAccelerator + endpointGroupManager EndpointGroupManager + logger logr.Logger + stack core.Stack +} + +// EndpointPortConflict describes a conflict between endpoint ports in different groups +type EndpointPortConflict struct { + Port int32 + ConflictingGroups []string // ARNs of endpoint groups with this port + ListenerARNs []string // ARNs of listeners that contain the conflicting endpoint groups +} + +// endpointGroupAndSDKEndpointGroup contains a pair of endpoint group resource and its SDK endpoint group +type resAndSDKGroupPair struct { + resEndpointGroup *agamodel.EndpointGroup + sdkEndpointGroup *agatypes.EndpointGroup +} + +// mapEndpointGroupsByListenerARN maps endpoint groups by their parent listener ARN +func (s *endpointGroupSynthesizer) mapEndpointGroupsByListenerARN(ctx context.Context, resEndpointGroups []*agamodel.EndpointGroup) (map[string][]*agamodel.EndpointGroup, error) { + endpointGroupsByListenerARN := make(map[string][]*agamodel.EndpointGroup) + + for _, eg := range resEndpointGroups { + listenerARN, err := eg.Spec.ListenerARN.Resolve(ctx) + if err != nil { + return nil, errors.Wrapf(err, "failed to resolve listener ARN for endpoint group %s", eg.ID()) + } + endpointGroupsByListenerARN[listenerARN] = append(endpointGroupsByListenerARN[listenerARN], eg) + } + + return endpointGroupsByListenerARN, nil +} + +// areListenersEquivalent checks if two listener ARNs refer to the same underlying listener +// Since AWS Global Accelerator ARNs are consistent across resources, a simple string comparison is sufficient +func areListenersEquivalent(listener1, listener2 string) bool { + return listener1 == listener2 +} + +// EndpointPortInfo represents endpoint port usage information for a specific region +type EndpointPortInfo struct { + // Region where this endpoint port is used + Region string + // Port number + Port int32 + // ListenerARN that uses this port + ListenerARN string + // EndpointGroupARN of the group using this port + EndpointGroupARN string +} + +// detectConflictsWithSDKEndpointGroups identifies endpoint port conflicts between desired +// endpoint groups and existing SDK endpoint groups within the same AWS region. +// +// AWS Global Accelerator enforces a critical constraint: within a single region, endpoint ports +// must be unique across endpoint groups belonging to different listeners. This prevents the +// same destination port from being used by multiple listeners in the same region, which would +// create ambiguity for traffic routing. +// +// The function compares endpoint ports used by desired endpoint groups against those +// already in use by existing SDK endpoint groups. It returns a map of conflicting ports +// to their respective conflicting endpoint group ARNs for resolution. +// +// For more details on port override constraints, see: +// https://docs.aws.amazon.com/global-accelerator/latest/dg/about-endpoint-groups-port-override.html +func (s *endpointGroupSynthesizer) detectConflictsWithSDKEndpointGroups( + ctx context.Context, + resEndpointGroups []*agamodel.EndpointGroup, + sdkEndpointGroups []agatypes.EndpointGroup) (map[int32][]string, error) { + + // Step 1: Collect all desired endpoint ports by region and listener + var desiredPortInfos []EndpointPortInfo + + for _, resGroup := range resEndpointGroups { + // Skip groups with no port overrides + if resGroup.Spec.PortOverrides == nil || len(resGroup.Spec.PortOverrides) == 0 { + continue + } + + // Get listener ARN for this resource group + listenerARN, err := resGroup.Spec.ListenerARN.Resolve(ctx) + if err != nil { + return nil, errors.Wrapf(err, "failed to resolve listener ARN for endpoint group %s", resGroup.ID()) + } + + // Add all endpoint ports this group wants to use + for _, po := range resGroup.Spec.PortOverrides { + desiredPortInfos = append(desiredPortInfos, EndpointPortInfo{ + Region: resGroup.Spec.Region, + Port: po.EndpointPort, + ListenerARN: listenerARN, + }) + } + } + + // No desired port overrides means no conflicts return early + if len(desiredPortInfos) == 0 { + return nil, nil + } + + // Step 2: Collect all SDK endpoint ports by region and listener + var sdkPortInfos []EndpointPortInfo + + for _, sdkGroup := range sdkEndpointGroups { + region := awssdk.ToString(sdkGroup.EndpointGroupRegion) + groupARN := awssdk.ToString(sdkGroup.EndpointGroupArn) + listenerARN := extractListenerARNFromEndpointGroupARN(groupARN) + + // Add all ports for this SDK group + for _, po := range sdkGroup.PortOverrides { + port := awssdk.ToInt32(po.EndpointPort) + sdkPortInfos = append(sdkPortInfos, EndpointPortInfo{ + Region: region, + Port: port, + ListenerARN: listenerARN, + EndpointGroupARN: groupARN, + }) + } + } + + // Step 3: Find conflicts by comparing different listeners using the same port in the same region + conflicts := make(map[int32][]string) // map[port][]conflictingGroupARNs + + for _, desiredInfo := range desiredPortInfos { + for _, sdkInfo := range sdkPortInfos { + // Check if they're in the same region and using the same port + if desiredInfo.Region == sdkInfo.Region && desiredInfo.Port == sdkInfo.Port { + // Check if they're from different listeners + if !areListenersEquivalent(desiredInfo.ListenerARN, sdkInfo.ListenerARN) { + // If different listeners use same port in same region, it's a conflict + conflicts[desiredInfo.Port] = append(conflicts[desiredInfo.Port], sdkInfo.EndpointGroupARN) + + s.logger.V(1).Info("Detected endpoint port conflict", + "endpointPort", desiredInfo.Port, + "region", desiredInfo.Region, + "conflictingSDKGroup", sdkInfo.EndpointGroupARN) + } + } + } + } + + return conflicts, nil +} + +// extractListenerARNFromEndpointGroupARN extracts the listener ARN portion from an endpoint group ARN +// Returns empty string if the format doesn't match expectations +func extractListenerARNFromEndpointGroupARN(groupARN string) string { + // Expected format: arn:aws:globalaccelerator::123456789012:accelerator/abcd/listener/l-1234/endpoint-group/eg-1234 + + // Check for endpoint group pattern + if !strings.Contains(groupARN, "/endpoint-group/") { + return "" + } + + // Split by "endpoint-group" and take first part + parts := strings.Split(groupARN, "/endpoint-group/") + if len(parts) < 2 { + return "" + } + + return parts[0] +} + +// resolveConflictsWithSDKEndpointGroups resolves endpoint port conflicts by updating +// the existing SDK endpoint groups to remove conflicting port overrides +func (s *endpointGroupSynthesizer) resolveConflictsWithSDKEndpointGroups( + ctx context.Context, + conflicts map[int32][]string, + sdkEndpointGroups []agatypes.EndpointGroup) error { + + if len(conflicts) == 0 { + return nil + } + + s.logger.V(1).Info("Detected endpoint port conflicts with existing SDK endpoint groups", + "conflictCount", len(conflicts)) + + // Track which SDK groups need updating + sdkGroupUpdates := make(map[string][]agatypes.PortOverride) + + // For each conflict, we need to remove the conflicting port override from the SDK groups + for port, conflictingGroups := range conflicts { + for _, groupARN := range conflictingGroups { + // Find the SDK group with this ARN + for i, sdkGroup := range sdkEndpointGroups { + if awssdk.ToString(sdkGroup.EndpointGroupArn) == groupARN { + // Create a filtered list of port overrides excluding the conflicting one + updatedPortOverrides := make([]agatypes.PortOverride, 0, len(sdkGroup.PortOverrides)) + + for _, po := range sdkGroup.PortOverrides { + if awssdk.ToInt32(po.EndpointPort) != port { + updatedPortOverrides = append(updatedPortOverrides, po) + } + } + + // Store the updated port overrides for this group + sdkGroupUpdates[groupARN] = updatedPortOverrides + + // Update the SDK endpoint group in our local list to reflect the change + sdkEndpointGroups[i].PortOverrides = updatedPortOverrides + break + } + } + } + } + + // Update all the SDK endpoint groups that need updating + for groupARN, updatedPortOverrides := range sdkGroupUpdates { + s.logger.V(1).Info("Updating existing endpoint group to remove conflicting port overrides", + "endpointGroupARN", groupARN, + "updatedPortOverridesCount", len(updatedPortOverrides)) + + // Create update input + updateInput := &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: awssdk.String(groupARN), + PortOverrides: updatedPortOverrides, + } + + // Update the endpoint group + _, err := s.gaService.UpdateEndpointGroupWithContext(ctx, updateInput) + if err != nil { + return errors.Wrapf(err, "failed to update endpoint group %s to resolve port conflicts", groupARN) + } + } + + return nil +} + +// getAllEndpointGroupsInListeners returns all endpoint groups across all listeners +func (s *endpointGroupSynthesizer) getAllEndpointGroupsInListeners(ctx context.Context, listenerARNs []string) ([]agatypes.EndpointGroup, error) { + var allEndpointGroups []agatypes.EndpointGroup + for _, listenerARN := range listenerARNs { + endpointGroups, err := s.gaService.ListEndpointGroupsAsList(ctx, &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String(listenerARN), + }) + if err != nil { + return nil, errors.Wrapf(err, "failed to list endpoint groups for listener %s", listenerARN) + } + allEndpointGroups = append(allEndpointGroups, endpointGroups...) + } + + return allEndpointGroups, nil +} + +// Synthesize performs the actual synthesis of endpoint group resources +func (s *endpointGroupSynthesizer) Synthesize(ctx context.Context) error { + var resEndpointGroups []*agamodel.EndpointGroup + s.stack.ListResources(&resEndpointGroups) + + // Get listener ARNs from stack + var resListeners []*agamodel.Listener + s.stack.ListResources(&resListeners) + + // Nothing to process. No Listeners and endpoint groups in stack. + // This means we have already deleted all the unneeded listeners and its corresponding endpoint groups during listener synthesis. + if len(resListeners) == 0 { + return nil + } + + listenerARNs := make([]string, 0, len(resListeners)) + for _, resListener := range resListeners { + listenerARN, err := resListener.ListenerARN().Resolve(ctx) + if err != nil { + return errors.Wrapf(err, "failed to resolve listener ARN for resListener %s", resListener.ID()) + } + listenerARNs = append(listenerARNs, listenerARN) + } + + // Group endpoint groups by listener ARN + endpointGroupsByListenerARN, err := s.mapEndpointGroupsByListenerARN(ctx, resEndpointGroups) + if err != nil { + return err + } + + // Only detect conflicts for endpoint port duplicates if there are any desired endpoint groups in our stack + if len(resEndpointGroups) > 0 { + // Get endpoint groups and handle any conflicts before proceeding + if err := s.detectAndResolveEndpointGroupConflicts(ctx, resEndpointGroups, listenerARNs); err != nil { + return err + } + } + + // Process endpoint groups by listener ARN + for _, listenerARN := range listenerARNs { + resEndpointGroups := endpointGroupsByListenerARN[listenerARN] + if err := s.synthesizeEndpointGroupsOnListener(ctx, listenerARN, resEndpointGroups); err != nil { + return err + } + } + + return nil +} + +// PostSynthesize performs cleanup of endpoint group resources +// Currently not needed as deletion happens in synthesizeEndpointGroupsOnListener +func (s *endpointGroupSynthesizer) PostSynthesize(ctx context.Context) error { + return nil +} + +// detectAndResolveEndpointGroupConflicts handles endpoint group conflicts +// It detects conflicts between endpoint groups from different listeners, resolves them if needed +func (s *endpointGroupSynthesizer) detectAndResolveEndpointGroupConflicts(ctx context.Context, resEndpointGroups []*agamodel.EndpointGroup, listenerARNs []string) error { + s.logger.V(1).Info("Detecting and resolving endpoint group conflicts", + "endpointGroupCount", len(resEndpointGroups), + "listenerCount", len(listenerARNs)) + + // Get endpoint groups to check for conflicts with our desired state + allSDKEndpointGroups, err := s.getAllEndpointGroupsInListeners(ctx, listenerARNs) + if err != nil { + s.logger.Error(err, "Failed to get endpoint groups for conflict checking", + "listenerCount", len(listenerARNs)) + return errors.Wrap(err, "failed to get endpoint groups for conflict checking") + } + + // Detect conflicts between our desired endpoint groups and existing ones in AWS + sdkConflicts, err := s.detectConflictsWithSDKEndpointGroups(ctx, resEndpointGroups, allSDKEndpointGroups) + if err != nil { + s.logger.Error(err, "Failed to detect endpoint group conflicts", + "endpointGroupCount", len(resEndpointGroups), + "sdkEndpointGroupCount", len(allSDKEndpointGroups)) + return err + } + + // If conflicts with existing SDK endpoint groups are found, update them to remove conflicts + if len(sdkConflicts) > 0 { + for port, groups := range sdkConflicts { + s.logger.V(1).Info("Port conflict details", + "endpointPort", port, + "conflictingGroupCount", len(groups)) + } + + if err := s.resolveConflictsWithSDKEndpointGroups(ctx, sdkConflicts, allSDKEndpointGroups); err != nil { + s.logger.Error(err, "Failed to resolve endpoint group port conflicts", + "conflictCount", len(sdkConflicts)) + return errors.Wrap(err, "failed to resolve endpoint group port conflicts") + } + + s.logger.Info("Successfully resolved all endpoint group port conflicts", + "conflictCount", len(sdkConflicts)) + } else { + s.logger.V(1).Info("No endpoint group port conflicts detected with external groups") + } + + return nil +} + +// matchResAndSDKEndpointGroups matches resource endpoint groups with SDK endpoint groups using region as the unique key +func matchResAndSDKEndpointGroups(resEndpointGroups []*agamodel.EndpointGroup, sdkEndpointGroups []agatypes.EndpointGroup) ([]resAndSDKGroupPair, []*agamodel.EndpointGroup, []*agatypes.EndpointGroup) { + // Create maps for matching by region (region is the unique key within a listener) + sdkGroupsByRegion := make(map[string]*agatypes.EndpointGroup) + resGroupsByRegion := make(map[string]*agamodel.EndpointGroup) + + // Map resource endpoint groups by region + for _, resGroup := range resEndpointGroups { + region := resGroup.Spec.Region + resGroupsByRegion[region] = resGroup + } + + // Map SDK endpoint groups by region + for _, sdkGroup := range sdkEndpointGroups { + region := awssdk.ToString(sdkGroup.EndpointGroupRegion) + sdkGroupsByRegion[region] = &sdkGroup + } + resGroupRegions := sets.StringKeySet(resGroupsByRegion) + sdkGroupRegions := sets.StringKeySet(sdkGroupsByRegion) + + // Find matches and non-matches + var matchedResAndSDKGroups []resAndSDKGroupPair + var unmatchedResGroups []*agamodel.EndpointGroup + var unmatchedSDKGroups []*agatypes.EndpointGroup + + // Find matched pairs and unmatched resource groups + for _, region := range resGroupRegions.Intersection(sdkGroupRegions).List() { + resGroup := resGroupsByRegion[region] + sdkGroup := sdkGroupsByRegion[region] + matchedResAndSDKGroups = append(matchedResAndSDKGroups, resAndSDKGroupPair{ + resEndpointGroup: resGroup, + sdkEndpointGroup: sdkGroup, + }) + } + for _, region := range resGroupRegions.Difference(sdkGroupRegions).List() { + unmatchedResGroups = append(unmatchedResGroups, resGroupsByRegion[region]) + } + for _, region := range sdkGroupRegions.Difference(resGroupRegions).List() { + unmatchedSDKGroups = append(unmatchedSDKGroups, sdkGroupsByRegion[region]) + } + + return matchedResAndSDKGroups, unmatchedResGroups, unmatchedSDKGroups +} + +// synthesizeEndpointGroupsOnListener processes all endpoint groups for a specific listener +func (s *endpointGroupSynthesizer) synthesizeEndpointGroupsOnListener(ctx context.Context, listenerARN string, resEndpointGroups []*agamodel.EndpointGroup) error { + // Get existing endpoint groups for this listener from AWS + sdkEndpointGroups, err := s.getEndpointGroupsForListener(ctx, listenerARN) + if err != nil { + return errors.Wrapf(err, "failed to list endpoint groups for listener %s", listenerARN) + } + + // Match resource endpoint groups with SDK endpoint groups + matchedEndpointGroups, unmatchedResEndpointGroups, unmatchedSDKEndpointGroups := matchResAndSDKEndpointGroups(resEndpointGroups, sdkEndpointGroups) + + // Handle matched pairs - update them + for _, pair := range matchedEndpointGroups { + s.logger.Info("Updating existing endpoint group", + "endpointGroupArn", *pair.sdkEndpointGroup.EndpointGroupArn, + "region", pair.resEndpointGroup.Spec.Region) + status, err := s.endpointGroupManager.Update(ctx, pair.resEndpointGroup, pair.sdkEndpointGroup) + if err != nil { + return errors.Wrapf(err, "failed to update endpoint group %v", pair.resEndpointGroup.ID()) + } + + // Update the resource with the returned status + pair.resEndpointGroup.SetStatus(status) + } + + // Handle unmatched SDK endpoint groups - delete them + for _, sdkGroup := range unmatchedSDKEndpointGroups { + s.logger.Info("Deleting unneeded endpoint group", + "endpointGroupArn", *sdkGroup.EndpointGroupArn, + "region", *sdkGroup.EndpointGroupRegion) + egARN := awssdk.ToString(sdkGroup.EndpointGroupArn) + + if err := s.endpointGroupManager.Delete(ctx, egARN); err != nil { + return errors.Wrapf(err, "failed to delete unneeded endpoint group: %v", egARN) + } + } + + // Handle unmatched resource endpoint groups - create them + for _, resGroup := range unmatchedResEndpointGroups { + s.logger.Info("Creating new endpoint group", + "region", resGroup.Spec.Region) + status, err := s.endpointGroupManager.Create(ctx, resGroup) + if err != nil { + return errors.Wrapf(err, "failed to create endpoint group %v", resGroup.ID()) + } + + // Update the resource with the returned status + resGroup.SetStatus(status) + } + return nil +} + +// getEndpointGroupsForListener gets all endpoint groups for a specific listener +func (s *endpointGroupSynthesizer) getEndpointGroupsForListener(ctx context.Context, listenerARN string) ([]agatypes.EndpointGroup, error) { + listInput := &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String(listenerARN), + } + + return s.gaService.ListEndpointGroupsAsList(ctx, listInput) +} diff --git a/pkg/deploy/aga/endpoint_group_synthesizer_test.go b/pkg/deploy/aga/endpoint_group_synthesizer_test.go new file mode 100644 index 000000000..11a54ad81 --- /dev/null +++ b/pkg/deploy/aga/endpoint_group_synthesizer_test.go @@ -0,0 +1,1182 @@ +package aga + +import ( + "context" + "sort" + "testing" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" + "github.com/go-logr/logr" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +// Helper function to create a test endpoint group with port overrides +func createEndpointGroupWithPortOverrides(id string, region string, listenerARN string, portOverrides []agamodel.PortOverride) *agamodel.EndpointGroup { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + + return &agamodel.EndpointGroup{ + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", id), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(listenerARN), + Region: region, + TrafficDialPercentage: awssdk.Int32(100), + PortOverrides: portOverrides, + }, + } +} + +func Test_matchResAndSDKEndpointGroups(t *testing.T) { + mockStack := core.NewDefaultStack(core.StackID{Namespace: "test-namespace", Name: "test-name"}) + testListenerARN := "arn:aws:globalaccelerator::123456789012:listener/1234abcd-abcd-1234-abcd-1234abcdefgh" + + tests := []struct { + name string + resEndpointGroups []*agamodel.EndpointGroup + sdkEndpointGroups []agatypes.EndpointGroup + wantMatchedPairs []struct { + resID string + sdkRegion string + sdkGroupARN string + } + wantUnmatchedResIDs []string + wantUnmatchedSDKRegions []string + }{ + { + name: "empty lists", + resEndpointGroups: []*agamodel.EndpointGroup{}, + sdkEndpointGroups: []agatypes.EndpointGroup{}, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{}, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKRegions: []string{}, + }, + { + name: "single exact match by region", + resEndpointGroups: []*agamodel.EndpointGroup{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-1"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + TrafficDialPercentage: awssdk.Int32(100), + }, + }, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-1234"), + EndpointGroupRegion: awssdk.String("us-west-2"), + TrafficDialPercentage: awssdk.Float32(100.0), + }, + }, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + { + resID: "endpoint-group-1", + sdkRegion: "us-west-2", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-1234", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKRegions: []string{}, + }, + { + name: "multiple matches by region", + resEndpointGroups: []*agamodel.EndpointGroup{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-1"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-2"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-east-1", + }, + }, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-east"), + EndpointGroupRegion: awssdk.String("us-east-1"), + }, + }, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + { + resID: "endpoint-group-1", + sdkRegion: "us-west-2", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west", + }, + { + resID: "endpoint-group-2", + sdkRegion: "us-east-1", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-east", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKRegions: []string{}, + }, + { + name: "unmatched resource endpoint groups", + resEndpointGroups: []*agamodel.EndpointGroup{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-1"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-2"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "eu-west-1", // No matching SDK endpoint group + }, + }, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + }, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + { + resID: "endpoint-group-1", + sdkRegion: "us-west-2", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west", + }, + }, + wantUnmatchedResIDs: []string{"endpoint-group-2"}, + wantUnmatchedSDKRegions: []string{}, + }, + { + name: "unmatched SDK endpoint groups", + resEndpointGroups: []*agamodel.EndpointGroup{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-1"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + }, + }, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-east"), + EndpointGroupRegion: awssdk.String("us-east-1"), // No matching resource endpoint group + }, + }, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + { + resID: "endpoint-group-1", + sdkRegion: "us-west-2", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west", + }, + }, + wantUnmatchedResIDs: []string{}, + wantUnmatchedSDKRegions: []string{"us-east-1"}, + }, + { + name: "mixed matches and unmatches", + resEndpointGroups: []*agamodel.EndpointGroup{ + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-west"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-west-2", + }, + }, + { + ResourceMeta: core.NewResourceMeta(mockStack, "AWS::GlobalAccelerator::EndpointGroup", "endpoint-group-central"), + Spec: agamodel.EndpointGroupSpec{ + ListenerARN: core.LiteralStringToken(testListenerARN), + Region: "us-central-1", // No matching SDK endpoint group + }, + }, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-east"), + EndpointGroupRegion: awssdk.String("us-east-1"), // No matching resource endpoint group + }, + }, + wantMatchedPairs: []struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + { + resID: "endpoint-group-west", + sdkRegion: "us-west-2", + sdkGroupARN: "arn:aws:globalaccelerator::123456789012:endpointgroup/1234abcd-abcd-1234-abcd-1234abcdefgh/eg-west", + }, + }, + wantUnmatchedResIDs: []string{"endpoint-group-central"}, + wantUnmatchedSDKRegions: []string{"us-east-1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Run the function + matchedPairs, unmatchedResEndpointGroups, unmatchedSDKEndpointGroups := matchResAndSDKEndpointGroups(tt.resEndpointGroups, tt.sdkEndpointGroups) + + // Collect the actual pairs and IDs for verification + var actualMatchedPairs []struct { + resID string + sdkRegion string + sdkGroupARN string + } + + for _, pair := range matchedPairs { + actualMatchedPairs = append(actualMatchedPairs, struct { + resID string + sdkRegion string + sdkGroupARN string + }{ + resID: pair.resEndpointGroup.ID(), + sdkRegion: awssdk.ToString(pair.sdkEndpointGroup.EndpointGroupRegion), + sdkGroupARN: awssdk.ToString(pair.sdkEndpointGroup.EndpointGroupArn), + }) + } + + var actualUnmatchedResIDs []string + for _, eg := range unmatchedResEndpointGroups { + actualUnmatchedResIDs = append(actualUnmatchedResIDs, eg.ID()) + } + + var actualUnmatchedSDKRegions []string + for _, eg := range unmatchedSDKEndpointGroups { + actualUnmatchedSDKRegions = append(actualUnmatchedSDKRegions, awssdk.ToString(eg.EndpointGroupRegion)) + } + + // Sort all slices to ensure consistent comparison + sort.Slice(actualMatchedPairs, func(i, j int) bool { + if actualMatchedPairs[i].resID != actualMatchedPairs[j].resID { + return actualMatchedPairs[i].resID < actualMatchedPairs[j].resID + } + return actualMatchedPairs[i].sdkGroupARN < actualMatchedPairs[j].sdkGroupARN + }) + + sort.Slice(tt.wantMatchedPairs, func(i, j int) bool { + if tt.wantMatchedPairs[i].resID != tt.wantMatchedPairs[j].resID { + return tt.wantMatchedPairs[i].resID < tt.wantMatchedPairs[j].resID + } + return tt.wantMatchedPairs[i].sdkGroupARN < tt.wantMatchedPairs[j].sdkGroupARN + }) + + sort.Strings(actualUnmatchedResIDs) + sort.Strings(tt.wantUnmatchedResIDs) + sort.Strings(actualUnmatchedSDKRegions) + sort.Strings(tt.wantUnmatchedSDKRegions) + + // Verify matched pairs + assert.Equal(t, len(tt.wantMatchedPairs), len(actualMatchedPairs), "matched pairs count") + for i := range tt.wantMatchedPairs { + if i < len(actualMatchedPairs) { + assert.Equal(t, tt.wantMatchedPairs[i].resID, actualMatchedPairs[i].resID, "matched pair resID at index %d", i) + assert.Equal(t, tt.wantMatchedPairs[i].sdkRegion, actualMatchedPairs[i].sdkRegion, "matched pair sdkRegion at index %d", i) + assert.Equal(t, tt.wantMatchedPairs[i].sdkGroupARN, actualMatchedPairs[i].sdkGroupARN, "matched pair sdkGroupARN at index %d", i) + } + } + + // Handle nil vs empty slices + if len(actualUnmatchedResIDs) == 0 && len(tt.wantUnmatchedResIDs) == 0 { + // Both empty, no need to compare + } else { + // Verify unmatched resource endpoint groups + assert.ElementsMatch(t, tt.wantUnmatchedResIDs, actualUnmatchedResIDs, "unmatched resource endpoint groups") + } + + if len(actualUnmatchedSDKRegions) == 0 && len(tt.wantUnmatchedSDKRegions) == 0 { + // Both empty, no need to compare + } else { + // Verify unmatched SDK endpoint groups + assert.ElementsMatch(t, tt.wantUnmatchedSDKRegions, actualUnmatchedSDKRegions, "unmatched SDK endpoint groups") + } + }) + } +} + +// createSDKEndpointGroup creates an SDK endpoint group with port overrides +func createSDKEndpointGroup(arn string, region string, portOverrides []agatypes.PortOverride) agatypes.EndpointGroup { + return agatypes.EndpointGroup{ + EndpointGroupArn: awssdk.String(arn), + EndpointGroupRegion: awssdk.String(region), + PortOverrides: portOverrides, + } +} + +func Test_endpointGroupSynthesizer_getAllEndpointGroupsInListeners(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockGA := services.NewMockGlobalAccelerator(mockCtrl) + + // Create synthesizer instance with mock GA service + s := &endpointGroupSynthesizer{ + gaService: mockGA, + logger: logr.Discard(), + } + + // Test cases + tests := []struct { + name string + listenerARNs []string + mockSetup func(mockGA *services.MockGlobalAccelerator) + expectedGroups []agatypes.EndpointGroup + expectError bool + }{ + { + name: "empty listener ARNs list", + listenerARNs: []string{}, + mockSetup: func(mockGA *services.MockGlobalAccelerator) {}, + expectedGroups: []agatypes.EndpointGroup{}, + expectError: false, + }, + { + name: "single listener with no endpoint groups", + listenerARNs: []string{"arn:aws:globalaccelerator::123456789012:listener/listener1"}, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + mockGA.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:listener/listener1"), + }). + Return([]agatypes.EndpointGroup{}, nil) + }, + expectedGroups: []agatypes.EndpointGroup{}, + expectError: false, + }, + { + name: "single listener with one endpoint group", + listenerARNs: []string{"arn:aws:globalaccelerator::123456789012:listener/listener1"}, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + mockGA.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:listener/listener1"), + }). + Return([]agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg1"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + }, nil) + }, + expectedGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg1"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + }, + expectError: false, + }, + { + name: "multiple listeners with endpoint groups", + listenerARNs: []string{ + "arn:aws:globalaccelerator::123456789012:listener/listener1", + "arn:aws:globalaccelerator::123456789012:listener/listener2", + }, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + // First listener + mockGA.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:listener/listener1"), + }). + Return([]agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg1"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + }, nil) + + // Second listener + mockGA.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:listener/listener2"), + }). + Return([]agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg2"), + EndpointGroupRegion: awssdk.String("eu-west-1"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg3"), + EndpointGroupRegion: awssdk.String("ap-southeast-1"), + }, + }, nil) + }, + expectedGroups: []agatypes.EndpointGroup{ + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg1"), + EndpointGroupRegion: awssdk.String("us-west-2"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg2"), + EndpointGroupRegion: awssdk.String("eu-west-1"), + }, + { + EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg3"), + EndpointGroupRegion: awssdk.String("ap-southeast-1"), + }, + }, + expectError: false, + }, + { + name: "error retrieving endpoint groups", + listenerARNs: []string{"arn:aws:globalaccelerator::123456789012:listener/listener1"}, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + mockGA.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: awssdk.String("arn:aws:globalaccelerator::123456789012:listener/listener1"), + }). + Return(nil, errors.New("API error")) + }, + expectedGroups: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup mock expectations + tt.mockSetup(mockGA) + + // Call the function + ctx := context.Background() + endpointGroups, err := s.getAllEndpointGroupsInListeners(ctx, tt.listenerARNs) + + // Check error + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Check endpoint groups + assert.Equal(t, len(tt.expectedGroups), len(endpointGroups)) + + if len(tt.expectedGroups) > 0 { + // Compare each endpoint group + for i, expectedGroup := range tt.expectedGroups { + if i < len(endpointGroups) { + assert.Equal(t, awssdk.ToString(expectedGroup.EndpointGroupArn), + awssdk.ToString(endpointGroups[i].EndpointGroupArn)) + assert.Equal(t, awssdk.ToString(expectedGroup.EndpointGroupRegion), + awssdk.ToString(endpointGroups[i].EndpointGroupRegion)) + } + } + } + }) + } +} + +func Test_endpointGroupSynthesizer_detectConflictsWithSDKEndpointGroups(t *testing.T) { + testListener1ARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd/listener/l-1" + testListener2ARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd/listener/l-2" + tests := []struct { + name string + resEndpointGroups []*agamodel.EndpointGroup + sdkEndpointGroups []agatypes.EndpointGroup + wantConflictCount int + wantConflictPorts []int32 + }{ + { + name: "no endpoint groups", + resEndpointGroups: []*agamodel.EndpointGroup{}, + sdkEndpointGroups: []agatypes.EndpointGroup{}, + wantConflictCount: 0, + wantConflictPorts: []int32{}, + }, + { + name: "no port overrides", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, nil), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup("arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk", "us-east-1", nil), + }, + wantConflictCount: 0, + wantConflictPorts: []int32{}, + }, + { + name: "no conflicts - different endpoint ports", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk", + "us-east-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(9090)}, // Different port + }, + ), + }, + wantConflictCount: 0, + wantConflictPorts: []int32{}, + }, + { + name: "no conflicts - same endpoint port but different regions", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk", + "us-east-1", // Different region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8080)}, // Same endpoint port but different region + }, + ), + }, + wantConflictCount: 0, // No conflict because different regions + wantConflictPorts: []int32{}, + }, + { + name: "single conflict - same region", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:DIFFERENT:SERVICE:123456789012:endpoint-group/TEST_DIFFERENT_ARN/eg-sdk", + "us-west-2", // Same region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8080)}, // Same endpoint port + }, + ), + }, + wantConflictCount: 1, + wantConflictPorts: []int32{8080}, + }, + { + name: "multiple conflicts with same SDK group - same region", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + {ListenerPort: 443, EndpointPort: 8443}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:DIFFERENT:SERVICE:123456789012:endpoint-group/DIFFERENT-TYPES/eg-sdk", + "us-west-2", // Same region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, // Same as first port + {ListenerPort: awssdk.Int32(444), EndpointPort: awssdk.Int32(8443)}, // Same as second port + }, + ), + }, + wantConflictCount: 2, + wantConflictPorts: []int32{8080, 8443}, + }, + { + name: "multiple conflicts with same SDK group - different region (no conflict)", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + {ListenerPort: 443, EndpointPort: 8443}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk", + "us-east-1", // Different region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, // Same endpoint port but different region + {ListenerPort: awssdk.Int32(444), EndpointPort: awssdk.Int32(8443)}, // Same endpoint port but different region + }, + ), + }, + wantConflictCount: 0, // No conflicts because different regions + wantConflictPorts: []int32{}, + }, + { + name: "multiple conflicts with different SDK groups - same regions", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + {ListenerPort: 443, EndpointPort: 8443}, + }), + createEndpointGroupWithPortOverrides("eg-2", "eu-west-1", testListener2ARN, []agamodel.PortOverride{ + {ListenerPort: 81, EndpointPort: 9090}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/TEST_DIFFERENT_ARN_1/eg-sdk-1", + "us-west-2", // Same region as first resource group + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(8080)}, // Conflicts with eg-1 + }, + ), + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/TEST_DIFFERENT_ARN_2/eg-sdk-2", + "eu-west-1", // Same region as second resource group + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(83), EndpointPort: awssdk.Int32(9090)}, // Conflicts with eg-2 + }, + ), + }, + wantConflictCount: 2, + wantConflictPorts: []int32{8080, 9090}, + }, + { + name: "multiple conflicts with different SDK groups - different regions (no conflicts)", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createEndpointGroupWithPortOverrides("eg-2", "eu-west-1", testListener2ARN, []agamodel.PortOverride{ + {ListenerPort: 81, EndpointPort: 9090}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk-1", + "us-east-1", // Different region than first resource group + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(8080)}, // Same port but different region + }, + ), + createSDKEndpointGroup( + "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-sdk-2", + "ap-southeast-1", // Different region than second resource group + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(83), EndpointPort: awssdk.Int32(9090)}, // Same port but different region + }, + ), + }, + wantConflictCount: 0, // No conflicts because different regions + wantConflictPorts: []int32{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create synthesizer instance + s := &endpointGroupSynthesizer{ + logger: logr.Discard(), + } + + // Run the function + ctx := context.Background() + conflicts, err := s.detectConflictsWithSDKEndpointGroups(ctx, tt.resEndpointGroups, tt.sdkEndpointGroups) + + // Check errors + assert.NoError(t, err) + + // Verify the number of conflicts + assert.Equal(t, tt.wantConflictCount, len(conflicts), "conflict count should match") + + // Collect conflict ports for checking + var conflictPorts []int32 + for port := range conflicts { + conflictPorts = append(conflictPorts, port) + } + sort.Slice(conflictPorts, func(i, j int) bool { + return conflictPorts[i] < conflictPorts[j] + }) + + // Sort expected ports for comparison + expectedPorts := make([]int32, len(tt.wantConflictPorts)) + copy(expectedPorts, tt.wantConflictPorts) + sort.Slice(expectedPorts, func(i, j int) bool { + return expectedPorts[i] < expectedPorts[j] + }) + + // Check conflict ports + assert.ElementsMatch(t, expectedPorts, conflictPorts, "conflicting ports should match") + }) + } +} + +func Test_endpointGroupSynthesizer_resolveConflictsWithSDKEndpointGroups(t *testing.T) { + // Create a mock GlobalAccelerator service + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + mockGA := services.NewMockGlobalAccelerator(mockCtrl) + + testARN1 := "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-1" + testARN2 := "arn:aws:globalaccelerator::123456789012:endpointgroup/eg-2" + + tests := []struct { + name string + conflicts map[int32][]string + sdkEndpointGroups []agatypes.EndpointGroup + mockSetup func(mockGA *services.MockGlobalAccelerator) + expectError bool + }{ + { + name: "no conflicts", + conflicts: map[int32][]string{}, + sdkEndpointGroups: []agatypes.EndpointGroup{}, + mockSetup: func(mockGA *services.MockGlobalAccelerator) {}, + expectError: false, + }, + { + name: "single conflict with one group", + conflicts: map[int32][]string{ + 8080: {testARN1}, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + testARN1, + "us-east-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Conflicting + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, // Not conflicting + }, + ), + }, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + // Expect an update call with only the non-conflicting port override + mockGA.EXPECT().UpdateEndpointGroupWithContext(gomock.Any(), &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: awssdk.String(testARN1), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, + }, + }).Return(&globalaccelerator.UpdateEndpointGroupOutput{}, nil) + }, + expectError: false, + }, + { + name: "multiple conflicts with different groups", + conflicts: map[int32][]string{ + 8080: {testARN1}, + 9090: {testARN2}, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + testARN1, + "us-east-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Conflicting + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, // Not conflicting + }, + ), + createSDKEndpointGroup( + testARN2, + "eu-west-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(9090)}, // Conflicting + {ListenerPort: awssdk.Int32(444), EndpointPort: awssdk.Int32(9443)}, // Not conflicting + }, + ), + }, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + // Expect updates for both groups + mockGA.EXPECT().UpdateEndpointGroupWithContext(gomock.Any(), &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: awssdk.String(testARN1), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, + }, + }).Return(&globalaccelerator.UpdateEndpointGroupOutput{}, nil) + + mockGA.EXPECT().UpdateEndpointGroupWithContext(gomock.Any(), &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: awssdk.String(testARN2), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(444), EndpointPort: awssdk.Int32(9443)}, + }, + }).Return(&globalaccelerator.UpdateEndpointGroupOutput{}, nil) + }, + expectError: false, + }, + { + name: "error during update", + conflicts: map[int32][]string{ + 8080: {testARN1}, + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + createSDKEndpointGroup( + testARN1, + "us-east-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Conflicting + }, + ), + }, + mockSetup: func(mockGA *services.MockGlobalAccelerator) { + // Simulate an error during update + mockGA.EXPECT().UpdateEndpointGroupWithContext(gomock.Any(), gomock.Any()). + Return(nil, errors.New("update failed")) + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up the mock expectations + tt.mockSetup(mockGA) + + // Create synthesizer instance + s := &endpointGroupSynthesizer{ + gaService: mockGA, + logger: logr.Discard(), + } + + // Run the function + ctx := context.Background() + err := s.resolveConflictsWithSDKEndpointGroups(ctx, tt.conflicts, tt.sdkEndpointGroups) + + // Check errors + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +//func Test_endpointGroupSynthesizer_getAllEndpointGroupsInListeners(t *testing.T) { +// // Create a mock GlobalAccelerator service +// mockCtrl := gomock.NewController(t) +// defer mockCtrl.Finish() +// mockGA := services.NewMockGlobalAccelerator(mockCtrl) +// listenerARN1 := "arn:aws:globalaccelerator::123456789012:listener/acc-1/listener-1" +// listenerARN2 := "arn:aws:globalaccelerator::123456789012:listener/acc-1/listener-2" +// listenerARN3 := "arn:aws:globalaccelerator::123456789012:listener/acc-2/listener-3" +// +// endpointGroups1 := []agatypes.EndpointGroup{ +// { +// EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg-1"), +// EndpointGroupRegion: awssdk.String("us-east-1"), +// }, +// } +// +// endpointGroups2 := []agatypes.EndpointGroup{ +// { +// EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg-2"), +// EndpointGroupRegion: awssdk.String("us-west-2"), +// }, +// } +// +// endpointGroups3 := []agatypes.EndpointGroup{ +// { +// EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg-3"), +// EndpointGroupRegion: awssdk.String("eu-west-1"), +// }, +// { +// EndpointGroupArn: awssdk.String("arn:aws:globalaccelerator::123456789012:endpointgroup/eg-4"), +// EndpointGroupRegion: awssdk.String("eu-central-1"), +// }, +// } +// +// mockGA.EXPECT().ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ +// ListenerArn: awssdk.String(listenerARN1), +// }).Return(endpointGroups1, nil) +// +// mockGA.EXPECT().ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ +// ListenerArn: awssdk.String(listenerARN2), +// }).Return(endpointGroups2, nil) +// +// mockGA.EXPECT().ListEndpointGroupsAsList(gomock.Any(), &globalaccelerator.ListEndpointGroupsInput{ +// ListenerArn: awssdk.String(listenerARN3), +// }).Return(endpointGroups3, nil) +// +// // Create synthesizer instance +// s := &endpointGroupSynthesizer{ +// gaService: mockGA, +// logger: logr.Discard(), +// } +// +// // Run the function +// ctx := context.Background() +// allEndpointGroups, err := s.getAllEndpointGroupsInListeners(ctx, nil) +// +// // Check errors +// assert.NoError(t, err) +// +// // Verify that all endpoint groups are returned +// expectedCount := len(endpointGroups1) + len(endpointGroups2) + len(endpointGroups3) +// assert.Equal(t, expectedCount, len(allEndpointGroups), "should return all endpoint groups") +// +// // Check that groups from all listeners are included +// endpointGroupARNs := make(map[string]bool) +// endpointGroupRegions := make(map[string]bool) +// +// // Collect all ARNs and regions +// for _, eg := range allEndpointGroups { +// endpointGroupARNs[awssdk.ToString(eg.EndpointGroupArn)] = true +// endpointGroupRegions[awssdk.ToString(eg.EndpointGroupRegion)] = true +// } +// +// // Check that expected endpoint groups are included +// assert.Contains(t, endpointGroupARNs, awssdk.ToString(endpointGroups1[0].EndpointGroupArn), "should contain endpoint group 1") +// assert.Contains(t, endpointGroupARNs, awssdk.ToString(endpointGroups2[0].EndpointGroupArn), "should contain endpoint group 2") +// assert.Contains(t, endpointGroupARNs, awssdk.ToString(endpointGroups3[0].EndpointGroupArn), "should contain endpoint group 3") +// assert.Contains(t, endpointGroupARNs, awssdk.ToString(endpointGroups3[1].EndpointGroupArn), "should contain endpoint group 4") +// +// // Check that expected regions are included +// assert.Contains(t, endpointGroupRegions, "us-east-1", "should contain us-east-1 region") +// assert.Contains(t, endpointGroupRegions, "us-west-2", "should contain us-west-2 region") +// assert.Contains(t, endpointGroupRegions, "eu-west-1", "should contain eu-west-1 region") +// assert.Contains(t, endpointGroupRegions, "eu-central-1", "should contain eu-central-1 region") +//} + +// Test_endpointGroupSynthesizer_detectConflictsWithSDKEndpointGroups_OwnListener tests that the +// detectConflictsWithSDKEndpointGroups function correctly ignores conflicts with our own listeners +func Test_endpointGroupSynthesizer_detectConflictsWithSDKEndpointGroups_OwnListener(t *testing.T) { + // Define test listeners and endpoint groups + testListener1ARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd/listener/l-1" + testListener2ARN := "arn:aws:globalaccelerator::123456789012:accelerator/5678efgh/listener/l-2" + + // ARNs for SDK endpoint groups - match the structure required by the extractor + sdkGroup1ARN := "arn:aws:globalaccelerator::123456789012:accelerator/1234abcd/listener/l-1/endpoint-group/eg-1" + sdkGroup2ARN := "arn:aws:globalaccelerator::123456789012:accelerator/5678efgh/listener/l-2/endpoint-group/eg-2" + sdkGroup3ARN := "arn:aws:globalaccelerator::123456789012:accelerator/9012ijkl/listener/l-3/endpoint-group/eg-3" + + tests := []struct { + name string + resEndpointGroups []*agamodel.EndpointGroup + sdkEndpointGroups []agatypes.EndpointGroup + wantConflictCount int + wantConflictPorts []int32 + }{ + { + name: "no conflict with own listener - same endpoint port in same region", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + // Same listener ARN pattern as resource group, should be ignored + createSDKEndpointGroup( + sdkGroup1ARN, + "us-west-2", // Same region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, // Same endpoint port + }, + ), + }, + wantConflictCount: 0, // No conflict detected with our own listener + wantConflictPorts: []int32{}, + }, + { + name: "conflict with external listener - same endpoint port in same region", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + // Different listener ARN than resource group, should detect conflict + createSDKEndpointGroup( + sdkGroup2ARN, + "us-west-2", // Same region + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, // Same endpoint port + }, + ), + }, + wantConflictCount: 1, + wantConflictPorts: []int32{8080}, + }, + { + name: "multiple listeners in different regions - no conflicts", + resEndpointGroups: []*agamodel.EndpointGroup{ + createEndpointGroupWithPortOverrides("eg-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + }), + createEndpointGroupWithPortOverrides("eg-2", "eu-west-1", testListener2ARN, []agamodel.PortOverride{ + {ListenerPort: 443, EndpointPort: 8080}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + // Same listener ARN but different region - no conflict + createSDKEndpointGroup( + sdkGroup1ARN, + "us-east-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, + }, + ), + // Different listener ARN but different region - no conflict + createSDKEndpointGroup( + sdkGroup3ARN, + "eu-central-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(8080)}, + }, + ), + }, + wantConflictCount: 0, + wantConflictPorts: []int32{}, + }, + { + name: "complex scenario - multiple listeners, regions, with own and external conflicts", + resEndpointGroups: []*agamodel.EndpointGroup{ + // Resource group 1 + createEndpointGroupWithPortOverrides("eg-west-1", "us-west-2", testListener1ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 8080}, + {ListenerPort: 443, EndpointPort: 8443}, + }), + // Resource group 2 - different region + createEndpointGroupWithPortOverrides("eg-eu-1", "eu-west-1", testListener2ARN, []agamodel.PortOverride{ + {ListenerPort: 80, EndpointPort: 9090}, + }), + }, + sdkEndpointGroups: []agatypes.EndpointGroup{ + // Same listener as eg-west-1, should NOT be conflict + createSDKEndpointGroup( + sdkGroup1ARN, + "us-west-2", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8080)}, + {ListenerPort: awssdk.Int32(444), EndpointPort: awssdk.Int32(8443)}, + }, + ), + // Same listener as eg-eu-1, should NOT be conflict + createSDKEndpointGroup( + sdkGroup2ARN, + "eu-west-1", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(9090)}, + }, + ), + // Different listener in us-west-2, SHOULD be conflict with eg-west-1 + createSDKEndpointGroup( + sdkGroup3ARN, + "us-west-2", + []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(83), EndpointPort: awssdk.Int32(8080)}, + }, + ), + }, + wantConflictCount: 1, // Only one conflict (port 8080 in us-west-2 with external listener) + wantConflictPorts: []int32{8080}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create synthesizer instance + s := &endpointGroupSynthesizer{ + logger: logr.Discard(), + } + + // Run the function + ctx := context.Background() + conflicts, err := s.detectConflictsWithSDKEndpointGroups(ctx, tt.resEndpointGroups, tt.sdkEndpointGroups) + + // Check errors + assert.NoError(t, err) + + // Verify the number of conflicts + assert.Equal(t, tt.wantConflictCount, len(conflicts), "conflict count should match") + + // Collect conflict ports for checking + var conflictPorts []int32 + for port := range conflicts { + conflictPorts = append(conflictPorts, port) + } + sort.Slice(conflictPorts, func(i, j int) bool { + return conflictPorts[i] < conflictPorts[j] + }) + + // Sort expected ports for comparison + expectedPorts := make([]int32, len(tt.wantConflictPorts)) + copy(expectedPorts, tt.wantConflictPorts) + sort.Slice(expectedPorts, func(i, j int) bool { + return expectedPorts[i] < expectedPorts[j] + }) + + // Check conflict ports + assert.ElementsMatch(t, expectedPorts, conflictPorts, "conflicting ports should match") + + // If conflicts were detected, verify they're with the right endpoint groups + if len(conflicts) > 0 { + for port, sdkGroupARNs := range conflicts { + for _, sdkGroupARN := range sdkGroupARNs { + // For each test case, only check for listener ARNs that are actually in our resource groups + // Get the list of listener ARNs used in this test's resource groups + // Create a map of all listener ARNs used by our endpoint groups + // In our test setup, we know exactly which endpoint groups use which listener ARNs + // based on how we create them in the test cases + usedListenerARNs := map[string]bool{ + testListener1ARN: false, + testListener2ARN: false, + } + + // For our test cases, we directly know which ARNs are used for each test + switch tt.name { + case "no conflict with own listener - same endpoint port in same region": + usedListenerARNs[testListener1ARN] = true + case "conflict with external listener - same endpoint port in same region": + usedListenerARNs[testListener1ARN] = true + case "multiple listeners with same port - one own, one external": + usedListenerARNs[testListener1ARN] = true + case "multiple listeners in different regions - no conflicts": + usedListenerARNs[testListener1ARN] = true + usedListenerARNs[testListener2ARN] = true + case "complex scenario - multiple listeners, regions, with own and external conflicts": + usedListenerARNs[testListener1ARN] = true + usedListenerARNs[testListener2ARN] = true + } + + // Only check if the conflict is NOT with listeners we're using + if usedListenerARNs[testListener1ARN] { + assert.NotContains(t, sdkGroupARN, testListener1ARN, + "conflict should not be with our own listener (testListener1ARN) for port %d", port) + } + + if usedListenerARNs[testListener2ARN] { + assert.NotContains(t, sdkGroupARN, testListener2ARN, + "conflict should not be with our own listener (testListener2ARN) for port %d", port) + } + } + } + } + }) + } +} diff --git a/pkg/deploy/aga/errors.go b/pkg/deploy/aga/errors.go index bb8bb9e5c..2529133bf 100644 --- a/pkg/deploy/aga/errors.go +++ b/pkg/deploy/aga/errors.go @@ -9,6 +9,9 @@ const ( // DeploymentFailed is the error code when stack deployment fails DeploymentFailed = "DeploymentFailed" + + // Status reason constants + EndpointLoadFailed = "EndpointLoadFailed" ) // AcceleratorNotDisabledError is returned when an accelerator is not ready for deletion diff --git a/pkg/deploy/aga/listener_manager.go b/pkg/deploy/aga/listener_manager.go index b72d676af..f149f7868 100644 --- a/pkg/deploy/aga/listener_manager.go +++ b/pkg/deploy/aga/listener_manager.go @@ -23,13 +23,17 @@ type ListenerManager interface { // Delete deletes a listener. Delete(ctx context.Context, listenerARN string) error + + // ListEndpointGroups lists all endpoint groups for a given listener + ListEndpointGroups(ctx context.Context, listenerARN string) ([]agatypes.EndpointGroup, error) } // NewDefaultListenerManager constructs new defaultListenerManager. -func NewDefaultListenerManager(gaService services.GlobalAccelerator, logger logr.Logger) *defaultListenerManager { +func NewDefaultListenerManager(gaService services.GlobalAccelerator, endpointGroupManager EndpointGroupManager, logger logr.Logger) *defaultListenerManager { return &defaultListenerManager{ - gaService: gaService, - logger: logger, + gaService: gaService, + endpointGroupManager: endpointGroupManager, + logger: logger, } } @@ -37,8 +41,9 @@ var _ ListenerManager = &defaultListenerManager{} // defaultListenerManager is the default implementation for ListenerManager. type defaultListenerManager struct { - gaService services.GlobalAccelerator - logger logr.Logger + gaService services.GlobalAccelerator + endpointGroupManager EndpointGroupManager + logger logr.Logger } // convertPortRangesToSDK converts model port ranges to SDK port ranges @@ -164,11 +169,30 @@ func (m *defaultListenerManager) Update(ctx context.Context, resListener *agamod } func (m *defaultListenerManager) Delete(ctx context.Context, listenerARN string) error { - // TODO: This will be enhanced to check for and delete endpoint groups - // before deleting the listener (when those features are implemented) - m.logger.Info("Deleting listener", "listenerARN", listenerARN) + // Step 1: Delete all endpoint groups associated with this listener + endpointGroups, err := m.ListEndpointGroups(ctx, listenerARN) + if err != nil { + var apiErr *agatypes.ListenerNotFoundException + if errors.As(err, &apiErr) { + m.logger.Info("Listener not found, assuming already deleted", "listenerARN", listenerARN) + return nil + } + return fmt.Errorf("failed to list endpoint groups for listener: %w", err) + } + + for _, endpointGroup := range endpointGroups { + endpointGroupARN := aws.ToString(endpointGroup.EndpointGroupArn) + m.logger.Info("Deleting endpoint group for listener", "endpointGroupARN", endpointGroupARN, "listenerARN", listenerARN) + + if err := m.endpointGroupManager.Delete(ctx, endpointGroupARN); err != nil { + return fmt.Errorf("failed to delete endpoint group %s: %w", endpointGroupARN, err) + } + m.logger.Info("Deleted endpoint group for listener", "endpointGroupARN", endpointGroupARN, "listenerARN", listenerARN) + } + + // Step 2: Delete the listener deleteInput := &globalaccelerator.DeleteListenerInput{ ListenerArn: aws.String(listenerARN), } @@ -187,6 +211,15 @@ func (m *defaultListenerManager) Delete(ctx context.Context, listenerARN string) return nil } +// ListEndpointGroups lists all endpoint groups for a given listener +func (m *defaultListenerManager) ListEndpointGroups(ctx context.Context, listenerARN string) ([]agatypes.EndpointGroup, error) { + listInput := &globalaccelerator.ListEndpointGroupsInput{ + ListenerArn: aws.String(listenerARN), + } + + return m.gaService.ListEndpointGroupsAsList(ctx, listInput) +} + // isSDKListenerSettingsDrifted checks if the listener configuration has drifted from the desired state func (m *defaultListenerManager) isSDKListenerSettingsDrifted(resListener *agamodel.Listener, sdkListener *ListenerResource) bool { // Check if protocol differs diff --git a/pkg/deploy/aga/listener_manager_mocks.go b/pkg/deploy/aga/listener_manager_mocks.go index b6c1d60f6..93699ee3c 100644 --- a/pkg/deploy/aga/listener_manager_mocks.go +++ b/pkg/deploy/aga/listener_manager_mocks.go @@ -8,6 +8,7 @@ import ( context "context" reflect "reflect" + types "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" gomock "github.com/golang/mock/gomock" aga0 "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" ) @@ -64,6 +65,21 @@ func (mr *MockListenerManagerMockRecorder) Delete(arg0, arg1 interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockListenerManager)(nil).Delete), arg0, arg1) } +// ListEndpointGroups mocks base method. +func (m *MockListenerManager) ListEndpointGroups(arg0 context.Context, arg1 string) ([]types.EndpointGroup, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListEndpointGroups", arg0, arg1) + ret0, _ := ret[0].([]types.EndpointGroup) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListEndpointGroups indicates an expected call of ListEndpointGroups. +func (mr *MockListenerManagerMockRecorder) ListEndpointGroups(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListEndpointGroups", reflect.TypeOf((*MockListenerManager)(nil).ListEndpointGroups), arg0, arg1) +} + // Update mocks base method. func (m *MockListenerManager) Update(arg0 context.Context, arg1 *aga0.Listener, arg2 *ListenerResource) (aga0.ListenerStatus, error) { m.ctrl.T.Helper() diff --git a/pkg/deploy/aga/listener_manager_test.go b/pkg/deploy/aga/listener_manager_test.go index a2c56ab17..ca6ea29f7 100644 --- a/pkg/deploy/aga/listener_manager_test.go +++ b/pkg/deploy/aga/listener_manager_test.go @@ -8,7 +8,9 @@ import ( "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" agatypes "github.com/aws/aws-sdk-go-v2/service/globalaccelerator/types" "github.com/go-logr/logr" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" ) @@ -316,9 +318,22 @@ func Test_defaultListenerManager_isSDKListenerSettingsDrifted(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Create listener manager + // Create mock controller + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + // Create mock GlobalAccelerator service + mockGAService := services.NewMockGlobalAccelerator(ctrl) + + // Set up the mock behavior + mockGAService.EXPECT(). + ListEndpointGroupsAsList(gomock.Any(), gomock.Any()). + Return([]agatypes.EndpointGroup{}, nil). + AnyTimes() + + // Create manager with mock service m := &defaultListenerManager{ - gaService: nil, // Not needed for this test + gaService: mockGAService, logger: logr.Discard(), } diff --git a/pkg/deploy/aga/listener_synthesizer.go b/pkg/deploy/aga/listener_synthesizer.go index 7e003d140..be4e0f8cb 100644 --- a/pkg/deploy/aga/listener_synthesizer.go +++ b/pkg/deploy/aga/listener_synthesizer.go @@ -9,6 +9,7 @@ import ( "github.com/go-logr/logr" "github.com/pkg/errors" "k8s.io/apimachinery/pkg/util/sets" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" @@ -97,55 +98,15 @@ func (s *listenerSynthesizer) synthesizeListenersOnAccelerator(ctx context.Conte // Improved operation order to minimize traffic disruption: // 1. Delete only conflicting listeners (that would block updates) - // 2. Update matched listeners - // 3. Delete unneeded (non-conflicting) listeners - // 4. Create new listeners + // 2. Process all port overrides which may block listener updates + // - Remove endpoint port overrides that overlap with any new listener port ranges + // - Remove listener port overrides from existing listener which is outside desired listener port ranges + // 3. Update matched listeners + // 4. Delete unneeded (non-conflicting) listeners + // 5. Create new listeners // STEP 1: Find SDK listeners that have port conflicts with planned updates - var conflictingListeners []*ListenerResource - var nonConflictingListeners []*ListenerResource - - // Track which listeners have port conflicts with our updates - conflictMap := make(map[string][]*ListenerResource) - - // For each update we're planning to do... - for _, pair := range matchedResAndSDKListeners { - var conflicts []*ListenerResource - - // Check against all unmatched SDK listeners for conflicts - for _, sdkListener := range unmatchedSDKListeners { - if s.hasPortRangeConflict(pair.resListener, sdkListener) { - conflicts = append(conflicts, sdkListener) - } - } - - // If there are conflicts, add them to our conflict map - if len(conflicts) > 0 { - conflictMap[pair.resListener.ID()] = conflicts - } - } - - // Build list of conflicting and non-conflicting listeners - listenerIsConflicting := make(map[string]bool) - - // Add all listeners with port conflicts to the conflicting list - for _, conflicts := range conflictMap { - for _, listener := range conflicts { - arn := *listener.Listener.ListenerArn - if !listenerIsConflicting[arn] { - conflictingListeners = append(conflictingListeners, listener) - listenerIsConflicting[arn] = true - } - } - } - - // Sort remaining unmatched listeners into non-conflicting - for _, sdkListener := range unmatchedSDKListeners { - arn := *sdkListener.Listener.ListenerArn - if !listenerIsConflicting[arn] { - nonConflictingListeners = append(nonConflictingListeners, sdkListener) - } - } + conflictingListeners, nonConflictingListeners := s.findConflictingAndNonConflictingListeners(matchedResAndSDKListeners, unmatchedSDKListeners) // STEP 2: Execute operations in correct order @@ -164,6 +125,15 @@ func (s *listenerSynthesizer) synthesizeListenersOnAccelerator(ctx context.Conte } } + // Next, Process all port overrides BEFORE updating listeners + allResListenerPortRanges, allSDKListenersToProcess, updatePortRangesByListener := s.preparePortOverrideProcessing(resListeners, matchedResAndSDKListeners, nonConflictingListeners) + + // Consolidated port override processing + if err := s.ProcessEndpointGroupPortOverrides(ctx, allSDKListenersToProcess, allResListenerPortRanges, updatePortRangesByListener); err != nil { + s.logger.Error(err, "Failed to process endpoint group port overrides") + return err + } + // Next, update existing matched listeners (now conflict-free) for _, pair := range matchedResAndSDKListeners { s.logger.Info("Updating existing listener", @@ -250,9 +220,9 @@ type resAndSDKListenerPair struct { // matchResAndSDKListeners matches resource listeners with SDK listeners using a multi-phase approach. // // The algorithm implements a two-phase matching process: -// 1. First phase (Exact Matching): Matches listeners with identical protocol and port ranges -// 2. Second phase (Similarity Matching): For remaining unmatched listeners, uses a similarity-based -// algorithm to find the best matches based on protocol and port range overlap +// 1. First phase (Exact Matching): Matches listeners with identical protocol and port ranges +// 2. Second phase (Similarity Matching): For remaining unmatched listeners, uses a similarity-based +// algorithm to find the best matches based on protocol and port range overlap // // Returns three groups: // - matchedResAndSDKListeners: pairs of resource and SDK listeners that will be updated @@ -288,7 +258,7 @@ func (s *listenerSynthesizer) matchResAndSDKListeners(resListeners []*agamodel.L // 3. Matches listeners with identical keys (exact protocol and port range matches) // 4. Returns matched pairs and remaining unmatched listeners // -// The key generation ensures that port ranges in different order but with identical +// The key generation ensures that port ranges in different order but with identical // values still match correctly. func (s *listenerSynthesizer) findExactMatches(resListeners []*agamodel.Listener, sdkListeners []*ListenerResource) ( []resAndSDKListenerPair, []*agamodel.Listener, []*ListenerResource) { @@ -468,22 +438,78 @@ func (s *listenerSynthesizer) findSimilarityMatches(resListeners []*agamodel.Lis return matchedPairs, unmatchedResListeners, unmatchedSDKListeners } +// findConflictingAndNonConflictingListeners separates unmatched SDK listeners into those that have +// port conflicts with planned updates and those that don't +func (s *listenerSynthesizer) findConflictingAndNonConflictingListeners( + matchedResAndSDKListeners []resAndSDKListenerPair, + unmatchedSDKListeners []*ListenerResource) ([]*ListenerResource, []*ListenerResource) { + + var conflictingListeners []*ListenerResource + var nonConflictingListeners []*ListenerResource + + // Track which listeners have port conflicts with our updates + conflictMap := make(map[string][]*ListenerResource) + + // For each update we're planning to do... + for _, pair := range matchedResAndSDKListeners { + var conflicts []*ListenerResource + + // Check against all unmatched SDK listeners for conflicts + for _, sdkListener := range unmatchedSDKListeners { + if s.hasPortRangeConflict(pair.resListener, sdkListener) { + conflicts = append(conflicts, sdkListener) + } + } + + // If there are conflicts, add them to our conflict map + if len(conflicts) > 0 { + conflictMap[pair.resListener.ID()] = conflicts + } + } + + // Build list of conflicting and non-conflicting listeners + listenerIsConflicting := make(map[string]bool) + + // Add all listeners with port conflicts to the conflicting list + for _, conflicts := range conflictMap { + for _, listener := range conflicts { + arn := *listener.Listener.ListenerArn + if !listenerIsConflicting[arn] { + conflictingListeners = append(conflictingListeners, listener) + listenerIsConflicting[arn] = true + } + } + } + + // Sort remaining unmatched listeners into non-conflicting + for _, sdkListener := range unmatchedSDKListeners { + arn := *sdkListener.Listener.ListenerArn + if !listenerIsConflicting[arn] { + nonConflictingListeners = append(nonConflictingListeners, sdkListener) + } + } + + return conflictingListeners, nonConflictingListeners +} + +// calculateSimilarityScore calculates how similar two listeners are +// Higher scores indicate better matches // calculateSimilarityScore calculates how similar two listeners are based on their attributes. // // The scoring system uses these components: // // 1. Base Protocol Score: -// - If protocols match: +40 points (significant bonus) -// - If protocols don't match: 0 points (no bonus) +// - If protocols match: +40 points (significant bonus) +// - If protocols don't match: 0 points (no bonus) // // 2. Port Overlap Score: -// - Uses Jaccard similarity: (intersection / union) * 100 -// - Calculates the percentage of common ports between the two listeners -// - Converts port ranges into individual port sets for precise comparison +// - Uses Jaccard similarity: (intersection / union) * 100 +// - Calculates the percentage of common ports between the two listeners +// - Converts port ranges into individual port sets for precise comparison // // 3. Client Affinity Score: -// - If both listeners have client affinity specified and they match: +10 points -// - Otherwise: 0 points (no bonus) +// - If both listeners have client affinity specified and they match: +10 points +// - Otherwise: 0 points (no bonus) // // Note: In the future, we might need to add endpoint matching as well as one of the // score components so that we match the listeners with the most endpoint matches @@ -597,3 +623,250 @@ func (s *listenerSynthesizer) hasPortRangeConflict(resListener *agamodel.Listene func (s *listenerSynthesizer) portRangesToString(portRanges []agamodel.PortRange) string { return ResPortRangesToString(portRanges) } + +// havePortRangesChanged checks if port ranges have changed between resource and SDK listener +func (s *listenerSynthesizer) havePortRangesChanged(resListener *agamodel.Listener, sdkListener *ListenerResource) bool { + if len(resListener.Spec.PortRanges) != len(sdkListener.Listener.PortRanges) { + return true + } + + // Build maps for easy comparison + resPortSet := s.makeResPortSet(resListener.Spec.PortRanges) + sdkPortSet := s.makeSDKPortSet(sdkListener.Listener.PortRanges) + + // If port sets have different sizes, they've changed + if len(resPortSet) != len(sdkPortSet) { + return true + } + + // Check if any port exists in one set but not the other + for port := range resPortSet { + if !sdkPortSet[port] { + return true + } + } + + for port := range sdkPortSet { + if !resPortSet[port] { + return true + } + } + + // Port ranges are the same + return false +} + +// ProcessEndpointGroupPortOverrides handles all port override validations and updates using a two-phase approach +// Phase 1: Collect all endpoint groups and analyze all port overrides for conflicts +// Phase 2: Execute updates for all identified conflicts +// +// The two-phase approach ensures consistent behavior regardless of processing order, since all +// analysis is completed before any modifications are made. +// +// It handles these validations: +// - Remove port overrides with endpoint port that overlap with any desired listener port ranges +// - Remove port overrides with listener port outside desired listener port ranges +func (s *listenerSynthesizer) ProcessEndpointGroupPortOverrides( + ctx context.Context, + listeners []*ListenerResource, + allListenerPortRanges []agamodel.PortRange, + updatePortRangesByListener map[string][]agamodel.PortRange) error { + + s.logger.V(1).Info("Processing all endpoint port overrides before updating listeners") + + // PHASE 1: Collection and Analysis + // Map of endpoint group ARN to its conflict information + type endpointGroupConflicts struct { + endpointGroup agatypes.EndpointGroup + listenerARN string + validPortOverrides []agatypes.PortOverride + invalidPortOverrides []agatypes.PortOverride + } + + // Store all conflicts to be resolved + conflictsByEndpointGroupARN := make(map[string]*endpointGroupConflicts) + + // Process each listener to collect all endpoint groups and analyze port overrides + for _, listener := range listeners { + listenerARN := awssdk.ToString(listener.Listener.ListenerArn) + + // List endpoint groups per listener + endpointGroups, err := s.listenerManager.ListEndpointGroups(ctx, listenerARN) + if err != nil { + return fmt.Errorf("failed to list endpoint groups for listener %s: %w", listenerARN, err) + } + + // Skip if no endpoint groups + if len(endpointGroups) == 0 { + continue + } + + // Get the updated port ranges for this listener if it's being updated + updatedPortRanges := updatePortRangesByListener[listenerARN] + + // Analyze each endpoint group's port overrides + for _, eg := range endpointGroups { + endpointGroupARN := awssdk.ToString(eg.EndpointGroupArn) + + // Skip if no port overrides to check + if eg.PortOverrides == nil || len(eg.PortOverrides) == 0 { + continue + } + + s.logger.V(1).Info("Analyzing endpoint group port overrides for conflicts", + "listenerARN", listenerARN, + "endpointGroupARN", endpointGroupARN) + + // Apply all validation rules and collect valid/invalid port overrides + validPortOverrides, invalidPortOverrides := s.processPortOverridesWithAllRules( + eg.PortOverrides, + allListenerPortRanges, + updatedPortRanges) + + // Only store conflicts if we found invalid overrides + if len(invalidPortOverrides) > 0 { + conflictsByEndpointGroupARN[endpointGroupARN] = &endpointGroupConflicts{ + endpointGroup: eg, + listenerARN: listenerARN, + validPortOverrides: validPortOverrides, + invalidPortOverrides: invalidPortOverrides, + } + } + } + } + + // PHASE 2: Execution - Update all endpoint groups with conflicts + // Process all conflicts + for endpointGroupARN := range conflictsByEndpointGroupARN { + conflictInfo := conflictsByEndpointGroupARN[endpointGroupARN] + + s.logger.V(1).Info("Updating endpoint group to remove conflicting port overrides", + "endpointGroupARN", endpointGroupARN, + "listenerARN", conflictInfo.listenerARN, + "conflictCount", len(conflictInfo.invalidPortOverrides)) + + // Update this endpoint group to remove the invalid port overrides + if err := s.updateEndpointGroupPortOverrides( + ctx, + conflictInfo.endpointGroup, + conflictInfo.validPortOverrides, + conflictInfo.invalidPortOverrides); err != nil { + return fmt.Errorf("failed to update endpoint group %s to remove conflicts: %w", + endpointGroupARN, err) + } + } + + return nil +} + +// processPortOverridesWithAllRules applies all validation rules to port overrides: +// 1. Endpoint ports must not overlap with any listener port ranges (if listener is being updated) +// 2. Listener ports must be within listener port ranges (if listener is being updated) +func (s *listenerSynthesizer) processPortOverridesWithAllRules( + portOverrides []agatypes.PortOverride, + allListenerPortRanges []agamodel.PortRange, + updatedListenerPortRanges []agamodel.PortRange) ([]agatypes.PortOverride, []agatypes.PortOverride) { + + validPortOverrides := make([]agatypes.PortOverride, 0) + invalidPortOverrides := make([]agatypes.PortOverride, 0) + for _, po := range portOverrides { + isValid := true + + // Rule 1: Endpoint port must not overlap with ANY listener port range + if aga.IsPortInRanges(awssdk.ToInt32(po.EndpointPort), allListenerPortRanges) { + isValid = false + s.logger.V(1).Info("Found port override with endpoint port that overlaps with a listener port range", + "endpointPort", awssdk.ToInt32(po.EndpointPort), + "listenerPort", awssdk.ToInt32(po.ListenerPort)) + } + + // Rule 2: If listener is being updated, listener port must be within updated port ranges + if isValid && len(updatedListenerPortRanges) > 0 && !aga.IsPortInRanges(awssdk.ToInt32(po.ListenerPort), updatedListenerPortRanges) { + isValid = false + s.logger.V(1).Info("Found port override with listener port outside updated listener port range", + "listenerPort", awssdk.ToInt32(po.ListenerPort), + "endpointPort", awssdk.ToInt32(po.EndpointPort)) + } + + // Add to appropriate collection based on validation result + if isValid { + validPortOverrides = append(validPortOverrides, po) + } else { + invalidPortOverrides = append(invalidPortOverrides, po) + } + } + + return validPortOverrides, invalidPortOverrides +} + +// preparePortOverrideProcessing collects all the port override processing requirements: +// - all res listener port ranges +// - all SDK listeners to process +// - map of listeners being updated with their new port ranges +func (s *listenerSynthesizer) preparePortOverrideProcessing( + resListeners []*agamodel.Listener, + matchedResAndSDKListeners []resAndSDKListenerPair, + nonConflictingListeners []*ListenerResource) ([]agamodel.PortRange, []*ListenerResource, map[string][]agamodel.PortRange) { + + // Collect all port ranges from resource listeners + var allResListenerPortRanges []agamodel.PortRange + for _, resListener := range resListeners { + allResListenerPortRanges = append(allResListenerPortRanges, resListener.Spec.PortRanges...) + } + + // Extract the SDK listeners from matchedResAndSDKListeners + var allSDKListenersToProcess []*ListenerResource + for _, pair := range matchedResAndSDKListeners { + allSDKListenersToProcess = append(allSDKListenersToProcess, pair.sdkListener) + } + + // Combine with nonConflictingListeners + allSDKListenersToProcess = append(allSDKListenersToProcess, nonConflictingListeners...) + + // Prepare map of listeners being updated with their new port ranges + updatePortRangesByListener := make(map[string][]agamodel.PortRange) + for _, pair := range matchedResAndSDKListeners { + if s.havePortRangesChanged(pair.resListener, pair.sdkListener) { + listenerARN := awssdk.ToString(pair.sdkListener.Listener.ListenerArn) + updatePortRangesByListener[listenerARN] = pair.resListener.Spec.PortRanges + } + } + + return allResListenerPortRanges, allSDKListenersToProcess, updatePortRangesByListener +} + +// updateEndpointGroupPortOverrides updates an endpoint group with valid port overrides +// and logs information about the removed invalid ones +func (s *listenerSynthesizer) updateEndpointGroupPortOverrides( + ctx context.Context, + endpointGroup agatypes.EndpointGroup, + validPortOverrides []agatypes.PortOverride, + invalidPortOverrides []agatypes.PortOverride) error { + + endpointGroupARN := awssdk.ToString(endpointGroup.EndpointGroupArn) + + // For logging purposes, record each removed override + for _, po := range invalidPortOverrides { + s.logger.V(1).Info("Removing port override", + "listenerPort", awssdk.ToInt32(po.ListenerPort), + "endpointPort", awssdk.ToInt32(po.EndpointPort), + "endpointGroupARN", endpointGroupARN) + } + + // Update the endpoint group with only valid port overrides + _, err := s.gaClient.UpdateEndpointGroupWithContext(ctx, &globalaccelerator.UpdateEndpointGroupInput{ + EndpointGroupArn: endpointGroup.EndpointGroupArn, + PortOverrides: validPortOverrides, + }) + + if err != nil { + return fmt.Errorf("failed to update endpoint group %s for port overrides to remove conflicts: %w", endpointGroupARN, err) + } + + s.logger.Info("Successfully updated endpoint group port overrides to remove conflicts", + "endpointGroupARN", endpointGroupARN, + "removedCount", len(invalidPortOverrides), + "remainingCount", len(validPortOverrides)) + + return nil +} diff --git a/pkg/deploy/aga/listener_synthesizer_test.go b/pkg/deploy/aga/listener_synthesizer_test.go index edd01feca..0358c87db 100644 --- a/pkg/deploy/aga/listener_synthesizer_test.go +++ b/pkg/deploy/aga/listener_synthesizer_test.go @@ -1,6 +1,11 @@ package aga import ( + "context" + "github.com/aws/aws-sdk-go-v2/service/globalaccelerator" + "github.com/golang/mock/gomock" + pkgaga "sigs.k8s.io/aws-load-balancer-controller/pkg/aga" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" "sort" "testing" @@ -1765,3 +1770,525 @@ func Test_listenerSynthesizer_generateSDKListenerKey(t *testing.T) { }) } } + +func Test_listenerSynthesizer_processPortOverridesWithAllRules(t *testing.T) { + tests := []struct { + name string + portOverrides []agatypes.PortOverride + allListenerPortRanges []agamodel.PortRange + updatedListenerPortRanges []agamodel.PortRange + wantValidCount int + wantInvalidCount int + wantInvalidPortsEndpoint []int32 + wantInvalidPortsListener []int32 + }{ + { + name: "empty port overrides", + portOverrides: []agatypes.PortOverride{}, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 85}}, + wantValidCount: 0, + wantInvalidCount: 0, + }, + { + name: "all port overrides valid", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8081)}, + }, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 85}}, + wantValidCount: 2, + wantInvalidCount: 0, + }, + { + name: "endpoint port overlaps with listener port range", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(85)}, // Invalid: endpoint port in listener range + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(8081)}, // Valid + }, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 85}}, + wantValidCount: 1, + wantInvalidCount: 1, + wantInvalidPortsEndpoint: []int32{85}, + }, + { + name: "listener port outside updated range", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Valid + {ListenerPort: awssdk.Int32(87), EndpointPort: awssdk.Int32(8087)}, // Invalid: listener port outside updated range + }, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 85}}, + wantValidCount: 1, + wantInvalidCount: 1, + wantInvalidPortsListener: []int32{87}, + }, + { + name: "multiple invalid conditions", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Valid + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(85)}, // Invalid: endpoint port in listener range + {ListenerPort: awssdk.Int32(87), EndpointPort: awssdk.Int32(8087)}, // Invalid: listener port outside updated range + }, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 85}}, + wantValidCount: 1, + wantInvalidCount: 2, + wantInvalidPortsEndpoint: []int32{85}, + wantInvalidPortsListener: []int32{87}, + }, + { + name: "no updated listener port ranges", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Valid except for endpoint port check + {ListenerPort: awssdk.Int32(81), EndpointPort: awssdk.Int32(85)}, // Invalid: endpoint port in listener range + }, + allListenerPortRanges: []agamodel.PortRange{{FromPort: 80, ToPort: 90}}, + updatedListenerPortRanges: []agamodel.PortRange{}, + wantValidCount: 1, + wantInvalidCount: 1, + wantInvalidPortsEndpoint: []int32{85}, + }, + { + name: "multiple listener port ranges - endpoint port in one range", + portOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Valid + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(443)}, // Invalid: endpoint port in another listener range + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + {FromPort: 443, ToPort: 443}, + }, + updatedListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + {FromPort: 443, ToPort: 443}, + }, + wantValidCount: 1, + wantInvalidCount: 1, + wantInvalidPortsEndpoint: []int32{443}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &listenerSynthesizer{ + logger: logr.Discard(), + } + + validPortOverrides, invalidPortOverrides := s.processPortOverridesWithAllRules( + tt.portOverrides, + tt.allListenerPortRanges, + tt.updatedListenerPortRanges, + ) + + // Verify counts + assert.Equal(t, tt.wantValidCount, len(validPortOverrides), "valid port overrides count") + assert.Equal(t, tt.wantInvalidCount, len(invalidPortOverrides), "invalid port overrides count") + + // If specific invalid endpoint ports were expected, check those + if tt.wantInvalidPortsEndpoint != nil { + var actualInvalidEndpointPorts []int32 + for _, po := range invalidPortOverrides { + // If this port was invalidated because of endpoint port overlap + if pkgaga.IsPortInRanges(awssdk.ToInt32(po.EndpointPort), tt.allListenerPortRanges) { + actualInvalidEndpointPorts = append(actualInvalidEndpointPorts, awssdk.ToInt32(po.EndpointPort)) + } + } + assert.ElementsMatch(t, tt.wantInvalidPortsEndpoint, actualInvalidEndpointPorts, "invalid endpoint ports") + } + + // If specific invalid listener ports were expected, check those + if tt.wantInvalidPortsListener != nil && len(tt.updatedListenerPortRanges) > 0 { + var actualInvalidListenerPorts []int32 + for _, po := range invalidPortOverrides { + // If this port was invalidated because listener port was outside updated ranges + if !pkgaga.IsPortInRanges(awssdk.ToInt32(po.EndpointPort), tt.allListenerPortRanges) && + !pkgaga.IsPortInRanges(awssdk.ToInt32(po.ListenerPort), tt.updatedListenerPortRanges) { + actualInvalidListenerPorts = append(actualInvalidListenerPorts, awssdk.ToInt32(po.ListenerPort)) + } + } + assert.ElementsMatch(t, tt.wantInvalidPortsListener, actualInvalidListenerPorts, "invalid listener ports") + } + }) + } +} + +func TestListenerSynthesizer_ProcessEndpointGroupPortOverrides(t *testing.T) { + tests := []struct { + name string + listeners []*ListenerResource + allListenerPortRanges []agamodel.PortRange + updatePortRangesByListener map[string][]agamodel.PortRange + endpointGroups map[string][]agatypes.EndpointGroup + updateCalls map[string][]agatypes.PortOverride + expectError bool + }{ + { + name: "no endpoint groups - no updates needed", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{}, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": {}, // No endpoint groups + }, + updateCalls: map[string][]agatypes.PortOverride{}, + expectError: false, + }, + { + name: "no port overrides - no updates needed", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(80)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 80}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{}, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup1"), + PortOverrides: []agatypes.PortOverride{}, // No port overrides + }, + }, + }, + updateCalls: map[string][]agatypes.PortOverride{}, + expectError: false, + }, + { + name: "endpoint port overlaps with listener port range - should be removed", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(90)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{}, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup1"), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(85)}, // Overlaps (endpoint port in listener range) + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Valid + }, + }, + }, + }, + updateCalls: map[string][]agatypes.PortOverride{ + "arn:endpointgroup1": { + {ListenerPort: awssdk.Int32(80), EndpointPort: awssdk.Int32(8080)}, // Only the valid one remains + }, + }, + expectError: false, + }, + { + name: "listener port outside updated ranges - should be removed", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(90)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{ + "arn:listener1": { + {FromPort: 80, ToPort: 85}, // Narrower range than current + }, + }, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup1"), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(8082)}, // Valid - within updated range + {ListenerPort: awssdk.Int32(88), EndpointPort: awssdk.Int32(8088)}, // Invalid - outside updated range + }, + }, + }, + }, + updateCalls: map[string][]agatypes.PortOverride{ + "arn:endpointgroup1": { + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(8082)}, // Only the valid one remains + }, + }, + expectError: false, + }, + { + name: "multiple listeners with multiple endpoint groups - consolidated processing", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(90)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener2"), + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(53), ToPort: awssdk.Int32(53)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + {FromPort: 53, ToPort: 53}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{ + "arn:listener1": { + {FromPort: 80, ToPort: 85}, // Narrowed range + }, + // No update for listener2 + }, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup1"), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(82)}, // Invalid - endpoint port overlaps + {ListenerPort: awssdk.Int32(88), EndpointPort: awssdk.Int32(8088)}, // Invalid - listener port outside updated range + {ListenerPort: awssdk.Int32(85), EndpointPort: awssdk.Int32(8085)}, // Valid + }, + }, + }, + "arn:listener2": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup2"), + PortOverrides: []agatypes.PortOverride{ + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(5353)}, // Valid + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(53)}, // Invalid - endpoint port overlaps + }, + }, + }, + }, + updateCalls: map[string][]agatypes.PortOverride{ + "arn:endpointgroup1": { + {ListenerPort: awssdk.Int32(85), EndpointPort: awssdk.Int32(8085)}, // Only valid one remains + }, + "arn:endpointgroup2": { + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(5353)}, // Only valid one remains + }, + }, + expectError: false, + }, + { + name: "multiple listeners each with both types of port override conflicts", + listeners: []*ListenerResource{ + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener1"), + Protocol: agatypes.ProtocolTcp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(80), ToPort: awssdk.Int32(90)}, + {FromPort: awssdk.Int32(443), ToPort: awssdk.Int32(443)}, + }, + }, + }, + { + Listener: &agatypes.Listener{ + ListenerArn: awssdk.String("arn:listener2"), + Protocol: agatypes.ProtocolUdp, + PortRanges: []agatypes.PortRange{ + {FromPort: awssdk.Int32(53), ToPort: awssdk.Int32(53)}, + {FromPort: awssdk.Int32(8000), ToPort: awssdk.Int32(8010)}, + }, + }, + }, + }, + allListenerPortRanges: []agamodel.PortRange{ + {FromPort: 80, ToPort: 90}, + {FromPort: 443, ToPort: 443}, + {FromPort: 53, ToPort: 53}, + {FromPort: 8000, ToPort: 8010}, + }, + updatePortRangesByListener: map[string][]agamodel.PortRange{ + "arn:listener1": { + {FromPort: 80, ToPort: 85}, // Narrowed range + {FromPort: 443, ToPort: 443}, // Unchanged + }, + "arn:listener2": { + {FromPort: 53, ToPort: 53}, // Unchanged + {FromPort: 8000, ToPort: 8005}, // Narrowed range + }, + }, + endpointGroups: map[string][]agatypes.EndpointGroup{ + "arn:listener1": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup1"), + PortOverrides: []agatypes.PortOverride{ + // Both types of issues + {ListenerPort: awssdk.Int32(82), EndpointPort: awssdk.Int32(82)}, // Invalid - endpoint port overlaps + {ListenerPort: awssdk.Int32(88), EndpointPort: awssdk.Int32(8088)}, // Invalid - listener port outside updated range + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(443)}, // Invalid - endpoint port overlaps + {ListenerPort: awssdk.Int32(85), EndpointPort: awssdk.Int32(8085)}, // Valid + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, // Valid + }, + }, + }, + "arn:listener2": { + { + EndpointGroupArn: awssdk.String("arn:endpointgroup2"), + PortOverrides: []agatypes.PortOverride{ + // Both types of issues + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(53)}, // Invalid - endpoint port overlaps + {ListenerPort: awssdk.Int32(8008), EndpointPort: awssdk.Int32(9008)}, // Invalid - listener port outside updated range + {ListenerPort: awssdk.Int32(8000), EndpointPort: awssdk.Int32(8000)}, // Invalid - endpoint port overlaps + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(5353)}, // Valid + {ListenerPort: awssdk.Int32(8005), EndpointPort: awssdk.Int32(9005)}, // Valid + }, + }, + }, + }, + updateCalls: map[string][]agatypes.PortOverride{ + "arn:endpointgroup1": { + // Only valid port overrides remain + {ListenerPort: awssdk.Int32(85), EndpointPort: awssdk.Int32(8085)}, + {ListenerPort: awssdk.Int32(443), EndpointPort: awssdk.Int32(8443)}, + }, + "arn:endpointgroup2": { + // Only valid port overrides remain + {ListenerPort: awssdk.Int32(53), EndpointPort: awssdk.Int32(5353)}, + {ListenerPort: awssdk.Int32(8005), EndpointPort: awssdk.Int32(9005)}, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mocks + // Setup controller and mocks + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockGaClient := services.NewMockGlobalAccelerator(ctrl) + + // Create mock listener manager using gomock + mockListenerManager := NewMockListenerManager(ctrl) + + // Setup expectations for mockListenerManager.ListEndpointGroups + for listenerARN, endpointGroups := range tt.endpointGroups { + mockListenerManager.EXPECT(). + ListEndpointGroups(gomock.Any(), listenerARN). + Return(endpointGroups, nil) + } + + // Track which endpoint groups have been updated + updatedEndpointGroups := make(map[string]bool) + + // Setup expectations for mockGaClient.UpdateEndpointGroupWithContext + for _, expectedPortMapping := range tt.updateCalls { + // If we expect an update for this endpoint group, mock the call + if len(expectedPortMapping) > 0 { + mockGaClient.EXPECT(). + UpdateEndpointGroupWithContext(gomock.Any(), gomock.Any()). + AnyTimes(). + DoAndReturn( + func(_ context.Context, input *globalaccelerator.UpdateEndpointGroupInput) (*globalaccelerator.UpdateEndpointGroupOutput, error) { + // Mark this endpoint group as updated + arn := awssdk.ToString(input.EndpointGroupArn) + updatedEndpointGroups[arn] = true + + // Get the expected port mapping for this endpoint group + expectedMapping, exists := tt.updateCalls[arn] + assert.True(t, exists, "Unexpected endpoint group update: %s", arn) + + // Verify port overrides count + assert.Equal(t, len(expectedMapping), len(input.PortOverrides)) + + // Create map of actual port overrides + actualMapping := make(map[int32]int32) + for _, po := range input.PortOverrides { + actualMapping[awssdk.ToInt32(po.ListenerPort)] = awssdk.ToInt32(po.EndpointPort) + } + + // Convert actual port overrides to a map for easier comparison + actualMap := make(map[int32]int32) + for _, po := range input.PortOverrides { + actualMap[awssdk.ToInt32(po.ListenerPort)] = awssdk.ToInt32(po.EndpointPort) + } + + // Create expected map + expectedMap := make(map[int32]int32) + for _, po := range expectedMapping { + expectedMap[awssdk.ToInt32(po.ListenerPort)] = awssdk.ToInt32(po.EndpointPort) + } + + // Verify the mappings match + assert.Equal(t, expectedMap, actualMap) + + return &globalaccelerator.UpdateEndpointGroupOutput{}, nil + }) + } + } + + // Create the synthesizer with the mocks + s := NewListenerSynthesizer(mockGaClient, mockListenerManager, logr.Discard(), nil) + + // Call the method under test + err := s.ProcessEndpointGroupPortOverrides( + context.Background(), + tt.listeners, + tt.allListenerPortRanges, + tt.updatePortRangesByListener, + ) + + // Assert results + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + + // Verify all expected endpoint group updates happened + for endpointGroupARN, expectedMapping := range tt.updateCalls { + if len(expectedMapping) > 0 { + assert.True(t, updatedEndpointGroups[endpointGroupARN], + "Expected endpoint group %s to be updated", endpointGroupARN) + } + } + } + }) + } +} diff --git a/pkg/deploy/aga/stack_deployer.go b/pkg/deploy/aga/stack_deployer.go index 8300882ca..a023758b9 100644 --- a/pkg/deploy/aga/stack_deployer.go +++ b/pkg/deploy/aga/stack_deployer.go @@ -32,25 +32,25 @@ func NewDefaultStackDeployer(cloud services.Cloud, config config.ControllerConfi // Create actual managers agaTaggingManager := NewDefaultTaggingManager(cloud.GlobalAccelerator(), cloud.RGT(), logger) - listenerManager := NewDefaultListenerManager(cloud.GlobalAccelerator(), logger) + endpointGroupManager := NewDefaultEndpointGroupManager(cloud.GlobalAccelerator(), logger) + listenerManager := NewDefaultListenerManager(cloud.GlobalAccelerator(), endpointGroupManager, logger) acceleratorManager := NewDefaultAcceleratorManager(cloud.GlobalAccelerator(), trackingProvider, agaTaggingManager, listenerManager, config.ExternalManagedTags, logger) // TODO: Create other managers when they are implemented - // endpointGroupManager := NewDefaultEndpointGroupManager(cloud.GlobalAccelerator(), trackingProvider, agaTaggingManager, config.ExternalManagedTags, logger) // endpointManager := NewDefaultEndpointManager(cloud.GlobalAccelerator(), logger) return &defaultStackDeployer{ - cloud: cloud, - controllerConfig: config, - trackingProvider: trackingProvider, - featureGates: config.FeatureGates, - logger: logger, - metricsCollector: metricsCollector, - controllerName: controllerName, - agaTaggingManager: agaTaggingManager, - acceleratorManager: acceleratorManager, - listenerManager: listenerManager, + cloud: cloud, + controllerConfig: config, + trackingProvider: trackingProvider, + featureGates: config.FeatureGates, + logger: logger, + metricsCollector: metricsCollector, + controllerName: controllerName, + agaTaggingManager: agaTaggingManager, + acceleratorManager: acceleratorManager, + listenerManager: listenerManager, + endpointGroupManager: endpointGroupManager, // TODO: Set other managers when implemented - // endpointGroupManager: endpointGroupManager, // endpointManager: endpointManager, } } @@ -68,11 +68,11 @@ type defaultStackDeployer struct { controllerName string // Actual managers - agaTaggingManager TaggingManager - acceleratorManager AcceleratorManager - listenerManager ListenerManager + agaTaggingManager TaggingManager + acceleratorManager AcceleratorManager + listenerManager ListenerManager + endpointGroupManager EndpointGroupManager // TODO: Add other managers when implemented - // endpointGroupManager EndpointGroupManager // endpointManager EndpointManager } @@ -92,8 +92,8 @@ func (d *defaultStackDeployer) Deploy(ctx context.Context, stack core.Stack, met synthesizers = append(synthesizers, NewAcceleratorSynthesizer(d.cloud.GlobalAccelerator(), d.trackingProvider, d.agaTaggingManager, d.acceleratorManager, d.logger, d.featureGates, stack), NewListenerSynthesizer(d.cloud.GlobalAccelerator(), d.listenerManager, d.logger, stack), + NewEndpointGroupSynthesizer(d.cloud.GlobalAccelerator(), d.endpointGroupManager, d.logger, stack), // TODO: Add other synthesizers when managers are implemented - // NewEndpointGroupSynthesizer(d.cloud.GlobalAccelerator(), d.trackingProvider, d.agaTaggingManager, d.endpointGroupManager, d.logger, d.featureGates, stack), // NewEndpointSynthesizer(d.cloud.GlobalAccelerator(), d.trackingProvider, d.endpointManager, d.logger, d.featureGates, stack), ) diff --git a/pkg/k8s/events.go b/pkg/k8s/events.go index efad8aa26..42c966f2f 100644 --- a/pkg/k8s/events.go +++ b/pkg/k8s/events.go @@ -52,6 +52,7 @@ const ( GlobalAcceleratorEventReasonFailedUpdateStatus = "FailedUpdateStatus" GlobalAcceleratorEventReasonFailedCleanup = "FailedCleanup" GlobalAcceleratorEventReasonFailedBuildModel = "FailedBuildModel" + GlobalAcceleratorEventReasonFailedEndpointLoad = "FailedEndpointLoad" GlobalAcceleratorEventReasonFailedDeploy = "FailedDeploy" GlobalAcceleratorEventReasonSuccessfullyReconciled = "SuccessfullyReconciled" ) diff --git a/pkg/model/aga/endpoint_group.go b/pkg/model/aga/endpoint_group.go new file mode 100644 index 000000000..dfa5e1ba0 --- /dev/null +++ b/pkg/model/aga/endpoint_group.go @@ -0,0 +1,102 @@ +package aga + +import ( + "context" + "github.com/pkg/errors" + "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" +) + +const ( + // ResourceTypeEndpointGroup is the resource type for Global Accelerator EndpointGroup + ResourceTypeEndpointGroup = "AWS::GlobalAccelerator::EndpointGroup" +) + +var _ core.Resource = &EndpointGroup{} + +// EndpointGroup represents an AWS Global Accelerator EndpointGroup. +type EndpointGroup struct { + core.ResourceMeta `json:"-"` + + // desired state of EndpointGroup + Spec EndpointGroupSpec `json:"spec"` + + // observed state of EndpointGroup + // +optional + Status *EndpointGroupStatus `json:"status,omitempty"` + + // reference to Listener resource + Listener *Listener `json:"-"` +} + +// NewEndpointGroup constructs new EndpointGroup resource. +func NewEndpointGroup(stack core.Stack, id string, spec EndpointGroupSpec, listener *Listener) *EndpointGroup { + endpointGroup := &EndpointGroup{ + ResourceMeta: core.NewResourceMeta(stack, ResourceTypeEndpointGroup, id), + Spec: spec, + Status: nil, + Listener: listener, + } + stack.AddResource(endpointGroup) + endpointGroup.registerDependencies(stack) + return endpointGroup +} + +// SetStatus sets the EndpointGroup's status +func (eg *EndpointGroup) SetStatus(status EndpointGroupStatus) { + eg.Status = &status +} + +// EndpointGroupARN returns The Amazon Resource Name (ARN) of the endpoint group. +func (eg *EndpointGroup) EndpointGroupARN() core.StringToken { + return core.NewResourceFieldStringToken(eg, "status/endpointGroupARN", + func(ctx context.Context, res core.Resource, fieldPath string) (s string, err error) { + endpointGroup := res.(*EndpointGroup) + if endpointGroup.Status == nil { + return "", errors.Errorf("EndpointGroup is not fulfilled yet: %v", endpointGroup.ID()) + } + return endpointGroup.Status.EndpointGroupARN, nil + }, + ) +} + +// register dependencies for EndpointGroup. +func (eg *EndpointGroup) registerDependencies(stack core.Stack) { + // EndpointGroup depends on its Listener + stack.AddDependency(eg, eg.Listener) +} + +// PortOverride defines the port override for Global Accelerator endpoint groups. +type PortOverride struct { + // ListenerPort is the listener port that you want to map to a specific endpoint port. + ListenerPort int32 `json:"listenerPort"` + + // EndpointPort is the endpoint port that you want traffic to be routed to. + EndpointPort int32 `json:"endpointPort"` +} + +// EndpointGroupSpec defines the desired state of EndpointGroup +type EndpointGroupSpec struct { + // ListenerARN is the ARN of the listener for the endpoint group + ListenerARN core.StringToken `json:"listenerARN"` + + // Region is the AWS Region where the endpoint group is located. + Region string `json:"region"` + + // TrafficDialPercentage is the percentage of traffic to send to an AWS Region. + // +optional + TrafficDialPercentage *int32 `json:"trafficDialPercentage,omitempty"` + + // PortOverrides is a list of endpoint port overrides. + // +optional + PortOverrides []PortOverride `json:"portOverrides,omitempty"` + + // EndpointConfigurations is a list of endpoint configurations for the endpoint group. + // +optional + // This field is not implemented in the initial version as it will be part of a separate endpoint builder. +} + +// EndpointGroupStatus defines the observed state of EndpointGroup +type EndpointGroupStatus struct { + // EndpointGroupARN is the Amazon Resource Name (ARN) of the endpoint group. + EndpointGroupARN string `json:"endpointGroupARN"` +} diff --git a/pkg/status/aga/status_updater.go b/pkg/status/aga/status_updater.go index 0fd2f046d..75cbfb827 100644 --- a/pkg/status/aga/status_updater.go +++ b/pkg/status/aga/status_updater.go @@ -188,7 +188,7 @@ func (u *defaultStatusUpdater) UpdateStatusFailure(ctx context.Context, ga *v1be Status: metav1.ConditionFalse, LastTransitionTime: metav1.Now(), Reason: reason, - Message: message, + Message: "Reconciliation failed. See events and controller logs for details", } conditionUpdated := u.updateCondition(&ga.Status.Conditions, failureCondition) diff --git a/pkg/status/aga/status_updater_test.go b/pkg/status/aga/status_updater_test.go index 7ea74b968..440fa4917 100644 --- a/pkg/status/aga/status_updater_test.go +++ b/pkg/status/aga/status_updater_test.go @@ -291,7 +291,7 @@ func Test_defaultStatusUpdater_UpdateStatusFailure(t *testing.T) { }, }, reason: "ProvisioningFailed", - message: "Failed to provision accelerator: validation error", + message: "Reconciliation failed. See events and controller logs for details", validateStatus: func(t *testing.T, ga *v1beta1.GlobalAccelerator) { // Check that observed generation was updated assert.NotNil(t, ga.Status.ObservedGeneration) @@ -303,7 +303,7 @@ func Test_defaultStatusUpdater_UpdateStatusFailure(t *testing.T) { assert.Equal(t, ConditionTypeReady, condition.Type) assert.Equal(t, metav1.ConditionFalse, condition.Status) assert.Equal(t, "ProvisioningFailed", condition.Reason) - assert.Equal(t, "Failed to provision accelerator: validation error", condition.Message) + assert.Equal(t, "Reconciliation failed. See events and controller logs for details", condition.Message) }, }, { @@ -328,7 +328,7 @@ func Test_defaultStatusUpdater_UpdateStatusFailure(t *testing.T) { }, }, reason: "NewError", - message: "New error message", + message: "Reconciliation failed. See events and controller logs for details", validateStatus: func(t *testing.T, ga *v1beta1.GlobalAccelerator) { // Check that observed generation was updated assert.NotNil(t, ga.Status.ObservedGeneration) @@ -340,7 +340,7 @@ func Test_defaultStatusUpdater_UpdateStatusFailure(t *testing.T) { assert.Equal(t, ConditionTypeReady, condition.Type) assert.Equal(t, metav1.ConditionFalse, condition.Status) assert.Equal(t, "NewError", condition.Reason) - assert.Equal(t, "New error message", condition.Message) + assert.Equal(t, "Reconciliation failed. See events and controller logs for details", condition.Message) }, }, { @@ -359,13 +359,13 @@ func Test_defaultStatusUpdater_UpdateStatusFailure(t *testing.T) { Status: metav1.ConditionFalse, LastTransitionTime: metav1.Now(), Reason: "SameError", - Message: "Same error message", + Message: "Reconciliation failed. See events and controller logs for details", }, }, }, }, reason: "SameError", - message: "Same error message", + message: "Reconciliation failed. See events and controller logs for details", validateStatus: func(t *testing.T, ga *v1beta1.GlobalAccelerator) { // Status should be unchanged assert.NotNil(t, ga.Status.ObservedGeneration) diff --git a/scripts/gen_mocks.sh b/scripts/gen_mocks.sh index 413cbec8d..73338c9e7 100755 --- a/scripts/gen_mocks.sh +++ b/scripts/gen_mocks.sh @@ -26,6 +26,7 @@ $MOCKGEN -package=networking -destination=./pkg/networking/vpc_info_provider_moc $MOCKGEN -package=networking -destination=./pkg/networking/backend_sg_provider_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/networking BackendSGProvider $MOCKGEN -package=networking -destination=./pkg/networking/security_group_resolver_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/networking SecurityGroupResolver $MOCKGEN -package=aga -destination=./pkg/deploy/aga/accelerator_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga AcceleratorManager +$MOCKGEN -package=aga -destination=./pkg/deploy/aga/endpoint_group_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga EndpointGroupManager $MOCKGEN -package=aga -destination=./pkg/deploy/aga/listener_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga ListenerManager $MOCKGEN -package=aga -destination=./pkg/deploy/aga/tagging_manager_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/aga TaggingManager $MOCKGEN -package=certs -destination=./pkg/certs/cert_discovery_mocks.go sigs.k8s.io/aws-load-balancer-controller/pkg/certs CertDiscovery