Skip to content

Commit 29ba214

Browse files
authored
Add Clone() method for arm/policy.ClientOptions (Azure#20288)
* fix empty policy copy problem for arm/runtime.NewPipeline * add Copy() method for arm/policy.ClientOptions * rename Copy() to Clone() and fix ci failure
1 parent 15aa35d commit 29ba214

File tree

4 files changed

+88
-2
lines changed

4 files changed

+88
-2
lines changed

sdk/azcore/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
## 1.3.2 (Unreleased)
44

55
### Features Added
6+
* Add `Clone()` method for `arm/policy.ClientOptions`.
67

78
### Breaking Changes
89

910
### Bugs Fixed
1011
* ARM's RP registration policy will no longer swallow unrecognized errors.
1112
* Fixed an issue in `runtime.NewPollerFromResumeToken()` when resuming a `Poller` with a custom `PollingHandler`.
13+
* Fixed wrong policy copy in `arm/runtime.NewPipeline()`.
1214

1315
### Other Changes
1416

sdk/azcore/arm/policy/policy.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,40 @@ type ClientOptions struct {
4747
// DisableRPRegistration disables the auto-RP registration policy. Defaults to false.
4848
DisableRPRegistration bool
4949
}
50+
51+
// Clone return a deep copy of the current options.
52+
func (o *ClientOptions) Clone() *ClientOptions {
53+
if o == nil {
54+
return nil
55+
}
56+
copiedOptions := *o
57+
copiedOptions.Cloud.Services = copyMap(copiedOptions.Cloud.Services)
58+
copiedOptions.Logging.AllowedHeaders = copyArray(copiedOptions.Logging.AllowedHeaders)
59+
copiedOptions.Logging.AllowedQueryParams = copyArray(copiedOptions.Logging.AllowedQueryParams)
60+
copiedOptions.Retry.StatusCodes = copyArray(copiedOptions.Retry.StatusCodes)
61+
copiedOptions.PerRetryPolicies = copyArray(copiedOptions.PerRetryPolicies)
62+
copiedOptions.PerCallPolicies = copyArray(copiedOptions.PerCallPolicies)
63+
return &copiedOptions
64+
}
65+
66+
// copyMap return a new map with all the key value pair in the src map
67+
func copyMap[K comparable, V any](src map[K]V) map[K]V {
68+
if src == nil {
69+
return nil
70+
}
71+
copiedMap := make(map[K]V)
72+
for k, v := range src {
73+
copiedMap[k] = v
74+
}
75+
return copiedMap
76+
}
77+
78+
// copyMap return a new array with all the elements in the src array
79+
func copyArray[T any](src []T) []T {
80+
if src == nil {
81+
return nil
82+
}
83+
copiedArray := make([]T, len(src))
84+
copy(copiedArray, src)
85+
return copiedArray
86+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//go:build go1.18
2+
// +build go1.18
3+
4+
// Copyright (c) Microsoft Corporation. All rights reserved.
5+
// Licensed under the MIT License.
6+
7+
package policy
8+
9+
import (
10+
"fmt"
11+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
12+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
13+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
14+
"github.com/stretchr/testify/require"
15+
"testing"
16+
)
17+
18+
func TestClientOptions_Copy(t *testing.T) {
19+
var option *ClientOptions
20+
require.Nil(t, option.Clone())
21+
22+
option = &ClientOptions{ClientOptions: policy.ClientOptions{
23+
Cloud: cloud.AzurePublic,
24+
Logging: policy.LogOptions{
25+
AllowedHeaders: []string{"test1", "test2"},
26+
AllowedQueryParams: []string{"test1", "test2"},
27+
},
28+
Retry: policy.RetryOptions{StatusCodes: []int{1, 2}},
29+
PerRetryPolicies: []policy.Policy{runtime.NewLogPolicy(nil)},
30+
PerCallPolicies: []policy.Policy{runtime.NewLogPolicy(nil)},
31+
}}
32+
copiedOption := option.Clone()
33+
require.Equal(t, option.APIVersion, copiedOption.APIVersion)
34+
require.NotEqual(t, fmt.Sprintf("%p", &option.APIVersion), fmt.Sprintf("%p", &copiedOption.APIVersion))
35+
require.Equal(t, option.Cloud.Services, copiedOption.Cloud.Services)
36+
require.NotEqual(t, fmt.Sprintf("%p", option.Cloud.Services), fmt.Sprintf("%p", copiedOption.Cloud.Services))
37+
require.Equal(t, option.Logging.AllowedHeaders, copiedOption.Logging.AllowedHeaders)
38+
require.NotEqual(t, fmt.Sprintf("%p", option.Logging.AllowedHeaders), fmt.Sprintf("%p", copiedOption.Logging.AllowedHeaders))
39+
require.Equal(t, option.Logging.AllowedQueryParams, copiedOption.Logging.AllowedQueryParams)
40+
require.NotEqual(t, fmt.Sprintf("%p", option.Logging.AllowedQueryParams), fmt.Sprintf("%p", copiedOption.Logging.AllowedQueryParams))
41+
require.Equal(t, option.Retry.StatusCodes, copiedOption.Retry.StatusCodes)
42+
require.NotEqual(t, fmt.Sprintf("%p", option.Retry.StatusCodes), fmt.Sprintf("%p", copiedOption.Retry.StatusCodes))
43+
require.Equal(t, option.PerRetryPolicies, copiedOption.PerRetryPolicies)
44+
require.NotEqual(t, fmt.Sprintf("%p", option.PerRetryPolicies), fmt.Sprintf("%p", copiedOption.PerRetryPolicies))
45+
require.Equal(t, option.PerCallPolicies, copiedOption.PerCallPolicies)
46+
require.NotEqual(t, fmt.Sprintf("%p", option.PerCallPolicies), fmt.Sprintf("%p", copiedOption.PerCallPolicies))
47+
}

sdk/azcore/arm/runtime/pipeline.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, plOpts azr
2929
return azruntime.Pipeline{}, err
3030
}
3131
authPolicy := NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{Scopes: []string{conf.Audience + "/.default"}})
32-
perRetry := make([]azpolicy.Policy, 0, len(plOpts.PerRetry)+1)
32+
perRetry := make([]azpolicy.Policy, len(plOpts.PerRetry), len(plOpts.PerRetry)+1)
3333
copy(perRetry, plOpts.PerRetry)
3434
plOpts.PerRetry = append(perRetry, authPolicy)
3535
if !options.DisableRPRegistration {
@@ -38,7 +38,7 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, plOpts azr
3838
if err != nil {
3939
return azruntime.Pipeline{}, err
4040
}
41-
perCall := make([]azpolicy.Policy, 0, len(plOpts.PerCall)+1)
41+
perCall := make([]azpolicy.Policy, len(plOpts.PerCall), len(plOpts.PerCall)+1)
4242
copy(perCall, plOpts.PerCall)
4343
plOpts.PerCall = append(perCall, regPolicy)
4444
}

0 commit comments

Comments
 (0)