Files
teleport/pkg/encryption/encryption.go
Justin Harms d24d1dc5ae Add initial project structure with core functionality
- Created a new Go module named 'teleport' for secure port forwarding.
- Added essential files including .gitignore, LICENSE, and README.md with project details.
- Implemented configuration management with YAML support in config package.
- Developed core client and server functionalities for handling port forwarding.
- Introduced DNS server capabilities and integrated logging with sanitization.
- Established rate limiting and metrics tracking for performance monitoring.
- Included comprehensive tests for core components and functionalities.
- Set up CI workflows for automated testing and release management using Gitea actions.
2025-09-20 18:07:08 -05:00

227 lines
5.7 KiB
Go

package encryption
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"fmt"
"io"
"math"
"sync"
"time"
"golang.org/x/crypto/pbkdf2"
)
const (
// PBKDF2 parameters for key derivation
PBKDF2Iterations = 100000 // OWASP recommended minimum
PBKDF2KeyLength = 32 // 256 bits
PBKDF2SaltLength = 16 // 128 bits
// Replay protection parameters
MaxPacketAge = 5 * time.Minute // Maximum age for UDP packets
NonceWindow = 1000 // Number of nonces to track for replay protection
)
// DeriveKey derives an encryption key from a password using PBKDF2
func DeriveKey(password string) []byte {
// Use a deterministic salt derived from the password hash for consistent key derivation
// This ensures the same password always produces the same key while avoiding rainbow tables
hasher := sha256.New()
hasher.Write([]byte(password))
passwordHash := hasher.Sum(nil)
// Create a deterministic salt from the password hash
salt := make([]byte, PBKDF2SaltLength)
copy(salt, passwordHash[:PBKDF2SaltLength])
key := pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, PBKDF2KeyLength, sha256.New)
return key
}
// DeriveKeyWithSalt derives an encryption key from a password using PBKDF2 with a custom salt
func DeriveKeyWithSalt(password string, salt []byte) ([]byte, error) {
if len(salt) != PBKDF2SaltLength {
return nil, fmt.Errorf("salt length must be %d bytes, got %d", PBKDF2SaltLength, len(salt))
}
key := pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, PBKDF2KeyLength, sha256.New)
return key, nil
}
// GenerateSalt generates a cryptographically secure random salt
func GenerateSalt() ([]byte, error) {
salt := make([]byte, PBKDF2SaltLength)
_, err := io.ReadFull(rand.Reader, salt)
return salt, err
}
// EncryptData encrypts data using AES-GCM
func EncryptData(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
ciphertext := gcm.Seal(nonce, nonce, data, nil)
return ciphertext, nil
}
// DecryptData decrypts data using AES-GCM
func DecryptData(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
// ReplayProtection tracks nonces to prevent replay attacks
type ReplayProtection struct {
nonces map[uint64]time.Time
mutex sync.RWMutex
maxNonces int // Maximum number of nonces to track
}
// NewReplayProtection creates a new replay protection instance
func NewReplayProtection() *ReplayProtection {
return &ReplayProtection{
nonces: make(map[uint64]time.Time),
maxNonces: NonceWindow * 2, // Allow 2x the window size for safety
}
}
// IsValidNonce checks if a nonce is valid (not replayed and not too old)
func (rp *ReplayProtection) IsValidNonce(nonce uint64, timestamp int64) bool {
rp.mutex.Lock()
defer rp.mutex.Unlock()
now := time.Now()
packetTime := time.Unix(timestamp, 0)
// Check if packet is too old
if now.Sub(packetTime) > MaxPacketAge {
return false
}
// Check if nonce was already used
if _, exists := rp.nonces[nonce]; exists {
return false
}
// Add nonce to tracking
rp.nonces[nonce] = now
// Clean up old nonces to prevent memory leaks
if len(rp.nonces) > rp.maxNonces {
rp.cleanupOldNonces(now)
}
return true
}
// cleanupOldNonces removes nonces older than MaxPacketAge
func (rp *ReplayProtection) cleanupOldNonces(now time.Time) {
for nonce, timestamp := range rp.nonces {
if now.Sub(timestamp) > MaxPacketAge {
delete(rp.nonces, nonce)
}
}
}
// ValidatePacketTimestamp validates that a packet timestamp is within acceptable range
func ValidatePacketTimestamp(timestamp int64) bool {
now := time.Now().Unix()
packetTime := timestamp
// Check if packet is too old or from the future
age := now - packetTime
return age >= 0 && age <= int64(MaxPacketAge.Seconds())
}
// ConstantTimeCompare performs a constant-time comparison of two byte slices
func ConstantTimeCompare(a, b []byte) bool {
return subtle.ConstantTimeCompare(a, b) == 1
}
// ValidateEncryptionKey validates that an encryption key meets security requirements
func ValidateEncryptionKey(key string) error {
if len(key) < 32 {
return fmt.Errorf("encryption key must be at least 32 characters long")
}
// Check for common weak keys
weakKeys := []string{
"password", "123456", "admin", "test", "default",
"your-secure-encryption-key-change-this-to-something-random",
"test-encryption-key-12345", "teleport-key", "secret",
}
for _, weak := range weakKeys {
if key == weak {
return fmt.Errorf("encryption key is too weak, please use a strong random key")
}
}
// Check for sufficient entropy (basic check)
entropy := calculateEntropy(key)
if entropy < 3.5 { // Minimum entropy threshold
return fmt.Errorf("encryption key has insufficient entropy, please use a more random key")
}
return nil
}
// calculateEntropy calculates the Shannon entropy of a string
func calculateEntropy(s string) float64 {
freq := make(map[rune]int)
for _, r := range s {
freq[r]++
}
entropy := 0.0
length := float64(len(s))
for _, count := range freq {
p := float64(count) / length
entropy -= p * log2(p)
}
return entropy
}
// log2 calculates log base 2
func log2(x float64) float64 {
return math.Log2(x)
}