From ef50fa4b4e65914abe8c1a6e2a86a90e78e91452 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 17:53:13 +0100 Subject: [PATCH 1/2] Add unit tests for internal test utilities --- internal/test/checks/checks_test.go | 17 +++++++ internal/test/failer_test.go | 20 ++++++++ internal/test/helpers_test.go | 50 ++++++++++++++++++++ internal/test/server_test.go | 73 +++++++++++++++++++++++++++++ 4 files changed, 160 insertions(+) create mode 100644 internal/test/checks/checks_test.go create mode 100644 internal/test/failer_test.go create mode 100644 internal/test/helpers_test.go create mode 100644 internal/test/server_test.go diff --git a/internal/test/checks/checks_test.go b/internal/test/checks/checks_test.go new file mode 100644 index 000000000..c27db68a0 --- /dev/null +++ b/internal/test/checks/checks_test.go @@ -0,0 +1,17 @@ +package checks + +import ( + "errors" + "testing" +) + +func TestChecksSuccessPaths(t *testing.T) { + NoError(t, nil) + NoErrorF(t, nil) + HasError(t, errors.New("err")) + target := errors.New("x") + ErrorIs(t, target, target) + ErrorIsF(t, target, target, "msg") + ErrorIsNot(t, errors.New("y"), target) + ErrorIsNotf(t, errors.New("y"), target, "msg") +} diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go new file mode 100644 index 000000000..be95b7fab --- /dev/null +++ b/internal/test/failer_test.go @@ -0,0 +1,20 @@ +package test + +import "testing" + +func TestFailingErrorBuffer(t *testing.T) { + buf := &FailingErrorBuffer{} + n, err := buf.Write([]byte("test")) + if err != ErrTestErrorAccumulatorWriteFailed { + t.Fatalf("expected %v, got %v", ErrTestErrorAccumulatorWriteFailed, err) + } + if n != 0 { + t.Fatalf("expected n=0, got %d", n) + } + if buf.Len() != 0 { + t.Fatalf("expected Len 0, got %d", buf.Len()) + } + if len(buf.Bytes()) != 0 { + t.Fatalf("expected empty bytes") + } +} diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go new file mode 100644 index 000000000..c8623b142 --- /dev/null +++ b/internal/test/helpers_test.go @@ -0,0 +1,50 @@ +package test + +import ( + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestCreateTestFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "file.txt") + CreateTestFile(t, path) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read created file: %v", err) + } + if string(data) != "hello" { + t.Fatalf("unexpected file contents: %q", string(data)) + } +} + +func TestTokenRoundTripperAddsHeader(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer "+GetTestToken() { + t.Fatalf("authorization header not set") + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := srv.Client() + client.Transport = &TokenRoundTripper{Token: GetTestToken(), Fallback: client.Transport} + + req, err := http.NewRequest(http.MethodGet, srv.URL, nil) + if err != nil { + t.Fatalf("request error: %v", err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client request error: %v", err) + } + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } +} diff --git a/internal/test/server_test.go b/internal/test/server_test.go new file mode 100644 index 000000000..367ee53c5 --- /dev/null +++ b/internal/test/server_test.go @@ -0,0 +1,73 @@ +package test + +import ( + "io" + "net/http" + "testing" +) + +func TestGetTestToken(t *testing.T) { + if GetTestToken() != testAPI { + t.Fatalf("unexpected token") + } +} + +func TestNewTestServer(t *testing.T) { + ts := NewTestServer() + if ts == nil || ts.handlers == nil { + t.Fatalf("server not properly initialized") + } + if len(ts.handlers) != 0 { + t.Fatalf("expected no handlers initially") + } +} + +func TestRegisterHandlerTransformsPath(t *testing.T) { + ts := NewTestServer() + h := func(w http.ResponseWriter, r *http.Request) {} + ts.RegisterHandler("/foo/*", h) + if ts.handlers["/foo/.*"] == nil { + t.Fatalf("handler not registered with transformed path") + } +} + +func TestOpenAITestServer(t *testing.T) { + ts := NewTestServer() + ts.RegisterHandler("/v1/test/*", func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "ok") + }) + srv := ts.OpenAITestServer() + srv.Start() + defer srv.Close() + + base := srv.Client().Transport + client := &http.Client{Transport: &TokenRoundTripper{Token: GetTestToken(), Fallback: base}} + resp, err := client.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + if resp.StatusCode != http.StatusOK || string(body) != "ok" { + t.Fatalf("unexpected response: %d %q", resp.StatusCode, string(body)) + } + + // unregistered path + resp, err = client.Get(srv.URL + "/unknown") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404, got %d", resp.StatusCode) + } + + // missing token should return unauthorized + clientNoToken := srv.Client() + resp, err = clientNoToken.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } +} From 3b9c5f25affd63aabb1caa780b6540a6b6bdee80 Mon Sep 17 00:00:00 2001 From: Alex Baranov <677093+sashabaranov@users.noreply.github.com> Date: Fri, 11 Jul 2025 18:24:09 +0100 Subject: [PATCH 2/2] Fix lint issues in internal tests --- internal/test/checks/checks_test.go | 18 +++++++++------- internal/test/failer_test.go | 8 +++++-- internal/test/helpers_test.go | 14 +++++++----- internal/test/server.go | 12 +++++++++++ internal/test/server_test.go | 33 +++++++++++++++++------------ 5 files changed, 57 insertions(+), 28 deletions(-) diff --git a/internal/test/checks/checks_test.go b/internal/test/checks/checks_test.go index c27db68a0..0677054df 100644 --- a/internal/test/checks/checks_test.go +++ b/internal/test/checks/checks_test.go @@ -1,17 +1,19 @@ -package checks +package checks_test import ( "errors" "testing" + + "github.com/sashabaranov/go-openai/internal/test/checks" ) func TestChecksSuccessPaths(t *testing.T) { - NoError(t, nil) - NoErrorF(t, nil) - HasError(t, errors.New("err")) + checks.NoError(t, nil) + checks.NoErrorF(t, nil) + checks.HasError(t, errors.New("err")) target := errors.New("x") - ErrorIs(t, target, target) - ErrorIsF(t, target, target, "msg") - ErrorIsNot(t, errors.New("y"), target) - ErrorIsNotf(t, errors.New("y"), target, "msg") + checks.ErrorIs(t, target, target) + checks.ErrorIsF(t, target, target, "msg") + checks.ErrorIsNot(t, errors.New("y"), target) + checks.ErrorIsNotf(t, errors.New("y"), target, "msg") } diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go index be95b7fab..fb1f4bf06 100644 --- a/internal/test/failer_test.go +++ b/internal/test/failer_test.go @@ -1,11 +1,15 @@ +//nolint:testpackage // need access to unexported fields and types for testing package test -import "testing" +import ( + "errors" + "testing" +) func TestFailingErrorBuffer(t *testing.T) { buf := &FailingErrorBuffer{} n, err := buf.Write([]byte("test")) - if err != ErrTestErrorAccumulatorWriteFailed { + if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed) { t.Fatalf("expected %v, got %v", ErrTestErrorAccumulatorWriteFailed, err) } if n != 0 { diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go index c8623b142..aa177679b 100644 --- a/internal/test/helpers_test.go +++ b/internal/test/helpers_test.go @@ -1,4 +1,4 @@ -package test +package test_test import ( "io" @@ -7,12 +7,14 @@ import ( "os" "path/filepath" "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" ) func TestCreateTestFile(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "file.txt") - CreateTestFile(t, path) + internaltest.CreateTestFile(t, path) data, err := os.ReadFile(path) if err != nil { t.Fatalf("failed to read created file: %v", err) @@ -24,7 +26,7 @@ func TestCreateTestFile(t *testing.T) { func TestTokenRoundTripperAddsHeader(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") != "Bearer "+GetTestToken() { + if r.Header.Get("Authorization") != "Bearer "+internaltest.GetTestToken() { t.Fatalf("authorization header not set") } w.WriteHeader(http.StatusOK) @@ -32,7 +34,7 @@ func TestTokenRoundTripperAddsHeader(t *testing.T) { defer srv.Close() client := srv.Client() - client.Transport = &TokenRoundTripper{Token: GetTestToken(), Fallback: client.Transport} + client.Transport = &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: client.Transport} req, err := http.NewRequest(http.MethodGet, srv.URL, nil) if err != nil { @@ -42,7 +44,9 @@ func TestTokenRoundTripperAddsHeader(t *testing.T) { if err != nil { t.Fatalf("client request error: %v", err) } - io.Copy(io.Discard, resp.Body) + if _, err = io.Copy(io.Discard, resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("unexpected status: %d", resp.StatusCode) diff --git a/internal/test/server.go b/internal/test/server.go index 127d4c16f..d32c3e4cb 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -23,6 +23,18 @@ func NewTestServer() *ServerTest { return &ServerTest{handlers: make(map[string]handler)} } +// HandlerCount returns the number of registered handlers. +func (ts *ServerTest) HandlerCount() int { + return len(ts.handlers) +} + +// HasHandler checks if a handler was registered for the given path. +func (ts *ServerTest) HasHandler(path string) bool { + path = strings.ReplaceAll(path, "*", ".*") + _, ok := ts.handlers[path] + return ok +} + func (ts *ServerTest) RegisterHandler(path string, handler handler) { // to make the registered paths friendlier to a regex match in the route handler // in OpenAITestServer diff --git a/internal/test/server_test.go b/internal/test/server_test.go index 367ee53c5..f8ce731d1 100644 --- a/internal/test/server_test.go +++ b/internal/test/server_test.go @@ -1,53 +1,60 @@ -package test +package test_test import ( "io" "net/http" "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" ) func TestGetTestToken(t *testing.T) { - if GetTestToken() != testAPI { + if internaltest.GetTestToken() != "this-is-my-secure-token-do-not-steal!!" { t.Fatalf("unexpected token") } } func TestNewTestServer(t *testing.T) { - ts := NewTestServer() - if ts == nil || ts.handlers == nil { + ts := internaltest.NewTestServer() + if ts == nil { t.Fatalf("server not properly initialized") } - if len(ts.handlers) != 0 { + if ts.HandlerCount() != 0 { t.Fatalf("expected no handlers initially") } } func TestRegisterHandlerTransformsPath(t *testing.T) { - ts := NewTestServer() - h := func(w http.ResponseWriter, r *http.Request) {} + ts := internaltest.NewTestServer() + h := func(_ http.ResponseWriter, _ *http.Request) {} ts.RegisterHandler("/foo/*", h) - if ts.handlers["/foo/.*"] == nil { + if !ts.HasHandler("/foo/*") { t.Fatalf("handler not registered with transformed path") } } func TestOpenAITestServer(t *testing.T) { - ts := NewTestServer() - ts.RegisterHandler("/v1/test/*", func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "ok") + ts := internaltest.NewTestServer() + ts.RegisterHandler("/v1/test/*", func(w http.ResponseWriter, _ *http.Request) { + if _, err := io.WriteString(w, "ok"); err != nil { + t.Fatalf("write: %v", err) + } }) srv := ts.OpenAITestServer() srv.Start() defer srv.Close() base := srv.Client().Transport - client := &http.Client{Transport: &TokenRoundTripper{Token: GetTestToken(), Fallback: base}} + client := &http.Client{Transport: &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: base}} resp, err := client.Get(srv.URL + "/v1/test/123") if err != nil { t.Fatalf("request error: %v", err) } - body, _ := io.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) resp.Body.Close() + if err != nil { + t.Fatalf("read response body: %v", err) + } if resp.StatusCode != http.StatusOK || string(body) != "ok" { t.Fatalf("unexpected response: %d %q", resp.StatusCode, string(body)) }