diff --git a/internal/test/checks/checks_test.go b/internal/test/checks/checks_test.go new file mode 100644 index 000000000..0677054df --- /dev/null +++ b/internal/test/checks/checks_test.go @@ -0,0 +1,19 @@ +package checks_test + +import ( + "errors" + "testing" + + "github.com/sashabaranov/go-openai/internal/test/checks" +) + +func TestChecksSuccessPaths(t *testing.T) { + checks.NoError(t, nil) + checks.NoErrorF(t, nil) + checks.HasError(t, errors.New("err")) + target := errors.New("x") + checks.ErrorIs(t, target, target) + checks.ErrorIsF(t, target, target, "msg") + checks.ErrorIsNot(t, errors.New("y"), target) + checks.ErrorIsNotf(t, errors.New("y"), target, "msg") +} diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go new file mode 100644 index 000000000..fb1f4bf06 --- /dev/null +++ b/internal/test/failer_test.go @@ -0,0 +1,24 @@ +//nolint:testpackage // need access to unexported fields and types for testing +package test + +import ( + "errors" + "testing" +) + +func TestFailingErrorBuffer(t *testing.T) { + buf := &FailingErrorBuffer{} + n, err := buf.Write([]byte("test")) + if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed) { + t.Fatalf("expected %v, got %v", ErrTestErrorAccumulatorWriteFailed, err) + } + if n != 0 { + t.Fatalf("expected n=0, got %d", n) + } + if buf.Len() != 0 { + t.Fatalf("expected Len 0, got %d", buf.Len()) + } + if len(buf.Bytes()) != 0 { + t.Fatalf("expected empty bytes") + } +} diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go new file mode 100644 index 000000000..aa177679b --- /dev/null +++ b/internal/test/helpers_test.go @@ -0,0 +1,54 @@ +package test_test + +import ( + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestCreateTestFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "file.txt") + internaltest.CreateTestFile(t, path) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed to read created file: %v", err) + } + if string(data) != "hello" { + t.Fatalf("unexpected file contents: %q", string(data)) + } +} + +func TestTokenRoundTripperAddsHeader(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer "+internaltest.GetTestToken() { + t.Fatalf("authorization header not set") + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := srv.Client() + client.Transport = &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: client.Transport} + + req, err := http.NewRequest(http.MethodGet, srv.URL, nil) + if err != nil { + t.Fatalf("request error: %v", err) + } + resp, err := client.Do(req) + if err != nil { + t.Fatalf("client request error: %v", err) + } + if _, err = io.Copy(io.Discard, resp.Body); err != nil { + t.Fatalf("read body: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %d", resp.StatusCode) + } +} diff --git a/internal/test/server.go b/internal/test/server.go index 127d4c16f..d32c3e4cb 100644 --- a/internal/test/server.go +++ b/internal/test/server.go @@ -23,6 +23,18 @@ func NewTestServer() *ServerTest { return &ServerTest{handlers: make(map[string]handler)} } +// HandlerCount returns the number of registered handlers. +func (ts *ServerTest) HandlerCount() int { + return len(ts.handlers) +} + +// HasHandler checks if a handler was registered for the given path. +func (ts *ServerTest) HasHandler(path string) bool { + path = strings.ReplaceAll(path, "*", ".*") + _, ok := ts.handlers[path] + return ok +} + func (ts *ServerTest) RegisterHandler(path string, handler handler) { // to make the registered paths friendlier to a regex match in the route handler // in OpenAITestServer diff --git a/internal/test/server_test.go b/internal/test/server_test.go new file mode 100644 index 000000000..f8ce731d1 --- /dev/null +++ b/internal/test/server_test.go @@ -0,0 +1,80 @@ +package test_test + +import ( + "io" + "net/http" + "testing" + + internaltest "github.com/sashabaranov/go-openai/internal/test" +) + +func TestGetTestToken(t *testing.T) { + if internaltest.GetTestToken() != "this-is-my-secure-token-do-not-steal!!" { + t.Fatalf("unexpected token") + } +} + +func TestNewTestServer(t *testing.T) { + ts := internaltest.NewTestServer() + if ts == nil { + t.Fatalf("server not properly initialized") + } + if ts.HandlerCount() != 0 { + t.Fatalf("expected no handlers initially") + } +} + +func TestRegisterHandlerTransformsPath(t *testing.T) { + ts := internaltest.NewTestServer() + h := func(_ http.ResponseWriter, _ *http.Request) {} + ts.RegisterHandler("/foo/*", h) + if !ts.HasHandler("/foo/*") { + t.Fatalf("handler not registered with transformed path") + } +} + +func TestOpenAITestServer(t *testing.T) { + ts := internaltest.NewTestServer() + ts.RegisterHandler("/v1/test/*", func(w http.ResponseWriter, _ *http.Request) { + if _, err := io.WriteString(w, "ok"); err != nil { + t.Fatalf("write: %v", err) + } + }) + srv := ts.OpenAITestServer() + srv.Start() + defer srv.Close() + + base := srv.Client().Transport + client := &http.Client{Transport: &internaltest.TokenRoundTripper{Token: internaltest.GetTestToken(), Fallback: base}} + resp, err := client.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatalf("read response body: %v", err) + } + if resp.StatusCode != http.StatusOK || string(body) != "ok" { + t.Fatalf("unexpected response: %d %q", resp.StatusCode, string(body)) + } + + // unregistered path + resp, err = client.Get(srv.URL + "/unknown") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("expected 404, got %d", resp.StatusCode) + } + + // missing token should return unauthorized + clientNoToken := srv.Client() + resp, err = clientNoToken.Get(srv.URL + "/v1/test/123") + if err != nil { + t.Fatalf("request error: %v", err) + } + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } +}