Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func main() {
ctrl.SetLogger(appLogger)
klog.SetLoggerWithOptions(appLogger, klog.ContextualLogger(true))

cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log)
cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log, nil)
if err != nil {
setupLog.Error(err, "unable to initialize AWS cloud")
os.Exit(1)
Expand Down
25 changes: 16 additions & 9 deletions pkg/aws/cloud.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
amerrors "k8s.io/apimachinery/pkg/util/errors"
epresolver "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
)

Expand Down Expand Up @@ -59,7 +60,7 @@ type Cloud interface {
}

// NewCloud constructs new Cloud implementation.
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger) (Cloud, error) {
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (Cloud, error) {
hasIPv4 := true
addrs, err := net.InterfaceAddrs()
if err == nil {
Expand Down Expand Up @@ -129,7 +130,14 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l
awsConfig.APIOptions = metrics.WithSDKMetricCollector(metricsCollector, awsConfig.APIOptions)
}

ec2Service := services.NewEC2(awsConfig, endpointsResolver)
if awsClientsProvider == nil {
var err error
awsClientsProvider, err = provider.NewDefaultAWSClientsProvider(awsConfig, endpointsResolver)
if err != nil {
return nil, errors.Wrap(err, "failed to create aws clients provider")
}
}
ec2Service := services.NewEC2(awsClientsProvider)

vpcID, err := getVpcID(cfg, ec2Service, ec2Metadata, logger)
if err != nil {
Expand All @@ -139,17 +147,16 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l
return &defaultCloud{
cfg: cfg,
ec2: ec2Service,
elbv2: services.NewELBV2(awsConfig, endpointsResolver),
acm: services.NewACM(awsConfig, endpointsResolver),
wafv2: services.NewWAFv2(awsConfig, endpointsResolver),
wafRegional: services.NewWAFRegional(awsConfig, endpointsResolver, cfg.Region),
shield: services.NewShield(awsConfig, endpointsResolver), //done
rgt: services.NewRGT(awsConfig, endpointsResolver),
elbv2: services.NewELBV2(awsClientsProvider),
acm: services.NewACM(awsClientsProvider),
wafv2: services.NewWAFv2(awsClientsProvider),
wafRegional: services.NewWAFRegional(awsClientsProvider, cfg.Region),
shield: services.NewShield(awsClientsProvider),
rgt: services.NewRGT(awsClientsProvider),
}, nil
}

func getVpcID(cfg CloudConfig, ec2Service services.EC2, ec2Metadata services.EC2Metadata, logger logr.Logger) (string, error) {

if cfg.VpcID != "" {
logger.V(1).Info("vpcid is specified using flag --aws-vpc-id, controller will use the value", "vpc: ", cfg.VpcID)
return cfg.VpcID, nil
Expand Down
109 changes: 109 additions & 0 deletions pkg/aws/provider/default_aws_clients_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package provider

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
"github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi"
"github.com/aws/aws-sdk-go-v2/service/shield"
"github.com/aws/aws-sdk-go-v2/service/wafregional"
"github.com/aws/aws-sdk-go-v2/service/wafv2"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
)

type defaultAWSClientsProvider struct {
ec2Client *ec2.Client
elbv2Client *elasticloadbalancingv2.Client
acmClient *acm.Client
wafv2Client *wafv2.Client
wafRegionClient *wafregional.Client
shieldClient *shield.Client
rgtClient *resourcegroupstaggingapi.Client
}

func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) {
ec2CustomEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID)
elbv2CustomEndpoint := endpointsResolver.EndpointFor(elasticloadbalancingv2.ServiceID)
acmCustomEndpoint := endpointsResolver.EndpointFor(acm.ServiceID)
wafv2CustomEndpoint := endpointsResolver.EndpointFor(wafv2.ServiceID)
wafregionalCustomEndpoint := endpointsResolver.EndpointFor(wafregional.ServiceID)
shieldCustomEndpoint := endpointsResolver.EndpointFor(shield.ServiceID)
rgtCustomEndpoint := endpointsResolver.EndpointFor(resourcegroupstaggingapi.ServiceID)

ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) {
if ec2CustomEndpoint != nil {
o.BaseEndpoint = ec2CustomEndpoint
}
})
elbv2Client := elasticloadbalancingv2.NewFromConfig(cfg, func(o *elasticloadbalancingv2.Options) {
if elbv2CustomEndpoint != nil {
o.BaseEndpoint = elbv2CustomEndpoint
}
})
acmClient := acm.NewFromConfig(cfg, func(o *acm.Options) {
if acmCustomEndpoint != nil {
o.BaseEndpoint = acmCustomEndpoint
}
})
wafv2Client := wafv2.NewFromConfig(cfg, func(o *wafv2.Options) {
if wafv2CustomEndpoint != nil {
o.BaseEndpoint = wafv2CustomEndpoint
}
})
wafregionalClient := wafregional.NewFromConfig(cfg, func(o *wafregional.Options) {
o.Region = cfg.Region
o.BaseEndpoint = wafregionalCustomEndpoint
})
sheildClient := shield.NewFromConfig(cfg, func(o *shield.Options) {
o.Region = "us-east-1"
o.BaseEndpoint = shieldCustomEndpoint
})
rgtClient := resourcegroupstaggingapi.NewFromConfig(cfg, func(o *resourcegroupstaggingapi.Options) {
if rgtCustomEndpoint != nil {
o.BaseEndpoint = rgtCustomEndpoint
}
})

return &defaultAWSClientsProvider{
ec2Client: ec2Client,
elbv2Client: elbv2Client,
acmClient: acmClient,
wafv2Client: wafv2Client,
wafRegionClient: wafregionalClient,
shieldClient: sheildClient,
rgtClient: rgtClient,
}, nil
}

// DO NOT REMOVE operationName as parameter, this is on purpose
// to retain the default behavior for OSS controller to use the default client for each aws service
// for our internal controller, we will choose different client based on operationName
func (p *defaultAWSClientsProvider) GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) {
return p.ec2Client, nil
}

func (p *defaultAWSClientsProvider) GetELBv2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error) {
return p.elbv2Client, nil
}

func (p *defaultAWSClientsProvider) GetACMClient(ctx context.Context, operationName string) (*acm.Client, error) {
return p.acmClient, nil
}

func (p *defaultAWSClientsProvider) GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error) {
return p.wafv2Client, nil
}

func (p *defaultAWSClientsProvider) GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error) {
return p.wafRegionClient, nil
}

func (p *defaultAWSClientsProvider) GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error) {
return p.shieldClient, nil
}

func (p *defaultAWSClientsProvider) GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error) {
return p.rgtClient, nil
}
22 changes: 22 additions & 0 deletions pkg/aws/provider/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package provider

import (
"context"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
"github.com/aws/aws-sdk-go-v2/service/resourcegroupstaggingapi"
"github.com/aws/aws-sdk-go-v2/service/shield"
"github.com/aws/aws-sdk-go-v2/service/wafregional"
"github.com/aws/aws-sdk-go-v2/service/wafv2"
)

type AWSClientsProvider interface {
GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error)
GetELBv2Client(ctx context.Context, operationName string) (*elasticloadbalancingv2.Client, error)
GetACMClient(ctx context.Context, operationName string) (*acm.Client, error)
GetWAFv2Client(ctx context.Context, operationName string) (*wafv2.Client, error)
GetWAFRegionClient(ctx context.Context, operationName string) (*wafregional.Client, error)
GetShieldClient(ctx context.Context, operationName string) (*shield.Client, error)
GetRGTClient(ctx context.Context, operationName string) (*resourcegroupstaggingapi.Client, error)
}
26 changes: 14 additions & 12 deletions pkg/aws/services/acm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ package services

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/acm"
"github.com/aws/aws-sdk-go-v2/service/acm/types"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
)

type ACM interface {
Expand All @@ -15,24 +14,23 @@ type ACM interface {
}

// NewACM constructs new ACM implementation.
func NewACM(cfg aws.Config, endpointsResolver *endpoints.Resolver) ACM {
customEndpoint := endpointsResolver.EndpointFor(acm.ServiceID)
func NewACM(awsClientsProvider provider.AWSClientsProvider) ACM {
return &acmClient{
acmClient: acm.NewFromConfig(cfg, func(o *acm.Options) {
if customEndpoint != nil {
o.BaseEndpoint = customEndpoint
}
}),
awsClientsProvider: awsClientsProvider,
}
}

type acmClient struct {
acmClient *acm.Client
awsClientsProvider provider.AWSClientsProvider
}

func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListCertificatesInput) ([]types.CertificateSummary, error) {
var result []types.CertificateSummary
paginator := acm.NewListCertificatesPaginator(c.acmClient, input)
client, err := c.awsClientsProvider.GetACMClient(ctx, "ListCertificates")
if err != nil {
return nil, err
}
paginator := acm.NewListCertificatesPaginator(client, input)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
Expand All @@ -44,5 +42,9 @@ func (c *acmClient) ListCertificatesAsList(ctx context.Context, input *acm.ListC
}

func (c *acmClient) DescribeCertificateWithContext(ctx context.Context, input *acm.DescribeCertificateInput) (*acm.DescribeCertificateOutput, error) {
return c.acmClient.DescribeCertificate(ctx, input)
client, err := c.awsClientsProvider.GetACMClient(ctx, "DescribeCertificate")
if err != nil {
return nil, err
}
return client.DescribeCertificate(ctx, input)
}
Loading
Loading