Skip to content

Commit f567b91

Browse files
authored
fix(server): asset cors middleware (#1810)
1 parent df2d665 commit f567b91

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

server/internal/app/internal/middleware/assets_cors_middleware.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package middleware
22

33
import (
44
"net/http"
5+
"net/url"
6+
"strings"
57

68
"github.com/labstack/echo/v4"
79
"github.com/reearth/reearth/server/internal/usecase/gateway"
@@ -23,8 +25,13 @@ func AssetsCORSMiddleware(domainChecker gateway.DomainChecker, allowedOrigins []
2325
}
2426
}
2527
if allowedOrigin == "" {
28+
domain, err := extractDomain(origin)
29+
if err != nil {
30+
log.Errorfc(c.Request().Context(), "[AssetsCORSMiddleware] extract domain err: %v", err)
31+
return next(c)
32+
}
2633
domainResp, err := domainChecker.CheckDomain(c.Request().Context(), gateway.DomainCheckRequest{
27-
Domain: origin,
34+
Domain: domain,
2835
})
2936
if err != nil {
3037
log.Errorfc(c.Request().Context(), "[AssetsCORSMiddleware] domain checker check domain err: %v", err)
@@ -46,3 +53,17 @@ func AssetsCORSMiddleware(domainChecker gateway.DomainChecker, allowedOrigins []
4653
}
4754
}
4855
}
56+
57+
func extractDomain(raw string) (string, error) {
58+
u, err := url.Parse(raw)
59+
if err != nil {
60+
return "", err
61+
}
62+
63+
host := u.Host
64+
if strings.Contains(host, ":") {
65+
host = strings.Split(host, ":")[0]
66+
}
67+
68+
return host, nil
69+
}

server/internal/app/internal/middleware/assets_cors_middleware_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func TestAssetsCORSMiddleware(t *testing.T) {
7777

7878
mockChecker := &mockDomainChecker{
7979
checkFunc: func(ctx context.Context, req gateway.DomainCheckRequest) (*gateway.DomainCheckResponse, error) {
80-
if req.Domain == "https://custom.com" {
80+
if req.Domain == "custom.com" {
8181
return &gateway.DomainCheckResponse{Allowed: true}, nil
8282
}
8383
return &gateway.DomainCheckResponse{Allowed: false}, nil
@@ -468,7 +468,7 @@ func TestAssetsCORSMiddleware_EdgeCases(t *testing.T) {
468468

469469
mockChecker := &mockDomainChecker{
470470
checkFunc: func(ctx context.Context, req gateway.DomainCheckRequest) (*gateway.DomainCheckResponse, error) {
471-
if req.Domain == "https://example.com/" {
471+
if req.Domain == "example.com" {
472472
return &gateway.DomainCheckResponse{Allowed: true}, nil
473473
}
474474
return &gateway.DomainCheckResponse{Allowed: false}, nil

server/internal/infrastructure/domain/http_checker.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,18 @@ func (h *HTTPDomainChecker) CheckDomain(ctx context.Context, req gateway.DomainC
4545
}
4646

4747
resp, err := h.client.Do(httpReq)
48+
4849
if err != nil {
4950
return nil, rerror.ErrInternalBy(fmt.Errorf("domain check request failed: %w", err))
5051
}
5152
defer func() {
5253
_ = resp.Body.Close()
5354
}()
5455

56+
if resp.StatusCode == http.StatusNotFound {
57+
return &gateway.DomainCheckResponse{Allowed: false}, nil
58+
}
59+
5560
if resp.StatusCode != http.StatusOK {
5661
return nil, rerror.ErrInternalBy(fmt.Errorf("domain check returned status %d", resp.StatusCode))
5762
}

0 commit comments

Comments
 (0)