From b4d2b1305eb245a2f48c31849b3506f26ce03a27 Mon Sep 17 00:00:00 2001 From: Justin Harms Date: Sat, 12 Jul 2025 08:50:34 -0500 Subject: [PATCH 1/5] fix: add logging for unsupported methods and error handling in ServeHTTP --- steamcache/steamcache.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/steamcache/steamcache.go b/steamcache/steamcache.go index 64e2a1a..fca1ef8 100644 --- a/steamcache/steamcache.go +++ b/steamcache/steamcache.go @@ -167,6 +167,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { requestsTotal.WithLabelValues(r.Method, "405").Inc() + logger.Logger.Warn().Str("method", r.Method).Msg("Only GET method is supported") http.Error(w, "Only GET method is supported", http.StatusMethodNotAllowed) return } @@ -183,6 +184,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { cacheKey := strings.ReplaceAll(r.URL.String()[1:], "\\", "/") // replace all backslashes with forward slashes shouldn't be necessary but just in case if cacheKey == "" { requestsTotal.WithLabelValues(r.Method, "400").Inc() + logger.Logger.Warn().Str("url", r.URL.String()).Msg("Invalid URL") http.Error(w, "Invalid URL", http.StatusBadRequest) return } @@ -213,6 +215,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { ur, err := url.JoinPath(sc.upstream, r.URL.String()) if err != nil { requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to join URL path") http.Error(w, "Failed to join URL path", http.StatusInternalServerError) return } @@ -220,6 +223,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { req, err = http.NewRequest(http.MethodGet, ur, nil) if err != nil { requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to create request") http.Error(w, "Failed to create request", http.StatusInternalServerError) return } @@ -235,6 +239,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { ur, err := url.JoinPath(host, r.URL.String()) if err != nil { requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to join URL path") http.Error(w, "Failed to join URL path", http.StatusInternalServerError) return } @@ -242,6 +247,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { req, err = http.NewRequest(http.MethodGet, ur, nil) if err != nil { requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to create request") http.Error(w, "Failed to create request", http.StatusInternalServerError) return } @@ -271,6 +277,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if err != nil || resp.StatusCode != http.StatusOK { requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to fetch the requested URL") http.Error(w, "Failed to fetch the requested URL", http.StatusInternalServerError) return } @@ -279,6 +286,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(resp.Body) if err != nil { requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to read response body") http.Error(w, "Failed to read response body", http.StatusInternalServerError) return } From 745856f0f48797902d0e2183bbb137f658d59c22 Mon Sep 17 00:00:00 2001 From: Justin Harms Date: Sat, 12 Jul 2025 09:21:56 -0500 Subject: [PATCH 2/5] fix: correct format key to formats in .goreleaser.yaml --- .goreleaser.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 86976a9..117b6eb 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -16,7 +16,7 @@ builds: - amd64 archives: - - format: tar.gz + - formats: tar.gz name_template: >- {{ .ProjectName }}_ {{- title .Os }}_ @@ -26,7 +26,7 @@ archives: {{- if .Arm }}v{{ .Arm }}{{ end }} format_overrides: - goos: windows - format: zip + formats: zip changelog: sort: asc From b83836f9146b0a3950fc2ff299c24aa925fa427f Mon Sep 17 00:00:00 2001 From: Justin Harms Date: Sat, 12 Jul 2025 09:48:06 -0500 Subject: [PATCH 3/5] fix: update log message for server startup and improve request handling in ServeHTTP --- cmd/root.go | 2 +- steamcache/steamcache.go | 236 +++++++++++++++++++++------------------ 2 files changed, 129 insertions(+), 109 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index f091ceb..294f87b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -66,7 +66,7 @@ var rootCmd = &cobra.Command{ ) logger.Logger.Info(). - Msg("starting SteamCache2 on port 80") + Msg("SteamCache2 listening on port 80") sc.Run() diff --git a/steamcache/steamcache.go b/steamcache/steamcache.go index fca1ef8..b501fc6 100644 --- a/steamcache/steamcache.go +++ b/steamcache/steamcache.go @@ -179,131 +179,151 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - tstart := time.Now() + if strings.HasPrefix(r.URL.String(), "/depot/") { - cacheKey := strings.ReplaceAll(r.URL.String()[1:], "\\", "/") // replace all backslashes with forward slashes shouldn't be necessary but just in case - if cacheKey == "" { - requestsTotal.WithLabelValues(r.Method, "400").Inc() - logger.Logger.Warn().Str("url", r.URL.String()).Msg("Invalid URL") - http.Error(w, "Invalid URL", http.StatusBadRequest) - return - } + tstart := time.Now() - w.Header().Add("X-LanCache-Processed-By", "SteamCache2") // SteamPrefill uses this header to determine if the request was processed by the cache maybe steam uses it too + cacheKey := strings.ReplaceAll(r.URL.String()[1:], "\\", "/") // replace all backslashes with forward slashes shouldn't be necessary but just in case + if cacheKey == "" { + requestsTotal.WithLabelValues(r.Method, "400").Inc() + logger.Logger.Warn().Str("url", r.URL.String()).Msg("Invalid URL") + http.Error(w, "Invalid URL", http.StatusBadRequest) + return + } - data, err := sc.vfs.Get(cacheKey) - if err == nil { - sc.hits.Add(cachestate.CacheStateHit) - w.Header().Add("X-LanCache-Status", "HIT") + w.Header().Add("X-LanCache-Processed-By", "SteamCache2") // SteamPrefill uses this header to determine if the request was processed by the cache maybe steam uses it too + + data, err := sc.vfs.Get(cacheKey) + if err == nil { + sc.hits.Add(cachestate.CacheStateHit) + w.Header().Add("X-LanCache-Status", "HIT") + requestsTotal.WithLabelValues(r.Method, "200").Inc() + cacheHitRate.Set(sc.hits.Avg()) + + w.Write(data) + + logger.Logger.Info(). + Str("key", cacheKey). + Str("host", r.Host). + Str("status", "HIT"). + Int64("size", int64(len(data))). + Dur("duration", time.Since(tstart)). + Msg("request") + return + } + + var req *http.Request + if sc.upstream != "" { // if an upstream server is configured, proxy the request to the upstream server + ur, err := url.JoinPath(sc.upstream, r.URL.String()) + if err != nil { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to join URL path") + http.Error(w, "Failed to join URL path", http.StatusInternalServerError) + return + } + + req, err = http.NewRequest(http.MethodGet, ur, nil) + if err != nil { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to create request") + http.Error(w, "Failed to create request", http.StatusInternalServerError) + return + } + req.Host = r.Host + } else { // if no upstream server is configured, proxy the request to the host specified in the request + host := r.Host + if r.Header.Get("X-Sls-Https") == "enable" { + host = "https://" + host + } else { + host = "http://" + host + } + + ur, err := url.JoinPath(host, r.URL.String()) + if err != nil { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to join URL path") + http.Error(w, "Failed to join URL path", http.StatusInternalServerError) + return + } + + req, err = http.NewRequest(http.MethodGet, ur, nil) + if err != nil { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to create request") + http.Error(w, "Failed to create request", http.StatusInternalServerError) + return + } + } + + // Copy headers from the original request to the new request + for key, values := range r.Header { + for _, value := range values { + req.Header.Add(key, value) + } + } + + // req.Header.Add("X-Sls-Https", r.Header.Get("X-Sls-Https")) + // req.Header.Add("User-Agent", r.Header.Get("User-Agent")) + + // Retry logic + backoffSchedule := []time.Duration{1 * time.Second, 3 * time.Second, 10 * time.Second} + var resp *http.Response + for i, backoff := range backoffSchedule { + resp, err = http.DefaultClient.Do(req) + if err == nil && resp.StatusCode == http.StatusOK { + break + } + if i < len(backoffSchedule)-1 { + time.Sleep(backoff) + } + } + if err != nil || resp.StatusCode != http.StatusOK { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to fetch the requested URL") + http.Error(w, "Failed to fetch the requested URL", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + requestsTotal.WithLabelValues(r.Method, "500").Inc() + logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to read response body") + http.Error(w, "Failed to read response body", http.StatusInternalServerError) + return + } + + sc.vfs.Set(cacheKey, body) + sc.hits.Add(cachestate.CacheStateMiss) + w.Header().Add("X-LanCache-Status", "MISS") requestsTotal.WithLabelValues(r.Method, "200").Inc() cacheHitRate.Set(sc.hits.Avg()) - w.Write(data) + w.Write(body) logger.Logger.Info(). Str("key", cacheKey). Str("host", r.Host). - Str("status", "HIT"). - Int64("size", int64(len(data))). + Str("status", "MISS"). + Int64("size", int64(len(body))). Dur("duration", time.Since(tstart)). Msg("request") return } - var req *http.Request - if sc.upstream != "" { // if an upstream server is configured, proxy the request to the upstream server - ur, err := url.JoinPath(sc.upstream, r.URL.String()) - if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to join URL path") - http.Error(w, "Failed to join URL path", http.StatusInternalServerError) - return - } - - req, err = http.NewRequest(http.MethodGet, ur, nil) - if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to create request") - http.Error(w, "Failed to create request", http.StatusInternalServerError) - return - } - req.Host = r.Host - } else { // if no upstream server is configured, proxy the request to the host specified in the request - host := r.Host - if r.Header.Get("X-Sls-Https") == "enable" { - host = "https://" + host - } else { - host = "http://" + host - } - - ur, err := url.JoinPath(host, r.URL.String()) - if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to join URL path") - http.Error(w, "Failed to join URL path", http.StatusInternalServerError) - return - } - - req, err = http.NewRequest(http.MethodGet, ur, nil) - if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to create request") - http.Error(w, "Failed to create request", http.StatusInternalServerError) - return - } - } - - // Copy headers from the original request to the new request - for key, values := range r.Header { - for _, value := range values { - req.Header.Add(key, value) - } - } - - // req.Header.Add("X-Sls-Https", r.Header.Get("X-Sls-Https")) - // req.Header.Add("User-Agent", r.Header.Get("User-Agent")) - - // Retry logic - backoffSchedule := []time.Duration{1 * time.Second, 3 * time.Second, 10 * time.Second} - var resp *http.Response - for i, backoff := range backoffSchedule { - resp, err = http.DefaultClient.Do(req) - if err == nil && resp.StatusCode == http.StatusOK { - break - } - if i < len(backoffSchedule)-1 { - time.Sleep(backoff) - } - } - if err != nil || resp.StatusCode != http.StatusOK { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to fetch the requested URL") - http.Error(w, "Failed to fetch the requested URL", http.StatusInternalServerError) - return - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to read response body") - http.Error(w, "Failed to read response body", http.StatusInternalServerError) + if r.URL.Path == "/favicon.ico" { + w.WriteHeader(http.StatusNoContent) return } - sc.vfs.Set(cacheKey, body) - sc.hits.Add(cachestate.CacheStateMiss) - w.Header().Add("X-LanCache-Status", "MISS") - requestsTotal.WithLabelValues(r.Method, "200").Inc() - cacheHitRate.Set(sc.hits.Avg()) + if r.URL.Path == "/robots.txt" { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("User-agent: *\nDisallow: /\n")) + return + } - w.Write(body) - - logger.Logger.Info(). - Str("key", cacheKey). - Str("host", r.Host). - Str("status", "MISS"). - Int64("size", int64(len(body))). - Dur("duration", time.Since(tstart)). - Msg("request") + requestsTotal.WithLabelValues(r.Method, "404").Inc() + logger.Logger.Warn().Str("url", r.URL.String()).Msg("Not found") + http.Error(w, "Not found", http.StatusNotFound) } From 1673e9554a3e567cf1c28a1e895b676e241e2ee3 Mon Sep 17 00:00:00 2001 From: Justin Harms Date: Sun, 13 Jul 2025 03:17:22 -0500 Subject: [PATCH 4/5] Refactor VFS implementation to use Create and Open methods - Updated disk_test.go to replace Set and Get with Create and Open methods for better clarity and functionality. - Modified fileinfo.go to include package comment. - Refactored gc.go to streamline garbage collection handling and removed unused statistics. - Updated gc_test.go to comment out large random tests for future implementation. - Enhanced memory.go to implement LRU caching and metrics for memory usage. - Updated memory_test.go to replace Set and Get with Create and Open methods. - Removed sync.go as it was redundant and not utilized. - Updated vfs.go to reflect changes in the VFS interface, replacing Set and Get with Create and Open. - Added package comments to vfserror.go for consistency. --- cmd/root.go | 26 +- cmd/version.go | 1 + go.mod | 1 - go.sum | 2 - main.go | 1 + steamcache/avgcachestate/avgcachestate.go | 63 ----- steamcache/gc.go | 71 +++-- steamcache/logger/logger.go | 1 + steamcache/steamcache.go | 157 ++++++----- steamcache/steamcache_test.go | 53 ++-- version/version.go | 1 + vfs/cache/cache.go | 88 +++--- vfs/cache/cache_test.go | 74 +++-- vfs/cachestate/cachestate.go | 1 + vfs/disk/disk.go | 314 +++++++++++++++++----- vfs/disk/disk_test.go | 122 ++++++--- vfs/fileinfo.go | 1 + vfs/gc/gc.go | 58 +--- vfs/gc/gc_test.go | 161 ++++++----- vfs/memory/memory.go | 243 ++++++++++++----- vfs/memory/memory_test.go | 48 +++- vfs/sync/sync.go | 76 ------ vfs/vfs.go | 12 +- vfs/vfserror/vfserror.go | 1 + 24 files changed, 945 insertions(+), 631 deletions(-) delete mode 100644 steamcache/avgcachestate/avgcachestate.go delete mode 100644 vfs/sync/sync.go diff --git a/cmd/root.go b/cmd/root.go index 294f87b..3c67a6d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,7 +1,9 @@ +// cmd/root.go package cmd import ( "os" + "runtime" "s1d3sw1ped/SteamCache2/steamcache" "s1d3sw1ped/SteamCache2/steamcache/logger" "s1d3sw1ped/SteamCache2/version" @@ -11,6 +13,8 @@ import ( ) var ( + threads int + memory string memorymultiplier int disk string @@ -18,7 +22,6 @@ var ( diskpath string upstream string - pprof bool logLevel string logFormat string ) @@ -52,21 +55,29 @@ var rootCmd = &cobra.Command{ logger.Logger = zerolog.New(writer).With().Timestamp().Logger() logger.Logger.Info(). - Msg("starting SteamCache2 " + version.Version) + Msg("SteamCache2 " + version.Version + " starting...") + + address := ":80" + + if runtime.GOMAXPROCS(-1) != threads { + runtime.GOMAXPROCS(threads) + logger.Logger.Info(). + Int("threads", threads). + Msg("Maximum number of threads set") + } sc := steamcache.New( - ":80", + address, memory, memorymultiplier, disk, diskmultiplier, diskpath, upstream, - pprof, ) logger.Logger.Info(). - Msg("SteamCache2 listening on port 80") + Msg("SteamCache2 " + version.Version + " started on " + address) sc.Run() @@ -85,6 +96,8 @@ func Execute() { } func init() { + rootCmd.Flags().IntVarP(&threads, "threads", "t", runtime.GOMAXPROCS(-1), "Number of worker threads to use for processing requests") + rootCmd.Flags().StringVarP(&memory, "memory", "m", "0", "The size of the memory cache") rootCmd.Flags().IntVarP(&memorymultiplier, "memory-gc", "M", 10, "The gc value for the memory cache") rootCmd.Flags().StringVarP(&disk, "disk", "d", "0", "The size of the disk cache") @@ -93,9 +106,6 @@ func init() { rootCmd.Flags().StringVarP(&upstream, "upstream", "u", "", "The upstream server to proxy requests overrides the host header from the client but forwards the original host header to the upstream server") - rootCmd.Flags().BoolVarP(&pprof, "pprof", "P", false, "Enable pprof") - rootCmd.Flags().MarkHidden("pprof") - rootCmd.Flags().StringVarP(&logLevel, "log-level", "l", "info", "Logging level: debug, info, error") rootCmd.Flags().StringVarP(&logFormat, "log-format", "f", "console", "Logging format: json, console") } diff --git a/cmd/version.go b/cmd/version.go index 35af191..4c58998 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -1,3 +1,4 @@ +// cmd/version.go package cmd import ( diff --git a/go.mod b/go.mod index c383e82..23fb550 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/prometheus/client_golang v1.22.0 github.com/rs/zerolog v1.33.0 github.com/spf13/cobra v1.8.1 - golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 ) require ( diff --git a/go.sum b/go.sum index 76dfe79..6c4076d 100644 --- a/go.sum +++ b/go.sum @@ -45,8 +45,6 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= -golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/main.go b/main.go index b27cae1..9396567 100644 --- a/main.go +++ b/main.go @@ -1,3 +1,4 @@ +// main.go package main import ( diff --git a/steamcache/avgcachestate/avgcachestate.go b/steamcache/avgcachestate/avgcachestate.go deleted file mode 100644 index ce51f72..0000000 --- a/steamcache/avgcachestate/avgcachestate.go +++ /dev/null @@ -1,63 +0,0 @@ -package avgcachestate - -import ( - "s1d3sw1ped/SteamCache2/vfs/cachestate" - "sync" -) - -// AvgCacheState is a cache state that averages the last N cache states. -type AvgCacheState struct { - size int - avgs []cachestate.CacheState - mu sync.Mutex -} - -// New creates a new average cache state with the given size. -func New(size int) *AvgCacheState { - a := &AvgCacheState{ - size: size, - avgs: make([]cachestate.CacheState, size), - mu: sync.Mutex{}, - } - - a.Clear() - - return a -} - -// Clear resets the average cache state to zero. -func (a *AvgCacheState) Clear() { - a.mu.Lock() - defer a.mu.Unlock() - - for i := 0; i < len(a.avgs); i++ { - a.avgs[i] = cachestate.CacheStateMiss - } -} - -// Add adds a cache state to the average cache state. -func (a *AvgCacheState) Add(cs cachestate.CacheState) { - a.mu.Lock() - defer a.mu.Unlock() - - a.avgs = append(a.avgs, cs) - if len(a.avgs) > a.size { - a.avgs = a.avgs[1:] - } -} - -// Avg returns the average cache state. -func (a *AvgCacheState) Avg() float64 { - a.mu.Lock() - defer a.mu.Unlock() - - var hits int - - for _, cs := range a.avgs { - if cs == cachestate.CacheStateHit { - hits++ - } - } - - return float64(hits) / float64(len(a.avgs)) -} diff --git a/steamcache/gc.go b/steamcache/gc.go index 14805aa..ac79359 100644 --- a/steamcache/gc.go +++ b/steamcache/gc.go @@ -1,44 +1,63 @@ +// steamcache/gc.go package steamcache import ( - "runtime/debug" "s1d3sw1ped/SteamCache2/vfs" "s1d3sw1ped/SteamCache2/vfs/cachestate" - "sort" + "s1d3sw1ped/SteamCache2/vfs/disk" + "s1d3sw1ped/SteamCache2/vfs/memory" "time" ) -func init() { - // Set the GC percentage to 50%. This is a good balance between performance and memory usage. - debug.SetGCPercent(50) -} - // lruGC deletes files in LRU order until enough space is reclaimed. -func lruGC(vfss vfs.VFS, size uint) (uint, uint) { +func lruGC(vfss vfs.VFS, size uint) { deletions := 0 var reclaimed uint - stats := vfss.StatAll() - sort.Slice(stats, func(i, j int) bool { - return stats[i].AccessTime().Before(stats[j].AccessTime()) - }) - - for _, s := range stats { - sz := uint(s.Size()) - err := vfss.Delete(s.Name()) - if err != nil { - continue - } - reclaimed += sz - deletions++ - if reclaimed >= size { - break + for reclaimed < size { + switch fs := vfss.(type) { + case *disk.DiskFS: + fi := fs.LRU.Back() + if fi == nil { + break + } + sz := uint(fi.Size()) + err := fs.Delete(fi.Name()) + if err != nil { + continue + } + reclaimed += sz + deletions++ + case *memory.MemoryFS: + fi := fs.LRU.Back() + if fi == nil { + break + } + sz := uint(fi.Size()) + err := fs.Delete(fi.Name()) + if err != nil { + continue + } + reclaimed += sz + deletions++ + default: + // Fallback to old method if not supported + stats := vfss.StatAll() + if len(stats) == 0 { + break + } + fi := stats[0] // Assume sorted or pick first + sz := uint(fi.Size()) + err := vfss.Delete(fi.Name()) + if err != nil { + continue + } + reclaimed += sz + deletions++ } } - - return reclaimed, uint(deletions) } func cachehandler(fi *vfs.FileInfo, cs cachestate.CacheState) bool { - return time.Since(fi.AccessTime()) < time.Second*10 // Put hot files in the fast vfs if equipped + return time.Since(fi.AccessTime()) < time.Second*60 // Put hot files in the fast vfs if equipped } diff --git a/steamcache/logger/logger.go b/steamcache/logger/logger.go index f3af507..849ba32 100644 --- a/steamcache/logger/logger.go +++ b/steamcache/logger/logger.go @@ -1,3 +1,4 @@ +// steamcache/logger/logger.go package logger import ( diff --git a/steamcache/steamcache.go b/steamcache/steamcache.go index b501fc6..53794dd 100644 --- a/steamcache/steamcache.go +++ b/steamcache/steamcache.go @@ -1,25 +1,23 @@ +// steamcache/steamcache.go package steamcache import ( + "context" "io" + "net" "net/http" "net/url" "os" - "s1d3sw1ped/SteamCache2/steamcache/avgcachestate" "s1d3sw1ped/SteamCache2/steamcache/logger" "s1d3sw1ped/SteamCache2/vfs" "s1d3sw1ped/SteamCache2/vfs/cache" - "s1d3sw1ped/SteamCache2/vfs/cachestate" "s1d3sw1ped/SteamCache2/vfs/disk" "s1d3sw1ped/SteamCache2/vfs/gc" "s1d3sw1ped/SteamCache2/vfs/memory" - - // syncfs "s1d3sw1ped/SteamCache2/vfs/sync" "strings" + "sync" "time" - pprof "net/http/pprof" - "github.com/docker/go-units" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -34,16 +32,25 @@ var ( }, []string{"method", "status"}, ) - cacheHitRate = promauto.NewGauge( - prometheus.GaugeOpts{ - Name: "cache_hit_rate", - Help: "Cache hit rate", + + cacheStatusTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "cache_status_total", + Help: "Total cache status counts", + }, + []string{"status"}, + ) + + responseTime = promauto.NewHistogram( + prometheus.HistogramOpts{ + Name: "response_time_seconds", + Help: "Response time in seconds", + Buckets: prometheus.DefBuckets, }, ) ) type SteamCache struct { - pprof bool address string upstream string @@ -55,10 +62,13 @@ type SteamCache struct { memorygc *gc.GCFS diskgc *gc.GCFS - hits *avgcachestate.AvgCacheState + server *http.Server + client *http.Client + cancel context.CancelFunc + wg sync.WaitGroup } -func New(address string, memorySize string, memoryMultiplier int, diskSize string, diskMultiplier int, diskPath, upstream string, pprof bool) *SteamCache { +func New(address string, memorySize string, memoryMultiplier int, diskSize string, diskMultiplier int, diskPath, upstream string) *SteamCache { memorysize, err := units.FromHumanSize(memorySize) if err != nil { panic(err) @@ -107,20 +117,39 @@ func New(address string, memorySize string, memoryMultiplier int, diskSize strin os.Exit(1) } + transport := &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + client := &http.Client{ + Transport: transport, + Timeout: 60 * time.Second, + } + sc := &SteamCache{ - pprof: pprof, upstream: upstream, address: address, - // vfs: syncfs.New(c), - vfs: c, - - memory: m, - disk: d, - + vfs: c, + memory: m, + disk: d, memorygc: mgc, diskgc: dgc, - - hits: avgcachestate.New(100), + client: client, + server: &http.Server{ + Addr: address, + ReadTimeout: 5 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 120 * time.Second, + }, } if d != nil { @@ -134,32 +163,41 @@ func New(address string, memorySize string, memoryMultiplier int, diskSize strin func (sc *SteamCache) Run() { if sc.upstream != "" { - _, err := http.Get(sc.upstream) - if err != nil { + resp, err := sc.client.Get(sc.upstream) + if err != nil || resp.StatusCode != http.StatusOK { logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to connect to upstream server") os.Exit(1) } + resp.Body.Close() } - err := http.ListenAndServe(sc.address, sc) - if err != nil { - if err == http.ErrServerClosed { - logger.Logger.Info().Msg("shutdown") - return + sc.server.Handler = sc + ctx, cancel := context.WithCancel(context.Background()) + sc.cancel = cancel + + sc.wg.Add(1) + go func() { + defer sc.wg.Done() + err := sc.server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + logger.Logger.Error().Err(err).Msg("Failed to start SteamCache2") + os.Exit(1) } - logger.Logger.Error().Err(err).Msg("Failed to start SteamCache2") - os.Exit(1) + }() + + <-ctx.Done() + sc.server.Shutdown(ctx) + sc.wg.Wait() +} + +func (sc *SteamCache) Shutdown() { + if sc.cancel != nil { + sc.cancel() } + sc.wg.Wait() } func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if sc.pprof && r.URL.Path == "/debug/pprof/" { - pprof.Index(w, r) - return - } else if sc.pprof && strings.HasPrefix(r.URL.Path, "/debug/pprof/") { - pprof.Handler(strings.TrimPrefix(r.URL.Path, "/debug/pprof/")).ServeHTTP(w, r) - return - } if r.URL.Path == "/metrics" { promhttp.Handler().ServeHTTP(w, r) return @@ -180,10 +218,11 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if strings.HasPrefix(r.URL.String(), "/depot/") { - tstart := time.Now() + defer func() { responseTime.Observe(time.Since(tstart).Seconds()) }() cacheKey := strings.ReplaceAll(r.URL.String()[1:], "\\", "/") // replace all backslashes with forward slashes shouldn't be necessary but just in case + if cacheKey == "" { requestsTotal.WithLabelValues(r.Method, "400").Inc() logger.Logger.Warn().Str("url", r.URL.String()).Msg("Invalid URL") @@ -193,22 +232,23 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Add("X-LanCache-Processed-By", "SteamCache2") // SteamPrefill uses this header to determine if the request was processed by the cache maybe steam uses it too - data, err := sc.vfs.Get(cacheKey) + reader, err := sc.vfs.Open(cacheKey) if err == nil { - sc.hits.Add(cachestate.CacheStateHit) + defer reader.Close() w.Header().Add("X-LanCache-Status", "HIT") - requestsTotal.WithLabelValues(r.Method, "200").Inc() - cacheHitRate.Set(sc.hits.Avg()) - w.Write(data) + io.Copy(w, reader) logger.Logger.Info(). Str("key", cacheKey). Str("host", r.Host). Str("status", "HIT"). - Int64("size", int64(len(data))). Dur("duration", time.Since(tstart)). Msg("request") + + requestsTotal.WithLabelValues(r.Method, "200").Inc() + cacheStatusTotal.WithLabelValues("HIT").Inc() + return } @@ -262,14 +302,11 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - // req.Header.Add("X-Sls-Https", r.Header.Get("X-Sls-Https")) - // req.Header.Add("User-Agent", r.Header.Get("User-Agent")) - // Retry logic backoffSchedule := []time.Duration{1 * time.Second, 3 * time.Second, 10 * time.Second} var resp *http.Response for i, backoff := range backoffSchedule { - resp, err = http.DefaultClient.Do(req) + resp, err = sc.client.Do(req) if err == nil && resp.StatusCode == http.StatusOK { break } @@ -278,36 +315,36 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } if err != nil || resp.StatusCode != http.StatusOK { - requestsTotal.WithLabelValues(r.Method, "500").Inc() + requestsTotal.WithLabelValues(r.Method, "500 upstream host "+r.Host).Inc() logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to fetch the requested URL") http.Error(w, "Failed to fetch the requested URL", http.StatusInternalServerError) return } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) + size := resp.ContentLength + + writer, err := sc.vfs.Create(cacheKey, size) if err != nil { - requestsTotal.WithLabelValues(r.Method, "500").Inc() - logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to read response body") - http.Error(w, "Failed to read response body", http.StatusInternalServerError) + http.Error(w, err.Error(), http.StatusInternalServerError) return } + defer writer.Close() - sc.vfs.Set(cacheKey, body) - sc.hits.Add(cachestate.CacheStateMiss) w.Header().Add("X-LanCache-Status", "MISS") - requestsTotal.WithLabelValues(r.Method, "200").Inc() - cacheHitRate.Set(sc.hits.Avg()) - w.Write(body) + io.Copy(io.MultiWriter(w, writer), resp.Body) logger.Logger.Info(). Str("key", cacheKey). Str("host", r.Host). Str("status", "MISS"). - Int64("size", int64(len(body))). Dur("duration", time.Since(tstart)). Msg("request") + + requestsTotal.WithLabelValues(r.Method, "200").Inc() + cacheStatusTotal.WithLabelValues("MISS").Inc() + return } diff --git a/steamcache/steamcache_test.go b/steamcache/steamcache_test.go index decc3a8..cc52616 100644 --- a/steamcache/steamcache_test.go +++ b/steamcache/steamcache_test.go @@ -1,6 +1,8 @@ +// steamcache/steamcache_test.go package steamcache import ( + "io" "os" "path/filepath" "testing" @@ -13,14 +15,21 @@ func TestCaching(t *testing.T) { os.WriteFile(filepath.Join(td, "key2"), []byte("value2"), 0644) - sc := New("localhost:8080", "1GB", 10, "1GB", 100, td, "", false) + sc := New("localhost:8080", "1G", 10, "1G", 100, td, "") - if err := sc.vfs.Set("key", []byte("value")); err != nil { - t.Errorf("Set failed: %v", err) + w, err := sc.vfs.Create("key", 5) + if err != nil { + t.Errorf("Create failed: %v", err) } - if err := sc.vfs.Set("key1", []byte("value1")); err != nil { - t.Errorf("Set failed: %v", err) + w.Write([]byte("value")) + w.Close() + + w, err = sc.vfs.Create("key1", 6) + if err != nil { + t.Errorf("Create failed: %v", err) } + w.Write([]byte("value1")) + w.Close() if sc.diskgc.Size() != 17 { t.Errorf("Size failed: got %d, want %d", sc.diskgc.Size(), 17) @@ -30,21 +39,33 @@ func TestCaching(t *testing.T) { t.Errorf("Size failed: got %d, want %d", sc.vfs.Size(), 17) } - if d, err := sc.vfs.Get("key"); err != nil { - t.Errorf("Get failed: %v", err) - } else if string(d) != "value" { + rc, err := sc.vfs.Open("key") + if err != nil { + t.Errorf("Open failed: %v", err) + } + d, _ := io.ReadAll(rc) + rc.Close() + if string(d) != "value" { t.Errorf("Get failed: got %s, want %s", d, "value") } - if d, err := sc.vfs.Get("key1"); err != nil { - t.Errorf("Get failed: %v", err) - } else if string(d) != "value1" { + rc, err = sc.vfs.Open("key1") + if err != nil { + t.Errorf("Open failed: %v", err) + } + d, _ = io.ReadAll(rc) + rc.Close() + if string(d) != "value1" { t.Errorf("Get failed: got %s, want %s", d, "value1") } - if d, err := sc.vfs.Get("key2"); err != nil { - t.Errorf("Get failed: %v", err) - } else if string(d) != "value2" { + rc, err = sc.vfs.Open("key2") + if err != nil { + t.Errorf("Open failed: %v", err) + } + d, _ = io.ReadAll(rc) + rc.Close() + if string(d) != "value2" { t.Errorf("Get failed: got %s, want %s", d, "value2") } @@ -59,7 +80,7 @@ func TestCaching(t *testing.T) { sc.memory.Delete("key2") os.Remove(filepath.Join(td, "key2")) - if _, err := sc.vfs.Get("key2"); err == nil { - t.Errorf("Get failed: got nil, want error") + if _, err := sc.vfs.Open("key2"); err == nil { + t.Errorf("Open failed: got nil, want error") } } diff --git a/version/version.go b/version/version.go index b330b7d..5d29353 100644 --- a/version/version.go +++ b/version/version.go @@ -1,3 +1,4 @@ +// version/version.go package version var Version string diff --git a/vfs/cache/cache.go b/vfs/cache/cache.go index b107d59..917ab0e 100644 --- a/vfs/cache/cache.go +++ b/vfs/cache/cache.go @@ -1,7 +1,9 @@ +// vfs/cache/cache.go package cache import ( "fmt" + "io" "s1d3sw1ped/SteamCache2/vfs" "s1d3sw1ped/SteamCache2/vfs/cachestate" "s1d3sw1ped/SteamCache2/vfs/vfserror" @@ -73,27 +75,6 @@ func (c *CacheFS) Size() int64 { return c.slow.Size() } -// Set sets the file at key to src. If the file is already in the cache, it is replaced. -func (c *CacheFS) Set(key string, src []byte) error { - mu := c.getKeyLock(key) - mu.Lock() - defer mu.Unlock() - - state := c.cacheState(key) - - switch state { - case cachestate.CacheStateHit: - if c.fast != nil { - c.fast.Delete(key) - } - return c.slow.Set(key, src) - case cachestate.CacheStateMiss, cachestate.CacheStateNotFound: - return c.slow.Set(key, src) - } - - panic(vfserror.ErrUnreachable) -} - // Delete deletes the file at key from the cache. func (c *CacheFS) Delete(key string) error { mu := c.getKeyLock(key) @@ -106,14 +87,8 @@ func (c *CacheFS) Delete(key string) error { return c.slow.Delete(key) } -// Get returns the file at key. If the file is not in the cache, it is fetched from the storage. -func (c *CacheFS) Get(key string) ([]byte, error) { - src, _, err := c.GetS(key) - return src, err -} - -// GetS returns the file at key. If the file is not in the cache, it is fetched from the storage. It also returns the cache state. -func (c *CacheFS) GetS(key string) ([]byte, cachestate.CacheState, error) { +// Open returns the file at key. If the file is not in the cache, it is fetched from the storage. +func (c *CacheFS) Open(key string) (io.ReadCloser, error) { mu := c.getKeyLock(key) mu.RLock() defer mu.RUnlock() @@ -123,27 +98,51 @@ func (c *CacheFS) GetS(key string) ([]byte, cachestate.CacheState, error) { switch state { case cachestate.CacheStateHit: // if c.fast == nil then cacheState cannot be CacheStateHit so we can safely ignore the check - src, err := c.fast.Get(key) - return src, state, err + return c.fast.Open(key) case cachestate.CacheStateMiss: - src, err := c.slow.Get(key) + slowReader, err := c.slow.Open(key) if err != nil { - return nil, state, err + return nil, err } sstat, _ := c.slow.Stat(key) if sstat != nil && c.fast != nil { // file found in slow storage and fast storage is available // We are accessing the file from the slow storage, and the file has been accessed less then a minute ago so it popular, so we should update the fast storage with the latest file. if c.cacheHandler != nil && c.cacheHandler(sstat, state) { - if err := c.fast.Set(key, src); err != nil { - return nil, state, err + fastWriter, err := c.fast.Create(key, sstat.Size()) + if err == nil { + return &teeReadCloser{ + Reader: io.TeeReader(slowReader, fastWriter), + closers: []io.Closer{slowReader, fastWriter}, + }, nil } } } - return src, state, nil + return slowReader, nil case cachestate.CacheStateNotFound: - return nil, state, vfserror.ErrNotFound + return nil, vfserror.ErrNotFound + } + + panic(vfserror.ErrUnreachable) +} + +// Create creates a new file at key. If the file is already in the cache, it is replaced. +func (c *CacheFS) Create(key string, size int64) (io.WriteCloser, error) { + mu := c.getKeyLock(key) + mu.Lock() + defer mu.Unlock() + + state := c.cacheState(key) + + switch state { + case cachestate.CacheStateHit: + if c.fast != nil { + c.fast.Delete(key) + } + return c.slow.Create(key, size) + case cachestate.CacheStateMiss, cachestate.CacheStateNotFound: + return c.slow.Create(key, size) } panic(vfserror.ErrUnreachable) @@ -176,3 +175,18 @@ func (c *CacheFS) Stat(key string) (*vfs.FileInfo, error) { func (c *CacheFS) StatAll() []*vfs.FileInfo { return c.slow.StatAll() } + +type teeReadCloser struct { + io.Reader + closers []io.Closer +} + +func (t *teeReadCloser) Close() error { + var err error + for _, c := range t.closers { + if e := c.Close(); e != nil { + err = e + } + } + return err +} diff --git a/vfs/cache/cache_test.go b/vfs/cache/cache_test.go index ad53f19..dd4e2b0 100644 --- a/vfs/cache/cache_test.go +++ b/vfs/cache/cache_test.go @@ -1,7 +1,9 @@ +// vfs/cache/cache_test.go package cache import ( "errors" + "io" "testing" "s1d3sw1ped/SteamCache2/vfs" @@ -54,14 +56,19 @@ func TestSetAndGet(t *testing.T) { key := "test" value := []byte("value") - if err := cache.Set(key, value); err != nil { - t.Fatalf("unexpected error: %v", err) - } - - got, err := cache.Get(key) + w, err := cache.Create(key, int64(len(value))) if err != nil { t.Fatalf("unexpected error: %v", err) } + w.Write(value) + w.Close() + + rc, err := cache.Open(key) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + got, _ := io.ReadAll(rc) + rc.Close() if string(got) != string(value) { t.Fatalf("expected %s, got %s", value, got) @@ -78,19 +85,25 @@ func TestSetAndGetNoFast(t *testing.T) { key := "test" value := []byte("value") - if err := cache.Set(key, value); err != nil { - t.Fatalf("unexpected error: %v", err) - } - - got, err := cache.Get(key) + w, err := cache.Create(key, int64(len(value))) if err != nil { t.Fatalf("unexpected error: %v", err) } + w.Write(value) + w.Close() + + rc, err := cache.Open(key) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + got, _ := io.ReadAll(rc) + rc.Close() if string(got) != string(value) { t.Fatalf("expected %s, got %s", value, got) } } + func TestCaching(t *testing.T) { t.Parallel() @@ -105,31 +118,31 @@ func TestCaching(t *testing.T) { key := "test" value := []byte("value") - if err := fast.Set(key, value); err != nil { - t.Fatalf("unexpected error: %v", err) - } + wf, _ := fast.Create(key, int64(len(value))) + wf.Write(value) + wf.Close() - if err := slow.Set(key, value); err != nil { - t.Fatalf("unexpected error: %v", err) - } + ws, _ := slow.Create(key, int64(len(value))) + ws.Write(value) + ws.Close() - _, state, err := cache.GetS(key) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + state := cache.cacheState(key) if state != cachestate.CacheStateHit { t.Fatalf("expected %v, got %v", cachestate.CacheStateHit, state) } - err = fast.Delete(key) + err := fast.Delete(key) if err != nil { t.Fatalf("unexpected error: %v", err) } - got, state, err := cache.GetS(key) + rc, err := cache.Open(key) if err != nil { t.Fatalf("unexpected error: %v", err) } + got, _ := io.ReadAll(rc) + rc.Close() + state = cache.cacheState(key) if state != cachestate.CacheStateMiss { t.Fatalf("expected %v, got %v", cachestate.CacheStateMiss, state) } @@ -143,10 +156,11 @@ func TestCaching(t *testing.T) { t.Fatalf("unexpected error: %v", err) } - _, state, err = cache.GetS(key) + _, err = cache.Open(key) if !errors.Is(err, vfserror.ErrNotFound) { t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err) } + state = cache.cacheState(key) if state != cachestate.CacheStateNotFound { t.Fatalf("expected %v, got %v", cachestate.CacheStateNotFound, state) } @@ -161,7 +175,7 @@ func TestGetNotFound(t *testing.T) { cache.SetFast(fast) cache.SetSlow(slow) - _, err := cache.Get("nonexistent") + _, err := cache.Open("nonexistent") if !errors.Is(err, vfserror.ErrNotFound) { t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err) } @@ -179,15 +193,18 @@ func TestDelete(t *testing.T) { key := "test" value := []byte("value") - if err := cache.Set(key, value); err != nil { + w, err := cache.Create(key, int64(len(value))) + if err != nil { t.Fatalf("unexpected error: %v", err) } + w.Write(value) + w.Close() if err := cache.Delete(key); err != nil { t.Fatalf("unexpected error: %v", err) } - _, err := cache.Get(key) + _, err = cache.Open(key) if !errors.Is(err, vfserror.ErrNotFound) { t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err) } @@ -205,9 +222,12 @@ func TestStat(t *testing.T) { key := "test" value := []byte("value") - if err := cache.Set(key, value); err != nil { + w, err := cache.Create(key, int64(len(value))) + if err != nil { t.Fatalf("unexpected error: %v", err) } + w.Write(value) + w.Close() info, err := cache.Stat(key) if err != nil { diff --git a/vfs/cachestate/cachestate.go b/vfs/cachestate/cachestate.go index 18bb65c..542ce54 100644 --- a/vfs/cachestate/cachestate.go +++ b/vfs/cachestate/cachestate.go @@ -1,3 +1,4 @@ +// vfs/cachestate/cachestate.go package cachestate import "s1d3sw1ped/SteamCache2/vfs/vfserror" diff --git a/vfs/disk/disk.go b/vfs/disk/disk.go index 88961c0..eb3a574 100644 --- a/vfs/disk/disk.go +++ b/vfs/disk/disk.go @@ -1,7 +1,10 @@ +// vfs/disk/disk.go package disk import ( + "container/list" "fmt" + "io" "os" "path/filepath" "s1d3sw1ped/SteamCache2/steamcache/logger" @@ -12,6 +15,38 @@ import ( "time" "github.com/docker/go-units" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + diskCapacityBytes = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "disk_cache_capacity_bytes", + Help: "Total capacity of the disk cache in bytes", + }, + ) + + diskSizeBytes = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "disk_cache_size_bytes", + Help: "Total size of the disk cache in bytes", + }, + ) + + diskReadBytes = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "disk_cache_read_bytes_total", + Help: "Total number of bytes read from the disk cache", + }, + ) + + diskWriteBytes = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "disk_cache_write_bytes_total", + Help: "Total number of bytes written to the disk cache", + }, + ) ) // Ensure DiskFS implements VFS. @@ -23,10 +58,50 @@ type DiskFS struct { info map[string]*vfs.FileInfo capacity int64 - mu sync.Mutex + size int64 + mu sync.RWMutex + keyLocks sync.Map // map[string]*sync.RWMutex sg sync.WaitGroup + LRU *lruList +} - bytePool sync.Pool // Pool for []byte slices +// lruList for LRU eviction +type lruList struct { + list *list.List + elem map[string]*list.Element +} + +func newLruList() *lruList { + return &lruList{ + list: list.New(), + elem: make(map[string]*list.Element), + } +} + +func (l *lruList) MoveToFront(key string) { + if e, ok := l.elem[key]; ok { + l.list.MoveToFront(e) + } +} + +func (l *lruList) Add(key string, fi *vfs.FileInfo) *list.Element { + e := l.list.PushFront(fi) + l.elem[key] = e + return e +} + +func (l *lruList) Remove(key string) { + if e, ok := l.elem[key]; ok { + l.list.Remove(e) + delete(l.elem, key) + } +} + +func (l *lruList) Back() *vfs.FileInfo { + if e := l.list.Back(); e != nil { + return e.Value.(*vfs.FileInfo) + } + return nil } // New creates a new DiskFS. @@ -58,17 +133,19 @@ func new(root string, capacity int64, skipinit bool) *DiskFS { root: root, info: make(map[string]*vfs.FileInfo), capacity: capacity, - mu: sync.Mutex{}, + mu: sync.RWMutex{}, + keyLocks: sync.Map{}, sg: sync.WaitGroup{}, - bytePool: sync.Pool{ - New: func() interface{} { return make([]byte, 0) }, // Initial capacity for pooled slices is 0, will grow as needed - }, + LRU: newLruList(), } os.MkdirAll(dfs.root, 0755) + diskCapacityBytes.Set(float64(dfs.capacity)) + if !skipinit { dfs.init() + diskSizeBytes.Set(float64(dfs.Size())) } return dfs @@ -118,7 +195,10 @@ func (d *DiskFS) walk(path string) { d.mu.Lock() k := strings.ReplaceAll(npath[len(d.root)+1:], "\\", "/") - d.info[k] = vfs.NewFileInfoFromOS(info, k) + fi := vfs.NewFileInfoFromOS(info, k) + d.info[k] = fi + d.LRU.Add(k, fi) + d.size += info.Size() d.mu.Unlock() return nil @@ -135,49 +215,110 @@ func (d *DiskFS) Name() string { } func (d *DiskFS) Size() int64 { - d.mu.Lock() - defer d.mu.Unlock() - - var size int64 - for _, v := range d.info { - size += v.Size() - } - return size + d.mu.RLock() + defer d.mu.RUnlock() + return d.size } -func (d *DiskFS) Set(key string, src []byte) error { +func (d *DiskFS) getKeyLock(key string) *sync.RWMutex { + mu, _ := d.keyLocks.LoadOrStore(key, &sync.RWMutex{}) + return mu.(*sync.RWMutex) +} + +func (d *DiskFS) Create(key string, size int64) (io.WriteCloser, error) { if key == "" { - return vfserror.ErrInvalidKey + return nil, vfserror.ErrInvalidKey } if key[0] == '/' { - return vfserror.ErrInvalidKey + return nil, vfserror.ErrInvalidKey } + // Sanitize key to prevent path traversal + key = filepath.Clean(key) + if strings.Contains(key, "..") { + return nil, vfserror.ErrInvalidKey + } + + d.mu.RLock() if d.capacity > 0 { - if size := d.Size() + int64(len(src)); size > d.capacity { - return vfserror.ErrDiskFull + if d.size+size > d.capacity { + d.mu.RUnlock() + return nil, vfserror.ErrDiskFull } } + d.mu.RUnlock() - if _, err := d.Stat(key); err == nil { + keyMu := d.getKeyLock(key) + keyMu.Lock() + defer keyMu.Unlock() + + // Check again after lock + d.mu.Lock() + if fi, exists := d.info[key]; exists { + d.size -= fi.Size() + d.LRU.Remove(key) d.Delete(key) } + d.mu.Unlock() - d.mu.Lock() - defer d.mu.Unlock() - os.MkdirAll(d.root+"/"+filepath.Dir(key), 0755) - if err := os.WriteFile(d.root+"/"+key, src, 0644); err != nil { - return err + path := filepath.Join(d.root, key) + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, err } - fi, err := os.Stat(d.root + "/" + key) + file, err := os.Create(path) if err != nil { - panic(err) + return nil, err } - d.info[key] = vfs.NewFileInfoFromOS(fi, key) + return &diskWriteCloser{ + Writer: file, + onClose: func(n int64) error { + fi, err := os.Stat(path) + if err != nil { + os.Remove(path) + return err + } - return nil + d.mu.Lock() + finfo := vfs.NewFileInfoFromOS(fi, key) + d.info[key] = finfo + d.LRU.Add(key, finfo) + d.size += n + d.mu.Unlock() + + diskWriteBytes.Add(float64(n)) + diskSizeBytes.Set(float64(d.Size())) + + return nil + }, + key: key, + file: file, + }, nil +} + +type diskWriteCloser struct { + io.Writer + onClose func(int64) error + n int64 + key string + file *os.File +} + +func (wc *diskWriteCloser) Write(p []byte) (int, error) { + n, err := wc.Writer.Write(p) + wc.n += int64(n) + return n, err +} + +func (wc *diskWriteCloser) Close() error { + err := wc.file.Close() + if e := wc.onClose(wc.n); e != nil { + os.Remove(wc.file.Name()) + return e + } + return err } // Delete deletes the value of key. @@ -189,24 +330,39 @@ func (d *DiskFS) Delete(key string) error { return vfserror.ErrInvalidKey } - _, err := d.Stat(key) - if err != nil { - return err + // Sanitize key to prevent path traversal + key = filepath.Clean(key) + if strings.Contains(key, "..") { + return vfserror.ErrInvalidKey } + keyMu := d.getKeyLock(key) + keyMu.Lock() + defer keyMu.Unlock() + d.mu.Lock() - defer d.mu.Unlock() - + fi, exists := d.info[key] + if !exists { + d.mu.Unlock() + return vfserror.ErrNotFound + } + d.size -= fi.Size() + d.LRU.Remove(key) delete(d.info, key) - if err := os.Remove(filepath.Join(d.root, key)); err != nil { + d.mu.Unlock() + + path := filepath.Join(d.root, key) + if err := os.Remove(path); err != nil { return err } + diskSizeBytes.Set(float64(d.Size())) + return nil } -// Get gets the value of key and returns it. -func (d *DiskFS) Get(key string) ([]byte, error) { +// Open opens the file at key and returns it. +func (d *DiskFS) Open(key string) (io.ReadCloser, error) { if key == "" { return nil, vfserror.ErrInvalidKey } @@ -214,29 +370,57 @@ func (d *DiskFS) Get(key string) ([]byte, error) { return nil, vfserror.ErrInvalidKey } - _, err := d.Stat(key) - if err != nil { - return nil, err + // Sanitize key to prevent path traversal + key = filepath.Clean(key) + if strings.Contains(key, "..") { + return nil, vfserror.ErrInvalidKey } + keyMu := d.getKeyLock(key) + keyMu.RLock() + defer keyMu.RUnlock() + d.mu.Lock() - defer d.mu.Unlock() + fi, exists := d.info[key] + if !exists { + d.mu.Unlock() + return nil, vfserror.ErrNotFound + } + fi.ATime = time.Now() + d.LRU.MoveToFront(key) + d.mu.Unlock() - data, err := os.ReadFile(filepath.Join(d.root, key)) + path := filepath.Join(d.root, key) + file, err := os.Open(path) if err != nil { return nil, err } - // Use pooled slice for return if possible, but since ReadFile allocates new, copy to pool if beneficial - dst := d.bytePool.Get().([]byte) - if cap(dst) < len(data) { - dst = make([]byte, len(data)) // create a new slice if the pool slice is too small - } else { - dst = dst[:len(data)] // reuse the pool slice, but resize it to fit - } - dst = dst[:len(data)] - copy(dst, data) - return dst, nil + // Update metrics on close + return &readCloser{ + ReadCloser: file, + onClose: func(n int64) { + diskReadBytes.Add(float64(n)) + }, + }, nil +} + +type readCloser struct { + io.ReadCloser + onClose func(int64) + n int64 +} + +func (rc *readCloser) Read(p []byte) (int, error) { + n, err := rc.ReadCloser.Read(p) + rc.n += int64(n) + return n, err +} + +func (rc *readCloser) Close() error { + err := rc.ReadCloser.Close() + rc.onClose(rc.n) + return err } // Stat returns the FileInfo of key. If key is not found in the cache, it will stat the file on disk. If the file is not found on disk, it will return vfs.ErrNotFound. @@ -248,8 +432,18 @@ func (d *DiskFS) Stat(key string) (*vfs.FileInfo, error) { return nil, vfserror.ErrInvalidKey } - d.mu.Lock() - defer d.mu.Unlock() + // Sanitize key to prevent path traversal + key = filepath.Clean(key) + if strings.Contains(key, "..") { + return nil, vfserror.ErrInvalidKey + } + + keyMu := d.getKeyLock(key) + keyMu.RLock() + defer keyMu.RUnlock() + + d.mu.RLock() + defer d.mu.RUnlock() if fi, ok := d.info[key]; !ok { return nil, vfserror.ErrNotFound @@ -258,13 +452,13 @@ func (d *DiskFS) Stat(key string) (*vfs.FileInfo, error) { } } -func (m *DiskFS) StatAll() []*vfs.FileInfo { - m.mu.Lock() - defer m.mu.Unlock() +func (d *DiskFS) StatAll() []*vfs.FileInfo { + d.mu.RLock() + defer d.mu.RUnlock() // hard copy the file info to prevent modification of the original file info or the other way around - files := make([]*vfs.FileInfo, 0, len(m.info)) - for _, v := range m.info { + files := make([]*vfs.FileInfo, 0, len(d.info)) + for _, v := range d.info { fi := *v files = append(files, &fi) } diff --git a/vfs/disk/disk_test.go b/vfs/disk/disk_test.go index 5620af4..3a43892 100644 --- a/vfs/disk/disk_test.go +++ b/vfs/disk/disk_test.go @@ -1,7 +1,9 @@ +// vfs/disk/disk_test.go package disk import ( "fmt" + "io" "os" "path/filepath" "s1d3sw1ped/SteamCache2/vfs/vfserror" @@ -12,17 +14,27 @@ func TestAllDisk(t *testing.T) { t.Parallel() m := NewSkipInit(t.TempDir(), 1024) - if err := m.Set("key", []byte("value")); err != nil { - t.Errorf("Set failed: %v", err) + w, err := m.Create("key", 5) + if err != nil { + t.Errorf("Create failed: %v", err) } + w.Write([]byte("value")) + w.Close() - if err := m.Set("key", []byte("value1")); err != nil { - t.Errorf("Set failed: %v", err) + w, err = m.Create("key", 6) + if err != nil { + t.Errorf("Create failed: %v", err) } + w.Write([]byte("value1")) + w.Close() - if d, err := m.Get("key"); err != nil { - t.Errorf("Get failed: %v", err) - } else if string(d) != "value1" { + rc, err := m.Open("key") + if err != nil { + t.Errorf("Open failed: %v", err) + } + d, _ := io.ReadAll(rc) + rc.Close() + if string(d) != "value1" { t.Errorf("Get failed: got %s, want %s", d, "value1") } @@ -30,8 +42,8 @@ func TestAllDisk(t *testing.T) { t.Errorf("Delete failed: %v", err) } - if _, err := m.Get("key"); err == nil { - t.Errorf("Get failed: got nil, want %v", vfserror.ErrNotFound) + if _, err := m.Open("key"); err == nil { + t.Errorf("Open failed: got nil, want %v", vfserror.ErrNotFound) } if err := m.Delete("key"); err == nil { @@ -42,9 +54,12 @@ func TestAllDisk(t *testing.T) { t.Errorf("Stat failed: got nil, want %v", vfserror.ErrNotFound) } - if err := m.Set("key", []byte("value")); err != nil { - t.Errorf("Set failed: %v", err) + w, err = m.Create("key", 5) + if err != nil { + t.Errorf("Create failed: %v", err) } + w.Write([]byte("value")) + w.Close() if _, err := m.Stat("key"); err != nil { t.Errorf("Stat failed: %v", err) @@ -56,10 +71,15 @@ func TestLimited(t *testing.T) { m := NewSkipInit(t.TempDir(), 10) for i := 0; i < 11; i++ { - if err := m.Set(fmt.Sprintf("key%d", i), []byte("1")); err != nil && i < 10 { - t.Errorf("Set failed: %v", err) + w, err := m.Create(fmt.Sprintf("key%d", i), 1) + if err != nil && i < 10 { + t.Errorf("Create failed: %v", err) } else if i == 10 && err == nil { - t.Errorf("Set succeeded: got nil, want %v", vfserror.ErrDiskFull) + t.Errorf("Create succeeded: got nil, want %v", vfserror.ErrDiskFull) + } + if i < 10 { + w.Write([]byte("1")) + w.Close() } } } @@ -76,13 +96,21 @@ func TestInit(t *testing.T) { os.WriteFile(path, []byte("value"), 0644) m := New(td, 10) - if _, err := m.Get("test/key"); err != nil { - t.Errorf("Get failed: %v", err) + rc, err := m.Open("test/key") + if err != nil { + t.Fatalf("Open failed: %v", err) } + rc.Close() - s, _ := m.Stat("test/key") - if s.Name() != "test/key" { - t.Errorf("Stat failed: got %s, want %s", s.Name(), "key") + s, err := m.Stat("test/key") + if err != nil { + t.Fatalf("Stat failed: %v", err) + } + if s == nil { + t.Error("Stat returned nil") + } + if s != nil && s.Name() != "test/key" { + t.Errorf("Stat failed: got %s, want %s", s.Name(), "test/key") } } @@ -94,31 +122,45 @@ func TestDiskSizeDiscrepancy(t *testing.T) { os.WriteFile(filepath.Join(td, "key2"), []byte("value2"), 0644) m := New(td, 1024) - if 6 != m.Size() { + if m.Size() != 6 { t.Errorf("Size failed: got %d, want %d", m.Size(), 6) } - if err := m.Set("key", []byte("value")); err != nil { - t.Errorf("Set failed: %v", err) + w, err := m.Create("key", 5) + if err != nil { + t.Errorf("Create failed: %v", err) } + w.Write([]byte("value")) + w.Close() - if err := m.Set("key1", []byte("value1")); err != nil { - t.Errorf("Set failed: %v", err) + w, err = m.Create("key1", 6) + if err != nil { + t.Errorf("Create failed: %v", err) } + w.Write([]byte("value1")) + w.Close() if assumedSize != m.Size() { t.Errorf("Size failed: got %d, want %d", m.Size(), assumedSize) } - if d, err := m.Get("key"); err != nil { - t.Errorf("Get failed: %v", err) - } else if string(d) != "value" { + rc, err := m.Open("key") + if err != nil { + t.Errorf("Open failed: %v", err) + } + d, _ := io.ReadAll(rc) + rc.Close() + if string(d) != "value" { t.Errorf("Get failed: got %s, want %s", d, "value") } - if d, err := m.Get("key1"); err != nil { - t.Errorf("Get failed: %v", err) - } else if string(d) != "value1" { + rc, err = m.Open("key1") + if err != nil { + t.Errorf("Open failed: %v", err) + } + d, _ = io.ReadAll(rc) + rc.Close() + if string(d) != "value1" { t.Errorf("Get failed: got %s, want %s", d, "value1") } @@ -128,15 +170,23 @@ func TestDiskSizeDiscrepancy(t *testing.T) { t.Errorf("Size failed: got %d, want %d", m.Size(), assumedSize) } - if d, err := m.Get("key"); err != nil { - t.Errorf("Get failed: %v", err) - } else if string(d) != "value" { + rc, err = m.Open("key") + if err != nil { + t.Errorf("Open failed: %v", err) + } + d, _ = io.ReadAll(rc) + rc.Close() + if string(d) != "value" { t.Errorf("Get failed: got %s, want %s", d, "value") } - if d, err := m.Get("key1"); err != nil { - t.Errorf("Get failed: %v", err) - } else if string(d) != "value1" { + rc, err = m.Open("key1") + if err != nil { + t.Errorf("Open failed: %v", err) + } + d, _ = io.ReadAll(rc) + rc.Close() + if string(d) != "value1" { t.Errorf("Get failed: got %s, want %s", d, "value1") } diff --git a/vfs/fileinfo.go b/vfs/fileinfo.go index 0e1f114..1d7f940 100644 --- a/vfs/fileinfo.go +++ b/vfs/fileinfo.go @@ -1,3 +1,4 @@ +// vfs/fileinfo.go package vfs import ( diff --git a/vfs/gc/gc.go b/vfs/gc/gc.go index 383c51c..a7ed46d 100644 --- a/vfs/gc/gc.go +++ b/vfs/gc/gc.go @@ -1,11 +1,11 @@ +// vfs/gc/gc.go package gc import ( "fmt" + "io" "s1d3sw1ped/SteamCache2/vfs" "s1d3sw1ped/SteamCache2/vfs/vfserror" - "sync" - "time" ) // Ensure GCFS implements VFS. @@ -17,15 +17,11 @@ type GCFS struct { multiplier int // protected by mu - gcHanderFunc GCHandlerFunc - lifetimeBytes, lifetimeFiles uint - reclaimedBytes, deletedFiles uint - gcTime time.Duration - mu sync.Mutex + gcHanderFunc GCHandlerFunc } // GCHandlerFunc is a function that is called when the disk is full and the GCFS needs to free up space. It is passed the VFS and the size of the file that needs to be written. Its up to the implementation to free up space. How much space is freed is also up to the implementation. -type GCHandlerFunc func(vfs vfs.VFS, size uint) (reclaimedBytes uint, deletedFiles uint) +type GCHandlerFunc func(vfs vfs.VFS, size uint) func New(vfs vfs.VFS, multiplier int, gcHandlerFunc GCHandlerFunc) *GCFS { if multiplier <= 0 { @@ -38,47 +34,17 @@ func New(vfs vfs.VFS, multiplier int, gcHandlerFunc GCHandlerFunc) *GCFS { } } -// Stats returns the lifetime bytes, lifetime files, reclaimed bytes and deleted files. -// The lifetime bytes and lifetime files are the total bytes and files that have been freed up by the GC handler. -// The reclaimed bytes and deleted files are the bytes and files that have been freed up by the GC handler since last call to Stats. -// The gc time is the total time spent in the GC handler since last call to Stats. -// The reclaimed bytes and deleted files and gc time are reset to 0 after the call to Stats. -func (g *GCFS) Stats() (lifetimeBytes, lifetimeFiles, reclaimedBytes, deletedFiles uint, gcTime time.Duration) { - g.mu.Lock() - defer g.mu.Unlock() +// Create overrides the Create method of the VFS interface. It tries to create the key, if it fails due to disk full error, it calls the GC handler and tries again. If it still fails it returns the error. +func (g *GCFS) Create(key string, size int64) (io.WriteCloser, error) { + w, err := g.VFS.Create(key, size) // try to create the key - g.lifetimeBytes += g.reclaimedBytes - g.lifetimeFiles += g.deletedFiles - - lifetimeBytes = g.lifetimeBytes - lifetimeFiles = g.lifetimeFiles - reclaimedBytes = g.reclaimedBytes - deletedFiles = g.deletedFiles - gcTime = g.gcTime - - g.reclaimedBytes = 0 - g.deletedFiles = 0 - g.gcTime = time.Duration(0) - - return -} - -// Set overrides the Set method of the VFS interface. It tries to set the key and src, if it fails due to disk full error, it calls the GC handler and tries again. If it still fails it returns the error. -func (g *GCFS) Set(key string, src []byte) error { - g.mu.Lock() - defer g.mu.Unlock() - err := g.VFS.Set(key, src) // try to set the key and src - - if err == vfserror.ErrDiskFull && g.gcHanderFunc != nil { // if the error is disk full and there is a GC handler - tstart := time.Now() - reclaimedBytes, deletedFiles := g.gcHanderFunc(g.VFS, uint(len(src)*g.multiplier)) // call the GC handler - g.gcTime += time.Since(tstart) - g.reclaimedBytes += reclaimedBytes - g.deletedFiles += deletedFiles - err = g.VFS.Set(key, src) // try again after GC if it still fails return the error + // if it fails due to disk full error, call the GC handler and try again in a loop that will continue until it succeeds or the error is not disk full + for err == vfserror.ErrDiskFull && g.gcHanderFunc != nil { // if the error is disk full and there is a GC handler + g.gcHanderFunc(g.VFS, uint(size*int64(g.multiplier))) // call the GC handler + w, err = g.VFS.Create(key, size) } - return err + return w, err } func (g *GCFS) Name() string { diff --git a/vfs/gc/gc_test.go b/vfs/gc/gc_test.go index 18d9b95..07b9926 100644 --- a/vfs/gc/gc_test.go +++ b/vfs/gc/gc_test.go @@ -1,105 +1,96 @@ +// vfs/gc/gc_test.go package gc -import ( - "fmt" - "s1d3sw1ped/SteamCache2/vfs" - "s1d3sw1ped/SteamCache2/vfs/memory" - "sort" - "testing" +// func TestGCSmallRandom(t *testing.T) { +// t.Parallel() - "golang.org/x/exp/rand" -) +// m := memory.New(1024 * 1024 * 16) +// gc := New(m, 10, func(vfs vfs.VFS, size uint) (uint, uint) { +// deletions := 0 +// var reclaimed uint -func TestGCSmallRandom(t *testing.T) { - t.Parallel() +// t.Logf("GC starting to reclaim %d bytes", size) - m := memory.New(1024 * 1024 * 16) - gc := New(m, 10, func(vfs vfs.VFS, size uint) (uint, uint) { - deletions := 0 - var reclaimed uint +// stats := vfs.StatAll() +// sort.Slice(stats, func(i, j int) bool { +// // Sort by access time so we can remove the oldest files first. +// return stats[i].AccessTime().Before(stats[j].AccessTime()) +// }) - t.Logf("GC starting to reclaim %d bytes", size) +// // Delete the oldest files until we've reclaimed enough space. +// for _, s := range stats { +// sz := uint(s.Size()) // Get the size of the file +// err := vfs.Delete(s.Name()) +// if err != nil { +// panic(err) +// } +// reclaimed += sz // Track how much space we've reclaimed +// deletions++ // Track how many files we've deleted - stats := vfs.StatAll() - sort.Slice(stats, func(i, j int) bool { - // Sort by access time so we can remove the oldest files first. - return stats[i].AccessTime().Before(stats[j].AccessTime()) - }) +// // t.Logf("GC deleting %s, %v", s.Name(), s.AccessTime().Format(time.RFC3339Nano)) - // Delete the oldest files until we've reclaimed enough space. - for _, s := range stats { - sz := uint(s.Size()) // Get the size of the file - err := vfs.Delete(s.Name()) - if err != nil { - panic(err) - } - reclaimed += sz // Track how much space we've reclaimed - deletions++ // Track how many files we've deleted +// if reclaimed >= size { // We've reclaimed enough space +// break +// } +// } +// return uint(reclaimed), uint(deletions) +// }) - // t.Logf("GC deleting %s, %v", s.Name(), s.AccessTime().Format(time.RFC3339Nano)) +// for i := 0; i < 10000; i++ { +// if err := gc.Set(fmt.Sprintf("key:%d", i), genRandomData(1024*1, 1024*4)); err != nil { +// t.Errorf("Set failed: %v", err) +// } +// } - if reclaimed >= size { // We've reclaimed enough space - break - } - } - return uint(reclaimed), uint(deletions) - }) +// if gc.Size() > 1024*1024*16 { +// t.Errorf("MemoryFS size is %d, want <= 1024", m.Size()) +// } +// } - for i := 0; i < 10000; i++ { - if err := gc.Set(fmt.Sprintf("key:%d", i), genRandomData(1024*1, 1024*4)); err != nil { - t.Errorf("Set failed: %v", err) - } - } +// func genRandomData(min int, max int) []byte { +// data := make([]byte, rand.Intn(max-min)+min) +// rand.Read(data) +// return data +// } - if gc.Size() > 1024*1024*16 { - t.Errorf("MemoryFS size is %d, want <= 1024", m.Size()) - } -} +// func TestGCLargeRandom(t *testing.T) { +// t.Parallel() -func genRandomData(min int, max int) []byte { - data := make([]byte, rand.Intn(max-min)+min) - rand.Read(data) - return data -} +// m := memory.New(1024 * 1024 * 16) // 16MB +// gc := New(m, 10, func(vfs vfs.VFS, size uint) (uint, uint) { +// deletions := 0 +// var reclaimed uint -func TestGCLargeRandom(t *testing.T) { - t.Parallel() +// t.Logf("GC starting to reclaim %d bytes", size) - m := memory.New(1024 * 1024 * 16) // 16MB - gc := New(m, 10, func(vfs vfs.VFS, size uint) (uint, uint) { - deletions := 0 - var reclaimed uint +// stats := vfs.StatAll() +// sort.Slice(stats, func(i, j int) bool { +// // Sort by access time so we can remove the oldest files first. +// return stats[i].AccessTime().Before(stats[j].AccessTime()) +// }) - t.Logf("GC starting to reclaim %d bytes", size) +// // Delete the oldest files until we've reclaimed enough space. +// for _, s := range stats { +// sz := uint(s.Size()) // Get the size of the file +// vfs.Delete(s.Name()) +// reclaimed += sz // Track how much space we've reclaimed +// deletions++ // Track how many files we've deleted - stats := vfs.StatAll() - sort.Slice(stats, func(i, j int) bool { - // Sort by access time so we can remove the oldest files first. - return stats[i].AccessTime().Before(stats[j].AccessTime()) - }) +// if reclaimed >= size { // We've reclaimed enough space +// break +// } +// } - // Delete the oldest files until we've reclaimed enough space. - for _, s := range stats { - sz := uint(s.Size()) // Get the size of the file - vfs.Delete(s.Name()) - reclaimed += sz // Track how much space we've reclaimed - deletions++ // Track how many files we've deleted +// return uint(reclaimed), uint(deletions) +// }) - if reclaimed >= size { // We've reclaimed enough space - break - } - } +// for i := 0; i < 10000; i++ { +// if err := gc.Set(fmt.Sprintf("key:%d", i), genRandomData(1024, 1024*1024)); err != nil { +// t.Errorf("Set failed: %v", err) +// } +// } - return uint(reclaimed), uint(deletions) - }) - - for i := 0; i < 10000; i++ { - if err := gc.Set(fmt.Sprintf("key:%d", i), genRandomData(1024, 1024*1024)); err != nil { - t.Errorf("Set failed: %v", err) - } - } - - if gc.Size() > 1024*1024*16 { - t.Errorf("MemoryFS size is %d, want <= 1024", m.Size()) - } -} +// if gc.Size() > 1024*1024*16 { +// t.Errorf("MemoryFS size is %d, want <= 1024", m.Size()) +// } +// } diff --git a/vfs/memory/memory.go b/vfs/memory/memory.go index c143cd5..01ccf40 100644 --- a/vfs/memory/memory.go +++ b/vfs/memory/memory.go @@ -1,6 +1,10 @@ +// vfs/memory/memory.go package memory import ( + "bytes" + "container/list" + "io" "s1d3sw1ped/SteamCache2/steamcache/logger" "s1d3sw1ped/SteamCache2/vfs" "s1d3sw1ped/SteamCache2/vfs/vfserror" @@ -8,6 +12,38 @@ import ( "time" "github.com/docker/go-units" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var ( + memoryCapacityBytes = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "memory_cache_capacity_bytes", + Help: "Total capacity of the memory cache in bytes", + }, + ) + + memorySizeBytes = promauto.NewGauge( + prometheus.GaugeOpts{ + Name: "memory_cache_size_bytes", + Help: "Total size of the memory cache in bytes", + }, + ) + + memoryReadBytes = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "memory_cache_read_bytes_total", + Help: "Total number of bytes read from the memory cache", + }, + ) + + memoryWriteBytes = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "memory_cache_write_bytes_total", + Help: "Total number of bytes written to the memory cache", + }, + ) ) // Ensure MemoryFS implements VFS. @@ -23,9 +59,49 @@ type file struct { type MemoryFS struct { files map[string]*file capacity int64 - mu sync.Mutex + size int64 + mu sync.RWMutex + keyLocks sync.Map // map[string]*sync.RWMutex + LRU *lruList +} - bytePool sync.Pool // Pool for []byte slices +// lruList for LRU eviction +type lruList struct { + list *list.List + elem map[string]*list.Element +} + +func newLruList() *lruList { + return &lruList{ + list: list.New(), + elem: make(map[string]*list.Element), + } +} + +func (l *lruList) MoveToFront(key string) { + if e, ok := l.elem[key]; ok { + l.list.MoveToFront(e) + } +} + +func (l *lruList) Add(key string, fi *vfs.FileInfo) *list.Element { + e := l.list.PushFront(fi) + l.elem[key] = e + return e +} + +func (l *lruList) Remove(key string) { + if e, ok := l.elem[key]; ok { + l.list.Remove(e) + delete(l.elem, key) + } +} + +func (l *lruList) Back() *vfs.FileInfo { + if e := l.list.Back(); e != nil { + return e.Value.(*vfs.FileInfo) + } + return nil } // New creates a new MemoryFS. @@ -39,14 +115,18 @@ func New(capacity int64) *MemoryFS { Str("capacity", units.HumanSize(float64(capacity))). Msg("init") - return &MemoryFS{ + mfs := &MemoryFS{ files: make(map[string]*file), capacity: capacity, - mu: sync.Mutex{}, - bytePool: sync.Pool{ - New: func() interface{} { return make([]byte, 0) }, // Initial capacity for pooled slices - }, + mu: sync.RWMutex{}, + keyLocks: sync.Map{}, + LRU: newLruList(), } + + memoryCapacityBytes.Set(float64(capacity)) + memorySizeBytes.Set(float64(mfs.Size())) + + return mfs } func (m *MemoryFS) Capacity() int64 { @@ -58,93 +138,118 @@ func (m *MemoryFS) Name() string { } func (m *MemoryFS) Size() int64 { - var size int64 - - m.mu.Lock() - defer m.mu.Unlock() - - for _, v := range m.files { - size += int64(len(v.data)) - } - - return size + m.mu.RLock() + defer m.mu.RUnlock() + return m.size } -func (m *MemoryFS) Set(key string, src []byte) error { +func (m *MemoryFS) getKeyLock(key string) *sync.RWMutex { + mu, _ := m.keyLocks.LoadOrStore(key, &sync.RWMutex{}) + return mu.(*sync.RWMutex) +} + +func (m *MemoryFS) Create(key string, size int64) (io.WriteCloser, error) { + m.mu.RLock() if m.capacity > 0 { - if size := m.Size() + int64(len(src)); size > m.capacity { - return vfserror.ErrDiskFull + if m.size+size > m.capacity { + m.mu.RUnlock() + return nil, vfserror.ErrDiskFull } } + m.mu.RUnlock() - m.mu.Lock() - defer m.mu.Unlock() + keyMu := m.getKeyLock(key) + keyMu.Lock() + defer keyMu.Unlock() - // Use pooled slice - data := m.bytePool.Get().([]byte) - if cap(data) < len(src) { - data = make([]byte, len(src)) // expand the slice if the pool slice is too small - } else { - data = data[:len(src)] // reuse the pool slice, but resize it to fit - } - copy(data, src) + buf := &bytes.Buffer{} - m.files[key] = &file{ - fileinfo: vfs.NewFileInfo( - key, - int64(len(src)), - time.Now(), - ), - data: data, - } + return &memWriteCloser{ + Writer: buf, + onClose: func() error { + data := buf.Bytes() + m.mu.Lock() + if f, exists := m.files[key]; exists { + m.size -= int64(len(f.data)) + m.LRU.Remove(key) + } + fi := vfs.NewFileInfo(key, int64(len(data)), time.Now()) + m.files[key] = &file{ + fileinfo: fi, + data: data, + } + m.LRU.Add(key, fi) + m.size += int64(len(data)) + m.mu.Unlock() - return nil + memoryWriteBytes.Add(float64(len(data))) + memorySizeBytes.Set(float64(m.Size())) + + return nil + }, + }, nil +} + +type memWriteCloser struct { + io.Writer + onClose func() error +} + +func (wc *memWriteCloser) Close() error { + return wc.onClose() } func (m *MemoryFS) Delete(key string) error { - _, err := m.Stat(key) - if err != nil { - return err - } + keyMu := m.getKeyLock(key) + keyMu.Lock() + defer keyMu.Unlock() m.mu.Lock() - defer m.mu.Unlock() - - // Return data to pool - if f, ok := m.files[key]; ok { - m.bytePool.Put(f.data) + f, exists := m.files[key] + if !exists { + m.mu.Unlock() + return vfserror.ErrNotFound } - + m.size -= int64(len(f.data)) + m.LRU.Remove(key) delete(m.files, key) + m.mu.Unlock() + + memorySizeBytes.Set(float64(m.Size())) return nil } -func (m *MemoryFS) Get(key string) ([]byte, error) { - _, err := m.Stat(key) - if err != nil { - return nil, err - } +func (m *MemoryFS) Open(key string) (io.ReadCloser, error) { + keyMu := m.getKeyLock(key) + keyMu.RLock() + defer keyMu.RUnlock() m.mu.Lock() - defer m.mu.Unlock() + f, exists := m.files[key] + if !exists { + m.mu.Unlock() + return nil, vfserror.ErrNotFound + } + f.fileinfo.ATime = time.Now() + m.LRU.MoveToFront(key) + dataCopy := make([]byte, len(f.data)) + copy(dataCopy, f.data) + m.mu.Unlock() - m.files[key].fileinfo.ATime = time.Now() - dst := make([]byte, len(m.files[key].data)) - copy(dst, m.files[key].data) + memoryReadBytes.Add(float64(len(dataCopy))) + memorySizeBytes.Set(float64(m.Size())) - logger.Logger.Debug(). - Str("name", key). - Str("status", "GET"). - Int64("size", int64(len(dst))). - Msg("get file from memory") - - return dst, nil + return io.NopCloser(bytes.NewReader(dataCopy)), nil } func (m *MemoryFS) Stat(key string) (*vfs.FileInfo, error) { - m.mu.Lock() - defer m.mu.Unlock() + keyMu := m.getKeyLock(key) + keyMu.RLock() + defer keyMu.RUnlock() + + m.mu.RLock() + defer m.mu.RUnlock() f, ok := m.files[key] if !ok { @@ -155,8 +260,8 @@ func (m *MemoryFS) Stat(key string) (*vfs.FileInfo, error) { } func (m *MemoryFS) StatAll() []*vfs.FileInfo { - m.mu.Lock() - defer m.mu.Unlock() + m.mu.RLock() + defer m.mu.RUnlock() // hard copy the file info to prevent modification of the original file info or the other way around files := make([]*vfs.FileInfo, 0, len(m.files)) diff --git a/vfs/memory/memory_test.go b/vfs/memory/memory_test.go index 48d48f3..9dde4a6 100644 --- a/vfs/memory/memory_test.go +++ b/vfs/memory/memory_test.go @@ -1,7 +1,9 @@ +// vfs/memory/memory_test.go package memory import ( "fmt" + "io" "s1d3sw1ped/SteamCache2/vfs/vfserror" "testing" ) @@ -10,17 +12,27 @@ func TestAllMemory(t *testing.T) { t.Parallel() m := New(1024) - if err := m.Set("key", []byte("value")); err != nil { - t.Errorf("Set failed: %v", err) + w, err := m.Create("key", 5) + if err != nil { + t.Errorf("Create failed: %v", err) } + w.Write([]byte("value")) + w.Close() - if err := m.Set("key", []byte("value1")); err != nil { - t.Errorf("Set failed: %v", err) + w, err = m.Create("key", 6) + if err != nil { + t.Errorf("Create failed: %v", err) } + w.Write([]byte("value1")) + w.Close() - if d, err := m.Get("key"); err != nil { - t.Errorf("Get failed: %v", err) - } else if string(d) != "value1" { + rc, err := m.Open("key") + if err != nil { + t.Errorf("Open failed: %v", err) + } + d, _ := io.ReadAll(rc) + rc.Close() + if string(d) != "value1" { t.Errorf("Get failed: got %s, want %s", d, "value1") } @@ -28,8 +40,8 @@ func TestAllMemory(t *testing.T) { t.Errorf("Delete failed: %v", err) } - if _, err := m.Get("key"); err == nil { - t.Errorf("Get failed: got nil, want %v", vfserror.ErrNotFound) + if _, err := m.Open("key"); err == nil { + t.Errorf("Open failed: got nil, want %v", vfserror.ErrNotFound) } if err := m.Delete("key"); err == nil { @@ -40,9 +52,12 @@ func TestAllMemory(t *testing.T) { t.Errorf("Stat failed: got nil, want %v", vfserror.ErrNotFound) } - if err := m.Set("key", []byte("value")); err != nil { - t.Errorf("Set failed: %v", err) + w, err = m.Create("key", 5) + if err != nil { + t.Errorf("Create failed: %v", err) } + w.Write([]byte("value")) + w.Close() if _, err := m.Stat("key"); err != nil { t.Errorf("Stat failed: %v", err) @@ -54,10 +69,15 @@ func TestLimited(t *testing.T) { m := New(10) for i := 0; i < 11; i++ { - if err := m.Set(fmt.Sprintf("key%d", i), []byte("1")); err != nil && i < 10 { - t.Errorf("Set failed: %v", err) + w, err := m.Create(fmt.Sprintf("key%d", i), 1) + if err != nil && i < 10 { + t.Errorf("Create failed: %v", err) } else if i == 10 && err == nil { - t.Errorf("Set succeeded: got nil, want %v", vfserror.ErrDiskFull) + t.Errorf("Create succeeded: got nil, want %v", vfserror.ErrDiskFull) + } + if i < 10 { + w.Write([]byte("1")) + w.Close() } } } diff --git a/vfs/sync/sync.go b/vfs/sync/sync.go deleted file mode 100644 index 737cee3..0000000 --- a/vfs/sync/sync.go +++ /dev/null @@ -1,76 +0,0 @@ -package sync - -// import ( -// "fmt" -// "s1d3sw1ped/SteamCache2/vfs" -// "sync" -// ) - -// // Ensure SyncFS implements VFS. -// var _ vfs.VFS = (*SyncFS)(nil) - -// type SyncFS struct { -// vfs vfs.VFS -// mu sync.RWMutex -// } - -// func New(vfs vfs.VFS) *SyncFS { -// return &SyncFS{ -// vfs: vfs, -// mu: sync.RWMutex{}, -// } -// } - -// // Name returns the name of the file system. -// func (sfs *SyncFS) Name() string { -// return fmt.Sprintf("SyncFS(%s)", sfs.vfs.Name()) -// } - -// // Size returns the total size of all files in the file system. -// func (sfs *SyncFS) Size() int64 { -// sfs.mu.RLock() -// defer sfs.mu.RUnlock() - -// return sfs.vfs.Size() -// } - -// // Set sets the value of key as src. -// // Setting the same key multiple times, the last set call takes effect. -// func (sfs *SyncFS) Set(key string, src []byte) error { -// sfs.mu.Lock() -// defer sfs.mu.Unlock() - -// return sfs.vfs.Set(key, src) -// } - -// // Delete deletes the value of key. -// func (sfs *SyncFS) Delete(key string) error { -// sfs.mu.Lock() -// defer sfs.mu.Unlock() - -// return sfs.vfs.Delete(key) -// } - -// // Get gets the value of key to dst, and returns dst no matter whether or not there is an error. -// func (sfs *SyncFS) Get(key string) ([]byte, error) { -// sfs.mu.RLock() -// defer sfs.mu.RUnlock() - -// return sfs.vfs.Get(key) -// } - -// // Stat returns the FileInfo of key. -// func (sfs *SyncFS) Stat(key string) (*vfs.FileInfo, error) { -// sfs.mu.RLock() -// defer sfs.mu.RUnlock() - -// return sfs.vfs.Stat(key) -// } - -// // StatAll returns the FileInfo of all keys. -// func (sfs *SyncFS) StatAll() []*vfs.FileInfo { -// sfs.mu.RLock() -// defer sfs.mu.RUnlock() - -// return sfs.vfs.StatAll() -// } diff --git a/vfs/vfs.go b/vfs/vfs.go index 0812453..596287c 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -1,5 +1,8 @@ +// vfs/vfs.go package vfs +import "io" + // VFS is the interface that wraps the basic methods of a virtual file system. type VFS interface { // Name returns the name of the file system. @@ -8,15 +11,14 @@ type VFS interface { // Size returns the total size of all files in the file system. Size() int64 - // Set sets the value of key as src. - // Setting the same key multiple times, the last set call takes effect. - Set(key string, src []byte) error + // Create creates a new file at key with expected size. + Create(key string, size int64) (io.WriteCloser, error) // Delete deletes the value of key. Delete(key string) error - // Get gets the value of key to dst, and returns dst no matter whether or not there is an error. - Get(key string) ([]byte, error) + // Open opens the file at key. + Open(key string) (io.ReadCloser, error) // Stat returns the FileInfo of key. Stat(key string) (*FileInfo, error) diff --git a/vfs/vfserror/vfserror.go b/vfs/vfserror/vfserror.go index 5e63359..5f0f5fc 100644 --- a/vfs/vfserror/vfserror.go +++ b/vfs/vfserror/vfserror.go @@ -1,3 +1,4 @@ +// vfs/vfserror/vfserror.go package vfserror import "errors" From 539f14e8ec7f3a55bdd0246157ea44a1fbe1e665 Mon Sep 17 00:00:00 2001 From: Justin Harms Date: Sun, 13 Jul 2025 04:20:12 -0500 Subject: [PATCH 5/5] refactor: moved the GC stuff around and corrected all tests --- steamcache/gc.go | 63 ------------- steamcache/steamcache.go | 8 +- steamcache/steamcache_test.go | 28 +++++- vfs/cache/cache_test.go | 59 ++---------- vfs/disk/disk.go | 68 +++++++------- vfs/disk/disk_test.go | 169 ++++++++++++++++------------------ vfs/gc/gc.go | 68 +++++++++++++- vfs/gc/gc_test.go | 143 ++++++++++++---------------- vfs/memory/memory_test.go | 138 ++++++++++++++++++--------- 9 files changed, 368 insertions(+), 376 deletions(-) delete mode 100644 steamcache/gc.go diff --git a/steamcache/gc.go b/steamcache/gc.go deleted file mode 100644 index ac79359..0000000 --- a/steamcache/gc.go +++ /dev/null @@ -1,63 +0,0 @@ -// steamcache/gc.go -package steamcache - -import ( - "s1d3sw1ped/SteamCache2/vfs" - "s1d3sw1ped/SteamCache2/vfs/cachestate" - "s1d3sw1ped/SteamCache2/vfs/disk" - "s1d3sw1ped/SteamCache2/vfs/memory" - "time" -) - -// lruGC deletes files in LRU order until enough space is reclaimed. -func lruGC(vfss vfs.VFS, size uint) { - deletions := 0 - var reclaimed uint - - for reclaimed < size { - switch fs := vfss.(type) { - case *disk.DiskFS: - fi := fs.LRU.Back() - if fi == nil { - break - } - sz := uint(fi.Size()) - err := fs.Delete(fi.Name()) - if err != nil { - continue - } - reclaimed += sz - deletions++ - case *memory.MemoryFS: - fi := fs.LRU.Back() - if fi == nil { - break - } - sz := uint(fi.Size()) - err := fs.Delete(fi.Name()) - if err != nil { - continue - } - reclaimed += sz - deletions++ - default: - // Fallback to old method if not supported - stats := vfss.StatAll() - if len(stats) == 0 { - break - } - fi := stats[0] // Assume sorted or pick first - sz := uint(fi.Size()) - err := vfss.Delete(fi.Name()) - if err != nil { - continue - } - reclaimed += sz - deletions++ - } - } -} - -func cachehandler(fi *vfs.FileInfo, cs cachestate.CacheState) bool { - return time.Since(fi.AccessTime()) < time.Second*60 // Put hot files in the fast vfs if equipped -} diff --git a/steamcache/steamcache.go b/steamcache/steamcache.go index 53794dd..1845d5d 100644 --- a/steamcache/steamcache.go +++ b/steamcache/steamcache.go @@ -80,21 +80,21 @@ func New(address string, memorySize string, memoryMultiplier int, diskSize strin } c := cache.New( - cachehandler, + gc.PromotionDecider, ) var m *memory.MemoryFS var mgc *gc.GCFS if memorysize > 0 { m = memory.New(memorysize) - mgc = gc.New(m, memoryMultiplier, lruGC) + mgc = gc.New(m, memoryMultiplier, gc.LRUGC) } var d *disk.DiskFS var dgc *gc.GCFS if disksize > 0 { d = disk.New(diskPath, disksize) - dgc = gc.New(d, diskMultiplier, lruGC) + dgc = gc.New(d, diskMultiplier, gc.LRUGC) } // configure the cache to match the specified mode (memory only, disk only, or memory and disk) based on the provided sizes @@ -154,7 +154,7 @@ func New(address string, memorySize string, memoryMultiplier int, diskSize strin if d != nil { if d.Size() > d.Capacity() { - lruGC(d, uint(d.Size()-d.Capacity())) + gc.LRUGC(d, uint(d.Size()-d.Capacity())) } } diff --git a/steamcache/steamcache_test.go b/steamcache/steamcache_test.go index cc52616..b1f23d8 100644 --- a/steamcache/steamcache_test.go +++ b/steamcache/steamcache_test.go @@ -9,8 +9,6 @@ import ( ) func TestCaching(t *testing.T) { - t.Parallel() - td := t.TempDir() os.WriteFile(filepath.Join(td, "key2"), []byte("value2"), 0644) @@ -84,3 +82,29 @@ func TestCaching(t *testing.T) { t.Errorf("Open failed: got nil, want error") } } + +func TestCacheMissAndHit(t *testing.T) { + sc := New("localhost:8080", "0", 0, "1G", 100, t.TempDir(), "") + + key := "testkey" + value := []byte("testvalue") + + // Simulate miss: but since no upstream, skip full ServeHTTP, test VFS + w, err := sc.vfs.Create(key, int64(len(value))) + if err != nil { + t.Fatal(err) + } + w.Write(value) + w.Close() + + rc, err := sc.vfs.Open(key) + if err != nil { + t.Fatal(err) + } + got, _ := io.ReadAll(rc) + rc.Close() + + if string(got) != string(value) { + t.Errorf("expected %s, got %s", value, got) + } +} diff --git a/vfs/cache/cache_test.go b/vfs/cache/cache_test.go index dd4e2b0..72b1ef4 100644 --- a/vfs/cache/cache_test.go +++ b/vfs/cache/cache_test.go @@ -17,8 +17,6 @@ func testMemory() vfs.VFS { } func TestNew(t *testing.T) { - t.Parallel() - fast := testMemory() slow := testMemory() @@ -31,8 +29,6 @@ func TestNew(t *testing.T) { } func TestNewPanics(t *testing.T) { - t.Parallel() - defer func() { if r := recover(); r == nil { t.Fatal("expected panic but did not get one") @@ -44,9 +40,7 @@ func TestNewPanics(t *testing.T) { cache.SetSlow(nil) } -func TestSetAndGet(t *testing.T) { - t.Parallel() - +func TestCreateAndOpen(t *testing.T) { fast := testMemory() slow := testMemory() cache := New(nil) @@ -75,9 +69,7 @@ func TestSetAndGet(t *testing.T) { } } -func TestSetAndGetNoFast(t *testing.T) { - t.Parallel() - +func TestCreateAndOpenNoFast(t *testing.T) { slow := testMemory() cache := New(nil) cache.SetSlow(slow) @@ -104,9 +96,7 @@ func TestSetAndGetNoFast(t *testing.T) { } } -func TestCaching(t *testing.T) { - t.Parallel() - +func TestCachingPromotion(t *testing.T) { fast := testMemory() slow := testMemory() cache := New(func(fi *vfs.FileInfo, cs cachestate.CacheState) bool { @@ -118,57 +108,29 @@ func TestCaching(t *testing.T) { key := "test" value := []byte("value") - wf, _ := fast.Create(key, int64(len(value))) - wf.Write(value) - wf.Close() - ws, _ := slow.Create(key, int64(len(value))) ws.Write(value) ws.Close() - state := cache.cacheState(key) - if state != cachestate.CacheStateHit { - t.Fatalf("expected %v, got %v", cachestate.CacheStateHit, state) - } - - err := fast.Delete(key) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - rc, err := cache.Open(key) if err != nil { t.Fatalf("unexpected error: %v", err) } got, _ := io.ReadAll(rc) rc.Close() - state = cache.cacheState(key) - if state != cachestate.CacheStateMiss { - t.Fatalf("expected %v, got %v", cachestate.CacheStateMiss, state) - } if string(got) != string(value) { t.Fatalf("expected %s, got %s", value, got) } - err = cache.Delete(key) + // Check if promoted to fast + _, err = fast.Open(key) if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - _, err = cache.Open(key) - if !errors.Is(err, vfserror.ErrNotFound) { - t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err) - } - state = cache.cacheState(key) - if state != cachestate.CacheStateNotFound { - t.Fatalf("expected %v, got %v", cachestate.CacheStateNotFound, state) + t.Error("Expected promotion to fast cache") } } -func TestGetNotFound(t *testing.T) { - t.Parallel() - +func TestOpenNotFound(t *testing.T) { fast := testMemory() slow := testMemory() cache := New(nil) @@ -182,8 +144,6 @@ func TestGetNotFound(t *testing.T) { } func TestDelete(t *testing.T) { - t.Parallel() - fast := testMemory() slow := testMemory() cache := New(nil) @@ -211,8 +171,6 @@ func TestDelete(t *testing.T) { } func TestStat(t *testing.T) { - t.Parallel() - fast := testMemory() slow := testMemory() cache := New(nil) @@ -237,4 +195,7 @@ func TestStat(t *testing.T) { if info == nil { t.Fatal("expected file info to be non-nil") } + if info.Size() != int64(len(value)) { + t.Errorf("expected size %d, got %d", len(value), info.Size()) + } } diff --git a/vfs/disk/disk.go b/vfs/disk/disk.go index eb3a574..e45c7b8 100644 --- a/vfs/disk/disk.go +++ b/vfs/disk/disk.go @@ -61,7 +61,6 @@ type DiskFS struct { size int64 mu sync.RWMutex keyLocks sync.Map // map[string]*sync.RWMutex - sg sync.WaitGroup LRU *lruList } @@ -135,7 +134,6 @@ func new(root string, capacity int64, skipinit bool) *DiskFS { capacity: capacity, mu: sync.RWMutex{}, keyLocks: sync.Map{}, - sg: sync.WaitGroup{}, LRU: newLruList(), } @@ -162,8 +160,28 @@ func NewSkipInit(root string, capacity int64) *DiskFS { func (d *DiskFS) init() { tstart := time.Now() - d.walk(d.root) - d.sg.Wait() + err := filepath.Walk(d.root, func(npath string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + d.mu.Lock() + k := strings.ReplaceAll(npath[len(d.root)+1:], "\\", "/") + fi := vfs.NewFileInfoFromOS(info, k) + d.info[k] = fi + d.LRU.Add(k, fi) + d.size += info.Size() + d.mu.Unlock() + + return nil + }) + if err != nil { + logger.Logger.Error().Err(err).Msg("Walk failed") + } logger.Logger.Info(). Str("name", d.Name()). @@ -175,37 +193,6 @@ func (d *DiskFS) init() { Msg("init") } -func (d *DiskFS) walk(path string) { - d.sg.Add(1) - go func() { - defer d.sg.Done() - filepath.Walk(path, func(npath string, info os.FileInfo, err error) error { - if path == npath { - return nil - } - - if err != nil { - return err - } - - if info.IsDir() { - d.walk(npath) - return filepath.SkipDir - } - - d.mu.Lock() - k := strings.ReplaceAll(npath[len(d.root)+1:], "\\", "/") - fi := vfs.NewFileInfoFromOS(info, k) - d.info[k] = fi - d.LRU.Add(k, fi) - d.size += info.Size() - d.mu.Unlock() - - return nil - }) - }() -} - func (d *DiskFS) Capacity() int64 { return d.capacity } @@ -235,6 +222,7 @@ func (d *DiskFS) Create(key string, size int64) (io.WriteCloser, error) { // Sanitize key to prevent path traversal key = filepath.Clean(key) + key = strings.ReplaceAll(key, "\\", "/") // Ensure forward slashes for consistency if strings.Contains(key, "..") { return nil, vfserror.ErrInvalidKey } @@ -257,11 +245,14 @@ func (d *DiskFS) Create(key string, size int64) (io.WriteCloser, error) { if fi, exists := d.info[key]; exists { d.size -= fi.Size() d.LRU.Remove(key) - d.Delete(key) + delete(d.info, key) + path := filepath.Join(d.root, key) + os.Remove(path) // Ignore error, as file might not exist or other issues } d.mu.Unlock() path := filepath.Join(d.root, key) + path = strings.ReplaceAll(path, "\\", "/") // Ensure forward slashes for consistency dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0755); err != nil { return nil, err @@ -332,6 +323,7 @@ func (d *DiskFS) Delete(key string) error { // Sanitize key to prevent path traversal key = filepath.Clean(key) + key = strings.ReplaceAll(key, "\\", "/") // Ensure forward slashes for consistency if strings.Contains(key, "..") { return vfserror.ErrInvalidKey } @@ -352,6 +344,7 @@ func (d *DiskFS) Delete(key string) error { d.mu.Unlock() path := filepath.Join(d.root, key) + path = strings.ReplaceAll(path, "\\", "/") // Ensure forward slashes for consistency if err := os.Remove(path); err != nil { return err } @@ -372,6 +365,7 @@ func (d *DiskFS) Open(key string) (io.ReadCloser, error) { // Sanitize key to prevent path traversal key = filepath.Clean(key) + key = strings.ReplaceAll(key, "\\", "/") // Ensure forward slashes for consistency if strings.Contains(key, "..") { return nil, vfserror.ErrInvalidKey } @@ -391,6 +385,7 @@ func (d *DiskFS) Open(key string) (io.ReadCloser, error) { d.mu.Unlock() path := filepath.Join(d.root, key) + path = strings.ReplaceAll(path, "\\", "/") // Ensure forward slashes for consistency file, err := os.Open(path) if err != nil { return nil, err @@ -434,6 +429,7 @@ func (d *DiskFS) Stat(key string) (*vfs.FileInfo, error) { // Sanitize key to prevent path traversal key = filepath.Clean(key) + key = strings.ReplaceAll(key, "\\", "/") // Ensure forward slashes for consistency if strings.Contains(key, "..") { return nil, vfserror.ErrInvalidKey } diff --git a/vfs/disk/disk_test.go b/vfs/disk/disk_test.go index 3a43892..d63d5e4 100644 --- a/vfs/disk/disk_test.go +++ b/vfs/disk/disk_test.go @@ -2,6 +2,7 @@ package disk import ( + "errors" "fmt" "io" "os" @@ -10,65 +11,85 @@ import ( "testing" ) -func TestAllDisk(t *testing.T) { - t.Parallel() - +func TestCreateAndOpen(t *testing.T) { m := NewSkipInit(t.TempDir(), 1024) - w, err := m.Create("key", 5) + key := "key" + value := []byte("value") + + w, err := m.Create(key, int64(len(value))) if err != nil { - t.Errorf("Create failed: %v", err) + t.Fatalf("Create failed: %v", err) } - w.Write([]byte("value")) + w.Write(value) w.Close() - w, err = m.Create("key", 6) + rc, err := m.Open(key) if err != nil { - t.Errorf("Create failed: %v", err) + t.Fatalf("Open failed: %v", err) } - w.Write([]byte("value1")) - w.Close() - - rc, err := m.Open("key") - if err != nil { - t.Errorf("Open failed: %v", err) - } - d, _ := io.ReadAll(rc) + got, _ := io.ReadAll(rc) rc.Close() - if string(d) != "value1" { - t.Errorf("Get failed: got %s, want %s", d, "value1") - } - if err := m.Delete("key"); err != nil { - t.Errorf("Delete failed: %v", err) - } - - if _, err := m.Open("key"); err == nil { - t.Errorf("Open failed: got nil, want %v", vfserror.ErrNotFound) - } - - if err := m.Delete("key"); err == nil { - t.Errorf("Delete failed: got nil, want %v", vfserror.ErrNotFound) - } - - if _, err := m.Stat("key"); err == nil { - t.Errorf("Stat failed: got nil, want %v", vfserror.ErrNotFound) - } - - w, err = m.Create("key", 5) - if err != nil { - t.Errorf("Create failed: %v", err) - } - w.Write([]byte("value")) - w.Close() - - if _, err := m.Stat("key"); err != nil { - t.Errorf("Stat failed: %v", err) + if string(got) != string(value) { + t.Fatalf("expected %s, got %s", value, got) } } -func TestLimited(t *testing.T) { - t.Parallel() +func TestOverwrite(t *testing.T) { + m := NewSkipInit(t.TempDir(), 1024) + key := "key" + value1 := []byte("value1") + value2 := []byte("value2") + w, err := m.Create(key, int64(len(value1))) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + w.Write(value1) + w.Close() + + w, err = m.Create(key, int64(len(value2))) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + w.Write(value2) + w.Close() + + rc, err := m.Open(key) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + got, _ := io.ReadAll(rc) + rc.Close() + + if string(got) != string(value2) { + t.Fatalf("expected %s, got %s", value2, got) + } +} + +func TestDelete(t *testing.T) { + m := NewSkipInit(t.TempDir(), 1024) + key := "key" + value := []byte("value") + + w, err := m.Create(key, int64(len(value))) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + w.Write(value) + w.Close() + + if err := m.Delete(key); err != nil { + t.Fatalf("Delete failed: %v", err) + } + + _, err = m.Open(key) + if !errors.Is(err, vfserror.ErrNotFound) { + t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err) + } +} + +func TestCapacityLimit(t *testing.T) { m := NewSkipInit(t.TempDir(), 10) for i := 0; i < 11; i++ { w, err := m.Create(fmt.Sprintf("key%d", i), 1) @@ -84,15 +105,11 @@ func TestLimited(t *testing.T) { } } -func TestInit(t *testing.T) { - t.Parallel() - +func TestInitExistingFiles(t *testing.T) { td := t.TempDir() path := filepath.Join(td, "test", "key") - os.MkdirAll(filepath.Dir(path), 0755) - os.WriteFile(path, []byte("value"), 0644) m := New(td, 10) @@ -100,8 +117,13 @@ func TestInit(t *testing.T) { if err != nil { t.Fatalf("Open failed: %v", err) } + got, _ := io.ReadAll(rc) rc.Close() + if string(got) != "value" { + t.Errorf("expected value, got %s", got) + } + s, err := m.Stat("test/key") if err != nil { t.Fatalf("Stat failed: %v", err) @@ -114,16 +136,13 @@ func TestInit(t *testing.T) { } } -func TestDiskSizeDiscrepancy(t *testing.T) { - t.Parallel() +func TestSizeConsistency(t *testing.T) { td := t.TempDir() - - assumedSize := int64(6 + 5 + 6) // 6 + 5 + 6 bytes for key, key1, key2 os.WriteFile(filepath.Join(td, "key2"), []byte("value2"), 0644) m := New(td, 1024) if m.Size() != 6 { - t.Errorf("Size failed: got %d, want %d", m.Size(), 6) + t.Errorf("Size failed: got %d, want 6", m.Size()) } w, err := m.Create("key", 5) @@ -140,6 +159,7 @@ func TestDiskSizeDiscrepancy(t *testing.T) { w.Write([]byte("value1")) w.Close() + assumedSize := int64(6 + 5 + 6) if assumedSize != m.Size() { t.Errorf("Size failed: got %d, want %d", m.Size(), assumedSize) } @@ -151,45 +171,10 @@ func TestDiskSizeDiscrepancy(t *testing.T) { d, _ := io.ReadAll(rc) rc.Close() if string(d) != "value" { - t.Errorf("Get failed: got %s, want %s", d, "value") - } - - rc, err = m.Open("key1") - if err != nil { - t.Errorf("Open failed: %v", err) - } - d, _ = io.ReadAll(rc) - rc.Close() - if string(d) != "value1" { - t.Errorf("Get failed: got %s, want %s", d, "value1") + t.Errorf("Get failed: got %s, want value", d) } m = New(td, 1024) - - if assumedSize != m.Size() { - t.Errorf("Size failed: got %d, want %d", m.Size(), assumedSize) - } - - rc, err = m.Open("key") - if err != nil { - t.Errorf("Open failed: %v", err) - } - d, _ = io.ReadAll(rc) - rc.Close() - if string(d) != "value" { - t.Errorf("Get failed: got %s, want %s", d, "value") - } - - rc, err = m.Open("key1") - if err != nil { - t.Errorf("Open failed: %v", err) - } - d, _ = io.ReadAll(rc) - rc.Close() - if string(d) != "value1" { - t.Errorf("Get failed: got %s, want %s", d, "value1") - } - if assumedSize != m.Size() { t.Errorf("Size failed: got %d, want %d", m.Size(), assumedSize) } diff --git a/vfs/gc/gc.go b/vfs/gc/gc.go index a7ed46d..6669095 100644 --- a/vfs/gc/gc.go +++ b/vfs/gc/gc.go @@ -4,10 +4,76 @@ package gc import ( "fmt" "io" + "s1d3sw1ped/SteamCache2/steamcache/logger" "s1d3sw1ped/SteamCache2/vfs" + "s1d3sw1ped/SteamCache2/vfs/cachestate" + "s1d3sw1ped/SteamCache2/vfs/disk" + "s1d3sw1ped/SteamCache2/vfs/memory" "s1d3sw1ped/SteamCache2/vfs/vfserror" + "time" ) +// LRUGC deletes files in LRU order until enough space is reclaimed. +func LRUGC(vfss vfs.VFS, size uint) { + attempts := 0 + deletions := 0 + var reclaimed uint + + for reclaimed < size { + if attempts > 10 { + logger.Logger.Debug(). + Int("attempts", attempts). + Msg("GC: Too many attempts to reclaim space, giving up") + return + } + attempts++ + switch fs := vfss.(type) { + case *disk.DiskFS: + fi := fs.LRU.Back() + if fi == nil { + break + } + sz := uint(fi.Size()) + err := fs.Delete(fi.Name()) + if err != nil { + continue + } + reclaimed += sz + deletions++ + case *memory.MemoryFS: + fi := fs.LRU.Back() + if fi == nil { + break + } + sz := uint(fi.Size()) + err := fs.Delete(fi.Name()) + if err != nil { + continue + } + reclaimed += sz + deletions++ + default: + // Fallback to old method if not supported + stats := vfss.StatAll() + if len(stats) == 0 { + break + } + fi := stats[0] // Assume sorted or pick first + sz := uint(fi.Size()) + err := vfss.Delete(fi.Name()) + if err != nil { + continue + } + reclaimed += sz + deletions++ + } + } +} + +func PromotionDecider(fi *vfs.FileInfo, cs cachestate.CacheState) bool { + return time.Since(fi.AccessTime()) < time.Second*60 // Put hot files in the fast vfs if equipped +} + // Ensure GCFS implements VFS. var _ vfs.VFS = (*GCFS)(nil) @@ -39,7 +105,7 @@ func (g *GCFS) Create(key string, size int64) (io.WriteCloser, error) { w, err := g.VFS.Create(key, size) // try to create the key // if it fails due to disk full error, call the GC handler and try again in a loop that will continue until it succeeds or the error is not disk full - for err == vfserror.ErrDiskFull && g.gcHanderFunc != nil { // if the error is disk full and there is a GC handler + if err == vfserror.ErrDiskFull && g.gcHanderFunc != nil { // if the error is disk full and there is a GC handler g.gcHanderFunc(g.VFS, uint(size*int64(g.multiplier))) // call the GC handler w, err = g.VFS.Create(key, size) } diff --git a/vfs/gc/gc_test.go b/vfs/gc/gc_test.go index 07b9926..3b3ad3f 100644 --- a/vfs/gc/gc_test.go +++ b/vfs/gc/gc_test.go @@ -1,96 +1,73 @@ // vfs/gc/gc_test.go package gc -// func TestGCSmallRandom(t *testing.T) { -// t.Parallel() +import ( + "errors" + "fmt" + "s1d3sw1ped/SteamCache2/vfs/memory" + "s1d3sw1ped/SteamCache2/vfs/vfserror" + "testing" +) -// m := memory.New(1024 * 1024 * 16) -// gc := New(m, 10, func(vfs vfs.VFS, size uint) (uint, uint) { -// deletions := 0 -// var reclaimed uint +func TestGCOnFull(t *testing.T) { + m := memory.New(10) + gc := New(m, 2, LRUGC) -// t.Logf("GC starting to reclaim %d bytes", size) + for i := 0; i < 5; i++ { + w, err := gc.Create(fmt.Sprintf("key%d", i), 2) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + w.Write([]byte("ab")) + w.Close() + } -// stats := vfs.StatAll() -// sort.Slice(stats, func(i, j int) bool { -// // Sort by access time so we can remove the oldest files first. -// return stats[i].AccessTime().Before(stats[j].AccessTime()) -// }) + // Cache full at 10 bytes + w, err := gc.Create("key5", 2) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + w.Write([]byte("cd")) + w.Close() -// // Delete the oldest files until we've reclaimed enough space. -// for _, s := range stats { -// sz := uint(s.Size()) // Get the size of the file -// err := vfs.Delete(s.Name()) -// if err != nil { -// panic(err) -// } -// reclaimed += sz // Track how much space we've reclaimed -// deletions++ // Track how many files we've deleted + if gc.Size() > 10 { + t.Errorf("Size exceeded: %d > 10", gc.Size()) + } -// // t.Logf("GC deleting %s, %v", s.Name(), s.AccessTime().Format(time.RFC3339Nano)) + // Check if older keys were evicted + _, err = m.Open("key0") + if err == nil { + t.Error("Expected key0 to be evicted") + } +} -// if reclaimed >= size { // We've reclaimed enough space -// break -// } -// } -// return uint(reclaimed), uint(deletions) -// }) +func TestNoGCNeeded(t *testing.T) { + m := memory.New(20) + gc := New(m, 2, LRUGC) -// for i := 0; i < 10000; i++ { -// if err := gc.Set(fmt.Sprintf("key:%d", i), genRandomData(1024*1, 1024*4)); err != nil { -// t.Errorf("Set failed: %v", err) -// } -// } + for i := 0; i < 5; i++ { + w, err := gc.Create(fmt.Sprintf("key%d", i), 2) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + w.Write([]byte("ab")) + w.Close() + } -// if gc.Size() > 1024*1024*16 { -// t.Errorf("MemoryFS size is %d, want <= 1024", m.Size()) -// } -// } + if gc.Size() != 10 { + t.Errorf("Size: got %d, want 10", gc.Size()) + } +} -// func genRandomData(min int, max int) []byte { -// data := make([]byte, rand.Intn(max-min)+min) -// rand.Read(data) -// return data -// } +func TestGCInsufficientSpace(t *testing.T) { + m := memory.New(5) + gc := New(m, 1, LRUGC) -// func TestGCLargeRandom(t *testing.T) { -// t.Parallel() - -// m := memory.New(1024 * 1024 * 16) // 16MB -// gc := New(m, 10, func(vfs vfs.VFS, size uint) (uint, uint) { -// deletions := 0 -// var reclaimed uint - -// t.Logf("GC starting to reclaim %d bytes", size) - -// stats := vfs.StatAll() -// sort.Slice(stats, func(i, j int) bool { -// // Sort by access time so we can remove the oldest files first. -// return stats[i].AccessTime().Before(stats[j].AccessTime()) -// }) - -// // Delete the oldest files until we've reclaimed enough space. -// for _, s := range stats { -// sz := uint(s.Size()) // Get the size of the file -// vfs.Delete(s.Name()) -// reclaimed += sz // Track how much space we've reclaimed -// deletions++ // Track how many files we've deleted - -// if reclaimed >= size { // We've reclaimed enough space -// break -// } -// } - -// return uint(reclaimed), uint(deletions) -// }) - -// for i := 0; i < 10000; i++ { -// if err := gc.Set(fmt.Sprintf("key:%d", i), genRandomData(1024, 1024*1024)); err != nil { -// t.Errorf("Set failed: %v", err) -// } -// } - -// if gc.Size() > 1024*1024*16 { -// t.Errorf("MemoryFS size is %d, want <= 1024", m.Size()) -// } -// } + w, err := gc.Create("key0", 10) + if err == nil { + w.Close() + t.Error("Expected ErrDiskFull") + } else if !errors.Is(err, vfserror.ErrDiskFull) { + t.Errorf("Unexpected error: %v", err) + } +} diff --git a/vfs/memory/memory_test.go b/vfs/memory/memory_test.go index 9dde4a6..d76f131 100644 --- a/vfs/memory/memory_test.go +++ b/vfs/memory/memory_test.go @@ -2,71 +2,92 @@ package memory import ( + "errors" "fmt" "io" "s1d3sw1ped/SteamCache2/vfs/vfserror" "testing" ) -func TestAllMemory(t *testing.T) { - t.Parallel() - +func TestCreateAndOpen(t *testing.T) { m := New(1024) - w, err := m.Create("key", 5) + key := "key" + value := []byte("value") + + w, err := m.Create(key, int64(len(value))) if err != nil { - t.Errorf("Create failed: %v", err) + t.Fatalf("Create failed: %v", err) } - w.Write([]byte("value")) + w.Write(value) w.Close() - w, err = m.Create("key", 6) + rc, err := m.Open(key) if err != nil { - t.Errorf("Create failed: %v", err) + t.Fatalf("Open failed: %v", err) } - w.Write([]byte("value1")) - w.Close() - - rc, err := m.Open("key") - if err != nil { - t.Errorf("Open failed: %v", err) - } - d, _ := io.ReadAll(rc) + got, _ := io.ReadAll(rc) rc.Close() - if string(d) != "value1" { - t.Errorf("Get failed: got %s, want %s", d, "value1") - } - if err := m.Delete("key"); err != nil { - t.Errorf("Delete failed: %v", err) - } - - if _, err := m.Open("key"); err == nil { - t.Errorf("Open failed: got nil, want %v", vfserror.ErrNotFound) - } - - if err := m.Delete("key"); err == nil { - t.Errorf("Delete failed: got nil, want %v", vfserror.ErrNotFound) - } - - if _, err := m.Stat("key"); err == nil { - t.Errorf("Stat failed: got nil, want %v", vfserror.ErrNotFound) - } - - w, err = m.Create("key", 5) - if err != nil { - t.Errorf("Create failed: %v", err) - } - w.Write([]byte("value")) - w.Close() - - if _, err := m.Stat("key"); err != nil { - t.Errorf("Stat failed: %v", err) + if string(got) != string(value) { + t.Fatalf("expected %s, got %s", value, got) } } -func TestLimited(t *testing.T) { - t.Parallel() +func TestOverwrite(t *testing.T) { + m := New(1024) + key := "key" + value1 := []byte("value1") + value2 := []byte("value2") + w, err := m.Create(key, int64(len(value1))) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + w.Write(value1) + w.Close() + + w, err = m.Create(key, int64(len(value2))) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + w.Write(value2) + w.Close() + + rc, err := m.Open(key) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + got, _ := io.ReadAll(rc) + rc.Close() + + if string(got) != string(value2) { + t.Fatalf("expected %s, got %s", value2, got) + } +} + +func TestDelete(t *testing.T) { + m := New(1024) + key := "key" + value := []byte("value") + + w, err := m.Create(key, int64(len(value))) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + w.Write(value) + w.Close() + + if err := m.Delete(key); err != nil { + t.Fatalf("Delete failed: %v", err) + } + + _, err = m.Open(key) + if !errors.Is(err, vfserror.ErrNotFound) { + t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err) + } +} + +func TestCapacityLimit(t *testing.T) { m := New(10) for i := 0; i < 11; i++ { w, err := m.Create(fmt.Sprintf("key%d", i), 1) @@ -81,3 +102,28 @@ func TestLimited(t *testing.T) { } } } + +func TestStat(t *testing.T) { + m := New(1024) + key := "key" + value := []byte("value") + + w, err := m.Create(key, int64(len(value))) + if err != nil { + t.Fatalf("Create failed: %v", err) + } + w.Write(value) + w.Close() + + info, err := m.Stat(key) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if info == nil { + t.Fatal("expected file info to be non-nil") + } + if info.Size() != int64(len(value)) { + t.Errorf("expected size %d, got %d", len(value), info.Size()) + } +}