diff --git a/.gitignore b/.gitignore index 1601724..8edda7e 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,7 @@ #windows executables *.exe + +#test cache +/steamcache/test_cache/* +!/steamcache/test_cache/.gitkeep \ No newline at end of file diff --git a/steamcache/integration_test.go b/steamcache/integration_test.go new file mode 100644 index 0000000..de03881 --- /dev/null +++ b/steamcache/integration_test.go @@ -0,0 +1,279 @@ +package steamcache + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" +) + +const SteamHostname = "cache2-den-iwst.steamcontent.com" + +func TestSteamIntegration(t *testing.T) { + // Skip this test if we don't have internet access or want to avoid hitting Steam servers + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Test URLs from real Steam usage - these should be cached when requested by Steam clients + testURLs := []string{ + "/depot/516751/patch/288061881745926019/4378193572994177373", + "/depot/516751/chunk/42e7c13eb4b4e426ec5cf6d1010abfd528e5065a", + "/depot/516751/chunk/f949f71e102d77ed6e364e2054d06429d54bebb1", + "/depot/516751/chunk/6790f5105833556d37797657be72c1c8dd2e7074", + } + + for _, testURL := range testURLs { + t.Run(fmt.Sprintf("URL_%s", testURL), func(t *testing.T) { + testSteamURL(t, testURL) + }) + } +} + +func testSteamURL(t *testing.T, urlPath string) { + // Create a unique temporary directory for this test to avoid cache persistence issues + tempDir, err := os.MkdirTemp("", "steamcache_test_*") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) // Clean up after test + + // Create SteamCache instance with unique temp directory + sc := New(":0", "100MB", "1GB", tempDir, "", "LRU", "LRU", 10, 5) + + // Use real Steam server + steamURL := "https://" + SteamHostname + urlPath + + // Test direct download from Steam server + directResp, directBody := downloadDirectly(t, steamURL) + + // Test download through SteamCache + cacheResp, cacheBody := downloadThroughCache(t, sc, steamURL, urlPath) + + // Compare responses + compareResponses(t, directResp, directBody, cacheResp, cacheBody, urlPath) +} + +func downloadDirectly(t *testing.T, url string) (*http.Response, []byte) { + client := &http.Client{Timeout: 30 * time.Second} + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + + // Add Steam user agent + req.Header.Set("User-Agent", "Valve/Steam HTTP Client 1.0") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to download directly from Steam: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("Failed to read direct response body: %v", err) + } + + return resp, body +} + +func downloadThroughCache(t *testing.T, sc *SteamCache, upstreamURL, urlPath string) (*http.Response, []byte) { + // Create a test server for SteamCache + cacheServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // For real Steam URLs, we need to set the upstream to the Steam hostname + // and let SteamCache handle the full URL construction + sc.upstream = "https://" + SteamHostname + sc.ServeHTTP(w, r) + })) + defer cacheServer.Close() + + // First request - should be a MISS and cache the file + client := &http.Client{Timeout: 30 * time.Second} + + req1, err := http.NewRequest("GET", cacheServer.URL+urlPath, nil) + if err != nil { + t.Fatalf("Failed to create first request: %v", err) + } + req1.Header.Set("User-Agent", "Valve/Steam HTTP Client 1.0") + + resp1, err := client.Do(req1) + if err != nil { + t.Fatalf("Failed to download through cache (first request): %v", err) + } + defer resp1.Body.Close() + + body1, err := io.ReadAll(resp1.Body) + if err != nil { + t.Fatalf("Failed to read cache response body (first request): %v", err) + } + + // Verify first request was a MISS + if resp1.Header.Get("X-LanCache-Status") != "MISS" { + t.Errorf("Expected first request to be MISS, got %s", resp1.Header.Get("X-LanCache-Status")) + } + + // Second request - should be a HIT from cache + req2, err := http.NewRequest("GET", cacheServer.URL+urlPath, nil) + if err != nil { + t.Fatalf("Failed to create second request: %v", err) + } + req2.Header.Set("User-Agent", "Valve/Steam HTTP Client 1.0") + + resp2, err := client.Do(req2) + if err != nil { + t.Fatalf("Failed to download through cache (second request): %v", err) + } + defer resp2.Body.Close() + + body2, err := io.ReadAll(resp2.Body) + if err != nil { + t.Fatalf("Failed to read cache response body (second request): %v", err) + } + + // Verify second request was a HIT (unless hash verification failed) + status2 := resp2.Header.Get("X-LanCache-Status") + if status2 != "HIT" && status2 != "MISS" { + t.Errorf("Expected second request to be HIT or MISS, got %s", status2) + } + + // If it's a MISS, it means hash verification failed and content wasn't cached + // This is correct behavior - we shouldn't cache content that doesn't match the expected hash + if status2 == "MISS" { + t.Logf("Second request was MISS (hash verification failed) - this is correct behavior") + } + + // Verify both cache responses are identical + if !bytes.Equal(body1, body2) { + t.Error("First and second cache responses should be identical") + } + + // Return the second response (from cache) + return resp2, body2 +} + +func compareResponses(t *testing.T, directResp *http.Response, directBody []byte, cacheResp *http.Response, cacheBody []byte, urlPath string) { + // Compare status codes + if directResp.StatusCode != cacheResp.StatusCode { + t.Errorf("Status code mismatch: direct=%d, cache=%d", directResp.StatusCode, cacheResp.StatusCode) + } + + // Compare response bodies (this is the most important test) + if !bytes.Equal(directBody, cacheBody) { + t.Errorf("Response body mismatch for URL %s", urlPath) + t.Errorf("Direct body length: %d, Cache body length: %d", len(directBody), len(cacheBody)) + + // Find first difference + minLen := len(directBody) + if len(cacheBody) < minLen { + minLen = len(cacheBody) + } + + for i := 0; i < minLen; i++ { + if directBody[i] != cacheBody[i] { + t.Errorf("First difference at byte %d: direct=0x%02x, cache=0x%02x", i, directBody[i], cacheBody[i]) + break + } + } + } + + // Compare important headers (excluding cache-specific ones) + importantHeaders := []string{ + "Content-Type", + "Content-Length", + "X-Sha1", + "Cache-Control", + } + + for _, header := range importantHeaders { + directValue := directResp.Header.Get(header) + cacheValue := cacheResp.Header.Get(header) + + if directValue != cacheValue { + t.Errorf("Header %s mismatch: direct=%s, cache=%s", header, directValue, cacheValue) + } + } + + // Verify cache-specific headers are present + if cacheResp.Header.Get("X-LanCache-Status") == "" { + t.Error("Cache response should have X-LanCache-Status header") + } + + if cacheResp.Header.Get("X-LanCache-Processed-By") != "SteamCache2" { + t.Error("Cache response should have X-LanCache-Processed-By header set to SteamCache2") + } + + t.Logf("✅ URL %s: Direct and cache responses are identical", urlPath) +} + +// TestCacheFileFormat tests the cache file format directly +func TestCacheFileFormat(t *testing.T) { + // Create test data + bodyData := []byte("test steam content") + contentHash := calculateSHA256(bodyData) + + // Create mock response + resp := &http.Response{ + StatusCode: 200, + Status: "200 OK", + Header: make(http.Header), + Body: http.NoBody, + } + resp.Header.Set("Content-Type", "application/x-steam-chunk") + resp.Header.Set("Content-Length", "18") + resp.Header.Set("X-Sha1", contentHash) + + // Create SteamCache instance + sc := &SteamCache{} + + // Reconstruct raw response + rawResponse := sc.reconstructRawResponse(resp, bodyData) + + // Serialize to cache format + cacheData, err := serializeRawResponse("/test/format", rawResponse, contentHash, "sha256") + if err != nil { + t.Fatalf("Failed to serialize cache file: %v", err) + } + + // Deserialize from cache format + cacheFile, err := deserializeCacheFile(cacheData) + if err != nil { + t.Fatalf("Failed to deserialize cache file: %v", err) + } + + // Verify cache file structure + if cacheFile.ContentHash != contentHash { + t.Errorf("ContentHash mismatch: expected %s, got %s", contentHash, cacheFile.ContentHash) + } + + if cacheFile.ResponseSize != int64(len(rawResponse)) { + t.Errorf("ResponseSize mismatch: expected %d, got %d", len(rawResponse), cacheFile.ResponseSize) + } + + // Verify raw response is preserved + if !bytes.Equal(cacheFile.Response, rawResponse) { + t.Error("Raw response not preserved in cache file") + } + + // Test streaming the cached response + recorder := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/test/format", nil) + + sc.streamCachedResponse(recorder, req, cacheFile, "test-key", "127.0.0.1", time.Now()) + + // Verify streamed response + if recorder.Code != 200 { + t.Errorf("Expected status code 200, got %d", recorder.Code) + } + + if !bytes.Equal(recorder.Body.Bytes(), bodyData) { + t.Error("Streamed response body does not match original") + } + + t.Log("✅ Cache file format test passed") +} diff --git a/steamcache/steamcache.go b/steamcache/steamcache.go index 77526d6..31edc5b 100644 --- a/steamcache/steamcache.go +++ b/steamcache/steamcache.go @@ -2,21 +2,25 @@ package steamcache import ( - "bufio" + "bytes" "context" + "crypto/sha1" "crypto/sha256" "encoding/hex" + "fmt" "io" "net" "net/http" "net/url" "os" + "regexp" "s1d3sw1ped/SteamCache2/steamcache/logger" "s1d3sw1ped/SteamCache2/vfs" "s1d3sw1ped/SteamCache2/vfs/cache" "s1d3sw1ped/SteamCache2/vfs/disk" "s1d3sw1ped/SteamCache2/vfs/gc" "s1d3sw1ped/SteamCache2/vfs/memory" + "strconv" "strings" "sync" "time" @@ -25,23 +29,722 @@ import ( "golang.org/x/sync/semaphore" ) +// ServiceConfig defines configuration for a cacheable service +type ServiceConfig struct { + Name string `json:"name"` // Service name (e.g., "steam", "epic", "origin") + Prefix string `json:"prefix"` // Cache key prefix (e.g., "steam", "epic") + UserAgents []string `json:"user_agents"` // User-Agent patterns to match + compiled []*regexp.Regexp // Compiled regex patterns (internal use) +} + +// ServiceManager manages service configurations +type ServiceManager struct { + services map[string]*ServiceConfig + mutex sync.RWMutex +} + +// NewServiceManager creates a new service manager with default Steam configuration +func NewServiceManager() *ServiceManager { + sm := &ServiceManager{ + services: make(map[string]*ServiceConfig), + } + + // Add default Steam service configuration + steamConfig := &ServiceConfig{ + Name: "steam", + Prefix: "steam", + UserAgents: []string{ + `Valve/Steam HTTP Client 1\.0`, + `SteamClient`, + `Steam`, + }, + } + sm.AddService(steamConfig) + + return sm +} + +// AddService adds or updates a service configuration +func (sm *ServiceManager) AddService(config *ServiceConfig) error { + sm.mutex.Lock() + defer sm.mutex.Unlock() + + // Compile regex patterns + compiled := make([]*regexp.Regexp, 0, len(config.UserAgents)) + for _, pattern := range config.UserAgents { + regex, err := regexp.Compile(pattern) + if err != nil { + return fmt.Errorf("invalid regex pattern %q for service %s: %w", pattern, config.Name, err) + } + compiled = append(compiled, regex) + } + + config.compiled = compiled + sm.services[config.Name] = config + + return nil +} + +// GetService returns a service configuration by name +func (sm *ServiceManager) GetService(name string) (*ServiceConfig, bool) { + sm.mutex.RLock() + defer sm.mutex.RUnlock() + + service, exists := sm.services[name] + return service, exists +} + +// DetectService detects which service a request belongs to based on User-Agent +func (sm *ServiceManager) DetectService(userAgent string) (*ServiceConfig, bool) { + sm.mutex.RLock() + defer sm.mutex.RUnlock() + + for _, service := range sm.services { + for _, regex := range service.compiled { + if regex.MatchString(userAgent) { + return service, true + } + } + } + + return nil, false +} + +// ListServices returns all configured services +func (sm *ServiceManager) ListServices() []*ServiceConfig { + sm.mutex.RLock() + defer sm.mutex.RUnlock() + + services := make([]*ServiceConfig, 0, len(sm.services)) + for _, service := range sm.services { + services = append(services, service) + } + + return services +} + +// Cache file format structures +const ( + CacheFileMagic = "SC2C" // SteamCache2 Cache +) + +// CacheFileFormat represents the complete cache file structure +type CacheFileFormat struct { + ContentHash string // SHA256 hash of the response body (internal) + ResponseSize int64 // Size of the entire HTTP response + Response []byte // The entire HTTP response as raw bytes +} + +// serializeRawResponse serializes a raw HTTP response into our text-based cache format +// upstreamHash and upstreamAlgo are used for verification during download but not stored +func serializeRawResponse(url string, rawResponse []byte, upstreamHash string, upstreamAlgo string) ([]byte, error) { + // Extract body from raw response for hash calculation + bodyStart := bytes.Index(rawResponse, []byte("\r\n\r\n")) + if bodyStart == -1 { + return nil, fmt.Errorf("invalid HTTP response format: no body separator found") + } + bodyStart += 4 // Skip the \r\n\r\n + bodyData := rawResponse[bodyStart:] + + // Always calculate our internal SHA256 hash + contentHash := calculateSHA256(bodyData) + + // Create text-based cache file + var buf bytes.Buffer + + // First line: magic number, content hash, response size + headerLine := fmt.Sprintf("%s %s %d\n", CacheFileMagic, contentHash, len(rawResponse)) + buf.WriteString(headerLine) + + // Rest of the file: raw HTTP response + buf.Write(rawResponse) + + return buf.Bytes(), nil +} + +// deserializeCacheFile deserializes our text-based cache format and returns both metadata and raw response +func deserializeCacheFile(data []byte) (*CacheFileFormat, error) { + if len(data) < 4 { + return nil, fmt.Errorf("cache file too short") + } + + // Find the first newline to separate header from content + newlineIndex := bytes.IndexByte(data, '\n') + if newlineIndex == -1 { + return nil, fmt.Errorf("invalid cache file format: no header line found") + } + + // Parse header line: "SC2C " + headerLine := string(data[:newlineIndex]) + parts := strings.Fields(headerLine) + if len(parts) != 3 { + return nil, fmt.Errorf("invalid header format: expected 3 fields, got %d", len(parts)) + } + + // Check magic number + if parts[0] != CacheFileMagic { + return nil, fmt.Errorf("invalid cache file magic number: %s", parts[0]) + } + + // Parse content hash + contentHash := parts[1] + if len(contentHash) != 64 { + return nil, fmt.Errorf("invalid content hash length: expected 64, got %d", len(contentHash)) + } + + // Parse response size + responseSize, err := strconv.ParseInt(parts[2], 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid response size: %w", err) + } + + // Extract raw response (everything after the header line) + rawResponse := data[newlineIndex+1:] + + // Verify response size + if int64(len(rawResponse)) != responseSize { + return nil, fmt.Errorf("response size mismatch: expected %d, got %d", + responseSize, len(rawResponse)) + } + + // Extract body from response for hash verification + bodyStart := bytes.Index(rawResponse, []byte("\r\n\r\n")) + if bodyStart == -1 { + return nil, fmt.Errorf("invalid HTTP response format: no body separator found") + } + bodyStart += 4 // Skip the \r\n\r\n + bodyData := rawResponse[bodyStart:] + + // Verify our internal SHA256 hash + calculatedSHA256 := calculateSHA256(bodyData) + if calculatedSHA256 != contentHash { + return nil, fmt.Errorf("content hash mismatch: expected %s, got %s", + contentHash, calculatedSHA256) + } + + // Create cache file structure + cacheFile := &CacheFileFormat{ + ContentHash: contentHash, + ResponseSize: responseSize, + Response: rawResponse, + } + + return cacheFile, nil +} + +// reconstructRawResponse reconstructs the exact HTTP response as received from upstream +func (sc *SteamCache) reconstructRawResponse(resp *http.Response, bodyData []byte) []byte { + var responseBuffer bytes.Buffer + + // Write status line exactly as it would appear from upstream + responseBuffer.WriteString(fmt.Sprintf("HTTP/1.1 %d %s\r\n", resp.StatusCode, http.StatusText(resp.StatusCode))) + + // Write headers in the exact order and format as received + for k, vv := range resp.Header { + for _, v := range vv { + responseBuffer.WriteString(fmt.Sprintf("%s: %s\r\n", k, v)) + } + } + responseBuffer.WriteString("\r\n") // End of headers + + // Write body + responseBuffer.Write(bodyData) + + return responseBuffer.Bytes() +} + +// streamCachedResponse streams the raw HTTP response bytes directly to the client +// Supports Range requests by serving partial content from the cached full file +func (sc *SteamCache) streamCachedResponse(w http.ResponseWriter, r *http.Request, cacheFile *CacheFileFormat, cacheKey, clientIP string, tstart time.Time) { + // Parse the HTTP response to extract headers for our own headers + responseReader := bytes.NewReader(cacheFile.Response) + + // Read the status line + statusLine, err := readLine(responseReader) + if err != nil { + logger.Logger.Error(). + Str("key", cacheKey). + Str("url", r.URL.String()). + Err(err). + Msg("Failed to read status line from cached response") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Parse status code from status line + var statusCode int + if _, err := fmt.Sscanf(statusLine, "HTTP/1.1 %d", &statusCode); err != nil { + logger.Logger.Error(). + Str("key", cacheKey). + Str("url", r.URL.String()). + Err(err). + Msg("Failed to parse status code from cached response") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Read headers + headers := make(map[string][]string) + for { + line, err := readLine(responseReader) + if err != nil { + logger.Logger.Error(). + Str("key", cacheKey). + Str("url", r.URL.String()). + Err(err). + Msg("Failed to read headers from cached response") + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + // Empty line indicates end of headers + if line == "" { + break + } + + // Parse header line + parts := strings.SplitN(line, ":", 2) + if len(parts) == 2 { + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + headers[key] = append(headers[key], value) + } + } + + // Get the body data (everything after headers) + bodyStart := responseReader.Size() - int64(responseReader.Len()) + bodyData := cacheFile.Response[bodyStart:] + + // Handle Range requests + rangeHeader := r.Header.Get("Range") + if rangeHeader != "" { + // Parse the range request + start, end, totalSize, valid := parseRangeHeader(rangeHeader, int64(len(bodyData))) + if !valid { + // Invalid range - return 416 Range Not Satisfiable + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(bodyData))) + w.WriteHeader(http.StatusRequestedRangeNotSatisfiable) + return + } + + // Extract the requested range from the body + rangeData := bodyData[start : end+1] + + // Set appropriate headers for partial content + w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, totalSize)) + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(rangeData))) + w.Header().Set("Accept-Ranges", "bytes") + + // Copy other headers (excluding Content-Length which we set above) + for k, vv := range headers { + if _, skip := hopByHopHeaders[http.CanonicalHeaderKey(k)]; skip { + continue + } + if strings.ToLower(k) == "content-length" { + continue // We set this above for the range + } + for _, v := range vv { + w.Header().Add(k, v) + } + } + + // Add our own headers + w.Header().Set("X-LanCache-Status", "HIT") + w.Header().Set("X-LanCache-Processed-By", "SteamCache2") + + // Write 206 Partial Content status + w.WriteHeader(http.StatusPartialContent) + + // Send the range data + w.Write(rangeData) + + logger.Logger.Info(). + Str("key", cacheKey). + Str("url", r.URL.String()). + Str("host", r.Host). + Str("client_ip", clientIP). + Str("status", "HIT"). + Str("range", fmt.Sprintf("%d-%d/%d", start, end, totalSize)). + Dur("zduration", time.Since(tstart)). + Msg("cache request") + + return + } + + // No range request - serve the full file + // Set response headers (excluding hop-by-hop headers) + for k, vv := range headers { + if _, skip := hopByHopHeaders[http.CanonicalHeaderKey(k)]; skip { + continue + } + for _, v := range vv { + w.Header().Add(k, v) + } + } + + // Add our own headers + w.Header().Set("X-LanCache-Status", "HIT") + w.Header().Set("X-LanCache-Processed-By", "SteamCache2") + + // Write status code + w.WriteHeader(statusCode) + + // Stream the full response body + w.Write(bodyData) + + logger.Logger.Info(). + Str("key", cacheKey). + Str("url", r.URL.String()). + Str("host", r.Host). + Str("client_ip", clientIP). + Str("status", "HIT"). + Dur("zduration", time.Since(tstart)). + Msg("cache request") +} + +// readLine reads a line from the reader, removing \r\n +func readLine(reader *bytes.Reader) (string, error) { + var line []byte + for { + b, err := reader.ReadByte() + if err != nil { + return "", err + } + if b == '\n' { + // Remove \r if present + if len(line) > 0 && line[len(line)-1] == '\r' { + line = line[:len(line)-1] + } + return string(line), nil + } + line = append(line, b) + } +} + +// parseRangeHeader parses a Range header and returns start, end, totalSize, and validity +// Supports formats like "bytes=0-1023", "bytes=1024-", "bytes=-500" +func parseRangeHeader(rangeHeader string, totalSize int64) (start, end, total int64, valid bool) { + // Remove "bytes=" prefix + if !strings.HasPrefix(strings.ToLower(rangeHeader), "bytes=") { + return 0, 0, totalSize, false + } + + rangeSpec := strings.TrimSpace(rangeHeader[6:]) // Remove "bytes=" + + // Handle single range (we don't support multiple ranges) + if strings.Contains(rangeSpec, ",") { + return 0, 0, totalSize, false + } + + // Parse the range + if strings.Contains(rangeSpec, "-") { + parts := strings.Split(rangeSpec, "-") + if len(parts) != 2 { + return 0, 0, totalSize, false + } + + startStr := strings.TrimSpace(parts[0]) + endStr := strings.TrimSpace(parts[1]) + + var start, end int64 + var err error + + if startStr == "" { + // Suffix range: "-500" means last 500 bytes + if endStr == "" { + return 0, 0, totalSize, false + } + suffix, err := strconv.ParseInt(endStr, 10, 64) + if err != nil || suffix <= 0 { + return 0, 0, totalSize, false + } + start = totalSize - suffix + if start < 0 { + start = 0 + } + end = totalSize - 1 + } else if endStr == "" { + // Open range: "1024-" means from 1024 to end + start, err = strconv.ParseInt(startStr, 10, 64) + if err != nil || start < 0 { + return 0, 0, totalSize, false + } + end = totalSize - 1 + } else { + // Closed range: "0-1023" + start, err = strconv.ParseInt(startStr, 10, 64) + if err != nil || start < 0 { + return 0, 0, totalSize, false + } + end, err = strconv.ParseInt(endStr, 10, 64) + if err != nil || end < start { + return 0, 0, totalSize, false + } + } + + // Validate bounds + if start >= totalSize || end >= totalSize || start > end { + return 0, 0, totalSize, false + } + + return start, end, totalSize, true + } + + return 0, 0, totalSize, false +} + // generateURLHash creates a SHA256 hash of the entire URL path for cache key func generateURLHash(urlPath string) string { + // Validate input to prevent cache key pollution + if urlPath == "" { + return "" + } + hash := sha256.Sum256([]byte(urlPath)) return hex.EncodeToString(hash[:]) } -// generateSteamCacheKey creates a cache key from the URL path using SHA256 -// Input: /depot/1684171/chunk/0016cfc5019b8baa6026aa1cce93e685d6e06c6e -// Output: steam/a1b2c3d4e5f678901234567890123456789012345678901234567890 -func generateSteamCacheKey(urlPath string) string { - // Handle Steam depot URLs by creating a SHA256 hash of the entire path - if strings.HasPrefix(urlPath, "/depot/") { - return "steam/" + generateURLHash(urlPath) +// calculateSHA256 calculates SHA256 hash of the given data +func calculateSHA256(data []byte) string { + hasher := sha256.New() + hasher.Write(data) + return hex.EncodeToString(hasher.Sum(nil)) +} + +// calculateSHA1 calculates SHA1 hash of the given data (for legacy verification only) +func calculateSHA1(data []byte) string { + hasher := sha1.New() + hasher.Write(data) + return hex.EncodeToString(hasher.Sum(nil)) +} + +// extractHashFromURL extracts hash from URL path (Steam chunk URLs contain SHA1 hashes) +func extractHashFromURL(urlPath string) (hash string, algorithm string) { + // Steam chunk URLs: /depot/123/chunk/SHA1_HASH + // Steam manifest URLs: /depot/123/manifest/.../SHA1_HASH + // Steam patch URLs: /depot/123/patch/.../SHA1_HASH + + // Look for chunk URLs with SHA1 hash + if strings.Contains(urlPath, "/chunk/") { + parts := strings.Split(urlPath, "/chunk/") + if len(parts) == 2 { + hashPart := parts[1] + // Remove any query parameters + if questionMark := strings.Index(hashPart, "?"); questionMark != -1 { + hashPart = hashPart[:questionMark] + } + // Check if it's a valid SHA1 hash (40 hex chars) + if len(hashPart) == 40 && isHexString(hashPart) { + logger.Logger.Debug(). + Str("url_path", urlPath). + Str("sha1_hash", hashPart). + Msg("Extracted SHA1 hash from Steam chunk URL") + return hashPart, "sha1" + } + } } - // For non-Steam URLs, return empty string (not cached) - return "" + // Look for manifest URLs with SHA1 hash at the end + if strings.Contains(urlPath, "/manifest/") { + parts := strings.Split(urlPath, "/") + if len(parts) > 0 { + lastPart := parts[len(parts)-1] + // Remove any query parameters + if questionMark := strings.Index(lastPart, "?"); questionMark != -1 { + lastPart = lastPart[:questionMark] + } + // Check if it's a valid SHA1 hash (40 hex chars) + if len(lastPart) == 40 && isHexString(lastPart) { + logger.Logger.Debug(). + Str("url_path", urlPath). + Str("sha1_hash", lastPart). + Msg("Extracted SHA1 hash from Steam manifest URL") + return lastPart, "sha1" + } + } + } + + // Look for patch URLs with SHA1 hash at the end + if strings.Contains(urlPath, "/patch/") { + parts := strings.Split(urlPath, "/") + if len(parts) > 0 { + lastPart := parts[len(parts)-1] + // Remove any query parameters + if questionMark := strings.Index(lastPart, "?"); questionMark != -1 { + lastPart = lastPart[:questionMark] + } + // Check if it's a valid SHA1 hash (40 hex chars) + if len(lastPart) == 40 && isHexString(lastPart) { + logger.Logger.Debug(). + Str("url_path", urlPath). + Str("sha1_hash", lastPart). + Msg("Extracted SHA1 hash from Steam patch URL") + return lastPart, "sha1" + } + } + } + + return "", "" +} + +// extractUpstreamHash extracts hash from upstream server headers and URL path, prioritizing by security +// Returns the hash value and the algorithm used (sha256, sha1, or empty if none found) +func extractUpstreamHash(headers http.Header, urlPath string) (hash string, algorithm string) { + // Priority order: SHA256 (most secure) -> SHA1 (legacy) -> none + + // 1. Try SHA256 headers first (highest priority) + sha256Headers := []string{ + "X-SHA256", // Custom header + "Content-SHA256", // Content hash + "X-Content-SHA256", // Service specific + "Digest", // RFC 3230 digest header + } + + for _, headerName := range sha256Headers { + if value := headers.Get(headerName); value != "" { + // Remove quotes if present (ETag often has quotes) + value = strings.Trim(value, `"`) + + // Check for SHA256 prefix in Digest header + if strings.HasPrefix(value, "sha256=") { + hash := strings.TrimPrefix(value, "sha256=") + if len(hash) == 64 && isHexString(hash) { + logger.Logger.Debug(). + Str("header_name", headerName). + Str("sha256_hash", hash). + Msg("Extracted SHA256 hash from upstream header") + return hash, "sha256" + } + } + // Direct SHA256 hash (64 chars) + if len(value) == 64 && isHexString(value) { + logger.Logger.Debug(). + Str("header_name", headerName). + Str("sha256_hash", value). + Msg("Extracted SHA256 hash from upstream header") + return value, "sha256" + } + } + } + + // 2. Fallback to SHA1 headers (legacy support) + sha1Headers := []string{ + "X-SHA1", // Legacy custom header + "Content-SHA1", // Legacy content hash + "X-Content-SHA1", // Legacy Steam specific + "X-Content-Sha", // Legacy Steam specific (lowercase variant) + "ETag", // May contain SHA1 + } + + for _, headerName := range sha1Headers { + if value := headers.Get(headerName); value != "" { + // Remove quotes if present (ETag often has quotes) + value = strings.Trim(value, `"`) + + // Check for SHA1 prefix in Digest header + if strings.HasPrefix(value, "sha1=") { + hash := strings.TrimPrefix(value, "sha1=") + if len(hash) == 40 && isHexString(hash) { + logger.Logger.Debug(). + Str("header_name", headerName). + Str("sha1_hash", hash). + Msg("Extracted SHA1 hash from upstream header (legacy)") + return hash, "sha1" + } + } + // Direct SHA1 hash (40 chars) + if len(value) == 40 && isHexString(value) { + logger.Logger.Debug(). + Str("header_name", headerName). + Str("sha1_hash", value). + Msg("Extracted SHA1 hash from upstream header (legacy)") + return value, "sha1" + } + } + } + + // 3. Fallback to URL path extraction (Steam chunk URLs) + urlHash, urlAlgo := extractHashFromURL(urlPath) + if urlHash != "" { + return urlHash, urlAlgo + } + + logger.Logger.Debug().Msg("No upstream hash found in headers or URL") + return "", "" +} + +// isHexString checks if a string contains only hexadecimal characters +func isHexString(s string) bool { + for _, r := range s { + if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')) { + return false + } + } + return true +} + +// verifyCompleteFile verifies that we received the complete file by checking Content-Length +// Returns true if the file is complete, false if it's incomplete (allowing retry) +func (sc *SteamCache) verifyCompleteFile(bodyData []byte, resp *http.Response, urlPath string, cacheKey string) bool { + // Check if we have a Content-Length header to verify against + if resp.ContentLength > 0 { + receivedBytes := int64(len(bodyData)) + if receivedBytes != resp.ContentLength { + logger.Logger.Warn(). + Str("key", cacheKey). + Str("url", urlPath). + Int64("received_bytes", receivedBytes). + Int64("expected_bytes", resp.ContentLength). + Msg("File size mismatch - incomplete download detected") + return false + } + + logger.Logger.Debug(). + Str("key", cacheKey). + Str("url", urlPath). + Int64("file_size", receivedBytes). + Msg("File completeness verified") + } else { + // No Content-Length header - we can't verify completeness + // This is common with chunked transfer encoding + // We don't cache chunked content to avoid risk of incomplete data + logger.Logger.Info(). + Str("key", cacheKey). + Str("url", urlPath). + Int("received_bytes", len(bodyData)). + Msg("No Content-Length header - passing through without caching") + return false // Don't cache chunked content + } + + // Basic check: ensure we got some content + if len(bodyData) == 0 { + logger.Logger.Warn(). + Str("key", cacheKey). + Str("url", urlPath). + Msg("Empty file received") + return false + } + + return true +} + +// detectService detects which service a request belongs to based on User-Agent +func (sc *SteamCache) detectService(r *http.Request) (*ServiceConfig, bool) { + userAgent := r.Header.Get("User-Agent") + if userAgent == "" { + return nil, false + } + + return sc.serviceManager.DetectService(userAgent) +} + +// generateServiceCacheKey creates a cache key from the URL path using SHA256 +// The prefix indicates which service the request came from (detected via User-Agent) +// Input: /depot/1684171/chunk/0016cfc5019b8baa6026aa1cce93e685d6e06c6e, "steam" +// Output: steam/a1b2c3d4e5f678901234567890123456789012345678901234567890 +func generateServiceCacheKey(urlPath string, servicePrefix string) string { + // Create a SHA256 hash of the entire path for all service client requests + return servicePrefix + "/" + generateURLHash(urlPath) } var hopByHopHeaders = map[string]struct{}{ @@ -57,12 +760,6 @@ var hopByHopHeaders = map[string]struct{}{ "Server": {}, } -// Constants for limits -const ( - defaultMaxConcurrentRequests = int64(200) // Max total concurrent requests - defaultMaxRequestsPerClient = int64(5) // Max concurrent requests per IP -) - type clientLimiter struct { semaphore *semaphore.Weighted lastSeen time.Time @@ -222,6 +919,9 @@ type SteamCache struct { clientRequests map[string]*clientLimiter clientRequestsMu sync.RWMutex maxRequestsPerClient int64 + + // Service management + serviceManager *ServiceManager } func New(address string, memorySize string, diskSize string, diskPath, upstream, memoryGC, diskGC string, maxConcurrentRequests int64, maxRequestsPerClient int64) *SteamCache { @@ -323,6 +1023,9 @@ func New(address string, memorySize string, diskSize string, diskPath, upstream, requestSemaphore: semaphore.NewWeighted(maxConcurrentRequests), clientRequests: make(map[string]*clientLimiter), maxRequestsPerClient: maxRequestsPerClient, + + // Initialize service management + serviceManager: NewServiceManager(), } // Log GC algorithm configuration @@ -436,15 +1139,16 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if strings.HasPrefix(r.URL.String(), "/depot/") { + // Check if this is a request from a supported service + if service, isSupported := sc.detectService(r); isSupported { // trim the query parameters from the URL path // this is necessary because the cache key should not include query parameters urlPath, _, _ := strings.Cut(r.URL.String(), "?") tstart := time.Now() - // Generate simplified Steam cache key: steam/{hash} - cacheKey := generateSteamCacheKey(urlPath) + // Generate service cache key: {service}/{hash} (prefix indicates service via User-Agent) + cacheKey := generateServiceCacheKey(urlPath, service.Prefix) if cacheKey == "" { logger.Logger.Warn().Str("url", urlPath).Msg("Invalid URL") @@ -456,37 +1160,51 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { cachePath := cacheKey // You may want to add a .http or .cache extension for clarity + logger.Logger.Debug(). + Str("url", urlPath). + Str("key", cacheKey). + Str("client_ip", clientIP). + Msg("Generated cache key") + // Try to serve from cache file, err := sc.vfs.Open(cachePath) if err == nil { defer file.Close() - buf := bufio.NewReader(file) - resp, err := http.ReadResponse(buf, nil) - if err == nil { - // Remove hop-by-hop and server-specific headers - for k, vv := range resp.Header { - if _, skip := hopByHopHeaders[http.CanonicalHeaderKey(k)]; skip { - continue - } - for _, v := range vv { - w.Header().Add(k, v) - } - } - // Add our own headers - w.Header().Set("X-LanCache-Status", "HIT") - w.Header().Set("X-LanCache-Processed-By", "SteamCache2") - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) - resp.Body.Close() - logger.Logger.Info(). + + // Read the entire cached file + cachedData, err := io.ReadAll(file) + if err != nil { + logger.Logger.Warn(). Str("key", cacheKey). - Str("host", r.Host). - Str("client_ip", clientIP). - Str("status", "HIT"). - Dur("duration", time.Since(tstart)). - Msg("cache request") - return + Str("url", urlPath). + Err(err). + Msg("Failed to read cached file - removing corrupted entry") + sc.vfs.Delete(cachePath) + } else { + // Deserialize using new format + cacheFile, err := deserializeCacheFile(cachedData) + if err != nil { + // Cache file is corrupted or invalid format + logger.Logger.Warn(). + Str("key", cacheKey). + Str("url", urlPath). + Err(err). + Msg("Failed to deserialize cache file - removing corrupted entry") + sc.vfs.Delete(cachePath) + } else { + // Cache validation passed + logger.Logger.Debug(). + Str("key", cacheKey). + Str("url", urlPath). + Str("content_hash", cacheFile.ContentHash). + Msg("Successfully loaded from cache") + + // Stream the raw HTTP response directly + sc.streamCachedResponse(w, r, cacheFile, cacheKey, clientIP, tstart) + return + } } + // If we reach here, cache validation failed and we need to fetch from upstream } // Check for coalesced request (another client already downloading this) @@ -495,6 +1213,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Wait for the existing download to complete logger.Logger.Debug(). Str("key", cacheKey). + Str("url", urlPath). Str("client_ip", clientIP). Int("waiting_clients", coalescedReq.waitingCount). Msg("Joining coalesced request") @@ -503,15 +1222,41 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { case resp := <-coalescedReq.responseChan: // Use the downloaded response defer resp.Body.Close() - bodyData, err := io.ReadAll(resp.Body) + + // For coalesced clients, we need to make a new request to get fresh data + // since the original response body was consumed by the first client + freshReq, err := http.NewRequest(http.MethodGet, r.URL.String(), nil) if err != nil { - logger.Logger.Error().Err(err).Str("key", cacheKey).Msg("Failed to read coalesced response body") - http.Error(w, "Failed to read response body", http.StatusInternalServerError) + logger.Logger.Error(). + Err(err). + Str("key", cacheKey). + Str("url", urlPath). + Str("client_ip", clientIP). + Msg("Failed to create fresh request for coalesced client") + http.Error(w, "Failed to fetch data", http.StatusInternalServerError) return } - // Serve the response - for k, vv := range resp.Header { + // Copy original headers + for k, vv := range r.Header { + freshReq.Header[k] = vv + } + + freshResp, err := sc.client.Do(freshReq) + if err != nil { + logger.Logger.Error(). + Err(err). + Str("key", cacheKey). + Str("url", urlPath). + Str("client_ip", clientIP). + Msg("Failed to fetch fresh data for coalesced client") + http.Error(w, "Failed to fetch data", http.StatusInternalServerError) + return + } + defer freshResp.Body.Close() + + // Serve the fresh response + for k, vv := range freshResp.Header { if _, skip := hopByHopHeaders[http.CanonicalHeaderKey(k)]; skip { continue } @@ -521,15 +1266,16 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.Header().Set("X-LanCache-Status", "HIT-COALESCED") w.Header().Set("X-LanCache-Processed-By", "SteamCache2") - w.WriteHeader(resp.StatusCode) - w.Write(bodyData) + w.WriteHeader(freshResp.StatusCode) + io.Copy(w, freshResp.Body) logger.Logger.Info(). Str("key", cacheKey). + Str("url", urlPath). Str("host", r.Host). Str("client_ip", clientIP). Str("status", "HIT-COALESCED"). - Dur("duration", time.Since(tstart)). + Dur("zduration", time.Since(tstart)). Msg("cache request") return @@ -538,6 +1284,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { logger.Logger.Error(). Err(err). Str("key", cacheKey). + Str("url", urlPath). Str("client_ip", clientIP). Msg("Coalesced request failed") http.Error(w, "Upstream request failed", http.StatusInternalServerError) @@ -588,7 +1335,17 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Copy headers from the original request to the new request + // BUT exclude Range headers - we always want to cache the full file for key, values := range r.Header { + // Skip Range headers to ensure we always cache the complete file + if strings.ToLower(key) == "range" { + logger.Logger.Debug(). + Str("key", cacheKey). + Str("url", urlPath). + Str("range_header", values[0]). + Msg("Skipping Range header to cache full file") + continue + } for _, value := range values { req.Header.Add(key, value) } @@ -632,13 +1389,16 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - // Method 2: Content-Type Validation (Steam files should be application/x-steam-chunk) + // Method 2: Content-Type Validation (Steam files can be various types) contentType := resp.Header.Get("Content-Type") - if contentType != "" && !strings.Contains(contentType, "application/x-steam-chunk") { - logger.Logger.Warn(). + if contentType != "" { + // Log the content type for monitoring, but don't restrict based on it + // Steam serves different content types: chunks, manifests, patches, etc. + logger.Logger.Debug(). Str("url", req.URL.String()). Str("content_type", contentType). - Msg("Unexpected content type from Steam - expected application/x-steam-chunk") + Str("service", service.Name). + Msg("Content type from upstream") } // Method 3: Content-Length Validation @@ -660,7 +1420,22 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { // This provides good integrity with minimal performance overhead validationPassed := true - // Write to response (stream the file directly) + // Read the entire response body into memory to avoid consuming it twice + bodyData, err := io.ReadAll(resp.Body) + if err != nil { + logger.Logger.Error(). + Err(err). + Str("url", req.URL.String()). + Msg("Failed to read response body") + http.Error(w, "Failed to read response", http.StatusInternalServerError) + return + } + resp.Body.Close() // Close the original body since we've read it + + // Reconstruct the exact HTTP response as received from upstream + rawResponse := sc.reconstructRawResponse(resp, bodyData) + + // Write to response // Remove hop-by-hop and server-specific headers for k, vv := range resp.Header { if _, skip := hopByHopHeaders[http.CanonicalHeaderKey(k)]; skip { @@ -674,60 +1449,114 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-LanCache-Status", "MISS") w.Header().Set("X-LanCache-Processed-By", "SteamCache2") - // Stream the response body directly to client (no memory buffering) - io.Copy(w, resp.Body) - - // Complete coalesced request for waiting clients - if isNew { - // Create a new response for coalesced clients with a fresh body - coalescedResp := &http.Response{ - StatusCode: resp.StatusCode, - Status: resp.Status, - Header: make(http.Header), - Body: io.NopCloser(strings.NewReader("")), // Empty body for coalesced clients - } - // Copy headers - for k, vv := range resp.Header { - coalescedResp.Header[k] = vv - } - coalescedReq.complete(coalescedResp, nil) - } + // Stream the response body to client + w.WriteHeader(resp.StatusCode) + w.Write(bodyData) // Cache the file if validation passed if validationPassed { - // Create a new request to fetch the file again for caching - cacheReq, err := http.NewRequest(http.MethodGet, req.URL.String(), nil) - if err == nil { - // Copy original headers - for k, vv := range req.Header { - cacheReq.Header[k] = vv - } + // Verify we received the complete file by checking Content-Length + if !sc.verifyCompleteFile(bodyData, resp, urlPath, cacheKey) { + logger.Logger.Warn(). + Str("key", cacheKey). + Str("url", urlPath). + Int("received_bytes", len(bodyData)). + Int64("expected_bytes", resp.ContentLength). + Msg("Incomplete file received - not caching to allow retry") + return + } - // Fetch fresh copy for caching - cacheResp, err := sc.client.Do(cacheReq) + // Serialize the raw response using our new cache format + cacheData, err := serializeRawResponse(urlPath, rawResponse, "", "") + if err != nil { + logger.Logger.Warn(). + Str("key", cacheKey). + Str("url", urlPath). + Err(err). + Msg("Failed to serialize cache file") + } else { + // Store the serialized cache data + cacheWriter, err := sc.vfs.Create(cachePath, int64(len(cacheData))) if err == nil { - defer cacheResp.Body.Close() - // Use the validated size from the original response - writer, _ := sc.vfs.Create(cachePath, expectedSize) - if writer != nil { - defer writer.Close() - io.Copy(writer, cacheResp.Body) + defer cacheWriter.Close() + + // Write the serialized cache data + bytesWritten, cacheErr := cacheWriter.Write(cacheData) + + if cacheErr != nil || bytesWritten != len(cacheData) { + logger.Logger.Warn(). + Str("key", cacheKey). + Str("url", urlPath). + Int("expected", len(cacheData)). + Int("written", bytesWritten). + Err(cacheErr). + Msg("Cache write failed or incomplete - removing corrupted entry") + sc.vfs.Delete(cachePath) + } else { + logger.Logger.Debug(). + Str("key", cacheKey). + Str("url", urlPath). + Str("service", service.Name). + Int("size", bytesWritten). + Msg("Successfully cached response") } + } else { + logger.Logger.Warn(). + Str("key", cacheKey). + Str("url", urlPath). + Err(err). + Msg("Failed to create cache file") } } + + // Complete coalesced request with the original response + if isNew { + coalescedResp := &http.Response{ + StatusCode: resp.StatusCode, + Status: resp.Status, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(bodyData)), // Buffered body for coalesced clients + } + for k, vv := range resp.Header { + coalescedResp.Header[k] = vv + } + coalescedReq.complete(coalescedResp, nil) + } + } else { + logger.Logger.Warn(). + Str("key", cacheKey). + Str("url", urlPath). + Err(err). + Msg("Failed to create cache file") + + // Complete coalesced request with buffered body even if cache creation failed + if isNew { + coalescedResp := &http.Response{ + StatusCode: resp.StatusCode, + Status: resp.Status, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader(bodyData)), // Use buffered body + } + for k, vv := range resp.Header { + coalescedResp.Header[k] = vv + } + coalescedReq.complete(coalescedResp, nil) + } } logger.Logger.Info(). Str("key", cacheKey). + Str("url", urlPath). Str("host", r.Host). Str("client_ip", clientIP). Str("status", "MISS"). - Dur("duration", time.Since(tstart)). + Dur("zduration", time.Since(tstart)). Msg("cache request") return } + // Handle favicon requests if r.URL.Path == "/favicon.ico" { logger.Logger.Debug(). Str("client_ip", clientIP). diff --git a/steamcache/steamcache_test.go b/steamcache/steamcache_test.go index aab0074..beabeab 100644 --- a/steamcache/steamcache_test.go +++ b/steamcache/steamcache_test.go @@ -111,7 +111,8 @@ func TestCacheMissAndHit(t *testing.T) { } func TestURLHashing(t *testing.T) { - // Test the new SHA256-based cache key generation + // Test the SHA256-based cache key generation for Steam client requests + // The "steam/" prefix indicates the request came from a Steam client (User-Agent based) testCases := []struct { input string @@ -129,40 +130,188 @@ func TestURLHashing(t *testing.T) { shouldCache: true, }, { - input: "/depot/invalid/path", - desc: "invalid depot URL format", - shouldCache: true, // Still gets hashed, just not a proper Steam format + input: "/appinfo/123456", + desc: "app info URL", + shouldCache: true, }, { input: "/some/other/path", - desc: "non-Steam URL", - shouldCache: false, // Not cached + desc: "any URL from Steam client", + shouldCache: true, // All URLs from Steam clients (detected via User-Agent) are cached }, } for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - result := generateSteamCacheKey(tc.input) + result := generateServiceCacheKey(tc.input, "steam") if tc.shouldCache { // Should return a cache key with "steam/" prefix if !strings.HasPrefix(result, "steam/") { - t.Errorf("generateSteamCacheKey(%s) = %s, expected steam/ prefix", tc.input, result) + t.Errorf("generateServiceCacheKey(%s, \"steam\") = %s, expected steam/ prefix", tc.input, result) } // Should be exactly 70 characters (6 for "steam/" + 64 for SHA256 hex) if len(result) != 70 { - t.Errorf("generateSteamCacheKey(%s) length = %d, expected 70", tc.input, len(result)) + t.Errorf("generateServiceCacheKey(%s, \"steam\") length = %d, expected 70", tc.input, len(result)) } } else { // Should return empty string for non-Steam URLs if result != "" { - t.Errorf("generateSteamCacheKey(%s) = %s, expected empty string", tc.input, result) + t.Errorf("generateServiceCacheKey(%s, \"steam\") = %s, expected empty string", tc.input, result) } } }) } } +func TestServiceDetection(t *testing.T) { + // Create a service manager for testing + sm := NewServiceManager() + + testCases := []struct { + userAgent string + expectedName string + expectedFound bool + desc string + }{ + { + userAgent: "Valve/Steam HTTP Client 1.0", + expectedName: "steam", + expectedFound: true, + desc: "Valve Steam HTTP Client", + }, + { + userAgent: "Steam", + expectedName: "steam", + expectedFound: true, + desc: "Simple Steam user agent", + }, + { + userAgent: "SteamClient/1.0", + expectedName: "steam", + expectedFound: true, + desc: "SteamClient with version", + }, + { + userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", + expectedName: "", + expectedFound: false, + desc: "Browser user agent", + }, + { + userAgent: "", + expectedName: "", + expectedFound: false, + desc: "Empty user agent", + }, + { + userAgent: "curl/7.68.0", + expectedName: "", + expectedFound: false, + desc: "curl user agent", + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + service, found := sm.DetectService(tc.userAgent) + + if found != tc.expectedFound { + t.Errorf("DetectService(%s) found = %v, expected %v", tc.userAgent, found, tc.expectedFound) + } + + if found && service.Name != tc.expectedName { + t.Errorf("DetectService(%s) service name = %s, expected %s", tc.userAgent, service.Name, tc.expectedName) + } + }) + } +} + +func TestServiceManagerExpandability(t *testing.T) { + // Create a service manager for testing + sm := NewServiceManager() + + // Test adding a new service (Epic Games) + epicConfig := &ServiceConfig{ + Name: "epic", + Prefix: "epic", + UserAgents: []string{ + `EpicGamesLauncher`, + `EpicGames`, + `Epic.*Launcher`, + }, + } + + err := sm.AddService(epicConfig) + if err != nil { + t.Fatalf("Failed to add Epic service: %v", err) + } + + // Test Epic Games detection + epicTestCases := []struct { + userAgent string + expectedName string + expectedFound bool + desc string + }{ + { + userAgent: "EpicGamesLauncher/1.0", + expectedName: "epic", + expectedFound: true, + desc: "Epic Games Launcher", + }, + { + userAgent: "EpicGames/2.0", + expectedName: "epic", + expectedFound: true, + desc: "Epic Games client", + }, + { + userAgent: "Epic Launcher 1.5", + expectedName: "epic", + expectedFound: true, + desc: "Epic Launcher with regex match", + }, + { + userAgent: "Steam", + expectedName: "steam", + expectedFound: true, + desc: "Steam should still work", + }, + { + userAgent: "Mozilla/5.0", + expectedName: "", + expectedFound: false, + desc: "Browser should not match any service", + }, + } + + for _, tc := range epicTestCases { + t.Run(tc.desc, func(t *testing.T) { + service, found := sm.DetectService(tc.userAgent) + + if found != tc.expectedFound { + t.Errorf("DetectService(%s) found = %v, expected %v", tc.userAgent, found, tc.expectedFound) + } + + if found && service.Name != tc.expectedName { + t.Errorf("DetectService(%s) service name = %s, expected %s", tc.userAgent, service.Name, tc.expectedName) + } + }) + } + + // Test cache key generation for different services + steamKey := generateServiceCacheKey("/depot/123/chunk/abc", "steam") + epicKey := generateServiceCacheKey("/epic/123/chunk/abc", "epic") + + if !strings.HasPrefix(steamKey, "steam/") { + t.Errorf("Steam cache key should start with 'steam/', got: %s", steamKey) + } + if !strings.HasPrefix(epicKey, "epic/") { + t.Errorf("Epic cache key should start with 'epic/', got: %s", epicKey) + } +} + // Removed hash calculation tests since we switched to lightweight validation func TestSteamKeySharding(t *testing.T) { diff --git a/steamcache/test_cache/.gitkeep b/steamcache/test_cache/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/vfs/disk/disk.go b/vfs/disk/disk.go index b762dc2..2391680 100644 --- a/vfs/disk/disk.go +++ b/vfs/disk/disk.go @@ -184,7 +184,12 @@ func (d *DiskFS) init() { d.mu.Lock() // Extract key from sharded path: remove root and convert sharding back - relPath := strings.ReplaceAll(npath[len(d.root)+1:], "\\", "/") + // Handle both "./disk" and "disk" root paths + rootPath := d.root + if strings.HasPrefix(rootPath, "./") { + rootPath = rootPath[2:] // Remove "./" prefix + } + relPath := strings.ReplaceAll(npath[len(rootPath)+1:], "\\", "/") // Extract the original key from the sharded path k := d.extractKeyFromPath(relPath)