From 9ca8fa4a5ebb3f79ad9552e30bca69a292ee97d3 Mon Sep 17 00:00:00 2001 From: Justin Harms Date: Tue, 2 Sep 2025 06:50:42 -0500 Subject: [PATCH] Add concurrency limits and configuration options for SteamCache - Introduced maxConcurrentRequests and maxRequestsPerClient fields in the Config struct to manage request limits. - Updated the SteamCache implementation to utilize these new configuration options for controlling concurrent requests. - Enhanced the ServeHTTP method to enforce global and per-client rate limiting using semaphores. - Modified the root command to accept new flags for configuring concurrency limits via command-line arguments. - Updated tests to reflect changes in the SteamCache initialization and request handling logic. --- cmd/root.go | 19 ++++ config/config.go | 20 +++- go.mod | 1 + go.sum | 2 + steamcache/steamcache.go | 184 ++++++++++++++++++++++++++++++---- steamcache/steamcache_test.go | 6 +- 6 files changed, 203 insertions(+), 29 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index d8ca81e..c2282ae 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -19,6 +19,9 @@ var ( logLevel string logFormat string + + maxConcurrentRequests int64 + maxRequestsPerClient int64 ) var rootCmd = &cobra.Command{ @@ -94,6 +97,17 @@ var rootCmd = &cobra.Command{ Str("config_path", configPath). Msg("Configuration loaded successfully") + // Use command-line flags if provided, otherwise use config values + finalMaxConcurrentRequests := cfg.MaxConcurrentRequests + if maxConcurrentRequests > 0 { + finalMaxConcurrentRequests = maxConcurrentRequests + } + + finalMaxRequestsPerClient := cfg.MaxRequestsPerClient + if maxRequestsPerClient > 0 { + finalMaxRequestsPerClient = maxRequestsPerClient + } + sc := steamcache.New( cfg.ListenAddress, cfg.Cache.Memory.Size, @@ -102,6 +116,8 @@ var rootCmd = &cobra.Command{ cfg.Upstream, cfg.Cache.Memory.GCAlgorithm, cfg.Cache.Disk.GCAlgorithm, + finalMaxConcurrentRequests, + finalMaxRequestsPerClient, ) logger.Logger.Info(). @@ -128,4 +144,7 @@ func init() { rootCmd.Flags().StringVarP(&logLevel, "log-level", "l", "info", "Logging level: debug, info, error") rootCmd.Flags().StringVarP(&logFormat, "log-format", "f", "console", "Logging format: json, console") + + rootCmd.Flags().Int64Var(&maxConcurrentRequests, "max-concurrent-requests", 0, "Maximum concurrent requests (0 = use config file value)") + rootCmd.Flags().Int64Var(&maxRequestsPerClient, "max-requests-per-client", 0, "Maximum concurrent requests per client IP (0 = use config file value)") } diff --git a/config/config.go b/config/config.go index b2efc62..62501fe 100644 --- a/config/config.go +++ b/config/config.go @@ -11,6 +11,10 @@ type Config struct { // Server configuration ListenAddress string `yaml:"listen_address" default:":80"` + // Concurrency limits + MaxConcurrentRequests int64 `yaml:"max_concurrent_requests" default:"200"` + MaxRequestsPerClient int64 `yaml:"max_requests_per_client" default:"5"` + // Cache configuration Cache CacheConfig `yaml:"cache"` @@ -65,6 +69,12 @@ func LoadConfig(configPath string) (*Config, error) { if config.ListenAddress == "" { config.ListenAddress = ":80" } + if config.MaxConcurrentRequests == 0 { + config.MaxConcurrentRequests = 50 + } + if config.MaxRequestsPerClient == 0 { + config.MaxRequestsPerClient = 3 + } if config.Cache.Memory.Size == "" { config.Cache.Memory.Size = "0" } @@ -88,16 +98,18 @@ func SaveDefaultConfig(configPath string) error { } defaultConfig := Config{ - ListenAddress: ":80", + ListenAddress: ":80", + MaxConcurrentRequests: 50, // Reduced for home user (less concurrent load) + MaxRequestsPerClient: 3, // Reduced for home user (more conservative per client) Cache: CacheConfig{ Memory: MemoryConfig{ - Size: "1GB", + Size: "1GB", // Recommended for systems that can spare 1GB RAM for caching GCAlgorithm: "lru", }, Disk: DiskConfig{ - Size: "10GB", + Size: "1TB", // Large HDD cache for home user Path: "./disk", - GCAlgorithm: "hybrid", + GCAlgorithm: "lru", // Better for gaming patterns (keeps recently played games) }, }, Upstream: "", diff --git a/go.mod b/go.mod index 50da213..bcfbf30 100644 --- a/go.mod +++ b/go.mod @@ -15,5 +15,6 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.12.0 // indirect ) diff --git a/go.sum b/go.sum index 069cc6e..51aad9f 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,8 @@ github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= diff --git a/steamcache/steamcache.go b/steamcache/steamcache.go index a673de4..77526d6 100644 --- a/steamcache/steamcache.go +++ b/steamcache/steamcache.go @@ -22,6 +22,7 @@ import ( "time" "github.com/docker/go-units" + "golang.org/x/sync/semaphore" ) // generateURLHash creates a SHA256 hash of the entire URL path for cache key @@ -56,12 +57,17 @@ var hopByHopHeaders = map[string]struct{}{ "Server": {}, } -var ( - // Request coalescing structures - coalescedRequests = make(map[string]*coalescedRequest) - coalescedRequestsMu sync.RWMutex +// Constants for limits +const ( + defaultMaxConcurrentRequests = int64(200) // Max total concurrent requests + defaultMaxRequestsPerClient = int64(5) // Max concurrent requests per IP ) +type clientLimiter struct { + semaphore *semaphore.Weighted + lastSeen time.Time +} + type coalescedRequest struct { responseChan chan *http.Response errorChan chan error @@ -107,25 +113,84 @@ func (cr *coalescedRequest) complete(resp *http.Response, err error) { } // getOrCreateCoalescedRequest gets an existing coalesced request or creates a new one -func getOrCreateCoalescedRequest(cacheKey string) (*coalescedRequest, bool) { - coalescedRequestsMu.Lock() - defer coalescedRequestsMu.Unlock() +func (sc *SteamCache) getOrCreateCoalescedRequest(cacheKey string) (*coalescedRequest, bool) { + sc.coalescedRequestsMu.Lock() + defer sc.coalescedRequestsMu.Unlock() - if cr, exists := coalescedRequests[cacheKey]; exists { + if cr, exists := sc.coalescedRequests[cacheKey]; exists { cr.addWaiter() return cr, false } cr := newCoalescedRequest() - coalescedRequests[cacheKey] = cr + sc.coalescedRequests[cacheKey] = cr return cr, true } // removeCoalescedRequest removes a completed coalesced request -func removeCoalescedRequest(cacheKey string) { - coalescedRequestsMu.Lock() - defer coalescedRequestsMu.Unlock() - delete(coalescedRequests, cacheKey) +func (sc *SteamCache) removeCoalescedRequest(cacheKey string) { + sc.coalescedRequestsMu.Lock() + defer sc.coalescedRequestsMu.Unlock() + delete(sc.coalescedRequests, cacheKey) +} + +// getClientIP extracts the client IP address from the request +func getClientIP(r *http.Request) string { + // Check for forwarded headers first (common in proxy setups) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + // X-Forwarded-For can contain multiple IPs, take the first one + if idx := strings.Index(xff, ","); idx > 0 { + return strings.TrimSpace(xff[:idx]) + } + return strings.TrimSpace(xff) + } + + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return strings.TrimSpace(xri) + } + + // Fall back to RemoteAddr + if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + return host + } + + return r.RemoteAddr +} + +// getOrCreateClientLimiter gets or creates a rate limiter for a client IP +func (sc *SteamCache) getOrCreateClientLimiter(clientIP string) *clientLimiter { + sc.clientRequestsMu.Lock() + defer sc.clientRequestsMu.Unlock() + + limiter, exists := sc.clientRequests[clientIP] + if !exists || time.Since(limiter.lastSeen) > 5*time.Minute { + // Create new limiter or refresh existing one + limiter = &clientLimiter{ + semaphore: semaphore.NewWeighted(sc.maxRequestsPerClient), + lastSeen: time.Now(), + } + sc.clientRequests[clientIP] = limiter + } else { + limiter.lastSeen = time.Now() + } + + return limiter +} + +// cleanupOldClientLimiters removes old client limiters to prevent memory leaks +func (sc *SteamCache) cleanupOldClientLimiters() { + for { + time.Sleep(10 * time.Minute) // Clean up every 10 minutes + + sc.clientRequestsMu.Lock() + now := time.Now() + for ip, limiter := range sc.clientRequests { + if now.Sub(limiter.lastSeen) > 30*time.Minute { + delete(sc.clientRequests, ip) + } + } + sc.clientRequestsMu.Unlock() + } } type SteamCache struct { @@ -144,9 +209,22 @@ type SteamCache struct { client *http.Client cancel context.CancelFunc wg sync.WaitGroup + + // Request coalescing structures + coalescedRequests map[string]*coalescedRequest + coalescedRequestsMu sync.RWMutex + + // Concurrency control + maxConcurrentRequests int64 + requestSemaphore *semaphore.Weighted + + // Per-client rate limiting + clientRequests map[string]*clientLimiter + clientRequestsMu sync.RWMutex + maxRequestsPerClient int64 } -func New(address string, memorySize string, diskSize string, diskPath, upstream, memoryGC, diskGC string) *SteamCache { +func New(address string, memorySize string, diskSize string, diskPath, upstream, memoryGC, diskGC string, maxConcurrentRequests int64, maxRequestsPerClient int64) *SteamCache { memorysize, err := units.FromHumanSize(memorySize) if err != nil { panic(err) @@ -238,6 +316,13 @@ func New(address string, memorySize string, diskSize string, diskPath, upstream, ReadHeaderTimeout: 10 * time.Second, // New, for header attacks MaxHeaderBytes: 1 << 20, // 1MB, optional }, + + // Initialize concurrency control fields + coalescedRequests: make(map[string]*coalescedRequest), + maxConcurrentRequests: maxConcurrentRequests, + requestSemaphore: semaphore.NewWeighted(maxConcurrentRequests), + clientRequests: make(map[string]*clientLimiter), + maxRequestsPerClient: maxRequestsPerClient, } // Log GC algorithm configuration @@ -272,6 +357,13 @@ func (sc *SteamCache) Run() { ctx, cancel := context.WithCancel(context.Background()) sc.cancel = cancel + // Start cleanup goroutine for old client limiters + sc.wg.Add(1) + go func() { + defer sc.wg.Done() + sc.cleanupOldClientLimiters() + }() + sc.wg.Add(1) go func() { defer sc.wg.Done() @@ -295,18 +387,49 @@ func (sc *SteamCache) Shutdown() { } func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Apply global concurrency limit first + if err := sc.requestSemaphore.Acquire(context.Background(), 1); err != nil { + logger.Logger.Warn().Str("client_ip", getClientIP(r)).Msg("Server at capacity, rejecting request") + http.Error(w, "Server busy, please try again later", http.StatusServiceUnavailable) + return + } + defer sc.requestSemaphore.Release(1) + + // Apply per-client rate limiting + clientIP := getClientIP(r) + clientLimiter := sc.getOrCreateClientLimiter(clientIP) + + if err := clientLimiter.semaphore.Acquire(context.Background(), 1); err != nil { + logger.Logger.Warn(). + Str("client_ip", clientIP). + Int("max_per_client", int(sc.maxRequestsPerClient)). + Msg("Client exceeded concurrent request limit") + http.Error(w, "Too many concurrent requests from this client", http.StatusTooManyRequests) + return + } + defer clientLimiter.semaphore.Release(1) + if r.Method != http.MethodGet { - logger.Logger.Warn().Str("method", r.Method).Msg("Only GET method is supported") + logger.Logger.Warn(). + Str("method", r.Method). + Str("client_ip", clientIP). + Msg("Only GET method is supported") http.Error(w, "Only GET method is supported", http.StatusMethodNotAllowed) return } if r.URL.Path == "/" { + logger.Logger.Debug(). + Str("client_ip", clientIP). + Msg("Health check request") w.WriteHeader(http.StatusOK) // this is used by steamcache2's upstream verification at startup return } if r.URL.String() == "/lancache-heartbeat" { + logger.Logger.Debug(). + Str("client_ip", clientIP). + Msg("LanCache heartbeat request") w.Header().Add("X-LanCache-Processed-By", "SteamCache2") w.WriteHeader(http.StatusNoContent) w.Write(nil) @@ -358,19 +481,21 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { logger.Logger.Info(). Str("key", cacheKey). Str("host", r.Host). + Str("client_ip", clientIP). Str("status", "HIT"). Dur("duration", time.Since(tstart)). - Msg("request") + Msg("cache request") return } } // Check for coalesced request (another client already downloading this) - coalescedReq, isNew := getOrCreateCoalescedRequest(cacheKey) + coalescedReq, isNew := sc.getOrCreateCoalescedRequest(cacheKey) if !isNew { // Wait for the existing download to complete logger.Logger.Debug(). Str("key", cacheKey). + Str("client_ip", clientIP). Int("waiting_clients", coalescedReq.waitingCount). Msg("Joining coalesced request") @@ -402,21 +527,26 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { logger.Logger.Info(). Str("key", cacheKey). Str("host", r.Host). + Str("client_ip", clientIP). Str("status", "HIT-COALESCED"). Dur("duration", time.Since(tstart)). - Msg("request") + Msg("cache request") return case err := <-coalescedReq.errorChan: - logger.Logger.Error().Err(err).Str("key", cacheKey).Msg("Coalesced request failed") + logger.Logger.Error(). + Err(err). + Str("key", cacheKey). + Str("client_ip", clientIP). + Msg("Coalesced request failed") http.Error(w, "Upstream request failed", http.StatusInternalServerError) return } } // Remove coalesced request when done - defer removeCoalescedRequest(cacheKey) + defer sc.removeCoalescedRequest(cacheKey) var req *http.Request if sc.upstream != "" { // if an upstream server is configured, proxy the request to the upstream server @@ -590,25 +720,35 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { logger.Logger.Info(). Str("key", cacheKey). Str("host", r.Host). + Str("client_ip", clientIP). Str("status", "MISS"). Dur("duration", time.Since(tstart)). - Msg("request") + Msg("cache request") return } if r.URL.Path == "/favicon.ico" { + logger.Logger.Debug(). + Str("client_ip", clientIP). + Msg("Favicon request") w.WriteHeader(http.StatusNoContent) return } if r.URL.Path == "/robots.txt" { + logger.Logger.Debug(). + Str("client_ip", clientIP). + Msg("Robots.txt request") w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte("User-agent: *\nDisallow: /\n")) return } - logger.Logger.Warn().Str("url", r.URL.String()).Msg("Not found") + logger.Logger.Warn(). + Str("url", r.URL.String()). + Str("client_ip", clientIP). + Msg("Request not found") http.Error(w, "Not found", http.StatusNotFound) } diff --git a/steamcache/steamcache_test.go b/steamcache/steamcache_test.go index dc5e017..aab0074 100644 --- a/steamcache/steamcache_test.go +++ b/steamcache/steamcache_test.go @@ -14,7 +14,7 @@ func TestCaching(t *testing.T) { os.WriteFile(filepath.Join(td, "key2"), []byte("value2"), 0644) - sc := New("localhost:8080", "1G", "1G", td, "", "lru", "lru") + sc := New("localhost:8080", "1G", "1G", td, "", "lru", "lru", 200, 5) w, err := sc.vfs.Create("key", 5) if err != nil { @@ -85,7 +85,7 @@ func TestCaching(t *testing.T) { } func TestCacheMissAndHit(t *testing.T) { - sc := New("localhost:8080", "0", "1G", t.TempDir(), "", "lru", "lru") + sc := New("localhost:8080", "0", "1G", t.TempDir(), "", "lru", "lru", 200, 5) key := "testkey" value := []byte("testvalue") @@ -166,7 +166,7 @@ func TestURLHashing(t *testing.T) { // Removed hash calculation tests since we switched to lightweight validation func TestSteamKeySharding(t *testing.T) { - sc := New("localhost:8080", "0", "1G", t.TempDir(), "", "lru", "lru") + sc := New("localhost:8080", "0", "1G", t.TempDir(), "", "lru", "lru", 200, 5) // Test with a Steam-style key that should trigger sharding steamKey := "steam/0016cfc5019b8baa6026aa1cce93e685d6e06c6e"