Skip to content

Commit 3e46b09

Browse files
authored
Fixed bug in runtime.SetMultipartFormData (Azure#19302)
Slices of io.ReadSeekCloser was improperly handled.
1 parent 46e8351 commit 3e46b09

File tree

3 files changed

+42
-6
lines changed

3 files changed

+42
-6
lines changed

sdk/azcore/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
### Breaking Changes
88

99
### Bugs Fixed
10+
* Fixed an issue in `runtime.SetMultipartFormData` to properly handle slices of `io.ReadSeekCloser`.
1011

1112
### Other Changes
1213

sdk/azcore/runtime/request.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,30 @@ func MarshalAsXML(req *policy.Request, v interface{}) error {
119119
func SetMultipartFormData(req *policy.Request, formData map[string]interface{}) error {
120120
body := bytes.Buffer{}
121121
writer := multipart.NewWriter(&body)
122+
123+
writeContent := func(fieldname, filename string, src io.Reader) error {
124+
fd, err := writer.CreateFormFile(fieldname, filename)
125+
if err != nil {
126+
return err
127+
}
128+
// copy the data to the form file
129+
if _, err = io.Copy(fd, src); err != nil {
130+
return err
131+
}
132+
return nil
133+
}
134+
122135
for k, v := range formData {
123136
if rsc, ok := v.(io.ReadSeekCloser); ok {
124-
// this is the body to upload, the key is its file name
125-
fd, err := writer.CreateFormFile(k, k)
126-
if err != nil {
137+
if err := writeContent(k, k, rsc); err != nil {
127138
return err
128139
}
129-
// copy the data to the form file
130-
if _, err = io.Copy(fd, rsc); err != nil {
131-
return err
140+
continue
141+
} else if rscs, ok := v.([]io.ReadSeekCloser); ok {
142+
for _, rsc := range rscs {
143+
if err := writeContent(k, k, rsc); err != nil {
144+
return err
145+
}
132146
}
133147
continue
134148
}

sdk/azcore/runtime/request_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported"
2525
"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared"
26+
"github.com/stretchr/testify/require"
2627
)
2728

2829
type testJSON struct {
@@ -635,6 +636,11 @@ func TestSetMultipartFormData(t *testing.T) {
635636
"string": "value",
636637
"int": 1,
637638
"data": exported.NopCloser(strings.NewReader("some data")),
639+
"datum": []io.ReadSeekCloser{
640+
exported.NopCloser(strings.NewReader("first part")),
641+
exported.NopCloser(strings.NewReader("second part")),
642+
exported.NopCloser(strings.NewReader("third part")),
643+
},
638644
})
639645
if err != nil {
640646
t.Fatal(err)
@@ -647,6 +653,7 @@ func TestSetMultipartFormData(t *testing.T) {
647653
t.Fatalf("unexpected media type %s", mt)
648654
}
649655
reader := multipart.NewReader(req.Raw().Body, params["boundary"])
656+
var datum []io.ReadSeekCloser
650657
for {
651658
part, err := reader.NextPart()
652659
if err == io.EOF {
@@ -682,8 +689,22 @@ func TestSetMultipartFormData(t *testing.T) {
682689
if tr := string(dataPart[:9]); tr != "some data" {
683690
t.Fatalf("unexpected value %s", tr)
684691
}
692+
case "datum":
693+
content, err := io.ReadAll(part)
694+
require.NoError(t, err)
695+
datum = append(datum, exported.NopCloser(bytes.NewReader(content)))
685696
default:
686697
t.Fatalf("unexpected part %s", fn)
687698
}
688699
}
700+
require.Len(t, datum, 3)
701+
first, err := io.ReadAll(datum[0])
702+
require.NoError(t, err)
703+
second, err := io.ReadAll(datum[1])
704+
require.NoError(t, err)
705+
third, err := io.ReadAll(datum[2])
706+
require.NoError(t, err)
707+
require.Equal(t, "first part", string(first))
708+
require.Equal(t, "second part", string(second))
709+
require.Equal(t, "third part", string(third))
689710
}

0 commit comments

Comments
 (0)