179 lines
5.6 KiB
Go
179 lines
5.6 KiB
Go
package tests
|
|
|
|
import (
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
|
|
"optoant/config"
|
|
"optoant/handlers"
|
|
)
|
|
|
|
// mockUpstream starts a test HTTP server that always responds with the given
|
|
// status code and body.
|
|
func mockUpstream(t *testing.T, status int, body string) *httptest.Server {
|
|
t.Helper()
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
_, _ = io.WriteString(w, body)
|
|
}))
|
|
t.Cleanup(srv.Close)
|
|
return srv
|
|
}
|
|
|
|
// newTestApp creates a Fiber app wired to an OpenAI handler pointing at mockURL.
|
|
func newTestApp(cfg *config.Config) *fiber.App {
|
|
app := fiber.New()
|
|
app.All("/v1/*", handlers.OpenAIHandler(cfg, nil))
|
|
return app
|
|
}
|
|
|
|
// TestOpenAIProxy_Success verifies that a valid upstream response is forwarded correctly.
|
|
func TestOpenAIProxy_Success(t *testing.T) {
|
|
upstream := mockUpstream(t, http.StatusOK, `{"choices":[{"message":{"role":"assistant","content":"Hello!"}}]}`)
|
|
|
|
cfg := &config.Config{
|
|
OpenAIBackend: upstream.URL,
|
|
RequestTimeoutSeconds: 5,
|
|
}
|
|
app := newTestApp(cfg)
|
|
|
|
payload := `{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}`
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(payload))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer test-key")
|
|
|
|
resp, err := app.Test(req, fiber.TestConfig{Timeout: -1})
|
|
if err != nil {
|
|
t.Fatalf("app.Test error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
var result map[string]interface{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
t.Fatalf("decode response: %v", err)
|
|
}
|
|
choices, ok := result["choices"].([]interface{})
|
|
if !ok || len(choices) == 0 {
|
|
t.Errorf("expected choices in response, got: %v", result)
|
|
}
|
|
}
|
|
|
|
// TestOpenAIProxy_DefaultModelInjection verifies that OPENAI_MODEL is injected
|
|
// when the request body has no model field.
|
|
func TestOpenAIProxy_DefaultModelInjection(t *testing.T) {
|
|
var capturedBody string
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
body, _ := io.ReadAll(r.Body)
|
|
capturedBody = string(body)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":"ok"}}]}`)
|
|
}))
|
|
t.Cleanup(upstream.Close)
|
|
|
|
cfg := &config.Config{
|
|
OpenAIBackend: upstream.URL,
|
|
RequestTimeoutSeconds: 5,
|
|
OpenAIModel: "deepseek/deepseek-v4-pro",
|
|
}
|
|
app := newTestApp(cfg)
|
|
|
|
// No model in payload — should be injected
|
|
payload := `{"messages":[{"role":"user","content":"Hi"}]}`
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(payload))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := app.Test(req, fiber.TestConfig{Timeout: -1})
|
|
if err != nil {
|
|
t.Fatalf("app.Test error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
var result map[string]interface{}
|
|
json.Unmarshal([]byte(capturedBody), &result)
|
|
if result["model"] != cfg.OpenAIModel {
|
|
t.Errorf("expected model=%q, got %q", cfg.OpenAIModel, result["model"])
|
|
}
|
|
}
|
|
|
|
// TestOpenAIProxy_DefaultModelInjection_ExistingModel verifies that an existing
|
|
// model in the request body is NOT overridden by OPENAI_MODEL.
|
|
func TestOpenAIProxy_DefaultModelInjection_ExistingModel(t *testing.T) {
|
|
var capturedBody string
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
body, _ := io.ReadAll(r.Body)
|
|
capturedBody = string(body)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":"ok"}}]}`)
|
|
}))
|
|
t.Cleanup(upstream.Close)
|
|
|
|
cfg := &config.Config{
|
|
OpenAIBackend: upstream.URL,
|
|
RequestTimeoutSeconds: 5,
|
|
OpenAIModel: "deepseek/deepseek-v4-pro",
|
|
}
|
|
app := newTestApp(cfg)
|
|
|
|
// Model already set — should NOT be overridden
|
|
payload := `{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}`
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(payload))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := app.Test(req, fiber.TestConfig{Timeout: -1})
|
|
if err != nil {
|
|
t.Fatalf("app.Test error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", resp.StatusCode)
|
|
}
|
|
|
|
var result map[string]interface{}
|
|
json.Unmarshal([]byte(capturedBody), &result)
|
|
if result["model"] != "gpt-4" {
|
|
t.Errorf("expected model=gpt-4 (preserved), got %q", result["model"])
|
|
}
|
|
}
|
|
|
|
// TestOpenAIProxy_UpstreamError verifies that a 502 is returned when upstream fails.
|
|
func TestOpenAIProxy_UpstreamError(t *testing.T) {
|
|
// Point to an address that should refuse connection
|
|
cfg := &config.Config{
|
|
OpenAIBackend: "http://127.0.0.1:19999", // nothing listening
|
|
RequestTimeoutSeconds: 2,
|
|
}
|
|
app := newTestApp(cfg)
|
|
|
|
payload := `{"model":"gpt-4","messages":[{"role":"user","content":"Hi"}]}`
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(payload))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := app.Test(req, fiber.TestConfig{Timeout: -1})
|
|
if err != nil {
|
|
t.Fatalf("app.Test error: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusBadGateway {
|
|
t.Errorf("expected 502, got %d", resp.StatusCode)
|
|
}
|
|
}
|