diff --git a/pkg/backend/endpoint_resolver.go b/pkg/backend/endpoint_resolver.go index 507ac16ca2..7d39cbec58 100644 --- a/pkg/backend/endpoint_resolver.go +++ b/pkg/backend/endpoint_resolver.go @@ -3,6 +3,7 @@ package backend import ( "context" "fmt" + "net/netip" awssdk "github.com/aws/aws-sdk-go-v2/aws" "github.com/go-logr/logr" @@ -13,6 +14,7 @@ import ( "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/intstr" "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" + "sigs.k8s.io/aws-load-balancer-controller/pkg/networking" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -72,7 +74,7 @@ func (r *defaultEndpointResolver) ResolvePodEndpoints(ctx context.Context, svcKe if err != nil { return nil, false, err } - return r.resolvePodEndpointsWithEndpointsData(ctx, svcKey, svcPort, endpointsDataList, resolveOpts.PodReadinessGates) + return r.resolvePodEndpointsWithEndpointsData(ctx, svcKey, svcPort, endpointsDataList, resolveOpts.PodReadinessGates, resolveOpts.cidrs) } func (r *defaultEndpointResolver) ResolveNodePortEndpoints(ctx context.Context, svcKey types.NamespacedName, port intstr.IntOrString, opts ...EndpointResolveOption) ([]NodePortEndpoint, error) { @@ -140,7 +142,7 @@ func (r *defaultEndpointResolver) computeServiceEndpointsData(ctx context.Contex return endpointsDataList, nil } -func (r *defaultEndpointResolver) resolvePodEndpointsWithEndpointsData(ctx context.Context, svcKey types.NamespacedName, svcPort corev1.ServicePort, endpointsDataList []EndpointsData, podReadinessGates []corev1.PodConditionType) ([]PodEndpoint, bool, error) { +func (r *defaultEndpointResolver) resolvePodEndpointsWithEndpointsData(ctx context.Context, svcKey types.NamespacedName, svcPort corev1.ServicePort, endpointsDataList []EndpointsData, podReadinessGates []corev1.PodConditionType, cidrs []netip.Prefix) ([]PodEndpoint, bool, error) { var readyPodEndpoints []PodEndpoint var unknownPodEndpoints []PodEndpoint containsPotentialReadyEndpoints := false @@ -171,6 +173,19 @@ func (r *defaultEndpointResolver) resolvePodEndpointsWithEndpointsData(ctx conte continue } + if len(cidrs) > 0 { + ip, err := netip.ParseAddr(epAddr) + if err != nil { + return nil, false, fmt.Errorf("parse ip addr: %w", err) + } + if !networking.IsIPWithinCIDRs(ip, cidrs) { + // this condition should never hit as long as cidrs are configured properly. if hit, then look at the cidr configured + // and make sure podIPs are within the range passed in. + r.logger.Error(fmt.Errorf("ip from endpoints being filtered"), fmt.Sprintf("unexpected condition hit for %s and ip: %s and cidrs: %s", svcKey.Name, epAddr, cidrs)) + continue + } + } + podEndpoint := buildPodEndpoint(pod, epAddr, epPort) // Recommendation from Kubernetes is to consider unknown ready status as ready (ready == nil) if ep.Conditions.Ready == nil || *ep.Conditions.Ready { diff --git a/pkg/backend/endpoint_types.go b/pkg/backend/endpoint_types.go index 4fb677fe94..60cce3a2c4 100644 --- a/pkg/backend/endpoint_types.go +++ b/pkg/backend/endpoint_types.go @@ -2,6 +2,8 @@ package backend import ( "fmt" + "net/netip" + corev1 "k8s.io/api/core/v1" discv1 "k8s.io/api/discovery/v1" "k8s.io/apimachinery/pkg/labels" @@ -54,6 +56,9 @@ type EndpointResolveOptions struct { // [Pod Endpoint] if pod readinessGates is defined, then pods from unready addresses with any of these readinessGates and containersReady condition will be included as well. // By default, no readinessGate is specified. PodReadinessGates []corev1.PodConditionType + + // cidrs will be used to filter out the list of IPs returned by the resolver + cidrs []netip.Prefix } func (opts *EndpointResolveOptions) ApplyOptions(options []EndpointResolveOption) { @@ -78,6 +83,14 @@ func WithPodReadinessGate(cond corev1.PodConditionType) EndpointResolveOption { } } +// WithCIDRRanges is an option that appends cidrs into EndpointResolveOptions to filter +// out the set of IPs to register +func WithCIDRRanges(cidrs []netip.Prefix) EndpointResolveOption { + return func(opts *EndpointResolveOptions) { + opts.cidrs = cidrs + } +} + // defaultEndpointResolveOptions returns the default value for EndpointResolveOptions. func defaultEndpointResolveOptions() EndpointResolveOptions { return EndpointResolveOptions{ diff --git a/pkg/targetgroupbinding/networking_manager.go b/pkg/targetgroupbinding/networking_manager.go index 21fa4b741f..7b10448891 100644 --- a/pkg/targetgroupbinding/networking_manager.go +++ b/pkg/targetgroupbinding/networking_manager.go @@ -202,7 +202,7 @@ func (m *defaultNetworkingManager) reconcileWithIngressPermissionsPerSG(ctx cont computedForAllTGBs := m.consolidateIngressPermissionsPerSGByTGB(ctx, tgbsWithNetworking) aggregatedIngressPermissionsPerSG := m.computeAggregatedIngressPermissionsPerSG(ctx) - permissionSelector := labels.SelectorFromSet(labels.Set{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}) + permissionSelector := labels.SelectorFromSet(labels.Set{tgbNetworkingIPPermissionLabelKey: m.clusterName}) var sgReconciliationErrors []error for sgID, permissions := range aggregatedIngressPermissionsPerSG { if err := m.sgReconciler.ReconcileIngress(ctx, sgID, permissions, @@ -421,7 +421,7 @@ func (m *defaultNetworkingManager) computePermissionsForPeerPort(ctx context.Con }) } - permissionLabels := map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue} + permissionLabels := map[string]string{tgbNetworkingIPPermissionLabelKey: m.clusterName} if peer.SecurityGroup != nil { groupID := peer.SecurityGroup.GroupID permissions := make([]networking.IPPermissionInfo, 0, len(sdkFromToPortPairs)) @@ -484,7 +484,7 @@ func (m *defaultNetworkingManager) gcIngressPermissionsFromUnusedEndpointSGs(ctx usedEndpointSGs := sets.StringKeySet(ingressPermissionsPerSG) unusedEndpointSGs := endpointSGs.Difference(usedEndpointSGs) - permissionSelector := labels.SelectorFromSet(labels.Set{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}) + permissionSelector := labels.SelectorFromSet(labels.Set{tgbNetworkingIPPermissionLabelKey: m.clusterName}) for sgID := range unusedEndpointSGs { err := m.sgReconciler.ReconcileIngress(ctx, sgID, nil, networking.WithPermissionSelector(permissionSelector)) diff --git a/pkg/targetgroupbinding/networking_manager_test.go b/pkg/targetgroupbinding/networking_manager_test.go index 74482a8439..fad4a601c4 100644 --- a/pkg/targetgroupbinding/networking_manager_test.go +++ b/pkg/targetgroupbinding/networking_manager_test.go @@ -17,6 +17,8 @@ import ( "sigs.k8s.io/aws-load-balancer-controller/pkg/networking" ) +const testClusterName = "test-123" + func Test_defaultNetworkingManager_computeIngressPermissionsForTGBNetworking(t *testing.T) { port8080 := intstr.FromInt(8080) port8443 := intstr.FromInt(8443) @@ -60,12 +62,12 @@ func Test_defaultNetworkingManager_computeIngressPermissionsForTGBNetworking(t * ToPort: awssdk.Int32(8080), UserIdGroupPairs: []ec2types.UserIdGroupPair{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), GroupId: awssdk.String("sg-abcdefg"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, }, }, @@ -107,12 +109,12 @@ func Test_defaultNetworkingManager_computeIngressPermissionsForTGBNetworking(t * ToPort: awssdk.Int32(8080), UserIdGroupPairs: []ec2types.UserIdGroupPair{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), GroupId: awssdk.String("sg-abcdefg"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, { Permission: ec2types.IpPermission{ @@ -121,12 +123,12 @@ func Test_defaultNetworkingManager_computeIngressPermissionsForTGBNetworking(t * ToPort: awssdk.Int32(8443), UserIdGroupPairs: []ec2types.UserIdGroupPair{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), GroupId: awssdk.String("sg-abcdefg"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, { Permission: ec2types.IpPermission{ @@ -135,12 +137,12 @@ func Test_defaultNetworkingManager_computeIngressPermissionsForTGBNetworking(t * ToPort: awssdk.Int32(8080), IpRanges: []ec2types.IpRange{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), CidrIp: awssdk.String("192.168.1.1/16"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, { Permission: ec2types.IpPermission{ @@ -149,12 +151,12 @@ func Test_defaultNetworkingManager_computeIngressPermissionsForTGBNetworking(t * ToPort: awssdk.Int32(8443), IpRanges: []ec2types.IpRange{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), CidrIp: awssdk.String("192.168.1.1/16"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, }, }, @@ -202,12 +204,12 @@ func Test_defaultNetworkingManager_computeIngressPermissionsForTGBNetworking(t * ToPort: awssdk.Int32(8080), UserIdGroupPairs: []ec2types.UserIdGroupPair{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), GroupId: awssdk.String("sg-abcdefg"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, { Permission: ec2types.IpPermission{ @@ -216,19 +218,21 @@ func Test_defaultNetworkingManager_computeIngressPermissionsForTGBNetworking(t * ToPort: awssdk.Int32(8443), IpRanges: []ec2types.IpRange{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), CidrIp: awssdk.String("192.168.1.1/16"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m := &defaultNetworkingManager{} + m := &defaultNetworkingManager{ + clusterName: testClusterName, + } got, err := m.computeIngressPermissionsForTGBNetworking(context.Background(), tt.args.tgbNetworking, tt.args.pods) if tt.wantErr != nil { assert.EqualError(t, err, tt.wantErr.Error()) @@ -277,12 +281,12 @@ func Test_defaultNetworkingManager_computePermissionsForPeerPort(t *testing.T) { ToPort: awssdk.Int32(8080), UserIdGroupPairs: []ec2types.UserIdGroupPair{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), GroupId: awssdk.String("sg-abcdefg"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, }, }, @@ -309,11 +313,11 @@ func Test_defaultNetworkingManager_computePermissionsForPeerPort(t *testing.T) { IpRanges: []ec2types.IpRange{ { CidrIp: awssdk.String("192.168.1.1/16"), - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, }, }, @@ -340,11 +344,11 @@ func Test_defaultNetworkingManager_computePermissionsForPeerPort(t *testing.T) { Ipv6Ranges: []ec2types.Ipv6Range{ { CidrIpv6: awssdk.String("2002::1234:abcd:ffff:c0a8:101/64"), - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, }, }, @@ -389,12 +393,12 @@ func Test_defaultNetworkingManager_computePermissionsForPeerPort(t *testing.T) { ToPort: awssdk.Int32(80), UserIdGroupPairs: []ec2types.UserIdGroupPair{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), GroupId: awssdk.String("sg-abcdefg"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, { Permission: ec2types.IpPermission{ @@ -403,12 +407,12 @@ func Test_defaultNetworkingManager_computePermissionsForPeerPort(t *testing.T) { ToPort: awssdk.Int32(8080), UserIdGroupPairs: []ec2types.UserIdGroupPair{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), GroupId: awssdk.String("sg-abcdefg"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, }, }, @@ -434,12 +438,12 @@ func Test_defaultNetworkingManager_computePermissionsForPeerPort(t *testing.T) { ToPort: awssdk.Int32(8080), UserIdGroupPairs: []ec2types.UserIdGroupPair{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), GroupId: awssdk.String("sg-abcdefg"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, }, }, @@ -464,19 +468,21 @@ func Test_defaultNetworkingManager_computePermissionsForPeerPort(t *testing.T) { ToPort: awssdk.Int32(65535), UserIdGroupPairs: []ec2types.UserIdGroupPair{ { - Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=shared"), + Description: awssdk.String("elbv2.k8s.aws/targetGroupBinding=test-123"), GroupId: awssdk.String("sg-abcdefg"), }, }, }, - Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: tgbNetworkingIPPermissionLabelValue}, + Labels: map[string]string{tgbNetworkingIPPermissionLabelKey: testClusterName}, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m := &defaultNetworkingManager{} + m := &defaultNetworkingManager{ + clusterName: testClusterName, + } got, err := m.computePermissionsForPeerPort(context.Background(), tt.args.peer, tt.args.port, tt.args.pods) if tt.wantErr != nil { assert.EqualError(t, err, tt.wantErr.Error()) diff --git a/pkg/targetgroupbinding/resource_manager.go b/pkg/targetgroupbinding/resource_manager.go index 9bef6a70f3..5113ec2437 100644 --- a/pkg/targetgroupbinding/resource_manager.go +++ b/pkg/targetgroupbinding/resource_manager.go @@ -6,6 +6,7 @@ import ( elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/aws/smithy-go" "net/netip" + "strings" "time" "k8s.io/client-go/tools/record" @@ -134,6 +135,12 @@ func (m *defaultResourceManager) reconcileWithIPTargetType(ctx context.Context, var containsPotentialReadyEndpoints bool var err error + cidrs, err := parseCIDR(tgb) + if err != nil { + return "", "", false, fmt.Errorf("parse cidrs: %w", err) + } + resolveOpts = append(resolveOpts, backend.WithCIDRRanges(cidrs)) + endpoints, containsPotentialReadyEndpoints, err = m.endpointResolver.ResolvePodEndpoints(ctx, svcKey, tgb.Spec.ServiceRef.Port, resolveOpts...) if err != nil { @@ -159,7 +166,7 @@ func (m *defaultResourceManager) reconcileWithIPTargetType(ctx context.Context, tgARN := tgb.Spec.TargetGroupARN vpcID := tgb.Spec.VpcID - targets, err := m.targetsManager.ListTargets(ctx, tgARN) + targets, err := m.targetsManager.ListTargets(ctx, tgARN, cidrs) if err != nil { return "", "", false, err } @@ -264,7 +271,7 @@ func (m *defaultResourceManager) reconcileWithInstanceTargetType(ctx context.Con } tgARN := tgb.Spec.TargetGroupARN - targets, err := m.targetsManager.ListTargets(ctx, tgARN) + targets, err := m.targetsManager.ListTargets(ctx, tgARN, []netip.Prefix{}) if err != nil { return "", "", false, err } @@ -300,7 +307,11 @@ func (m *defaultResourceManager) reconcileWithInstanceTargetType(ctx context.Con } func (m *defaultResourceManager) cleanupTargets(ctx context.Context, tgb *elbv2api.TargetGroupBinding) error { - targets, err := m.targetsManager.ListTargets(ctx, tgb.Spec.TargetGroupARN) + cidrs, err := parseCIDR(tgb) + if err != nil { + return fmt.Errorf("parse cidrs: %w", err) + } + targets, err := m.targetsManager.ListTargets(ctx, tgb.Spec.TargetGroupARN, cidrs) if err != nil { if isELBV2TargetGroupNotFoundError(err) { return nil @@ -659,3 +670,17 @@ func isELBV2TargetGroupARNInvalidError(err error) bool { } return false } + +func parseCIDR(tgb *elbv2api.TargetGroupBinding) ([]netip.Prefix, error) { + var cidrs []netip.Prefix + // only triggered if the object has an annotation named `annotation` + if annotation := tgb.Annotations["filter-cidrs"]; annotation != "" { + s := strings.Split(annotation, ",") + c, err := networking.ParseCIDRs(s) + if err != nil { + return cidrs, err + } + cidrs = c + } + return cidrs, nil +} diff --git a/pkg/targetgroupbinding/targets_manager.go b/pkg/targetgroupbinding/targets_manager.go index b32eb350e0..3353c75e1b 100644 --- a/pkg/targetgroupbinding/targets_manager.go +++ b/pkg/targetgroupbinding/targets_manager.go @@ -2,14 +2,18 @@ package targetgroupbinding import ( "context" + "fmt" "github.com/aws/aws-sdk-go-v2/aws" elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + "net/netip" + "sync" + "time" + "github.com/go-logr/logr" "k8s.io/apimachinery/pkg/util/cache" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" - "sync" - "time" + "sigs.k8s.io/aws-load-balancer-controller/pkg/networking" ) const ( @@ -27,7 +31,7 @@ type TargetsManager interface { DeregisterTargets(ctx context.Context, tgARN string, targets []elbv2types.TargetDescription) error // List Targets from TargetGroup. - ListTargets(ctx context.Context, tgARN string) ([]TargetInfo, error) + ListTargets(ctx context.Context, tgARN string, cidrs []netip.Prefix) ([]TargetInfo, error) } // NewCachedTargetsManager constructs new cachedTargetsManager @@ -118,7 +122,7 @@ func (m *cachedTargetsManager) DeregisterTargets(ctx context.Context, tgARN stri return nil } -func (m *cachedTargetsManager) ListTargets(ctx context.Context, tgARN string) ([]TargetInfo, error) { +func (m *cachedTargetsManager) ListTargets(ctx context.Context, tgARN string, cidrs []netip.Prefix) ([]TargetInfo, error) { m.targetsCacheMutex.Lock() defer m.targetsCacheMutex.Unlock() @@ -126,7 +130,7 @@ func (m *cachedTargetsManager) ListTargets(ctx context.Context, tgARN string) ([ targetsCacheItem := rawTargetsCacheItem.(*targetsCacheItem) targetsCacheItem.mutex.Lock() defer targetsCacheItem.mutex.Unlock() - refreshedTargets, err := m.refreshUnhealthyTargets(ctx, tgARN, targetsCacheItem.targets) + refreshedTargets, err := m.refreshUnhealthyTargets(ctx, tgARN, targetsCacheItem.targets, cidrs) if err != nil { return nil, err } @@ -134,7 +138,7 @@ func (m *cachedTargetsManager) ListTargets(ctx context.Context, tgARN string) ([ return cloneTargetInfoSlice(refreshedTargets), nil } - refreshedTargets, err := m.refreshAllTargets(ctx, tgARN) + refreshedTargets, err := m.refreshAllTargets(ctx, tgARN, cidrs) if err != nil { return nil, err } @@ -147,8 +151,8 @@ func (m *cachedTargetsManager) ListTargets(ctx context.Context, tgARN string) ([ } // refreshAllTargets will refresh all targets for targetGroup. -func (m *cachedTargetsManager) refreshAllTargets(ctx context.Context, tgARN string) ([]TargetInfo, error) { - targets, err := m.listTargetsFromAWS(ctx, tgARN, nil) +func (m *cachedTargetsManager) refreshAllTargets(ctx context.Context, tgARN string, cidrs []netip.Prefix) ([]TargetInfo, error) { + targets, err := m.listTargetsFromAWS(ctx, tgARN, nil, cidrs) if err != nil { return nil, err } @@ -158,7 +162,7 @@ func (m *cachedTargetsManager) refreshAllTargets(ctx context.Context, tgARN stri // refreshUnhealthyTargets will refresh targets that are not in healthy status for targetGroup. // To save API calls, we don't refresh targets that are already healthy since once a target turns healthy, we'll unblock it's readinessProbe. // we can do nothing from controller perspective when a healthy target becomes unhealthy. -func (m *cachedTargetsManager) refreshUnhealthyTargets(ctx context.Context, tgARN string, cachedTargets []TargetInfo) ([]TargetInfo, error) { +func (m *cachedTargetsManager) refreshUnhealthyTargets(ctx context.Context, tgARN string, cachedTargets []TargetInfo, cidrs []netip.Prefix) ([]TargetInfo, error) { var refreshedTargets []TargetInfo var unhealthyTargets []elbv2types.TargetDescription for _, cachedTarget := range cachedTargets { @@ -172,7 +176,7 @@ func (m *cachedTargetsManager) refreshUnhealthyTargets(ctx context.Context, tgAR return refreshedTargets, nil } - refreshedUnhealthyTargets, err := m.listTargetsFromAWS(ctx, tgARN, unhealthyTargets) + refreshedUnhealthyTargets, err := m.listTargetsFromAWS(ctx, tgARN, unhealthyTargets, cidrs) if err != nil { return nil, err } @@ -188,7 +192,7 @@ func (m *cachedTargetsManager) refreshUnhealthyTargets(ctx context.Context, tgAR // listTargetsFromAWS will list targets for TargetGroup using ELBV2API. // if specified targets is non-empty, only these targets will be listed. // otherwise, all targets for targetGroup will be listed. -func (m *cachedTargetsManager) listTargetsFromAWS(ctx context.Context, tgARN string, targets []elbv2types.TargetDescription) ([]TargetInfo, error) { +func (m *cachedTargetsManager) listTargetsFromAWS(ctx context.Context, tgARN string, targets []elbv2types.TargetDescription, cidrs []netip.Prefix) ([]TargetInfo, error) { req := &elbv2sdk.DescribeTargetHealthInput{ TargetGroupArn: aws.String(tgARN), Targets: pointerizeTargetDescriptions(targets), @@ -200,6 +204,15 @@ func (m *cachedTargetsManager) listTargetsFromAWS(ctx context.Context, tgARN str listedTargets := make([]TargetInfo, 0, len(resp.TargetHealthDescriptions)) for _, elem := range resp.TargetHealthDescriptions { + if len(cidrs) > 0 { + ip, err := netip.ParseAddr(*elem.Target.Id) + if err != nil { + return nil, fmt.Errorf("parse ip addr: %w", err) + } + if !networking.IsIPWithinCIDRs(ip, cidrs) { + continue + } + } listedTargets = append(listedTargets, TargetInfo{ Target: *elem.Target, TargetHealth: elem.TargetHealth, diff --git a/pkg/targetgroupbinding/targets_manager_test.go b/pkg/targetgroupbinding/targets_manager_test.go index 1a291ffd1e..a493762e6f 100644 --- a/pkg/targetgroupbinding/targets_manager_test.go +++ b/pkg/targetgroupbinding/targets_manager_test.go @@ -2,6 +2,11 @@ package targetgroupbinding import ( "context" + "net/netip" + "sync" + "testing" + "time" + awssdk "github.com/aws/aws-sdk-go-v2/aws" elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" @@ -10,9 +15,6 @@ import ( "k8s.io/apimachinery/pkg/util/cache" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" "sigs.k8s.io/controller-runtime/pkg/log" - "sync" - "testing" - "time" ) func Test_cachedTargetsManager_RegisterTargets(t *testing.T) { @@ -791,7 +793,7 @@ func Test_cachedTargetsManager_ListTargets(t *testing.T) { } ctx := context.Background() - got, err := m.ListTargets(ctx, tt.args.tgARN) + got, err := m.ListTargets(ctx, tt.args.tgARN, []netip.Prefix{}) if tt.wantErr != nil { assert.EqualError(t, err, tt.wantErr.Error()) } else { @@ -1189,7 +1191,7 @@ func Test_cachedTargetsManager_refreshUnhealthyTargets(t *testing.T) { elbv2Client: elbv2Client, } ctx := context.Background() - got, err := m.refreshUnhealthyTargets(ctx, tt.args.tgARN, tt.args.cachedTargets) + got, err := m.refreshUnhealthyTargets(ctx, tt.args.tgARN, tt.args.cachedTargets, []netip.Prefix{}) if tt.wantErr != nil { assert.EqualError(t, err, tt.wantErr.Error()) } else { @@ -1352,7 +1354,7 @@ func Test_cachedTargetsManager_listTargetsFromAWS(t *testing.T) { elbv2Client: elbv2Client, } ctx := context.Background() - got, err := m.listTargetsFromAWS(ctx, tt.args.tgARN, tt.args.targets) + got, err := m.listTargetsFromAWS(ctx, tt.args.tgARN, tt.args.targets, []netip.Prefix{}) if tt.wantErr != nil { assert.EqualError(t, err, tt.wantErr.Error()) } else {