Refactor caching functions and simplify response serialization
All checks were successful
Release Tag / release (push) Successful in 27s

- Updated the `downloadThroughCache` function to remove the upstream URL parameter, streamlining the caching process.
- Modified the `serializeRawResponse` function to eliminate unnecessary parameters, enhancing clarity and usability.
- Adjusted integration tests to align with the new function signatures, ensuring consistent testing of caching behavior.
This commit is contained in:
2025-09-21 22:55:49 -05:00
parent 45ae234694
commit 46495dc3aa
2 changed files with 5 additions and 306 deletions

View File

@@ -52,7 +52,7 @@ func testSteamURL(t *testing.T, urlPath string) {
directResp, directBody := downloadDirectly(t, steamURL)
// Test download through SteamCache
cacheResp, cacheBody := downloadThroughCache(t, sc, steamURL, urlPath)
cacheResp, cacheBody := downloadThroughCache(t, sc, urlPath)
// Compare responses
compareResponses(t, directResp, directBody, cacheResp, cacheBody, urlPath)
@@ -83,7 +83,7 @@ func downloadDirectly(t *testing.T, url string) (*http.Response, []byte) {
return resp, body
}
func downloadThroughCache(t *testing.T, sc *SteamCache, upstreamURL, urlPath string) (*http.Response, []byte) {
func downloadThroughCache(t *testing.T, sc *SteamCache, urlPath string) (*http.Response, []byte) {
// Create a test server for SteamCache
cacheServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// For real Steam URLs, we need to set the upstream to the Steam hostname
@@ -235,7 +235,7 @@ func TestCacheFileFormat(t *testing.T) {
rawResponse := sc.reconstructRawResponse(resp, bodyData)
// Serialize to cache format
cacheData, err := serializeRawResponse("/test/format", rawResponse, contentHash, "sha256")
cacheData, err := serializeRawResponse(rawResponse)
if err != nil {
t.Fatalf("Failed to serialize cache file: %v", err)
}

View File

@@ -4,7 +4,6 @@ package steamcache
import (
"bytes"
"context"
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"fmt"
@@ -140,7 +139,7 @@ type CacheFileFormat struct {
// serializeRawResponse serializes a raw HTTP response into our text-based cache format
// upstreamHash and upstreamAlgo are used for verification during download but not stored
func serializeRawResponse(url string, rawResponse []byte, upstreamHash string, upstreamAlgo string) ([]byte, error) {
func serializeRawResponse(rawResponse []byte) ([]byte, error) {
// Extract body from raw response for hash calculation
bodyStart := bytes.Index(rawResponse, []byte("\r\n\r\n"))
if bodyStart == -1 {
@@ -514,178 +513,6 @@ func calculateSHA256(data []byte) string {
return hex.EncodeToString(hasher.Sum(nil))
}
// calculateSHA1 calculates SHA1 hash of the given data (for legacy verification only)
func calculateSHA1(data []byte) string {
hasher := sha1.New()
hasher.Write(data)
return hex.EncodeToString(hasher.Sum(nil))
}
// extractHashFromURL extracts hash from URL path (Steam chunk URLs contain SHA1 hashes)
func extractHashFromURL(urlPath string) (hash string, algorithm string) {
// Steam chunk URLs: /depot/123/chunk/SHA1_HASH
// Steam manifest URLs: /depot/123/manifest/.../SHA1_HASH
// Steam patch URLs: /depot/123/patch/.../SHA1_HASH
// Look for chunk URLs with SHA1 hash
if strings.Contains(urlPath, "/chunk/") {
parts := strings.Split(urlPath, "/chunk/")
if len(parts) == 2 {
hashPart := parts[1]
// Remove any query parameters
if questionMark := strings.Index(hashPart, "?"); questionMark != -1 {
hashPart = hashPart[:questionMark]
}
// Check if it's a valid SHA1 hash (40 hex chars)
if len(hashPart) == 40 && isHexString(hashPart) {
logger.Logger.Debug().
Str("url_path", urlPath).
Str("sha1_hash", hashPart).
Msg("Extracted SHA1 hash from Steam chunk URL")
return hashPart, "sha1"
}
}
}
// Look for manifest URLs with SHA1 hash at the end
if strings.Contains(urlPath, "/manifest/") {
parts := strings.Split(urlPath, "/")
if len(parts) > 0 {
lastPart := parts[len(parts)-1]
// Remove any query parameters
if questionMark := strings.Index(lastPart, "?"); questionMark != -1 {
lastPart = lastPart[:questionMark]
}
// Check if it's a valid SHA1 hash (40 hex chars)
if len(lastPart) == 40 && isHexString(lastPart) {
logger.Logger.Debug().
Str("url_path", urlPath).
Str("sha1_hash", lastPart).
Msg("Extracted SHA1 hash from Steam manifest URL")
return lastPart, "sha1"
}
}
}
// Look for patch URLs with SHA1 hash at the end
if strings.Contains(urlPath, "/patch/") {
parts := strings.Split(urlPath, "/")
if len(parts) > 0 {
lastPart := parts[len(parts)-1]
// Remove any query parameters
if questionMark := strings.Index(lastPart, "?"); questionMark != -1 {
lastPart = lastPart[:questionMark]
}
// Check if it's a valid SHA1 hash (40 hex chars)
if len(lastPart) == 40 && isHexString(lastPart) {
logger.Logger.Debug().
Str("url_path", urlPath).
Str("sha1_hash", lastPart).
Msg("Extracted SHA1 hash from Steam patch URL")
return lastPart, "sha1"
}
}
}
return "", ""
}
// extractUpstreamHash extracts hash from upstream server headers and URL path, prioritizing by security
// Returns the hash value and the algorithm used (sha256, sha1, or empty if none found)
func extractUpstreamHash(headers http.Header, urlPath string) (hash string, algorithm string) {
// Priority order: SHA256 (most secure) -> SHA1 (legacy) -> none
// 1. Try SHA256 headers first (highest priority)
sha256Headers := []string{
"X-SHA256", // Custom header
"Content-SHA256", // Content hash
"X-Content-SHA256", // Service specific
"Digest", // RFC 3230 digest header
}
for _, headerName := range sha256Headers {
if value := headers.Get(headerName); value != "" {
// Remove quotes if present (ETag often has quotes)
value = strings.Trim(value, `"`)
// Check for SHA256 prefix in Digest header
if strings.HasPrefix(value, "sha256=") {
hash := strings.TrimPrefix(value, "sha256=")
if len(hash) == 64 && isHexString(hash) {
logger.Logger.Debug().
Str("header_name", headerName).
Str("sha256_hash", hash).
Msg("Extracted SHA256 hash from upstream header")
return hash, "sha256"
}
}
// Direct SHA256 hash (64 chars)
if len(value) == 64 && isHexString(value) {
logger.Logger.Debug().
Str("header_name", headerName).
Str("sha256_hash", value).
Msg("Extracted SHA256 hash from upstream header")
return value, "sha256"
}
}
}
// 2. Fallback to SHA1 headers (legacy support)
sha1Headers := []string{
"X-SHA1", // Legacy custom header
"Content-SHA1", // Legacy content hash
"X-Content-SHA1", // Legacy Steam specific
"X-Content-Sha", // Legacy Steam specific (lowercase variant)
"ETag", // May contain SHA1
}
for _, headerName := range sha1Headers {
if value := headers.Get(headerName); value != "" {
// Remove quotes if present (ETag often has quotes)
value = strings.Trim(value, `"`)
// Check for SHA1 prefix in Digest header
if strings.HasPrefix(value, "sha1=") {
hash := strings.TrimPrefix(value, "sha1=")
if len(hash) == 40 && isHexString(hash) {
logger.Logger.Debug().
Str("header_name", headerName).
Str("sha1_hash", hash).
Msg("Extracted SHA1 hash from upstream header (legacy)")
return hash, "sha1"
}
}
// Direct SHA1 hash (40 chars)
if len(value) == 40 && isHexString(value) {
logger.Logger.Debug().
Str("header_name", headerName).
Str("sha1_hash", value).
Msg("Extracted SHA1 hash from upstream header (legacy)")
return value, "sha1"
}
}
}
// 3. Fallback to URL path extraction (Steam chunk URLs)
urlHash, urlAlgo := extractHashFromURL(urlPath)
if urlHash != "" {
return urlHash, urlAlgo
}
logger.Logger.Debug().Msg("No upstream hash found in headers or URL")
return "", ""
}
// isHexString checks if a string contains only hexadecimal characters
func isHexString(s string) bool {
for _, r := range s {
if !((r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F')) {
return false
}
}
return true
}
// verifyCompleteFile verifies that we received the complete file by checking Content-Length
// Returns true if the file is complete, false if it's incomplete (allowing retry)
func (sc *SteamCache) verifyCompleteFile(bodyData []byte, resp *http.Response, urlPath string, cacheKey string) bool {
@@ -1547,7 +1374,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Serialize the raw response using our new cache format
cacheData, err := serializeRawResponse(urlPath, rawResponse, "", "")
cacheData, err := serializeRawResponse(rawResponse)
if err != nil {
logger.Logger.Warn().
Str("key", cacheKey).
@@ -1711,131 +1538,3 @@ func (sc *SteamCache) recordCacheMiss(key string, size int64) {
sc.cacheWarmer.RequestWarming(key, 3, "cache_miss", size, "cache_miss_analyzer")
}
}
// GetAdaptiveStats returns adaptive caching statistics
func (sc *SteamCache) GetAdaptiveStats() map[string]interface{} {
stats := make(map[string]interface{})
// Get current strategy
currentStrategy := sc.adaptiveManager.GetCurrentStrategy()
stats["current_strategy"] = currentStrategy
stats["adaptation_count"] = sc.adaptiveManager.GetAdaptationCount()
// Get dominant pattern (using public method)
// Note: In a real implementation, we'd need a public method to get the dominant pattern
stats["dominant_pattern"] = "unknown" // Placeholder for now
return stats
}
// GetPredictiveStats returns predictive caching statistics
func (sc *SteamCache) GetPredictiveStats() map[string]interface{} {
stats := make(map[string]interface{})
predictiveStats := sc.predictiveManager.GetStats()
stats["prefetch_hits"] = predictiveStats.PrefetchHits
stats["prefetch_misses"] = predictiveStats.PrefetchMisses
stats["prefetch_requests"] = predictiveStats.PrefetchRequests
stats["cache_warm_hits"] = predictiveStats.CacheWarmHits
stats["cache_warm_misses"] = predictiveStats.CacheWarmMisses
return stats
}
// GetWarmingStats returns cache warming statistics
func (sc *SteamCache) GetWarmingStats() map[string]interface{} {
stats := make(map[string]interface{})
warmingStats := sc.cacheWarmer.GetStats()
stats["warm_requests"] = warmingStats.WarmRequests
stats["warm_successes"] = warmingStats.WarmSuccesses
stats["warm_failures"] = warmingStats.WarmFailures
stats["warm_bytes"] = warmingStats.WarmBytes
stats["warm_duration"] = warmingStats.WarmDuration
stats["active_warmers"] = warmingStats.ActiveWarmers
stats["warming_enabled"] = sc.cacheWarmer.IsWarmingEnabled()
return stats
}
// SetWarmingEnabled enables or disables cache warming
func (sc *SteamCache) SetWarmingEnabled(enabled bool) {
sc.cacheWarmer.SetWarmingEnabled(enabled)
}
// WarmPopularContent manually triggers warming of popular content
func (sc *SteamCache) WarmPopularContent(keys []string) {
sc.cacheWarmer.WarmPopularContent(keys, 2)
}
// WarmPredictedContent manually triggers warming of predicted content
func (sc *SteamCache) WarmPredictedContent(keys []string) {
sc.cacheWarmer.WarmPredictedContent(keys, 3)
}
// SetAdaptiveEnabled enables or disables adaptive features
func (sc *SteamCache) SetAdaptiveEnabled(enabled bool) {
sc.adaptiveEnabled = enabled
if !enabled {
// Stop adaptive components when disabled
sc.adaptiveManager.Stop()
sc.predictiveManager.Stop()
sc.cacheWarmer.Stop()
}
}
// IsAdaptiveEnabled returns whether adaptive features are enabled
func (sc *SteamCache) IsAdaptiveEnabled() bool {
return sc.adaptiveEnabled
}
// GetMemoryStats returns memory monitoring statistics
func (sc *SteamCache) GetMemoryStats() map[string]interface{} {
if sc.memoryMonitor == nil {
return map[string]interface{}{"error": "memory monitoring not enabled"}
}
stats := sc.memoryMonitor.GetMemoryStats()
if sc.dynamicCacheMgr != nil {
dynamicStats := sc.dynamicCacheMgr.GetStats()
for k, v := range dynamicStats {
stats["dynamic_"+k] = v
}
}
return stats
}
// GetDynamicCacheStats returns dynamic cache management statistics
func (sc *SteamCache) GetDynamicCacheStats() map[string]interface{} {
if sc.dynamicCacheMgr == nil {
return map[string]interface{}{"error": "dynamic cache management not enabled"}
}
return sc.dynamicCacheMgr.GetStats()
}
// SetMemoryTarget sets the target memory usage for dynamic cache sizing
func (sc *SteamCache) SetMemoryTarget(targetBytes uint64) {
if sc.memoryMonitor != nil {
sc.memoryMonitor.SetTargetMemoryUsage(targetBytes)
}
}
// ForceCacheAdjustment forces an immediate cache size adjustment
func (sc *SteamCache) ForceCacheAdjustment() {
if sc.dynamicCacheMgr != nil {
// This would trigger an immediate adjustment
// Implementation depends on the specific needs
}
}
// GetMemoryFragmentationStats returns memory fragmentation statistics
func (sc *SteamCache) GetMemoryFragmentationStats() map[string]interface{} {
if sc.memory == nil {
return map[string]interface{}{"error": "memory cache not enabled"}
}
return sc.memory.GetFragmentationStats()
}