Files
steamcache2/steamcache/steamcache.go
Justin Harms 694c223b00 Add integration tests and service management for SteamCache
- Introduced integration tests for SteamCache to validate caching behavior with real Steam URLs.
- Implemented a ServiceManager to manage service configurations, allowing for dynamic detection of services based on User-Agent.
- Updated cache key generation to include service prefixes, enhancing cache organization and retrieval.
- Enhanced the caching logic to support multiple services, starting with Steam and Epic Games.
- Improved .gitignore to exclude test cache files while retaining necessary structure.
2025-09-21 20:07:18 -05:00

1584 lines
45 KiB
Go

// steamcache/steamcache.go
package steamcache
import (
"bytes"
"context"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"regexp"
"s1d3sw1ped/SteamCache2/steamcache/logger"
"s1d3sw1ped/SteamCache2/vfs"
"s1d3sw1ped/SteamCache2/vfs/cache"
"s1d3sw1ped/SteamCache2/vfs/disk"
"s1d3sw1ped/SteamCache2/vfs/gc"
"s1d3sw1ped/SteamCache2/vfs/memory"
"strconv"
"strings"
"sync"
"time"
"github.com/docker/go-units"
"golang.org/x/sync/semaphore"
)
// ServiceConfig defines configuration for a cacheable service
type ServiceConfig struct {
Name string `json:"name"` // Service name (e.g., "steam", "epic", "origin")
Prefix string `json:"prefix"` // Cache key prefix (e.g., "steam", "epic")
UserAgents []string `json:"user_agents"` // User-Agent patterns to match
compiled []*regexp.Regexp // Compiled regex patterns (internal use)
}
// ServiceManager manages service configurations
type ServiceManager struct {
services map[string]*ServiceConfig
mutex sync.RWMutex
}
// NewServiceManager creates a new service manager with default Steam configuration
func NewServiceManager() *ServiceManager {
sm := &ServiceManager{
services: make(map[string]*ServiceConfig),
}
// Add default Steam service configuration
steamConfig := &ServiceConfig{
Name: "steam",
Prefix: "steam",
UserAgents: []string{
`Valve/Steam HTTP Client 1\.0`,
`SteamClient`,
`Steam`,
},
}
sm.AddService(steamConfig)
return sm
}
// AddService adds or updates a service configuration
func (sm *ServiceManager) AddService(config *ServiceConfig) error {
sm.mutex.Lock()
defer sm.mutex.Unlock()
// Compile regex patterns
compiled := make([]*regexp.Regexp, 0, len(config.UserAgents))
for _, pattern := range config.UserAgents {
regex, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("invalid regex pattern %q for service %s: %w", pattern, config.Name, err)
}
compiled = append(compiled, regex)
}
config.compiled = compiled
sm.services[config.Name] = config
return nil
}
// GetService returns a service configuration by name
func (sm *ServiceManager) GetService(name string) (*ServiceConfig, bool) {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
service, exists := sm.services[name]
return service, exists
}
// DetectService detects which service a request belongs to based on User-Agent
func (sm *ServiceManager) DetectService(userAgent string) (*ServiceConfig, bool) {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
for _, service := range sm.services {
for _, regex := range service.compiled {
if regex.MatchString(userAgent) {
return service, true
}
}
}
return nil, false
}
// ListServices returns all configured services
func (sm *ServiceManager) ListServices() []*ServiceConfig {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
services := make([]*ServiceConfig, 0, len(sm.services))
for _, service := range sm.services {
services = append(services, service)
}
return services
}
// Cache file format structures
const (
CacheFileMagic = "SC2C" // SteamCache2 Cache
)
// CacheFileFormat represents the complete cache file structure
type CacheFileFormat struct {
ContentHash string // SHA256 hash of the response body (internal)
ResponseSize int64 // Size of the entire HTTP response
Response []byte // The entire HTTP response as raw bytes
}
// serializeRawResponse serializes a raw HTTP response into our text-based cache format
// upstreamHash and upstreamAlgo are used for verification during download but not stored
func serializeRawResponse(url string, rawResponse []byte, upstreamHash string, upstreamAlgo string) ([]byte, error) {
// Extract body from raw response for hash calculation
bodyStart := bytes.Index(rawResponse, []byte("\r\n\r\n"))
if bodyStart == -1 {
return nil, fmt.Errorf("invalid HTTP response format: no body separator found")
}
bodyStart += 4 // Skip the \r\n\r\n
bodyData := rawResponse[bodyStart:]
// Always calculate our internal SHA256 hash
contentHash := calculateSHA256(bodyData)
// Create text-based cache file
var buf bytes.Buffer
// First line: magic number, content hash, response size
headerLine := fmt.Sprintf("%s %s %d\n", CacheFileMagic, contentHash, len(rawResponse))
buf.WriteString(headerLine)
// Rest of the file: raw HTTP response
buf.Write(rawResponse)
return buf.Bytes(), nil
}
// deserializeCacheFile deserializes our text-based cache format and returns both metadata and raw response
func deserializeCacheFile(data []byte) (*CacheFileFormat, error) {
if len(data) < 4 {
return nil, fmt.Errorf("cache file too short")
}
// Find the first newline to separate header from content
newlineIndex := bytes.IndexByte(data, '\n')
if newlineIndex == -1 {
return nil, fmt.Errorf("invalid cache file format: no header line found")
}
// Parse header line: "SC2C <hash> <size>"
headerLine := string(data[:newlineIndex])
parts := strings.Fields(headerLine)
if len(parts) != 3 {
return nil, fmt.Errorf("invalid header format: expected 3 fields, got %d", len(parts))
}
// Check magic number
if parts[0] != CacheFileMagic {
return nil, fmt.Errorf("invalid cache file magic number: %s", parts[0])
}
// Parse content hash
contentHash := parts[1]
if len(contentHash) != 64 {
return nil, fmt.Errorf("invalid content hash length: expected 64, got %d", len(contentHash))
}
// Parse response size
responseSize, err := strconv.ParseInt(parts[2], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid response size: %w", err)
}
// Extract raw response (everything after the header line)
rawResponse := data[newlineIndex+1:]
// Verify response size
if int64(len(rawResponse)) != responseSize {
return nil, fmt.Errorf("response size mismatch: expected %d, got %d",
responseSize, len(rawResponse))
}
// Extract body from response for hash verification
bodyStart := bytes.Index(rawResponse, []byte("\r\n\r\n"))
if bodyStart == -1 {
return nil, fmt.Errorf("invalid HTTP response format: no body separator found")
}
bodyStart += 4 // Skip the \r\n\r\n
bodyData := rawResponse[bodyStart:]
// Verify our internal SHA256 hash
calculatedSHA256 := calculateSHA256(bodyData)
if calculatedSHA256 != contentHash {
return nil, fmt.Errorf("content hash mismatch: expected %s, got %s",
contentHash, calculatedSHA256)
}
// Create cache file structure
cacheFile := &CacheFileFormat{
ContentHash: contentHash,
ResponseSize: responseSize,
Response: rawResponse,
}
return cacheFile, nil
}
// reconstructRawResponse reconstructs the exact HTTP response as received from upstream
func (sc *SteamCache) reconstructRawResponse(resp *http.Response, bodyData []byte) []byte {
var responseBuffer bytes.Buffer
// Write status line exactly as it would appear from upstream
responseBuffer.WriteString(fmt.Sprintf("HTTP/1.1 %d %s\r\n", resp.StatusCode, http.StatusText(resp.StatusCode)))
// Write headers in the exact order and format as received
for k, vv := range resp.Header {
for _, v := range vv {
responseBuffer.WriteString(fmt.Sprintf("%s: %s\r\n", k, v))
}
}
responseBuffer.WriteString("\r\n") // End of headers
// Write body
responseBuffer.Write(bodyData)
return responseBuffer.Bytes()
}
// streamCachedResponse streams the raw HTTP response bytes directly to the client
// Supports Range requests by serving partial content from the cached full file
func (sc *SteamCache) streamCachedResponse(w http.ResponseWriter, r *http.Request, cacheFile *CacheFileFormat, cacheKey, clientIP string, tstart time.Time) {
// Parse the HTTP response to extract headers for our own headers
responseReader := bytes.NewReader(cacheFile.Response)
// Read the status line
statusLine, err := readLine(responseReader)
if err != nil {
logger.Logger.Error().
Str("key", cacheKey).
Str("url", r.URL.String()).
Err(err).
Msg("Failed to read status line from cached response")
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// Parse status code from status line
var statusCode int
if _, err := fmt.Sscanf(statusLine, "HTTP/1.1 %d", &statusCode); err != nil {
logger.Logger.Error().
Str("key", cacheKey).
Str("url", r.URL.String()).
Err(err).
Msg("Failed to parse status code from cached response")
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// Read headers
headers := make(map[string][]string)
for {
line, err := readLine(responseReader)
if err != nil {
logger.Logger.Error().
Str("key", cacheKey).
Str("url", r.URL.String()).
Err(err).
Msg("Failed to read headers from cached response")
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
// Empty line indicates end of headers
if line == "" {
break
}
// Parse header line
parts := strings.SplitN(line, ":", 2)
if len(parts) == 2 {
key := strings.TrimSpace(parts[0])
value := strings.TrimSpace(parts[1])
headers[key] = append(headers[key], value)
}
}
// Get the body data (everything after headers)
bodyStart := responseReader.Size() - int64(responseReader.Len())
bodyData := cacheFile.Response[bodyStart:]
// Handle Range requests
rangeHeader := r.Header.Get("Range")
if rangeHeader != "" {
// Parse the range request
start, end, totalSize, valid := parseRangeHeader(rangeHeader, int64(len(bodyData)))
if !valid {
// Invalid range - return 416 Range Not Satisfiable
w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", len(bodyData)))
w.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return
}
// Extract the requested range from the body
rangeData := bodyData[start : end+1]
// Set appropriate headers for partial content
w.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, totalSize))
w.Header().Set("Content-Length", fmt.Sprintf("%d", len(rangeData)))
w.Header().Set("Accept-Ranges", "bytes")
// Copy other headers (excluding Content-Length which we set above)
for k, vv := range headers {
if _, skip := hopByHopHeaders[http.CanonicalHeaderKey(k)]; skip {
continue
}
if strings.ToLower(k) == "content-length" {
continue // We set this above for the range
}
for _, v := range vv {
w.Header().Add(k, v)
}
}
// Add our own headers
w.Header().Set("X-LanCache-Status", "HIT")
w.Header().Set("X-LanCache-Processed-By", "SteamCache2")
// Write 206 Partial Content status
w.WriteHeader(http.StatusPartialContent)
// Send the range data
w.Write(rangeData)
logger.Logger.Info().
Str("key", cacheKey).
Str("url", r.URL.String()).
Str("host", r.Host).
Str("client_ip", clientIP).
Str("status", "HIT").
Str("range", fmt.Sprintf("%d-%d/%d", start, end, totalSize)).
Dur("zduration", time.Since(tstart)).
Msg("cache request")
return
}
// No range request - serve the full file
// Set response headers (excluding hop-by-hop headers)
for k, vv := range headers {
if _, skip := hopByHopHeaders[http.CanonicalHeaderKey(k)]; skip {
continue
}
for _, v := range vv {
w.Header().Add(k, v)
}
}
// Add our own headers
w.Header().Set("X-LanCache-Status", "HIT")
w.Header().Set("X-LanCache-Processed-By", "SteamCache2")
// Write status code
w.WriteHeader(statusCode)
// Stream the full response body
w.Write(bodyData)
logger.Logger.Info().
Str("key", cacheKey).
Str("url", r.URL.String()).
Str("host", r.Host).
Str("client_ip", clientIP).
Str("status", "HIT").
Dur("zduration", time.Since(tstart)).
Msg("cache request")
}
// readLine reads a line from the reader, removing \r\n
func readLine(reader *bytes.Reader) (string, error) {
var line []byte
for {
b, err := reader.ReadByte()
if err != nil {
return "", err
}
if b == '\n' {
// Remove \r if present
if len(line) > 0 && line[len(line)-1] == '\r' {
line = line[:len(line)-1]
}
return string(line), nil
}
line = append(line, b)
}
}
// parseRangeHeader parses a Range header and returns start, end, totalSize, and validity
// Supports formats like "bytes=0-1023", "bytes=1024-", "bytes=-500"
func parseRangeHeader(rangeHeader string, totalSize int64) (start, end, total int64, valid bool) {
// Remove "bytes=" prefix
if !strings.HasPrefix(strings.ToLower(rangeHeader), "bytes=") {
return 0, 0, totalSize, false
}
rangeSpec := strings.TrimSpace(rangeHeader[6:]) // Remove "bytes="
// Handle single range (we don't support multiple ranges)
if strings.Contains(rangeSpec, ",") {
return 0, 0, totalSize, false
}
// Parse the range
if strings.Contains(rangeSpec, "-") {
parts := strings.Split(rangeSpec, "-")
if len(parts) != 2 {
return 0, 0, totalSize, false
}
startStr := strings.TrimSpace(parts[0])
endStr := strings.TrimSpace(parts[1])
var start, end int64
var err error
if startStr == "" {
// Suffix range: "-500" means last 500 bytes
if endStr == "" {
return 0, 0, totalSize, false
}
suffix, err := strconv.ParseInt(endStr, 10, 64)
if err != nil || suffix <= 0 {
return 0, 0, totalSize, false
}
start = totalSize - suffix
if start < 0 {
start = 0
}
end = totalSize - 1
} else if endStr == "" {
// Open range: "1024-" means from 1024 to end
start, err = strconv.ParseInt(startStr, 10, 64)
if err != nil || start < 0 {
return 0, 0, totalSize, false
}
end = totalSize - 1
} else {
// Closed range: "0-1023"
start, err = strconv.ParseInt(startStr, 10, 64)
if err != nil || start < 0 {
return 0, 0, totalSize, false
}
end, err = strconv.ParseInt(endStr, 10, 64)
if err != nil || end < start {
return 0, 0, totalSize, false
}
}
// Validate bounds
if start >= totalSize || end >= totalSize || start > end {
return 0, 0, totalSize, false
}
return start, end, totalSize, true
}
return 0, 0, totalSize, false
}
// generateURLHash creates a SHA256 hash of the entire URL path for cache key
func generateURLHash(urlPath string) string {
// Validate input to prevent cache key pollution
if urlPath == "" {
return ""
}
hash := sha256.Sum256([]byte(urlPath))
return hex.EncodeToString(hash[:])
}
// calculateSHA256 calculates SHA256 hash of the given data
func calculateSHA256(data []byte) string {
hasher := sha256.New()
hasher.Write(data)
return hex.EncodeToString(hasher.Sum(nil))
}
// calculateSHA1 calculates SHA1 hash of the given data (for legacy verification only)
func calculateSHA1(data []byte) string {
hasher := sha1.New()
hasher.Write(data)
return hex.EncodeToString(hasher.Sum(nil))
}
// extractHashFromURL extracts hash from URL path (Steam chunk URLs contain SHA1 hashes)
func extractHashFromURL(urlPath string) (hash string, algorithm string) {
// Steam chunk URLs: /depot/123/chunk/SHA1_HASH
// Steam manifest URLs: /depot/123/manifest/.../SHA1_HASH
// Steam patch URLs: /depot/123/patch/.../SHA1_HASH
// Look for chunk URLs with SHA1 hash
if strings.Contains(urlPath, "/chunk/") {
parts := strings.Split(urlPath, "/chunk/")
if len(parts) == 2 {
hashPart := parts[1]
// Remove any query parameters
if questionMark := strings.Index(hashPart, "?"); questionMark != -1 {
hashPart = hashPart[:questionMark]
}
// Check if it's a valid SHA1 hash (40 hex chars)
if len(hashPart) == 40 && isHexString(hashPart) {
logger.Logger.Debug().
Str("url_path", urlPath).
Str("sha1_hash", hashPart).
Msg("Extracted SHA1 hash from Steam chunk URL")
return hashPart, "sha1"
}
}
}
// Look for manifest URLs with SHA1 hash at the end
if strings.Contains(urlPath, "/manifest/") {
parts := strings.Split(urlPath, "/")
if len(parts) > 0 {
lastPart := parts[len(parts)-1]
// Remove any query parameters
if questionMark := strings.Index(lastPart, "?"); questionMark != -1 {
lastPart = lastPart[:questionMark]
}
// Check if it's a valid SHA1 hash (40 hex chars)
if len(lastPart) == 40 && isHexString(lastPart) {
logger.Logger.Debug().
Str("url_path", urlPath).
Str("sha1_hash", lastPart).
Msg("Extracted SHA1 hash from Steam manifest URL")
return lastPart, "sha1"
}
}
}
// Look for patch URLs with SHA1 hash at the end
if strings.Contains(urlPath, "/patch/") {
parts := strings.Split(urlPath, "/")
if len(parts) > 0 {
lastPart := parts[len(parts)-1]
// Remove any query parameters
if questionMark := strings.Index(lastPart, "?"); questionMark != -1 {
lastPart = lastPart[:questionMark]
}
// Check if it's a valid SHA1 hash (40 hex chars)
if len(lastPart) == 40 && isHexString(lastPart) {
logger.Logger.Debug().
Str("url_path", urlPath).
Str("sha1_hash", lastPart).
Msg("Extracted SHA1 hash from Steam patch URL")
return lastPart, "sha1"
}
}
}
return "", ""
}
// extractUpstreamHash extracts hash from upstream server headers and URL path, prioritizing by security
// Returns the hash value and the algorithm used (sha256, sha1, or empty if none found)
func extractUpstreamHash(headers http.Header, urlPath string) (hash string, algorithm string) {
// Priority order: SHA256 (most secure) -> SHA1 (legacy) -> none
// 1. Try SHA256 headers first (highest priority)
sha256Headers := []string{
"X-SHA256", // Custom header
"Content-SHA256", // Content hash
"X-Content-SHA256", // Service specific
"Digest", // RFC 3230 digest header
}
for _, headerName := range sha256Headers {
if value := headers.Get(headerName); value != "" {
// Remove quotes if present (ETag often has quotes)
value = strings.Trim(value, `"`)
// Check for SHA256 prefix in Digest header
if strings.HasPrefix(value, "sha256=") {
hash := strings.TrimPrefix(value, "sha256=")
if len(hash) == 64 && isHexString(hash) {
logger.Logger.Debug().
Str("header_name", headerName).
Str("sha256_hash", hash).
Msg("Extracted SHA256 hash from upstream header")
return hash, "sha256"
}
}
// Direct SHA256 hash (64 chars)
if len(value) == 64 && isHexString(value) {
logger.Logger.Debug().
Str("header_name", headerName).
Str("sha256_hash", value).
Msg("Extracted SHA256 hash from upstream header")
return value, "sha256"
}
}
}
// 2. Fallback to SHA1 headers (legacy support)
sha1Headers := []string{
"X-SHA1", // Legacy custom header
"Content-SHA1", // Legacy content hash
"X-Content-SHA1", // Legacy Steam specific
"X-Content-Sha", // Legacy Steam specific (lowercase variant)
"ETag", // May contain SHA1
}
for _, headerName := range sha1Headers {
if value := headers.Get(headerName); value != "" {
// Remove quotes if present (ETag often has quotes)
value = strings.Trim(value, `"`)
// Check for SHA1 prefix in Digest header
if strings.HasPrefix(value, "sha1=") {
hash := strings.TrimPrefix(value, "sha1=")
if len(hash) == 40 && isHexString(hash) {
logger.Logger.Debug().
Str("header_name", headerName).
Str("sha1_hash", hash).
Msg("Extracted SHA1 hash from upstream header (legacy)")
return hash, "sha1"
}
}
// Direct SHA1 hash (40 chars)
if len(value) == 40 && isHexString(value) {
logger.Logger.Debug().
Str("header_name", headerName).
Str("sha1_hash", value).
Msg("Extracted SHA1 hash from upstream header (legacy)")
return value, "sha1"
}
}
}
// 3. Fallback to URL path extraction (Steam chunk URLs)
urlHash, urlAlgo := extractHashFromURL(urlPath)
if urlHash != "" {
return urlHash, urlAlgo
}
logger.Logger.Debug().Msg("No upstream hash found in headers or URL")
return "", ""
}
// isHexString checks if a string contains only hexadecimal characters
func isHexString(s string) bool {
for _, r := range s {
if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')) {
return false
}
}
return true
}
// verifyCompleteFile verifies that we received the complete file by checking Content-Length
// Returns true if the file is complete, false if it's incomplete (allowing retry)
func (sc *SteamCache) verifyCompleteFile(bodyData []byte, resp *http.Response, urlPath string, cacheKey string) bool {
// Check if we have a Content-Length header to verify against
if resp.ContentLength > 0 {
receivedBytes := int64(len(bodyData))
if receivedBytes != resp.ContentLength {
logger.Logger.Warn().
Str("key", cacheKey).
Str("url", urlPath).
Int64("received_bytes", receivedBytes).
Int64("expected_bytes", resp.ContentLength).
Msg("File size mismatch - incomplete download detected")
return false
}
logger.Logger.Debug().
Str("key", cacheKey).
Str("url", urlPath).
Int64("file_size", receivedBytes).
Msg("File completeness verified")
} else {
// No Content-Length header - we can't verify completeness
// This is common with chunked transfer encoding
// We don't cache chunked content to avoid risk of incomplete data
logger.Logger.Info().
Str("key", cacheKey).
Str("url", urlPath).
Int("received_bytes", len(bodyData)).
Msg("No Content-Length header - passing through without caching")
return false // Don't cache chunked content
}
// Basic check: ensure we got some content
if len(bodyData) == 0 {
logger.Logger.Warn().
Str("key", cacheKey).
Str("url", urlPath).
Msg("Empty file received")
return false
}
return true
}
// detectService detects which service a request belongs to based on User-Agent
func (sc *SteamCache) detectService(r *http.Request) (*ServiceConfig, bool) {
userAgent := r.Header.Get("User-Agent")
if userAgent == "" {
return nil, false
}
return sc.serviceManager.DetectService(userAgent)
}
// generateServiceCacheKey creates a cache key from the URL path using SHA256
// The prefix indicates which service the request came from (detected via User-Agent)
// Input: /depot/1684171/chunk/0016cfc5019b8baa6026aa1cce93e685d6e06c6e, "steam"
// Output: steam/a1b2c3d4e5f678901234567890123456789012345678901234567890
func generateServiceCacheKey(urlPath string, servicePrefix string) string {
// Create a SHA256 hash of the entire path for all service client requests
return servicePrefix + "/" + generateURLHash(urlPath)
}
var hopByHopHeaders = map[string]struct{}{
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"TE": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
"Date": {},
"Server": {},
}
type clientLimiter struct {
semaphore *semaphore.Weighted
lastSeen time.Time
}
type coalescedRequest struct {
responseChan chan *http.Response
errorChan chan error
waitingCount int
done bool
mu sync.Mutex
}
func newCoalescedRequest() *coalescedRequest {
return &coalescedRequest{
responseChan: make(chan *http.Response, 1),
errorChan: make(chan error, 1),
waitingCount: 1,
done: false,
}
}
func (cr *coalescedRequest) addWaiter() {
cr.mu.Lock()
defer cr.mu.Unlock()
cr.waitingCount++
}
func (cr *coalescedRequest) complete(resp *http.Response, err error) {
cr.mu.Lock()
defer cr.mu.Unlock()
if cr.done {
return
}
cr.done = true
if err != nil {
select {
case cr.errorChan <- err:
default:
}
} else {
select {
case cr.responseChan <- resp:
default:
}
}
}
// getOrCreateCoalescedRequest gets an existing coalesced request or creates a new one
func (sc *SteamCache) getOrCreateCoalescedRequest(cacheKey string) (*coalescedRequest, bool) {
sc.coalescedRequestsMu.Lock()
defer sc.coalescedRequestsMu.Unlock()
if cr, exists := sc.coalescedRequests[cacheKey]; exists {
cr.addWaiter()
return cr, false
}
cr := newCoalescedRequest()
sc.coalescedRequests[cacheKey] = cr
return cr, true
}
// removeCoalescedRequest removes a completed coalesced request
func (sc *SteamCache) removeCoalescedRequest(cacheKey string) {
sc.coalescedRequestsMu.Lock()
defer sc.coalescedRequestsMu.Unlock()
delete(sc.coalescedRequests, cacheKey)
}
// getClientIP extracts the client IP address from the request
func getClientIP(r *http.Request) string {
// Check for forwarded headers first (common in proxy setups)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// X-Forwarded-For can contain multiple IPs, take the first one
if idx := strings.Index(xff, ","); idx > 0 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return host
}
return r.RemoteAddr
}
// getOrCreateClientLimiter gets or creates a rate limiter for a client IP
func (sc *SteamCache) getOrCreateClientLimiter(clientIP string) *clientLimiter {
sc.clientRequestsMu.Lock()
defer sc.clientRequestsMu.Unlock()
limiter, exists := sc.clientRequests[clientIP]
if !exists || time.Since(limiter.lastSeen) > 5*time.Minute {
// Create new limiter or refresh existing one
limiter = &clientLimiter{
semaphore: semaphore.NewWeighted(sc.maxRequestsPerClient),
lastSeen: time.Now(),
}
sc.clientRequests[clientIP] = limiter
} else {
limiter.lastSeen = time.Now()
}
return limiter
}
// cleanupOldClientLimiters removes old client limiters to prevent memory leaks
func (sc *SteamCache) cleanupOldClientLimiters() {
for {
time.Sleep(10 * time.Minute) // Clean up every 10 minutes
sc.clientRequestsMu.Lock()
now := time.Now()
for ip, limiter := range sc.clientRequests {
if now.Sub(limiter.lastSeen) > 30*time.Minute {
delete(sc.clientRequests, ip)
}
}
sc.clientRequestsMu.Unlock()
}
}
type SteamCache struct {
address string
upstream string
vfs vfs.VFS
memory *memory.MemoryFS
disk *disk.DiskFS
memorygc *gc.GCFS
diskgc *gc.GCFS
server *http.Server
client *http.Client
cancel context.CancelFunc
wg sync.WaitGroup
// Request coalescing structures
coalescedRequests map[string]*coalescedRequest
coalescedRequestsMu sync.RWMutex
// Concurrency control
maxConcurrentRequests int64
requestSemaphore *semaphore.Weighted
// Per-client rate limiting
clientRequests map[string]*clientLimiter
clientRequestsMu sync.RWMutex
maxRequestsPerClient int64
// Service management
serviceManager *ServiceManager
}
func New(address string, memorySize string, diskSize string, diskPath, upstream, memoryGC, diskGC string, maxConcurrentRequests int64, maxRequestsPerClient int64) *SteamCache {
memorysize, err := units.FromHumanSize(memorySize)
if err != nil {
panic(err)
}
disksize, err := units.FromHumanSize(diskSize)
if err != nil {
panic(err)
}
c := cache.New()
var m *memory.MemoryFS
var mgc *gc.GCFS
if memorysize > 0 {
m = memory.New(memorysize)
memoryGCAlgo := gc.GCAlgorithm(memoryGC)
if memoryGCAlgo == "" {
memoryGCAlgo = gc.LRU // default to LRU
}
mgc = gc.New(m, memoryGCAlgo)
}
var d *disk.DiskFS
var dgc *gc.GCFS
if disksize > 0 {
d = disk.New(diskPath, disksize)
diskGCAlgo := gc.GCAlgorithm(diskGC)
if diskGCAlgo == "" {
diskGCAlgo = gc.LRU // default to LRU
}
dgc = gc.New(d, diskGCAlgo)
}
// configure the cache to match the specified mode (memory only, disk only, or memory and disk) based on the provided sizes
if disksize == 0 && memorysize != 0 {
//memory only mode - no disk
c.SetSlow(mgc)
} else if disksize != 0 && memorysize == 0 {
// disk only mode
c.SetSlow(dgc)
} else if disksize != 0 && memorysize != 0 {
// memory and disk mode
c.SetFast(mgc)
c.SetSlow(dgc)
} else {
// no memory or disk isn't a valid configuration
logger.Logger.Error().Bool("memory", false).Bool("disk", false).Msg("configuration invalid :( exiting")
os.Exit(1)
}
transport := &http.Transport{
MaxIdleConns: 200, // Increased from 100
MaxIdleConnsPerHost: 50, // Increased from 10
IdleConnTimeout: 120 * time.Second, // Increased from 90s
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 15 * time.Second, // Increased from 10s
ResponseHeaderTimeout: 30 * time.Second, // Increased from 10s
ExpectContinueTimeout: 5 * time.Second, // Increased from 1s
DisableCompression: true, // Steam doesn't use compression
ForceAttemptHTTP2: true, // Enable HTTP/2 if available
}
client := &http.Client{
Transport: transport,
Timeout: 120 * time.Second, // Increased from 60s
}
sc := &SteamCache{
upstream: upstream,
address: address,
vfs: c,
memory: m,
disk: d,
memorygc: mgc,
diskgc: dgc,
client: client,
server: &http.Server{
Addr: address,
ReadTimeout: 30 * time.Second, // Increased
WriteTimeout: 60 * time.Second, // Increased
IdleTimeout: 120 * time.Second, // Good for keep-alive
ReadHeaderTimeout: 10 * time.Second, // New, for header attacks
MaxHeaderBytes: 1 << 20, // 1MB, optional
},
// Initialize concurrency control fields
coalescedRequests: make(map[string]*coalescedRequest),
maxConcurrentRequests: maxConcurrentRequests,
requestSemaphore: semaphore.NewWeighted(maxConcurrentRequests),
clientRequests: make(map[string]*clientLimiter),
maxRequestsPerClient: maxRequestsPerClient,
// Initialize service management
serviceManager: NewServiceManager(),
}
// Log GC algorithm configuration
if m != nil {
logger.Logger.Info().Str("memory_gc", memoryGC).Msg("Memory cache GC algorithm configured")
}
if d != nil {
logger.Logger.Info().Str("disk_gc", diskGC).Msg("Disk cache GC algorithm configured")
}
if d != nil {
if d.Size() > d.Capacity() {
gcHandler := gc.GetGCAlgorithm(gc.GCAlgorithm(diskGC))
gcHandler(d, uint(d.Size()-d.Capacity()))
}
}
return sc
}
func (sc *SteamCache) Run() {
if sc.upstream != "" {
resp, err := sc.client.Get(sc.upstream)
if err != nil || resp.StatusCode != http.StatusOK {
logger.Logger.Error().Err(err).Int("status_code", resp.StatusCode).Str("upstream", sc.upstream).Msg("Failed to connect to upstream server")
os.Exit(1)
}
resp.Body.Close()
}
sc.server.Handler = sc
ctx, cancel := context.WithCancel(context.Background())
sc.cancel = cancel
// Start cleanup goroutine for old client limiters
sc.wg.Add(1)
go func() {
defer sc.wg.Done()
sc.cleanupOldClientLimiters()
}()
sc.wg.Add(1)
go func() {
defer sc.wg.Done()
err := sc.server.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
logger.Logger.Error().Err(err).Msg("Failed to start SteamCache2")
os.Exit(1)
}
}()
<-ctx.Done()
sc.server.Shutdown(ctx)
sc.wg.Wait()
}
func (sc *SteamCache) Shutdown() {
if sc.cancel != nil {
sc.cancel()
}
sc.wg.Wait()
}
func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Apply global concurrency limit first
if err := sc.requestSemaphore.Acquire(context.Background(), 1); err != nil {
logger.Logger.Warn().Str("client_ip", getClientIP(r)).Msg("Server at capacity, rejecting request")
http.Error(w, "Server busy, please try again later", http.StatusServiceUnavailable)
return
}
defer sc.requestSemaphore.Release(1)
// Apply per-client rate limiting
clientIP := getClientIP(r)
clientLimiter := sc.getOrCreateClientLimiter(clientIP)
if err := clientLimiter.semaphore.Acquire(context.Background(), 1); err != nil {
logger.Logger.Warn().
Str("client_ip", clientIP).
Int("max_per_client", int(sc.maxRequestsPerClient)).
Msg("Client exceeded concurrent request limit")
http.Error(w, "Too many concurrent requests from this client", http.StatusTooManyRequests)
return
}
defer clientLimiter.semaphore.Release(1)
if r.Method != http.MethodGet {
logger.Logger.Warn().
Str("method", r.Method).
Str("client_ip", clientIP).
Msg("Only GET method is supported")
http.Error(w, "Only GET method is supported", http.StatusMethodNotAllowed)
return
}
if r.URL.Path == "/" {
logger.Logger.Debug().
Str("client_ip", clientIP).
Msg("Health check request")
w.WriteHeader(http.StatusOK) // this is used by steamcache2's upstream verification at startup
return
}
if r.URL.String() == "/lancache-heartbeat" {
logger.Logger.Debug().
Str("client_ip", clientIP).
Msg("LanCache heartbeat request")
w.Header().Add("X-LanCache-Processed-By", "SteamCache2")
w.WriteHeader(http.StatusNoContent)
w.Write(nil)
return
}
// Check if this is a request from a supported service
if service, isSupported := sc.detectService(r); isSupported {
// trim the query parameters from the URL path
// this is necessary because the cache key should not include query parameters
urlPath, _, _ := strings.Cut(r.URL.String(), "?")
tstart := time.Now()
// Generate service cache key: {service}/{hash} (prefix indicates service via User-Agent)
cacheKey := generateServiceCacheKey(urlPath, service.Prefix)
if cacheKey == "" {
logger.Logger.Warn().Str("url", urlPath).Msg("Invalid URL")
http.Error(w, "Invalid URL", http.StatusBadRequest)
return
}
w.Header().Add("X-LanCache-Processed-By", "SteamCache2") // SteamPrefill uses this header to determine if the request was processed by the cache maybe steam uses it too
cachePath := cacheKey // You may want to add a .http or .cache extension for clarity
logger.Logger.Debug().
Str("url", urlPath).
Str("key", cacheKey).
Str("client_ip", clientIP).
Msg("Generated cache key")
// Try to serve from cache
file, err := sc.vfs.Open(cachePath)
if err == nil {
defer file.Close()
// Read the entire cached file
cachedData, err := io.ReadAll(file)
if err != nil {
logger.Logger.Warn().
Str("key", cacheKey).
Str("url", urlPath).
Err(err).
Msg("Failed to read cached file - removing corrupted entry")
sc.vfs.Delete(cachePath)
} else {
// Deserialize using new format
cacheFile, err := deserializeCacheFile(cachedData)
if err != nil {
// Cache file is corrupted or invalid format
logger.Logger.Warn().
Str("key", cacheKey).
Str("url", urlPath).
Err(err).
Msg("Failed to deserialize cache file - removing corrupted entry")
sc.vfs.Delete(cachePath)
} else {
// Cache validation passed
logger.Logger.Debug().
Str("key", cacheKey).
Str("url", urlPath).
Str("content_hash", cacheFile.ContentHash).
Msg("Successfully loaded from cache")
// Stream the raw HTTP response directly
sc.streamCachedResponse(w, r, cacheFile, cacheKey, clientIP, tstart)
return
}
}
// If we reach here, cache validation failed and we need to fetch from upstream
}
// Check for coalesced request (another client already downloading this)
coalescedReq, isNew := sc.getOrCreateCoalescedRequest(cacheKey)
if !isNew {
// Wait for the existing download to complete
logger.Logger.Debug().
Str("key", cacheKey).
Str("url", urlPath).
Str("client_ip", clientIP).
Int("waiting_clients", coalescedReq.waitingCount).
Msg("Joining coalesced request")
select {
case resp := <-coalescedReq.responseChan:
// Use the downloaded response
defer resp.Body.Close()
// For coalesced clients, we need to make a new request to get fresh data
// since the original response body was consumed by the first client
freshReq, err := http.NewRequest(http.MethodGet, r.URL.String(), nil)
if err != nil {
logger.Logger.Error().
Err(err).
Str("key", cacheKey).
Str("url", urlPath).
Str("client_ip", clientIP).
Msg("Failed to create fresh request for coalesced client")
http.Error(w, "Failed to fetch data", http.StatusInternalServerError)
return
}
// Copy original headers
for k, vv := range r.Header {
freshReq.Header[k] = vv
}
freshResp, err := sc.client.Do(freshReq)
if err != nil {
logger.Logger.Error().
Err(err).
Str("key", cacheKey).
Str("url", urlPath).
Str("client_ip", clientIP).
Msg("Failed to fetch fresh data for coalesced client")
http.Error(w, "Failed to fetch data", http.StatusInternalServerError)
return
}
defer freshResp.Body.Close()
// Serve the fresh response
for k, vv := range freshResp.Header {
if _, skip := hopByHopHeaders[http.CanonicalHeaderKey(k)]; skip {
continue
}
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.Header().Set("X-LanCache-Status", "HIT-COALESCED")
w.Header().Set("X-LanCache-Processed-By", "SteamCache2")
w.WriteHeader(freshResp.StatusCode)
io.Copy(w, freshResp.Body)
logger.Logger.Info().
Str("key", cacheKey).
Str("url", urlPath).
Str("host", r.Host).
Str("client_ip", clientIP).
Str("status", "HIT-COALESCED").
Dur("zduration", time.Since(tstart)).
Msg("cache request")
return
case err := <-coalescedReq.errorChan:
logger.Logger.Error().
Err(err).
Str("key", cacheKey).
Str("url", urlPath).
Str("client_ip", clientIP).
Msg("Coalesced request failed")
http.Error(w, "Upstream request failed", http.StatusInternalServerError)
return
}
}
// Remove coalesced request when done
defer sc.removeCoalescedRequest(cacheKey)
var req *http.Request
if sc.upstream != "" { // if an upstream server is configured, proxy the request to the upstream server
ur, err := url.JoinPath(sc.upstream, urlPath)
if err != nil {
logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to join URL path")
http.Error(w, "Failed to join URL path", http.StatusInternalServerError)
return
}
req, err = http.NewRequest(http.MethodGet, ur, nil)
if err != nil {
logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to create request")
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
req.Host = r.Host
} else { // if no upstream server is configured, proxy the request to the host specified in the request
host := r.Host
if r.Header.Get("X-Sls-Https") == "enable" {
host = "https://" + host
} else {
host = "http://" + host
}
ur, err := url.JoinPath(host, urlPath)
if err != nil {
logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to join URL path")
http.Error(w, "Failed to join URL path", http.StatusInternalServerError)
return
}
req, err = http.NewRequest(http.MethodGet, ur, nil)
if err != nil {
logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to create request")
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
}
// Copy headers from the original request to the new request
// BUT exclude Range headers - we always want to cache the full file
for key, values := range r.Header {
// Skip Range headers to ensure we always cache the complete file
if strings.ToLower(key) == "range" {
logger.Logger.Debug().
Str("key", cacheKey).
Str("url", urlPath).
Str("range_header", values[0]).
Msg("Skipping Range header to cache full file")
continue
}
for _, value := range values {
req.Header.Add(key, value)
}
}
// Retry logic
backoffSchedule := []time.Duration{1 * time.Second, 3 * time.Second, 10 * time.Second}
var resp *http.Response
for i, backoff := range backoffSchedule {
resp, err = sc.client.Do(req)
if err == nil && resp.StatusCode == http.StatusOK {
break
}
if i < len(backoffSchedule)-1 {
time.Sleep(backoff)
}
}
if err != nil || resp.StatusCode != http.StatusOK {
logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to fetch the requested URL")
// Complete coalesced request with error
if isNew {
coalescedReq.complete(nil, err)
}
http.Error(w, "Failed to fetch the requested URL", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
// Fast path: Flexible lightweight validation for all files
// Multiple validation layers ensure data integrity without blocking legitimate Steam content
// Method 1: HTTP Status Validation
if resp.StatusCode != http.StatusOK {
logger.Logger.Error().
Str("url", req.URL.String()).
Int("status_code", resp.StatusCode).
Msg("Steam returned non-OK status")
http.Error(w, "Upstream server error", http.StatusBadGateway)
return
}
// Method 2: Content-Type Validation (Steam files can be various types)
contentType := resp.Header.Get("Content-Type")
if contentType != "" {
// Log the content type for monitoring, but don't restrict based on it
// Steam serves different content types: chunks, manifests, patches, etc.
logger.Logger.Debug().
Str("url", req.URL.String()).
Str("content_type", contentType).
Str("service", service.Name).
Msg("Content type from upstream")
}
// Method 3: Content-Length Validation
expectedSize := resp.ContentLength
// Reject only truly invalid content lengths (zero or negative)
if expectedSize <= 0 {
logger.Logger.Error().
Str("url", req.URL.String()).
Int64("content_length", expectedSize).
Msg("Invalid content length, rejecting file")
http.Error(w, "Invalid content length", http.StatusBadGateway)
return
}
// Content length is valid - no size restrictions to keep logs clean
// Lightweight validation passed - trust the Content-Length and HTTP status
// This provides good integrity with minimal performance overhead
validationPassed := true
// Read the entire response body into memory to avoid consuming it twice
bodyData, err := io.ReadAll(resp.Body)
if err != nil {
logger.Logger.Error().
Err(err).
Str("url", req.URL.String()).
Msg("Failed to read response body")
http.Error(w, "Failed to read response", http.StatusInternalServerError)
return
}
resp.Body.Close() // Close the original body since we've read it
// Reconstruct the exact HTTP response as received from upstream
rawResponse := sc.reconstructRawResponse(resp, bodyData)
// Write to response
// Remove hop-by-hop and server-specific headers
for k, vv := range resp.Header {
if _, skip := hopByHopHeaders[http.CanonicalHeaderKey(k)]; skip {
continue
}
for _, v := range vv {
w.Header().Add(k, v)
}
}
// Add our own headers
w.Header().Set("X-LanCache-Status", "MISS")
w.Header().Set("X-LanCache-Processed-By", "SteamCache2")
// Stream the response body to client
w.WriteHeader(resp.StatusCode)
w.Write(bodyData)
// Cache the file if validation passed
if validationPassed {
// Verify we received the complete file by checking Content-Length
if !sc.verifyCompleteFile(bodyData, resp, urlPath, cacheKey) {
logger.Logger.Warn().
Str("key", cacheKey).
Str("url", urlPath).
Int("received_bytes", len(bodyData)).
Int64("expected_bytes", resp.ContentLength).
Msg("Incomplete file received - not caching to allow retry")
return
}
// Serialize the raw response using our new cache format
cacheData, err := serializeRawResponse(urlPath, rawResponse, "", "")
if err != nil {
logger.Logger.Warn().
Str("key", cacheKey).
Str("url", urlPath).
Err(err).
Msg("Failed to serialize cache file")
} else {
// Store the serialized cache data
cacheWriter, err := sc.vfs.Create(cachePath, int64(len(cacheData)))
if err == nil {
defer cacheWriter.Close()
// Write the serialized cache data
bytesWritten, cacheErr := cacheWriter.Write(cacheData)
if cacheErr != nil || bytesWritten != len(cacheData) {
logger.Logger.Warn().
Str("key", cacheKey).
Str("url", urlPath).
Int("expected", len(cacheData)).
Int("written", bytesWritten).
Err(cacheErr).
Msg("Cache write failed or incomplete - removing corrupted entry")
sc.vfs.Delete(cachePath)
} else {
logger.Logger.Debug().
Str("key", cacheKey).
Str("url", urlPath).
Str("service", service.Name).
Int("size", bytesWritten).
Msg("Successfully cached response")
}
} else {
logger.Logger.Warn().
Str("key", cacheKey).
Str("url", urlPath).
Err(err).
Msg("Failed to create cache file")
}
}
// Complete coalesced request with the original response
if isNew {
coalescedResp := &http.Response{
StatusCode: resp.StatusCode,
Status: resp.Status,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(bodyData)), // Buffered body for coalesced clients
}
for k, vv := range resp.Header {
coalescedResp.Header[k] = vv
}
coalescedReq.complete(coalescedResp, nil)
}
} else {
logger.Logger.Warn().
Str("key", cacheKey).
Str("url", urlPath).
Err(err).
Msg("Failed to create cache file")
// Complete coalesced request with buffered body even if cache creation failed
if isNew {
coalescedResp := &http.Response{
StatusCode: resp.StatusCode,
Status: resp.Status,
Header: make(http.Header),
Body: io.NopCloser(bytes.NewReader(bodyData)), // Use buffered body
}
for k, vv := range resp.Header {
coalescedResp.Header[k] = vv
}
coalescedReq.complete(coalescedResp, nil)
}
}
logger.Logger.Info().
Str("key", cacheKey).
Str("url", urlPath).
Str("host", r.Host).
Str("client_ip", clientIP).
Str("status", "MISS").
Dur("zduration", time.Since(tstart)).
Msg("cache request")
return
}
// Handle favicon requests
if r.URL.Path == "/favicon.ico" {
logger.Logger.Debug().
Str("client_ip", clientIP).
Msg("Favicon request")
w.WriteHeader(http.StatusNoContent)
return
}
if r.URL.Path == "/robots.txt" {
logger.Logger.Debug().
Str("client_ip", clientIP).
Msg("Robots.txt request")
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("User-agent: *\nDisallow: /\n"))
return
}
logger.Logger.Warn().
Str("url", r.URL.String()).
Str("client_ip", clientIP).
Msg("Request not found")
http.Error(w, "Not found", http.StatusNotFound)
}