Add initial project structure with core functionality
- Created a new Go module named 'teleport' for secure port forwarding. - Added essential files including .gitignore, LICENSE, and README.md with project details. - Implemented configuration management with YAML support in config package. - Developed core client and server functionalities for handling port forwarding. - Introduced DNS server capabilities and integrated logging with sanitization. - Established rate limiting and metrics tracking for performance monitoring. - Included comprehensive tests for core components and functionalities. - Set up CI workflows for automated testing and release management using Gitea actions.
This commit is contained in:
226
pkg/encryption/encryption.go
Normal file
226
pkg/encryption/encryption.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
const (
|
||||
// PBKDF2 parameters for key derivation
|
||||
PBKDF2Iterations = 100000 // OWASP recommended minimum
|
||||
PBKDF2KeyLength = 32 // 256 bits
|
||||
PBKDF2SaltLength = 16 // 128 bits
|
||||
|
||||
// Replay protection parameters
|
||||
MaxPacketAge = 5 * time.Minute // Maximum age for UDP packets
|
||||
NonceWindow = 1000 // Number of nonces to track for replay protection
|
||||
)
|
||||
|
||||
// DeriveKey derives an encryption key from a password using PBKDF2
|
||||
func DeriveKey(password string) []byte {
|
||||
// Use a deterministic salt derived from the password hash for consistent key derivation
|
||||
// This ensures the same password always produces the same key while avoiding rainbow tables
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(password))
|
||||
passwordHash := hasher.Sum(nil)
|
||||
|
||||
// Create a deterministic salt from the password hash
|
||||
salt := make([]byte, PBKDF2SaltLength)
|
||||
copy(salt, passwordHash[:PBKDF2SaltLength])
|
||||
|
||||
key := pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, PBKDF2KeyLength, sha256.New)
|
||||
return key
|
||||
}
|
||||
|
||||
// DeriveKeyWithSalt derives an encryption key from a password using PBKDF2 with a custom salt
|
||||
func DeriveKeyWithSalt(password string, salt []byte) ([]byte, error) {
|
||||
if len(salt) != PBKDF2SaltLength {
|
||||
return nil, fmt.Errorf("salt length must be %d bytes, got %d", PBKDF2SaltLength, len(salt))
|
||||
}
|
||||
|
||||
key := pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, PBKDF2KeyLength, sha256.New)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GenerateSalt generates a cryptographically secure random salt
|
||||
func GenerateSalt() ([]byte, error) {
|
||||
salt := make([]byte, PBKDF2SaltLength)
|
||||
_, err := io.ReadFull(rand.Reader, salt)
|
||||
return salt, err
|
||||
}
|
||||
|
||||
// EncryptData encrypts data using AES-GCM
|
||||
func EncryptData(data []byte, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, data, nil)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// DecryptData decrypts data using AES-GCM
|
||||
func DecryptData(data []byte, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return nil, fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// ReplayProtection tracks nonces to prevent replay attacks
|
||||
type ReplayProtection struct {
|
||||
nonces map[uint64]time.Time
|
||||
mutex sync.RWMutex
|
||||
maxNonces int // Maximum number of nonces to track
|
||||
}
|
||||
|
||||
// NewReplayProtection creates a new replay protection instance
|
||||
func NewReplayProtection() *ReplayProtection {
|
||||
return &ReplayProtection{
|
||||
nonces: make(map[uint64]time.Time),
|
||||
maxNonces: NonceWindow * 2, // Allow 2x the window size for safety
|
||||
}
|
||||
}
|
||||
|
||||
// IsValidNonce checks if a nonce is valid (not replayed and not too old)
|
||||
func (rp *ReplayProtection) IsValidNonce(nonce uint64, timestamp int64) bool {
|
||||
rp.mutex.Lock()
|
||||
defer rp.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
packetTime := time.Unix(timestamp, 0)
|
||||
|
||||
// Check if packet is too old
|
||||
if now.Sub(packetTime) > MaxPacketAge {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if nonce was already used
|
||||
if _, exists := rp.nonces[nonce]; exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Add nonce to tracking
|
||||
rp.nonces[nonce] = now
|
||||
|
||||
// Clean up old nonces to prevent memory leaks
|
||||
if len(rp.nonces) > rp.maxNonces {
|
||||
rp.cleanupOldNonces(now)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupOldNonces removes nonces older than MaxPacketAge
|
||||
func (rp *ReplayProtection) cleanupOldNonces(now time.Time) {
|
||||
for nonce, timestamp := range rp.nonces {
|
||||
if now.Sub(timestamp) > MaxPacketAge {
|
||||
delete(rp.nonces, nonce)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePacketTimestamp validates that a packet timestamp is within acceptable range
|
||||
func ValidatePacketTimestamp(timestamp int64) bool {
|
||||
now := time.Now().Unix()
|
||||
packetTime := timestamp
|
||||
|
||||
// Check if packet is too old or from the future
|
||||
age := now - packetTime
|
||||
return age >= 0 && age <= int64(MaxPacketAge.Seconds())
|
||||
}
|
||||
|
||||
// ConstantTimeCompare performs a constant-time comparison of two byte slices
|
||||
func ConstantTimeCompare(a, b []byte) bool {
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
|
||||
// ValidateEncryptionKey validates that an encryption key meets security requirements
|
||||
func ValidateEncryptionKey(key string) error {
|
||||
if len(key) < 32 {
|
||||
return fmt.Errorf("encryption key must be at least 32 characters long")
|
||||
}
|
||||
|
||||
// Check for common weak keys
|
||||
weakKeys := []string{
|
||||
"password", "123456", "admin", "test", "default",
|
||||
"your-secure-encryption-key-change-this-to-something-random",
|
||||
"test-encryption-key-12345", "teleport-key", "secret",
|
||||
}
|
||||
|
||||
for _, weak := range weakKeys {
|
||||
if key == weak {
|
||||
return fmt.Errorf("encryption key is too weak, please use a strong random key")
|
||||
}
|
||||
}
|
||||
|
||||
// Check for sufficient entropy (basic check)
|
||||
entropy := calculateEntropy(key)
|
||||
if entropy < 3.5 { // Minimum entropy threshold
|
||||
return fmt.Errorf("encryption key has insufficient entropy, please use a more random key")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateEntropy calculates the Shannon entropy of a string
|
||||
func calculateEntropy(s string) float64 {
|
||||
freq := make(map[rune]int)
|
||||
for _, r := range s {
|
||||
freq[r]++
|
||||
}
|
||||
|
||||
entropy := 0.0
|
||||
length := float64(len(s))
|
||||
|
||||
for _, count := range freq {
|
||||
p := float64(count) / length
|
||||
entropy -= p * log2(p)
|
||||
}
|
||||
|
||||
return entropy
|
||||
}
|
||||
|
||||
// log2 calculates log base 2
|
||||
func log2(x float64) float64 {
|
||||
return math.Log2(x)
|
||||
}
|
||||
242
pkg/encryption/encryption_test.go
Normal file
242
pkg/encryption/encryption_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDeriveKey(t *testing.T) {
|
||||
password := "test-password"
|
||||
key := DeriveKey(password)
|
||||
|
||||
if len(key) != 32 {
|
||||
t.Errorf("Expected key length 32, got %d", len(key))
|
||||
}
|
||||
|
||||
// Test that same password produces same key
|
||||
key2 := DeriveKey(password)
|
||||
if string(key) != string(key2) {
|
||||
t.Error("Same password should produce same key")
|
||||
}
|
||||
|
||||
// Test that different passwords produce different keys
|
||||
key3 := DeriveKey("different-password")
|
||||
if string(key) == string(key3) {
|
||||
t.Error("Different passwords should produce different keys")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
key := DeriveKey("test-key")
|
||||
originalData := []byte("Hello, World! This is a test message.")
|
||||
|
||||
// Test encryption
|
||||
encryptedData, err := EncryptData(originalData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption failed: %v", err)
|
||||
}
|
||||
|
||||
if len(encryptedData) <= len(originalData) {
|
||||
t.Error("Encrypted data should be longer than original data (due to nonce)")
|
||||
}
|
||||
|
||||
// Test decryption
|
||||
decryptedData, err := DecryptData(encryptedData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Decryption failed: %v", err)
|
||||
}
|
||||
|
||||
if string(decryptedData) != string(originalData) {
|
||||
t.Errorf("Decrypted data doesn't match original. Expected: %s, Got: %s",
|
||||
string(originalData), string(decryptedData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptEmptyData(t *testing.T) {
|
||||
key := DeriveKey("test-key")
|
||||
originalData := []byte("")
|
||||
|
||||
encryptedData, err := EncryptData(originalData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption of empty data failed: %v", err)
|
||||
}
|
||||
|
||||
decryptedData, err := DecryptData(encryptedData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Decryption of empty data failed: %v", err)
|
||||
}
|
||||
|
||||
if len(decryptedData) != 0 {
|
||||
t.Error("Decrypted empty data should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptLargeData(t *testing.T) {
|
||||
key := DeriveKey("test-key")
|
||||
|
||||
// Create a large data block (1MB)
|
||||
originalData := make([]byte, 1024*1024)
|
||||
for i := range originalData {
|
||||
originalData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
encryptedData, err := EncryptData(originalData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption of large data failed: %v", err)
|
||||
}
|
||||
|
||||
decryptedData, err := DecryptData(encryptedData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Decryption of large data failed: %v", err)
|
||||
}
|
||||
|
||||
if len(decryptedData) != len(originalData) {
|
||||
t.Errorf("Decrypted data length mismatch. Expected: %d, Got: %d",
|
||||
len(originalData), len(decryptedData))
|
||||
}
|
||||
|
||||
for i := range originalData {
|
||||
if decryptedData[i] != originalData[i] {
|
||||
t.Errorf("Data mismatch at position %d", i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrongKeyDecryption(t *testing.T) {
|
||||
key1 := DeriveKey("key1")
|
||||
key2 := DeriveKey("key2")
|
||||
originalData := []byte("test data")
|
||||
|
||||
encryptedData, err := EncryptData(originalData, key1)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to decrypt with wrong key
|
||||
_, err = DecryptData(encryptedData, key2)
|
||||
if err == nil {
|
||||
t.Error("Decryption with wrong key should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorruptedDataDecryption(t *testing.T) {
|
||||
key := DeriveKey("test-key")
|
||||
originalData := []byte("test data")
|
||||
|
||||
encryptedData, err := EncryptData(originalData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption failed: %v", err)
|
||||
}
|
||||
|
||||
// Corrupt the data
|
||||
encryptedData[0] ^= 0xFF
|
||||
|
||||
// Try to decrypt corrupted data
|
||||
_, err = DecryptData(encryptedData, key)
|
||||
if err == nil {
|
||||
t.Error("Decryption of corrupted data should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEncryptionKey(t *testing.T) {
|
||||
// Test valid key
|
||||
validKey := "a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03"
|
||||
if err := ValidateEncryptionKey(validKey); err != nil {
|
||||
t.Errorf("Valid key should pass validation: %v", err)
|
||||
}
|
||||
|
||||
// Test short key
|
||||
shortKey := "short"
|
||||
if err := ValidateEncryptionKey(shortKey); err == nil {
|
||||
t.Error("Short key should fail validation")
|
||||
}
|
||||
|
||||
// Test weak keys
|
||||
weakKeys := []string{
|
||||
"password", "123456", "admin", "test", "default",
|
||||
"your-secure-encryption-key-change-this-to-something-random",
|
||||
"test-encryption-key-12345", "teleport-key", "secret",
|
||||
}
|
||||
|
||||
for _, weakKey := range weakKeys {
|
||||
if err := ValidateEncryptionKey(weakKey); err == nil {
|
||||
t.Errorf("Weak key '%s' should fail validation", weakKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Test low entropy key
|
||||
lowEntropyKey := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
if err := ValidateEncryptionKey(lowEntropyKey); err == nil {
|
||||
t.Error("Low entropy key should fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplayProtection(t *testing.T) {
|
||||
rp := NewReplayProtection()
|
||||
nonce := uint64(12345)
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
// First use should be valid
|
||||
if !rp.IsValidNonce(nonce, timestamp) {
|
||||
t.Error("First use of nonce should be valid")
|
||||
}
|
||||
|
||||
// Replay should be invalid
|
||||
if rp.IsValidNonce(nonce, timestamp) {
|
||||
t.Error("Replay of nonce should be invalid")
|
||||
}
|
||||
|
||||
// Different nonce should be valid
|
||||
if !rp.IsValidNonce(nonce+1, timestamp) {
|
||||
t.Error("Different nonce should be valid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePacketTimestamp(t *testing.T) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
// Current timestamp should be valid
|
||||
if !ValidatePacketTimestamp(now) {
|
||||
t.Error("Current timestamp should be valid")
|
||||
}
|
||||
|
||||
// Recent timestamp should be valid
|
||||
recent := now - 60 // 1 minute ago
|
||||
if !ValidatePacketTimestamp(recent) {
|
||||
t.Error("Recent timestamp should be valid")
|
||||
}
|
||||
|
||||
// Old timestamp should be invalid
|
||||
old := now - int64(MaxPacketAge.Seconds()) - 1
|
||||
if ValidatePacketTimestamp(old) {
|
||||
t.Error("Old timestamp should be invalid")
|
||||
}
|
||||
|
||||
// Future timestamp should be invalid
|
||||
future := now + 3600 // 1 hour in future
|
||||
if ValidatePacketTimestamp(future) {
|
||||
t.Error("Future timestamp should be invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstantTimeCompare(t *testing.T) {
|
||||
a := []byte("test")
|
||||
b := []byte("test")
|
||||
c := []byte("different")
|
||||
|
||||
if !ConstantTimeCompare(a, b) {
|
||||
t.Error("Identical byte slices should compare equal")
|
||||
}
|
||||
|
||||
if ConstantTimeCompare(a, c) {
|
||||
t.Error("Different byte slices should not compare equal")
|
||||
}
|
||||
|
||||
// Test with empty slices
|
||||
empty1 := []byte{}
|
||||
empty2 := []byte{}
|
||||
if !ConstantTimeCompare(empty1, empty2) {
|
||||
t.Error("Empty slices should compare equal")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user