Skip to content

Commit 1d56ec5

Browse files
committed
feat: cors middleware supports custom settings
1 parent d0557f8 commit 1d56ec5

File tree

2 files changed

+359
-11
lines changed

2 files changed

+359
-11
lines changed

pkg/gin/middleware/cors.go

Lines changed: 110 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,115 @@ import (
77
"github.com/gin-gonic/gin"
88
)
99

10+
type CoresConfig = cors.Config
11+
12+
// CoresOption set coresOptions.
13+
type CoresOption func(*coresOptions)
14+
15+
type coresOptions struct {
16+
newCoresConfig *CoresConfig // if nil, use default config under fields.
17+
18+
allowOrigins []string
19+
allowMethods []string
20+
allowHeaders []string
21+
exposeHeaders []string
22+
maxAge time.Duration
23+
allowWildcard bool
24+
allowCredentials bool
25+
}
26+
27+
func (o *coresOptions) apply(opts ...CoresOption) {
28+
for _, opt := range opts {
29+
opt(o)
30+
}
31+
}
32+
33+
func defaultCoreOptions() *coresOptions {
34+
return &coresOptions{
35+
allowOrigins: []string{"*"},
36+
allowMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"},
37+
allowHeaders: []string{"Origin", "Authorization", "Content-Type", "Accept", "X-Requested-With", "X-CSRF-Token"},
38+
exposeHeaders: []string{"Content-Length", "text/plain", "Authorization", "Content-Type"},
39+
allowCredentials: true,
40+
allowWildcard: true,
41+
maxAge: 12 * time.Hour,
42+
}
43+
}
44+
45+
// WithNewConfig set cors config, if nil, use default config under fields.
46+
func WithNewConfig(config *CoresConfig) CoresOption {
47+
return func(o *coresOptions) {
48+
o.newCoresConfig = config
49+
}
50+
}
51+
52+
// WithAllowOrigins set allowOrigins, e.g. "https://yourdomain.com", "https://*.subdomain.com"
53+
func WithAllowOrigins(allowOrigins ...string) CoresOption {
54+
return func(o *coresOptions) {
55+
o.allowOrigins = allowOrigins
56+
}
57+
}
58+
59+
// WithAllowMethods set allowMethods, e.g. "GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"
60+
func WithAllowMethods(allowMethods ...string) CoresOption {
61+
return func(o *coresOptions) {
62+
o.allowMethods = allowMethods
63+
}
64+
}
65+
66+
// WithAllowHeaders set allowHeaders, e.g. "Origin", "Authorization", "Content-Type", "Accept"
67+
func WithAllowHeaders(allowHeaders ...string) CoresOption {
68+
return func(o *coresOptions) {
69+
o.allowHeaders = allowHeaders
70+
}
71+
}
72+
73+
// WithExposeHeaders set exposeHeaders
74+
func WithExposeHeaders(exposeHeaders ...string) CoresOption {
75+
return func(o *coresOptions) {
76+
o.exposeHeaders = exposeHeaders
77+
}
78+
}
79+
80+
// WithMaxAge set maxAge
81+
func WithMaxAge(maxAge time.Duration) CoresOption {
82+
return func(o *coresOptions) {
83+
o.maxAge = maxAge
84+
}
85+
}
86+
87+
// WithAllowCredentials set allowCredentials
88+
func WithAllowCredentials(allowCredentials bool) CoresOption {
89+
return func(o *coresOptions) {
90+
o.allowCredentials = allowCredentials
91+
}
92+
}
93+
94+
// WithAllowWildcard set allowWildcard
95+
func WithAllowWildcard(allowWildcard bool) CoresOption {
96+
return func(o *coresOptions) {
97+
o.allowWildcard = allowWildcard
98+
}
99+
}
100+
10101
// Cors cross domain
11-
func Cors() gin.HandlerFunc {
12-
return cors.New(
13-
cors.Config{
14-
AllowOrigins: []string{"*"},
15-
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"},
16-
AllowHeaders: []string{"Origin", "Authorization", "Content-Type", "Accept"},
17-
ExposeHeaders: []string{"Content-Length", "text/plain", "Authorization", "Content-Type"},
18-
AllowCredentials: true,
19-
MaxAge: 12 * time.Hour,
20-
},
21-
)
102+
func Cors(opts ...CoresOption) gin.HandlerFunc {
103+
o := defaultCoreOptions()
104+
o.apply(opts...)
105+
106+
var corsConfig cors.Config
107+
if o.newCoresConfig != nil {
108+
corsConfig = *o.newCoresConfig
109+
} else {
110+
corsConfig = cors.Config{}
111+
corsConfig.AllowOrigins = o.allowOrigins
112+
corsConfig.AllowMethods = o.allowMethods
113+
corsConfig.AllowHeaders = o.allowHeaders
114+
corsConfig.ExposeHeaders = o.exposeHeaders
115+
corsConfig.AllowCredentials = o.allowCredentials
116+
corsConfig.AllowWildcard = o.allowWildcard
117+
corsConfig.MaxAge = o.maxAge
118+
}
119+
120+
return cors.New(corsConfig)
22121
}

pkg/gin/middleware/cors_test.go

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
"time"
8+
9+
"github.com/gin-gonic/gin"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestDefaultCoreOptions(t *testing.T) {
14+
opts := defaultCoreOptions()
15+
16+
assert.NotNil(t, opts)
17+
assert.Equal(t, []string{"*"}, opts.allowOrigins)
18+
assert.Equal(t, []string{"GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"}, opts.allowMethods)
19+
assert.Equal(t, []string{"Origin", "Authorization", "Content-Type", "Accept", "X-Requested-With", "X-CSRF-Token"}, opts.allowHeaders)
20+
assert.Equal(t, []string{"Content-Length", "text/plain", "Authorization", "Content-Type"}, opts.exposeHeaders)
21+
assert.True(t, opts.allowCredentials)
22+
assert.True(t, opts.allowWildcard)
23+
assert.Equal(t, 12*time.Hour, opts.maxAge)
24+
}
25+
26+
func TestCorsWithDefaultOptions(t *testing.T) {
27+
router := gin.New()
28+
router.Use(Cors())
29+
router.GET("/test", func(c *gin.Context) {
30+
c.String(200, "test")
31+
})
32+
33+
w := httptest.NewRecorder()
34+
req, _ := http.NewRequest("GET", "/test", nil)
35+
router.ServeHTTP(w, req)
36+
37+
assert.Equal(t, 200, w.Code)
38+
assert.Equal(t, "test", w.Body.String())
39+
}
40+
41+
func TestCorsWithCustomConfig(t *testing.T) {
42+
customConfig := &CoresConfig{
43+
AllowOrigins: []string{"https://example.com"},
44+
AllowMethods: []string{"GET", "POST"},
45+
AllowHeaders: []string{"Authorization"},
46+
ExposeHeaders: []string{"Content-Length"},
47+
AllowCredentials: false,
48+
MaxAge: 1 * time.Hour,
49+
AllowWildcard: true,
50+
}
51+
52+
router := gin.New()
53+
router.Use(Cors(WithNewConfig(customConfig)))
54+
router.OPTIONS("/test", func(c *gin.Context) {
55+
c.Status(200)
56+
})
57+
58+
w := httptest.NewRecorder()
59+
req, _ := http.NewRequest("OPTIONS", "/test", nil)
60+
req.Header.Set("Origin", "https://example.com")
61+
req.Header.Set("Access-Control-Request-Method", "GET")
62+
router.ServeHTTP(w, req)
63+
64+
// CORS middleware returns 204 for preflight requests
65+
assert.Equal(t, 204, w.Code)
66+
}
67+
68+
func TestCorsWithAllowOrigins(t *testing.T) {
69+
router := gin.New()
70+
router.Use(Cors(WithAllowOrigins("https://example.com", "https://test.com")))
71+
router.OPTIONS("/test", func(c *gin.Context) {
72+
c.Status(200)
73+
})
74+
75+
w := httptest.NewRecorder()
76+
req, _ := http.NewRequest("OPTIONS", "/test", nil)
77+
req.Header.Set("Origin", "https://example.com")
78+
req.Header.Set("Access-Control-Request-Method", "GET")
79+
router.ServeHTTP(w, req)
80+
81+
// CORS middleware returns 204 for preflight requests
82+
assert.Equal(t, 204, w.Code)
83+
}
84+
85+
func TestCorsWithAllowMethods(t *testing.T) {
86+
router := gin.New()
87+
router.Use(Cors(WithAllowMethods("GET", "POST")))
88+
router.OPTIONS("/test", func(c *gin.Context) {
89+
c.Status(200)
90+
})
91+
92+
w := httptest.NewRecorder()
93+
req, _ := http.NewRequest("OPTIONS", "/test", nil)
94+
router.ServeHTTP(w, req)
95+
96+
assert.Equal(t, 200, w.Code)
97+
}
98+
99+
func TestCorsWithAllowHeaders(t *testing.T) {
100+
router := gin.New()
101+
router.Use(Cors(WithAllowHeaders("X-Custom-Header", "X-Another-Header")))
102+
router.OPTIONS("/test", func(c *gin.Context) {
103+
c.Status(200)
104+
})
105+
106+
w := httptest.NewRecorder()
107+
req, _ := http.NewRequest("OPTIONS", "/test", nil)
108+
router.ServeHTTP(w, req)
109+
110+
assert.Equal(t, 200, w.Code)
111+
}
112+
113+
func TestCorsWithExposeHeaders(t *testing.T) {
114+
router := gin.New()
115+
router.Use(Cors(WithExposeHeaders("X-Custom-Header", "X-Another-Header")))
116+
router.GET("/test", func(c *gin.Context) {
117+
c.String(200, "test")
118+
})
119+
120+
w := httptest.NewRecorder()
121+
req, _ := http.NewRequest("GET", "/test", nil)
122+
router.ServeHTTP(w, req)
123+
124+
assert.Equal(t, 200, w.Code)
125+
}
126+
127+
func TestCorsWithMaxAge(t *testing.T) {
128+
router := gin.New()
129+
router.Use(Cors(WithMaxAge(30 * time.Minute)))
130+
router.OPTIONS("/test", func(c *gin.Context) {
131+
c.Status(200)
132+
})
133+
134+
w := httptest.NewRecorder()
135+
req, _ := http.NewRequest("OPTIONS", "/test", nil)
136+
router.ServeHTTP(w, req)
137+
138+
assert.Equal(t, 200, w.Code)
139+
}
140+
141+
func TestCorsWithAllowCredentials(t *testing.T) {
142+
router := gin.New()
143+
router.Use(Cors(WithAllowCredentials(false)))
144+
router.GET("/test", func(c *gin.Context) {
145+
c.String(200, "test")
146+
})
147+
148+
w := httptest.NewRecorder()
149+
req, _ := http.NewRequest("GET", "/test", nil)
150+
router.ServeHTTP(w, req)
151+
152+
assert.Equal(t, 200, w.Code)
153+
}
154+
155+
func TestCorsWithAllowWildcard(t *testing.T) {
156+
router := gin.New()
157+
router.Use(Cors(WithAllowCredentials(false)))
158+
router.GET("/test", func(c *gin.Context) {
159+
c.String(200, "test")
160+
})
161+
162+
w := httptest.NewRecorder()
163+
req, _ := http.NewRequest("GET", "/test", nil)
164+
router.ServeHTTP(w, req)
165+
166+
assert.Equal(t, 200, w.Code)
167+
}
168+
169+
func TestCorsWithMultipleOptions(t *testing.T) {
170+
router := gin.New()
171+
router.Use(Cors(
172+
WithAllowOrigins("https://example.com"),
173+
WithAllowMethods("GET", "POST"),
174+
WithAllowHeaders("X-Custom-Header"),
175+
WithExposeHeaders("X-Exposed-Header"),
176+
WithMaxAge(30*time.Minute),
177+
WithAllowCredentials(false),
178+
))
179+
router.OPTIONS("/test", func(c *gin.Context) {
180+
c.Status(200)
181+
})
182+
183+
w := httptest.NewRecorder()
184+
req, _ := http.NewRequest("OPTIONS", "/test", nil)
185+
req.Header.Set("Origin", "https://example.com")
186+
req.Header.Set("Access-Control-Request-Method", "GET")
187+
router.ServeHTTP(w, req)
188+
189+
// CORS middleware returns 204 for preflight requests
190+
assert.Equal(t, 204, w.Code)
191+
}
192+
193+
func TestCorsWithNilConfig(t *testing.T) {
194+
router := gin.New()
195+
router.Use(Cors(WithNewConfig(nil)))
196+
router.GET("/test", func(c *gin.Context) {
197+
c.String(200, "test")
198+
})
199+
200+
w := httptest.NewRecorder()
201+
req, _ := http.NewRequest("GET", "/test", nil)
202+
router.ServeHTTP(w, req)
203+
204+
assert.Equal(t, 200, w.Code)
205+
}
206+
207+
func TestCorsOptionsApply(t *testing.T) {
208+
opts := &coresOptions{}
209+
opt1 := WithAllowOrigins("https://example.com")
210+
opt2 := WithAllowMethods("GET")
211+
212+
opts.apply(opt1, opt2)
213+
214+
assert.Equal(t, []string{"https://example.com"}, opts.allowOrigins)
215+
assert.Equal(t, []string{"GET"}, opts.allowMethods)
216+
}
217+
218+
func TestCorsPreflightRequest(t *testing.T) {
219+
router := gin.New()
220+
router.Use(Cors(WithAllowOrigins("https://example.com")))
221+
router.OPTIONS("/test", func(c *gin.Context) {
222+
c.Status(200)
223+
})
224+
225+
w := httptest.NewRecorder()
226+
req, _ := http.NewRequest("OPTIONS", "/test", nil)
227+
req.Header.Set("Origin", "https://example.com")
228+
req.Header.Set("Access-Control-Request-Method", "GET")
229+
router.ServeHTTP(w, req)
230+
231+
// CORS middleware returns 204 for preflight requests
232+
assert.Equal(t, 204, w.Code)
233+
}
234+
235+
func TestCorsActualRequest(t *testing.T) {
236+
router := gin.New()
237+
router.Use(Cors(WithAllowOrigins("https://example.com")))
238+
router.GET("/test", func(c *gin.Context) {
239+
c.String(200, "test")
240+
})
241+
242+
w := httptest.NewRecorder()
243+
req, _ := http.NewRequest("GET", "/test", nil)
244+
req.Header.Set("Origin", "https://example.com")
245+
router.ServeHTTP(w, req)
246+
247+
assert.Equal(t, 200, w.Code)
248+
assert.Equal(t, "test", w.Body.String())
249+
}

0 commit comments

Comments
 (0)