fix: redo the whole caching functionality to make it really 420 blaze it fast #3

Merged
s1d3sw1ped merged 5 commits from fix/blazing-sun-speed into main 2025-07-13 05:21:20 -05:00
25 changed files with 1215 additions and 881 deletions

View File

@@ -16,7 +16,7 @@ builds:
- amd64
archives:
- format: tar.gz
- formats: tar.gz
name_template: >-
{{ .ProjectName }}_
{{- title .Os }}_
@@ -26,7 +26,7 @@ archives:
{{- if .Arm }}v{{ .Arm }}{{ end }}
format_overrides:
- goos: windows
format: zip
formats: zip
changelog:
sort: asc

View File

@@ -1,7 +1,9 @@
// cmd/root.go
package cmd
import (
"os"
"runtime"
"s1d3sw1ped/SteamCache2/steamcache"
"s1d3sw1ped/SteamCache2/steamcache/logger"
"s1d3sw1ped/SteamCache2/version"
@@ -11,6 +13,8 @@ import (
)
var (
threads int
memory string
memorymultiplier int
disk string
@@ -18,7 +22,6 @@ var (
diskpath string
upstream string
pprof bool
logLevel string
logFormat string
)
@@ -52,21 +55,29 @@ var rootCmd = &cobra.Command{
logger.Logger = zerolog.New(writer).With().Timestamp().Logger()
logger.Logger.Info().
Msg("starting SteamCache2 " + version.Version)
Msg("SteamCache2 " + version.Version + " starting...")
address := ":80"
if runtime.GOMAXPROCS(-1) != threads {
runtime.GOMAXPROCS(threads)
logger.Logger.Info().
Int("threads", threads).
Msg("Maximum number of threads set")
}
sc := steamcache.New(
":80",
address,
memory,
memorymultiplier,
disk,
diskmultiplier,
diskpath,
upstream,
pprof,
)
logger.Logger.Info().
Msg("starting SteamCache2 on port 80")
Msg("SteamCache2 " + version.Version + " started on " + address)
sc.Run()
@@ -85,6 +96,8 @@ func Execute() {
}
func init() {
rootCmd.Flags().IntVarP(&threads, "threads", "t", runtime.GOMAXPROCS(-1), "Number of worker threads to use for processing requests")
rootCmd.Flags().StringVarP(&memory, "memory", "m", "0", "The size of the memory cache")
rootCmd.Flags().IntVarP(&memorymultiplier, "memory-gc", "M", 10, "The gc value for the memory cache")
rootCmd.Flags().StringVarP(&disk, "disk", "d", "0", "The size of the disk cache")
@@ -93,9 +106,6 @@ func init() {
rootCmd.Flags().StringVarP(&upstream, "upstream", "u", "", "The upstream server to proxy requests overrides the host header from the client but forwards the original host header to the upstream server")
rootCmd.Flags().BoolVarP(&pprof, "pprof", "P", false, "Enable pprof")
rootCmd.Flags().MarkHidden("pprof")
rootCmd.Flags().StringVarP(&logLevel, "log-level", "l", "info", "Logging level: debug, info, error")
rootCmd.Flags().StringVarP(&logFormat, "log-format", "f", "console", "Logging format: json, console")
}

View File

@@ -1,3 +1,4 @@
// cmd/version.go
package cmd
import (

1
go.mod
View File

@@ -7,7 +7,6 @@ require (
github.com/prometheus/client_golang v1.22.0
github.com/rs/zerolog v1.33.0
github.com/spf13/cobra v1.8.1
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8
)
require (

2
go.sum
View File

@@ -45,8 +45,6 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA=
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -1,3 +1,4 @@
// main.go
package main
import (

View File

@@ -1,63 +0,0 @@
package avgcachestate
import (
"s1d3sw1ped/SteamCache2/vfs/cachestate"
"sync"
)
// AvgCacheState is a cache state that averages the last N cache states.
type AvgCacheState struct {
size int
avgs []cachestate.CacheState
mu sync.Mutex
}
// New creates a new average cache state with the given size.
func New(size int) *AvgCacheState {
a := &AvgCacheState{
size: size,
avgs: make([]cachestate.CacheState, size),
mu: sync.Mutex{},
}
a.Clear()
return a
}
// Clear resets the average cache state to zero.
func (a *AvgCacheState) Clear() {
a.mu.Lock()
defer a.mu.Unlock()
for i := 0; i < len(a.avgs); i++ {
a.avgs[i] = cachestate.CacheStateMiss
}
}
// Add adds a cache state to the average cache state.
func (a *AvgCacheState) Add(cs cachestate.CacheState) {
a.mu.Lock()
defer a.mu.Unlock()
a.avgs = append(a.avgs, cs)
if len(a.avgs) > a.size {
a.avgs = a.avgs[1:]
}
}
// Avg returns the average cache state.
func (a *AvgCacheState) Avg() float64 {
a.mu.Lock()
defer a.mu.Unlock()
var hits int
for _, cs := range a.avgs {
if cs == cachestate.CacheStateHit {
hits++
}
}
return float64(hits) / float64(len(a.avgs))
}

View File

@@ -1,44 +0,0 @@
package steamcache
import (
"runtime/debug"
"s1d3sw1ped/SteamCache2/vfs"
"s1d3sw1ped/SteamCache2/vfs/cachestate"
"sort"
"time"
)
func init() {
// Set the GC percentage to 50%. This is a good balance between performance and memory usage.
debug.SetGCPercent(50)
}
// lruGC deletes files in LRU order until enough space is reclaimed.
func lruGC(vfss vfs.VFS, size uint) (uint, uint) {
deletions := 0
var reclaimed uint
stats := vfss.StatAll()
sort.Slice(stats, func(i, j int) bool {
return stats[i].AccessTime().Before(stats[j].AccessTime())
})
for _, s := range stats {
sz := uint(s.Size())
err := vfss.Delete(s.Name())
if err != nil {
continue
}
reclaimed += sz
deletions++
if reclaimed >= size {
break
}
}
return reclaimed, uint(deletions)
}
func cachehandler(fi *vfs.FileInfo, cs cachestate.CacheState) bool {
return time.Since(fi.AccessTime()) < time.Second*10 // Put hot files in the fast vfs if equipped
}

View File

@@ -1,3 +1,4 @@
// steamcache/logger/logger.go
package logger
import (

View File

@@ -1,25 +1,23 @@
// steamcache/steamcache.go
package steamcache
import (
"context"
"io"
"net"
"net/http"
"net/url"
"os"
"s1d3sw1ped/SteamCache2/steamcache/avgcachestate"
"s1d3sw1ped/SteamCache2/steamcache/logger"
"s1d3sw1ped/SteamCache2/vfs"
"s1d3sw1ped/SteamCache2/vfs/cache"
"s1d3sw1ped/SteamCache2/vfs/cachestate"
"s1d3sw1ped/SteamCache2/vfs/disk"
"s1d3sw1ped/SteamCache2/vfs/gc"
"s1d3sw1ped/SteamCache2/vfs/memory"
// syncfs "s1d3sw1ped/SteamCache2/vfs/sync"
"strings"
"sync"
"time"
pprof "net/http/pprof"
"github.com/docker/go-units"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
@@ -34,16 +32,25 @@ var (
},
[]string{"method", "status"},
)
cacheHitRate = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "cache_hit_rate",
Help: "Cache hit rate",
cacheStatusTotal = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_status_total",
Help: "Total cache status counts",
},
[]string{"status"},
)
responseTime = promauto.NewHistogram(
prometheus.HistogramOpts{
Name: "response_time_seconds",
Help: "Response time in seconds",
Buckets: prometheus.DefBuckets,
},
)
)
type SteamCache struct {
pprof bool
address string
upstream string
@@ -55,10 +62,13 @@ type SteamCache struct {
memorygc *gc.GCFS
diskgc *gc.GCFS
hits *avgcachestate.AvgCacheState
server *http.Server
client *http.Client
cancel context.CancelFunc
wg sync.WaitGroup
}
func New(address string, memorySize string, memoryMultiplier int, diskSize string, diskMultiplier int, diskPath, upstream string, pprof bool) *SteamCache {
func New(address string, memorySize string, memoryMultiplier int, diskSize string, diskMultiplier int, diskPath, upstream string) *SteamCache {
memorysize, err := units.FromHumanSize(memorySize)
if err != nil {
panic(err)
@@ -70,21 +80,21 @@ func New(address string, memorySize string, memoryMultiplier int, diskSize strin
}
c := cache.New(
cachehandler,
gc.PromotionDecider,
)
var m *memory.MemoryFS
var mgc *gc.GCFS
if memorysize > 0 {
m = memory.New(memorysize)
mgc = gc.New(m, memoryMultiplier, lruGC)
mgc = gc.New(m, memoryMultiplier, gc.LRUGC)
}
var d *disk.DiskFS
var dgc *gc.GCFS
if disksize > 0 {
d = disk.New(diskPath, disksize)
dgc = gc.New(d, diskMultiplier, lruGC)
dgc = gc.New(d, diskMultiplier, gc.LRUGC)
}
// configure the cache to match the specified mode (memory only, disk only, or memory and disk) based on the provided sizes
@@ -107,25 +117,44 @@ func New(address string, memorySize string, memoryMultiplier int, diskSize strin
os.Exit(1)
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
client := &http.Client{
Transport: transport,
Timeout: 60 * time.Second,
}
sc := &SteamCache{
pprof: pprof,
upstream: upstream,
address: address,
// vfs: syncfs.New(c),
vfs: c,
memory: m,
disk: d,
vfs: c,
memory: m,
disk: d,
memorygc: mgc,
diskgc: dgc,
hits: avgcachestate.New(100),
client: client,
server: &http.Server{
Addr: address,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 120 * time.Second,
},
}
if d != nil {
if d.Size() > d.Capacity() {
lruGC(d, uint(d.Size()-d.Capacity()))
gc.LRUGC(d, uint(d.Size()-d.Capacity()))
}
}
@@ -134,32 +163,41 @@ func New(address string, memorySize string, memoryMultiplier int, diskSize strin
func (sc *SteamCache) Run() {
if sc.upstream != "" {
_, err := http.Get(sc.upstream)
if err != nil {
resp, err := sc.client.Get(sc.upstream)
if err != nil || resp.StatusCode != http.StatusOK {
logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to connect to upstream server")
os.Exit(1)
}
resp.Body.Close()
}
err := http.ListenAndServe(sc.address, sc)
if err != nil {
if err == http.ErrServerClosed {
logger.Logger.Info().Msg("shutdown")
return
sc.server.Handler = sc
ctx, cancel := context.WithCancel(context.Background())
sc.cancel = cancel
sc.wg.Add(1)
go func() {
defer sc.wg.Done()
err := sc.server.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
logger.Logger.Error().Err(err).Msg("Failed to start SteamCache2")
os.Exit(1)
}
logger.Logger.Error().Err(err).Msg("Failed to start SteamCache2")
os.Exit(1)
}()
<-ctx.Done()
sc.server.Shutdown(ctx)
sc.wg.Wait()
}
func (sc *SteamCache) Shutdown() {
if sc.cancel != nil {
sc.cancel()
}
sc.wg.Wait()
}
func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if sc.pprof && r.URL.Path == "/debug/pprof/" {
pprof.Index(w, r)
return
} else if sc.pprof && strings.HasPrefix(r.URL.Path, "/debug/pprof/") {
pprof.Handler(strings.TrimPrefix(r.URL.Path, "/debug/pprof/")).ServeHTTP(w, r)
return
}
if r.URL.Path == "/metrics" {
promhttp.Handler().ServeHTTP(w, r)
return
@@ -167,6 +205,7 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
requestsTotal.WithLabelValues(r.Method, "405").Inc()
logger.Logger.Warn().Str("method", r.Method).Msg("Only GET method is supported")
http.Error(w, "Only GET method is supported", http.StatusMethodNotAllowed)
return
}
@@ -178,124 +217,150 @@ func (sc *SteamCache) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
tstart := time.Now()
if strings.HasPrefix(r.URL.String(), "/depot/") {
tstart := time.Now()
defer func() { responseTime.Observe(time.Since(tstart).Seconds()) }()
cacheKey := strings.ReplaceAll(r.URL.String()[1:], "\\", "/") // replace all backslashes with forward slashes shouldn't be necessary but just in case
if cacheKey == "" {
requestsTotal.WithLabelValues(r.Method, "400").Inc()
http.Error(w, "Invalid URL", http.StatusBadRequest)
return
}
cacheKey := strings.ReplaceAll(r.URL.String()[1:], "\\", "/") // replace all backslashes with forward slashes shouldn't be necessary but just in case
w.Header().Add("X-LanCache-Processed-By", "SteamCache2") // SteamPrefill uses this header to determine if the request was processed by the cache maybe steam uses it too
if cacheKey == "" {
requestsTotal.WithLabelValues(r.Method, "400").Inc()
logger.Logger.Warn().Str("url", r.URL.String()).Msg("Invalid URL")
http.Error(w, "Invalid URL", http.StatusBadRequest)
return
}
data, err := sc.vfs.Get(cacheKey)
if err == nil {
sc.hits.Add(cachestate.CacheStateHit)
w.Header().Add("X-LanCache-Status", "HIT")
requestsTotal.WithLabelValues(r.Method, "200").Inc()
cacheHitRate.Set(sc.hits.Avg())
w.Header().Add("X-LanCache-Processed-By", "SteamCache2") // SteamPrefill uses this header to determine if the request was processed by the cache maybe steam uses it too
w.Write(data)
reader, err := sc.vfs.Open(cacheKey)
if err == nil {
defer reader.Close()
w.Header().Add("X-LanCache-Status", "HIT")
io.Copy(w, reader)
logger.Logger.Info().
Str("key", cacheKey).
Str("host", r.Host).
Str("status", "HIT").
Dur("duration", time.Since(tstart)).
Msg("request")
requestsTotal.WithLabelValues(r.Method, "200").Inc()
cacheStatusTotal.WithLabelValues("HIT").Inc()
return
}
var req *http.Request
if sc.upstream != "" { // if an upstream server is configured, proxy the request to the upstream server
ur, err := url.JoinPath(sc.upstream, r.URL.String())
if err != nil {
requestsTotal.WithLabelValues(r.Method, "500").Inc()
logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to join URL path")
http.Error(w, "Failed to join URL path", http.StatusInternalServerError)
return
}
req, err = http.NewRequest(http.MethodGet, ur, nil)
if err != nil {
requestsTotal.WithLabelValues(r.Method, "500").Inc()
logger.Logger.Error().Err(err).Str("upstream", sc.upstream).Msg("Failed to create request")
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
req.Host = r.Host
} else { // if no upstream server is configured, proxy the request to the host specified in the request
host := r.Host
if r.Header.Get("X-Sls-Https") == "enable" {
host = "https://" + host
} else {
host = "http://" + host
}
ur, err := url.JoinPath(host, r.URL.String())
if err != nil {
requestsTotal.WithLabelValues(r.Method, "500").Inc()
logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to join URL path")
http.Error(w, "Failed to join URL path", http.StatusInternalServerError)
return
}
req, err = http.NewRequest(http.MethodGet, ur, nil)
if err != nil {
requestsTotal.WithLabelValues(r.Method, "500").Inc()
logger.Logger.Error().Err(err).Str("host", host).Msg("Failed to create request")
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
}
// Copy headers from the original request to the new request
for key, values := range r.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
// Retry logic
backoffSchedule := []time.Duration{1 * time.Second, 3 * time.Second, 10 * time.Second}
var resp *http.Response
for i, backoff := range backoffSchedule {
resp, err = sc.client.Do(req)
if err == nil && resp.StatusCode == http.StatusOK {
break
}
if i < len(backoffSchedule)-1 {
time.Sleep(backoff)
}
}
if err != nil || resp.StatusCode != http.StatusOK {
requestsTotal.WithLabelValues(r.Method, "500 upstream host "+r.Host).Inc()
logger.Logger.Error().Err(err).Str("url", req.URL.String()).Msg("Failed to fetch the requested URL")
http.Error(w, "Failed to fetch the requested URL", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
size := resp.ContentLength
writer, err := sc.vfs.Create(cacheKey, size)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer writer.Close()
w.Header().Add("X-LanCache-Status", "MISS")
io.Copy(io.MultiWriter(w, writer), resp.Body)
logger.Logger.Info().
Str("key", cacheKey).
Str("host", r.Host).
Str("status", "HIT").
Int64("size", int64(len(data))).
Str("status", "MISS").
Dur("duration", time.Since(tstart)).
Msg("request")
requestsTotal.WithLabelValues(r.Method, "200").Inc()
cacheStatusTotal.WithLabelValues("MISS").Inc()
return
}
var req *http.Request
if sc.upstream != "" { // if an upstream server is configured, proxy the request to the upstream server
ur, err := url.JoinPath(sc.upstream, r.URL.String())
if err != nil {
requestsTotal.WithLabelValues(r.Method, "500").Inc()
http.Error(w, "Failed to join URL path", http.StatusInternalServerError)
return
}
req, err = http.NewRequest(http.MethodGet, ur, nil)
if err != nil {
requestsTotal.WithLabelValues(r.Method, "500").Inc()
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
req.Host = r.Host
} else { // if no upstream server is configured, proxy the request to the host specified in the request
host := r.Host
if r.Header.Get("X-Sls-Https") == "enable" {
host = "https://" + host
} else {
host = "http://" + host
}
ur, err := url.JoinPath(host, r.URL.String())
if err != nil {
requestsTotal.WithLabelValues(r.Method, "500").Inc()
http.Error(w, "Failed to join URL path", http.StatusInternalServerError)
return
}
req, err = http.NewRequest(http.MethodGet, ur, nil)
if err != nil {
requestsTotal.WithLabelValues(r.Method, "500").Inc()
http.Error(w, "Failed to create request", http.StatusInternalServerError)
return
}
}
// Copy headers from the original request to the new request
for key, values := range r.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
// req.Header.Add("X-Sls-Https", r.Header.Get("X-Sls-Https"))
// req.Header.Add("User-Agent", r.Header.Get("User-Agent"))
// Retry logic
backoffSchedule := []time.Duration{1 * time.Second, 3 * time.Second, 10 * time.Second}
var resp *http.Response
for i, backoff := range backoffSchedule {
resp, err = http.DefaultClient.Do(req)
if err == nil && resp.StatusCode == http.StatusOK {
break
}
if i < len(backoffSchedule)-1 {
time.Sleep(backoff)
}
}
if err != nil || resp.StatusCode != http.StatusOK {
requestsTotal.WithLabelValues(r.Method, "500").Inc()
http.Error(w, "Failed to fetch the requested URL", http.StatusInternalServerError)
return
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
requestsTotal.WithLabelValues(r.Method, "500").Inc()
http.Error(w, "Failed to read response body", http.StatusInternalServerError)
if r.URL.Path == "/favicon.ico" {
w.WriteHeader(http.StatusNoContent)
return
}
sc.vfs.Set(cacheKey, body)
sc.hits.Add(cachestate.CacheStateMiss)
w.Header().Add("X-LanCache-Status", "MISS")
requestsTotal.WithLabelValues(r.Method, "200").Inc()
cacheHitRate.Set(sc.hits.Avg())
if r.URL.Path == "/robots.txt" {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("User-agent: *\nDisallow: /\n"))
return
}
w.Write(body)
logger.Logger.Info().
Str("key", cacheKey).
Str("host", r.Host).
Str("status", "MISS").
Int64("size", int64(len(body))).
Dur("duration", time.Since(tstart)).
Msg("request")
requestsTotal.WithLabelValues(r.Method, "404").Inc()
logger.Logger.Warn().Str("url", r.URL.String()).Msg("Not found")
http.Error(w, "Not found", http.StatusNotFound)
}

View File

@@ -1,26 +1,33 @@
// steamcache/steamcache_test.go
package steamcache
import (
"io"
"os"
"path/filepath"
"testing"
)
func TestCaching(t *testing.T) {
t.Parallel()
td := t.TempDir()
os.WriteFile(filepath.Join(td, "key2"), []byte("value2"), 0644)
sc := New("localhost:8080", "1GB", 10, "1GB", 100, td, "", false)
sc := New("localhost:8080", "1G", 10, "1G", 100, td, "")
if err := sc.vfs.Set("key", []byte("value")); err != nil {
t.Errorf("Set failed: %v", err)
w, err := sc.vfs.Create("key", 5)
if err != nil {
t.Errorf("Create failed: %v", err)
}
if err := sc.vfs.Set("key1", []byte("value1")); err != nil {
t.Errorf("Set failed: %v", err)
w.Write([]byte("value"))
w.Close()
w, err = sc.vfs.Create("key1", 6)
if err != nil {
t.Errorf("Create failed: %v", err)
}
w.Write([]byte("value1"))
w.Close()
if sc.diskgc.Size() != 17 {
t.Errorf("Size failed: got %d, want %d", sc.diskgc.Size(), 17)
@@ -30,21 +37,33 @@ func TestCaching(t *testing.T) {
t.Errorf("Size failed: got %d, want %d", sc.vfs.Size(), 17)
}
if d, err := sc.vfs.Get("key"); err != nil {
t.Errorf("Get failed: %v", err)
} else if string(d) != "value" {
rc, err := sc.vfs.Open("key")
if err != nil {
t.Errorf("Open failed: %v", err)
}
d, _ := io.ReadAll(rc)
rc.Close()
if string(d) != "value" {
t.Errorf("Get failed: got %s, want %s", d, "value")
}
if d, err := sc.vfs.Get("key1"); err != nil {
t.Errorf("Get failed: %v", err)
} else if string(d) != "value1" {
rc, err = sc.vfs.Open("key1")
if err != nil {
t.Errorf("Open failed: %v", err)
}
d, _ = io.ReadAll(rc)
rc.Close()
if string(d) != "value1" {
t.Errorf("Get failed: got %s, want %s", d, "value1")
}
if d, err := sc.vfs.Get("key2"); err != nil {
t.Errorf("Get failed: %v", err)
} else if string(d) != "value2" {
rc, err = sc.vfs.Open("key2")
if err != nil {
t.Errorf("Open failed: %v", err)
}
d, _ = io.ReadAll(rc)
rc.Close()
if string(d) != "value2" {
t.Errorf("Get failed: got %s, want %s", d, "value2")
}
@@ -59,7 +78,33 @@ func TestCaching(t *testing.T) {
sc.memory.Delete("key2")
os.Remove(filepath.Join(td, "key2"))
if _, err := sc.vfs.Get("key2"); err == nil {
t.Errorf("Get failed: got nil, want error")
if _, err := sc.vfs.Open("key2"); err == nil {
t.Errorf("Open failed: got nil, want error")
}
}
func TestCacheMissAndHit(t *testing.T) {
sc := New("localhost:8080", "0", 0, "1G", 100, t.TempDir(), "")
key := "testkey"
value := []byte("testvalue")
// Simulate miss: but since no upstream, skip full ServeHTTP, test VFS
w, err := sc.vfs.Create(key, int64(len(value)))
if err != nil {
t.Fatal(err)
}
w.Write(value)
w.Close()
rc, err := sc.vfs.Open(key)
if err != nil {
t.Fatal(err)
}
got, _ := io.ReadAll(rc)
rc.Close()
if string(got) != string(value) {
t.Errorf("expected %s, got %s", value, got)
}
}

View File

@@ -1,3 +1,4 @@
// version/version.go
package version
var Version string

88
vfs/cache/cache.go vendored
View File

@@ -1,7 +1,9 @@
// vfs/cache/cache.go
package cache
import (
"fmt"
"io"
"s1d3sw1ped/SteamCache2/vfs"
"s1d3sw1ped/SteamCache2/vfs/cachestate"
"s1d3sw1ped/SteamCache2/vfs/vfserror"
@@ -73,27 +75,6 @@ func (c *CacheFS) Size() int64 {
return c.slow.Size()
}
// Set sets the file at key to src. If the file is already in the cache, it is replaced.
func (c *CacheFS) Set(key string, src []byte) error {
mu := c.getKeyLock(key)
mu.Lock()
defer mu.Unlock()
state := c.cacheState(key)
switch state {
case cachestate.CacheStateHit:
if c.fast != nil {
c.fast.Delete(key)
}
return c.slow.Set(key, src)
case cachestate.CacheStateMiss, cachestate.CacheStateNotFound:
return c.slow.Set(key, src)
}
panic(vfserror.ErrUnreachable)
}
// Delete deletes the file at key from the cache.
func (c *CacheFS) Delete(key string) error {
mu := c.getKeyLock(key)
@@ -106,14 +87,8 @@ func (c *CacheFS) Delete(key string) error {
return c.slow.Delete(key)
}
// Get returns the file at key. If the file is not in the cache, it is fetched from the storage.
func (c *CacheFS) Get(key string) ([]byte, error) {
src, _, err := c.GetS(key)
return src, err
}
// GetS returns the file at key. If the file is not in the cache, it is fetched from the storage. It also returns the cache state.
func (c *CacheFS) GetS(key string) ([]byte, cachestate.CacheState, error) {
// Open returns the file at key. If the file is not in the cache, it is fetched from the storage.
func (c *CacheFS) Open(key string) (io.ReadCloser, error) {
mu := c.getKeyLock(key)
mu.RLock()
defer mu.RUnlock()
@@ -123,27 +98,51 @@ func (c *CacheFS) GetS(key string) ([]byte, cachestate.CacheState, error) {
switch state {
case cachestate.CacheStateHit:
// if c.fast == nil then cacheState cannot be CacheStateHit so we can safely ignore the check
src, err := c.fast.Get(key)
return src, state, err
return c.fast.Open(key)
case cachestate.CacheStateMiss:
src, err := c.slow.Get(key)
slowReader, err := c.slow.Open(key)
if err != nil {
return nil, state, err
return nil, err
}
sstat, _ := c.slow.Stat(key)
if sstat != nil && c.fast != nil { // file found in slow storage and fast storage is available
// We are accessing the file from the slow storage, and the file has been accessed less then a minute ago so it popular, so we should update the fast storage with the latest file.
if c.cacheHandler != nil && c.cacheHandler(sstat, state) {
if err := c.fast.Set(key, src); err != nil {
return nil, state, err
fastWriter, err := c.fast.Create(key, sstat.Size())
if err == nil {
return &teeReadCloser{
Reader: io.TeeReader(slowReader, fastWriter),
closers: []io.Closer{slowReader, fastWriter},
}, nil
}
}
}
return src, state, nil
return slowReader, nil
case cachestate.CacheStateNotFound:
return nil, state, vfserror.ErrNotFound
return nil, vfserror.ErrNotFound
}
panic(vfserror.ErrUnreachable)
}
// Create creates a new file at key. If the file is already in the cache, it is replaced.
func (c *CacheFS) Create(key string, size int64) (io.WriteCloser, error) {
mu := c.getKeyLock(key)
mu.Lock()
defer mu.Unlock()
state := c.cacheState(key)
switch state {
case cachestate.CacheStateHit:
if c.fast != nil {
c.fast.Delete(key)
}
return c.slow.Create(key, size)
case cachestate.CacheStateMiss, cachestate.CacheStateNotFound:
return c.slow.Create(key, size)
}
panic(vfserror.ErrUnreachable)
@@ -176,3 +175,18 @@ func (c *CacheFS) Stat(key string) (*vfs.FileInfo, error) {
func (c *CacheFS) StatAll() []*vfs.FileInfo {
return c.slow.StatAll()
}
type teeReadCloser struct {
io.Reader
closers []io.Closer
}
func (t *teeReadCloser) Close() error {
var err error
for _, c := range t.closers {
if e := c.Close(); e != nil {
err = e
}
}
return err
}

View File

@@ -1,7 +1,9 @@
// vfs/cache/cache_test.go
package cache
import (
"errors"
"io"
"testing"
"s1d3sw1ped/SteamCache2/vfs"
@@ -15,8 +17,6 @@ func testMemory() vfs.VFS {
}
func TestNew(t *testing.T) {
t.Parallel()
fast := testMemory()
slow := testMemory()
@@ -29,8 +29,6 @@ func TestNew(t *testing.T) {
}
func TestNewPanics(t *testing.T) {
t.Parallel()
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic but did not get one")
@@ -42,9 +40,7 @@ func TestNewPanics(t *testing.T) {
cache.SetSlow(nil)
}
func TestSetAndGet(t *testing.T) {
t.Parallel()
func TestCreateAndOpen(t *testing.T) {
fast := testMemory()
slow := testMemory()
cache := New(nil)
@@ -54,23 +50,26 @@ func TestSetAndGet(t *testing.T) {
key := "test"
value := []byte("value")
if err := cache.Set(key, value); err != nil {
t.Fatalf("unexpected error: %v", err)
}
got, err := cache.Get(key)
w, err := cache.Create(key, int64(len(value)))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
w.Write(value)
w.Close()
rc, err := cache.Open(key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
got, _ := io.ReadAll(rc)
rc.Close()
if string(got) != string(value) {
t.Fatalf("expected %s, got %s", value, got)
}
}
func TestSetAndGetNoFast(t *testing.T) {
t.Parallel()
func TestCreateAndOpenNoFast(t *testing.T) {
slow := testMemory()
cache := New(nil)
cache.SetSlow(slow)
@@ -78,22 +77,26 @@ func TestSetAndGetNoFast(t *testing.T) {
key := "test"
value := []byte("value")
if err := cache.Set(key, value); err != nil {
t.Fatalf("unexpected error: %v", err)
}
got, err := cache.Get(key)
w, err := cache.Create(key, int64(len(value)))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
w.Write(value)
w.Close()
rc, err := cache.Open(key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
got, _ := io.ReadAll(rc)
rc.Close()
if string(got) != string(value) {
t.Fatalf("expected %s, got %s", value, got)
}
}
func TestCaching(t *testing.T) {
t.Parallel()
func TestCachingPromotion(t *testing.T) {
fast := testMemory()
slow := testMemory()
cache := New(func(fi *vfs.FileInfo, cs cachestate.CacheState) bool {
@@ -105,71 +108,42 @@ func TestCaching(t *testing.T) {
key := "test"
value := []byte("value")
if err := fast.Set(key, value); err != nil {
t.Fatalf("unexpected error: %v", err)
}
ws, _ := slow.Create(key, int64(len(value)))
ws.Write(value)
ws.Close()
if err := slow.Set(key, value); err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, state, err := cache.GetS(key)
rc, err := cache.Open(key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if state != cachestate.CacheStateHit {
t.Fatalf("expected %v, got %v", cachestate.CacheStateHit, state)
}
err = fast.Delete(key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
got, state, err := cache.GetS(key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if state != cachestate.CacheStateMiss {
t.Fatalf("expected %v, got %v", cachestate.CacheStateMiss, state)
}
got, _ := io.ReadAll(rc)
rc.Close()
if string(got) != string(value) {
t.Fatalf("expected %s, got %s", value, got)
}
err = cache.Delete(key)
// Check if promoted to fast
_, err = fast.Open(key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, state, err = cache.GetS(key)
if !errors.Is(err, vfserror.ErrNotFound) {
t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err)
}
if state != cachestate.CacheStateNotFound {
t.Fatalf("expected %v, got %v", cachestate.CacheStateNotFound, state)
t.Error("Expected promotion to fast cache")
}
}
func TestGetNotFound(t *testing.T) {
t.Parallel()
func TestOpenNotFound(t *testing.T) {
fast := testMemory()
slow := testMemory()
cache := New(nil)
cache.SetFast(fast)
cache.SetSlow(slow)
_, err := cache.Get("nonexistent")
_, err := cache.Open("nonexistent")
if !errors.Is(err, vfserror.ErrNotFound) {
t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err)
}
}
func TestDelete(t *testing.T) {
t.Parallel()
fast := testMemory()
slow := testMemory()
cache := New(nil)
@@ -179,23 +153,24 @@ func TestDelete(t *testing.T) {
key := "test"
value := []byte("value")
if err := cache.Set(key, value); err != nil {
w, err := cache.Create(key, int64(len(value)))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
w.Write(value)
w.Close()
if err := cache.Delete(key); err != nil {
t.Fatalf("unexpected error: %v", err)
}
_, err := cache.Get(key)
_, err = cache.Open(key)
if !errors.Is(err, vfserror.ErrNotFound) {
t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err)
}
}
func TestStat(t *testing.T) {
t.Parallel()
fast := testMemory()
slow := testMemory()
cache := New(nil)
@@ -205,9 +180,12 @@ func TestStat(t *testing.T) {
key := "test"
value := []byte("value")
if err := cache.Set(key, value); err != nil {
w, err := cache.Create(key, int64(len(value)))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
w.Write(value)
w.Close()
info, err := cache.Stat(key)
if err != nil {
@@ -217,4 +195,7 @@ func TestStat(t *testing.T) {
if info == nil {
t.Fatal("expected file info to be non-nil")
}
if info.Size() != int64(len(value)) {
t.Errorf("expected size %d, got %d", len(value), info.Size())
}
}

View File

@@ -1,3 +1,4 @@
// vfs/cachestate/cachestate.go
package cachestate
import "s1d3sw1ped/SteamCache2/vfs/vfserror"

View File

@@ -1,7 +1,10 @@
// vfs/disk/disk.go
package disk
import (
"container/list"
"fmt"
"io"
"os"
"path/filepath"
"s1d3sw1ped/SteamCache2/steamcache/logger"
@@ -12,6 +15,38 @@ import (
"time"
"github.com/docker/go-units"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
diskCapacityBytes = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "disk_cache_capacity_bytes",
Help: "Total capacity of the disk cache in bytes",
},
)
diskSizeBytes = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "disk_cache_size_bytes",
Help: "Total size of the disk cache in bytes",
},
)
diskReadBytes = promauto.NewCounter(
prometheus.CounterOpts{
Name: "disk_cache_read_bytes_total",
Help: "Total number of bytes read from the disk cache",
},
)
diskWriteBytes = promauto.NewCounter(
prometheus.CounterOpts{
Name: "disk_cache_write_bytes_total",
Help: "Total number of bytes written to the disk cache",
},
)
)
// Ensure DiskFS implements VFS.
@@ -23,10 +58,49 @@ type DiskFS struct {
info map[string]*vfs.FileInfo
capacity int64
mu sync.Mutex
sg sync.WaitGroup
size int64
mu sync.RWMutex
keyLocks sync.Map // map[string]*sync.RWMutex
LRU *lruList
}
bytePool sync.Pool // Pool for []byte slices
// lruList for LRU eviction
type lruList struct {
list *list.List
elem map[string]*list.Element
}
func newLruList() *lruList {
return &lruList{
list: list.New(),
elem: make(map[string]*list.Element),
}
}
func (l *lruList) MoveToFront(key string) {
if e, ok := l.elem[key]; ok {
l.list.MoveToFront(e)
}
}
func (l *lruList) Add(key string, fi *vfs.FileInfo) *list.Element {
e := l.list.PushFront(fi)
l.elem[key] = e
return e
}
func (l *lruList) Remove(key string) {
if e, ok := l.elem[key]; ok {
l.list.Remove(e)
delete(l.elem, key)
}
}
func (l *lruList) Back() *vfs.FileInfo {
if e := l.list.Back(); e != nil {
return e.Value.(*vfs.FileInfo)
}
return nil
}
// New creates a new DiskFS.
@@ -58,17 +132,18 @@ func new(root string, capacity int64, skipinit bool) *DiskFS {
root: root,
info: make(map[string]*vfs.FileInfo),
capacity: capacity,
mu: sync.Mutex{},
sg: sync.WaitGroup{},
bytePool: sync.Pool{
New: func() interface{} { return make([]byte, 0) }, // Initial capacity for pooled slices is 0, will grow as needed
},
mu: sync.RWMutex{},
keyLocks: sync.Map{},
LRU: newLruList(),
}
os.MkdirAll(dfs.root, 0755)
diskCapacityBytes.Set(float64(dfs.capacity))
if !skipinit {
dfs.init()
diskSizeBytes.Set(float64(dfs.Size()))
}
return dfs
@@ -85,8 +160,28 @@ func NewSkipInit(root string, capacity int64) *DiskFS {
func (d *DiskFS) init() {
tstart := time.Now()
d.walk(d.root)
d.sg.Wait()
err := filepath.Walk(d.root, func(npath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
d.mu.Lock()
k := strings.ReplaceAll(npath[len(d.root)+1:], "\\", "/")
fi := vfs.NewFileInfoFromOS(info, k)
d.info[k] = fi
d.LRU.Add(k, fi)
d.size += info.Size()
d.mu.Unlock()
return nil
})
if err != nil {
logger.Logger.Error().Err(err).Msg("Walk failed")
}
logger.Logger.Info().
Str("name", d.Name()).
@@ -98,34 +193,6 @@ func (d *DiskFS) init() {
Msg("init")
}
func (d *DiskFS) walk(path string) {
d.sg.Add(1)
go func() {
defer d.sg.Done()
filepath.Walk(path, func(npath string, info os.FileInfo, err error) error {
if path == npath {
return nil
}
if err != nil {
return err
}
if info.IsDir() {
d.walk(npath)
return filepath.SkipDir
}
d.mu.Lock()
k := strings.ReplaceAll(npath[len(d.root)+1:], "\\", "/")
d.info[k] = vfs.NewFileInfoFromOS(info, k)
d.mu.Unlock()
return nil
})
}()
}
func (d *DiskFS) Capacity() int64 {
return d.capacity
}
@@ -135,49 +202,114 @@ func (d *DiskFS) Name() string {
}
func (d *DiskFS) Size() int64 {
d.mu.Lock()
defer d.mu.Unlock()
var size int64
for _, v := range d.info {
size += v.Size()
}
return size
d.mu.RLock()
defer d.mu.RUnlock()
return d.size
}
func (d *DiskFS) Set(key string, src []byte) error {
func (d *DiskFS) getKeyLock(key string) *sync.RWMutex {
mu, _ := d.keyLocks.LoadOrStore(key, &sync.RWMutex{})
return mu.(*sync.RWMutex)
}
func (d *DiskFS) Create(key string, size int64) (io.WriteCloser, error) {
if key == "" {
return vfserror.ErrInvalidKey
return nil, vfserror.ErrInvalidKey
}
if key[0] == '/' {
return vfserror.ErrInvalidKey
return nil, vfserror.ErrInvalidKey
}
// Sanitize key to prevent path traversal
key = filepath.Clean(key)
key = strings.ReplaceAll(key, "\\", "/") // Ensure forward slashes for consistency
if strings.Contains(key, "..") {
return nil, vfserror.ErrInvalidKey
}
d.mu.RLock()
if d.capacity > 0 {
if size := d.Size() + int64(len(src)); size > d.capacity {
return vfserror.ErrDiskFull
if d.size+size > d.capacity {
d.mu.RUnlock()
return nil, vfserror.ErrDiskFull
}
}
d.mu.RUnlock()
if _, err := d.Stat(key); err == nil {
d.Delete(key)
}
keyMu := d.getKeyLock(key)
keyMu.Lock()
defer keyMu.Unlock()
// Check again after lock
d.mu.Lock()
defer d.mu.Unlock()
os.MkdirAll(d.root+"/"+filepath.Dir(key), 0755)
if err := os.WriteFile(d.root+"/"+key, src, 0644); err != nil {
return err
if fi, exists := d.info[key]; exists {
d.size -= fi.Size()
d.LRU.Remove(key)
delete(d.info, key)
path := filepath.Join(d.root, key)
os.Remove(path) // Ignore error, as file might not exist or other issues
}
d.mu.Unlock()
path := filepath.Join(d.root, key)
path = strings.ReplaceAll(path, "\\", "/") // Ensure forward slashes for consistency
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, err
}
fi, err := os.Stat(d.root + "/" + key)
file, err := os.Create(path)
if err != nil {
panic(err)
return nil, err
}
d.info[key] = vfs.NewFileInfoFromOS(fi, key)
return &diskWriteCloser{
Writer: file,
onClose: func(n int64) error {
fi, err := os.Stat(path)
if err != nil {
os.Remove(path)
return err
}
return nil
d.mu.Lock()
finfo := vfs.NewFileInfoFromOS(fi, key)
d.info[key] = finfo
d.LRU.Add(key, finfo)
d.size += n
d.mu.Unlock()
diskWriteBytes.Add(float64(n))
diskSizeBytes.Set(float64(d.Size()))
return nil
},
key: key,
file: file,
}, nil
}
type diskWriteCloser struct {
io.Writer
onClose func(int64) error
n int64
key string
file *os.File
}
func (wc *diskWriteCloser) Write(p []byte) (int, error) {
n, err := wc.Writer.Write(p)
wc.n += int64(n)
return n, err
}
func (wc *diskWriteCloser) Close() error {
err := wc.file.Close()
if e := wc.onClose(wc.n); e != nil {
os.Remove(wc.file.Name())
return e
}
return err
}
// Delete deletes the value of key.
@@ -189,24 +321,41 @@ func (d *DiskFS) Delete(key string) error {
return vfserror.ErrInvalidKey
}
_, err := d.Stat(key)
if err != nil {
return err
// Sanitize key to prevent path traversal
key = filepath.Clean(key)
key = strings.ReplaceAll(key, "\\", "/") // Ensure forward slashes for consistency
if strings.Contains(key, "..") {
return vfserror.ErrInvalidKey
}
keyMu := d.getKeyLock(key)
keyMu.Lock()
defer keyMu.Unlock()
d.mu.Lock()
defer d.mu.Unlock()
fi, exists := d.info[key]
if !exists {
d.mu.Unlock()
return vfserror.ErrNotFound
}
d.size -= fi.Size()
d.LRU.Remove(key)
delete(d.info, key)
if err := os.Remove(filepath.Join(d.root, key)); err != nil {
d.mu.Unlock()
path := filepath.Join(d.root, key)
path = strings.ReplaceAll(path, "\\", "/") // Ensure forward slashes for consistency
if err := os.Remove(path); err != nil {
return err
}
diskSizeBytes.Set(float64(d.Size()))
return nil
}
// Get gets the value of key and returns it.
func (d *DiskFS) Get(key string) ([]byte, error) {
// Open opens the file at key and returns it.
func (d *DiskFS) Open(key string) (io.ReadCloser, error) {
if key == "" {
return nil, vfserror.ErrInvalidKey
}
@@ -214,29 +363,59 @@ func (d *DiskFS) Get(key string) ([]byte, error) {
return nil, vfserror.ErrInvalidKey
}
_, err := d.Stat(key)
if err != nil {
return nil, err
// Sanitize key to prevent path traversal
key = filepath.Clean(key)
key = strings.ReplaceAll(key, "\\", "/") // Ensure forward slashes for consistency
if strings.Contains(key, "..") {
return nil, vfserror.ErrInvalidKey
}
keyMu := d.getKeyLock(key)
keyMu.RLock()
defer keyMu.RUnlock()
d.mu.Lock()
defer d.mu.Unlock()
fi, exists := d.info[key]
if !exists {
d.mu.Unlock()
return nil, vfserror.ErrNotFound
}
fi.ATime = time.Now()
d.LRU.MoveToFront(key)
d.mu.Unlock()
data, err := os.ReadFile(filepath.Join(d.root, key))
path := filepath.Join(d.root, key)
path = strings.ReplaceAll(path, "\\", "/") // Ensure forward slashes for consistency
file, err := os.Open(path)
if err != nil {
return nil, err
}
// Use pooled slice for return if possible, but since ReadFile allocates new, copy to pool if beneficial
dst := d.bytePool.Get().([]byte)
if cap(dst) < len(data) {
dst = make([]byte, len(data)) // create a new slice if the pool slice is too small
} else {
dst = dst[:len(data)] // reuse the pool slice, but resize it to fit
}
dst = dst[:len(data)]
copy(dst, data)
return dst, nil
// Update metrics on close
return &readCloser{
ReadCloser: file,
onClose: func(n int64) {
diskReadBytes.Add(float64(n))
},
}, nil
}
type readCloser struct {
io.ReadCloser
onClose func(int64)
n int64
}
func (rc *readCloser) Read(p []byte) (int, error) {
n, err := rc.ReadCloser.Read(p)
rc.n += int64(n)
return n, err
}
func (rc *readCloser) Close() error {
err := rc.ReadCloser.Close()
rc.onClose(rc.n)
return err
}
// Stat returns the FileInfo of key. If key is not found in the cache, it will stat the file on disk. If the file is not found on disk, it will return vfs.ErrNotFound.
@@ -248,8 +427,19 @@ func (d *DiskFS) Stat(key string) (*vfs.FileInfo, error) {
return nil, vfserror.ErrInvalidKey
}
d.mu.Lock()
defer d.mu.Unlock()
// Sanitize key to prevent path traversal
key = filepath.Clean(key)
key = strings.ReplaceAll(key, "\\", "/") // Ensure forward slashes for consistency
if strings.Contains(key, "..") {
return nil, vfserror.ErrInvalidKey
}
keyMu := d.getKeyLock(key)
keyMu.RLock()
defer keyMu.RUnlock()
d.mu.RLock()
defer d.mu.RUnlock()
if fi, ok := d.info[key]; !ok {
return nil, vfserror.ErrNotFound
@@ -258,13 +448,13 @@ func (d *DiskFS) Stat(key string) (*vfs.FileInfo, error) {
}
}
func (m *DiskFS) StatAll() []*vfs.FileInfo {
m.mu.Lock()
defer m.mu.Unlock()
func (d *DiskFS) StatAll() []*vfs.FileInfo {
d.mu.RLock()
defer d.mu.RUnlock()
// hard copy the file info to prevent modification of the original file info or the other way around
files := make([]*vfs.FileInfo, 0, len(m.info))
for _, v := range m.info {
files := make([]*vfs.FileInfo, 0, len(d.info))
for _, v := range d.info {
fi := *v
files = append(files, &fi)
}

View File

@@ -1,145 +1,180 @@
// vfs/disk/disk_test.go
package disk
import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"s1d3sw1ped/SteamCache2/vfs/vfserror"
"testing"
)
func TestAllDisk(t *testing.T) {
t.Parallel()
func TestCreateAndOpen(t *testing.T) {
m := NewSkipInit(t.TempDir(), 1024)
if err := m.Set("key", []byte("value")); err != nil {
t.Errorf("Set failed: %v", err)
}
key := "key"
value := []byte("value")
if err := m.Set("key", []byte("value1")); err != nil {
t.Errorf("Set failed: %v", err)
w, err := m.Create(key, int64(len(value)))
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write(value)
w.Close()
if d, err := m.Get("key"); err != nil {
t.Errorf("Get failed: %v", err)
} else if string(d) != "value1" {
t.Errorf("Get failed: got %s, want %s", d, "value1")
rc, err := m.Open(key)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
got, _ := io.ReadAll(rc)
rc.Close()
if err := m.Delete("key"); err != nil {
t.Errorf("Delete failed: %v", err)
}
if _, err := m.Get("key"); err == nil {
t.Errorf("Get failed: got nil, want %v", vfserror.ErrNotFound)
}
if err := m.Delete("key"); err == nil {
t.Errorf("Delete failed: got nil, want %v", vfserror.ErrNotFound)
}
if _, err := m.Stat("key"); err == nil {
t.Errorf("Stat failed: got nil, want %v", vfserror.ErrNotFound)
}
if err := m.Set("key", []byte("value")); err != nil {
t.Errorf("Set failed: %v", err)
}
if _, err := m.Stat("key"); err != nil {
t.Errorf("Stat failed: %v", err)
if string(got) != string(value) {
t.Fatalf("expected %s, got %s", value, got)
}
}
func TestLimited(t *testing.T) {
t.Parallel()
func TestOverwrite(t *testing.T) {
m := NewSkipInit(t.TempDir(), 1024)
key := "key"
value1 := []byte("value1")
value2 := []byte("value2")
w, err := m.Create(key, int64(len(value1)))
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write(value1)
w.Close()
w, err = m.Create(key, int64(len(value2)))
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write(value2)
w.Close()
rc, err := m.Open(key)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
got, _ := io.ReadAll(rc)
rc.Close()
if string(got) != string(value2) {
t.Fatalf("expected %s, got %s", value2, got)
}
}
func TestDelete(t *testing.T) {
m := NewSkipInit(t.TempDir(), 1024)
key := "key"
value := []byte("value")
w, err := m.Create(key, int64(len(value)))
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write(value)
w.Close()
if err := m.Delete(key); err != nil {
t.Fatalf("Delete failed: %v", err)
}
_, err = m.Open(key)
if !errors.Is(err, vfserror.ErrNotFound) {
t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err)
}
}
func TestCapacityLimit(t *testing.T) {
m := NewSkipInit(t.TempDir(), 10)
for i := 0; i < 11; i++ {
if err := m.Set(fmt.Sprintf("key%d", i), []byte("1")); err != nil && i < 10 {
t.Errorf("Set failed: %v", err)
w, err := m.Create(fmt.Sprintf("key%d", i), 1)
if err != nil && i < 10 {
t.Errorf("Create failed: %v", err)
} else if i == 10 && err == nil {
t.Errorf("Set succeeded: got nil, want %v", vfserror.ErrDiskFull)
t.Errorf("Create succeeded: got nil, want %v", vfserror.ErrDiskFull)
}
if i < 10 {
w.Write([]byte("1"))
w.Close()
}
}
}
func TestInit(t *testing.T) {
t.Parallel()
func TestInitExistingFiles(t *testing.T) {
td := t.TempDir()
path := filepath.Join(td, "test", "key")
os.MkdirAll(filepath.Dir(path), 0755)
os.WriteFile(path, []byte("value"), 0644)
m := New(td, 10)
if _, err := m.Get("test/key"); err != nil {
t.Errorf("Get failed: %v", err)
rc, err := m.Open("test/key")
if err != nil {
t.Fatalf("Open failed: %v", err)
}
got, _ := io.ReadAll(rc)
rc.Close()
if string(got) != "value" {
t.Errorf("expected value, got %s", got)
}
s, _ := m.Stat("test/key")
if s.Name() != "test/key" {
t.Errorf("Stat failed: got %s, want %s", s.Name(), "key")
s, err := m.Stat("test/key")
if err != nil {
t.Fatalf("Stat failed: %v", err)
}
if s == nil {
t.Error("Stat returned nil")
}
if s != nil && s.Name() != "test/key" {
t.Errorf("Stat failed: got %s, want %s", s.Name(), "test/key")
}
}
func TestDiskSizeDiscrepancy(t *testing.T) {
t.Parallel()
func TestSizeConsistency(t *testing.T) {
td := t.TempDir()
assumedSize := int64(6 + 5 + 6) // 6 + 5 + 6 bytes for key, key1, key2
os.WriteFile(filepath.Join(td, "key2"), []byte("value2"), 0644)
m := New(td, 1024)
if 6 != m.Size() {
t.Errorf("Size failed: got %d, want %d", m.Size(), 6)
if m.Size() != 6 {
t.Errorf("Size failed: got %d, want 6", m.Size())
}
if err := m.Set("key", []byte("value")); err != nil {
t.Errorf("Set failed: %v", err)
w, err := m.Create("key", 5)
if err != nil {
t.Errorf("Create failed: %v", err)
}
w.Write([]byte("value"))
w.Close()
if err := m.Set("key1", []byte("value1")); err != nil {
t.Errorf("Set failed: %v", err)
w, err = m.Create("key1", 6)
if err != nil {
t.Errorf("Create failed: %v", err)
}
w.Write([]byte("value1"))
w.Close()
assumedSize := int64(6 + 5 + 6)
if assumedSize != m.Size() {
t.Errorf("Size failed: got %d, want %d", m.Size(), assumedSize)
}
if d, err := m.Get("key"); err != nil {
t.Errorf("Get failed: %v", err)
} else if string(d) != "value" {
t.Errorf("Get failed: got %s, want %s", d, "value")
rc, err := m.Open("key")
if err != nil {
t.Errorf("Open failed: %v", err)
}
if d, err := m.Get("key1"); err != nil {
t.Errorf("Get failed: %v", err)
} else if string(d) != "value1" {
t.Errorf("Get failed: got %s, want %s", d, "value1")
d, _ := io.ReadAll(rc)
rc.Close()
if string(d) != "value" {
t.Errorf("Get failed: got %s, want value", d)
}
m = New(td, 1024)
if assumedSize != m.Size() {
t.Errorf("Size failed: got %d, want %d", m.Size(), assumedSize)
}
if d, err := m.Get("key"); err != nil {
t.Errorf("Get failed: %v", err)
} else if string(d) != "value" {
t.Errorf("Get failed: got %s, want %s", d, "value")
}
if d, err := m.Get("key1"); err != nil {
t.Errorf("Get failed: %v", err)
} else if string(d) != "value1" {
t.Errorf("Get failed: got %s, want %s", d, "value1")
}
if assumedSize != m.Size() {
t.Errorf("Size failed: got %d, want %d", m.Size(), assumedSize)
}

View File

@@ -1,3 +1,4 @@
// vfs/fileinfo.go
package vfs
import (

View File

@@ -1,13 +1,79 @@
// vfs/gc/gc.go
package gc
import (
"fmt"
"io"
"s1d3sw1ped/SteamCache2/steamcache/logger"
"s1d3sw1ped/SteamCache2/vfs"
"s1d3sw1ped/SteamCache2/vfs/cachestate"
"s1d3sw1ped/SteamCache2/vfs/disk"
"s1d3sw1ped/SteamCache2/vfs/memory"
"s1d3sw1ped/SteamCache2/vfs/vfserror"
"sync"
"time"
)
// LRUGC deletes files in LRU order until enough space is reclaimed.
func LRUGC(vfss vfs.VFS, size uint) {
attempts := 0
deletions := 0
var reclaimed uint
for reclaimed < size {
if attempts > 10 {
logger.Logger.Debug().
Int("attempts", attempts).
Msg("GC: Too many attempts to reclaim space, giving up")
return
}
attempts++
switch fs := vfss.(type) {
case *disk.DiskFS:
fi := fs.LRU.Back()
if fi == nil {
break
}
sz := uint(fi.Size())
err := fs.Delete(fi.Name())
if err != nil {
continue
}
reclaimed += sz
deletions++
case *memory.MemoryFS:
fi := fs.LRU.Back()
if fi == nil {
break
}
sz := uint(fi.Size())
err := fs.Delete(fi.Name())
if err != nil {
continue
}
reclaimed += sz
deletions++
default:
// Fallback to old method if not supported
stats := vfss.StatAll()
if len(stats) == 0 {
break
}
fi := stats[0] // Assume sorted or pick first
sz := uint(fi.Size())
err := vfss.Delete(fi.Name())
if err != nil {
continue
}
reclaimed += sz
deletions++
}
}
}
func PromotionDecider(fi *vfs.FileInfo, cs cachestate.CacheState) bool {
return time.Since(fi.AccessTime()) < time.Second*60 // Put hot files in the fast vfs if equipped
}
// Ensure GCFS implements VFS.
var _ vfs.VFS = (*GCFS)(nil)
@@ -17,15 +83,11 @@ type GCFS struct {
multiplier int
// protected by mu
gcHanderFunc GCHandlerFunc
lifetimeBytes, lifetimeFiles uint
reclaimedBytes, deletedFiles uint
gcTime time.Duration
mu sync.Mutex
gcHanderFunc GCHandlerFunc
}
// GCHandlerFunc is a function that is called when the disk is full and the GCFS needs to free up space. It is passed the VFS and the size of the file that needs to be written. Its up to the implementation to free up space. How much space is freed is also up to the implementation.
type GCHandlerFunc func(vfs vfs.VFS, size uint) (reclaimedBytes uint, deletedFiles uint)
type GCHandlerFunc func(vfs vfs.VFS, size uint)
func New(vfs vfs.VFS, multiplier int, gcHandlerFunc GCHandlerFunc) *GCFS {
if multiplier <= 0 {
@@ -38,47 +100,17 @@ func New(vfs vfs.VFS, multiplier int, gcHandlerFunc GCHandlerFunc) *GCFS {
}
}
// Stats returns the lifetime bytes, lifetime files, reclaimed bytes and deleted files.
// The lifetime bytes and lifetime files are the total bytes and files that have been freed up by the GC handler.
// The reclaimed bytes and deleted files are the bytes and files that have been freed up by the GC handler since last call to Stats.
// The gc time is the total time spent in the GC handler since last call to Stats.
// The reclaimed bytes and deleted files and gc time are reset to 0 after the call to Stats.
func (g *GCFS) Stats() (lifetimeBytes, lifetimeFiles, reclaimedBytes, deletedFiles uint, gcTime time.Duration) {
g.mu.Lock()
defer g.mu.Unlock()
g.lifetimeBytes += g.reclaimedBytes
g.lifetimeFiles += g.deletedFiles
lifetimeBytes = g.lifetimeBytes
lifetimeFiles = g.lifetimeFiles
reclaimedBytes = g.reclaimedBytes
deletedFiles = g.deletedFiles
gcTime = g.gcTime
g.reclaimedBytes = 0
g.deletedFiles = 0
g.gcTime = time.Duration(0)
return
}
// Set overrides the Set method of the VFS interface. It tries to set the key and src, if it fails due to disk full error, it calls the GC handler and tries again. If it still fails it returns the error.
func (g *GCFS) Set(key string, src []byte) error {
g.mu.Lock()
defer g.mu.Unlock()
err := g.VFS.Set(key, src) // try to set the key and src
// Create overrides the Create method of the VFS interface. It tries to create the key, if it fails due to disk full error, it calls the GC handler and tries again. If it still fails it returns the error.
func (g *GCFS) Create(key string, size int64) (io.WriteCloser, error) {
w, err := g.VFS.Create(key, size) // try to create the key
// if it fails due to disk full error, call the GC handler and try again in a loop that will continue until it succeeds or the error is not disk full
if err == vfserror.ErrDiskFull && g.gcHanderFunc != nil { // if the error is disk full and there is a GC handler
tstart := time.Now()
reclaimedBytes, deletedFiles := g.gcHanderFunc(g.VFS, uint(len(src)*g.multiplier)) // call the GC handler
g.gcTime += time.Since(tstart)
g.reclaimedBytes += reclaimedBytes
g.deletedFiles += deletedFiles
err = g.VFS.Set(key, src) // try again after GC if it still fails return the error
g.gcHanderFunc(g.VFS, uint(size*int64(g.multiplier))) // call the GC handler
w, err = g.VFS.Create(key, size)
}
return err
return w, err
}
func (g *GCFS) Name() string {

View File

@@ -1,105 +1,73 @@
// vfs/gc/gc_test.go
package gc
import (
"errors"
"fmt"
"s1d3sw1ped/SteamCache2/vfs"
"s1d3sw1ped/SteamCache2/vfs/memory"
"sort"
"s1d3sw1ped/SteamCache2/vfs/vfserror"
"testing"
"golang.org/x/exp/rand"
)
func TestGCSmallRandom(t *testing.T) {
t.Parallel()
func TestGCOnFull(t *testing.T) {
m := memory.New(10)
gc := New(m, 2, LRUGC)
m := memory.New(1024 * 1024 * 16)
gc := New(m, 10, func(vfs vfs.VFS, size uint) (uint, uint) {
deletions := 0
var reclaimed uint
t.Logf("GC starting to reclaim %d bytes", size)
stats := vfs.StatAll()
sort.Slice(stats, func(i, j int) bool {
// Sort by access time so we can remove the oldest files first.
return stats[i].AccessTime().Before(stats[j].AccessTime())
})
// Delete the oldest files until we've reclaimed enough space.
for _, s := range stats {
sz := uint(s.Size()) // Get the size of the file
err := vfs.Delete(s.Name())
if err != nil {
panic(err)
}
reclaimed += sz // Track how much space we've reclaimed
deletions++ // Track how many files we've deleted
// t.Logf("GC deleting %s, %v", s.Name(), s.AccessTime().Format(time.RFC3339Nano))
if reclaimed >= size { // We've reclaimed enough space
break
}
}
return uint(reclaimed), uint(deletions)
})
for i := 0; i < 10000; i++ {
if err := gc.Set(fmt.Sprintf("key:%d", i), genRandomData(1024*1, 1024*4)); err != nil {
t.Errorf("Set failed: %v", err)
for i := 0; i < 5; i++ {
w, err := gc.Create(fmt.Sprintf("key%d", i), 2)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write([]byte("ab"))
w.Close()
}
if gc.Size() > 1024*1024*16 {
t.Errorf("MemoryFS size is %d, want <= 1024", m.Size())
// Cache full at 10 bytes
w, err := gc.Create("key5", 2)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write([]byte("cd"))
w.Close()
if gc.Size() > 10 {
t.Errorf("Size exceeded: %d > 10", gc.Size())
}
// Check if older keys were evicted
_, err = m.Open("key0")
if err == nil {
t.Error("Expected key0 to be evicted")
}
}
func genRandomData(min int, max int) []byte {
data := make([]byte, rand.Intn(max-min)+min)
rand.Read(data)
return data
}
func TestNoGCNeeded(t *testing.T) {
m := memory.New(20)
gc := New(m, 2, LRUGC)
func TestGCLargeRandom(t *testing.T) {
t.Parallel()
m := memory.New(1024 * 1024 * 16) // 16MB
gc := New(m, 10, func(vfs vfs.VFS, size uint) (uint, uint) {
deletions := 0
var reclaimed uint
t.Logf("GC starting to reclaim %d bytes", size)
stats := vfs.StatAll()
sort.Slice(stats, func(i, j int) bool {
// Sort by access time so we can remove the oldest files first.
return stats[i].AccessTime().Before(stats[j].AccessTime())
})
// Delete the oldest files until we've reclaimed enough space.
for _, s := range stats {
sz := uint(s.Size()) // Get the size of the file
vfs.Delete(s.Name())
reclaimed += sz // Track how much space we've reclaimed
deletions++ // Track how many files we've deleted
if reclaimed >= size { // We've reclaimed enough space
break
}
}
return uint(reclaimed), uint(deletions)
})
for i := 0; i < 10000; i++ {
if err := gc.Set(fmt.Sprintf("key:%d", i), genRandomData(1024, 1024*1024)); err != nil {
t.Errorf("Set failed: %v", err)
for i := 0; i < 5; i++ {
w, err := gc.Create(fmt.Sprintf("key%d", i), 2)
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write([]byte("ab"))
w.Close()
}
if gc.Size() > 1024*1024*16 {
t.Errorf("MemoryFS size is %d, want <= 1024", m.Size())
if gc.Size() != 10 {
t.Errorf("Size: got %d, want 10", gc.Size())
}
}
func TestGCInsufficientSpace(t *testing.T) {
m := memory.New(5)
gc := New(m, 1, LRUGC)
w, err := gc.Create("key0", 10)
if err == nil {
w.Close()
t.Error("Expected ErrDiskFull")
} else if !errors.Is(err, vfserror.ErrDiskFull) {
t.Errorf("Unexpected error: %v", err)
}
}

View File

@@ -1,6 +1,10 @@
// vfs/memory/memory.go
package memory
import (
"bytes"
"container/list"
"io"
"s1d3sw1ped/SteamCache2/steamcache/logger"
"s1d3sw1ped/SteamCache2/vfs"
"s1d3sw1ped/SteamCache2/vfs/vfserror"
@@ -8,6 +12,38 @@ import (
"time"
"github.com/docker/go-units"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
var (
memoryCapacityBytes = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "memory_cache_capacity_bytes",
Help: "Total capacity of the memory cache in bytes",
},
)
memorySizeBytes = promauto.NewGauge(
prometheus.GaugeOpts{
Name: "memory_cache_size_bytes",
Help: "Total size of the memory cache in bytes",
},
)
memoryReadBytes = promauto.NewCounter(
prometheus.CounterOpts{
Name: "memory_cache_read_bytes_total",
Help: "Total number of bytes read from the memory cache",
},
)
memoryWriteBytes = promauto.NewCounter(
prometheus.CounterOpts{
Name: "memory_cache_write_bytes_total",
Help: "Total number of bytes written to the memory cache",
},
)
)
// Ensure MemoryFS implements VFS.
@@ -23,9 +59,49 @@ type file struct {
type MemoryFS struct {
files map[string]*file
capacity int64
mu sync.Mutex
size int64
mu sync.RWMutex
keyLocks sync.Map // map[string]*sync.RWMutex
LRU *lruList
}
bytePool sync.Pool // Pool for []byte slices
// lruList for LRU eviction
type lruList struct {
list *list.List
elem map[string]*list.Element
}
func newLruList() *lruList {
return &lruList{
list: list.New(),
elem: make(map[string]*list.Element),
}
}
func (l *lruList) MoveToFront(key string) {
if e, ok := l.elem[key]; ok {
l.list.MoveToFront(e)
}
}
func (l *lruList) Add(key string, fi *vfs.FileInfo) *list.Element {
e := l.list.PushFront(fi)
l.elem[key] = e
return e
}
func (l *lruList) Remove(key string) {
if e, ok := l.elem[key]; ok {
l.list.Remove(e)
delete(l.elem, key)
}
}
func (l *lruList) Back() *vfs.FileInfo {
if e := l.list.Back(); e != nil {
return e.Value.(*vfs.FileInfo)
}
return nil
}
// New creates a new MemoryFS.
@@ -39,14 +115,18 @@ func New(capacity int64) *MemoryFS {
Str("capacity", units.HumanSize(float64(capacity))).
Msg("init")
return &MemoryFS{
mfs := &MemoryFS{
files: make(map[string]*file),
capacity: capacity,
mu: sync.Mutex{},
bytePool: sync.Pool{
New: func() interface{} { return make([]byte, 0) }, // Initial capacity for pooled slices
},
mu: sync.RWMutex{},
keyLocks: sync.Map{},
LRU: newLruList(),
}
memoryCapacityBytes.Set(float64(capacity))
memorySizeBytes.Set(float64(mfs.Size()))
return mfs
}
func (m *MemoryFS) Capacity() int64 {
@@ -58,93 +138,118 @@ func (m *MemoryFS) Name() string {
}
func (m *MemoryFS) Size() int64 {
var size int64
m.mu.Lock()
defer m.mu.Unlock()
for _, v := range m.files {
size += int64(len(v.data))
}
return size
m.mu.RLock()
defer m.mu.RUnlock()
return m.size
}
func (m *MemoryFS) Set(key string, src []byte) error {
func (m *MemoryFS) getKeyLock(key string) *sync.RWMutex {
mu, _ := m.keyLocks.LoadOrStore(key, &sync.RWMutex{})
return mu.(*sync.RWMutex)
}
func (m *MemoryFS) Create(key string, size int64) (io.WriteCloser, error) {
m.mu.RLock()
if m.capacity > 0 {
if size := m.Size() + int64(len(src)); size > m.capacity {
return vfserror.ErrDiskFull
if m.size+size > m.capacity {
m.mu.RUnlock()
return nil, vfserror.ErrDiskFull
}
}
m.mu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
keyMu := m.getKeyLock(key)
keyMu.Lock()
defer keyMu.Unlock()
// Use pooled slice
data := m.bytePool.Get().([]byte)
if cap(data) < len(src) {
data = make([]byte, len(src)) // expand the slice if the pool slice is too small
} else {
data = data[:len(src)] // reuse the pool slice, but resize it to fit
}
copy(data, src)
buf := &bytes.Buffer{}
m.files[key] = &file{
fileinfo: vfs.NewFileInfo(
key,
int64(len(src)),
time.Now(),
),
data: data,
}
return &memWriteCloser{
Writer: buf,
onClose: func() error {
data := buf.Bytes()
m.mu.Lock()
if f, exists := m.files[key]; exists {
m.size -= int64(len(f.data))
m.LRU.Remove(key)
}
fi := vfs.NewFileInfo(key, int64(len(data)), time.Now())
m.files[key] = &file{
fileinfo: fi,
data: data,
}
m.LRU.Add(key, fi)
m.size += int64(len(data))
m.mu.Unlock()
return nil
memoryWriteBytes.Add(float64(len(data)))
memorySizeBytes.Set(float64(m.Size()))
return nil
},
}, nil
}
type memWriteCloser struct {
io.Writer
onClose func() error
}
func (wc *memWriteCloser) Close() error {
return wc.onClose()
}
func (m *MemoryFS) Delete(key string) error {
_, err := m.Stat(key)
if err != nil {
return err
}
keyMu := m.getKeyLock(key)
keyMu.Lock()
defer keyMu.Unlock()
m.mu.Lock()
defer m.mu.Unlock()
// Return data to pool
if f, ok := m.files[key]; ok {
m.bytePool.Put(f.data)
f, exists := m.files[key]
if !exists {
m.mu.Unlock()
return vfserror.ErrNotFound
}
m.size -= int64(len(f.data))
m.LRU.Remove(key)
delete(m.files, key)
m.mu.Unlock()
memorySizeBytes.Set(float64(m.Size()))
return nil
}
func (m *MemoryFS) Get(key string) ([]byte, error) {
_, err := m.Stat(key)
if err != nil {
return nil, err
}
func (m *MemoryFS) Open(key string) (io.ReadCloser, error) {
keyMu := m.getKeyLock(key)
keyMu.RLock()
defer keyMu.RUnlock()
m.mu.Lock()
defer m.mu.Unlock()
f, exists := m.files[key]
if !exists {
m.mu.Unlock()
return nil, vfserror.ErrNotFound
}
f.fileinfo.ATime = time.Now()
m.LRU.MoveToFront(key)
dataCopy := make([]byte, len(f.data))
copy(dataCopy, f.data)
m.mu.Unlock()
m.files[key].fileinfo.ATime = time.Now()
dst := make([]byte, len(m.files[key].data))
copy(dst, m.files[key].data)
memoryReadBytes.Add(float64(len(dataCopy)))
memorySizeBytes.Set(float64(m.Size()))
logger.Logger.Debug().
Str("name", key).
Str("status", "GET").
Int64("size", int64(len(dst))).
Msg("get file from memory")
return dst, nil
return io.NopCloser(bytes.NewReader(dataCopy)), nil
}
func (m *MemoryFS) Stat(key string) (*vfs.FileInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
keyMu := m.getKeyLock(key)
keyMu.RLock()
defer keyMu.RUnlock()
m.mu.RLock()
defer m.mu.RUnlock()
f, ok := m.files[key]
if !ok {
@@ -155,8 +260,8 @@ func (m *MemoryFS) Stat(key string) (*vfs.FileInfo, error) {
}
func (m *MemoryFS) StatAll() []*vfs.FileInfo {
m.mu.Lock()
defer m.mu.Unlock()
m.mu.RLock()
defer m.mu.RUnlock()
// hard copy the file info to prevent modification of the original file info or the other way around
files := make([]*vfs.FileInfo, 0, len(m.files))

View File

@@ -1,63 +1,129 @@
// vfs/memory/memory_test.go
package memory
import (
"errors"
"fmt"
"io"
"s1d3sw1ped/SteamCache2/vfs/vfserror"
"testing"
)
func TestAllMemory(t *testing.T) {
t.Parallel()
func TestCreateAndOpen(t *testing.T) {
m := New(1024)
if err := m.Set("key", []byte("value")); err != nil {
t.Errorf("Set failed: %v", err)
}
key := "key"
value := []byte("value")
if err := m.Set("key", []byte("value1")); err != nil {
t.Errorf("Set failed: %v", err)
w, err := m.Create(key, int64(len(value)))
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write(value)
w.Close()
if d, err := m.Get("key"); err != nil {
t.Errorf("Get failed: %v", err)
} else if string(d) != "value1" {
t.Errorf("Get failed: got %s, want %s", d, "value1")
rc, err := m.Open(key)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
got, _ := io.ReadAll(rc)
rc.Close()
if err := m.Delete("key"); err != nil {
t.Errorf("Delete failed: %v", err)
}
if _, err := m.Get("key"); err == nil {
t.Errorf("Get failed: got nil, want %v", vfserror.ErrNotFound)
}
if err := m.Delete("key"); err == nil {
t.Errorf("Delete failed: got nil, want %v", vfserror.ErrNotFound)
}
if _, err := m.Stat("key"); err == nil {
t.Errorf("Stat failed: got nil, want %v", vfserror.ErrNotFound)
}
if err := m.Set("key", []byte("value")); err != nil {
t.Errorf("Set failed: %v", err)
}
if _, err := m.Stat("key"); err != nil {
t.Errorf("Stat failed: %v", err)
if string(got) != string(value) {
t.Fatalf("expected %s, got %s", value, got)
}
}
func TestLimited(t *testing.T) {
t.Parallel()
func TestOverwrite(t *testing.T) {
m := New(1024)
key := "key"
value1 := []byte("value1")
value2 := []byte("value2")
w, err := m.Create(key, int64(len(value1)))
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write(value1)
w.Close()
w, err = m.Create(key, int64(len(value2)))
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write(value2)
w.Close()
rc, err := m.Open(key)
if err != nil {
t.Fatalf("Open failed: %v", err)
}
got, _ := io.ReadAll(rc)
rc.Close()
if string(got) != string(value2) {
t.Fatalf("expected %s, got %s", value2, got)
}
}
func TestDelete(t *testing.T) {
m := New(1024)
key := "key"
value := []byte("value")
w, err := m.Create(key, int64(len(value)))
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write(value)
w.Close()
if err := m.Delete(key); err != nil {
t.Fatalf("Delete failed: %v", err)
}
_, err = m.Open(key)
if !errors.Is(err, vfserror.ErrNotFound) {
t.Fatalf("expected %v, got %v", vfserror.ErrNotFound, err)
}
}
func TestCapacityLimit(t *testing.T) {
m := New(10)
for i := 0; i < 11; i++ {
if err := m.Set(fmt.Sprintf("key%d", i), []byte("1")); err != nil && i < 10 {
t.Errorf("Set failed: %v", err)
w, err := m.Create(fmt.Sprintf("key%d", i), 1)
if err != nil && i < 10 {
t.Errorf("Create failed: %v", err)
} else if i == 10 && err == nil {
t.Errorf("Set succeeded: got nil, want %v", vfserror.ErrDiskFull)
t.Errorf("Create succeeded: got nil, want %v", vfserror.ErrDiskFull)
}
if i < 10 {
w.Write([]byte("1"))
w.Close()
}
}
}
func TestStat(t *testing.T) {
m := New(1024)
key := "key"
value := []byte("value")
w, err := m.Create(key, int64(len(value)))
if err != nil {
t.Fatalf("Create failed: %v", err)
}
w.Write(value)
w.Close()
info, err := m.Stat(key)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if info == nil {
t.Fatal("expected file info to be non-nil")
}
if info.Size() != int64(len(value)) {
t.Errorf("expected size %d, got %d", len(value), info.Size())
}
}

View File

@@ -1,76 +0,0 @@
package sync
// import (
// "fmt"
// "s1d3sw1ped/SteamCache2/vfs"
// "sync"
// )
// // Ensure SyncFS implements VFS.
// var _ vfs.VFS = (*SyncFS)(nil)
// type SyncFS struct {
// vfs vfs.VFS
// mu sync.RWMutex
// }
// func New(vfs vfs.VFS) *SyncFS {
// return &SyncFS{
// vfs: vfs,
// mu: sync.RWMutex{},
// }
// }
// // Name returns the name of the file system.
// func (sfs *SyncFS) Name() string {
// return fmt.Sprintf("SyncFS(%s)", sfs.vfs.Name())
// }
// // Size returns the total size of all files in the file system.
// func (sfs *SyncFS) Size() int64 {
// sfs.mu.RLock()
// defer sfs.mu.RUnlock()
// return sfs.vfs.Size()
// }
// // Set sets the value of key as src.
// // Setting the same key multiple times, the last set call takes effect.
// func (sfs *SyncFS) Set(key string, src []byte) error {
// sfs.mu.Lock()
// defer sfs.mu.Unlock()
// return sfs.vfs.Set(key, src)
// }
// // Delete deletes the value of key.
// func (sfs *SyncFS) Delete(key string) error {
// sfs.mu.Lock()
// defer sfs.mu.Unlock()
// return sfs.vfs.Delete(key)
// }
// // Get gets the value of key to dst, and returns dst no matter whether or not there is an error.
// func (sfs *SyncFS) Get(key string) ([]byte, error) {
// sfs.mu.RLock()
// defer sfs.mu.RUnlock()
// return sfs.vfs.Get(key)
// }
// // Stat returns the FileInfo of key.
// func (sfs *SyncFS) Stat(key string) (*vfs.FileInfo, error) {
// sfs.mu.RLock()
// defer sfs.mu.RUnlock()
// return sfs.vfs.Stat(key)
// }
// // StatAll returns the FileInfo of all keys.
// func (sfs *SyncFS) StatAll() []*vfs.FileInfo {
// sfs.mu.RLock()
// defer sfs.mu.RUnlock()
// return sfs.vfs.StatAll()
// }

View File

@@ -1,5 +1,8 @@
// vfs/vfs.go
package vfs
import "io"
// VFS is the interface that wraps the basic methods of a virtual file system.
type VFS interface {
// Name returns the name of the file system.
@@ -8,15 +11,14 @@ type VFS interface {
// Size returns the total size of all files in the file system.
Size() int64
// Set sets the value of key as src.
// Setting the same key multiple times, the last set call takes effect.
Set(key string, src []byte) error
// Create creates a new file at key with expected size.
Create(key string, size int64) (io.WriteCloser, error)
// Delete deletes the value of key.
Delete(key string) error
// Get gets the value of key to dst, and returns dst no matter whether or not there is an error.
Get(key string) ([]byte, error)
// Open opens the file at key.
Open(key string) (io.ReadCloser, error)
// Stat returns the FileInfo of key.
Stat(key string) (*FileInfo, error)

View File

@@ -1,3 +1,4 @@
// vfs/vfserror/vfserror.go
package vfserror
import "errors"