diff --git a/controllers/aga/globalaccelerator_controller.go b/controllers/aga/globalaccelerator_controller.go index 07277cf84..8601b6c45 100644 --- a/controllers/aga/globalaccelerator_controller.go +++ b/controllers/aga/globalaccelerator_controller.go @@ -107,6 +107,7 @@ func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder recor config.ExternalManagedTags, logger.WithName("aga-model-builder"), metricsCollector, + cloud.ELBV2(), ) // Create stack marshaller diff --git a/controllers/ingress/group_controller.go b/controllers/ingress/group_controller.go index 9e776bb90..efbe35a2e 100644 --- a/controllers/ingress/group_controller.go +++ b/controllers/ingress/group_controller.go @@ -153,7 +153,7 @@ func (r *groupReconciler) reconcile(ctx context.Context, req reconcile.Request) return ctrlerrors.NewErrorWithMetrics(controllerName, "add_group_finalizer_error", err, r.metricsCollector) } - _, lb, frontendNlb, err := r.buildAndDeployModel(ctx, ingGroup) + _, lb, frontendNlb, listenerPorts, err := r.buildAndDeployModel(ctx, ingGroup) if err != nil { return err } @@ -173,7 +173,7 @@ func (r *groupReconciler) reconcile(ctx context.Context, req reconcile.Request) return } } - statusErr = r.updateIngressGroupStatus(ctx, ingGroup, lbDNS, frontendNlbDNS) + statusErr = r.updateIngressGroupStatus(ctx, ingGroup, lbDNS, frontendNlbDNS, listenerPorts) if statusErr != nil { r.recordIngressGroupEvent(ctx, ingGroup, corev1.EventTypeWarning, k8s.IngressEventReasonFailedUpdateStatus, fmt.Sprintf("Failed update status due to %v", statusErr)) @@ -200,25 +200,26 @@ func (r *groupReconciler) reconcile(ctx context.Context, req reconcile.Request) return nil } -func (r *groupReconciler) buildAndDeployModel(ctx context.Context, ingGroup ingress.Group) (core.Stack, *elbv2model.LoadBalancer, *elbv2model.LoadBalancer, error) { +func (r *groupReconciler) buildAndDeployModel(ctx context.Context, ingGroup ingress.Group) (core.Stack, *elbv2model.LoadBalancer, *elbv2model.LoadBalancer, []int32, error) { var stack core.Stack var lb *elbv2model.LoadBalancer var secrets []types.NamespacedName var backendSGRequired bool var err error var frontendNlb *elbv2model.LoadBalancer + var listenerPorts []int32 buildModelFn := func() { - stack, lb, secrets, backendSGRequired, frontendNlb, err = r.modelBuilder.Build(ctx, ingGroup, r.metricsCollector) + stack, lb, secrets, backendSGRequired, frontendNlb, listenerPorts, err = r.modelBuilder.Build(ctx, ingGroup, r.metricsCollector) } r.metricsCollector.ObserveControllerReconcileLatency(controllerName, "build_model", buildModelFn) if err != nil { r.recordIngressGroupEvent(ctx, ingGroup, corev1.EventTypeWarning, k8s.IngressEventReasonFailedBuildModel, fmt.Sprintf("Failed build model due to %v", err)) - return nil, nil, nil, ctrlerrors.NewErrorWithMetrics(controllerName, "build_model_error", err, r.metricsCollector) + return nil, nil, nil, nil, ctrlerrors.NewErrorWithMetrics(controllerName, "build_model_error", err, r.metricsCollector) } stackJSON, err := r.stackMarshaller.Marshal(stack) if err != nil { r.recordIngressGroupEvent(ctx, ingGroup, corev1.EventTypeWarning, k8s.IngressEventReasonFailedBuildModel, fmt.Sprintf("Failed build model due to %v", err)) - return nil, nil, nil, err + return nil, nil, nil, nil, err } r.logger.Info("successfully built model", "model", stackJSON) @@ -229,10 +230,10 @@ func (r *groupReconciler) buildAndDeployModel(ctx context.Context, ingGroup ingr if err != nil { var requeueNeededAfter *ctrlerrors.RequeueNeededAfter if errors.As(err, &requeueNeededAfter) { - return nil, nil, nil, err + return nil, nil, nil, nil, err } r.recordIngressGroupEvent(ctx, ingGroup, corev1.EventTypeWarning, k8s.IngressEventReasonFailedDeployModel, fmt.Sprintf("Failed deploy model due to %v", err)) - return nil, nil, nil, ctrlerrors.NewErrorWithMetrics(controllerName, "deploy_model_error", err, r.metricsCollector) + return nil, nil, nil, nil, ctrlerrors.NewErrorWithMetrics(controllerName, "deploy_model_error", err, r.metricsCollector) } r.logger.Info("successfully deployed model", "ingressGroup", ingGroup.ID) r.secretsManager.MonitorSecrets(ingGroup.ID.String(), secrets) @@ -242,9 +243,9 @@ func (r *groupReconciler) buildAndDeployModel(ctx context.Context, ingGroup ingr inactiveResources = append(inactiveResources, k8s.ToSliceOfNamespacedNames(ingGroup.Members)...) } if err := r.backendSGProvider.Release(ctx, networkingpkg.ResourceTypeIngress, inactiveResources); err != nil { - return nil, nil, nil, ctrlerrors.NewErrorWithMetrics(controllerName, "release_auto_generated_backend_sg_error", err, r.metricsCollector) + return nil, nil, nil, nil, ctrlerrors.NewErrorWithMetrics(controllerName, "release_auto_generated_backend_sg_error", err, r.metricsCollector) } - return stack, lb, frontendNlb, nil + return stack, lb, frontendNlb, listenerPorts, nil } func (r *groupReconciler) recordIngressGroupEvent(_ context.Context, ingGroup ingress.Group, eventType string, reason string, message string) { @@ -253,16 +254,23 @@ func (r *groupReconciler) recordIngressGroupEvent(_ context.Context, ingGroup in } } -func (r *groupReconciler) updateIngressGroupStatus(ctx context.Context, ingGroup ingress.Group, lbDNS string, frontendNLBDNS string) error { +func (r *groupReconciler) updateIngressGroupStatus(ctx context.Context, ingGroup ingress.Group, lbDNS string, frontendNLBDNS string, listenerPorts []int32) error { for _, member := range ingGroup.Members { - if err := r.updateIngressStatus(ctx, lbDNS, frontendNLBDNS, member.Ing); err != nil { + if err := r.updateIngressStatus(ctx, lbDNS, frontendNLBDNS, member.Ing, listenerPorts); err != nil { return err } } return nil } -func (r *groupReconciler) updateIngressStatus(ctx context.Context, lbDNS string, frontendNlbDNS string, ing *networking.Ingress) error { +func (r *groupReconciler) updateIngressStatus(ctx context.Context, lbDNS string, frontendNlbDNS string, ing *networking.Ingress, ports []int32) error { + ingressPorts := make([]networking.IngressPortStatus, len(ports)) + for i, port := range ports { + ingressPorts[i] = networking.IngressPortStatus{ + Port: port, + } + } + ingOld := ing.DeepCopy() if len(ing.Status.LoadBalancer.Ingress) != 1 || ing.Status.LoadBalancer.Ingress[0].IP != "" || @@ -270,8 +278,11 @@ func (r *groupReconciler) updateIngressStatus(ctx context.Context, lbDNS string, ing.Status.LoadBalancer.Ingress = []networking.IngressLoadBalancerIngress{ { Hostname: lbDNS, + Ports: ingressPorts, }, } + } else if len(ports) > 0 { + ing.Status.LoadBalancer.Ingress[0].Ports = ingressPorts } // Ensure frontendNLBDNS is appended if it is not already added @@ -405,21 +416,56 @@ func isIngressStatusEqual(a, b []networking.IngressLoadBalancerIngress) bool { return false } - setA := make(map[string]struct{}, len(a)) - setB := make(map[string]struct{}, len(b)) + hostnameToPortsA := make(map[string]map[int32]struct{}) + hostnameToPortsB := make(map[string]map[int32]struct{}) for _, ingress := range a { - setA[ingress.Hostname] = struct{}{} + if ingress.Hostname == "" { + continue + } + + portSet := make(map[int32]struct{}) + for _, portStatus := range ingress.Ports { + portSet[portStatus.Port] = struct{}{} + } + hostnameToPortsA[ingress.Hostname] = portSet } for _, ingress := range b { - setB[ingress.Hostname] = struct{}{} + if ingress.Hostname == "" { + continue + } + + portSet := make(map[int32]struct{}) + for _, portStatus := range ingress.Ports { + portSet[portStatus.Port] = struct{}{} + } + hostnameToPortsB[ingress.Hostname] = portSet } - for key := range setA { - if _, exists := setB[key]; !exists { + // Check if the maps are equal (same hostnames with same ports) + if len(hostnameToPortsA) != len(hostnameToPortsB) { + return false + } + + // Check if all hostnames in A exist in B with the same ports + for hostname, portsA := range hostnameToPortsA { + portsB, exists := hostnameToPortsB[hostname] + if !exists { + return false // Hostname in A doesn't exist in B + } + + // Check if port sets are equal (same length and same values) + if len(portsA) != len(portsB) { return false } + + // Check if all ports in A exist in B + for port := range portsA { + if _, exists := portsB[port]; !exists { + return false + } + } } return true } diff --git a/controllers/ingress/group_controller_test.go b/controllers/ingress/group_controller_test.go new file mode 100644 index 000000000..0c20800b4 --- /dev/null +++ b/controllers/ingress/group_controller_test.go @@ -0,0 +1,290 @@ +package ingress + +import ( + "github.com/stretchr/testify/assert" + networking "k8s.io/api/networking/v1" + "testing" +) + +func TestIsIngressStatusEqual(t *testing.T) { + testCases := []struct { + name string + statusA []networking.IngressLoadBalancerIngress + statusB []networking.IngressLoadBalancerIngress + expected bool + }{ + { + name: "Empty statuses", + statusA: []networking.IngressLoadBalancerIngress{}, + statusB: []networking.IngressLoadBalancerIngress{}, + expected: true, + }, + { + name: "Different length statuses", + statusA: []networking.IngressLoadBalancerIngress{}, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + }, + }, + expected: false, + }, + { + name: "Same hostname with no ports", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + }, + }, + expected: true, + }, + { + name: "Different hostnames", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb1.amazonaws.com", + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb2.amazonaws.com", + }, + }, + expected: false, + }, + { + name: "Same hostname with same ports", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + }, + expected: true, + }, + { + name: "Same hostname with ports in different order", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 443}, + {Port: 80}, + }, + }, + }, + expected: true, + }, + { + name: "Same hostname with different port counts", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + }, + }, + }, + expected: false, + }, + { + name: "Same hostname with different port values", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 8443}, + }, + }, + }, + expected: false, + }, + { + name: "Multiple entries with matching hostnames and ports", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + { + Hostname: "nlb.amazonaws.com", + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + { + Hostname: "nlb.amazonaws.com", + }, + }, + expected: true, + }, + { + name: "Multiple entries with different port configurations", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + { + Hostname: "nlb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + }, + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + { + Hostname: "nlb.amazonaws.com", + }, + }, + expected: false, + }, + { + name: "One status has empty hostname", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + }, + }, + { + Hostname: "", + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + }, + }, + { + Hostname: "", + }, + }, + expected: true, + }, + { + name: "One status has IP instead of hostname", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + IP: "", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + }, + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "", + IP: "192.168.1.1", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + }, + }, + }, + expected: false, + }, + { + name: "Different entries order with same ports", + statusA: []networking.IngressLoadBalancerIngress{ + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + { + Hostname: "nlb.amazonaws.com", + }, + }, + statusB: []networking.IngressLoadBalancerIngress{ + { + Hostname: "nlb.amazonaws.com", + }, + { + Hostname: "alb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 443}, + {Port: 80}, + }, + }, + }, + expected: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := isIngressStatusEqual(tc.statusA, tc.statusB) + assert.Equal(t, tc.expected, result) + }) + } +} diff --git a/controllers/service/service_controller.go b/controllers/service/service_controller.go index 364807189..a05406c53 100644 --- a/controllers/service/service_controller.go +++ b/controllers/service/service_controller.go @@ -240,11 +240,17 @@ func (r *serviceReconciler) cleanupLoadBalancerResources(ctx context.Context, sv func (r *serviceReconciler) updateServiceStatus(ctx context.Context, lbDNS string, svc *corev1.Service) error { if len(svc.Status.LoadBalancer.Ingress) != 1 || svc.Status.LoadBalancer.Ingress[0].IP != "" || - svc.Status.LoadBalancer.Ingress[0].Hostname != lbDNS { + svc.Status.LoadBalancer.Ingress[0].Hostname != lbDNS || + r.shouldUpdatePorts(svc) { + svcOld := svc.DeepCopy() + + ports := r.buildPortsForStatus(svc) + svc.Status.LoadBalancer.Ingress = []corev1.LoadBalancerIngress{ { Hostname: lbDNS, + Ports: ports, }, } if err := r.k8sClient.Status().Patch(ctx, svc, client.MergeFrom(svcOld)); err != nil { @@ -254,6 +260,49 @@ func (r *serviceReconciler) updateServiceStatus(ctx context.Context, lbDNS strin return nil } +// shouldUpdatePorts checks if we need to update the port information in the status +func (r *serviceReconciler) shouldUpdatePorts(svc *corev1.Service) bool { + if len(svc.Status.LoadBalancer.Ingress) != 1 { + return true + } + + existingPorts := svc.Status.LoadBalancer.Ingress[0].Ports + expectedPorts := r.buildPortsForStatus(svc) + + if len(existingPorts) != len(expectedPorts) { + return true + } + + // Create maps for easier comparison + existingPortMap := make(map[int32]bool) + for _, port := range existingPorts { + existingPortMap[port.Port] = true + } + + // Check if any expected port is missing + for _, port := range expectedPorts { + if !existingPortMap[port.Port] { + return true + } + } + + return false +} + +// buildPortsForStatus builds the list of port entries for the service status +func (r *serviceReconciler) buildPortsForStatus(svc *corev1.Service) []corev1.PortStatus { + var ports []corev1.PortStatus + + for _, svcPort := range svc.Spec.Ports { + ports = append(ports, corev1.PortStatus{ + Port: svcPort.Port, + Protocol: svcPort.Protocol, + }) + } + + return ports +} + func (r *serviceReconciler) cleanupServiceStatus(ctx context.Context, svc *corev1.Service) error { svcOld := svc.DeepCopy() svc.Status.LoadBalancer = corev1.LoadBalancerStatus{} diff --git a/controllers/service/service_controller_test.go b/controllers/service/service_controller_test.go new file mode 100644 index 000000000..d6d79e73e --- /dev/null +++ b/controllers/service/service_controller_test.go @@ -0,0 +1,281 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" +) + +func TestBuildPortsForStatus(t *testing.T) { + tests := []struct { + name string + service *corev1.Service + expected []corev1.PortStatus + }{ + { + name: "service with single port", + service: &corev1.Service{ + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + }, + }, + }, + }, + expected: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + }, + }, + { + name: "service with multiple ports", + service: &corev1.Service{ + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + }, + { + Name: "https", + Protocol: corev1.ProtocolTCP, + Port: 443, + }, + { + Name: "dns", + Protocol: corev1.ProtocolUDP, + Port: 53, + }, + }, + }, + }, + expected: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 443, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 53, + Protocol: corev1.ProtocolUDP, + }, + }, + }, + { + name: "service with no ports", + service: &corev1.Service{ + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{}, + }, + }, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reconciler := &serviceReconciler{} + result := reconciler.buildPortsForStatus(tt.service) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestShouldUpdatePorts(t *testing.T) { + tests := []struct { + name string + service *corev1.Service + expected bool + }{ + { + name: "no existing ingress entry", + service: &corev1.Service{ + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + }, + }, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{}, + }, + }, + }, + expected: true, + }, + { + name: "different port count", + service: &corev1.Service{ + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + }, + { + Name: "https", + Protocol: corev1.ProtocolTCP, + Port: 443, + }, + }, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{ + { + Hostname: "test-nlb.elb.amazonaws.com", + Ports: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + }, + }, + }, + }, + }, + }, + expected: true, + }, + { + name: "missing port", + service: &corev1.Service{ + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + }, + { + Name: "https", + Protocol: corev1.ProtocolTCP, + Port: 443, + }, + }, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{ + { + Hostname: "test-nlb.elb.amazonaws.com", + Ports: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 8080, // Different from spec + Protocol: corev1.ProtocolTCP, + }, + }, + }, + }, + }, + }, + }, + expected: true, + }, + { + name: "matching ports - no update needed", + service: &corev1.Service{ + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + }, + { + Name: "https", + Protocol: corev1.ProtocolTCP, + Port: 443, + }, + }, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{ + { + Hostname: "test-nlb.elb.amazonaws.com", + Ports: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 443, + Protocol: corev1.ProtocolTCP, + }, + }, + }, + }, + }, + }, + }, + expected: false, + }, + { + name: "matching ports, order changed- no update needed", + service: &corev1.Service{ + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + }, + { + Name: "https", + Protocol: corev1.ProtocolTCP, + Port: 443, + }, + }, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{ + { + Hostname: "test-nlb.elb.amazonaws.com", + Ports: []corev1.PortStatus{ + { + Port: 443, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + }, + }, + }, + }, + }, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reconciler := &serviceReconciler{} + result := reconciler.shouldUpdatePorts(tt.service) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/aga/endpoint_discovery.go b/pkg/aga/endpoint_discovery.go new file mode 100644 index 000000000..0a3bedbc8 --- /dev/null +++ b/pkg/aga/endpoint_discovery.go @@ -0,0 +1,289 @@ +package aga + +import ( + "context" + "fmt" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + "github.com/go-logr/logr" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + "sigs.k8s.io/aws-load-balancer-controller/pkg/shared_utils" + + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/annotations" + "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" + "sigs.k8s.io/controller-runtime/pkg/client" + gwv1 "sigs.k8s.io/gateway-api/apis/v1" +) + +// ProtocolPortInfo contains information about a protocol and its associated ports +type ProtocolPortInfo struct { + Protocol agaapi.GlobalAcceleratorProtocol + Ports []int32 +} + +// EndpointDiscovery is responsible for extracting protocol and port information from different endpoint types +type EndpointDiscovery struct { + client client.Client + annotationParser annotations.Parser + logger logr.Logger + elbv2Client services.ELBV2 +} + +// NewEndpointDiscovery creates a new EndpointDiscovery instance +func NewEndpointDiscovery(client client.Client, logger logr.Logger, elbv2Client services.ELBV2) *EndpointDiscovery { + annotationParser := annotations.NewSuffixAnnotationParser(annotations.AnnotationPrefixIngress) + return &EndpointDiscovery{ + client: client, + annotationParser: annotationParser, + logger: logger, + elbv2Client: elbv2Client, + } +} + +// FetchProtocolPortInfo extracts port and protocol information from a loaded endpoint +// For the auto-discovery scenario, we use the following approach: +// 1. Identify the endpoint type (Service, Ingress, Gateway, or LoadBalancer via EndpointID) +// 2. Extract protocol and port information from the stored K8s resource or AWS API +// 3. For Service endpoints, handle both TCP and UDP protocols based on the Service definition +// 4. For Ingress endpoints, extract ports from the load balancer status +// 5. For Gateway endpoints, map Gateway protocols to GlobalAccelerator protocols +// 6. For LoadBalancer (EndpointID) endpoints, query AWS API to get listener information +func (d *EndpointDiscovery) FetchProtocolPortInfo(ctx context.Context, endpoint *LoadedEndpoint) ([]ProtocolPortInfo, error) { + // For Kubernetes resource types, check if K8s resource is available + if endpoint.Type != agaapi.GlobalAcceleratorEndpointTypeEndpointID && endpoint.K8sResource == nil { + return nil, fmt.Errorf("kubernetes resource not available for endpoint %s/%s", + endpoint.Namespace, endpoint.Name) + } + + // Process based on endpoint type + switch endpoint.Type { + case agaapi.GlobalAcceleratorEndpointTypeService: + return d.fetchServiceProtocolPortInfo(ctx, endpoint) + case agaapi.GlobalAcceleratorEndpointTypeIngress: + return d.fetchIngressProtocolPortInfo(ctx, endpoint) + case agaapi.GlobalAcceleratorEndpointTypeGateway: + return d.fetchGatewayProtocolPortInfo(ctx, endpoint) + case agaapi.GlobalAcceleratorEndpointTypeEndpointID: + // For LoadBalancer ARN endpoints, we query the AWS API directly + // ARN should be already resolved during endpoint loading + if endpoint.ARN == "" { + return nil, fmt.Errorf("endpoint ARN is not available for endpoint with EndpointID type") + } + return d.fetchLoadBalancerProtocolPortInfo(ctx, endpoint) + } + + return nil, fmt.Errorf("auto-discovery not supported for endpoint type %s", endpoint.Type) +} + +// fetchServiceProtocolPortInfo extracts protocol and port information from a Service endpoint +func (d *EndpointDiscovery) fetchServiceProtocolPortInfo(_ context.Context, endpoint *LoadedEndpoint) ([]ProtocolPortInfo, error) { + svc, ok := endpoint.K8sResource.(*corev1.Service) + if !ok { + return nil, fmt.Errorf("expected Service object for endpoint %v but got %T", + k8s.NamespacedName(endpoint.K8sResource), endpoint.K8sResource) + } + + // Get ports from the service status + if len(svc.Status.LoadBalancer.Ingress) > 0 && len(svc.Status.LoadBalancer.Ingress[0].Ports) > 0 { + // Group ports by port number to check for TCP_UDP services (same port number, different protocols) + portMap := make(map[int32][]corev1.PortStatus) + for _, port := range svc.Status.LoadBalancer.Ingress[0].Ports { + key := port.Port + if vals, exists := portMap[key]; exists { + portMap[key] = append(vals, port) + } else { + portMap[key] = []corev1.PortStatus{port} + } + } + + // Check for TCP_UDP services and return error if found + for portNum, portStatuses := range portMap { + if len(portStatuses) > 1 { + // TCP_UDP service case not supported + return nil, fmt.Errorf("auto-discovery does not support TCP_UDP services on the same port %d for endpoint %v", + portNum, k8s.NamespacedName(svc)) + } + } + + // Group ports by protocol + tcpPorts := []int32{} + udpPorts := []int32{} + + for _, port := range svc.Status.LoadBalancer.Ingress[0].Ports { + if port.Protocol == corev1.ProtocolUDP { + udpPorts = append(udpPorts, port.Port) + } else { + tcpPorts = append(tcpPorts, port.Port) + } + } + return createProtocolPortsInfo(tcpPorts, udpPorts), nil + } + + // No ports found in status + return nil, fmt.Errorf("no port information available in service status for endpoint %v", + k8s.NamespacedName(svc)) +} + +// fetchIngressProtocolPortInfo extracts protocol and port information from an Ingress endpoint +// This function uses the listener ports stored in the Ingress status +func (d *EndpointDiscovery) fetchIngressProtocolPortInfo(_ context.Context, endpoint *LoadedEndpoint) ([]ProtocolPortInfo, error) { + ing, ok := endpoint.K8sResource.(*networkingv1.Ingress) + if !ok { + return nil, fmt.Errorf("expected Ingress object for endpoint %v but got %T", + k8s.NamespacedName(endpoint.K8sResource), endpoint.K8sResource) + } + + // Get ports from the ALB entry in status using FindIngressTwoDNSName + var tcpPorts []int32 + albDNS, _ := shared_utils.FindIngressTwoDNSName(ing) + + // Find the entry that corresponds to the ALB DNS + if albDNS != "" { + for _, ingressEntry := range ing.Status.LoadBalancer.Ingress { + if ingressEntry.Hostname == albDNS && len(ingressEntry.Ports) > 0 { + for _, portStatus := range ingressEntry.Ports { + tcpPorts = append(tcpPorts, portStatus.Port) + } + break + } + } + } + + if len(tcpPorts) == 0 { + return nil, fmt.Errorf("no valid ports found for ingress %v", k8s.NamespacedName(ing)) + } + + // Return TCP protocol with discovered ports + return []ProtocolPortInfo{ + {Protocol: agaapi.GlobalAcceleratorProtocolTCP, Ports: tcpPorts}, + }, nil +} + +// fetchGatewayProtocolPortInfo extracts protocol and port information from a Gateway endpoint +func (d *EndpointDiscovery) fetchGatewayProtocolPortInfo(_ context.Context, endpoint *LoadedEndpoint) ([]ProtocolPortInfo, error) { + gw, ok := endpoint.K8sResource.(*gwv1.Gateway) + if !ok { + return nil, fmt.Errorf("expected Gateway object for endpoint %v but got %T", + k8s.NamespacedName(endpoint.K8sResource), endpoint.K8sResource) + } + + tcpPortsMap := make(map[int32]bool) + udpPortsMap := make(map[int32]bool) + + // Process each listener and record ports by protocol + for _, listener := range gw.Spec.Listeners { + switch listener.Protocol { + case gwv1.UDPProtocolType: + udpPortsMap[int32(listener.Port)] = true + default: + // For HTTP, HTTPS, TLS, and other protocols, use TCP + tcpPortsMap[int32(listener.Port)] = true + } + } + + // Convert maps to slices for easier handling + var tcpPorts, udpPorts []int32 + for port := range tcpPortsMap { + tcpPorts = append(tcpPorts, port) + } + for port := range udpPortsMap { + udpPorts = append(udpPorts, port) + } + + return createProtocolPortsInfo(tcpPorts, udpPorts), nil +} + +// fetchLoadBalancerProtocolPortInfo extracts protocol and port information from a LoadBalancer ARN +// This uses the AWS API to retrieve ELBv2 listener information +func (d *EndpointDiscovery) fetchLoadBalancerProtocolPortInfo(ctx context.Context, endpoint *LoadedEndpoint) ([]ProtocolPortInfo, error) { + lbARN := endpoint.ARN + + // Call the AWS API to get listener information + protocolPortsInfo, err := d.getProtocolPortFromELBListener(ctx, lbARN) + if err != nil { + return nil, fmt.Errorf("failed to describe listeners for load balancer ARN %s: %w", lbARN, err) + } + + // No listeners found + if len(protocolPortsInfo) == 0 { + return nil, fmt.Errorf("no listeners found for load balancer ARN %s", lbARN) + } + + var tcpPorts, udpPorts []int32 + for _, info := range protocolPortsInfo { + if info.Protocol == agaapi.GlobalAcceleratorProtocolTCP { + tcpPorts = info.Ports + } else if info.Protocol == agaapi.GlobalAcceleratorProtocolUDP { + udpPorts = info.Ports + } + } + + d.logger.V(1).Info("discovered protocols and ports from AWS load balancer", + "loadBalancerARN", lbARN, + "tcpPorts", tcpPorts, + "udpPorts", udpPorts) + + return protocolPortsInfo, nil +} + +// getProtocolPortFromELBListener get the protocol and port info from ELB listener +func (d *EndpointDiscovery) getProtocolPortFromELBListener(ctx context.Context, lbARN string) ([]ProtocolPortInfo, error) { + input := &elasticloadbalancingv2.DescribeListenersInput{ + LoadBalancerArn: awssdk.String(lbARN), + } + + listeners, err := d.elbv2Client.DescribeListenersAsList(ctx, input) + if err != nil { + return nil, fmt.Errorf("failed to describe listeners for load balancer %s: %w", lbARN, err) + } + + // Group ports by protocol + tcpPorts := []int32{} + udpPorts := []int32{} + + for _, listener := range listeners { + port := awssdk.ToInt32(listener.Port) + listenerProtocol := listener.Protocol + + // Map ELB protocol to GA protocol + switch listenerProtocol { + case elbv2types.ProtocolEnumHttp, elbv2types.ProtocolEnumHttps, elbv2types.ProtocolEnumTcp, elbv2types.ProtocolEnumTls: + // All HTTP, HTTPS, TCP, TLS protocols map to TCP for Global Accelerator + tcpPorts = append(tcpPorts, port) + case elbv2types.ProtocolEnumUdp: + // UDP maps directly to UDP for Global Accelerator + udpPorts = append(udpPorts, port) + default: + // Any other protocols are not supported by Global Accelerator + return nil, fmt.Errorf("listener protocol %s is not supported by Global Accelerator for load balancer %s", + listenerProtocol, lbARN) + } + } + + return createProtocolPortsInfo(tcpPorts, udpPorts), nil +} + +// createProtocolPortsInfo is a helper function that creates ProtocolPortInfo objects from TCP and UDP port lists +func createProtocolPortsInfo(tcpPorts, udpPorts []int32) []ProtocolPortInfo { + var protocolPortsInfo []ProtocolPortInfo + + if len(tcpPorts) > 0 { + protocolPortsInfo = append(protocolPortsInfo, ProtocolPortInfo{ + Protocol: agaapi.GlobalAcceleratorProtocolTCP, + Ports: tcpPorts, + }) + } + if len(udpPorts) > 0 { + protocolPortsInfo = append(protocolPortsInfo, ProtocolPortInfo{ + Protocol: agaapi.GlobalAcceleratorProtocolUDP, + Ports: udpPorts, + }) + } + + return protocolPortsInfo +} diff --git a/pkg/aga/endpoint_discovery_test.go b/pkg/aga/endpoint_discovery_test.go new file mode 100644 index 000000000..2982e4f11 --- /dev/null +++ b/pkg/aga/endpoint_discovery_test.go @@ -0,0 +1,1044 @@ +package aga + +import ( + "context" + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + mock_client "sigs.k8s.io/aws-load-balancer-controller/mocks/controller-runtime/client" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + gwv1 "sigs.k8s.io/gateway-api/apis/v1" +) + +func TestFetchEndpointProtocolPortInfo(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mock_client.NewMockClient(ctrl) + ctx := context.TODO() + + gatewayClassName := gwv1.ObjectName("test-class") + httpType := gwv1.HTTPProtocolType + httpsType := gwv1.HTTPSProtocolType + udpType := gwv1.UDPProtocolType + + // Test Service endpoint with ports in status + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, // This protocol should NOT be used + Port: 9999, // This port should NOT be used + }, + }, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{ + { + Hostname: "test-nlb.us-west-2.elb.amazonaws.com", + Ports: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 443, + Protocol: corev1.ProtocolTCP, + }, + }, + }, + }, + }, + }, + } + + t.Run("Service endpoint with TCP ports in status", func(t *testing.T) { + serviceEndpoint := &LoadedEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service", + Namespace: "default", + Status: EndpointStatusLoaded, + K8sResource: svc, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/net/test-nlb/1234567890123456", + } + + mockElbv2Client := services.NewMockELBV2(ctrl) + logger := zap.New() + discovery := NewEndpointDiscovery(mockClient, logger, mockElbv2Client) + protocolPortsInfo, err := discovery.FetchProtocolPortInfo(ctx, serviceEndpoint) + assert.NoError(t, err) + + assert.Len(t, protocolPortsInfo, 1, "Should have one protocol group (TCP)") + assert.Equal(t, agaapi.GlobalAcceleratorProtocolTCP, protocolPortsInfo[0].Protocol, "Protocol should be TCP") + + // Verify both ports are present in the TCP group + portsFound := make(map[int32]bool) + for _, port := range protocolPortsInfo[0].Ports { + portsFound[port] = true + } + assert.True(t, portsFound[80], "Port 80 should be present") + assert.True(t, portsFound[443], "Port 443 should be present") + }) + + // Test Service with multi-protocol ports in status + svcMultiProto := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service-multi-proto", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + Ports: []corev1.ServicePort{ + { + Port: 9999, // Should be ignored as we're using status ports + }, + }, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{ + { + Hostname: "test-nlb-multi.us-west-2.elb.amazonaws.com", + Ports: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 53, + Protocol: corev1.ProtocolUDP, + }, + }, + }, + }, + }, + }, + } + + t.Run("Service with multi-protocol ports in status", func(t *testing.T) { + serviceMultiProtoEndpoint := &LoadedEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service-multi-proto", + Namespace: "default", + Status: EndpointStatusLoaded, + K8sResource: svcMultiProto, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/net/test-nlb-multi/1234567890123456", + } + + mockElbv2Client := services.NewMockELBV2(ctrl) + logger := zap.New() + discovery := NewEndpointDiscovery(mockClient, logger, mockElbv2Client) + protocolPortsInfo, err := discovery.FetchProtocolPortInfo(ctx, serviceMultiProtoEndpoint) + assert.NoError(t, err) + assert.Len(t, protocolPortsInfo, 2, "Should have two protocol groups (TCP and UDP)") + + // We expect one group for TCP and one group for UDP + tcpPorts := []int32{} + udpPorts := []int32{} + + // Extract the ports by protocol + for _, info := range protocolPortsInfo { + if info.Protocol == agaapi.GlobalAcceleratorProtocolTCP { + tcpPorts = append(tcpPorts, info.Ports...) + } else if info.Protocol == agaapi.GlobalAcceleratorProtocolUDP { + udpPorts = append(udpPorts, info.Ports...) + } + } + + // Verify TCP port + assert.Len(t, tcpPorts, 1, "Should have one TCP port") + assert.Contains(t, tcpPorts, int32(80), "Port 80 should be TCP") + + // Verify UDP port + assert.Len(t, udpPorts, 1, "Should have one UDP port") + assert.Contains(t, udpPorts, int32(53), "Port 53 should be UDP") + }) + + // Test Ingress endpoint + ing := &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + }, + Status: networkingv1.IngressStatus{ + LoadBalancer: networkingv1.IngressLoadBalancerStatus{ + Ingress: []networkingv1.IngressLoadBalancerIngress{ + { + Hostname: "test-alb.us-west-2.elb.amazonaws.com", + Ports: []networkingv1.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + { + Hostname: "test-nlb.amazonaws.com", // Non-ALB entry + }, + }, + }, + }, + } + + t.Run("Ingress endpoint with ALB DNS and ports in status", func(t *testing.T) { + ingressEndpoint := &LoadedEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: "test-ingress", + Namespace: "default", + Status: EndpointStatusLoaded, + K8sResource: ing, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-alb/1234567890123456", + } + + mockElbv2Client := services.NewMockELBV2(ctrl) + logger := zap.New() + discovery := NewEndpointDiscovery(mockClient, logger, mockElbv2Client) + protocolPortsInfo, err := discovery.FetchProtocolPortInfo(ctx, ingressEndpoint) + assert.NoError(t, err) + assert.Len(t, protocolPortsInfo, 1) // TCP protocol group + assert.Equal(t, agaapi.GlobalAcceleratorProtocolTCP, protocolPortsInfo[0].Protocol) + assert.Len(t, protocolPortsInfo[0].Ports, 2, "Should have two ports in TCP group") + + // Check if both ports are present + ports := protocolPortsInfo[0].Ports + assert.Contains(t, ports, int32(80), "Port 80 should be in ports") + assert.Contains(t, ports, int32(443), "Port 443 should be in ports") + }) + + // Test Gateway endpoint + gw := &gwv1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-gateway", + Namespace: "default", + }, + Spec: gwv1.GatewaySpec{ + GatewayClassName: gatewayClassName, + Listeners: []gwv1.Listener{ + { + Name: "http", + Port: 80, + Protocol: httpType, + }, + { + Name: "https", + Port: 443, + Protocol: httpsType, + }, + { + Name: "udp", + Port: 1433, + Protocol: udpType, + }, + }, + }, + } + + t.Run("Gateway endpoint with mixed protocols", func(t *testing.T) { + gatewayEndpoint := &LoadedEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeGateway, + Name: "test-gateway", + Namespace: "default", + Status: EndpointStatusLoaded, + K8sResource: gw, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-gateway-alb/1234567890123456", + } + + mockElbv2Client := services.NewMockELBV2(ctrl) + logger := zap.New() + discovery := NewEndpointDiscovery(mockClient, logger, mockElbv2Client) + protocolPortsInfo, err := discovery.FetchProtocolPortInfo(ctx, gatewayEndpoint) + assert.NoError(t, err) + assert.Len(t, protocolPortsInfo, 2, "Should have two protocol groups (TCP and UDP)") + + // We expect one group for TCP and one group for UDP + tcpPorts := []int32{} + udpPorts := []int32{} + + // Extract the ports by protocol + for _, info := range protocolPortsInfo { + if info.Protocol == agaapi.GlobalAcceleratorProtocolTCP { + tcpPorts = append(tcpPorts, info.Ports...) + } else if info.Protocol == agaapi.GlobalAcceleratorProtocolUDP { + udpPorts = append(udpPorts, info.Ports...) + } + } + + // Verify TCP ports + assert.Len(t, tcpPorts, 2, "Should have two TCP ports") + assert.Contains(t, tcpPorts, int32(80), "Port 80 should be in TCP group") + assert.Contains(t, tcpPorts, int32(443), "Port 443 should be in TCP group") + + // Verify UDP port + assert.Len(t, udpPorts, 1, "Should have one UDP port") + assert.Contains(t, udpPorts, int32(1433), "Port 1433 should be in UDP group") + }) +} + +func TestFetchIngressProtocolPortInfo(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.TODO() + + testCases := []struct { + name string + ingressStatus []networkingv1.IngressLoadBalancerIngress + expectedPortCount int + expectedPorts map[int32]bool + expectError bool + errorSubstring string + }{ + { + name: "Ingress with ALB entry and ports in status", + ingressStatus: []networkingv1.IngressLoadBalancerIngress{ + { + Hostname: "test-alb.us-west-2.elb.amazonaws.com", + Ports: []networkingv1.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + }, + expectedPortCount: 2, + expectedPorts: map[int32]bool{80: true, 443: true}, + expectError: false, + }, + { + name: "Ingress with ALB entry but no ports in status", + ingressStatus: []networkingv1.IngressLoadBalancerIngress{ + { + Hostname: "test-alb.us-west-2.elb.amazonaws.com", + // No ports + }, + }, + expectedPortCount: 0, + expectedPorts: map[int32]bool{}, + expectError: true, + errorSubstring: "no valid ports found", + }, + { + name: "Ingress with ALB and NLB entries in status (should use ALB ports)", + ingressStatus: []networkingv1.IngressLoadBalancerIngress{ + { + Hostname: "test-alb.us-west-2.elb.amazonaws.com", + Ports: []networkingv1.IngressPortStatus{ + {Port: 80}, + {Port: 443}, + }, + }, + { + Hostname: "test-nlb.amazonaws.com", // NLB entry, should be ignored for port discovery + }, + }, + expectedPortCount: 2, + expectedPorts: map[int32]bool{80: true, 443: true}, + expectError: false, + }, + { + name: "Ingress with NLB entry first and ALB entry second in status", + ingressStatus: []networkingv1.IngressLoadBalancerIngress{ + { + Hostname: "test-nlb.amazonaws.com", // NLB entry, should be ignored for port discovery + }, + { + Hostname: "test-alb.us-west-2.elb.amazonaws.com", + Ports: []networkingv1.IngressPortStatus{ + {Port: 443}, + {Port: 8443}, + }, + }, + }, + expectedPortCount: 2, + expectedPorts: map[int32]bool{443: true, 8443: true}, + expectError: false, + }, + { + name: "Ingress with no ALB entry in status", + ingressStatus: []networkingv1.IngressLoadBalancerIngress{ + { + Hostname: "test-nlb.amazonaws.com", // Not an ALB entry + }, + }, + expectedPortCount: 0, + expectedPorts: map[int32]bool{}, + expectError: true, + errorSubstring: "no valid ports found", + }, + { + name: "Ingress with empty status", + ingressStatus: []networkingv1.IngressLoadBalancerIngress{}, + expectedPortCount: 0, + expectedPorts: map[int32]bool{}, + expectError: true, + errorSubstring: "no valid ports found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create Ingress resource with test case status + ingress := &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress-" + tc.name, + Namespace: "default", + }, + Status: networkingv1.IngressStatus{ + LoadBalancer: networkingv1.IngressLoadBalancerStatus{ + Ingress: tc.ingressStatus, + }, + }, + } + + // Create loaded endpoint with the ingress resource + ingressEndpoint := &LoadedEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: "test-ingress-" + tc.name, + Namespace: "default", + Status: EndpointStatusLoaded, + K8sResource: ingress, + } + + // Create mocks + mockClient := mock_client.NewMockClient(ctrl) + mockElbv2Client := services.NewMockELBV2(ctrl) + logger := zap.New() + + // Call the function under test + discovery := NewEndpointDiscovery(mockClient, logger, mockElbv2Client) + protocolPortsInfo, err := discovery.fetchIngressProtocolPortInfo(ctx, ingressEndpoint) + + // Check error expectations + if tc.expectError { + assert.Error(t, err) + if tc.errorSubstring != "" { + assert.Contains(t, err.Error(), tc.errorSubstring) + } + } else { + assert.NoError(t, err) + assert.Len(t, protocolPortsInfo, 1, "Should have one protocol group (TCP)") + assert.Equal(t, agaapi.GlobalAcceleratorProtocolTCP, protocolPortsInfo[0].Protocol, "Protocol should be TCP") + + // Check if all expected ports are in the TCP port group + portsFound := make(map[int32]bool) + for _, port := range protocolPortsInfo[0].Ports { + portsFound[port] = true + } + + // Verify expected ports are found + for port := range tc.expectedPorts { + assert.True(t, portsFound[port], "Port %d not found", port) + } + + // Verify port count matches expected + assert.Equal(t, tc.expectedPortCount, len(protocolPortsInfo[0].Ports), "Port count should match expected") + } + }) + } +} + +func TestFetchServiceProtocolPortInfo(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.TODO() + + testCases := []struct { + name string + serviceStatusPorts []corev1.PortStatus + expectedTCPPorts []int32 + expectedUDPPorts []int32 + expectError bool + errorSubstring string + }{ + { + name: "Service ports from status (TCP only)", + serviceStatusPorts: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 443, + Protocol: corev1.ProtocolTCP, + }, + }, + expectedTCPPorts: []int32{80, 443}, + expectedUDPPorts: []int32{}, + expectError: false, + }, + { + name: "Service ports from status (TCP + UDP)", + serviceStatusPorts: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 53, + Protocol: corev1.ProtocolUDP, + }, + }, + expectedTCPPorts: []int32{80}, + expectedUDPPorts: []int32{53}, + expectError: false, + }, + { + name: "Error for TCP_UDP service (same port with different protocols)", + serviceStatusPorts: []corev1.PortStatus{ + { + Port: 53, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 53, + Protocol: corev1.ProtocolUDP, + }, + }, + expectedTCPPorts: []int32{}, + expectedUDPPorts: []int32{}, + expectError: true, + errorSubstring: "auto-discovery does not support TCP_UDP services on the same port 53", + }, + { + name: "Error when status has no ports", + serviceStatusPorts: []corev1.PortStatus{}, + expectedTCPPorts: []int32{}, + expectedUDPPorts: []int32{}, + expectError: true, + errorSubstring: "no port information available", + }, + { + name: "Error when no status entry", + serviceStatusPorts: nil, + expectedTCPPorts: []int32{}, + expectedUDPPorts: []int32{}, + expectError: true, + errorSubstring: "no port information available", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create Service resource with test case ports + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service-" + tc.name, + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + }, + } + + // Add status ports if provided + if tc.serviceStatusPorts != nil { + svc.Status.LoadBalancer.Ingress = []corev1.LoadBalancerIngress{ + { + Hostname: "test-nlb.us-west-2.elb.amazonaws.com", + Ports: tc.serviceStatusPorts, + }, + } + } + + // Create loaded endpoint with the service resource + serviceEndpoint := &LoadedEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service-" + tc.name, + Namespace: "default", + Status: EndpointStatusLoaded, + K8sResource: svc, + } + + // Create mocks + mockClient := mock_client.NewMockClient(ctrl) + mockElbv2Client := services.NewMockELBV2(ctrl) + logger := zap.New() + + // Call the function under test + discovery := NewEndpointDiscovery(mockClient, logger, mockElbv2Client) + protocolPortsInfo, err := discovery.fetchServiceProtocolPortInfo(ctx, serviceEndpoint) + + // Check error expectations + if tc.expectError { + assert.Error(t, err) + assert.Nil(t, protocolPortsInfo) + } else { + assert.NoError(t, err) + + // Extract ports by protocol + tcpPorts := []int32{} + udpPorts := []int32{} + + for _, info := range protocolPortsInfo { + if info.Protocol == agaapi.GlobalAcceleratorProtocolTCP { + tcpPorts = append(tcpPorts, info.Ports...) + } else if info.Protocol == agaapi.GlobalAcceleratorProtocolUDP { + udpPorts = append(udpPorts, info.Ports...) + } + } + + // Verify TCP ports + assert.Equal(t, len(tc.expectedTCPPorts), len(tcpPorts), "TCP port count should match expected") + for _, expectedPort := range tc.expectedTCPPorts { + assert.Contains(t, tcpPorts, expectedPort, "Expected TCP port %d not found", expectedPort) + } + + // Verify UDP ports + assert.Equal(t, len(tc.expectedUDPPorts), len(udpPorts), "UDP port count should match expected") + for _, expectedPort := range tc.expectedUDPPorts { + assert.Contains(t, udpPorts, expectedPort, "Expected UDP port %d not found", expectedPort) + } + + // Verify protocol group count + expectedGroupCount := 0 + if len(tc.expectedTCPPorts) > 0 { + expectedGroupCount++ + } + if len(tc.expectedUDPPorts) > 0 { + expectedGroupCount++ + } + assert.Len(t, protocolPortsInfo, expectedGroupCount, "Protocol group count should match expected") + } + }) + } +} + +func TestFetchGatewayProtocolPortInfo(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.TODO() + gatewayClassName := gwv1.ObjectName("test-class") + + testCases := []struct { + name string + listeners []gwv1.Listener + expectedProtocolGroups int + expectedPorts map[int32]agaapi.GlobalAcceleratorProtocol + expectError bool + }{ + { + name: "Gateway with mixed protocols (HTTP, HTTPS, UDP)", + listeners: []gwv1.Listener{ + { + Name: "http", + Port: 80, + Protocol: gwv1.HTTPProtocolType, + }, + { + Name: "https", + Port: 443, + Protocol: gwv1.HTTPSProtocolType, + }, + { + Name: "udp", + Port: 1433, + Protocol: gwv1.UDPProtocolType, + }, + }, + // One protocol group for TCP and one for UDP + expectedProtocolGroups: 2, + expectedPorts: map[int32]agaapi.GlobalAcceleratorProtocol{ + 80: agaapi.GlobalAcceleratorProtocolTCP, + 443: agaapi.GlobalAcceleratorProtocolTCP, + 1433: agaapi.GlobalAcceleratorProtocolUDP, + }, + expectError: false, + }, + { + name: "Gateway with HTTP protocol only", + listeners: []gwv1.Listener{ + { + Name: "http-80", + Port: 80, + Protocol: gwv1.HTTPProtocolType, + }, + { + Name: "http-8080", + Port: 8080, + Protocol: gwv1.HTTPProtocolType, + }, + }, + // Only one protocol group for TCP + expectedProtocolGroups: 1, + expectedPorts: map[int32]agaapi.GlobalAcceleratorProtocol{ + 80: agaapi.GlobalAcceleratorProtocolTCP, + 8080: agaapi.GlobalAcceleratorProtocolTCP, + }, + expectError: false, + }, + { + name: "Gateway with UDP protocol only", + listeners: []gwv1.Listener{ + { + Name: "udp-53", + Port: 53, + Protocol: gwv1.UDPProtocolType, + }, + { + Name: "udp-123", + Port: 123, + Protocol: gwv1.UDPProtocolType, + }, + }, + // Only one protocol group for UDP + expectedProtocolGroups: 1, + expectedPorts: map[int32]agaapi.GlobalAcceleratorProtocol{ + 53: agaapi.GlobalAcceleratorProtocolUDP, + 123: agaapi.GlobalAcceleratorProtocolUDP, + }, + expectError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create Gateway resource with test case listeners + gateway := &gwv1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-gateway-" + tc.name, + Namespace: "default", + }, + Spec: gwv1.GatewaySpec{ + GatewayClassName: gatewayClassName, + Listeners: tc.listeners, + }, + } + + // Create loaded endpoint with the gateway resource + gatewayEndpoint := &LoadedEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeGateway, + Name: "test-gateway-" + tc.name, + Namespace: "default", + Status: EndpointStatusLoaded, + K8sResource: gateway, + } + + // Create mocks + mockClient := mock_client.NewMockClient(ctrl) + mockElbv2Client := services.NewMockELBV2(ctrl) + logger := zap.New() + + // Call the function under test + discovery := NewEndpointDiscovery(mockClient, logger, mockElbv2Client) + protocolPortsInfo, err := discovery.fetchGatewayProtocolPortInfo(ctx, gatewayEndpoint) + + // Check error expectations + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + // Verify the number of protocol groups is as expected + assert.Len(t, protocolPortsInfo, tc.expectedProtocolGroups, "Protocol group count should match expected") + + // Extract ports by protocol + tcpPorts := []int32{} + udpPorts := []int32{} + for _, info := range protocolPortsInfo { + if info.Protocol == agaapi.GlobalAcceleratorProtocolTCP { + tcpPorts = append(tcpPorts, info.Ports...) + } else if info.Protocol == agaapi.GlobalAcceleratorProtocolUDP { + udpPorts = append(udpPorts, info.Ports...) + } + } + + // Count expected TCP and UDP ports + expectedTCPPorts := []int32{} + expectedUDPPorts := []int32{} + for port, protocol := range tc.expectedPorts { + if protocol == agaapi.GlobalAcceleratorProtocolTCP { + expectedTCPPorts = append(expectedTCPPorts, port) + } else if protocol == agaapi.GlobalAcceleratorProtocolUDP { + expectedUDPPorts = append(expectedUDPPorts, port) + } + } + + // Verify port counts by protocol + assert.Len(t, tcpPorts, len(expectedTCPPorts), "TCP port count should match expected") + assert.Len(t, udpPorts, len(expectedUDPPorts), "UDP port count should match expected") + + // Verify each expected TCP port is present + for _, expectedPort := range expectedTCPPorts { + assert.Contains(t, tcpPorts, expectedPort, "Expected TCP port %d not found", expectedPort) + } + + // Verify each expected UDP port is present + for _, expectedPort := range expectedUDPPorts { + assert.Contains(t, udpPorts, expectedPort, "Expected UDP port %d not found", expectedPort) + } + } + }) + } +} + +func TestGetProtocolPortFromELBListener(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.TODO() + mockClient := mock_client.NewMockClient(ctrl) + logger := zap.New() + mockElbv2Client := services.NewMockELBV2(ctrl) + + discovery := NewEndpointDiscovery(mockClient, logger, mockElbv2Client) + + albARN := "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-lb/1234567890abcdef" + nlbARN := "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/net/test-lb/1234567890abcdef" + + // Test case 1: HTTP and HTTPS listeners + httpHttpsListeners := []types.Listener{ + { + ListenerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:listener/app/test-lb/1234567890abcdef/http"), + Protocol: types.ProtocolEnumHttp, + Port: awssdk.Int32(80), + }, + { + ListenerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:listener/app/test-lb/1234567890abcdef/https"), + Protocol: types.ProtocolEnumHttps, + Port: awssdk.Int32(443), + }, + } + + // Test case 2: TCP and TLS listeners + tcpTlsListeners := []types.Listener{ + { + ListenerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:listener/net/test-lb/1234567890abcdef/tcp"), + Protocol: types.ProtocolEnumTcp, + Port: awssdk.Int32(80), + }, + { + ListenerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:listener/net/test-lb/1234567890abcdef/tls"), + Protocol: types.ProtocolEnumTls, + Port: awssdk.Int32(443), + }, + } + + // Test case 3: UDP listener + udpListeners := []types.Listener{ + { + ListenerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:listener/net/test-lb/1234567890abcdef/udp"), + Protocol: types.ProtocolEnumUdp, + Port: awssdk.Int32(53), + }, + } + + // Test case 4: Unsupported protocol + unsupportedListeners := []types.Listener{ + { + ListenerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:listener/net/test-lb/1234567890abcdef/tcpudp"), + Protocol: types.ProtocolEnumTcpUdp, // Not supported by Global Accelerator + Port: awssdk.Int32(53), + }, + } + + // Test case 5: Empty listener list + emptyListeners := []types.Listener{} + + testCases := []struct { + name string + lbArn string + listeners []types.Listener + expectError bool + expectedTCPPorts []int32 + expectedUDPPorts []int32 + expectedGroupCount int + }{ + { + name: "HTTP and HTTPS listeners", + lbArn: albARN, + listeners: httpHttpsListeners, + expectError: false, + expectedTCPPorts: []int32{80, 443}, + expectedUDPPorts: []int32{}, + expectedGroupCount: 1, // Only TCP group + }, + { + name: "TCP and TLS listeners", + lbArn: nlbARN, + listeners: tcpTlsListeners, + expectError: false, + expectedTCPPorts: []int32{80, 443}, + expectedUDPPorts: []int32{}, + expectedGroupCount: 1, // Only TCP group + }, + { + name: "UDP listener", + lbArn: nlbARN, + listeners: udpListeners, + expectError: false, + expectedTCPPorts: []int32{}, + expectedUDPPorts: []int32{53}, + expectedGroupCount: 1, // Only UDP group + }, + { + name: "Unsupported protocol", + lbArn: nlbARN, + listeners: unsupportedListeners, + expectError: true, + expectedTCPPorts: []int32{}, + expectedUDPPorts: []int32{}, + expectedGroupCount: 0, + }, + { + name: "Empty listeners", + lbArn: albARN, + listeners: emptyListeners, + expectError: false, + expectedTCPPorts: []int32{}, + expectedUDPPorts: []int32{}, + expectedGroupCount: 0, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Set up mock expectations + mockElbv2Client.EXPECT(). + DescribeListenersAsList(gomock.Any(), gomock.Any()). + Return(tc.listeners, nil) + + // Call the function under test + result, err := discovery.getProtocolPortFromELBListener(ctx, tc.lbArn) + + // Verify the result + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedGroupCount, len(result), "Number of protocol groups should match expected") + + // Extract TCP and UDP ports from the result + var tcpPorts, udpPorts []int32 + for _, info := range result { + if info.Protocol == agaapi.GlobalAcceleratorProtocolTCP { + tcpPorts = append(tcpPorts, info.Ports...) + } else if info.Protocol == agaapi.GlobalAcceleratorProtocolUDP { + udpPorts = append(udpPorts, info.Ports...) + } + } + + // Verify TCP ports + assert.ElementsMatch(t, tc.expectedTCPPorts, tcpPorts, "TCP ports should match expected") + + // Verify UDP ports + assert.ElementsMatch(t, tc.expectedUDPPorts, udpPorts, "UDP ports should match expected") + } + }) + } +} + +func TestFetchLoadBalancerProtocolPortInfo(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.TODO() + mockClient := mock_client.NewMockClient(ctrl) + logger := zap.New() + mockElbv2Client := services.NewMockELBV2(ctrl) + + discovery := NewEndpointDiscovery(mockClient, logger, mockElbv2Client) + + lbARN := "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-lb/1234567890abcdef" + + // Test with mixed TCP and UDP listeners + mixedListeners := []types.Listener{ + { + ListenerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:listener/app/test-lb/1234567890abcdef/http"), + Protocol: types.ProtocolEnumHttp, + Port: awssdk.Int32(80), + }, + { + ListenerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:listener/app/test-lb/1234567890abcdef/https"), + Protocol: types.ProtocolEnumHttps, + Port: awssdk.Int32(443), + }, + { + ListenerArn: awssdk.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:listener/net/test-lb/1234567890abcdef/udp"), + Protocol: types.ProtocolEnumUdp, + Port: awssdk.Int32(53), + }, + } + + testCases := []struct { + name string + endpoint *LoadedEndpoint + listeners []types.Listener + expectError bool + expectedTCPPorts []int32 + expectedUDPPorts []int32 + }{ + { + name: "Load balancer with mixed TCP and UDP listeners", + endpoint: &LoadedEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + Name: "test-endpoint", + Namespace: "default", + ARN: lbARN, + }, + listeners: mixedListeners, + expectError: false, + expectedTCPPorts: []int32{80, 443}, + expectedUDPPorts: []int32{53}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Set up mock expectations if needed + if tc.endpoint.ARN != "" { + mockElbv2Client.EXPECT(). + DescribeListenersAsList(gomock.Any(), gomock.Any()). + Return(tc.listeners, nil) + } + + // Call the function under test + result, err := discovery.fetchLoadBalancerProtocolPortInfo(ctx, tc.endpoint) + + // Verify the result + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + + // Extract TCP and UDP ports from the result for easier comparison + var tcpPorts, udpPorts []int32 + for _, info := range result { + if info.Protocol == agaapi.GlobalAcceleratorProtocolTCP { + tcpPorts = append(tcpPorts, info.Ports...) + } else if info.Protocol == agaapi.GlobalAcceleratorProtocolUDP { + udpPorts = append(udpPorts, info.Ports...) + } + } + + // Verify TCP ports + assert.ElementsMatch(t, tc.expectedTCPPorts, tcpPorts) + + // Verify UDP ports + assert.ElementsMatch(t, tc.expectedUDPPorts, udpPorts) + + // Verify protocol groups count + expectedGroups := 0 + if len(tc.expectedTCPPorts) > 0 { + expectedGroups++ + } + if len(tc.expectedUDPPorts) > 0 { + expectedGroups++ + } + assert.Equal(t, expectedGroups, len(result)) + } + }) + } +} diff --git a/pkg/aga/endpoint_loader.go b/pkg/aga/endpoint_loader.go index d94a5ea17..29543445c 100644 --- a/pkg/aga/endpoint_loader.go +++ b/pkg/aga/endpoint_loader.go @@ -63,6 +63,9 @@ type LoadedEndpoint struct { Status LoadedEndpointStatus Error error // The error that occurred during loading, if any Message string // Human-readable message explaining the status + + // K8s resource reference - used for port and protocol discovery + K8sResource client.Object } // IsUsable returns true if this endpoint can be used in the model @@ -225,6 +228,10 @@ func (l *endpointLoaderImpl) loadResourceWithDNS( result.ARN = arn result.Message = fmt.Sprintf("Successfully resolved %s to LoadBalancer ARN", resourceType) + // Store the K8s resource object in the result as a generalized client.Object + // This is used for port and protocol auto-discovery + result.K8sResource = obj.DeepCopyObject().(client.Object) + return nil } diff --git a/pkg/aga/model_build_listener.go b/pkg/aga/model_build_listener.go index 545720f25..48bee201a 100644 --- a/pkg/aga/model_build_listener.go +++ b/pkg/aga/model_build_listener.go @@ -3,41 +3,71 @@ package aga import ( "context" "fmt" + + "github.com/go-logr/logr" "github.com/pkg/errors" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" + "sigs.k8s.io/controller-runtime/pkg/client" ) // listenerBuilder builds Listener model resources type listenerBuilder interface { - Build(ctx context.Context, stack core.Stack, accelerator *agamodel.Accelerator, listeners []agaapi.GlobalAcceleratorListener) ([]*agamodel.Listener, error) + Build(ctx context.Context, stack core.Stack, accelerator *agamodel.Accelerator, listeners []agaapi.GlobalAcceleratorListener, ga *agaapi.GlobalAccelerator, loadedEndpoints []*LoadedEndpoint) ([]*agamodel.Listener, []agaapi.GlobalAcceleratorListener, error) } // NewListenerBuilder constructs new listenerBuilder -func NewListenerBuilder() listenerBuilder { - return &defaultListenerBuilder{} +func NewListenerBuilder(k8sClient client.Client, logger logr.Logger, elbv2Client services.ELBV2) listenerBuilder { + + endpointDiscovery := NewEndpointDiscovery(k8sClient, logger, elbv2Client) + + return &defaultListenerBuilder{ + endpointDiscovery: endpointDiscovery, + logger: logger, + } } var _ listenerBuilder = &defaultListenerBuilder{} -type defaultListenerBuilder struct{} +type defaultListenerBuilder struct { + endpointDiscovery *EndpointDiscovery + logger logr.Logger +} // Build builds Listener model resources -func (b *defaultListenerBuilder) Build(ctx context.Context, stack core.Stack, accelerator *agamodel.Accelerator, listeners []agaapi.GlobalAcceleratorListener) ([]*agamodel.Listener, error) { +func (b *defaultListenerBuilder) Build(ctx context.Context, stack core.Stack, accelerator *agamodel.Accelerator, listeners []agaapi.GlobalAcceleratorListener, ga *agaapi.GlobalAccelerator, loadedEndpoints []*LoadedEndpoint) ([]*agamodel.Listener, []agaapi.GlobalAcceleratorListener, error) { if listeners == nil || len(listeners) == 0 { - return nil, nil + return nil, nil, nil + } + + var listenersToProcess []agaapi.GlobalAcceleratorListener + + // Default to using original listeners + listenersToProcess = listeners + + // Apply auto-discovery logic if applicable + canApplyAutoDiscovery := canApplyAutoDiscoveryForGA(ga, loadedEndpoints) + if canApplyAutoDiscovery { + var err error + listenersToProcess, err = b.buildAutoDiscoveryListeners(ctx, listeners[0], loadedEndpoints[0], ga) + if err != nil { + return nil, nil, err + } } var result []*agamodel.Listener - for i, listener := range listeners { + for i, listener := range listenersToProcess { listenerModel, err := buildListener(ctx, stack, accelerator, listener, i) if err != nil { - return nil, err + return nil, nil, err } result = append(result, listenerModel) } - return result, nil + return result, listenersToProcess, nil } // buildListener builds a single Listener model resource @@ -77,9 +107,7 @@ func buildListenerSpec(ctx context.Context, accelerator *agamodel.Accelerator, l // buildListenerProtocol determines the protocol for the listener func buildListenerProtocol(_ context.Context, listener agaapi.GlobalAcceleratorListener) (agamodel.Protocol, error) { if listener.Protocol == nil { - // TODO: Auto-discovery feature - Auto-determine protocol from endpoints if nil - // Return error until auto-discovery feature is implemented - return "", errors.New("listener protocol must be specified (auto-discovery not yet implemented)") + return "", errors.New("listener protocol must be specified ") } switch *listener.Protocol { @@ -95,9 +123,7 @@ func buildListenerProtocol(_ context.Context, listener agaapi.GlobalAcceleratorL // buildListenerPortRanges determines the port ranges for the listener func buildListenerPortRanges(_ context.Context, listener agaapi.GlobalAcceleratorListener) ([]agamodel.PortRange, error) { if listener.PortRanges == nil { - // TODO: Auto-discovery feature - Auto-determine port ranges from endpoints if nil - // Return error until auto-discovery feature is implemented - return []agamodel.PortRange{}, errors.New("listener port ranges must be specified (auto-discovery not yet implemented)") + return []agamodel.PortRange{}, errors.New("listener port ranges must be specified") } var portRanges []agamodel.PortRange @@ -111,6 +137,153 @@ func buildListenerPortRanges(_ context.Context, listener agaapi.GlobalAccelerato return portRanges, nil } +// buildAutoDiscoveryListeners creates listeners based on auto-discovered protocols and ports +// This function is responsible for: +// 1. Fetching protocols and ports information from the loaded endpoint +// 2. Determining which protocols to create listeners for, based on protocol specification in the input +// 3. Creating appropriate listeners for each protocol with their corresponding port ranges +// 4. Consolidating ports into ranges to optimize AWS resources +// +// Returns new listeners with auto-discovered protocols and ports +func (b *defaultListenerBuilder) buildAutoDiscoveryListeners( + ctx context.Context, + templateListener agaapi.GlobalAcceleratorListener, + loadedEndpoint *LoadedEndpoint, + ga *agaapi.GlobalAccelerator) ([]agaapi.GlobalAcceleratorListener, error) { + + // Pre-fetch the protocol information + protocolPortsInfo, discoveryErr := b.endpointDiscovery.FetchProtocolPortInfo(ctx, loadedEndpoint) + if discoveryErr != nil { + b.logger.Error(discoveryErr, "failed to fetch endpoint port info for auto-discovery", + "endpoint", loadedEndpoint.Name, + "accelerator", k8s.NamespacedName(ga)) + return nil, errors.Wrap(discoveryErr, "failed to fetch endpoint port info for auto-discovery") + } + + // Check if we have any protocol information to work with + if len(protocolPortsInfo) == 0 { + err := errors.New("no protocol or port information found for auto-discovery") + b.logger.Error(err, "unable to auto-discover listener configuration", + "endpoint", loadedEndpoint.Name, + "accelerator", k8s.NamespacedName(ga)) + return nil, err + } + + // Determine which protocols to create listeners for + var protocolsToCreate []agaapi.GlobalAcceleratorProtocol + if templateListener.Protocol != nil { + // Explicitly specified protocol - create single listener + protocolsToCreate = []agaapi.GlobalAcceleratorProtocol{*templateListener.Protocol} + } else if len(protocolPortsInfo) > 1 { + // Multiple protocols detected - will always be TCP and UDP only + // Create one listener for each protocol + protocolsToCreate = []agaapi.GlobalAcceleratorProtocol{ + agaapi.GlobalAcceleratorProtocolTCP, + agaapi.GlobalAcceleratorProtocolUDP, + } + } else if len(protocolPortsInfo) == 1 { + // Single protocol - create one listener + protocolsToCreate = []agaapi.GlobalAcceleratorProtocol{protocolPortsInfo[0].Protocol} + } + + // Create new listeners for each protocol detected + listenersToProcess := make([]agaapi.GlobalAcceleratorListener, 0, len(protocolsToCreate)) + + // Create listeners for each protocol + for _, protocol := range protocolsToCreate { + // Get matching ports for this protocol + var matchingPorts []int32 + for _, info := range protocolPortsInfo { + if info.Protocol == protocol { + matchingPorts = append(matchingPorts, info.Ports...) + } + } + + // Consolidate ports into ranges + var portRanges []agamodel.PortRange + if len(matchingPorts) > 0 { + portRanges = consolidatePortRanges(matchingPorts) + } + + // Create new listener with protocol and port ranges + newListener := createNewListener(templateListener, protocol, portRanges) + listenersToProcess = append(listenersToProcess, newListener) + + b.logger.V(1).Info( + "Created auto-discovery listener with port ranges", + "protocol", protocol, + "portCount", len(matchingPorts), + "rangeCount", len(portRanges), + ) + } + + return listenersToProcess, nil +} + +// createNewListener creates a new listener with specified protocol and optional port ranges +func createNewListener(template agaapi.GlobalAcceleratorListener, protocol agaapi.GlobalAcceleratorProtocol, portRanges []agamodel.PortRange) agaapi.GlobalAcceleratorListener { + var newListener agaapi.GlobalAcceleratorListener + + // Set protocol to the specified protocol + newListener.Protocol = &protocol + + // Copy non-pointer fields directly + newListener.ClientAffinity = template.ClientAffinity + + // Set port ranges - prioritize template port ranges if they exist + if template.PortRanges != nil { + portRangesCopy := make([]agaapi.PortRange, len(*template.PortRanges)) + copy(portRangesCopy, *template.PortRanges) + newListener.PortRanges = &portRangesCopy + } else if len(portRanges) > 0 { + apiPortRanges := make([]agaapi.PortRange, 0, len(portRanges)) + for _, pr := range portRanges { + apiPortRanges = append(apiPortRanges, agaapi.PortRange{ + FromPort: pr.FromPort, + ToPort: pr.ToPort, + }) + } + newListener.PortRanges = &apiPortRanges + } + + // Copy EndpointGroups + if template.EndpointGroups != nil { + endpointGroupsCopy := make([]agaapi.GlobalAcceleratorEndpointGroup, len(*template.EndpointGroups)) + for i, eg := range *template.EndpointGroups { + endpointGroupsCopy[i] = eg + + // Create new slice for Endpoints if they exist + if eg.Endpoints != nil { + endpointsCopy := make([]agaapi.GlobalAcceleratorEndpoint, len(*eg.Endpoints)) + copy(endpointsCopy, *eg.Endpoints) + endpointGroupsCopy[i].Endpoints = &endpointsCopy + } + + // Deep copy Region if it exists + if eg.Region != nil { + region := *eg.Region + endpointGroupsCopy[i].Region = ®ion + } + + // Deep copy TrafficDialPercentage if it exists + if eg.TrafficDialPercentage != nil { + trafficDialPct := *eg.TrafficDialPercentage + endpointGroupsCopy[i].TrafficDialPercentage = &trafficDialPct + } + + // Copy port overrides if they exist + if eg.PortOverrides != nil { + portOverridesCopy := make([]agaapi.PortOverride, len(*eg.PortOverrides)) + copy(portOverridesCopy, *eg.PortOverrides) + endpointGroupsCopy[i].PortOverrides = &portOverridesCopy + } + } + newListener.EndpointGroups = &endpointGroupsCopy + } + + return newListener +} + // buildListenerClientAffinity determines the client affinity for the listener func buildListenerClientAffinity(_ context.Context, listener agaapi.GlobalAcceleratorListener) agamodel.ClientAffinity { switch listener.ClientAffinity { diff --git a/pkg/aga/model_build_listener_test.go b/pkg/aga/model_build_listener_test.go index cb287dd6c..97e7f7de2 100644 --- a/pkg/aga/model_build_listener_test.go +++ b/pkg/aga/model_build_listener_test.go @@ -3,13 +3,21 @@ package aga import ( "context" awssdk "github.com/aws/aws-sdk-go-v2/aws" - "testing" - + "github.com/go-logr/logr" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + networking "k8s.io/api/networking/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" + "sigs.k8s.io/controller-runtime/pkg/log" + "testing" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + mock_client "sigs.k8s.io/aws-load-balancer-controller/mocks/controller-runtime/client" agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" "sigs.k8s.io/aws-load-balancer-controller/pkg/model/core" + "sigs.k8s.io/controller-runtime/pkg/log/zap" ) func TestDefaultListenerBuilder_Build(t *testing.T) { @@ -18,10 +26,15 @@ func TestDefaultListenerBuilder_Build(t *testing.T) { protocolUDP := agaapi.GlobalAcceleratorProtocolUDP tests := []struct { - name string - listeners []agaapi.GlobalAcceleratorListener - wantListeners int - wantErr bool + name string + listeners []agaapi.GlobalAcceleratorListener + endpoints []*LoadedEndpoint + setupMocks func(*gomock.Controller) (*mock_client.MockClient, services.ELBV2) + wantListeners int + wantProtocol agamodel.Protocol + wantPortCount int + wantErr bool + isAutoDiscovery bool }{ { name: "with nil listeners", @@ -96,6 +109,228 @@ func TestDefaultListenerBuilder_Build(t *testing.T) { wantListeners: 2, wantErr: false, }, + { + name: "with auto-discovery for Ingress endpoint with custom ports", + listeners: []agaapi.GlobalAcceleratorListener{ + { + // Both Protocol and PortRanges are nil for auto-discovery + Protocol: nil, + PortRanges: nil, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: awssdk.String("test-ingress"), + Namespace: awssdk.String("default"), + }, + }, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + endpoints: []*LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: "test-ingress", + Namespace: "default", + Status: EndpointStatusLoaded, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-ingress/1234567890123456", + K8sResource: &networking.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + }, + Status: networking.IngressStatus{ + LoadBalancer: networking.IngressLoadBalancerStatus{ + Ingress: []networking.IngressLoadBalancerIngress{ + { + Hostname: "test-alb.us-west-2.elb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 8080}, + {Port: 8443}, + }, + }, + }, + }, + }, + }, + }, + }, + setupMocks: func(ctrl *gomock.Controller) (*mock_client.MockClient, services.ELBV2) { + mockClient := mock_client.NewMockClient(ctrl) + mockElbv2Client := services.NewMockELBV2(ctrl) + // Configure mocks to handle IngressClassParams lookup + mockClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + return mockClient, mockElbv2Client + }, + wantListeners: 1, + wantProtocol: agamodel.ProtocolTCP, + wantPortCount: 2, + wantErr: false, + isAutoDiscovery: true, + }, + { + name: "with auto-discovery for Service endpoint with TCP protocol", + listeners: []agaapi.GlobalAcceleratorListener{ + { + // Both Protocol and PortRanges are nil for auto-discovery + Protocol: nil, + PortRanges: nil, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: awssdk.String("test-service"), + Namespace: awssdk.String("default"), + }, + }, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + endpoints: []*LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "test-service", + Namespace: "default", + Status: EndpointStatusLoaded, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/net/test-service/1234567890123456", + K8sResource: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-service", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + }, + { + Name: "https", + Protocol: corev1.ProtocolTCP, + Port: 443, + }, + }, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{ + { + Hostname: "test-nlb.us-west-2.elb.amazonaws.com", + Ports: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 443, + Protocol: corev1.ProtocolTCP, + }, + }, + }, + }, + }, + }, + }, + }, + }, + setupMocks: func(ctrl *gomock.Controller) (*mock_client.MockClient, services.ELBV2) { + mockClient := mock_client.NewMockClient(ctrl) + mockElbv2Client := services.NewMockELBV2(ctrl) + return mockClient, mockElbv2Client + }, + wantListeners: 1, + wantProtocol: agamodel.ProtocolTCP, + wantPortCount: 2, + wantErr: false, + isAutoDiscovery: true, + }, + { + name: "with auto-discovery for Service endpoint with mixed TCP/UDP protocols", + listeners: []agaapi.GlobalAcceleratorListener{ + { + // Both Protocol and PortRanges are nil for auto-discovery + Protocol: nil, + PortRanges: nil, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: awssdk.String("mixed-service"), + Namespace: awssdk.String("default"), + }, + }, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + endpoints: []*LoadedEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: "mixed-service", + Namespace: "default", + Status: EndpointStatusLoaded, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/net/mixed-service/1234567890123456", + K8sResource: &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "mixed-service", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: corev1.ProtocolTCP, + Port: 80, + }, + { + Name: "dns", + Protocol: corev1.ProtocolUDP, + Port: 53, + }, + }, + }, + Status: corev1.ServiceStatus{ + LoadBalancer: corev1.LoadBalancerStatus{ + Ingress: []corev1.LoadBalancerIngress{ + { + Hostname: "mixed-service-nlb.us-west-2.elb.amazonaws.com", + Ports: []corev1.PortStatus{ + { + Port: 80, + Protocol: corev1.ProtocolTCP, + }, + { + Port: 53, + Protocol: corev1.ProtocolUDP, + }, + }, + }, + }, + }, + }, + }, + }, + }, + setupMocks: func(ctrl *gomock.Controller) (*mock_client.MockClient, services.ELBV2) { + mockClient := mock_client.NewMockClient(ctrl) + mockElbv2Client := services.NewMockELBV2(ctrl) + return mockClient, mockElbv2Client + }, + wantListeners: 2, // Should create 2 listeners (one for TCP, one for UDP) + wantErr: false, + isAutoDiscovery: true, + }, } for _, tt := range tests { @@ -105,9 +340,40 @@ func TestDefaultListenerBuilder_Build(t *testing.T) { stack := core.NewDefaultStack(core.StackID{Namespace: "test-ns", Name: "test-name"}) accelerator := createTestAccelerator(stack) + // Create mock GA resource and loaded endpoints for the test + ga := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ga", + Namespace: "default", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &tt.listeners, + }, + } + + // Use provided endpoints or an empty array + loadedEndpoints := tt.endpoints + if loadedEndpoints == nil { + loadedEndpoints = []*LoadedEndpoint{} + } + // Create listener builder and build listeners - builder := NewListenerBuilder() - listeners, err := builder.Build(ctx, stack, accelerator, tt.listeners) + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var mockClient *mock_client.MockClient + var mockElbv2Client services.ELBV2 + + if tt.setupMocks != nil { + mockClient, mockElbv2Client = tt.setupMocks(ctrl) + } else { + mockClient = mock_client.NewMockClient(ctrl) + mockElbv2Client = services.NewMockELBV2(ctrl) + } + + logger := logr.New(&log.NullLogSink{}) + builder := NewListenerBuilder(mockClient, logger, mockElbv2Client) + listeners, _, err := builder.Build(ctx, stack, accelerator, tt.listeners, ga, loadedEndpoints) // Check results if tt.wantErr { @@ -117,7 +383,26 @@ func TestDefaultListenerBuilder_Build(t *testing.T) { if tt.wantListeners == 0 { assert.Nil(t, listeners) } else { + // Verify number of listeners assert.Equal(t, tt.wantListeners, len(listeners)) + + if tt.isAutoDiscovery { + + if tt.wantPortCount > 0 { + // For simple cases, verify port count on the first listener + if len(listeners) > 0 { + // Verify port ranges were auto-discovered correctly + assert.Equal(t, tt.wantPortCount, len(listeners[0].Spec.PortRanges), + "Incorrect number of port ranges") + } + } + + if tt.wantProtocol != "" { + // Verify protocol was set correctly + assert.Equal(t, tt.wantProtocol, listeners[0].Spec.Protocol, + "Protocol was not auto-discovered correctly") + } + } } } }) @@ -475,3 +760,335 @@ func TestBuildListenerClientAffinity(t *testing.T) { }) } } + +// TestAutomaticEndpointDiscovery tests the auto-discovery feature for listeners +func TestAutomaticEndpointDiscovery(t *testing.T) { + // Setup test context + ctx := context.Background() + stack := core.NewDefaultStack(core.StackID{Namespace: "test-ns", Name: "test-name"}) + accelerator := createTestAccelerator(stack) + + // Create a mock global accelerator resource + ga := &agaapi.GlobalAccelerator{ + ObjectMeta: metav1.ObjectMeta{ + Name: "auto-discovery-ga", + Namespace: "default", + }, + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + // Both Protocol and PortRanges are nil + // They should be auto-discovered from the endpoint + Protocol: nil, + PortRanges: nil, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: awssdk.String("test-ingress"), + Namespace: awssdk.String("default"), + }, + }, + }, + }, + ClientAffinity: agaapi.ClientAffinityNone, + }, + }, + }, + } + + // Create a LoadedEndpoint with mock K8s resource for auto-discovery test + endpoint := &LoadedEndpoint{ + Type: agaapi.GlobalAcceleratorEndpointTypeIngress, + Name: "test-ingress", + Namespace: "default", + Status: EndpointStatusLoaded, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-ingress/1234567890123456", + K8sResource: &networking.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + }, + Status: networking.IngressStatus{ + LoadBalancer: networking.IngressLoadBalancerStatus{ + Ingress: []networking.IngressLoadBalancerIngress{ + { + Hostname: "test-alb.us-west-2.elb.amazonaws.com", + Ports: []networking.IngressPortStatus{ + {Port: 8080}, + {Port: 8443}, + }, + }, + }, + }, + }, + }, + } + + loadedEndpoints := []*LoadedEndpoint{endpoint} + + // Verify auto-discovery is applicable + canApply := canApplyAutoDiscoveryForGA(ga, loadedEndpoints) + assert.True(t, canApply) + + // Create listener builder and build the listener + ctrl := gomock.NewController(t) + mockClient := mock_client.NewMockClient(ctrl) + mockElbv2Client := services.NewMockELBV2(ctrl) + logger := zap.New() + + // Configure mocks to handle IngressClassParams lookup + mockClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + + builder := NewListenerBuilder(mockClient, logger, mockElbv2Client) + listeners, _, err := builder.Build(ctx, stack, accelerator, *ga.Spec.Listeners, ga, loadedEndpoints) + + // Verify build was successful + assert.NoError(t, err) + assert.Equal(t, 1, len(listeners)) + + // Verify auto-discovered protocol and port ranges + assert.Equal(t, agamodel.ProtocolTCP, listeners[0].Spec.Protocol) + assert.Equal(t, 2, len(listeners[0].Spec.PortRanges)) + + // Sort port ranges for consistent test results + portRanges := listeners[0].Spec.PortRanges + portMap := make(map[int32]bool) + for _, pr := range portRanges { + portMap[pr.FromPort] = true + } + + // Verify the expected ports were discovered from the ingress annotations + assert.True(t, portMap[8080]) + assert.True(t, portMap[8443]) +} + +func TestCreateNewListener(t *testing.T) { + // Protocol references + protocolTCP := agaapi.GlobalAcceleratorProtocolTCP + protocolUDP := agaapi.GlobalAcceleratorProtocolUDP + + // Port ranges + tcpPortRanges := []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + } + + emptyPortRanges := []agamodel.PortRange{} + + // Create template listener with full data + templateListener := agaapi.GlobalAcceleratorListener{ + Protocol: &protocolTCP, + PortRanges: &[]agaapi.PortRange{ + { + FromPort: 8080, + ToPort: 8080, + }, + }, + ClientAffinity: agaapi.ClientAffinitySourceIP, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Region: awssdk.String("us-west-2"), + TrafficDialPercentage: awssdk.Int32(100), + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + Name: awssdk.String("test-service"), + Namespace: awssdk.String("default"), + Weight: awssdk.Int32(100), + }, + }, + PortOverrides: &[]agaapi.PortOverride{ + { + ListenerPort: 8080, + EndpointPort: 80, + }, + }, + }, + }, + } + + // Test cases for createNewListener function + testCases := []struct { + name string + template agaapi.GlobalAcceleratorListener + protocol agaapi.GlobalAcceleratorProtocol + portRanges []agamodel.PortRange + expectProtocol agaapi.GlobalAcceleratorProtocol + expectPorts []agaapi.PortRange + expectAffinity agaapi.ClientAffinityType + }{ + { + name: "Create TCP listener with discovered ports", + template: agaapi.GlobalAcceleratorListener{ + ClientAffinity: agaapi.ClientAffinitySourceIP, + // No protocol, port ranges, or endpoint groups + }, + protocol: protocolTCP, + portRanges: tcpPortRanges, + expectProtocol: protocolTCP, + expectPorts: []agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + expectAffinity: agaapi.ClientAffinitySourceIP, + }, + { + name: "Create UDP listener with discovered ports", + template: agaapi.GlobalAcceleratorListener{ + ClientAffinity: agaapi.ClientAffinitySourceIP, + // No protocol, port ranges, or endpoint groups + }, + protocol: protocolUDP, + portRanges: tcpPortRanges, // Reuse same ports for UDP + expectProtocol: protocolUDP, + expectPorts: []agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + expectAffinity: agaapi.ClientAffinitySourceIP, + }, + { + name: "Create listener with empty port ranges", + template: templateListener, + protocol: protocolTCP, + portRanges: emptyPortRanges, + expectProtocol: protocolTCP, + expectPorts: []agaapi.PortRange{ + { + FromPort: 8080, + ToPort: 8080, + }, + }, // Uses template ports + expectAffinity: agaapi.ClientAffinitySourceIP, + }, + { + name: "Create listener from minimal template", + template: agaapi.GlobalAcceleratorListener{ + ClientAffinity: agaapi.ClientAffinityNone, + // No protocol, port ranges, or endpoint groups + }, + protocol: protocolTCP, + portRanges: tcpPortRanges, + expectProtocol: protocolTCP, + expectPorts: []agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + expectAffinity: agaapi.ClientAffinityNone, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Call the function under test + result := createNewListener(tc.template, tc.protocol, tc.portRanges) + + // Verify protocol was set correctly + assert.NotNil(t, result.Protocol) + assert.Equal(t, tc.expectProtocol, *result.Protocol) + + // Verify client affinity was copied + assert.Equal(t, tc.expectAffinity, result.ClientAffinity) + + // Verify port ranges + if tc.expectPorts != nil { + assert.NotNil(t, result.PortRanges) + assert.Equal(t, len(tc.expectPorts), len(*result.PortRanges)) + + // Check each port range + for i, expectedPort := range tc.expectPorts { + found := false + for _, actualPort := range *result.PortRanges { + if actualPort.FromPort == expectedPort.FromPort && actualPort.ToPort == expectedPort.ToPort { + found = true + break + } + } + assert.True(t, found, "Expected port range %d-%d not found", tc.expectPorts[i].FromPort, tc.expectPorts[i].ToPort) + } + } + + // Verify endpoint groups were copied + if tc.template.EndpointGroups != nil { + assert.NotNil(t, result.EndpointGroups) + assert.Equal(t, len(*tc.template.EndpointGroups), len(*result.EndpointGroups)) + + // For first endpoint group, check if fields were properly copied + if len(*tc.template.EndpointGroups) > 0 && len(*result.EndpointGroups) > 0 { + templateEG := (*tc.template.EndpointGroups)[0] + resultEG := (*result.EndpointGroups)[0] + + // Check region + if templateEG.Region != nil { + assert.NotNil(t, resultEG.Region) + assert.Equal(t, *templateEG.Region, *resultEG.Region) + } + + // Check traffic dial percentage + if templateEG.TrafficDialPercentage != nil { + assert.NotNil(t, resultEG.TrafficDialPercentage) + assert.Equal(t, *templateEG.TrafficDialPercentage, *resultEG.TrafficDialPercentage) + } + + // Check endpoints + if templateEG.Endpoints != nil && len(*templateEG.Endpoints) > 0 { + assert.NotNil(t, resultEG.Endpoints) + assert.Equal(t, len(*templateEG.Endpoints), len(*resultEG.Endpoints)) + + // Check first endpoint + if len(*templateEG.Endpoints) > 0 && len(*resultEG.Endpoints) > 0 { + templateEndpoint := (*templateEG.Endpoints)[0] + resultEndpoint := (*resultEG.Endpoints)[0] + + assert.Equal(t, templateEndpoint.Type, resultEndpoint.Type) + + if templateEndpoint.Name != nil { + assert.NotNil(t, resultEndpoint.Name) + assert.Equal(t, *templateEndpoint.Name, *resultEndpoint.Name) + } + + if templateEndpoint.Namespace != nil { + assert.NotNil(t, resultEndpoint.Namespace) + assert.Equal(t, *templateEndpoint.Namespace, *resultEndpoint.Namespace) + } + + assert.Equal(t, templateEndpoint.Weight, resultEndpoint.Weight) + } + } + + // Check port overrides + if templateEG.PortOverrides != nil { + assert.NotNil(t, resultEG.PortOverrides) + assert.Equal(t, len(*templateEG.PortOverrides), len(*resultEG.PortOverrides)) + } + } + } + }) + } +} diff --git a/pkg/aga/model_builder.go b/pkg/aga/model_builder.go index e9b492929..40699b348 100644 --- a/pkg/aga/model_builder.go +++ b/pkg/aga/model_builder.go @@ -2,9 +2,12 @@ package aga import ( "context" + "github.com/go-logr/logr" "k8s.io/client-go/tools/record" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" "sigs.k8s.io/aws-load-balancer-controller/pkg/config" "sigs.k8s.io/aws-load-balancer-controller/pkg/deploy/tracking" "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" @@ -23,7 +26,8 @@ type ModelBuilder interface { // NewDefaultModelBuilder constructs new defaultModelBuilder. func NewDefaultModelBuilder(k8sClient client.Client, eventRecorder record.EventRecorder, trackingProvider tracking.Provider, featureGates config.FeatureGates, - clusterName string, clusterRegion string, defaultTags map[string]string, externalManagedTags []string, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector) *defaultModelBuilder { + clusterName string, clusterRegion string, defaultTags map[string]string, externalManagedTags []string, + logger logr.Logger, metricsCollector lbcmetrics.MetricCollector, elbv2Client services.ELBV2) *defaultModelBuilder { return &defaultModelBuilder{ k8sClient: k8sClient, @@ -36,6 +40,7 @@ func NewDefaultModelBuilder(k8sClient client.Client, eventRecorder record.EventR externalManagedTags: externalManagedTags, logger: logger, metricsCollector: metricsCollector, + elbv2Client: elbv2Client, } } @@ -53,6 +58,7 @@ type defaultModelBuilder struct { externalManagedTags []string logger logr.Logger metricsCollector lbcmetrics.MetricCollector + elbv2Client services.ELBV2 } // Build model stack for a GlobalAccelerator. @@ -61,7 +67,7 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccele // Create fresh builder instances for each reconciliation acceleratorBuilder := NewAcceleratorBuilder(b.trackingProvider, b.clusterName, b.clusterRegion, b.defaultTags, b.externalManagedTags, b.featureGates.Enabled(config.EnableDefaultTagsLowPriority)) - listenerBuilder := NewListenerBuilder() + listenerBuilder := NewListenerBuilder(b.k8sClient, b.logger, b.elbv2Client) endpointGroupBuilder := NewEndpointGroupBuilder(b.clusterRegion, ga.Namespace, b.logger) // Build Accelerator @@ -72,14 +78,15 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ga *agaapi.GlobalAccele // Build Listeners if specified var listeners []*agamodel.Listener + var processedListeners []agaapi.GlobalAcceleratorListener if ga.Spec.Listeners != nil { - listeners, err = listenerBuilder.Build(ctx, stack, accelerator, *ga.Spec.Listeners) + listeners, processedListeners, err = listenerBuilder.Build(ctx, stack, accelerator, *ga.Spec.Listeners, ga, loadedEndpoints) if err != nil { return nil, nil, err } - // Build endpoint groups with loaded endpoints - _, err := endpointGroupBuilder.Build(ctx, stack, listeners, *ga.Spec.Listeners, loadedEndpoints) + // Build endpoint groups with loaded endpoints - using processedListeners to capture auto-discovery changes + _, err := endpointGroupBuilder.Build(ctx, stack, listeners, processedListeners, loadedEndpoints) if err != nil { return nil, nil, err } diff --git a/pkg/aga/utils.go b/pkg/aga/utils.go index a18323a14..bef20fb93 100644 --- a/pkg/aga/utils.go +++ b/pkg/aga/utils.go @@ -1,8 +1,10 @@ package aga import ( + "sort" "strings" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" "sigs.k8s.io/aws-load-balancer-controller/pkg/config" agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" ) @@ -45,3 +47,96 @@ func IsGlobalAcceleratorControllerEnabled(featureGates config.FeatureGates, regi return true } + +// consolidatePortRanges combines consecutive ports into ranges +func consolidatePortRanges(ports []int32) []agamodel.PortRange { + if len(ports) == 0 { + return nil + } + + // Sort ports for efficient range detection using standard library + sort.Slice(ports, func(i, j int) bool { + return ports[i] < ports[j] + }) + + // Consolidate ranges + var result []agamodel.PortRange + + rangeStart := ports[0] + rangeEnd := ports[0] + + for i := 1; i < len(ports); i++ { + // If current port is consecutive to previous, extend the range + if ports[i] == rangeEnd+1 { + rangeEnd = ports[i] + } else if ports[i] > rangeEnd+1 { // Skip duplicates + // Save the current range and start a new one + result = append(result, agamodel.PortRange{ + FromPort: rangeStart, + ToPort: rangeEnd, + }) + rangeStart = ports[i] + rangeEnd = ports[i] + } + } + + // Add the final range + result = append(result, agamodel.PortRange{ + FromPort: rangeStart, + ToPort: rangeEnd, + }) + + return result +} + +// canApplyAutoDiscoveryForGA checks if auto-discovery can be applied for the GlobalAccelerator +// Auto-discovery is only applicable if: +// 1. There's exactly one listener +// 2. The listener has exactly one endpoint group +// 3. The endpoint group has exactly one endpoint +// 4. The protocol or port ranges are not specified (needing discovery) +// 5. The loaded endpoint is usable (successful loading with valid ARN) +func canApplyAutoDiscoveryForGA(ga *agaapi.GlobalAccelerator, loadedEndpoints []*LoadedEndpoint) bool { + // Must have exactly one listener + if ga.Spec.Listeners == nil || len(*ga.Spec.Listeners) != 1 { + return false + } + + listener := (*ga.Spec.Listeners)[0] + + // Must have exactly one endpoint group + if listener.EndpointGroups == nil || len(*listener.EndpointGroups) != 1 { + return false + } + + endpointGroup := (*listener.EndpointGroups)[0] + + // Must have exactly one endpoint + if endpointGroup.Endpoints == nil || len(*endpointGroup.Endpoints) != 1 { + return false + } + + // Auto-discovery is allowed only when protocol and/or port ranges are not specified + needsProtocolDiscovery := listener.Protocol == nil + needsPortRangeDiscovery := listener.PortRanges == nil + + // Must need at least one type of discovery + if !needsProtocolDiscovery && !needsPortRangeDiscovery { + return false + } + + // For auto-discovery, we require exactly one usable endpoint with a valid ARN + if len(loadedEndpoints) != 1 || !loadedEndpoints[0].IsUsable() { + return false + } + + // Check if the endpoint is usable based on its type + loadedEndpoint := loadedEndpoints[0] + if loadedEndpoint.Type == agaapi.GlobalAcceleratorEndpointTypeEndpointID { + // For EndpointID type, we just need a valid ARN + return loadedEndpoint.ARN != "" + } else { + // For other types (Service, Ingress, Gateway), we need a K8s resource + return loadedEndpoint.K8sResource != nil + } +} diff --git a/pkg/aga/utils_test.go b/pkg/aga/utils_test.go index e444808c6..d2761b50b 100644 --- a/pkg/aga/utils_test.go +++ b/pkg/aga/utils_test.go @@ -1,9 +1,12 @@ package aga import ( + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" "testing" - "github.com/stretchr/testify/assert" + agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1" "sigs.k8s.io/aws-load-balancer-controller/pkg/config" agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga" ) @@ -96,6 +99,468 @@ func TestIsGlobalAcceleratorControllerEnabled(t *testing.T) { } } +func TestCanApplyAutoDiscoveryForGA(t *testing.T) { + protocol := agaapi.GlobalAcceleratorProtocolTCP + portRanges := []agaapi.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + } + + tests := []struct { + name string + ga *agaapi.GlobalAccelerator + loadedEndpoints []*LoadedEndpoint + want bool + }{ + { + name: "No listeners", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: nil, + }, + }, + loadedEndpoints: []*LoadedEndpoint{}, + want: false, + }, + { + name: "Empty listeners array", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{}, + }, + }, + loadedEndpoints: []*LoadedEndpoint{}, + want: false, + }, + { + name: "Multiple listeners", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + {}, {}, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{}, + want: false, + }, + { + name: "No endpoint groups", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: nil, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{}, + want: false, + }, + { + name: "Empty endpoint groups array", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{}, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{}, + want: false, + }, + { + name: "Multiple endpoint groups", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + {}, {}, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{}, + want: false, + }, + { + name: "No endpoints in endpoint group", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: nil, + }, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{}, + want: false, + }, + { + name: "Empty endpoints array", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{}, + }, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{}, + want: false, + }, + { + name: "Multiple endpoints", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + {}, {}, + }, + }, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{ + { + Status: EndpointStatusLoaded, + }, + { + Status: EndpointStatusLoaded, + }, + }, + want: false, + }, + { + name: "Both protocol and port ranges specified", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocol, + PortRanges: &portRanges, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + {}, + }, + }, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{ + { + Status: EndpointStatusLoaded, + }, + }, + want: false, + }, + { + name: "Failed endpoint loading", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocol, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + {}, + }, + }, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{ + { + Status: EndpointStatusWarning, + Error: assert.AnError, + }, + }, + want: false, + }, + { + name: "No loaded endpoints", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocol, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + {}, + }, + }, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{}, + want: false, + }, + { + name: "Multiple loaded endpoints", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocol, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + {}, {}, + }, + }, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{ + { + Status: EndpointStatusLoaded, + }, + { + Status: EndpointStatusLoaded, + }, + }, + want: false, + }, + { + name: "Valid for auto-discovery - protocol specified, port ranges not specified", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + Protocol: &protocol, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + }, + }, + }, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{ + { + Status: EndpointStatusLoaded, + Type: agaapi.GlobalAcceleratorEndpointTypeService, + K8sResource: &corev1.Service{}, // Add K8sResource to make the test pass + }, + }, + want: true, + }, + { + name: "Valid for auto-discovery - port ranges specified, protocol not specified", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + PortRanges: &portRanges, + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeService, + }, + }, + }, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{ + { + Status: EndpointStatusLoaded, + Type: agaapi.GlobalAcceleratorEndpointTypeService, + K8sResource: &corev1.Service{}, // Add K8sResource to make the test pass + }, + }, + want: true, + }, + { + name: "Valid for auto-discovery - both protocol and port ranges not specified", + ga: &agaapi.GlobalAccelerator{ + Spec: agaapi.GlobalAcceleratorSpec{ + Listeners: &[]agaapi.GlobalAcceleratorListener{ + { + EndpointGroups: &[]agaapi.GlobalAcceleratorEndpointGroup{ + { + Endpoints: &[]agaapi.GlobalAcceleratorEndpoint{ + { + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + EndpointID: awssdk.String("some-arn"), + Weight: awssdk.Int32(112), + }, + }, + }, + }, + }, + }, + }, + }, + loadedEndpoints: []*LoadedEndpoint{ + { + Status: EndpointStatusLoaded, + Type: agaapi.GlobalAcceleratorEndpointTypeEndpointID, + ARN: "arn:aws:elasticloadbalancing:us-west-2:123456789012:loadbalancer/app/test-endpoint/1234567890123456", // Add ARN to make the test pass + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := canApplyAutoDiscoveryForGA(tt.ga, tt.loadedEndpoints) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestConsolidatePortRanges(t *testing.T) { + tests := []struct { + name string + ports []int32 + expected []agamodel.PortRange + }{ + { + name: "empty ports", + ports: []int32{}, + expected: nil, + }, + { + name: "single port", + ports: []int32{80}, + expected: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + }, + }, + { + name: "consecutive ports", + ports: []int32{80, 81, 82}, + expected: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 82, + }, + }, + }, + { + name: "non-consecutive ports", + ports: []int32{80, 443, 8080}, + expected: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 80, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8080, + }, + }, + }, + { + name: "mixed consecutive and non-consecutive ports", + ports: []int32{80, 81, 443, 8080, 8081, 8082}, + expected: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 81, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8082, + }, + }, + }, + { + name: "unsorted ports", + ports: []int32{443, 80, 8080, 81}, + expected: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 81, + }, + { + FromPort: 443, + ToPort: 443, + }, + { + FromPort: 8080, + ToPort: 8080, + }, + }, + }, + { + name: "duplicate ports", + ports: []int32{80, 80, 81, 443, 443}, + expected: []agamodel.PortRange{ + { + FromPort: 80, + ToPort: 81, + }, + { + FromPort: 443, + ToPort: 443, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := consolidatePortRanges(tt.ports) + assert.Equal(t, tt.expected, result) + }) + } +} + func TestIsPortInRanges(t *testing.T) { tests := []struct { name string diff --git a/pkg/deploy/aga/endpoint_group_manager.go b/pkg/deploy/aga/endpoint_group_manager.go index c6d2529ab..9356384d0 100644 --- a/pkg/deploy/aga/endpoint_group_manager.go +++ b/pkg/deploy/aga/endpoint_group_manager.go @@ -47,7 +47,7 @@ type defaultEndpointGroupManager struct { // buildSDKPortOverrides converts model port overrides to SDK port overrides func (m *defaultEndpointGroupManager) buildSDKPortOverrides(modelPortOverrides []agamodel.PortOverride) []agatypes.PortOverride { if len(modelPortOverrides) == 0 { - return nil + return []agatypes.PortOverride{} } portOverrides := make([]agatypes.PortOverride, 0, len(modelPortOverrides)) @@ -133,12 +133,12 @@ func (m *defaultEndpointGroupManager) buildSDKUpdateEndpointGroupInput(_ context // Convert TrafficDialPercentage from int32 to float32 if provided if resEndpointGroup.Spec.TrafficDialPercentage != nil { updateInput.TrafficDialPercentage = awssdk.Float32(float32(*resEndpointGroup.Spec.TrafficDialPercentage)) + } else { + updateInput.TrafficDialPercentage = nil } // Add port overrides if specified - if len(resEndpointGroup.Spec.PortOverrides) > 0 { - updateInput.PortOverrides = m.buildSDKPortOverrides(resEndpointGroup.Spec.PortOverrides) - } + updateInput.PortOverrides = m.buildSDKPortOverrides(resEndpointGroup.Spec.PortOverrides) return updateInput, nil } @@ -399,7 +399,7 @@ func (m *defaultEndpointGroupManager) ManageEndpoints( configsToAdd, configsToUpdate, endpointsToRemove, isUpdateRequired := m.detectEndpointDrift(sdkEndpoints, resEndpointConfigs) if len(configsToAdd) == 0 && len(endpointsToRemove) == 0 && !isUpdateRequired { - m.logger.V(1).Info("No drift found for endpoint group", "endpointGroupARN", endpointGroupARN) + m.logger.V(1).Info("No drift found for endpoints", "endpointGroupARN", endpointGroupARN) return nil } diff --git a/pkg/ingress/model_builder.go b/pkg/ingress/model_builder.go index f9ba29e31..2e693ad92 100644 --- a/pkg/ingress/model_builder.go +++ b/pkg/ingress/model_builder.go @@ -7,6 +7,7 @@ import ( "k8s.io/apimachinery/pkg/util/cache" "reflect" "sigs.k8s.io/aws-load-balancer-controller/pkg/shared_utils" + "sort" "strconv" "sync" "time" @@ -46,7 +47,7 @@ const ( // ModelBuilder is responsible for build mode stack for a IngressGroup. type ModelBuilder interface { // build mode stack for a IngressGroup. - Build(ctx context.Context, ingGroup Group, metricsCollector lbcmetrics.MetricCollector) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, bool, *elbv2model.LoadBalancer, error) + Build(ctx context.Context, ingGroup Group, metricsCollector lbcmetrics.MetricCollector) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, bool, *elbv2model.LoadBalancer, []int32, error) } // NewDefaultModelBuilder constructs new defaultModelBuilder. @@ -135,7 +136,7 @@ type defaultModelBuilder struct { } // build mode stack for a IngressGroup. -func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group, metricsCollector lbcmetrics.MetricCollector) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, bool, *elbv2model.LoadBalancer, error) { +func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group, metricsCollector lbcmetrics.MetricCollector) (core.Stack, *elbv2model.LoadBalancer, []types.NamespacedName, bool, *elbv2model.LoadBalancer, []int32, error) { stack := core.NewDefaultStack(core.StackID(ingGroup.ID)) task := &defaultModelBuildTask{ @@ -193,11 +194,22 @@ func (b *defaultModelBuilder) Build(ctx context.Context, ingGroup Group, metrics localFrontendNlbData: make(map[string]*elbv2model.FrontendNlbTargetGroupState), } if err := task.run(ctx); err != nil { - return nil, nil, nil, false, nil, err + return nil, nil, nil, false, nil, nil, err } + // Extract just the port numbers from listenPortConfigByPort + var listenerPorts []int32 + for port := range task.listenPortConfigByPort { + listenerPorts = append(listenerPorts, port) + } + + // Sort ports for consistency + sort.Slice(listenerPorts, func(i, j int) bool { + return listenerPorts[i] < listenerPorts[j] + }) + _ = elbv2model.NewFrontendNlbTargetGroupDesiredState(task.stack, task.localFrontendNlbData) - return task.stack, task.loadBalancer, task.secretKeys, task.backendSGAllocated, task.frontendNlb, nil + return task.stack, task.loadBalancer, task.secretKeys, task.backendSGAllocated, task.frontendNlb, listenerPorts, nil } // the default model build task @@ -257,6 +269,7 @@ type defaultModelBuildTask struct { localFrontendNlbData map[string]*elbv2model.FrontendNlbTargetGroupState targetGroupNameToArnMapper shared_utils.TargetGroupARNMapper webACLNameToArnMapper *webACLNameToArnMapper + listenPortConfigByPort map[int32]listenPortConfig metricsCollector lbcmetrics.MetricCollector } @@ -298,25 +311,25 @@ func (t *defaultModelBuildTask) run(ctx context.Context) error { } } - listenPortConfigByPort := make(map[int32]listenPortConfig) + t.listenPortConfigByPort = make(map[int32]listenPortConfig) for port, cfgs := range listenPortConfigsByPort { mergedCfg, err := t.mergeListenPortConfigs(ctx, cfgs) if err != nil { return errors.Wrapf(err, "failed to merge listenPort config for port: %v", port) } - listenPortConfigByPort[port] = mergedCfg + t.listenPortConfigByPort[port] = mergedCfg } - lb, err := t.buildLoadBalancer(ctx, listenPortConfigByPort) + lb, err := t.buildLoadBalancer(ctx, t.listenPortConfigByPort) if err != nil { return ctrlerrors.NewErrorWithMetrics(controllerName, "build_load_balancer_error", err, t.metricsCollector) } - t.sslRedirectConfig, err = t.buildSSLRedirectConfig(ctx, listenPortConfigByPort) + t.sslRedirectConfig, err = t.buildSSLRedirectConfig(ctx, t.listenPortConfigByPort) if err != nil { return ctrlerrors.NewErrorWithMetrics(controllerName, "build_ssl_redirct_config_error", err, t.metricsCollector) } - for port, cfg := range listenPortConfigByPort { + for port, cfg := range t.listenPortConfigByPort { ingList := ingListByPort[port] ls, err := t.buildListener(ctx, lb.LoadBalancerARN(), port, cfg, ingList) if err != nil { diff --git a/pkg/ingress/model_builder_test.go b/pkg/ingress/model_builder_test.go index a42be793d..5e1e8c669 100644 --- a/pkg/ingress/model_builder_test.go +++ b/pkg/ingress/model_builder_test.go @@ -4833,7 +4833,7 @@ func Test_defaultModelBuilder_Build(t *testing.T) { b.enableIPTargetType = *tt.enableIPTargetType } - gotStack, _, _, _, _, err := b.Build(context.Background(), tt.args.ingGroup, b.metricsCollector) + gotStack, _, _, _, _, _, err := b.Build(context.Background(), tt.args.ingGroup, b.metricsCollector) if tt.wantErr != "" { assert.EqualError(t, err, tt.wantErr) } else {