Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions controllers/aga/globalaccelerator_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ const (
requeueMessage = "Monitoring provisioning state"
statusUpdateRequeueTime = 1 * time.Minute

// Status reason constants
EndpointLoadFailed = "EndpointLoadFailed"

// Metric stage constants
MetricStageFetchGlobalAccelerator = "fetch_globalAccelerator"
MetricStageAddFinalizers = "add_finalizers"
Expand All @@ -90,7 +87,7 @@ const (
func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder record.EventRecorder, finalizerManager k8s.FinalizerManager, config config.ControllerConfig, cloud services.Cloud, logger logr.Logger, metricsCollector lbcmetrics.MetricCollector, reconcileCounters *metricsutil.ReconcileCounters) *globalAcceleratorReconciler {

// Create tracking provider
trackingProvider := tracking.NewDefaultProvider(agaTagPrefix, config.ClusterName, tracking.WithRegion(config.AWSConfig.Region))
trackingProvider := tracking.NewDefaultProvider(agaTagPrefix, config.ClusterName, tracking.WithRegion(cloud.Region()))

// Create model builder
agaModelBuilder := aga.NewDefaultModelBuilder(
Expand All @@ -99,7 +96,7 @@ func NewGlobalAcceleratorReconciler(k8sClient client.Client, eventRecorder recor
trackingProvider,
config.FeatureGates,
config.ClusterName,
config.AWSConfig.Region,
cloud.Region(),
config.DefaultTags,
config.ExternalManagedTags,
logger.WithName("aga-model-builder"),
Expand Down Expand Up @@ -272,11 +269,11 @@ func (r *globalAcceleratorReconciler) buildModel(ctx context.Context, ga *agaapi
func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx context.Context, ga *agaapi.GlobalAccelerator) error {
r.logger.Info("Reconciling GlobalAccelerator resources", "globalAccelerator", k8s.NamespacedName(ga))

// Get all endpoints from GA
endpoints := aga.GetAllEndpointsFromGA(ga)
// Get all desired endpoints from GA
endpoints := aga.GetAllDesiredEndpointsFromGA(ga)

// Track referenced endpoints
r.referenceTracker.UpdateReferencesForGA(ga, endpoints)
r.referenceTracker.UpdateDesiredEndpointReferencesForGA(ga, endpoints)

// Update resource watches with the endpointResourcesManager
r.endpointResourcesManager.MonitorEndpointResources(ga, endpoints)
Expand All @@ -285,10 +282,10 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co
_, fatalErrors := r.endpointLoader.LoadEndpoints(ctx, ga, endpoints)
if len(fatalErrors) > 0 {
err := fmt.Errorf("failed to load endpoints: %v", fatalErrors[0])
r.logger.Error(err, "Fatal error loading endpoints")

r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedEndpointLoad, fmt.Sprintf("Failed to reconcile due to %v", err))
r.logger.Error(err, fmt.Sprintf("fatal error loading endpoints for %v", k8s.NamespacedName(ga)))
// Handle other endpoint loading errors
if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, EndpointLoadFailed, err.Error()); statusErr != nil {
if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.EndpointLoadFailed, err.Error()); statusErr != nil {
r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after endpoint load failure")
}
return err
Expand All @@ -302,6 +299,8 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co
}
r.metricsCollector.ObserveControllerReconcileLatency(controllerName, MetricStageBuildModel, buildModelFn)
if err != nil {
r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GatewayEventReasonFailedBuildModel, fmt.Sprintf("Failed to build model: %v", err))
r.logger.Error(err, fmt.Sprintf("Failed to build model for: %v", k8s.NamespacedName(ga)))
// Update status to indicate model building failure
if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.ModelBuildFailed, fmt.Sprintf("Failed to build model: %v", err)); statusErr != nil {
r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after model build failure")
Expand All @@ -316,7 +315,7 @@ func (r *globalAcceleratorReconciler) reconcileGlobalAcceleratorResources(ctx co
r.metricsCollector.ObserveControllerReconcileLatency(controllerName, MetricStageDeployStack, deployStackFn)
if err != nil {
r.eventRecorder.Event(ga, corev1.EventTypeWarning, k8s.GlobalAcceleratorEventReasonFailedDeploy, fmt.Sprintf("Failed to deploy stack due to %v", err))

r.logger.Error(err, fmt.Sprintf("Failed to deploy stack for: %v", k8s.NamespacedName(ga)))
// Update status to indicate deployment failure
if statusErr := r.statusUpdater.UpdateStatusFailure(ctx, ga, agadeploy.DeploymentFailed, fmt.Sprintf("Failed to deploy stack: %v", err)); statusErr != nil {
r.logger.Error(statusErr, "Failed to update GlobalAccelerator status after deployment failure")
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ require (
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
github.com/hashicorp/golang-lru v1.0.2
github.com/onsi/ginkgo/v2 v2.23.3
github.com/onsi/gomega v1.37.0
github.com/pkg/errors v0.9.1
Expand Down Expand Up @@ -104,7 +105,6 @@ require (
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/golang-lru v1.0.2 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/imkira/go-interpol v1.1.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func main() {
}

// Setup GlobalAccelerator controller only if enabled
if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) {
if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, cloud.Region()) {
agaReconciler := agacontroller.NewGlobalAcceleratorReconciler(mgr.GetClient(), mgr.GetEventRecorderFor("globalAccelerator"),
finalizerManager, controllerCFG, cloud, ctrl.Log.WithName("controllers").WithName("globalAccelerator"), lbcMetricsCollector, reconcileCounters)
if err := agaReconciler.SetupWithManager(ctx, mgr, clientSet); err != nil {
Expand Down Expand Up @@ -442,7 +442,7 @@ func main() {
networkingwebhook.NewIngressValidator(mgr.GetClient(), controllerCFG.IngressConfig, ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr)

// Setup GlobalAccelerator validator only if enabled
if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, controllerCFG.AWSConfig.Region) {
if aga.IsAGAControllerEnabled(controllerCFG.FeatureGates, cloud.Region()) {
agawebhook.NewGlobalAcceleratorValidator(ctrl.Log, lbcMetricsCollector).SetupWithManager(mgr)
}
//+kubebuilder:scaffold:builder
Expand Down
4 changes: 2 additions & 2 deletions pkg/aga/endpoint_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ type EndpointReference struct {
Endpoint *agaapi.GlobalAcceleratorEndpoint
}

// GetAllEndpointsFromGA extracts all endpoint references from a GlobalAccelerator resource
func GetAllEndpointsFromGA(ga *agaapi.GlobalAccelerator) []EndpointReference {
// GetAllDesiredEndpointsFromGA extracts all endpoint references from a GlobalAccelerator resource
func GetAllDesiredEndpointsFromGA(ga *agaapi.GlobalAccelerator) []EndpointReference {
if ga == nil || ga.Spec.Listeners == nil {
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/aga/endpoint_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func TestGetAllEndpointsFromGA(t *testing.T) {
}
}

result := GetAllEndpointsFromGA(tt.ga)
result := GetAllDesiredEndpointsFromGA(tt.ga)

// Compare lengths
assert.Equal(t, len(tt.expected), len(result))
Expand Down
250 changes: 250 additions & 0 deletions pkg/aga/model_build_endpoint_group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
package aga

import (
"context"
"fmt"
awssdk "github.com/aws/aws-sdk-go-v2/aws"
agaapi "sigs.k8s.io/aws-load-balancer-controller/apis/aga/v1beta1"
agamodel "sigs.k8s.io/aws-load-balancer-controller/pkg/model/aga"
"sigs.k8s.io/aws-load-balancer-controller/pkg/model/core"
)

// endpointGroupBuilder builds EndpointGroup model resources
type endpointGroupBuilder interface {
// Build builds all endpoint groups for all listeners
Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error)

// buildEndpointGroupsForListener builds endpoint groups for a specific listener
buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error)
}

// NewEndpointGroupBuilder constructs new endpointGroupBuilder
func NewEndpointGroupBuilder(clusterRegion string) endpointGroupBuilder {
return &defaultEndpointGroupBuilder{
clusterRegion: clusterRegion,
}
}

var _ endpointGroupBuilder = &defaultEndpointGroupBuilder{}

type defaultEndpointGroupBuilder struct {
clusterRegion string
}

// Build builds EndpointGroup model resources
func (b *defaultEndpointGroupBuilder) Build(ctx context.Context, stack core.Stack, listeners []*agamodel.Listener, listenerConfigs []agaapi.GlobalAcceleratorListener) ([]*agamodel.EndpointGroup, error) {
if listeners == nil || len(listeners) == 0 {
return nil, nil
}

var result []*agamodel.EndpointGroup

// Create a map of all listener port ranges
listenerPortRanges := make(map[string][]agamodel.PortRange) // Maps listener ID to its port ranges
for _, listener := range listeners {
listenerPortRanges[listener.ID()] = listener.Spec.PortRanges
}

for i, listener := range listeners {
listenerConfig := listenerConfigs[i]
if listenerConfig.EndpointGroups == nil {
continue
}

listenerEndpointGroups, err := b.buildEndpointGroupsForListener(ctx, stack, listener, *listenerConfig.EndpointGroups, i)
if err != nil {
return nil, err
}
result = append(result, listenerEndpointGroups...)
}

// Validate endpoint ports in all port overrides across all listeners
if err := b.validateEndpointPortOverridesCrossListeners(result, listenerPortRanges); err != nil {
return nil, err
}

return result, nil
}

// validateEndpointPortOverridesCrossListeners performs validations for endpoint port overrides across all listeners
func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesCrossListeners(endpointGroups []*agamodel.EndpointGroup, listenerPortRanges map[string][]agamodel.PortRange) error {
// Track endpoint port usage across all endpoint groups
endpointPortUsage := make(map[int32]string) // Maps endpoint port to listener ID

// Check all endpoint groups for port overrides
for _, endpointGroup := range endpointGroups {
listenerID := endpointGroup.Listener.ID()

for _, portOverride := range endpointGroup.Spec.PortOverrides {
endpointPort := portOverride.EndpointPort

// Rule 1: Check if endpoint port is within any listener's port range
if err := b.validateEndpointPortOverridesWithinListener(endpointPort, listenerPortRanges); err != nil {
return err
}

// Rule 2: Check for duplicate endpoint port usage across listeners
if existingListenerID, exists := endpointPortUsage[endpointPort]; exists && existingListenerID != listenerID {
return fmt.Errorf("duplicate endpoint port %d: the same endpoint port cannot be used in port overrides from different listeners (used in %s and %s)",
endpointPort, existingListenerID, listenerID)
}

// Register this endpoint port usage
endpointPortUsage[endpointPort] = listenerID
}
}

return nil
}

// validateEndpointPortOverridesWithinListener checks if an endpoint port is within any listener's port range
func (b *defaultEndpointGroupBuilder) validateEndpointPortOverridesWithinListener(endpointPort int32, listenerPortRanges map[string][]agamodel.PortRange) error {
for listenerID, portRanges := range listenerPortRanges {
if IsPortInRanges(endpointPort, portRanges) {
// Find the specific port range for the error message
for _, portRange := range portRanges {
if endpointPort >= portRange.FromPort && endpointPort <= portRange.ToPort {
return fmt.Errorf("endpoint port %d conflicts with listener %s port range %d-%d: endpoint port cannot be included in any listener port range",
endpointPort, listenerID, portRange.FromPort, portRange.ToPort)
}
}
}
}
return nil
}

// buildEndpointGroupsForListener builds EndpointGroup models for a specific listener
func (b *defaultEndpointGroupBuilder) buildEndpointGroupsForListener(ctx context.Context, stack core.Stack, listener *agamodel.Listener, endpointGroups []agaapi.GlobalAcceleratorEndpointGroup, listenerIndex int) ([]*agamodel.EndpointGroup, error) {
var result []*agamodel.EndpointGroup

for i, endpointGroup := range endpointGroups {
spec, err := b.buildEndpointGroupSpec(ctx, listener, endpointGroup)
if err != nil {
return nil, err
}

resourceID := fmt.Sprintf("EndpointGroup-%d-%d", listenerIndex, i)
endpointGroupModel := agamodel.NewEndpointGroup(stack, resourceID, spec, listener)
result = append(result, endpointGroupModel)
}

return result, nil
}

// buildEndpointGroupSpec builds the EndpointGroupSpec for a single EndpointGroup model resource
func (b *defaultEndpointGroupBuilder) buildEndpointGroupSpec(ctx context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (agamodel.EndpointGroupSpec, error) {
region, err := b.determineRegion(endpointGroup)
if err != nil {
return agamodel.EndpointGroupSpec{}, err
}

// Handle trafficDialPercentage
trafficDialPercentage := endpointGroup.TrafficDialPercentage

portOverrides, err := b.buildPortOverrides(ctx, listener, endpointGroup)
if err != nil {
return agamodel.EndpointGroupSpec{}, err
}

return agamodel.EndpointGroupSpec{
ListenerARN: listener.ListenerARN(),
Region: region,
TrafficDialPercentage: trafficDialPercentage,
PortOverrides: portOverrides,
}, nil
}

// validateListenerPortOverrideWithinListenerPortRanges ensures all listener ports used in port overrides are
// contained within the listener's port ranges
func (b *defaultEndpointGroupBuilder) validateListenerPortOverrideWithinListenerPortRanges(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error {
if len(portOverrides) == 0 {
return nil
}

for _, portOverride := range portOverrides {
listenerPort := portOverride.ListenerPort
if !IsPortInRanges(listenerPort, listener.Spec.PortRanges) {
return fmt.Errorf("port override listener port %d is not within any listener port ranges - this will cause AWS Global Accelerator to reject the configuration", listenerPort)
}
}
return nil
}

// determineRegion determines the region for the endpoint group
func (b *defaultEndpointGroupBuilder) determineRegion(endpointGroup agaapi.GlobalAcceleratorEndpointGroup) (string, error) {
// Use explicit region from endpoint group if specified
if endpointGroup.Region != nil && awssdk.ToString(endpointGroup.Region) != "" {
return awssdk.ToString(endpointGroup.Region), nil
}

// Default to cluster region if available
if b.clusterRegion != "" {
return b.clusterRegion, nil
}
return "", fmt.Errorf("region is required for endpoint group but neither specified in the endpoint group nor available from cluster configuration")
}

// buildPortOverrides builds the port overrides for the endpoint group
func (b *defaultEndpointGroupBuilder) buildPortOverrides(_ context.Context, listener *agamodel.Listener, endpointGroup agaapi.GlobalAcceleratorEndpointGroup) ([]agamodel.PortOverride, error) {
if endpointGroup.PortOverrides == nil {
return nil, nil
}

var portOverrides []agamodel.PortOverride
for _, po := range *endpointGroup.PortOverrides {
portOverrides = append(portOverrides, agamodel.PortOverride{
ListenerPort: po.ListenerPort,
EndpointPort: po.EndpointPort,
})
}

// Validate all port override rules
if err := b.validatePortOverrides(listener, portOverrides); err != nil {
return []agamodel.PortOverride{}, err
}

return portOverrides, nil
}

// validateNoDuplicatePorts checks both listener and endpoint ports for duplicates in a single pass
func (b *defaultEndpointGroupBuilder) validateNoDuplicatePorts(portOverrides []agamodel.PortOverride) error {
if len(portOverrides) <= 1 {
return nil
}

listenerPorts := make(map[int32]bool)
endpointPorts := make(map[int32]bool)

for _, portOverride := range portOverrides {
// Check for duplicate listener ports
listenerPort := portOverride.ListenerPort
if listenerPorts[listenerPort] {
return fmt.Errorf("duplicate listener port %d in port overrides: each listener port can only be used once in port overrides for an endpoint group", listenerPort)
}
listenerPorts[listenerPort] = true

// Check for duplicate endpoint ports
endpointPort := portOverride.EndpointPort
if endpointPorts[endpointPort] {
return fmt.Errorf("duplicate endpoint port %d in port overrides: each endpoint port can only be used once in port overrides for an endpoint group", endpointPort)
}
endpointPorts[endpointPort] = true
}

return nil
}

// validatePortOverrides is a wrapper function that runs all port override validation rules
func (b *defaultEndpointGroupBuilder) validatePortOverrides(listener *agamodel.Listener, portOverrides []agamodel.PortOverride) error {
// Validate listener port overrides against listener port ranges
if err := b.validateListenerPortOverrideWithinListenerPortRanges(listener, portOverrides); err != nil {
return err
}

// Check for duplicate listener and endpoint ports within this endpoint group's port overrides
if err := b.validateNoDuplicatePorts(portOverrides); err != nil {
return err
}

return nil
}
Loading
Loading