diff --git a/apis/elbv2/v1alpha1/targetgroupbinding_types.go b/apis/elbv2/v1alpha1/targetgroupbinding_types.go index 4af605afa7..a0f47a568d 100644 --- a/apis/elbv2/v1alpha1/targetgroupbinding_types.go +++ b/apis/elbv2/v1alpha1/targetgroupbinding_types.go @@ -128,6 +128,14 @@ type TargetGroupBindingSpec struct { // networking provides the networking setup for ELBV2 LoadBalancer to access targets in TargetGroup. // +optional Networking *TargetGroupBindingNetworking `json:"networking,omitempty"` + + // IAM Role ARN to assume when calling AWS APIs. Useful if the target group is in a different AWS account + // +optional + IamRoleArnToAssume string `json:"iamRoleArnToAssume,omitempty"` + + // IAM Role ARN to assume when calling AWS APIs. Needed to assume a role in another account and prevent the confused deputy problem. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html + // +optional + AssumeRoleExternalId string `json:"assumeRoleExternalId,omitempty"` } // TargetGroupBindingStatus defines the observed state of TargetGroupBinding diff --git a/apis/elbv2/v1beta1/targetgroupbinding_types.go b/apis/elbv2/v1beta1/targetgroupbinding_types.go index 7a273a1d43..4e5b109dbd 100644 --- a/apis/elbv2/v1beta1/targetgroupbinding_types.go +++ b/apis/elbv2/v1beta1/targetgroupbinding_types.go @@ -160,11 +160,11 @@ type TargetGroupBindingSpec struct { // IAM Role ARN to assume when calling AWS APIs. Useful if the target group is in a different AWS account // +optional - IamRoleArnToAssume string `json:"-"` // `json:"iamRoleArnToAssume,omitempty"` + IamRoleArnToAssume string `json:"iamRoleArnToAssume,omitempty"` // IAM Role ARN to assume when calling AWS APIs. Needed to assume a role in another account and prevent the confused deputy problem. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html // +optional - AssumeRoleExternalId string `json:"-"` // `json:"assumeRoleExternalId,omitempty"` + AssumeRoleExternalId string `json:"assumeRoleExternalId,omitempty"` } // TargetGroupBindingStatus defines the observed state of TargetGroupBinding diff --git a/config/crd/bases/elbv2.k8s.aws_targetgroupbindings.yaml b/config/crd/bases/elbv2.k8s.aws_targetgroupbindings.yaml index f9af43e6ff..f5385c977d 100644 --- a/config/crd/bases/elbv2.k8s.aws_targetgroupbindings.yaml +++ b/config/crd/bases/elbv2.k8s.aws_targetgroupbindings.yaml @@ -65,6 +65,15 @@ spec: spec: description: TargetGroupBindingSpec defines the desired state of TargetGroupBinding properties: + assumeRoleExternalId: + description: IAM Role ARN to assume when calling AWS APIs. Needed + to assume a role in another account and prevent the confused deputy + problem. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html + type: string + iamRoleArnToAssume: + description: IAM Role ARN to assume when calling AWS APIs. Useful + if the target group is in a different AWS account + type: string multiClusterTargetGroup: description: MultiClusterTargetGroup Denotes if the TargetGroup is shared among multiple clusters @@ -242,6 +251,15 @@ spec: spec: description: TargetGroupBindingSpec defines the desired state of TargetGroupBinding properties: + assumeRoleExternalId: + description: IAM Role ARN to assume when calling AWS APIs. Needed + to assume a role in another account and prevent the confused deputy + problem. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html + type: string + iamRoleArnToAssume: + description: IAM Role ARN to assume when calling AWS APIs. Useful + if the target group is in a different AWS account + type: string ipAddressType: description: ipAddressType specifies whether the target group is of type IPv4 or IPv6. If unspecified, it will be automatically inferred. diff --git a/docs/guide/targetgroupbinding/spec.md b/docs/guide/targetgroupbinding/spec.md index f2b9b80c55..96ee1963a6 100644 --- a/docs/guide/targetgroupbinding/spec.md +++ b/docs/guide/targetgroupbinding/spec.md @@ -55,15 +55,6 @@ Kubernetes meta/v1.ObjectMeta
annotations - - - - - -
alb.ingress.kubernetes.io/IamRoleArnToAssume
string
(Optional) In case the target group is in a differet AWS account, you put here the role that needs to be assumed in order to manipulate the target group. -
alb.ingress.kubernetes.io/AssumeRoleExternalId
string
(Optional) The external ID for the assume role operation. Optional, but recommended. It helps you to prevent the confused deputy problem. -
-
Refer to the Kubernetes API documentation for the other fields of the metadata field. diff --git a/docs/guide/targetgroupbinding/targetgroupbinding.md b/docs/guide/targetgroupbinding/targetgroupbinding.md index 0f3a055e74..e7d3ccf4aa 100644 --- a/docs/guide/targetgroupbinding/targetgroupbinding.md +++ b/docs/guide/targetgroupbinding/targetgroupbinding.md @@ -112,10 +112,108 @@ spec: ### AssumeRole Sometimes the AWS LoadBalancer controller needs to manipulate target groups from different AWS accounts. -The way to do that is assuming a role from such account. There are annotations that can help you with that: +The way to do that is assuming a role from such an account. The following spec fields help you with that. -* `alb.ingress.kubernetes.io/IamRoleArnToAssume`: the ARN that you need to assume -* `alb.ingress.kubernetes.io/AssumeRoleExternalId`: the external ID for the assume role operation. Optional, but recommended. It helps you to prevent the confused deputy problem ( https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html ) +* `iamRoleArnToAssume`: the ARN that you need to assume +* `assumeRoleExternalId`: the external ID for the assume role operation. Optional, but recommended. It helps you to prevent the confused deputy problem ( https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html ) + + +```yaml +apiVersion: elbv2.k8s.aws/v1beta1 +kind: TargetGroupBinding +metadata: + name: peered-tg + namespace: nlb-game-2048-1 +spec: + assumeRoleExternalId: very-secret-string-2 + iamRoleArnToAssume: arn:aws:iam::155642222660:role/tg-management-role + networking: + ingress: + - from: + - securityGroup: + groupID: sg-0b6a41a2fd959623f + ports: + - port: 80 + protocol: TCP + serviceRef: + name: service-2048 + port: 80 + targetGroupARN: arn:aws:elasticloadbalancing:us-west-2:155642222660:targetgroup/peered-tg/6a4ecf7bfae473c1 +``` + +In the following examples, we will refer to Cluster Owner (CO) and Target Group Owner (TGO) accounts. + +First, in the TGO account creates a role that will allow the AWS LBC in the CO account to assume it. +For improved security, we only allow the AWS LBC role in CO account to assume the role. + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::565768096483:role/eksctl-awslbc-loadtest-addon-iamserviceaccoun-Role1-13RdJCMqV6p2" + }, + "Action": "sts:AssumeRole", + "Condition": { + "StringEquals": { + "sts:ExternalId": "very-secret-string" + } + } + } + ] +} +``` + +Next, still in the TGO account we need to add the following permissions to the Role created in the first step. + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "VisualEditor0", + "Effect": "Allow", + "Action": [ + "elasticloadbalancing:RegisterTargets", + "elasticloadbalancing:DeregisterTargets" + ], + "Resource": [ + "arn:aws:elasticloadbalancing:us-west-2:155642222660:targetgroup/tg1/*", + "arn:aws:elasticloadbalancing:us-west-2:155642222660:targetgroup/tg2/*" + // add more here // + ] + }, + { + "Sid": "VisualEditor1", + "Effect": "Allow", + "Action": [ + "elasticloadbalancing:DescribeTargetGroups", + "elasticloadbalancing:DescribeTargetHealth" + ], + "Resource": "*" + } + ] +} +``` + + +Next, in the CO account, we need to allow the AWS LBC to perform the AssumeRole call. +By default, this permission is not a part of the standard IAM policy that is vended with the LBC installation scripts. +For improved security, it is possible to scope the AssumeRole permissions down to only roles that you know ahead of time the +LBC will need to Assume. + +```json + { + "Effect": "Allow", + "Action": [ + "sts:AssumeRole" + ], + "Resource": "*" + } +``` ## Sample YAML @@ -125,10 +223,9 @@ apiVersion: elbv2.k8s.aws/v1beta1 kind: TargetGroupBinding metadata: name: my-tgb - annotations: - alb.ingress.kubernetes.io/IamRoleArnToAssume: "arn:aws:iam::999999999999:role/alb-controller-policy-to-assume" - alb.ingress.kubernetes.io/AssumeRoleExternalId: "some-magic-string" spec: + iamRoleArnToAssume: "arn:aws:iam::999999999999:role/alb-controller-policy-to-assume" + assumeRoleExternalId: "some-magic-string" ... ``` diff --git a/go.sum b/go.sum index 292ee96ce6..6676d7940c 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,7 @@ github.com/aws/aws-sdk-go-v2/service/ec2 v1.173.0 h1:ta62lid9JkIpKZtZZXSj6rP2AqY github.com/aws/aws-sdk-go-v2/service/ec2 v1.173.0/go.mod h1:o6QDjdVKpP5EF0dp/VlvqckzuSDATr1rLdHt3A5m0YY= github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.43.1 h1:L9Wt9zgtoYKIlaeFTy+EztGjL4oaXBBGtVXA+jaeYko= github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2 v1.43.1/go.mod h1:yxzLdxt7bVGvIOPYIKFtiaJCJnx2ChlIIvlhW4QgI6M= +github.com/aws/aws-sdk-go-v2/service/iam v1.36.3/go.mod h1:HSvujsK8xeEHMIB18oMXjSfqaN9cVqpo/MtHJIksQRk= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= diff --git a/helm/aws-load-balancer-controller/crds/crds.yaml b/helm/aws-load-balancer-controller/crds/crds.yaml index b72e687892..ea31acc114 100644 --- a/helm/aws-load-balancer-controller/crds/crds.yaml +++ b/helm/aws-load-balancer-controller/crds/crds.yaml @@ -317,6 +317,15 @@ spec: spec: description: TargetGroupBindingSpec defines the desired state of TargetGroupBinding properties: + assumeRoleExternalId: + description: IAM Role ARN to assume when calling AWS APIs. Needed + to assume a role in another account and prevent the confused deputy + problem. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html + type: string + iamRoleArnToAssume: + description: IAM Role ARN to assume when calling AWS APIs. Useful + if the target group is in a different AWS account + type: string multiClusterTargetGroup: description: MultiClusterTargetGroup Denotes if the TargetGroup is shared among multiple clusters @@ -494,6 +503,15 @@ spec: spec: description: TargetGroupBindingSpec defines the desired state of TargetGroupBinding properties: + assumeRoleExternalId: + description: IAM Role ARN to assume when calling AWS APIs. Needed + to assume a role in another account and prevent the confused deputy + problem. https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html + type: string + iamRoleArnToAssume: + description: IAM Role ARN to assume when calling AWS APIs. Useful + if the target group is in a different AWS account + type: string ipAddressType: description: ipAddressType specifies whether the target group is of type IPv4 or IPv6. If unspecified, it will be automatically inferred. diff --git a/main.go b/main.go index de2a3edc9d..c4641fb109 100644 --- a/main.go +++ b/main.go @@ -90,7 +90,7 @@ func main() { awsMetricsCollector = awsmetrics.NewCollector(metrics.Registry) } - cloud, err := aws.NewCloud(controllerCFG.AWSConfig, awsMetricsCollector, ctrl.Log, nil) + cloud, err := aws.NewCloud(controllerCFG.AWSConfig, controllerCFG.ClusterName, awsMetricsCollector, ctrl.Log, nil) if err != nil { setupLog.Error(err, "unable to initialize AWS cloud") os.Exit(1) diff --git a/pkg/aws/aws_config.go b/pkg/aws/aws_config.go new file mode 100644 index 0000000000..4172f34daf --- /dev/null +++ b/pkg/aws/aws_config.go @@ -0,0 +1,81 @@ +package aws + +import ( + "context" + "github.com/aws/aws-sdk-go-v2/aws" + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/aws/ratelimit" + "github.com/aws/aws-sdk-go-v2/aws/retry" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" + smithymiddleware "github.com/aws/smithy-go/middleware" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle" + awsmetrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/aws" + "sigs.k8s.io/aws-load-balancer-controller/pkg/version" +) + +const ( + userAgent = "elbv2.k8s.aws" +) + +func NewAWSConfigGenerator(cfg CloudConfig, ec2IMDSEndpointMode imds.EndpointModeState, metricsCollector *awsmetrics.Collector) AWSConfigGenerator { + return &awsConfigGeneratorImpl{ + cfg: cfg, + ec2IMDSEndpointMode: ec2IMDSEndpointMode, + metricsCollector: metricsCollector, + } + +} + +// AWSConfigGenerator is responsible for generating an aws config based on the running environment +type AWSConfigGenerator interface { + GenerateAWSConfig(optFns ...func(*config.LoadOptions) error) (aws.Config, error) +} + +type awsConfigGeneratorImpl struct { + cfg CloudConfig + ec2IMDSEndpointMode imds.EndpointModeState + metricsCollector *awsmetrics.Collector +} + +func (gen *awsConfigGeneratorImpl) GenerateAWSConfig(optFns ...func(*config.LoadOptions) error) (aws.Config, error) { + + defaultOpts := []func(*config.LoadOptions) error{ + config.WithRegion(gen.cfg.Region), + config.WithRetryer(func() aws.Retryer { + return retry.NewStandard(func(o *retry.StandardOptions) { + o.RateLimiter = ratelimit.None + o.MaxAttempts = gen.cfg.MaxRetries + }) + }), + config.WithEC2IMDSEndpointMode(gen.ec2IMDSEndpointMode), + config.WithAPIOptions([]func(stack *smithymiddleware.Stack) error{ + awsmiddleware.AddUserAgentKeyValue(userAgent, version.GitVersion), + }), + } + + defaultOpts = append(defaultOpts, optFns...) + + awsConfig, err := config.LoadDefaultConfig(context.TODO(), + defaultOpts..., + ) + + if err != nil { + return aws.Config{}, err + } + + if gen.cfg.ThrottleConfig != nil { + throttler := throttle.NewThrottler(gen.cfg.ThrottleConfig) + awsConfig.APIOptions = append(awsConfig.APIOptions, func(stack *smithymiddleware.Stack) error { + return throttle.WithSDKRequestThrottleMiddleware(throttler)(stack) + }) + } + + if gen.metricsCollector != nil { + awsConfig.APIOptions = awsmetrics.WithSDKMetricCollector(gen.metricsCollector, awsConfig.APIOptions) + } + + return awsConfig, nil +} + +var _ AWSConfigGenerator = &awsConfigGeneratorImpl{} diff --git a/pkg/aws/cloud.go b/pkg/aws/cloud.go index 1e9d50f878..bd733a657d 100644 --- a/pkg/aws/cloud.go +++ b/pkg/aws/cloud.go @@ -3,23 +3,18 @@ package aws import ( "context" "fmt" - "log" + "k8s.io/apimachinery/pkg/util/cache" "net" "os" "strings" + "sync" + "time" - awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" - "github.com/aws/aws-sdk-go-v2/aws/ratelimit" - "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/sts" - smithymiddleware "github.com/aws/smithy-go/middleware" - "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/throttle" - "sigs.k8s.io/aws-load-balancer-controller/pkg/version" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/service/ec2" @@ -32,10 +27,12 @@ import ( aws_metrics "sigs.k8s.io/aws-load-balancer-controller/pkg/metrics/aws" ) -const userAgent = "elbv2.k8s.aws" +const ( + cacheTTLBufferTime = 30 * time.Second +) // NewCloud constructs new Cloud implementation. -func NewCloud(cfg CloudConfig, metricsCollector *aws_metrics.Collector, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (services.Cloud, error) { +func NewCloud(cfg CloudConfig, clusterName string, metricsCollector *aws_metrics.Collector, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (services.Cloud, error) { hasIPv4 := true addrs, err := net.InterfaceAddrs() if err == nil { @@ -76,29 +73,11 @@ func NewCloud(cfg CloudConfig, metricsCollector *aws_metrics.Collector, logger l } cfg.Region = region } - awsConfig, err := config.LoadDefaultConfig(context.TODO(), - config.WithRegion(cfg.Region), - config.WithRetryer(func() aws.Retryer { - return retry.NewStandard(func(o *retry.StandardOptions) { - o.RateLimiter = ratelimit.None - o.MaxAttempts = cfg.MaxRetries - }) - }), - config.WithEC2IMDSEndpointMode(ec2IMDSEndpointMode), - config.WithAPIOptions([]func(stack *smithymiddleware.Stack) error{ - awsmiddleware.AddUserAgentKeyValue(userAgent, version.GitVersion), - }), - ) - - if cfg.ThrottleConfig != nil { - throttler := throttle.NewThrottler(cfg.ThrottleConfig) - awsConfig.APIOptions = append(awsConfig.APIOptions, func(stack *smithymiddleware.Stack) error { - return throttle.WithSDKRequestThrottleMiddleware(throttler)(stack) - }) - } - if metricsCollector != nil { - awsConfig.APIOptions = aws_metrics.WithSDKMetricCollector(metricsCollector, awsConfig.APIOptions) + awsConfigGenerator := NewAWSConfigGenerator(cfg, ec2IMDSEndpointMode, metricsCollector) + awsConfig, err := awsConfigGenerator.GenerateAWSConfig() + if err != nil { + return nil, errors.Wrap(err, "Unable to generate AWS config") } if awsClientsProvider == nil { @@ -119,6 +98,7 @@ func NewCloud(cfg CloudConfig, metricsCollector *aws_metrics.Collector, logger l thisObj := &defaultCloud{ cfg: cfg, + clusterName: clusterName, ec2: ec2Service, acm: services.NewACM(awsClientsProvider), wafv2: services.NewWAFv2(awsClientsProvider), @@ -126,7 +106,10 @@ func NewCloud(cfg CloudConfig, metricsCollector *aws_metrics.Collector, logger l shield: services.NewShield(awsClientsProvider), rgt: services.NewRGT(awsClientsProvider), - assumeRoleElbV2: make(map[string]services.ELBV2), + awsConfigGenerator: awsConfigGenerator, + + assumeRoleElbV2Cache: cache.NewExpiring(), + awsClientsProvider: awsClientsProvider, logger: logger, } @@ -220,77 +203,64 @@ type defaultCloud struct { shield services.Shield rgt services.RGT - assumeRoleElbV2 map[string]services.ELBV2 + clusterName string + + awsConfigGenerator AWSConfigGenerator + + // A cache holding elbv2 clients that are assuming a role. + assumeRoleElbV2Cache *cache.Expiring + // assumeRoleElbV2CacheMutex protects assumeRoleElbV2Cache + assumeRoleElbV2CacheMutex sync.RWMutex + awsClientsProvider provider.AWSClientsProvider logger logr.Logger } -// returns ELBV2 client for the given assumeRoleArn, or the default ELBV2 client if assumeRoleArn is empty -func (c *defaultCloud) GetAssumedRoleELBV2(ctx context.Context, assumeRoleArn string, externalId string) services.ELBV2 { - +// GetAssumedRoleELBV2 returns ELBV2 client for the given assumeRoleArn, or the default ELBV2 client if assumeRoleArn is empty +func (c *defaultCloud) GetAssumedRoleELBV2(ctx context.Context, assumeRoleArn string, externalId string) (services.ELBV2, error) { if assumeRoleArn == "" { - return c.elbv2 + return c.elbv2, nil } - assumedRoleELBV2, exists := c.assumeRoleElbV2[assumeRoleArn] + c.assumeRoleElbV2CacheMutex.RLock() + assumedRoleELBV2, exists := c.assumeRoleElbV2Cache.Get(assumeRoleArn) + c.assumeRoleElbV2CacheMutex.RUnlock() + if exists { - return assumedRoleELBV2 + return assumedRoleELBV2.(services.ELBV2), nil } - c.logger.Info("awsCloud", "method", "GetAssumedRoleELBV2", "AssumeRoleArn", assumeRoleArn, "externalId", externalId) + c.logger.Info("Constructing new elbv2 client", "AssumeRoleArn", assumeRoleArn, "externalId", externalId) - //////////////// - existingAwsConfig, _ := c.awsClientsProvider.GetAWSConfig(ctx, "GetAWSConfigForIAMRoleImpersonation") + stsClient, err := c.awsClientsProvider.GetSTSClient(ctx, "AssumeRole") + if err != nil { + // This should never happen, but let's be forward-looking. + return nil, err + } - sourceAccount := sts.NewFromConfig(*existingAwsConfig) - response, err := sourceAccount.AssumeRole(ctx, &sts.AssumeRoleInput{ + response, err := stsClient.AssumeRole(ctx, &sts.AssumeRoleInput{ RoleArn: aws.String(assumeRoleArn), - RoleSessionName: aws.String("aws-load-balancer-controller"), + RoleSessionName: aws.String(generateAssumeRoleSessionName(c.clusterName)), ExternalId: aws.String(externalId), }) if err != nil { - log.Fatalf("Unable to assume target role, %v. Attempting to use default client", err) - return c.elbv2 + c.logger.Error(err, "Unable to assume target role", "roleArn", assumeRoleArn) + return nil, err } assumedRoleCreds := response.Credentials newCreds := credentials.NewStaticCredentialsProvider(*assumedRoleCreds.AccessKeyId, *assumedRoleCreds.SecretAccessKey, *assumedRoleCreds.SessionToken) - newAwsConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(c.cfg.Region), config.WithCredentialsProvider(newCreds)) + newAwsConfig, err := c.awsConfigGenerator.GenerateAWSConfig(config.WithCredentialsProvider(newCreds)) if err != nil { - log.Fatalf("Unable to load static credentials for service client config, %v. Attempting to use default client", err) - return c.elbv2 + c.logger.Error(err, "Create new service client config service client config", "roleArn", assumeRoleArn) + return nil, err } - existingAwsConfig.Credentials = newAwsConfig.Credentials // response.Credentials - - // // var assumedRoleCreds *stsTypes.Credentials = response.Credentials - - // // Create config with target service client, using assumed role - // cfg, err = config.LoadDefaultConfig(ctx, config.WithRegion(region), config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(*assumedRoleCreds.AccessKeyId, *assumedRoleCreds.SecretAccessKey, *assumedRoleCreds.SessionToken))) - // if err != nil { - // log.Fatalf("unable to load static credentials for service client config, %v", err) - // } - - // //////////////// - // appCreds := stscreds.NewAssumeRoleProvider(client, assumeRoleArn) - // value, err := appCreds.Retrieve(context.TODO()) - // if err != nil { - // // handle error - // } - // ///////// - - // ///////////// OLD - // creds := stscreds.NewCredentials(c.session, assumeRoleArn, func(p *stscreds.AssumeRoleProvider) { - // p.ExternalID = &externalId - // }) - // ////////////// - - // c.awsConfig.Credentials = creds - // // newObj := services.NewELBV2(c.session, c, c.awsCFG) - // newObj := services.NewELBV2(*c.awsConfig, c.endpointsResolver, c) - - newObj := services.NewELBV2(c.awsClientsProvider, c) - c.assumeRoleElbV2[assumeRoleArn] = newObj + cacheTTL := assumedRoleCreds.Expiration.Sub(time.Now()) + elbv2WithAssumedRole := services.NewELBV2FromStaticClient(c.awsClientsProvider.GenerateNewELBv2Client(newAwsConfig), c) - return newObj + c.assumeRoleElbV2CacheMutex.Lock() + defer c.assumeRoleElbV2CacheMutex.Unlock() + c.assumeRoleElbV2Cache.Set(assumeRoleArn, elbv2WithAssumedRole, cacheTTL-cacheTTLBufferTime) + return elbv2WithAssumedRole, nil } func (c *defaultCloud) EC2() services.EC2 { diff --git a/pkg/aws/cloud_util.go b/pkg/aws/cloud_util.go new file mode 100644 index 0000000000..6da13b8f0b --- /dev/null +++ b/pkg/aws/cloud_util.go @@ -0,0 +1,25 @@ +package aws + +import ( + "fmt" + "regexp" +) + +const ( + sessionNamePrefix = "AWS-LBC-" + maxSessionNameLength = 2047 +) + +var illegalValuesInSessionName = regexp.MustCompile(`[^a-zA-Z0-9=,.@\-_]+`) + +func generateAssumeRoleSessionName(clusterName string) string { + safeClusterName := illegalValuesInSessionName.ReplaceAllString(clusterName, "") + + sessionName := fmt.Sprintf("%s%s", sessionNamePrefix, safeClusterName) + + if len(sessionName) > maxSessionNameLength { + return sessionName[:maxSessionNameLength] + } + + return sessionName +} diff --git a/pkg/aws/cloud_util_test.go b/pkg/aws/cloud_util_test.go new file mode 100644 index 0000000000..83e37b612b --- /dev/null +++ b/pkg/aws/cloud_util_test.go @@ -0,0 +1,42 @@ +package aws + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestUpdateTrackedTargets(t *testing.T) { + testCases := []struct { + name string + clusterName string + expectedSessionName string + }{ + { + name: "no mods", + clusterName: "my-cluster-name", + expectedSessionName: "AWS-LBC-my-cluster-name", + }, + { + name: "mix lower and upper case", + clusterName: "My-ClUsTeR-name", + expectedSessionName: "AWS-LBC-My-ClUsTeR-name", + }, + { + name: "with legal characters", + clusterName: "my_cluster-name=foo,something@here.", + expectedSessionName: "AWS-LBC-my_cluster-name=foo,something@here.", + }, + { + name: "with illegal characters", + clusterName: "my&*&*cluster()!(&name", + expectedSessionName: "AWS-LBC-myclustername", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := generateAssumeRoleSessionName(tc.clusterName) + assert.Equal(t, tc.expectedSessionName, result) + }) + } +} diff --git a/pkg/aws/provider/default_aws_clients_provider.go b/pkg/aws/provider/default_aws_clients_provider.go index 41cd780554..1d1a2b713e 100644 --- a/pkg/aws/provider/default_aws_clients_provider.go +++ b/pkg/aws/provider/default_aws_clients_provider.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" "github.com/aws/aws-sdk-go-v2/service/shield" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go-v2/service/wafregional" "github.com/aws/aws-sdk-go-v2/service/wafv2" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints" @@ -21,11 +22,13 @@ type defaultAWSClientsProvider struct { wafRegionClient *wafregional.Client shieldClient *shield.Client rgtClient *resourcegroupstaggingapi.Client + stsClient *sts.Client - awsConfig *aws.Config + // used for dynamic creation of ELBv2 client + elbv2CustomEndpoint *string } -func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) { +func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (AWSClientsProvider, error) { ec2CustomEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID) elbv2CustomEndpoint := endpointsResolver.EndpointFor(elasticloadbalancingv2.ServiceID) acmCustomEndpoint := endpointsResolver.EndpointFor(acm.ServiceID) @@ -33,17 +36,16 @@ func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.R wafregionalCustomEndpoint := endpointsResolver.EndpointFor(wafregional.ServiceID) shieldCustomEndpoint := endpointsResolver.EndpointFor(shield.ServiceID) rgtCustomEndpoint := endpointsResolver.EndpointFor(resourcegroupstaggingapi.ServiceID) + stsCustomEndpoint := endpointsResolver.EndpointFor(sts.ServiceID) ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) { if ec2CustomEndpoint != nil { o.BaseEndpoint = ec2CustomEndpoint } }) - elbv2Client := elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) { - if elbv2CustomEndpoint != nil { - o.BaseEndpoint = elbv2CustomEndpoint - } - }) + + elbv2Client := generateNewELBv2ClientHelper(cfg, elbv2CustomEndpoint) + acmClient := acm.NewFromConfig(cfg, func(o *acm.Options) { if acmCustomEndpoint != nil { o.BaseEndpoint = acmCustomEndpoint @@ -68,6 +70,12 @@ func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.R } }) + stsClient := sts.NewFromConfig(cfg, func(o *sts.Options) { + if stsCustomEndpoint != nil { + o.BaseEndpoint = stsCustomEndpoint + } + }) + return &defaultAWSClientsProvider{ ec2Client: ec2Client, elbv2Client: elbv2Client, @@ -76,8 +84,9 @@ func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.R wafRegionClient: wafregionalClient, shieldClient: shieldClient, rgtClient: rgtClient, + stsClient: stsClient, - awsConfig: &cfg, + elbv2CustomEndpoint: elbv2CustomEndpoint, }, nil } @@ -112,6 +121,18 @@ func (p *defaultAWSClientsProvider) GetRGTClient(ctx context.Context, operationN return p.rgtClient, nil } -func (p *defaultAWSClientsProvider) GetAWSConfig(ctx context.Context, operationName string) (*aws.Config, error) { - return p.awsConfig, nil +func (p *defaultAWSClientsProvider) GetSTSClient(ctx context.Context, operationName string) (*sts.Client, error) { + return p.stsClient, nil +} + +func (p *defaultAWSClientsProvider) GenerateNewELBv2Client(cfg aws.Config) *elasticloadbalancingv2.Client { + return generateNewELBv2ClientHelper(cfg, p.elbv2CustomEndpoint) +} + +func generateNewELBv2ClientHelper(cfg aws.Config, elbv2CustomEndpoint *string) *elasticloadbalancingv2.Client { + return elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) { + if elbv2CustomEndpoint != nil { + o.BaseEndpoint = elbv2CustomEndpoint + } + }) } diff --git a/pkg/aws/provider/provider.go b/pkg/aws/provider/provider.go index 2cdff45745..66bb168286 100644 --- a/pkg/aws/provider/provider.go +++ b/pkg/aws/provider/provider.go @@ -8,6 +8,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" "github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi" "github.com/aws/aws-sdk-go-v2/service/shield" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go-v2/service/wafregional" "github.com/aws/aws-sdk-go-v2/service/wafv2" ) @@ -20,6 +21,6 @@ type AWSClientsProvider interface { GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error) GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error) GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error) - - GetAWSConfig(ctx context.Context, operationName string) (*aws.Config, error) + GetSTSClient(ctx context.Context, operationName string) (*sts.Client, error) + GenerateNewELBv2Client(cfg aws.Config) *elasticloadbalancingv2.Client } diff --git a/pkg/aws/services/cloudInterface.go b/pkg/aws/services/cloudInterface.go index 3e0ff558ec..8b11eaeb16 100644 --- a/pkg/aws/services/cloudInterface.go +++ b/pkg/aws/services/cloudInterface.go @@ -30,5 +30,5 @@ type Cloud interface { // VpcID for the LoadBalancer resources. VpcID() string - GetAssumedRoleELBV2(ctx context.Context, assumeRoleArn string, externalId string) ELBV2 + GetAssumedRoleELBV2(ctx context.Context, assumeRoleArn string, externalId string) (ELBV2, error) } diff --git a/pkg/aws/services/elbv2.go b/pkg/aws/services/elbv2.go index 877cdd5f3a..293f61f3b2 100644 --- a/pkg/aws/services/elbv2.go +++ b/pkg/aws/services/elbv2.go @@ -60,7 +60,7 @@ type ELBV2 interface { ModifyListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerAttributesInput) (*elasticloadbalancingv2.ModifyListenerAttributesOutput, error) ModifyCapacityReservationWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyCapacityReservationInput) (*elasticloadbalancingv2.ModifyCapacityReservationOutput, error) DescribeCapacityReservationWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeCapacityReservationInput) (*elasticloadbalancingv2.DescribeCapacityReservationOutput, error) - AssumeRole(ctx context.Context, assumeRoleArn string, externalId string) ELBV2 + AssumeRole(ctx context.Context, assumeRoleArn string, externalId string) (ELBV2, error) } func NewELBV2(awsClientsProvider provider.AWSClientsProvider, cloud Cloud) ELBV2 { @@ -70,21 +70,29 @@ func NewELBV2(awsClientsProvider provider.AWSClientsProvider, cloud Cloud) ELBV2 } } +func NewELBV2FromStaticClient(staticELBClient *elasticloadbalancingv2.Client, cloud Cloud) ELBV2 { + return &elbv2Client{ + staticELBClient: staticELBClient, + cloud: cloud, + } +} + // default implementation for ELBV2. type elbv2Client struct { awsClientsProvider provider.AWSClientsProvider + staticELBClient *elasticloadbalancingv2.Client cloud Cloud } -func (c *elbv2Client) AssumeRole(ctx context.Context, assumeRoleArn string, externalId string) ELBV2 { +func (c *elbv2Client) AssumeRole(ctx context.Context, assumeRoleArn string, externalId string) (ELBV2, error) { if assumeRoleArn == "" { - return c + return c, nil } return c.cloud.GetAssumedRoleELBV2(ctx, assumeRoleArn, externalId) } func (c *elbv2Client) AddListenerCertificatesWithContext(ctx context.Context, input *elasticloadbalancingv2.AddListenerCertificatesInput) (*elasticloadbalancingv2.AddListenerCertificatesOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "AddListenerCertificates") + client, err := c.getClient(ctx, "AddListenerCertificates") if err != nil { return nil, err } @@ -92,7 +100,7 @@ func (c *elbv2Client) AddListenerCertificatesWithContext(ctx context.Context, in } func (c *elbv2Client) RemoveListenerCertificatesWithContext(ctx context.Context, input *elasticloadbalancingv2.RemoveListenerCertificatesInput) (*elasticloadbalancingv2.RemoveListenerCertificatesOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "RemoveListenerCertificates") + client, err := c.getClient(ctx, "RemoveListenerCertificates") if err != nil { return nil, err } @@ -100,7 +108,7 @@ func (c *elbv2Client) RemoveListenerCertificatesWithContext(ctx context.Context, } func (c *elbv2Client) DescribeListenersWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeListenersInput) (*elasticloadbalancingv2.DescribeListenersOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListeners") + client, err := c.getClient(ctx, "DescribeListeners") if err != nil { return nil, err } @@ -108,7 +116,7 @@ func (c *elbv2Client) DescribeListenersWithContext(ctx context.Context, input *e } func (c *elbv2Client) DescribeRulesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeRulesInput) (*elasticloadbalancingv2.DescribeRulesOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeRules") + client, err := c.getClient(ctx, "DescribeRules") if err != nil { return nil, err } @@ -116,7 +124,7 @@ func (c *elbv2Client) DescribeRulesWithContext(ctx context.Context, input *elast } func (c *elbv2Client) RegisterTargetsWithContext(ctx context.Context, input *elasticloadbalancingv2.RegisterTargetsInput) (*elasticloadbalancingv2.RegisterTargetsOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "RegisterTargets") + client, err := c.getClient(ctx, "RegisterTargets") if err != nil { return nil, err } @@ -124,7 +132,7 @@ func (c *elbv2Client) RegisterTargetsWithContext(ctx context.Context, input *ela } func (c *elbv2Client) DeregisterTargetsWithContext(ctx context.Context, input *elasticloadbalancingv2.DeregisterTargetsInput) (*elasticloadbalancingv2.DeregisterTargetsOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeregisterTargets") + client, err := c.getClient(ctx, "DeregisterTargets") if err != nil { return nil, err } @@ -132,7 +140,7 @@ func (c *elbv2Client) DeregisterTargetsWithContext(ctx context.Context, input *e } func (c *elbv2Client) DescribeTrustStoresWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTrustStoresInput) (*elasticloadbalancingv2.DescribeTrustStoresOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTrustStores") + client, err := c.getClient(ctx, "DescribeTrustStores") if err != nil { return nil, err } @@ -140,7 +148,7 @@ func (c *elbv2Client) DescribeTrustStoresWithContext(ctx context.Context, input } func (c *elbv2Client) ModifyRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyRuleInput) (*elasticloadbalancingv2.ModifyRuleOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyRule") + client, err := c.getClient(ctx, "ModifyRule") if err != nil { return nil, err } @@ -148,7 +156,7 @@ func (c *elbv2Client) ModifyRuleWithContext(ctx context.Context, input *elasticl } func (c *elbv2Client) DeleteRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteRuleInput) (*elasticloadbalancingv2.DeleteRuleOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteRule") + client, err := c.getClient(ctx, "DeleteRule") if err != nil { return nil, err } @@ -156,7 +164,7 @@ func (c *elbv2Client) DeleteRuleWithContext(ctx context.Context, input *elasticl } func (c *elbv2Client) CreateRuleWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateRuleInput) (*elasticloadbalancingv2.CreateRuleOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateRule") + client, err := c.getClient(ctx, "CreateRule") if err != nil { return nil, err } @@ -164,7 +172,7 @@ func (c *elbv2Client) CreateRuleWithContext(ctx context.Context, input *elasticl } func (c *elbv2Client) WaitUntilLoadBalancerAvailableWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancersInput) error { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancers") + client, err := c.getClient(ctx, "DescribeLoadBalancers") if err != nil { return err } @@ -174,7 +182,7 @@ func (c *elbv2Client) WaitUntilLoadBalancerAvailableWithContext(ctx context.Cont } func (c *elbv2Client) DescribeLoadBalancersWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancersInput) (*elasticloadbalancingv2.DescribeLoadBalancersOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancers") + client, err := c.getClient(ctx, "DescribeLoadBalancers") if err != nil { return nil, err } @@ -182,7 +190,7 @@ func (c *elbv2Client) DescribeLoadBalancersWithContext(ctx context.Context, inpu } func (c *elbv2Client) DescribeTargetHealthWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetHealthInput) (*elasticloadbalancingv2.DescribeTargetHealthOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetHealth") + client, err := c.getClient(ctx, "DescribeTargetHealth") if err != nil { return nil, err } @@ -190,7 +198,7 @@ func (c *elbv2Client) DescribeTargetHealthWithContext(ctx context.Context, input } func (c *elbv2Client) DescribeTargetGroupsWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetGroupsInput) (*elasticloadbalancingv2.DescribeTargetGroupsOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetGroups") + client, err := c.getClient(ctx, "DescribeTargetGroups") if err != nil { return nil, err } @@ -198,7 +206,7 @@ func (c *elbv2Client) DescribeTargetGroupsWithContext(ctx context.Context, input } func (c *elbv2Client) DeleteTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteTargetGroupInput) (*elasticloadbalancingv2.DeleteTargetGroupOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteTargetGroup") + client, err := c.getClient(ctx, "DeleteTargetGroup") if err != nil { return nil, err } @@ -206,7 +214,7 @@ func (c *elbv2Client) DeleteTargetGroupWithContext(ctx context.Context, input *e } func (c *elbv2Client) ModifyTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyTargetGroupInput) (*elasticloadbalancingv2.ModifyTargetGroupOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyTargetGroup") + client, err := c.getClient(ctx, "ModifyTargetGroup") if err != nil { return nil, err } @@ -214,7 +222,7 @@ func (c *elbv2Client) ModifyTargetGroupWithContext(ctx context.Context, input *e } func (c *elbv2Client) CreateTargetGroupWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateTargetGroupInput) (*elasticloadbalancingv2.CreateTargetGroupOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateTargetGroup") + client, err := c.getClient(ctx, "CreateTargetGroup") if err != nil { return nil, err } @@ -222,7 +230,7 @@ func (c *elbv2Client) CreateTargetGroupWithContext(ctx context.Context, input *e } func (c *elbv2Client) DescribeTargetGroupAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTargetGroupAttributesInput) (*elasticloadbalancingv2.DescribeTargetGroupAttributesOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetGroupAttributes") + client, err := c.getClient(ctx, "DescribeTargetGroupAttributes") if err != nil { return nil, err } @@ -230,7 +238,7 @@ func (c *elbv2Client) DescribeTargetGroupAttributesWithContext(ctx context.Conte } func (c *elbv2Client) ModifyTargetGroupAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyTargetGroupAttributesInput) (*elasticloadbalancingv2.ModifyTargetGroupAttributesOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyTargetGroupAttributes") + client, err := c.getClient(ctx, "ModifyTargetGroupAttributes") if err != nil { return nil, err } @@ -238,7 +246,7 @@ func (c *elbv2Client) ModifyTargetGroupAttributesWithContext(ctx context.Context } func (c *elbv2Client) SetSecurityGroupsWithContext(ctx context.Context, input *elasticloadbalancingv2.SetSecurityGroupsInput) (*elasticloadbalancingv2.SetSecurityGroupsOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "SetSecurityGroups") + client, err := c.getClient(ctx, "SetSecurityGroups") if err != nil { return nil, err } @@ -246,7 +254,7 @@ func (c *elbv2Client) SetSecurityGroupsWithContext(ctx context.Context, input *e } func (c *elbv2Client) SetSubnetsWithContext(ctx context.Context, input *elasticloadbalancingv2.SetSubnetsInput) (*elasticloadbalancingv2.SetSubnetsOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "SetSubnets") + client, err := c.getClient(ctx, "SetSubnets") if err != nil { return nil, err } @@ -254,7 +262,7 @@ func (c *elbv2Client) SetSubnetsWithContext(ctx context.Context, input *elasticl } func (c *elbv2Client) SetIpAddressTypeWithContext(ctx context.Context, input *elasticloadbalancingv2.SetIpAddressTypeInput) (*elasticloadbalancingv2.SetIpAddressTypeOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "SetIpAddressType") + client, err := c.getClient(ctx, "SetIpAddressType") if err != nil { return nil, err } @@ -262,7 +270,7 @@ func (c *elbv2Client) SetIpAddressTypeWithContext(ctx context.Context, input *el } func (c *elbv2Client) DeleteLoadBalancerWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteLoadBalancerInput) (*elasticloadbalancingv2.DeleteLoadBalancerOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteLoadBalancer") + client, err := c.getClient(ctx, "DeleteLoadBalancer") if err != nil { return nil, err } @@ -270,7 +278,7 @@ func (c *elbv2Client) DeleteLoadBalancerWithContext(ctx context.Context, input * } func (c *elbv2Client) CreateLoadBalancerWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateLoadBalancerInput) (*elasticloadbalancingv2.CreateLoadBalancerOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateLoadBalancer") + client, err := c.getClient(ctx, "CreateLoadBalancer") if err != nil { return nil, err } @@ -278,7 +286,7 @@ func (c *elbv2Client) CreateLoadBalancerWithContext(ctx context.Context, input * } func (c *elbv2Client) DescribeLoadBalancerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeLoadBalancerAttributesInput) (*elasticloadbalancingv2.DescribeLoadBalancerAttributesOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancerAttributes") + client, err := c.getClient(ctx, "DescribeLoadBalancerAttributes") if err != nil { return nil, err } @@ -286,7 +294,7 @@ func (c *elbv2Client) DescribeLoadBalancerAttributesWithContext(ctx context.Cont } func (c *elbv2Client) ModifyLoadBalancerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyLoadBalancerAttributesInput) (*elasticloadbalancingv2.ModifyLoadBalancerAttributesOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyLoadBalancerAttributes") + client, err := c.getClient(ctx, "ModifyLoadBalancerAttributes") if err != nil { return nil, err } @@ -294,7 +302,7 @@ func (c *elbv2Client) ModifyLoadBalancerAttributesWithContext(ctx context.Contex } func (c *elbv2Client) ModifyListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerInput) (*elasticloadbalancingv2.ModifyListenerOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyListener") + client, err := c.getClient(ctx, "ModifyListener") if err != nil { return nil, err } @@ -302,7 +310,7 @@ func (c *elbv2Client) ModifyListenerWithContext(ctx context.Context, input *elas } func (c *elbv2Client) DeleteListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.DeleteListenerInput) (*elasticloadbalancingv2.DeleteListenerOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DeleteListener") + client, err := c.getClient(ctx, "DeleteListener") if err != nil { return nil, err } @@ -310,7 +318,7 @@ func (c *elbv2Client) DeleteListenerWithContext(ctx context.Context, input *elas } func (c *elbv2Client) CreateListenerWithContext(ctx context.Context, input *elasticloadbalancingv2.CreateListenerInput) (*elasticloadbalancingv2.CreateListenerOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "CreateListener") + client, err := c.getClient(ctx, "CreateListener") if err != nil { return nil, err } @@ -318,7 +326,7 @@ func (c *elbv2Client) CreateListenerWithContext(ctx context.Context, input *elas } func (c *elbv2Client) DescribeTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeTagsInput) (*elasticloadbalancingv2.DescribeTagsOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTags") + client, err := c.getClient(ctx, "DescribeTags") if err != nil { return nil, err } @@ -326,7 +334,7 @@ func (c *elbv2Client) DescribeTagsWithContext(ctx context.Context, input *elasti } func (c *elbv2Client) AddTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.AddTagsInput) (*elasticloadbalancingv2.AddTagsOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "AddTags") + client, err := c.getClient(ctx, "AddTags") if err != nil { return nil, err } @@ -334,7 +342,7 @@ func (c *elbv2Client) AddTagsWithContext(ctx context.Context, input *elasticload } func (c *elbv2Client) RemoveTagsWithContext(ctx context.Context, input *elasticloadbalancingv2.RemoveTagsInput) (*elasticloadbalancingv2.RemoveTagsOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "RemoveTags") + client, err := c.getClient(ctx, "RemoveTags") if err != nil { return nil, err } @@ -345,7 +353,7 @@ func (c *elbv2Client) DescribeLoadBalancersAsList(ctx context.Context, input *el var result []types.LoadBalancer var client *elasticloadbalancingv2.Client var err error - client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeLoadBalancers") + client, err = c.getClient(ctx, "DescribeLoadBalancers") if err != nil { return nil, err } @@ -364,7 +372,7 @@ func (c *elbv2Client) DescribeTargetGroupsAsList(ctx context.Context, input *ela var result []types.TargetGroup var client *elasticloadbalancingv2.Client var err error - client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeTargetGroups") + client, err = c.getClient(ctx, "DescribeTargetGroups") if err != nil { return nil, err } @@ -383,7 +391,7 @@ func (c *elbv2Client) DescribeListenersAsList(ctx context.Context, input *elasti var result []types.Listener var client *elasticloadbalancingv2.Client var err error - client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListeners") + client, err = c.getClient(ctx, "DescribeListeners") if err != nil { return nil, err } @@ -402,7 +410,7 @@ func (c *elbv2Client) DescribeListenerCertificatesAsList(ctx context.Context, in var result []types.Certificate var client *elasticloadbalancingv2.Client var err error - client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListenerCertificates") + client, err = c.getClient(ctx, "DescribeListenerCertificates") if err != nil { return nil, err } @@ -421,7 +429,7 @@ func (c *elbv2Client) DescribeRulesAsList(ctx context.Context, input *elasticloa var result []types.Rule var client *elasticloadbalancingv2.Client var err error - client, err = c.awsClientsProvider.GetELBv2Client(ctx, "DescribeRules") + client, err = c.getClient(ctx, "DescribeRules") if err != nil { return nil, err } @@ -437,7 +445,7 @@ func (c *elbv2Client) DescribeRulesAsList(ctx context.Context, input *elasticloa } func (c *elbv2Client) DescribeListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeListenerAttributesInput) (*elasticloadbalancingv2.DescribeListenerAttributesOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeListenerAttributes") + client, err := c.getClient(ctx, "DescribeListenerAttributes") if err != nil { return nil, err } @@ -445,7 +453,7 @@ func (c *elbv2Client) DescribeListenerAttributesWithContext(ctx context.Context, } func (c *elbv2Client) ModifyListenerAttributesWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyListenerAttributesInput) (*elasticloadbalancingv2.ModifyListenerAttributesOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyListenerAttributes") + client, err := c.getClient(ctx, "ModifyListenerAttributes") if err != nil { return nil, err } @@ -453,7 +461,7 @@ func (c *elbv2Client) ModifyListenerAttributesWithContext(ctx context.Context, i } func (c *elbv2Client) ModifyCapacityReservationWithContext(ctx context.Context, input *elasticloadbalancingv2.ModifyCapacityReservationInput) (*elasticloadbalancingv2.ModifyCapacityReservationOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "ModifyCapacityReservation") + client, err := c.getClient(ctx, "ModifyCapacityReservation") if err != nil { return nil, err } @@ -461,9 +469,16 @@ func (c *elbv2Client) ModifyCapacityReservationWithContext(ctx context.Context, } func (c *elbv2Client) DescribeCapacityReservationWithContext(ctx context.Context, input *elasticloadbalancingv2.DescribeCapacityReservationInput) (*elasticloadbalancingv2.DescribeCapacityReservationOutput, error) { - client, err := c.awsClientsProvider.GetELBv2Client(ctx, "DescribeCapacityReservation") + client, err := c.getClient(ctx, "DescribeCapacityReservation") if err != nil { return nil, err } return client.DescribeCapacityReservation(ctx, input) } + +func (c *elbv2Client) getClient(ctx context.Context, operation string) (*elasticloadbalancingv2.Client, error) { + if c.staticELBClient != nil { + return c.staticELBClient, nil + } + return c.awsClientsProvider.GetELBv2Client(ctx, operation) +} diff --git a/pkg/aws/services/elbv2_mocks.go b/pkg/aws/services/elbv2_mocks.go index 8d5b540a7e..1f427dd566 100644 --- a/pkg/aws/services/elbv2_mocks.go +++ b/pkg/aws/services/elbv2_mocks.go @@ -67,11 +67,12 @@ func (mr *MockELBV2MockRecorder) AddTagsWithContext(arg0, arg1 interface{}) *gom } // AssumeRole mocks base method. -func (m *MockELBV2) AssumeRole(arg0 context.Context, arg1, arg2 string) ELBV2 { +func (m *MockELBV2) AssumeRole(arg0 context.Context, arg1, arg2 string) (ELBV2, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AssumeRole", arg0, arg1, arg2) ret0, _ := ret[0].(ELBV2) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // AssumeRole indicates an expected call of AssumeRole. diff --git a/pkg/targetgroupbinding/resource_manager.go b/pkg/targetgroupbinding/resource_manager.go index 666d25b1c4..be3dfd81d2 100644 --- a/pkg/targetgroupbinding/resource_manager.go +++ b/pkg/targetgroupbinding/resource_manager.go @@ -96,7 +96,6 @@ func (m *defaultResourceManager) Reconcile(ctx context.Context, tgb *elbv2api.Ta var oldCheckPoint string var isDeferred bool var err error - AnnotationsToFields(tgb) if *tgb.Spec.TargetType == elbv2api.TargetTypeIP { newCheckPoint, oldCheckPoint, isDeferred, err = m.reconcileWithIPTargetType(ctx, tgb) @@ -116,7 +115,6 @@ func (m *defaultResourceManager) Reconcile(ctx context.Context, tgb *elbv2api.Ta } func (m *defaultResourceManager) Cleanup(ctx context.Context, tgb *elbv2api.TargetGroupBinding) error { - AnnotationsToFields(tgb) if err := m.cleanupTargets(ctx, tgb); err != nil { return err } @@ -546,42 +544,45 @@ func (m *defaultResourceManager) deregisterTargets(ctx context.Context, tgb *elb func (m *defaultResourceManager) registerPodEndpoints(ctx context.Context, tgb *elbv2api.TargetGroupBinding, endpoints []backend.PodEndpoint) error { vpcID := m.vpcID - // Target group is in a different VPC from the cluster's VPC if tgb.Spec.VpcID != "" && tgb.Spec.VpcID != m.vpcID { vpcID = tgb.Spec.VpcID m.logger.Info(fmt.Sprintf( "registering endpoints using the targetGroup's vpcID %s which is different from the cluster's vpcID %s", tgb.Spec.VpcID, m.vpcID)) + } - if tgb.Spec.IamRoleArnToAssume != "" { - // since we need to assume a role for this TGB, - // it is from a different account - // so the packets will need to leave the VPC and therefore - // target.AvailabilityZone = awssdk.String("all") must be set - // or else nothing will work - sdkTargets := make([]elbv2types.TargetDescription, 0, len(endpoints)) - for _, endpoint := range endpoints { - target := elbv2types.TargetDescription{ - Id: awssdk.String(endpoint.IP), - Port: awssdk.Int32(endpoint.Port), - } - target.AvailabilityZone = awssdk.String("all") - sdkTargets = append(sdkTargets, target) - } - return m.targetsManager.RegisterTargets(ctx, tgb, sdkTargets) + var overrideAzFn func(addr netip.Addr) bool + if tgb.Spec.IamRoleArnToAssume != "" { + // If we're interacting with another account, then we should always be setting "all" AZ to allow this + // target to get registered by the ELB API. + overrideAzFn = func(_ netip.Addr) bool { + return true + } + } else { + vpcInfo, err := m.vpcInfoProvider.FetchVPCInfo(ctx, vpcID) + if err != nil { + return err + } + var vpcRawCIDRs []string + vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv4CIDRs()...) + vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv6CIDRs()...) + vpcCIDRs, err := networking.ParseCIDRs(vpcRawCIDRs) + if err != nil { + return err + } + // If the pod ip resides out of all the VPC CIDRs, then the only way to force the ELB API is to use "all" AZ. + overrideAzFn = func(addr netip.Addr) bool { + return !networking.IsIPWithinCIDRs(addr, vpcCIDRs) } } - vpcInfo, err := m.vpcInfoProvider.FetchVPCInfo(ctx, vpcID) - if err != nil { - return err - } - var vpcRawCIDRs []string - vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv4CIDRs()...) - vpcRawCIDRs = append(vpcRawCIDRs, vpcInfo.AssociatedIPv6CIDRs()...) - vpcCIDRs, err := networking.ParseCIDRs(vpcRawCIDRs) + + sdkTargets, err := m.prepareRegistrationCall(endpoints, overrideAzFn) if err != nil { return err } + return m.targetsManager.RegisterTargets(ctx, tgb, sdkTargets) +} +func (m *defaultResourceManager) prepareRegistrationCall(endpoints []backend.PodEndpoint, doAzOverride func(addr netip.Addr) bool) ([]elbv2types.TargetDescription, error) { sdkTargets := make([]elbv2types.TargetDescription, 0, len(endpoints)) for _, endpoint := range endpoints { target := elbv2types.TargetDescription{ @@ -590,14 +591,14 @@ func (m *defaultResourceManager) registerPodEndpoints(ctx context.Context, tgb * } podIP, err := netip.ParseAddr(endpoint.IP) if err != nil { - return err + return sdkTargets, err } - if !networking.IsIPWithinCIDRs(podIP, vpcCIDRs) { + if doAzOverride(podIP) { target.AvailabilityZone = awssdk.String("all") } sdkTargets = append(sdkTargets, target) } - return m.targetsManager.RegisterTargets(ctx, tgb, sdkTargets) + return sdkTargets, nil } func (m *defaultResourceManager) registerNodePortEndpoints(ctx context.Context, tgb *elbv2api.TargetGroupBinding, endpoints []backend.NodePortEndpoint) error { diff --git a/pkg/targetgroupbinding/targets_manager.go b/pkg/targetgroupbinding/targets_manager.go index b00f95a9b5..13e6117a80 100644 --- a/pkg/targetgroupbinding/targets_manager.go +++ b/pkg/targetgroupbinding/targets_manager.go @@ -89,7 +89,13 @@ func (m *cachedTargetsManager) RegisterTargets(ctx context.Context, tgb *elbv2ap m.logger.Info("registering targets", "arn", tgARN, "targets", targetsChunk) - _, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId).RegisterTargetsWithContext(ctx, req) + + clientToUse, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId) + if err != nil { + return err + } + + _, err = clientToUse.RegisterTargetsWithContext(ctx, req) if err != nil { return err } @@ -111,7 +117,11 @@ func (m *cachedTargetsManager) DeregisterTargets(ctx context.Context, tgb *elbv2 m.logger.Info("deRegistering targets", "arn", tgARN, "targets", targetsChunk) - _, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId).DeregisterTargetsWithContext(ctx, req) + clientToUse, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId) + if err != nil { + return err + } + _, err = clientToUse.DeregisterTargetsWithContext(ctx, req) if err != nil { return err } @@ -199,7 +209,11 @@ func (m *cachedTargetsManager) listTargetsFromAWS(ctx context.Context, tgb *elbv TargetGroupArn: aws.String(tgARN), Targets: pointerizeTargetDescriptions(targets), } - resp, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId).DescribeTargetHealthWithContext(ctx, req) + clientToUse, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId) + if err != nil { + return nil, err + } + resp, err := clientToUse.DescribeTargetHealthWithContext(ctx, req) if err != nil { return nil, err } diff --git a/pkg/targetgroupbinding/targets_manager_test.go b/pkg/targetgroupbinding/targets_manager_test.go index f4f476fe03..5b85a6b7c4 100644 --- a/pkg/targetgroupbinding/targets_manager_test.go +++ b/pkg/targetgroupbinding/targets_manager_test.go @@ -280,7 +280,7 @@ func Test_cachedTargetsManager_RegisterTargets(t *testing.T) { ctx := context.Background() for _, call := range tt.fields.registerTargetsWithContextCalls { elbv2Client.EXPECT().RegisterTargetsWithContext(gomock.Any(), call.req).Return(call.resp, call.err) - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client) + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil) } targetsCache := cache.NewExpiring() @@ -526,7 +526,7 @@ func Test_cachedTargetsManager_DeregisterTargets(t *testing.T) { ctx := context.Background() for _, call := range tt.fields.deregisterTargetsWithContextCalls { elbv2Client.EXPECT().DeregisterTargetsWithContext(gomock.Any(), call.req).Return(call.resp, call.err) - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client) + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil) } targetsCache := cache.NewExpiring() @@ -792,7 +792,7 @@ func Test_cachedTargetsManager_ListTargets(t *testing.T) { elbv2Client := services.NewMockELBV2(ctrl) for _, call := range tt.fields.describeTargetHealthWithContextCalls { elbv2Client.EXPECT().DescribeTargetHealthWithContext(gomock.Any(), call.req).Return(call.resp, call.err) - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client) + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil) } targetsCache := cache.NewExpiring() targetsCacheTTL := 1 * time.Minute @@ -1203,7 +1203,7 @@ func Test_cachedTargetsManager_refreshUnhealthyTargets(t *testing.T) { elbv2Client := services.NewMockELBV2(ctrl) for _, call := range tt.fields.describeTargetHealthWithContextCalls { elbv2Client.EXPECT().DescribeTargetHealthWithContext(gomock.Any(), call.req).Return(call.resp, call.err) - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client) + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil) } m := &cachedTargetsManager{ elbv2Client: elbv2Client, @@ -1367,7 +1367,7 @@ func Test_cachedTargetsManager_listTargetsFromAWS(t *testing.T) { elbv2Client := services.NewMockELBV2(ctrl) for _, call := range tt.fields.describeTargetHealthWithContextCalls { elbv2Client.EXPECT().DescribeTargetHealthWithContext(gomock.Any(), call.req).Return(call.resp, call.err) - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client) + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil) } m := &cachedTargetsManager{ diff --git a/pkg/targetgroupbinding/utils.go b/pkg/targetgroupbinding/utils.go index 013c94081d..057e08a77a 100644 --- a/pkg/targetgroupbinding/utils.go +++ b/pkg/targetgroupbinding/utils.go @@ -25,27 +25,8 @@ const ( // Index Key for "ServiceReference" index. IndexKeyServiceRefName = "spec.serviceRef.name" - - // Annotation for IAM Role ARN to assume when calling AWS APIs. - AnnotationIamRoleArnToAssume = "alb.ingress.kubernetes.io/IamRoleArnToAssume" - - // Annotation for IAM Role External ID to use when calling AWS APIs. - AnnotationAssumeRoleExternalId = "alb.ingress.kubernetes.io/AssumeRoleExternalId" ) -// AnnotationsToFields converts annotations to fields. Currently it's tgb.Spec.IamRoleArnToAssume and tgb.Spec.AssumeRoleExternalId -func AnnotationsToFields(tgb *elbv2api.TargetGroupBinding) { - for key, value := range tgb.Annotations { - if key == AnnotationIamRoleArnToAssume { - tgb.Spec.IamRoleArnToAssume = value - } else { - if key == AnnotationAssumeRoleExternalId { - tgb.Spec.AssumeRoleExternalId = value - } - } - } -} - // BuildTargetHealthPodConditionType constructs the condition type for TargetHealth pod condition. func BuildTargetHealthPodConditionType(tgb *elbv2api.TargetGroupBinding) corev1.PodConditionType { return corev1.PodConditionType(fmt.Sprintf("%s/%s", TargetHealthPodConditionTypePrefix, tgb.Name)) diff --git a/test/framework/framework.go b/test/framework/framework.go index 9830647bb4..eb99b67b8d 100644 --- a/test/framework/framework.go +++ b/test/framework/framework.go @@ -63,7 +63,7 @@ func InitFramework() (*Framework, error) { VpcID: globalOptions.AWSVPCID, MaxRetries: 3, ThrottleConfig: throttle.NewDefaultServiceOperationsThrottleConfig(), - }, nil, logger, nil) + }, "clusterName", nil, logger, nil) if err != nil { return nil, err } diff --git a/webhooks/elbv2/targetgroup_helper.go b/webhooks/elbv2/targetgroup_helper.go new file mode 100644 index 0000000000..dc847dc6fe --- /dev/null +++ b/webhooks/elbv2/targetgroup_helper.go @@ -0,0 +1,45 @@ +package elbv2 + +import ( + "context" + elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + "github.com/pkg/errors" + elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1" + "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" +) + +// getTargetGroupFromAWS returns the AWS target group corresponding to the arn +func getTargetGroupFromAWS(ctx context.Context, elbv2Client services.ELBV2, tgb *elbv2api.TargetGroupBinding) (*elbv2types.TargetGroup, error) { + tgARN := tgb.Spec.TargetGroupARN + req := &elbv2sdk.DescribeTargetGroupsInput{ + TargetGroupArns: []string{tgARN}, + } + return getTargetGroupHelper(ctx, elbv2Client, tgb, tgARN, req) +} + +// getTargetGroupsByNameFromAWS returns the AWS target group corresponding to the name +func getTargetGroupsByNameFromAWS(ctx context.Context, elbv2Client services.ELBV2, tgb *elbv2api.TargetGroupBinding) (*elbv2types.TargetGroup, error) { + req := &elbv2sdk.DescribeTargetGroupsInput{ + Names: []string{tgb.Spec.TargetGroupName}, + } + + return getTargetGroupHelper(ctx, elbv2Client, tgb, tgb.Spec.TargetGroupName, req) +} + +func getTargetGroupHelper(ctx context.Context, elbv2Client services.ELBV2, tgb *elbv2api.TargetGroupBinding, tgIdentifier string, req *elbv2sdk.DescribeTargetGroupsInput) (*elbv2types.TargetGroup, error) { + clientToUse, err := elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId) + + if err != nil { + return nil, err + } + + tgList, err := clientToUse.DescribeTargetGroupsAsList(ctx, req) + if err != nil { + return nil, err + } + if len(tgList) != 1 { + return nil, errors.Errorf("expecting a single targetGroup with query [%s] but got %v", tgIdentifier, len(tgList)) + } + return &tgList[0], nil +} diff --git a/webhooks/elbv2/targetgroupbinding_mutator.go b/webhooks/elbv2/targetgroupbinding_mutator.go index 7818ce39d3..aed2f6dd65 100644 --- a/webhooks/elbv2/targetgroupbinding_mutator.go +++ b/webhooks/elbv2/targetgroupbinding_mutator.go @@ -6,14 +6,11 @@ import ( elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" awssdk "github.com/aws/aws-sdk-go-v2/aws" - elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" - "github.com/go-logr/logr" "github.com/pkg/errors" "k8s.io/apimachinery/pkg/runtime" elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" - "sigs.k8s.io/aws-load-balancer-controller/pkg/targetgroupbinding" "sigs.k8s.io/aws-load-balancer-controller/pkg/webhook" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -48,7 +45,6 @@ func (m *targetGroupBindingMutator) MutateCreate(ctx context.Context, obj runtim if err := m.getArnFromNameIfNeeded(ctx, tgb); err != nil { return nil, err } - targetgroupbinding.AnnotationsToFields(tgb) if err := m.defaultingTargetType(ctx, tgb); err != nil { return nil, err } @@ -63,7 +59,7 @@ func (m *targetGroupBindingMutator) MutateCreate(ctx context.Context, obj runtim func (m *targetGroupBindingMutator) getArnFromNameIfNeeded(ctx context.Context, tgb *elbv2api.TargetGroupBinding) error { if tgb.Spec.TargetGroupARN == "" && tgb.Spec.TargetGroupName != "" { - tgObj, err := m.getTargetGroupsByNameFromAWS(ctx, tgb.Spec.TargetGroupName) + tgObj, err := getTargetGroupsByNameFromAWS(ctx, m.elbv2Client, tgb) if err != nil { return err } @@ -123,7 +119,7 @@ func (m *targetGroupBindingMutator) defaultingVpcID(ctx context.Context, tgb *el } func (m *targetGroupBindingMutator) obtainSDKTargetTypeFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (string, error) { - targetGroup, err := m.getTargetGroupFromAWS(ctx, tgb) + targetGroup, err := getTargetGroupFromAWS(ctx, m.elbv2Client, tgb) if err != nil { return "", err } @@ -132,7 +128,7 @@ func (m *targetGroupBindingMutator) obtainSDKTargetTypeFromAWS(ctx context.Conte // getTargetGroupIPAddressTypeFromAWS returns the target group IP address type of AWS target group func (m *targetGroupBindingMutator) getTargetGroupIPAddressTypeFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (elbv2api.TargetGroupIPAddressType, error) { - targetGroup, err := m.getTargetGroupFromAWS(ctx, tgb) + targetGroup, err := getTargetGroupFromAWS(ctx, m.elbv2Client, tgb) if err != nil { return "", err } @@ -148,37 +144,8 @@ func (m *targetGroupBindingMutator) getTargetGroupIPAddressTypeFromAWS(ctx conte return ipAddressType, nil } -func (m *targetGroupBindingMutator) getTargetGroupFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (*elbv2types.TargetGroup, error) { - tgARN := tgb.Spec.TargetGroupARN - req := &elbv2sdk.DescribeTargetGroupsInput{ - TargetGroupArns: []string{tgARN}, - } - tgList, err := m.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId).DescribeTargetGroupsAsList(ctx, req) - if err != nil { - return nil, err - } - if len(tgList) != 1 { - return nil, errors.Errorf("expecting a single targetGroup but got %v", len(tgList)) - } - return &tgList[0], nil -} - -func (m *targetGroupBindingMutator) getTargetGroupsByNameFromAWS(ctx context.Context, tgName string) (*elbv2types.TargetGroup, error) { - req := &elbv2sdk.DescribeTargetGroupsInput{ - Names: []string{tgName}, - } - tgList, err := m.elbv2Client.DescribeTargetGroupsAsList(ctx, req) - if err != nil { - return nil, err - } - if len(tgList) != 1 { - return nil, errors.Errorf("expecting a single targetGroup with name [%s] but got %v", tgName, len(tgList)) - } - return &tgList[0], nil -} - func (m *targetGroupBindingMutator) getVpcIDFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (string, error) { - targetGroup, err := m.getTargetGroupFromAWS(ctx, tgb) + targetGroup, err := getTargetGroupFromAWS(ctx, m.elbv2Client, tgb) if err != nil { return "", err } diff --git a/webhooks/elbv2/targetgroupbinding_mutator_test.go b/webhooks/elbv2/targetgroupbinding_mutator_test.go index 032440ad14..0348d21de0 100644 --- a/webhooks/elbv2/targetgroupbinding_mutator_test.go +++ b/webhooks/elbv2/targetgroupbinding_mutator_test.go @@ -308,7 +308,7 @@ func Test_targetGroupBindingMutator_MutateCreate(t *testing.T) { ctx := context.Background() for _, call := range tt.fields.describeTargetGroupsAsListCalls { elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err).AnyTimes() - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes() + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes() } m := &targetGroupBindingMutator{ @@ -415,7 +415,7 @@ func Test_targetGroupBindingMutator_obtainSDKTargetTypeFromAWS(t *testing.T) { elbv2Client := services.NewMockELBV2(ctrl) for _, call := range tt.fields.describeTargetGroupsAsListCalls { elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err) - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes() + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes() } m := &targetGroupBindingMutator{ @@ -545,7 +545,7 @@ func Test_targetGroupBindingMutator_getIPAddressTypeFromAWS(t *testing.T) { elbv2Client := services.NewMockELBV2(ctrl) for _, call := range tt.fields.describeTargetGroupsAsListCalls { elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err) - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes() + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes() } m := &targetGroupBindingMutator{ @@ -631,7 +631,7 @@ func Test_targetGroupBindingMutator_obtainSDKVpcIDFromAWS(t *testing.T) { elbv2Client := services.NewMockELBV2(ctrl) for _, call := range tt.fields.describeTargetGroupsAsListCalls { elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err) - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes() + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes() } m := &targetGroupBindingMutator{ diff --git a/webhooks/elbv2/targetgroupbinding_validator.go b/webhooks/elbv2/targetgroupbinding_validator.go index bb6ee770a5..2e3e63bfee 100644 --- a/webhooks/elbv2/targetgroupbinding_validator.go +++ b/webhooks/elbv2/targetgroupbinding_validator.go @@ -9,14 +9,12 @@ import ( elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" awssdk "github.com/aws/aws-sdk-go-v2/aws" - elbv2sdk "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" "github.com/go-logr/logr" "github.com/pkg/errors" "k8s.io/apimachinery/pkg/runtime" elbv2api "sigs.k8s.io/aws-load-balancer-controller/apis/elbv2/v1beta1" "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services" "sigs.k8s.io/aws-load-balancer-controller/pkg/k8s" - "sigs.k8s.io/aws-load-balancer-controller/pkg/targetgroupbinding" "sigs.k8s.io/aws-load-balancer-controller/pkg/webhook" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -56,7 +54,6 @@ func (v *targetGroupBindingValidator) Prototype(_ admission.Request) (runtime.Ob func (v *targetGroupBindingValidator) ValidateCreate(ctx context.Context, obj runtime.Object) error { tgb := obj.(*elbv2api.TargetGroupBinding) - targetgroupbinding.AnnotationsToFields(tgb) if err := v.checkRequiredFields(ctx, tgb); err != nil { return err } @@ -72,13 +69,15 @@ func (v *targetGroupBindingValidator) ValidateCreate(ctx context.Context, obj ru if err := v.checkTargetGroupVpcID(ctx, tgb); err != nil { return err } + if err := v.checkAssumeRoleConfig(tgb); err != nil { + return err + } return nil } func (v *targetGroupBindingValidator) ValidateUpdate(ctx context.Context, obj runtime.Object, oldObj runtime.Object) error { tgb := obj.(*elbv2api.TargetGroupBinding) oldTgb := oldObj.(*elbv2api.TargetGroupBinding) - targetgroupbinding.AnnotationsToFields(tgb) if err := v.checkRequiredFields(ctx, tgb); err != nil { return err } @@ -88,6 +87,9 @@ func (v *targetGroupBindingValidator) ValidateUpdate(ctx context.Context, obj ru if err := v.checkNodeSelector(tgb); err != nil { return err } + if err := v.checkAssumeRoleConfig(tgb); err != nil { + return err + } return nil } @@ -113,7 +115,7 @@ func (v *targetGroupBindingValidator) checkRequiredFields(ctx context.Context, t By changing the object here I guarantee as early as possible that that assumption is true. */ - tgObj, err := v.getTargetGroupsByNameFromAWS(ctx, tgb.Spec.TargetGroupName) + tgObj, err := getTargetGroupsByNameFromAWS(ctx, v.elbv2Client, tgb) if err != nil { return fmt.Errorf("searching TargetGroup with name %s: %w", tgb.Spec.TargetGroupName, err) } @@ -215,7 +217,7 @@ func (v *targetGroupBindingValidator) checkTargetGroupVpcID(ctx context.Context, // getTargetGroupIPAddressTypeFromAWS returns the target group IP address type of AWS target group func (v *targetGroupBindingValidator) getTargetGroupIPAddressTypeFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (elbv2api.TargetGroupIPAddressType, error) { - targetGroup, err := v.getTargetGroupFromAWS(ctx, tgb) + targetGroup, err := getTargetGroupFromAWS(ctx, v.elbv2Client, tgb) if err != nil { return "", err } @@ -231,43 +233,25 @@ func (v *targetGroupBindingValidator) getTargetGroupIPAddressTypeFromAWS(ctx con return ipAddressType, nil } -// getTargetGroupFromAWS returns the AWS target group corresponding to the ARN -func (v *targetGroupBindingValidator) getTargetGroupFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (*elbv2types.TargetGroup, error) { - tgARN := tgb.Spec.TargetGroupARN - req := &elbv2sdk.DescribeTargetGroupsInput{ - TargetGroupArns: []string{tgARN}, - } - tgList, err := v.elbv2Client.AssumeRole(ctx, tgb.Spec.IamRoleArnToAssume, tgb.Spec.AssumeRoleExternalId).DescribeTargetGroupsAsList(ctx, req) - if err != nil { - return nil, err - } - if len(tgList) != 1 { - return nil, errors.Errorf("expecting a single targetGroup but got %v", len(tgList)) - } - return &tgList[0], nil -} - func (v *targetGroupBindingValidator) getVpcIDFromAWS(ctx context.Context, tgb *elbv2api.TargetGroupBinding) (string, error) { - targetGroup, err := v.getTargetGroupFromAWS(ctx, tgb) + targetGroup, err := getTargetGroupFromAWS(ctx, v.elbv2Client, tgb) if err != nil { return "", err } return awssdk.ToString(targetGroup.VpcId), nil } -// getTargetGroupFromAWS returns the AWS target group corresponding to the tgName -func (v *targetGroupBindingValidator) getTargetGroupsByNameFromAWS(ctx context.Context, tgName string) (*elbv2types.TargetGroup, error) { - req := &elbv2sdk.DescribeTargetGroupsInput{ - Names: []string{tgName}, - } - tgList, err := v.elbv2Client.DescribeTargetGroupsAsList(ctx, req) - if err != nil { - return nil, err +// checkAssumeRoleConfig various checks for using cross account target group bindings. +func (v *targetGroupBindingValidator) checkAssumeRoleConfig(tgb *elbv2api.TargetGroupBinding) error { + if tgb.Spec.IamRoleArnToAssume == "" { + return nil } - if len(tgList) != 1 { - return nil, errors.Errorf("expecting a single targetGroup with name [%s] but got %v", tgName, len(tgList)) + + if tgb.Spec.TargetType != nil && *tgb.Spec.TargetType == elbv2api.TargetTypeInstance { + return errors.New("Unable to use instance target type while using assume role") } - return &tgList[0], nil + + return nil } // +kubebuilder:webhook:path=/validate-elbv2-k8s-aws-v1beta1-targetgroupbinding,mutating=false,failurePolicy=fail,groups=elbv2.k8s.aws,resources=targetgroupbindings,verbs=create;update,versions=v1beta1,name=vtargetgroupbinding.elbv2.k8s.aws,sideEffects=None,webhookVersions=v1,admissionReviewVersions=v1beta1 diff --git a/webhooks/elbv2/targetgroupbinding_validator_test.go b/webhooks/elbv2/targetgroupbinding_validator_test.go index 4889604942..33fe3607a4 100644 --- a/webhooks/elbv2/targetgroupbinding_validator_test.go +++ b/webhooks/elbv2/targetgroupbinding_validator_test.go @@ -333,7 +333,7 @@ func Test_targetGroupBindingValidator_ValidateCreate(t *testing.T) { elbv2Client := services.NewMockELBV2(ctrl) for _, call := range tt.fields.describeTargetGroupsAsListCalls { elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err) - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes() + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes() } v := &targetGroupBindingValidator{ k8sClient: k8sClient, @@ -1395,7 +1395,7 @@ func Test_targetGroupBindingValidator_checkTargetGroupVpcID(t *testing.T) { elbv2Client := services.NewMockELBV2(ctrl) for _, call := range tt.fields.describeTargetGroupsAsListCalls { elbv2Client.EXPECT().DescribeTargetGroupsAsList(gomock.Any(), call.req).Return(call.resp, call.err) - elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client).AnyTimes() + elbv2Client.EXPECT().AssumeRole(ctx, gomock.Any(), gomock.Any()).Return(elbv2Client, nil).AnyTimes() } v := &targetGroupBindingValidator{ k8sClient: k8sClient, @@ -1412,6 +1412,67 @@ func Test_targetGroupBindingValidator_checkTargetGroupVpcID(t *testing.T) { } } +func TestCheckAssumeRoleConfig(t *testing.T) { + instance := elbv2api.TargetTypeInstance + ip := elbv2api.TargetTypeIP + testCases := []struct { + name string + tgb *elbv2api.TargetGroupBinding + err error + }{ + { + name: "ip no assume role", + tgb: &elbv2api.TargetGroupBinding{ + Spec: elbv2api.TargetGroupBindingSpec{ + TargetType: &ip, + }, + }, + }, + { + name: "instance no assume role", + tgb: &elbv2api.TargetGroupBinding{ + Spec: elbv2api.TargetGroupBindingSpec{ + TargetType: &instance, + }, + }, + }, + { + name: "ip with assume role", + tgb: &elbv2api.TargetGroupBinding{ + Spec: elbv2api.TargetGroupBindingSpec{ + TargetType: &ip, + IamRoleArnToAssume: "foo", + }, + }, + }, + { + name: "instance with assume role", + tgb: &elbv2api.TargetGroupBinding{ + Spec: elbv2api.TargetGroupBindingSpec{ + TargetType: &instance, + IamRoleArnToAssume: "foo", + }, + }, + err: errors.New("Unable to use instance target type while using assume role"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + v := &targetGroupBindingValidator{ + logger: logr.New(&log.NullLogSink{}), + } + + err := v.checkAssumeRoleConfig(tc.tgb) + if tc.err == nil { + assert.Nil(t, err) + } else { + assert.EqualError(t, err, tc.err.Error()) + } + }) + } +} + func generateRandomString(n int, addChars ...rune) string { const letters = "0123456789abcdef"