// steamcache/steamcache.go package steamcache import ( "context" "crypto/sha1" "encoding/hex" "fmt" "io" "net" "net/http" "net/url" "os" "path/filepath" "regexp" "s1d3sw1ped/SteamCache2/steamcache/logger" "s1d3sw1ped/SteamCache2/vfs" "s1d3sw1ped/SteamCache2/vfs/cache" "s1d3sw1ped/SteamCache2/vfs/disk" "s1d3sw1ped/SteamCache2/vfs/gc" "s1d3sw1ped/SteamCache2/vfs/memory" "sort" "strings" "sync" "time" "github.com/docker/go-units" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" ) // min returns the minimum of two integers func min(a, b int) int { if a < b { return a } return b } var ( requestsTotal = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "http_requests_total", Help: "Total number of HTTP requests", }, []string{"method", "status"}, ) cacheStatusTotal = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "cache_status_total", Help: "Total cache status counts", }, []string{"status"}, ) responseTime = promauto.NewHistogramVec( prometheus.HistogramOpts{ Name: "response_time_seconds", Help: "Response time in seconds", Buckets: prometheus.DefBuckets, }, []string{"cache_status"}, ) ) // hashVerificationTotal tracks hash verification attempts var hashVerificationTotal = promauto.NewCounterVec( prometheus.CounterOpts{ Name: "hash_verification_total", Help: "Total hash verification attempts", }, []string{"result"}, ) // extractHashFromFilename extracts a hash from a filename if present // Steam depot files often have hashes in their names like: filename_hash.ext func extractHashFromFilename(filename string) (string, bool) { // Common patterns for Steam depot files with hashes patterns := []*regexp.Regexp{ regexp.MustCompile(`^([a-fA-F0-9]{40})$`), // Standalone SHA1 hash (40 hex chars) regexp.MustCompile(`^([a-fA-F0-9]{40})\.`), // SHA1 hash with extension } for _, pattern := range patterns { if matches := pattern.FindStringSubmatch(filename); len(matches) > 1 { return strings.ToLower(matches[1]), true } } // Debug: log when we don't find a hash pattern if strings.Contains(filename, "manifest") { logger.Logger.Debug(). Str("filename", filename). Msg("No hash pattern found in manifest filename") } return "", false } // calculateFileHash calculates the SHA1 hash of the given data func calculateFileHash(data []byte) string { hash := sha1.Sum(data) return hex.EncodeToString(hash[:]) } // calculateResponseHash calculates the SHA1 hash of the full HTTP response func calculateResponseHash(resp *http.Response, bodyData []byte) string { hash := sha1.New() // Include status line statusLine := fmt.Sprintf("HTTP/1.1 %d %s\n", resp.StatusCode, resp.Status) hash.Write([]byte(statusLine)) // Include headers (sorted for consistency) headers := make([]string, 0, len(resp.Header)) for key, values := range resp.Header { for _, value := range values { headers = append(headers, fmt.Sprintf("%s: %s\n", key, value)) } } sort.Strings(headers) for _, header := range headers { hash.Write([]byte(header)) } // Include empty line between headers and body hash.Write([]byte("\n")) // Include body hash.Write(bodyData) return hex.EncodeToString(hash.Sum(nil)) } // verifyFileHash verifies that the file content matches the expected hash func verifyFileHash(data []byte, expectedHash string) bool { actualHash := calculateFileHash(data) return strings.EqualFold(actualHash, expectedHash) } // verifyResponseHash verifies that the full HTTP response matches the expected hash func verifyResponseHash(resp *http.Response, bodyData []byte, expectedHash string) bool { actualHash := calculateResponseHash(resp, bodyData) return strings.EqualFold(actualHash, expectedHash) } type SteamCache struct { address string upstream string vfs vfs.VFS memory *memory.MemoryFS disk *disk.DiskFS memorygc *gc.GCFS diskgc *gc.GCFS server *http.Server client *http.Client cancel context.CancelFunc wg sync.WaitGroup } func New(address string, memorySize string, diskSize string, diskPath, upstream, memoryGC, diskGC string) *SteamCache { memorysize, err := units.FromHumanSize(memorySize) if err != nil { panic(err) } disksize, err := units.FromHumanSize(diskSize) if err != nil { panic(err) } c := cache.New( gc.AdaptivePromotionDeciderFunc, ) var m *memory.MemoryFS var mgc *gc.GCFS if memorysize > 0 { m = memory.New(memorysize) memoryGCAlgo := gc.GCAlgorithm(memoryGC) if memoryGCAlgo == "" { memoryGCAlgo = gc.LRU // default to LRU } mgc = gc.New(m, gc.GetGCAlgorithm(memoryGCAlgo)) } var d *disk.DiskFS var dgc *gc.GCFS if disksize > 0 { d = disk.New(diskPath, disksize) diskGCAlgo := gc.GCAlgorithm(diskGC) if diskGCAlgo == "" { diskGCAlgo = gc.LRU // default to LRU } dgc = gc.New(d, gc.GetGCAlgorithm(diskGCAlgo)) } // configure the cache to match the specified mode (memory only, disk only, or memory and disk) based on the provided sizes if disksize == 0 && memorysize != 0 { //memory only mode - no disk c.SetSlow(mgc) } else if disksize != 0 && memorysize == 0 { // disk only mode c.SetSlow(dgc) } else if disksize != 0 && memorysize != 0 { // memory and disk mode c.SetFast(mgc) c.SetSlow(dgc) } else { // no memory or disk isn't a valid configuration logger.Logger.Error().Bool("memory", false).Bool("disk", false).Msg("configuration invalid :( exiting") os.Exit(1) } transport := &http.Transport{ MaxIdleConns: 200, // Increased from 100 MaxIdleConnsPerHost: 50, // Increased from 10 IdleConnTimeout: 120 * time.Second, // Increased from 90s DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, TLSHandshakeTimeout: 15 * time.Second, // Increased from 10s ResponseHeaderTimeout: 30 * time.Second, // Increased from 10s ExpectContinueTimeout: 5 * time.Second, // Increased from 1s DisableCompression: true, // Steam doesn't use compression ForceAttemptHTTP2: true, // Enable HTTP/2 if available } client := &http.Client{ Transport: transport, Timeout: 120 * time.Second, // Increased from 60s } sc := &SteamCache{ upstream: upstream, address: address, vfs: c, memory: m, disk: d, memorygc: mgc, diskgc: dgc, client: client, server: &http.Server{ Addr: address, ReadTimeout: 30 * time.Second, // Increased WriteTimeout: 60 * time.Second, // Increased IdleTimeout: 120 * time.Second, // Good for keep-alive ReadHeaderTimeout: 10 * time.Second, // New, for header attacks MaxHeaderBytes: 1 << 20, // 1MB, optional }, } // Log GC algorithm configuration if m != nil { logger.Logger.Info().Str("memory_gc", memoryGC).Msg("Memory cache GC algorithm configured") } if d != nil { logger.Logger.Info().Str("disk_gc", diskGC).Msg("Disk cache GC algorithm configured") } if d != nil { if d.Size() > d.Capacity() { gcHandler := gc.GetGCAlgorithm(gc.GCAlgorithm(diskGC)) gcHandler(d, uint(d.Size()-d.Capacity())) } } return sc } func (sc *SteamCache) Run() { if sc.upstream != "" { resp, err := sc.client.Get(sc.upstream) if err != nil || resp.StatusCode != http.StatusOK { logger.Logger.Error().Err(err).Int("status_code", resp.StatusCode).Str("upstream", sc.upstream).Msg("Failed to connect to upstream server") os.Exit(1) } resp.Body.Close() } sc.server.Handler = sc ctx, cancel := context.WithCancel(context.Background()) sc.cancel = cancel sc.wg.Add(1) go func() { defer sc.wg.Done() err := sc.server.ListenAndServe() if err != nil && err != http.ErrServerClosed { logger.Logger.Error().Err(err).Msg("Failed to start SteamCache2") os.Exit(1) } }() <-ctx.Done() sc.server.Shutdown(ctx) sc.wg.Wait() } func (sc *SteamCache) Shutdown() { if sc.cancel != nil { sc.cancel() } sc.wg.Wait() } func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { requestsTotal.WithLabelValues(r.Method, "405").Inc() logger.Logger.Warn().Str("method", r.Method).Msg("Only GET method is supported") http.Error(w, "Only GET method is supported", http.StatusMethodNotAllowed) return } if r.URL.Path == "/" { w.WriteHeader(http.StatusOK) // this is used by steamcache2's upstream verification at startup return } if r.URL.String() == "/lancache-heartbeat" { w.Header().Add("X-LanCache-Processed-By", "SteamCache2") w.WriteHeader(http.StatusNoContent) w.Write(nil) return } if r.URL.Path == "/metrics" { promhttp.Handler().ServeHTTP(w, r) return } if strings.HasPrefix(r.URL.String(), "/depot/") { // trim the query parameters from the URL path // this is necessary because the cache key should not include query parameters path := strings.Split(r.URL.String(), "?")[0] tstart := time.Now() cacheKey := strings.ReplaceAll(path[1:], "\\", "/") // replace all backslashes with forward slashes shouldn't be necessary but just in case if cacheKey == "" { requestsTotal.WithLabelValues(r.Method, "400").Inc() logger.Logger.Warn().Str("url", path).Msg("Invalid URL") http.Error(w, "Invalid URL", http.StatusBadRequest) return } w.Header().Add("X-LanCache-Processed-By", "SteamCache2") // SteamPrefill uses this header to determine if the request was processed by the cache maybe steam uses it too reader, err := sc.vfs.Open(cacheKey) if err == nil { defer reader.Close() w.Header().Add("X-LanCache-Status", "HIT") io.Copy(w, reader) logger.Logger.Info(). Str("key", cacheKey). Str("host", r.Host). Str("status", "HIT"). Dur("duration", time.Since(tstart)). Msg("request") requestsTotal.WithLabelValues(r.Method, "200").Inc() cacheStatusTotal.WithLabelValues("HIT").Inc() responseTime.WithLabelValues("HIT").Observe(time.Since(tstart).Seconds()) return } var req *http.Request if sc.upstream != "" { // if an upstream server is configured, proxy the request to the upstream server ur, err := url.JoinPath(sc.upstream, path) if err != nil { requestsTotal.WithLabelValues(r.Method, "500").Inc() logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to join URL path") http.Error(w, "Failed to join URL path", http.StatusInternalServerError) return } req, err = http.NewRequest(http.MethodGet, ur, nil) if err != nil { requestsTotal.WithLabelValues(r.Method, "500").Inc() logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to create request") http.Error(w, "Failed to create request", http.StatusInternalServerError) return } req.Host = r.Host } else { // if no upstream server is configured, proxy the request to the host specified in the request host := r.Host if r.Header.Get("X-Sls-Https") == "enable" { host = "https://" + host } else { host = "http://" + host } ur, err := url.JoinPath(host, path) if err != nil { requestsTotal.WithLabelValues(r.Method, "500").Inc() logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to join URL path") http.Error(w, "Failed to join URL path", http.StatusInternalServerError) return } req, err = http.NewRequest(http.MethodGet, ur, nil) if err != nil { requestsTotal.WithLabelValues(r.Method, "500").Inc() logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to create request") http.Error(w, "Failed to create request", http.StatusInternalServerError) return } } // Copy headers from the original request to the new request for key, values := range r.Header { for _, value := range values { req.Header.Add(key, value) } } // Retry logic backoffSchedule := []time.Duration{1 * time.Second, 3 * time.Second, 10 * time.Second} var resp *http.Response for i, backoff := range backoffSchedule { resp, err = sc.client.Do(req) if err == nil && resp.StatusCode == http.StatusOK { break } if i < len(backoffSchedule)-1 { time.Sleep(backoff) } } if err != nil || resp.StatusCode != http.StatusOK { requestsTotal.WithLabelValues(r.Method, "500 upstream host "+r.Host).Inc() logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to fetch the requested URL") http.Error(w, "Failed to fetch the requested URL", http.StatusInternalServerError) return } defer resp.Body.Close() size := resp.ContentLength // Read the entire response body into memory for hash verification bodyData, err := io.ReadAll(resp.Body) if err != nil { requestsTotal.WithLabelValues(r.Method, "500").Inc() logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to read response body") http.Error(w, "Failed to read response body", http.StatusInternalServerError) return } // Extract filename from cache key for hash verification filename := filepath.Base(cacheKey) expectedHash, hasHash := extractHashFromFilename(filename) // Hash verification using Steam's X-Content-Sha header and content length verification hashVerified := true if hasHash { // Get the hash from Steam's X-Content-Sha header steamHash := resp.Header.Get("X-Content-Sha") // Verify using Steam's hash if strings.EqualFold(steamHash, expectedHash) { hashVerificationTotal.WithLabelValues("success").Inc() } else { hashVerificationTotal.WithLabelValues("failed").Inc() logger.Logger.Error(). Str("key", cacheKey). Str("expected_hash", expectedHash). Str("steam_hash", steamHash). Int("content_length", len(bodyData)). Msg("Steam hash verification failed - Steam's hash doesn't match filename") hashVerified = false } } else { hashVerificationTotal.WithLabelValues("no_hash").Inc() } // Always verify content length as an additional safety check if resp.ContentLength > 0 && int64(len(bodyData)) != resp.ContentLength { hashVerificationTotal.WithLabelValues("content_length_failed").Inc() logger.Logger.Error(). Str("key", cacheKey). Int("actual_content_length", len(bodyData)). Int64("expected_content_length", resp.ContentLength). Msg("Content length verification failed") hashVerified = false } else if resp.ContentLength > 0 { hashVerificationTotal.WithLabelValues("content_length_success").Inc() } // Write to response (always serve the file) w.Header().Add("X-LanCache-Status", "MISS") w.Write(bodyData) // Only cache the file if hash verification passed (or no hash was present) if hashVerified { writer, _ := sc.vfs.Create(cacheKey, size) if writer != nil { defer writer.Close() writer.Write(bodyData) } } else { logger.Logger.Warn(). Str("key", cacheKey). Msg("File served but not cached due to hash verification failure") } logger.Logger.Info(). Str("key", cacheKey). Str("host", r.Host). Str("status", "MISS"). Dur("duration", time.Since(tstart)). Msg("request") requestsTotal.WithLabelValues(r.Method, "200").Inc() cacheStatusTotal.WithLabelValues("MISS").Inc() responseTime.WithLabelValues("MISS").Observe(time.Since(tstart).Seconds()) return } if r.URL.Path == "/favicon.ico" { w.WriteHeader(http.StatusNoContent) return } if r.URL.Path == "/robots.txt" { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte("User-agent: *\nDisallow: /\n")) return } requestsTotal.WithLabelValues(r.Method, "404").Inc() logger.Logger.Warn().Str("url", r.URL.String()).Msg("Not found") http.Error(w, "Not found", http.StatusNotFound) }