package api import ( "encoding/json" "net/http" "net/http/httptest" "testing" ) func TestCheckWebSocketOrigin_DevelopmentAllowsOrigin(t *testing.T) { t.Setenv("PRODUCTION", "false") req := httptest.NewRequest("GET", "http://localhost/ws", nil) req.Host = "localhost:8080" req.Header.Set("Origin", "http://example.com") if !checkWebSocketOrigin(req) { t.Fatal("expected development mode to allow origin") } } func TestCheckWebSocketOrigin_ProductionSameHostAllowed(t *testing.T) { t.Setenv("PRODUCTION", "true") t.Setenv("ALLOWED_ORIGINS", "") req := httptest.NewRequest("GET", "http://localhost/ws", nil) req.Host = "localhost:8080" req.Header.Set("Origin", "http://localhost:8080") if !checkWebSocketOrigin(req) { t.Fatal("expected same-host origin to be allowed") } } func TestRespondErrorWithCode_IncludesCodeField(t *testing.T) { s := &Manager{} rr := httptest.NewRecorder() s.respondErrorWithCode(rr, http.StatusBadRequest, "UPLOAD_SESSION_EXPIRED", "Upload session expired.") if rr.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rr.Code, http.StatusBadRequest) } var payload map[string]string if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil { t.Fatalf("failed to decode response: %v", err) } if payload["code"] != "UPLOAD_SESSION_EXPIRED" { t.Fatalf("unexpected code: %q", payload["code"]) } if payload["error"] == "" { t.Fatal("expected non-empty error message") } }