diff --git a/cmd/root.go b/cmd/root.go index f091ceb..294f87b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -66,7 +66,7 @@ var rootCmd = &cobra.Command{ ) logger.Logger.Info(). - Msg("starting SteamCache2 on port 80") + Msg("SteamCache2 listening on port 80") sc.Run() diff --git a/steamcache/steamcache.go b/steamcache/steamcache.go index fca1ef8..b501fc6 100644 --- a/steamcache/steamcache.go +++ b/steamcache/steamcache.go @@ -179,131 +179,151 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - tstart := time.Now() + if strings.HasPrefix(r.URL.String(), "/depot/") { - cacheKey := strings.ReplaceAll(r.URL.String()[1:], "\\", "/") // replace all backslashes with forward slashes shouldn't be necessary but just in case - if cacheKey == "" { - requestsTotal.WithLabelValues(r.Method, "400").Inc() - logger.Logger.Warn().Str("url", r.URL.String()).Msg("Invalid URL") - http.Error(w, "Invalid URL", http.StatusBadRequest) - return - } + tstart := time.Now() - w.Header().Add("X-LanCache-Processed-By", "SteamCache2") // SteamPrefill uses this header to determine if the request was processed by the cache maybe steam uses it too + cacheKey := strings.ReplaceAll(r.URL.String()[1:], "\\", "/") // replace all backslashes with forward slashes shouldn't be necessary but just in case + if cacheKey == "" { + requestsTotal.WithLabelValues(r.Method, "400").Inc() + logger.Logger.Warn().Str("url", r.URL.String()).Msg("Invalid URL") + http.Error(w, "Invalid URL", http.StatusBadRequest) + return + } - data, err := sc.vfs.Get(cacheKey) - if err == nil { - sc.hits.Add(cachestate.CacheStateHit) - w.Header().Add("X-LanCache-Status", "HIT") + w.Header().Add("X-LanCache-Processed-By", "SteamCache2") // SteamPrefill uses this header to determine if the request was processed by the cache maybe steam uses it too + + data, err := sc.vfs.Get(cacheKey) + if err == nil { + sc.hits.Add(cachestate.CacheStateHit) + w.Header().Add("X-LanCache-Status", "HIT") + requestsTotal.WithLabelValues(r.Method, "200").Inc() + cacheHitRate.Set(sc.hits.Avg()) + + w.Write(data) + + logger.Logger.Info(). + Str("key", cacheKey). + Str("host", r.Host). + Str("status", "HIT"). + Int64("size", int64(len(data))). + Dur("duration", time.Since(tstart)). + Msg("request") + return + } + + var req *http.Request + if sc.upstream != "" { // if an upstream server is configured, proxy the request to the upstream server + ur, err := url.JoinPath(sc.upstream, r.URL.String()) + if err != nil { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to join URL path") + http.Error(w, "Failed to join URL path", http.StatusInternalServerError) + return + } + + req, err = http.NewRequest(http.MethodGet, ur, nil) + if err != nil { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to create request") + http.Error(w, "Failed to create request", http.StatusInternalServerError) + return + } + req.Host = r.Host + } else { // if no upstream server is configured, proxy the request to the host specified in the request + host := r.Host + if r.Header.Get("X-Sls-Https") == "enable" { + host = "https://" + host + } else { + host = "http://" + host + } + + ur, err := url.JoinPath(host, r.URL.String()) + if err != nil { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to join URL path") + http.Error(w, "Failed to join URL path", http.StatusInternalServerError) + return + } + + req, err = http.NewRequest(http.MethodGet, ur, nil) + if err != nil { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to create request") + http.Error(w, "Failed to create request", http.StatusInternalServerError) + return + } + } + + // Copy headers from the original request to the new request + for key, values := range r.Header { + for _, value := range values { + req.Header.Add(key, value) + } + } + + // req.Header.Add("X-Sls-Https", r.Header.Get("X-Sls-Https")) + // req.Header.Add("User-Agent", r.Header.Get("User-Agent")) + + // Retry logic + backoffSchedule := []time.Duration{1 * time.Second, 3 * time.Second, 10 * time.Second} + var resp *http.Response + for i, backoff := range backoffSchedule { + resp, err = http.DefaultClient.Do(req) + if err == nil && resp.StatusCode == http.StatusOK { + break + } + if i < len(backoffSchedule)-1 { + time.Sleep(backoff) + } + } + if err != nil || resp.StatusCode != http.StatusOK { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to fetch the requested URL") + http.Error(w, "Failed to fetch the requested URL", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to read response body") + http.Error(w, "Failed to read response body", http.StatusInternalServerError) + return + } + + sc.vfs.Set(cacheKey, body) + sc.hits.Add(cachestate.CacheStateMiss) + w.Header().Add("X-LanCache-Status", "MISS") requestsTotal.WithLabelValues(r.Method, "200").Inc() cacheHitRate.Set(sc.hits.Avg()) - w.Write(data) + w.Write(body) logger.Logger.Info(). Str("key", cacheKey). Str("host", r.Host). - Str("status", "HIT"). - Int64("size", int64(len(data))). + Str("status", "MISS"). + Int64("size", int64(len(body))). Dur("duration", time.Since(tstart)). Msg("request") return } - var req *http.Request - if sc.upstream != "" { // if an upstream server is configured, proxy the request to the upstream server - ur, err := url.JoinPath(sc.upstream, r.URL.String()) - if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to join URL path") - http.Error(w, "Failed to join URL path", http.StatusInternalServerError) - return - } - - req, err = http.NewRequest(http.MethodGet, ur, nil) - if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to create request") - http.Error(w, "Failed to create request", http.StatusInternalServerError) - return - } - req.Host = r.Host - } else { // if no upstream server is configured, proxy the request to the host specified in the request - host := r.Host - if r.Header.Get("X-Sls-Https") == "enable" { - host = "https://" + host - } else { - host = "http://" + host - } - - ur, err := url.JoinPath(host, r.URL.String()) - if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to join URL path") - http.Error(w, "Failed to join URL path", http.StatusInternalServerError) - return - } - - req, err = http.NewRequest(http.MethodGet, ur, nil) - if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to create request") - http.Error(w, "Failed to create request", http.StatusInternalServerError) - return - } - } - - // Copy headers from the original request to the new request - for key, values := range r.Header { - for _, value := range values { - req.Header.Add(key, value) - } - } - - // req.Header.Add("X-Sls-Https", r.Header.Get("X-Sls-Https")) - // req.Header.Add("User-Agent", r.Header.Get("User-Agent")) - - // Retry logic - backoffSchedule := []time.Duration{1 * time.Second, 3 * time.Second, 10 * time.Second} - var resp *http.Response - for i, backoff := range backoffSchedule { - resp, err = http.DefaultClient.Do(req) - if err == nil && resp.StatusCode == http.StatusOK { - break - } - if i < len(backoffSchedule)-1 { - time.Sleep(backoff) - } - } - if err != nil || resp.StatusCode != http.StatusOK { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to fetch the requested URL") - http.Error(w, "Failed to fetch the requested URL", http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to read response body") - http.Error(w, "Failed to read response body", http.StatusInternalServerError) + if r.URL.Path == "/favicon.ico" { + w.WriteHeader(http.StatusNoContent) return } - sc.vfs.Set(cacheKey, body) - sc.hits.Add(cachestate.CacheStateMiss) - w.Header().Add("X-LanCache-Status", "MISS") - requestsTotal.WithLabelValues(r.Method, "200").Inc() - cacheHitRate.Set(sc.hits.Avg()) + if r.URL.Path == "/robots.txt" { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("User-agent: *\nDisallow: /\n")) + return + } - w.Write(body) - - logger.Logger.Info(). - Str("key", cacheKey). - Str("host", r.Host). - Str("status", "MISS"). - Int64("size", int64(len(body))). - Dur("duration", time.Since(tstart)). - Msg("request") + requestsTotal.WithLabelValues(r.Method, "404").Inc() + logger.Logger.Warn().Str("url", r.URL.String()).Msg("Not found") + http.Error(w, "Not found", http.StatusNotFound) }