Skip to content

Commit f88c7b2

Browse files
authored
Prevent data races in internal/recording (Azure#19610)
1 parent afe4ecb commit f88c7b2

File tree

3 files changed

+64
-26
lines changed

3 files changed

+64
-26
lines changed

sdk/internal/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
### Bugs Fixed
1010

1111
### Other Changes
12+
* Prevented data races in `recording` ([#18763](https://github.com/Azure/azure-sdk-for-go/issues/18763))
1213

1314
## 1.1.1 (2022-11-09)
1415

sdk/internal/recording/recording.go

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"runtime"
2525
"strconv"
2626
"strings"
27+
"sync"
2728
"testing"
2829
"time"
2930

@@ -513,7 +514,27 @@ type recordedTest struct {
513514
variables map[string]interface{}
514515
}
515516

516-
var testSuite = map[string]recordedTest{}
517+
// testMap maps test names to metadata
518+
type testMap struct {
519+
m *sync.Map
520+
}
521+
522+
// Load returns the named test's metadata, if it has been stored
523+
func (t *testMap) Load(name string) (recordedTest, bool) {
524+
var rt recordedTest
525+
v, ok := t.m.Load(name)
526+
if ok {
527+
rt = v.(recordedTest)
528+
}
529+
return rt, ok
530+
}
531+
532+
// Store sets metadata for the named test
533+
func (t *testMap) Store(name string, data recordedTest) {
534+
t.m.Store(name, data)
535+
}
536+
537+
var testSuite = testMap{&sync.Map{}}
517538

518539
var client = http.Client{
519540
Transport: &http.Transport{
@@ -675,7 +696,7 @@ func Start(t *testing.T, pathToRecordings string, options *RecordingOptions) err
675696
return nil
676697
}
677698

678-
if testStruct, ok := testSuite[t.Name()]; ok {
699+
if testStruct, ok := testSuite.Load(t.Name()); ok {
679700
if testStruct.liveOnly {
680701
// test should only be run live, don't want to generate recording
681702
return nil
@@ -729,16 +750,16 @@ func Start(t *testing.T, pathToRecordings string, options *RecordingOptions) err
729750
}
730751
}
731752

732-
if val, ok := testSuite[t.Name()]; ok {
753+
if val, ok := testSuite.Load(t.Name()); ok {
733754
val.recordingId = recId
734755
val.variables = m
735-
testSuite[t.Name()] = val
756+
testSuite.Store(t.Name(), val)
736757
} else {
737-
testSuite[t.Name()] = recordedTest{
758+
testSuite.Store(t.Name(), recordedTest{
738759
recordingId: recId,
739760
liveOnly: false,
740761
variables: m,
741-
}
762+
})
742763
}
743764
return nil
744765
}
@@ -752,7 +773,7 @@ func Stop(t *testing.T, options *RecordingOptions) error {
752773
return nil
753774
}
754775

755-
if testStruct, ok := testSuite[t.Name()]; ok {
776+
if testStruct, ok := testSuite.Load(t.Name()); ok {
756777
if testStruct.liveOnly {
757778
// test should only be run live, don't want to generate recording
758779
return nil
@@ -776,7 +797,7 @@ func Stop(t *testing.T, options *RecordingOptions) error {
776797

777798
var recTest recordedTest
778799
var ok bool
779-
if recTest, ok = testSuite[t.Name()]; !ok {
800+
if recTest, ok = testSuite.Load(t.Name()); !ok {
780801
return errors.New("Recording ID was never set. Did you call StartRecording?")
781802
}
782803
req.Header.Set(IDHeader, recTest.recordingId)
@@ -803,11 +824,11 @@ func GetEnvVariable(varName string, recordedValue string) string {
803824
}
804825

805826
func LiveOnly(t *testing.T) {
806-
if val, ok := testSuite[t.Name()]; ok {
827+
if val, ok := testSuite.Load(t.Name()); ok {
807828
val.liveOnly = true
808-
testSuite[t.Name()] = val
829+
testSuite.Store(t.Name(), val)
809830
} else {
810-
testSuite[t.Name()] = recordedTest{liveOnly: true}
831+
testSuite.Store(t.Name(), recordedTest{liveOnly: true})
811832
}
812833
if GetRecordMode() == PlaybackMode {
813834
t.Skip("Live Test Only")
@@ -823,7 +844,7 @@ func Sleep(duration time.Duration) {
823844
}
824845

825846
func GetRecordingId(t *testing.T) string {
826-
if val, ok := testSuite[t.Name()]; ok {
847+
if val, ok := testSuite.Load(t.Name()); ok {
827848
return val.recordingId
828849
} else {
829850
return ""
@@ -890,15 +911,15 @@ func GetHTTPClient(t *testing.T) (*http.Client, error) {
890911
}
891912

892913
func IsLiveOnly(t *testing.T) bool {
893-
if s, ok := testSuite[t.Name()]; ok {
914+
if s, ok := testSuite.Load(t.Name()); ok {
894915
return s.liveOnly
895916
}
896917
return false
897918
}
898919

899920
// GetVariables returns access to the variables stored by the test proxy for a specific test
900921
func GetVariables(t *testing.T) map[string]interface{} {
901-
if s, ok := testSuite[t.Name()]; ok {
922+
if s, ok := testSuite.Load(t.Name()); ok {
902923
return s.variables
903924
}
904925
return nil

sdk/internal/recording/recording_test.go

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,15 @@ func (s *recordingTests) TearDownSuite() {
378378
}
379379
}
380380

381+
func TestGetEnvVariable(t *testing.T) {
382+
require.Equal(t, GetEnvVariable("Nonexistentevnvar", "somefakevalue"), "somefakevalue")
383+
temp := recordMode
384+
recordMode = RecordingMode
385+
t.Setenv("TEST_VARIABLE", "expected")
386+
require.Equal(t, "expected", GetEnvVariable("TEST_VARIABLE", "unexpected"))
387+
recordMode = temp
388+
}
389+
381390
func TestRecordingOptions(t *testing.T) {
382391
r := RecordingOptions{
383392
UseHTTPS: true,
@@ -386,18 +395,6 @@ func TestRecordingOptions(t *testing.T) {
386395

387396
r.UseHTTPS = false
388397
require.Equal(t, r.baseURL(), "http://localhost:5000")
389-
390-
require.Equal(t, GetEnvVariable("Nonexistentevnvar", "somefakevalue"), "somefakevalue")
391-
temp := recordMode
392-
recordMode = RecordingMode
393-
require.NotEqual(t, GetEnvVariable("PROXY_CERT", "fake/path/to/proxycert"), "fake/path/to/proxycert")
394-
recordMode = temp
395-
396-
r.UseHTTPS = false
397-
require.Equal(t, r.baseURL(), "http://localhost:5000")
398-
399-
r.UseHTTPS = true
400-
require.Equal(t, r.baseURL(), "https://localhost:5001")
401398
}
402399

403400
var packagePath = "sdk/internal/recording/testdata"
@@ -677,3 +674,22 @@ func TestVariables(t *testing.T) {
677674
require.NoError(t, err)
678675
}()
679676
}
677+
678+
func TestRace(t *testing.T) {
679+
temp := recordMode
680+
recordMode = LiveMode
681+
t.Cleanup(func() { recordMode = temp })
682+
for i := 0; i < 4; i++ {
683+
t.Run("", func(t *testing.T) {
684+
t.Parallel()
685+
err := Start(t, "", nil)
686+
require.NoError(t, err)
687+
GetRecordingId(t)
688+
GetVariables(t)
689+
IsLiveOnly(t)
690+
err = Stop(t, nil)
691+
require.NoError(t, err)
692+
LiveOnly(t)
693+
})
694+
}
695+
}

0 commit comments

Comments
 (0)