package encryption import ( "crypto/aes" "crypto/cipher" "crypto/rand" "crypto/sha256" "crypto/subtle" "fmt" "io" "math" "sync" "time" "golang.org/x/crypto/pbkdf2" ) const ( // PBKDF2 parameters for key derivation PBKDF2Iterations = 100000 // OWASP recommended minimum PBKDF2KeyLength = 32 // 256 bits PBKDF2SaltLength = 16 // 128 bits // Replay protection parameters MaxPacketAge = 5 * time.Minute // Maximum age for UDP packets NonceWindow = 1000 // Number of nonces to track for replay protection ) // DeriveKey derives an encryption key from a password using PBKDF2 func DeriveKey(password string) []byte { // Use a deterministic salt derived from the password hash for consistent key derivation // This ensures the same password always produces the same key while avoiding rainbow tables hasher := sha256.New() hasher.Write([]byte(password)) passwordHash := hasher.Sum(nil) // Create a deterministic salt from the password hash salt := make([]byte, PBKDF2SaltLength) copy(salt, passwordHash[:PBKDF2SaltLength]) key := pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, PBKDF2KeyLength, sha256.New) return key } // DeriveKeyWithSalt derives an encryption key from a password using PBKDF2 with a custom salt func DeriveKeyWithSalt(password string, salt []byte) ([]byte, error) { if len(salt) != PBKDF2SaltLength { return nil, fmt.Errorf("salt length must be %d bytes, got %d", PBKDF2SaltLength, len(salt)) } key := pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, PBKDF2KeyLength, sha256.New) return key, nil } // GenerateSalt generates a cryptographically secure random salt func GenerateSalt() ([]byte, error) { salt := make([]byte, PBKDF2SaltLength) _, err := io.ReadFull(rand.Reader, salt) return salt, err } // EncryptData encrypts data using AES-GCM func EncryptData(data []byte, key []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } nonce := make([]byte, gcm.NonceSize()) if _, err = io.ReadFull(rand.Reader, nonce); err != nil { return nil, err } ciphertext := gcm.Seal(nonce, nonce, data, nil) return ciphertext, nil } // DecryptData decrypts data using AES-GCM func DecryptData(data []byte, key []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } nonceSize := gcm.NonceSize() if len(data) < nonceSize { return nil, fmt.Errorf("ciphertext too short") } nonce, ciphertext := data[:nonceSize], data[nonceSize:] plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) if err != nil { return nil, err } return plaintext, nil } // ReplayProtection tracks nonces to prevent replay attacks type ReplayProtection struct { nonces map[uint64]time.Time mutex sync.RWMutex maxNonces int // Maximum number of nonces to track } // NewReplayProtection creates a new replay protection instance func NewReplayProtection() *ReplayProtection { return &ReplayProtection{ nonces: make(map[uint64]time.Time), maxNonces: NonceWindow * 2, // Allow 2x the window size for safety } } // IsValidNonce checks if a nonce is valid (not replayed and not too old) func (rp *ReplayProtection) IsValidNonce(nonce uint64, timestamp int64) bool { rp.mutex.Lock() defer rp.mutex.Unlock() now := time.Now() packetTime := time.Unix(timestamp, 0) // Check if packet is too old if now.Sub(packetTime) > MaxPacketAge { return false } // Check if nonce was already used if _, exists := rp.nonces[nonce]; exists { return false } // Add nonce to tracking rp.nonces[nonce] = now // Clean up old nonces to prevent memory leaks if len(rp.nonces) > rp.maxNonces { rp.cleanupOldNonces(now) } return true } // cleanupOldNonces removes nonces older than MaxPacketAge func (rp *ReplayProtection) cleanupOldNonces(now time.Time) { for nonce, timestamp := range rp.nonces { if now.Sub(timestamp) > MaxPacketAge { delete(rp.nonces, nonce) } } } // ValidatePacketTimestamp validates that a packet timestamp is within acceptable range func ValidatePacketTimestamp(timestamp int64) bool { now := time.Now().Unix() packetTime := timestamp // Check if packet is too old or from the future age := now - packetTime return age >= 0 && age <= int64(MaxPacketAge.Seconds()) } // ConstantTimeCompare performs a constant-time comparison of two byte slices func ConstantTimeCompare(a, b []byte) bool { return subtle.ConstantTimeCompare(a, b) == 1 } // ValidateEncryptionKey validates that an encryption key meets security requirements func ValidateEncryptionKey(key string) error { if len(key) < 32 { return fmt.Errorf("encryption key must be at least 32 characters long") } // Check for common weak keys weakKeys := []string{ "password", "123456", "admin", "test", "default", "your-secure-encryption-key-change-this-to-something-random", "test-encryption-key-12345", "teleport-key", "secret", } for _, weak := range weakKeys { if key == weak { return fmt.Errorf("encryption key is too weak, please use a strong random key") } } // Check for sufficient entropy (basic check) entropy := calculateEntropy(key) if entropy < 3.5 { // Minimum entropy threshold return fmt.Errorf("encryption key has insufficient entropy, please use a more random key") } return nil } // calculateEntropy calculates the Shannon entropy of a string func calculateEntropy(s string) float64 { freq := make(map[rune]int) for _, r := range s { freq[r]++ } entropy := 0.0 length := float64(len(s)) for _, count := range freq { p := float64(count) / length entropy -= p * log2(p) } return entropy } // log2 calculates log base 2 func log2(x float64) float64 { return math.Log2(x) }