Skip to content

Commit 0353352

Browse files
committed
add support for aws clients provider
1 parent ebc3c25 commit 0353352

File tree

6 files changed

+127
-28
lines changed

6 files changed

+127
-28
lines changed

main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func main() {
8181
ctrl.SetLogger(appLogger)
8282
klog.SetLoggerWithOptions(appLogger, klog.ContextualLogger(true))
8383

84-
cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log)
84+
cloud, err := aws.NewCloud(controllerCFG.AWSConfig, metrics.Registry, ctrl.Log, nil)
8585
if err != nil {
8686
setupLog.Error(err, "unable to initialize AWS cloud")
8787
os.Exit(1)

pkg/aws/cloud.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/prometheus/client_golang/prometheus"
2525
amerrors "k8s.io/apimachinery/pkg/util/errors"
2626
epresolver "sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
27+
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
2728
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/services"
2829
)
2930

@@ -59,7 +60,7 @@ type Cloud interface {
5960
}
6061

6162
// NewCloud constructs new Cloud implementation.
62-
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger) (Cloud, error) {
63+
func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger logr.Logger, awsClientsProvider provider.AWSClientsProvider) (Cloud, error) {
6364
hasIPv4 := true
6465
addrs, err := net.InterfaceAddrs()
6566
if err == nil {
@@ -129,7 +130,14 @@ func NewCloud(cfg CloudConfig, metricsRegisterer prometheus.Registerer, logger l
129130
awsConfig.APIOptions = metrics.WithSDKMetricCollector(metricsCollector, awsConfig.APIOptions)
130131
}
131132

132-
ec2Service := services.NewEC2(awsConfig, endpointsResolver)
133+
if awsClientsProvider == nil {
134+
var err error
135+
awsClientsProvider, err = NewDefaultAWSClientsProvider(awsConfig, endpointsResolver)
136+
if err != nil {
137+
return nil, errors.Wrap(err, "failed to create aws clients provider")
138+
}
139+
}
140+
ec2Service := services.NewEC2(awsClientsProvider)
133141

134142
vpcID, err := getVpcID(cfg, ec2Service, ec2Metadata, logger)
135143
if err != nil {
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package aws
2+
3+
import (
4+
"context"
5+
"github.com/aws/aws-sdk-go-v2/aws"
6+
"github.com/aws/aws-sdk-go-v2/service/ec2"
7+
"github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2"
8+
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
9+
)
10+
11+
type defaultAWSClientsProvider struct {
12+
ec2Client *ec2.Client
13+
elbv2Client *elasticloadbalancingv2.Client
14+
}
15+
16+
func NewDefaultAWSClientsProvider(cfg aws.Config, endpointsResolver *endpoints.Resolver) (*defaultAWSClientsProvider, error) {
17+
customEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID)
18+
ec2Client := ec2.NewFromConfig(cfg, func(o *ec2.Options) {
19+
if customEndpoint != nil {
20+
o.BaseEndpoint = customEndpoint
21+
}
22+
})
23+
return &defaultAWSClientsProvider{
24+
ec2Client: ec2Client,
25+
elbv2Client: nil,
26+
}, nil
27+
}
28+
29+
func (p *defaultAWSClientsProvider) GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error) {
30+
return p.ec2Client, nil
31+
}

pkg/aws/provider/provider.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package provider
2+
3+
import (
4+
"context"
5+
"github.com/aws/aws-sdk-go-v2/service/ec2"
6+
)
7+
8+
type AWSClientsProvider interface {
9+
GetEC2Client(ctx context.Context, operationName string) (*ec2.Client, error)
10+
}

pkg/aws/services/ec2.go

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ package services
22

33
import (
44
"context"
5-
"github.com/aws/aws-sdk-go-v2/aws"
65
"github.com/aws/aws-sdk-go-v2/service/ec2"
76
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
8-
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/endpoints"
7+
"sigs.k8s.io/aws-load-balancer-controller/pkg/aws/provider"
98
)
109

1110
type EC2 interface {
@@ -37,28 +36,31 @@ type EC2 interface {
3736
}
3837

3938
// NewEC2 constructs new EC2 implementation.
40-
func NewEC2(cfg aws.Config, endpointsResolver *endpoints.Resolver) EC2 {
41-
customEndpoint := endpointsResolver.EndpointFor(ec2.ServiceID)
39+
func NewEC2(awsClientsProvider provider.AWSClientsProvider) EC2 {
4240
return &ec2Client{
43-
ec2Client: ec2.NewFromConfig(cfg, func(o *ec2.Options) {
44-
if customEndpoint != nil {
45-
o.BaseEndpoint = customEndpoint
46-
}
47-
}),
41+
awsClientsProvider: awsClientsProvider,
4842
}
4943
}
5044

5145
type ec2Client struct {
52-
ec2Client *ec2.Client
46+
awsClientsProvider provider.AWSClientsProvider
5347
}
5448

5549
func (c *ec2Client) DescribeInstancesWithContext(ctx context.Context, input *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) {
56-
return c.ec2Client.DescribeInstances(ctx, input)
50+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeInstances")
51+
if err != nil {
52+
return nil, err
53+
}
54+
return client.DescribeInstances(ctx, input)
5755
}
5856

5957
func (c *ec2Client) DescribeInstancesAsList(ctx context.Context, input *ec2.DescribeInstancesInput) ([]types.Instance, error) {
6058
var result []types.Instance
61-
paginator := ec2.NewDescribeInstancesPaginator(c.ec2Client, input)
59+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeInstances")
60+
if err != nil {
61+
return nil, err
62+
}
63+
paginator := ec2.NewDescribeInstancesPaginator(client, input)
6264
for paginator.HasMorePages() {
6365
output, err := paginator.NextPage(ctx)
6466
if err != nil {
@@ -73,7 +75,11 @@ func (c *ec2Client) DescribeInstancesAsList(ctx context.Context, input *ec2.Desc
7375

7476
func (c *ec2Client) DescribeNetworkInterfacesAsList(ctx context.Context, input *ec2.DescribeNetworkInterfacesInput) ([]types.NetworkInterface, error) {
7577
var result []types.NetworkInterface
76-
paginator := ec2.NewDescribeNetworkInterfacesPaginator(c.ec2Client, input)
78+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeNetworkInterfaces")
79+
if err != nil {
80+
return nil, err
81+
}
82+
paginator := ec2.NewDescribeNetworkInterfacesPaginator(client, input)
7783
for paginator.HasMorePages() {
7884
output, err := paginator.NextPage(ctx)
7985
if err != nil {
@@ -86,7 +92,11 @@ func (c *ec2Client) DescribeNetworkInterfacesAsList(ctx context.Context, input *
8692

8793
func (c *ec2Client) DescribeSecurityGroupsAsList(ctx context.Context, input *ec2.DescribeSecurityGroupsInput) ([]types.SecurityGroup, error) {
8894
var result []types.SecurityGroup
89-
paginator := ec2.NewDescribeSecurityGroupsPaginator(c.ec2Client, input)
95+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeSecurityGroups")
96+
if err != nil {
97+
return nil, err
98+
}
99+
paginator := ec2.NewDescribeSecurityGroupsPaginator(client, input)
90100
for paginator.HasMorePages() {
91101
output, err := paginator.NextPage(ctx)
92102
if err != nil {
@@ -99,7 +109,11 @@ func (c *ec2Client) DescribeSecurityGroupsAsList(ctx context.Context, input *ec2
99109

100110
func (c *ec2Client) DescribeSubnetsAsList(ctx context.Context, input *ec2.DescribeSubnetsInput) ([]types.Subnet, error) {
101111
var result []types.Subnet
102-
paginator := ec2.NewDescribeSubnetsPaginator(c.ec2Client, input)
112+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeSubnets")
113+
if err != nil {
114+
return nil, err
115+
}
116+
paginator := ec2.NewDescribeSubnetsPaginator(client, input)
103117
for paginator.HasMorePages() {
104118
output, err := paginator.NextPage(ctx)
105119
if err != nil {
@@ -112,7 +126,11 @@ func (c *ec2Client) DescribeSubnetsAsList(ctx context.Context, input *ec2.Descri
112126

113127
func (c *ec2Client) DescribeVPCsAsList(ctx context.Context, input *ec2.DescribeVpcsInput) ([]types.Vpc, error) {
114128
var result []types.Vpc
115-
paginator := ec2.NewDescribeVpcsPaginator(c.ec2Client, input)
129+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeVPCs")
130+
if err != nil {
131+
return nil, err
132+
}
133+
paginator := ec2.NewDescribeVpcsPaginator(client, input)
116134
for paginator.HasMorePages() {
117135
output, err := paginator.NextPage(ctx)
118136
if err != nil {
@@ -124,33 +142,65 @@ func (c *ec2Client) DescribeVPCsAsList(ctx context.Context, input *ec2.DescribeV
124142
}
125143

126144
func (c *ec2Client) CreateTagsWithContext(ctx context.Context, input *ec2.CreateTagsInput) (*ec2.CreateTagsOutput, error) {
127-
return c.ec2Client.CreateTags(ctx, input)
145+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "CreateTags")
146+
if err != nil {
147+
return nil, err
148+
}
149+
return client.CreateTags(ctx, input)
128150
}
129151

130152
func (c *ec2Client) DeleteTagsWithContext(ctx context.Context, input *ec2.DeleteTagsInput) (*ec2.DeleteTagsOutput, error) {
131-
return c.ec2Client.DeleteTags(ctx, input)
153+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DeleteTags")
154+
if err != nil {
155+
return nil, err
156+
}
157+
return client.DeleteTags(ctx, input)
132158
}
133159

134160
func (c *ec2Client) CreateSecurityGroupWithContext(ctx context.Context, input *ec2.CreateSecurityGroupInput) (*ec2.CreateSecurityGroupOutput, error) {
135-
return c.ec2Client.CreateSecurityGroup(ctx, input)
161+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "CreateSecurityGroup")
162+
if err != nil {
163+
return nil, err
164+
}
165+
return client.CreateSecurityGroup(ctx, input)
136166
}
137167

138168
func (c *ec2Client) DeleteSecurityGroupWithContext(ctx context.Context, input *ec2.DeleteSecurityGroupInput) (*ec2.DeleteSecurityGroupOutput, error) {
139-
return c.ec2Client.DeleteSecurityGroup(ctx, input)
169+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DeleteSecurityGroup")
170+
if err != nil {
171+
return nil, err
172+
}
173+
return client.DeleteSecurityGroup(ctx, input)
140174
}
141175

142176
func (c *ec2Client) AuthorizeSecurityGroupIngressWithContext(ctx context.Context, input *ec2.AuthorizeSecurityGroupIngressInput) (*ec2.AuthorizeSecurityGroupIngressOutput, error) {
143-
return c.ec2Client.AuthorizeSecurityGroupIngress(ctx, input)
177+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "AuthorizeSecurityGroupIngress")
178+
if err != nil {
179+
return nil, err
180+
}
181+
return client.AuthorizeSecurityGroupIngress(ctx, input)
144182
}
145183

146184
func (c *ec2Client) RevokeSecurityGroupIngressWithContext(ctx context.Context, input *ec2.RevokeSecurityGroupIngressInput) (*ec2.RevokeSecurityGroupIngressOutput, error) {
147-
return c.ec2Client.RevokeSecurityGroupIngress(ctx, input)
185+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "RevokeSecurityGroupIngress")
186+
if err != nil {
187+
return nil, err
188+
}
189+
return client.RevokeSecurityGroupIngress(ctx, input)
148190
}
149191

150192
func (c *ec2Client) DescribeAvailabilityZonesWithContext(ctx context.Context, input *ec2.DescribeAvailabilityZonesInput) (*ec2.DescribeAvailabilityZonesOutput, error) {
151-
return c.ec2Client.DescribeAvailabilityZones(ctx, input)
193+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeAvailabilityZones")
194+
if err != nil {
195+
return nil, err
196+
}
197+
return client.DescribeAvailabilityZones(ctx, input)
152198
}
153199

154200
func (c *ec2Client) DescribeVpcsWithContext(ctx context.Context, input *ec2.DescribeVpcsInput) (*ec2.DescribeVpcsOutput, error) {
155-
return c.ec2Client.DescribeVpcs(ctx, input)
201+
client, err := c.awsClientsProvider.GetEC2Client(ctx, "DescribeVpcs")
202+
if err != nil {
203+
return nil, err
204+
}
205+
return client.DescribeVpcs(ctx, input)
156206
}

test/framework/framework.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func InitFramework() (*Framework, error) {
6262
VpcID: globalOptions.AWSVPCID,
6363
MaxRetries: 3,
6464
ThrottleConfig: throttle.NewDefaultServiceOperationsThrottleConfig(),
65-
}, nil, logger)
65+
}, nil, logger, nil)
6666
if err != nil {
6767
return nil, err
6868
}

0 commit comments

Comments
 (0)