Add initial project structure with core functionality

- Created a new Go module named 'teleport' for secure port forwarding.
- Added essential files including .gitignore, LICENSE, and README.md with project details.
- Implemented configuration management with YAML support in config package.
- Developed core client and server functionalities for handling port forwarding.
- Introduced DNS server capabilities and integrated logging with sanitization.
- Established rate limiting and metrics tracking for performance monitoring.
- Included comprehensive tests for core components and functionalities.
- Set up CI workflows for automated testing and release management using Gitea actions.
This commit is contained in:
2025-09-20 18:07:08 -05:00
commit d24d1dc5ae
26 changed files with 6065 additions and 0 deletions

537
pkg/config/config.go Normal file
View File

@@ -0,0 +1,537 @@
package config
import (
"crypto/rand"
"encoding/hex"
"fmt"
"os"
"strconv"
"strings"
"time"
"teleport/pkg/encryption"
"gopkg.in/yaml.v3"
)
// Config represents the teleport configuration
type Config struct {
InstanceID string `yaml:"instance_id"`
ListenAddress string `yaml:"listen_address"`
RemoteAddress string `yaml:"remote_address"`
Ports []PortRule `yaml:"ports"`
EncryptionKey string `yaml:"encryption_key"`
KeepAlive bool `yaml:"keep_alive"`
ReadTimeout time.Duration `yaml:"read_timeout"`
WriteTimeout time.Duration `yaml:"write_timeout"`
MaxConnections int `yaml:"max_connections"`
RateLimit RateLimitConfig `yaml:"rate_limit"`
DNSServer DNSServerConfig `yaml:"dns_server"`
}
// PortRule defines a port forwarding rule
type PortRule struct {
LocalPort int `yaml:"local_port"`
RemotePort int `yaml:"remote_port"`
Protocol string `yaml:"protocol"` // "tcp" or "udp"
TargetHost string `yaml:"target_host,omitempty"` // Target host for server-side forwarding (defaults to localhost)
}
// UnmarshalYAML implements custom YAML unmarshaling for PortRule
func (p *PortRule) UnmarshalYAML(value *yaml.Node) error {
// Only support URL-style format: "protocol://target:targetport" (server) or "protocol://targetport:localport" (client)
if value.Kind != yaml.ScalarNode {
return fmt.Errorf("port must be a string in format 'protocol://target:port' or 'protocol://targetport:localport'")
}
portStr := value.Value
// Validate input length
if len(portStr) > 512 {
return fmt.Errorf("port string too long")
}
// Check if it starts with protocol://
if !strings.Contains(portStr, "://") {
return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port' or 'protocol://targetport:localport')", portStr)
}
parts := strings.Split(portStr, "://")
if len(parts) != 2 {
return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port')", portStr)
}
protocol := parts[0]
if protocol != "tcp" && protocol != "udp" {
return fmt.Errorf("invalid protocol: %s (expected 'tcp' or 'udp')", protocol)
}
addressPart := parts[1]
// Validate address part length
if len(addressPart) > 256 {
return fmt.Errorf("address part too long")
}
addressParts := strings.Split(addressPart, ":")
if len(addressParts) == 2 {
// Check if first part is a hostname/IP (contains dots or is not a number) - server format
if _, err := strconv.Atoi(addressParts[0]); err != nil || strings.Contains(addressParts[0], ".") {
// Server format: protocol://target:targetport
targetHost := addressParts[0]
// Validate target host
if len(targetHost) > 253 { // RFC 1123 limit
return fmt.Errorf("target host too long")
}
// Basic character validation for hostname
for _, c := range targetHost {
if c < 32 || c > 126 {
return fmt.Errorf("invalid character in target host")
}
}
targetPort, err := strconv.Atoi(addressParts[1])
if err != nil {
return fmt.Errorf("invalid target port: %s", addressParts[1])
}
// Validate port range
if targetPort < 1 || targetPort > 65535 {
return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort)
}
p.LocalPort = targetPort // Server listens on this port
p.RemotePort = targetPort // Server forwards to this port on target
p.Protocol = protocol
p.TargetHost = targetHost
return nil
} else {
// Client format: protocol://targetport:localport (first part is a number)
targetPort, err := strconv.Atoi(addressParts[0])
if err != nil {
return fmt.Errorf("invalid target port: %s", addressParts[0])
}
localPort, err := strconv.Atoi(addressParts[1])
if err != nil {
return fmt.Errorf("invalid local port: %s", addressParts[1])
}
// Validate port ranges
if targetPort < 1 || targetPort > 65535 {
return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort)
}
if localPort < 1 || localPort > 65535 {
return fmt.Errorf("invalid local port: %d (must be 1-65535)", localPort)
}
p.LocalPort = localPort // Client listens on this port
p.RemotePort = targetPort // Client forwards to this port on teleport server
p.Protocol = protocol
p.TargetHost = "" // Client doesn't specify target host
return nil
}
} else {
return fmt.Errorf("invalid address format: %s (expected 'target:port' for server or 'targetport:localport' for client)", addressPart)
}
}
// MarshalYAML implements custom YAML marshaling for PortRule
func (p PortRule) MarshalYAML() (interface{}, error) {
// Use new format: protocol://target:targetport (server) or protocol://targetport:localport (client)
if p.TargetHost == "" {
// Client format: protocol://targetport:localport
return fmt.Sprintf("%s://%d:%d", p.Protocol, p.RemotePort, p.LocalPort), nil
} else {
// Server format: protocol://target:targetport
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.TargetHost, p.RemotePort), nil
}
}
// RateLimitConfig defines rate limiting configuration
type RateLimitConfig struct {
Enabled bool `yaml:"enabled"`
RequestsPerSecond int `yaml:"requests_per_second"`
BurstSize int `yaml:"burst_size"`
WindowSize time.Duration `yaml:"window_size"`
}
// DNSServerConfig defines DNS server configuration
type DNSServerConfig struct {
Enabled bool `yaml:"enabled"`
ListenPort int `yaml:"listen_port"`
BackupServer string `yaml:"backup_server"`
CustomRecords []DNSRecord `yaml:"custom_records"`
}
// DNSRecord defines a custom DNS record
type DNSRecord struct {
Name string `yaml:"name"`
Type string `yaml:"type"` // A, AAAA, CNAME, MX, TXT, NS, SRV
Value string `yaml:"value"`
TTL uint32 `yaml:"ttl"`
Priority uint16 `yaml:"priority,omitempty"` // For MX and SRV records
Weight uint16 `yaml:"weight,omitempty"` // For SRV records
Port uint16 `yaml:"port,omitempty"` // For SRV records
}
// LoadConfig loads and parses the configuration file
func LoadConfig(filename string) (*Config, error) {
// Check file size before reading (max 1MB)
fileInfo, err := os.Stat(filename)
if err != nil {
return nil, fmt.Errorf("failed to stat config file: %v", err)
}
if fileInfo.Size() > 1024*1024 { // 1MB limit
return nil, fmt.Errorf("config file too large: %d bytes (max 1MB)", fileInfo.Size())
}
data, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %v", err)
}
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %v", err)
}
// Set default values
if config.ReadTimeout == 0 {
config.ReadTimeout = 30 * time.Second
}
if config.WriteTimeout == 0 {
config.WriteTimeout = 30 * time.Second
}
if config.MaxConnections == 0 {
config.MaxConnections = 1000
}
if config.RateLimit.RequestsPerSecond == 0 {
config.RateLimit.RequestsPerSecond = 100
}
if config.RateLimit.BurstSize == 0 {
config.RateLimit.BurstSize = 200
}
if config.RateLimit.WindowSize == 0 {
config.RateLimit.WindowSize = 1 * time.Second
}
if config.DNSServer.ListenPort == 0 {
config.DNSServer.ListenPort = 5353
}
if config.DNSServer.BackupServer == "" {
config.DNSServer.BackupServer = "8.8.8.8:53"
}
// Validate configuration
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("configuration validation failed: %v", err)
}
return &config, nil
}
// DetectMode determines if this should run as server or client based on config
func DetectMode(config *Config) (string, error) {
hasListen := config.ListenAddress != ""
hasRemote := config.RemoteAddress != ""
if hasListen && hasRemote {
return "", fmt.Errorf("configuration error: cannot have both 'listen_address' and 'remote_address' set. Choose one:\n" +
"- Server: set 'listen_address' and leave 'remote_address' empty\n" +
"- Client: set 'remote_address' and leave 'listen_address' empty")
}
if !hasListen && !hasRemote {
return "", fmt.Errorf("configuration error: must specify either 'listen_address' (for server) or 'remote_address' (for client)")
}
if hasListen && !hasRemote {
return "server", nil
}
if hasRemote && !hasListen {
return "client", nil
}
return "", fmt.Errorf("internal error: could not determine mode")
}
// Validate validates the configuration
func (c *Config) Validate() error {
// Validate instance ID
if c.InstanceID == "" {
return fmt.Errorf("instance_id is required")
}
// Validate encryption key
if err := encryption.ValidateEncryptionKey(c.EncryptionKey); err != nil {
return fmt.Errorf("invalid encryption key: %v", err)
}
// Validate ports
if len(c.Ports) == 0 {
return fmt.Errorf("at least one port rule is required")
}
for i, port := range c.Ports {
if err := port.Validate(); err != nil {
return fmt.Errorf("port rule %d is invalid: %v", i, err)
}
}
// Validate DNS server configuration
if c.DNSServer.Enabled {
if err := c.DNSServer.Validate(); err != nil {
return fmt.Errorf("DNS server configuration is invalid: %v", err)
}
}
// Validate rate limiting configuration
if c.RateLimit.Enabled {
if c.RateLimit.RequestsPerSecond <= 0 {
return fmt.Errorf("rate limit requests_per_second must be positive")
}
if c.RateLimit.BurstSize <= 0 {
return fmt.Errorf("rate limit burst_size must be positive")
}
if c.RateLimit.WindowSize <= 0 {
return fmt.Errorf("rate limit window_size must be positive")
}
if c.RateLimit.BurstSize < c.RateLimit.RequestsPerSecond {
return fmt.Errorf("rate limit burst_size should be at least requests_per_second")
}
if c.RateLimit.RequestsPerSecond > 10000 {
return fmt.Errorf("rate limit requests_per_second too high: %d (maximum 10000)", c.RateLimit.RequestsPerSecond)
}
if c.RateLimit.BurstSize > 50000 {
return fmt.Errorf("rate limit burst_size too high: %d (maximum 50000)", c.RateLimit.BurstSize)
}
}
// Validate connection limits
if c.MaxConnections < 0 {
return fmt.Errorf("max_connections cannot be negative")
}
if c.MaxConnections > 100000 {
return fmt.Errorf("max_connections too high: %d (maximum 100000)", c.MaxConnections)
}
// Validate timeout values
if c.ReadTimeout < 0 {
return fmt.Errorf("read_timeout cannot be negative")
}
if c.WriteTimeout < 0 {
return fmt.Errorf("write_timeout cannot be negative")
}
if c.ReadTimeout > 24*time.Hour {
return fmt.Errorf("read_timeout too high: %v (maximum 24 hours)", c.ReadTimeout)
}
if c.WriteTimeout > 24*time.Hour {
return fmt.Errorf("write_timeout too high: %v (maximum 24 hours)", c.WriteTimeout)
}
return nil
}
// Validate validates a port rule
func (p *PortRule) Validate() error {
if p.LocalPort < 1 || p.LocalPort > 65535 {
return fmt.Errorf("local_port must be between 1 and 65535, got %d", p.LocalPort)
}
if p.RemotePort < 1 || p.RemotePort > 65535 {
return fmt.Errorf("remote_port must be between 1 and 65535, got %d", p.RemotePort)
}
if p.Protocol != "tcp" && p.Protocol != "udp" {
return fmt.Errorf("protocol must be 'tcp' or 'udp', got '%s'", p.Protocol)
}
return nil
}
// Validate validates DNS server configuration
func (d *DNSServerConfig) Validate() error {
if d.ListenPort < 1 || d.ListenPort > 65535 {
return fmt.Errorf("DNS listen_port must be between 1 and 65535, got %d", d.ListenPort)
}
if d.BackupServer == "" {
return fmt.Errorf("DNS backup_server is required when DNS server is enabled")
}
// Validate backup server format
parts := strings.Split(d.BackupServer, ":")
if len(parts) != 2 {
return fmt.Errorf("DNS backup_server must be in format 'host:port', got '%s'", d.BackupServer)
}
port, err := strconv.Atoi(parts[1])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("DNS backup_server port must be between 1 and 65535, got '%s'", parts[1])
}
// Validate custom records
for i, record := range d.CustomRecords {
if err := record.Validate(); err != nil {
return fmt.Errorf("DNS record %d is invalid: %v", i, err)
}
}
return nil
}
// Validate validates a DNS record
func (r *DNSRecord) Validate() error {
if r.Name == "" {
return fmt.Errorf("DNS record name is required")
}
// Validate DNS name length and format
if len(r.Name) > 253 {
return fmt.Errorf("DNS record name too long")
}
// Basic character validation for DNS name
for _, c := range r.Name {
if c < 32 || c > 126 {
return fmt.Errorf("invalid character in DNS record name")
}
}
validTypes := []string{"A", "AAAA", "CNAME", "MX", "TXT", "SRV", "NS"}
valid := false
for _, t := range validTypes {
if strings.ToUpper(r.Type) == t {
valid = true
break
}
}
if !valid {
return fmt.Errorf("DNS record type must be one of %v, got '%s'", validTypes, r.Type)
}
if r.Value == "" {
return fmt.Errorf("DNS record value is required")
}
// Validate DNS record value length
if len(r.Value) > 1024 {
return fmt.Errorf("DNS record value too long")
}
// Basic character validation for DNS value
for _, c := range r.Value {
if c < 32 || c > 126 {
return fmt.Errorf("invalid character in DNS record value")
}
}
if r.TTL == 0 {
return fmt.Errorf("DNS record TTL must be greater than 0")
}
// Validate TTL range (reasonable limits)
if r.TTL > 86400 { // 24 hours
return fmt.Errorf("DNS record TTL too large (max 86400 seconds)")
}
// Note: Priority, Weight, and Port are uint16 types, so they cannot exceed 65535
// No additional validation needed for these fields as they are already constrained by their type
return nil
}
// GenerateExampleConfig generates an example configuration file
func GenerateExampleConfig(filename string) error {
// Generate a strong encryption key
strongKey, err := generateStrongEncryptionKey()
if err != nil {
return fmt.Errorf("failed to generate encryption key: %v", err)
}
var config Config
if strings.Contains(filename, "server") {
// Generate server configuration
config = Config{
InstanceID: "teleport-server-01",
ListenAddress: ":8080",
RemoteAddress: "",
Ports: []PortRule{
{LocalPort: 80, RemotePort: 80, Protocol: "tcp", TargetHost: "localhost"},
},
EncryptionKey: strongKey,
KeepAlive: true,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
MaxConnections: 1000,
RateLimit: RateLimitConfig{
Enabled: true,
RequestsPerSecond: 100,
BurstSize: 200,
WindowSize: 1 * time.Second,
},
DNSServer: DNSServerConfig{
ListenPort: 5353,
BackupServer: "8.8.8.8:53",
CustomRecords: []DNSRecord{},
},
}
} else {
// Generate client configuration
config = Config{
InstanceID: "teleport-client-01",
ListenAddress: "",
RemoteAddress: "localhost:8080",
Ports: []PortRule{
{LocalPort: 8080, RemotePort: 80, Protocol: "tcp", TargetHost: ""},
},
EncryptionKey: strongKey,
KeepAlive: true,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
MaxConnections: 100,
RateLimit: RateLimitConfig{
Enabled: true,
RequestsPerSecond: 50,
BurstSize: 100,
WindowSize: 1 * time.Second,
},
DNSServer: DNSServerConfig{
ListenPort: 5353,
BackupServer: "8.8.8.8:53",
CustomRecords: []DNSRecord{
{Name: "app.local", Type: "A", Value: "127.0.0.1", TTL: 300},
},
},
}
}
data, err := yaml.Marshal(&config)
if err != nil {
return fmt.Errorf("failed to marshal config: %v", err)
}
err = os.WriteFile(filename, data, 0644)
if err != nil {
return fmt.Errorf("failed to write config file: %v", err)
}
fmt.Printf("Generated example configuration: %s\n", filename)
fmt.Printf("Edit the configuration file and run: ./teleport -config %s\n", filename)
return nil
}
// generateStrongEncryptionKey generates a cryptographically secure encryption key
func generateStrongEncryptionKey() (string, error) {
// Generate 32 random bytes (256 bits) for a strong encryption key
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random key: %v", err)
}
// Convert to hexadecimal string for easy copying
return hex.EncodeToString(bytes), nil
}

363
pkg/config/config_test.go Normal file
View File

@@ -0,0 +1,363 @@
package config
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestLoadConfig(t *testing.T) {
// Create a temporary config file
tempDir := t.TempDir()
configFile := filepath.Join(tempDir, "test-config.yaml")
configContent := `
instance_id: test-instance
listen_address: 127.0.0.1:8080
remote_address: ""
ports:
- tcp://localhost:22
- tcp://localhost:80
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
keep_alive: true
read_timeout: 30s
write_timeout: 30s
dns_server:
enabled: true
listen_port: 5353
backup_server: 8.8.8.8:53
custom_records:
- name: test.local
type: A
value: 192.168.1.100
ttl: 300
`
err := os.WriteFile(configFile, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
// Test loading config
config, err := LoadConfig(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Verify config values
if config.InstanceID != "test-instance" {
t.Errorf("Expected InstanceID 'test-instance', got '%s'", config.InstanceID)
}
if config.ListenAddress != "127.0.0.1:8080" {
t.Errorf("Expected ListenAddress '127.0.0.1:8080', got '%s'", config.ListenAddress)
}
if config.RemoteAddress != "" {
t.Errorf("Expected empty RemoteAddress, got '%s'", config.RemoteAddress)
}
if len(config.Ports) != 2 {
t.Errorf("Expected 2 ports, got %d", len(config.Ports))
}
if config.Ports[0].LocalPort != 22 || config.Ports[0].RemotePort != 22 || config.Ports[0].Protocol != "tcp" || config.Ports[0].TargetHost != "localhost" {
t.Error("First port rule is incorrect")
}
if config.EncryptionKey != "a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03" {
t.Errorf("Expected EncryptionKey 'a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03', got '%s'", config.EncryptionKey)
}
if !config.KeepAlive {
t.Error("Expected KeepAlive to be true")
}
if config.ReadTimeout != 30*time.Second {
t.Errorf("Expected ReadTimeout 30s, got %v", config.ReadTimeout)
}
if config.WriteTimeout != 30*time.Second {
t.Errorf("Expected WriteTimeout 30s, got %v", config.WriteTimeout)
}
if !config.DNSServer.Enabled {
t.Error("Expected DNS server to be enabled")
}
if config.DNSServer.ListenPort != 5353 {
t.Errorf("Expected DNS listen port 5353, got %d", config.DNSServer.ListenPort)
}
if len(config.DNSServer.CustomRecords) != 1 {
t.Errorf("Expected 1 custom DNS record, got %d", len(config.DNSServer.CustomRecords))
}
}
func TestLoadConfigDefaults(t *testing.T) {
// Create a minimal config file
tempDir := t.TempDir()
configFile := filepath.Join(tempDir, "minimal-config.yaml")
configContent := `
instance_id: test-instance
listen_address: 127.0.0.1:8080
ports:
- tcp://localhost:22
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
`
err := os.WriteFile(configFile, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
config, err := LoadConfig(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Check default values
if config.ReadTimeout != 30*time.Second {
t.Errorf("Expected default ReadTimeout 30s, got %v", config.ReadTimeout)
}
if config.WriteTimeout != 30*time.Second {
t.Errorf("Expected default WriteTimeout 30s, got %v", config.WriteTimeout)
}
if config.DNSServer.ListenPort != 5353 {
t.Errorf("Expected default DNS listen port 5353, got %d", config.DNSServer.ListenPort)
}
if config.DNSServer.BackupServer != "8.8.8.8:53" {
t.Errorf("Expected default backup server '8.8.8.8:53', got '%s'", config.DNSServer.BackupServer)
}
}
func TestDetectMode(t *testing.T) {
tests := []struct {
name string
listenAddress string
remoteAddress string
expectedMode string
expectedError bool
}{
{
name: "Server mode",
listenAddress: "127.0.0.1:8080",
remoteAddress: "",
expectedMode: "server",
expectedError: false,
},
{
name: "Client mode",
listenAddress: "",
remoteAddress: "127.0.0.1:8080",
expectedMode: "client",
expectedError: false,
},
{
name: "Both addresses set - error",
listenAddress: "127.0.0.1:8080",
remoteAddress: "127.0.0.1:8080",
expectedMode: "",
expectedError: true,
},
{
name: "Neither address set - error",
listenAddress: "",
remoteAddress: "",
expectedMode: "",
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &Config{
ListenAddress: tt.listenAddress,
RemoteAddress: tt.remoteAddress,
}
mode, err := DetectMode(config)
if tt.expectedError {
if err == nil {
t.Error("Expected error but got none")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if mode != tt.expectedMode {
t.Errorf("Expected mode '%s', got '%s'", tt.expectedMode, mode)
}
}
})
}
}
func TestLoadConfigFileNotFound(t *testing.T) {
_, err := LoadConfig("nonexistent-config.yaml")
if err == nil {
t.Error("Expected error for nonexistent config file")
}
}
func TestLoadConfigInvalidYAML(t *testing.T) {
tempDir := t.TempDir()
configFile := filepath.Join(tempDir, "invalid-config.yaml")
// Write invalid YAML
err := os.WriteFile(configFile, []byte("invalid: yaml: content: ["), 0644)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
_, err = LoadConfig(configFile)
if err == nil {
t.Error("Expected error for invalid YAML")
}
}
func TestPortRuleURLFormat(t *testing.T) {
tests := []struct {
name string
config string
expected PortRule
}{
{
name: "Server format with localhost",
config: `
instance_id: test
listen_address: :8080
ports:
- tcp://localhost:80
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
`,
expected: PortRule{
LocalPort: 80,
RemotePort: 80,
Protocol: "tcp",
TargetHost: "localhost",
},
},
{
name: "Server format with remote host",
config: `
instance_id: test
listen_address: :8080
ports:
- tcp://server-a:22
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
`,
expected: PortRule{
LocalPort: 22,
RemotePort: 22,
Protocol: "tcp",
TargetHost: "server-a",
},
},
{
name: "Client format",
config: `
instance_id: test
remote_address: localhost:8080
ports:
- tcp://80:8080
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
`,
expected: PortRule{
LocalPort: 8080,
RemotePort: 80,
Protocol: "tcp",
TargetHost: "",
},
},
{
name: "UDP format",
config: `
instance_id: test
listen_address: :8080
ports:
- udp://dns-server:53
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
`,
expected: PortRule{
LocalPort: 53,
RemotePort: 53,
Protocol: "udp",
TargetHost: "dns-server",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
configFile := filepath.Join(tempDir, "test-config.yaml")
err := os.WriteFile(configFile, []byte(tt.config), 0644)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
config, err := LoadConfig(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if len(config.Ports) != 1 {
t.Fatalf("Expected 1 port, got %d", len(config.Ports))
}
port := config.Ports[0]
if port.LocalPort != tt.expected.LocalPort {
t.Errorf("Expected LocalPort %d, got %d", tt.expected.LocalPort, port.LocalPort)
}
if port.RemotePort != tt.expected.RemotePort {
t.Errorf("Expected RemotePort %d, got %d", tt.expected.RemotePort, port.RemotePort)
}
if port.Protocol != tt.expected.Protocol {
t.Errorf("Expected Protocol %s, got %s", tt.expected.Protocol, port.Protocol)
}
if port.TargetHost != tt.expected.TargetHost {
t.Errorf("Expected TargetHost %s, got %s", tt.expected.TargetHost, port.TargetHost)
}
})
}
}
func TestPortRuleOldFormatFails(t *testing.T) {
// Test that old object format now fails
tempDir := t.TempDir()
configFile := filepath.Join(tempDir, "old-format-config.yaml")
configContent := `
instance_id: test
listen_address: :8080
ports:
- local_port: 22
remote_port: 22
protocol: tcp
encryption_key: test-key
`
err := os.WriteFile(configFile, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
_, err = LoadConfig(configFile)
if err == nil {
t.Error("Expected error for old format, but got none")
}
// Check that the error message mentions the expected format
if !strings.Contains(err.Error(), "protocol://") {
t.Errorf("Expected error message to mention 'protocol://', got: %v", err)
}
}

329
pkg/dns/dns.go Normal file
View File

@@ -0,0 +1,329 @@
package dns
import (
"fmt"
"net"
"strings"
"time"
"teleport/pkg/config"
"teleport/pkg/logger"
"github.com/miekg/dns"
)
// StartDNSServer starts the built-in DNS server using miekg/dns
func StartDNSServer(cfg *config.Config) {
if !cfg.DNSServer.Enabled {
return
}
// Create DNS server
server := &dns.Server{
Addr: fmt.Sprintf(":%d", cfg.DNSServer.ListenPort),
Net: "udp",
}
// Set up handler
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
handleDNSQuery(w, r, cfg)
})
logger.WithField("port", cfg.DNSServer.ListenPort).Info("DNS server started")
// Start server
if err := server.ListenAndServe(); err != nil {
logger.WithField("error", err).Error("Failed to start DNS server")
}
}
// handleDNSQuery handles DNS queries using miekg/dns
func handleDNSQuery(w dns.ResponseWriter, r *dns.Msg, cfg *config.Config) {
// Check if we have custom records for this query
response := checkCustomRecords(r, cfg)
if response == nil {
// Forward to backup DNS server
response = forwardToBackupDNS(r, cfg)
if response == nil {
// Send error response
response = new(dns.Msg)
response.SetRcode(r, dns.RcodeServerFailure)
}
}
// Send response
w.WriteMsg(response)
}
// checkCustomRecords checks if we have custom records for the query
func checkCustomRecords(query *dns.Msg, cfg *config.Config) *dns.Msg {
if len(query.Question) == 0 {
return nil
}
question := query.Question[0]
questionName := strings.ToLower(question.Name)
// Validate question name length and format
if len(questionName) > 253 {
logger.WithField("name", questionName).Warn("DNS query name too long, ignoring")
return nil
}
// Basic character validation for DNS name
for _, c := range questionName {
if c < 32 || c > 126 {
logger.WithField("name", questionName).Warn("DNS query name contains invalid characters, ignoring")
return nil
}
}
// Look for matching custom records
var answers []dns.RR
for _, record := range cfg.DNSServer.CustomRecords {
if strings.ToLower(record.Name) == questionName && getRecordType(record.Type) == question.Qtype {
answer := createDNSRecord(record, questionName)
if answer != nil {
answers = append(answers, answer)
}
}
}
if len(answers) == 0 {
return nil
}
// Create response
response := new(dns.Msg)
response.SetReply(query)
response.Authoritative = true
response.Answer = answers
return response
}
// getRecordType converts string record type to DNS type code
func getRecordType(recordType string) uint16 {
switch strings.ToUpper(recordType) {
case "A":
return dns.TypeA
case "AAAA":
return dns.TypeAAAA
case "CNAME":
return dns.TypeCNAME
case "MX":
return dns.TypeMX
case "TXT":
return dns.TypeTXT
case "NS":
return dns.TypeNS
case "SRV":
return dns.TypeSRV
default:
return 0
}
}
// createDNSRecord creates a DNS resource record from a custom record
func createDNSRecord(record config.DNSRecord, name string) dns.RR {
// Validate record value length
if len(record.Value) > 1024 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("DNS record value too long, ignoring")
return nil
}
switch strings.ToUpper(record.Type) {
case "A":
// IPv4 address
ip := net.ParseIP(record.Value)
if ip == nil || ip.To4() == nil {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("Invalid IPv4 address in DNS record, ignoring")
return nil
}
return &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: record.TTL,
},
A: ip.To4(),
}
case "AAAA":
// IPv6 address
ip := net.ParseIP(record.Value)
if ip == nil || ip.To16() == nil {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("Invalid IPv6 address in DNS record, ignoring")
return nil
}
return &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: record.TTL,
},
AAAA: ip.To16(),
}
case "CNAME":
// Canonical name
// Validate CNAME target length
if len(record.Value) > 253 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("CNAME target too long, ignoring")
return nil
}
return &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: record.TTL,
},
Target: dns.Fqdn(record.Value),
}
case "MX":
// Mail exchange
// Validate MX target length
if len(record.Value) > 253 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("MX target too long, ignoring")
return nil
}
return &dns.MX{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeMX,
Class: dns.ClassINET,
Ttl: record.TTL,
},
Preference: record.Priority,
Mx: dns.Fqdn(record.Value),
}
case "TXT":
// Text record
// Validate TXT record length (RFC 1035 limit)
if len(record.Value) > 255 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("TXT record too long, ignoring")
return nil
}
return &dns.TXT{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: record.TTL,
},
Txt: []string{record.Value},
}
case "NS":
// Name server
// Validate NS target length
if len(record.Value) > 253 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("NS target too long, ignoring")
return nil
}
return &dns.NS{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: record.TTL,
},
Ns: dns.Fqdn(record.Value),
}
case "SRV":
// Service record
// Validate SRV target length
if len(record.Value) > 253 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("SRV target too long, ignoring")
return nil
}
return &dns.SRV{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: record.TTL,
},
Priority: record.Priority,
Weight: record.Weight,
Port: record.Port,
Target: dns.Fqdn(record.Value),
}
default:
return nil
}
}
// forwardToBackupDNS forwards the query to the backup DNS server
func forwardToBackupDNS(query *dns.Msg, cfg *config.Config) *dns.Msg {
// Create DNS client with timeout
client := new(dns.Client)
client.Net = "udp"
client.Timeout = 5 * time.Second // 5 second timeout
// Forward query to backup server
response, _, err := client.Exchange(query, cfg.DNSServer.BackupServer)
if err != nil {
logger.WithField("error", err).Error("Failed to forward DNS query to backup server")
return nil
}
// Validate response
if response == nil {
logger.Warn("Received nil response from backup DNS server")
return nil
}
// Basic response validation
if response.Id != query.Id {
logger.Warn("DNS response ID mismatch, potential spoofing attempt")
return nil
}
// Limit response size to prevent amplification attacks
if len(response.Answer) > 10 {
logger.WithField("answer_count", len(response.Answer)).Warn("DNS response has too many answers, truncating")
response.Answer = response.Answer[:10]
}
return response
}

View File

@@ -0,0 +1,226 @@
package encryption
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"fmt"
"io"
"math"
"sync"
"time"
"golang.org/x/crypto/pbkdf2"
)
const (
// PBKDF2 parameters for key derivation
PBKDF2Iterations = 100000 // OWASP recommended minimum
PBKDF2KeyLength = 32 // 256 bits
PBKDF2SaltLength = 16 // 128 bits
// Replay protection parameters
MaxPacketAge = 5 * time.Minute // Maximum age for UDP packets
NonceWindow = 1000 // Number of nonces to track for replay protection
)
// DeriveKey derives an encryption key from a password using PBKDF2
func DeriveKey(password string) []byte {
// Use a deterministic salt derived from the password hash for consistent key derivation
// This ensures the same password always produces the same key while avoiding rainbow tables
hasher := sha256.New()
hasher.Write([]byte(password))
passwordHash := hasher.Sum(nil)
// Create a deterministic salt from the password hash
salt := make([]byte, PBKDF2SaltLength)
copy(salt, passwordHash[:PBKDF2SaltLength])
key := pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, PBKDF2KeyLength, sha256.New)
return key
}
// DeriveKeyWithSalt derives an encryption key from a password using PBKDF2 with a custom salt
func DeriveKeyWithSalt(password string, salt []byte) ([]byte, error) {
if len(salt) != PBKDF2SaltLength {
return nil, fmt.Errorf("salt length must be %d bytes, got %d", PBKDF2SaltLength, len(salt))
}
key := pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, PBKDF2KeyLength, sha256.New)
return key, nil
}
// GenerateSalt generates a cryptographically secure random salt
func GenerateSalt() ([]byte, error) {
salt := make([]byte, PBKDF2SaltLength)
_, err := io.ReadFull(rand.Reader, salt)
return salt, err
}
// EncryptData encrypts data using AES-GCM
func EncryptData(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
ciphertext := gcm.Seal(nonce, nonce, data, nil)
return ciphertext, nil
}
// DecryptData decrypts data using AES-GCM
func DecryptData(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
// ReplayProtection tracks nonces to prevent replay attacks
type ReplayProtection struct {
nonces map[uint64]time.Time
mutex sync.RWMutex
maxNonces int // Maximum number of nonces to track
}
// NewReplayProtection creates a new replay protection instance
func NewReplayProtection() *ReplayProtection {
return &ReplayProtection{
nonces: make(map[uint64]time.Time),
maxNonces: NonceWindow * 2, // Allow 2x the window size for safety
}
}
// IsValidNonce checks if a nonce is valid (not replayed and not too old)
func (rp *ReplayProtection) IsValidNonce(nonce uint64, timestamp int64) bool {
rp.mutex.Lock()
defer rp.mutex.Unlock()
now := time.Now()
packetTime := time.Unix(timestamp, 0)
// Check if packet is too old
if now.Sub(packetTime) > MaxPacketAge {
return false
}
// Check if nonce was already used
if _, exists := rp.nonces[nonce]; exists {
return false
}
// Add nonce to tracking
rp.nonces[nonce] = now
// Clean up old nonces to prevent memory leaks
if len(rp.nonces) > rp.maxNonces {
rp.cleanupOldNonces(now)
}
return true
}
// cleanupOldNonces removes nonces older than MaxPacketAge
func (rp *ReplayProtection) cleanupOldNonces(now time.Time) {
for nonce, timestamp := range rp.nonces {
if now.Sub(timestamp) > MaxPacketAge {
delete(rp.nonces, nonce)
}
}
}
// ValidatePacketTimestamp validates that a packet timestamp is within acceptable range
func ValidatePacketTimestamp(timestamp int64) bool {
now := time.Now().Unix()
packetTime := timestamp
// Check if packet is too old or from the future
age := now - packetTime
return age >= 0 && age <= int64(MaxPacketAge.Seconds())
}
// ConstantTimeCompare performs a constant-time comparison of two byte slices
func ConstantTimeCompare(a, b []byte) bool {
return subtle.ConstantTimeCompare(a, b) == 1
}
// ValidateEncryptionKey validates that an encryption key meets security requirements
func ValidateEncryptionKey(key string) error {
if len(key) < 32 {
return fmt.Errorf("encryption key must be at least 32 characters long")
}
// Check for common weak keys
weakKeys := []string{
"password", "123456", "admin", "test", "default",
"your-secure-encryption-key-change-this-to-something-random",
"test-encryption-key-12345", "teleport-key", "secret",
}
for _, weak := range weakKeys {
if key == weak {
return fmt.Errorf("encryption key is too weak, please use a strong random key")
}
}
// Check for sufficient entropy (basic check)
entropy := calculateEntropy(key)
if entropy < 3.5 { // Minimum entropy threshold
return fmt.Errorf("encryption key has insufficient entropy, please use a more random key")
}
return nil
}
// calculateEntropy calculates the Shannon entropy of a string
func calculateEntropy(s string) float64 {
freq := make(map[rune]int)
for _, r := range s {
freq[r]++
}
entropy := 0.0
length := float64(len(s))
for _, count := range freq {
p := float64(count) / length
entropy -= p * log2(p)
}
return entropy
}
// log2 calculates log base 2
func log2(x float64) float64 {
return math.Log2(x)
}

View File

@@ -0,0 +1,242 @@
package encryption
import (
"testing"
"time"
)
func TestDeriveKey(t *testing.T) {
password := "test-password"
key := DeriveKey(password)
if len(key) != 32 {
t.Errorf("Expected key length 32, got %d", len(key))
}
// Test that same password produces same key
key2 := DeriveKey(password)
if string(key) != string(key2) {
t.Error("Same password should produce same key")
}
// Test that different passwords produce different keys
key3 := DeriveKey("different-password")
if string(key) == string(key3) {
t.Error("Different passwords should produce different keys")
}
}
func TestEncryptDecrypt(t *testing.T) {
key := DeriveKey("test-key")
originalData := []byte("Hello, World! This is a test message.")
// Test encryption
encryptedData, err := EncryptData(originalData, key)
if err != nil {
t.Fatalf("Encryption failed: %v", err)
}
if len(encryptedData) <= len(originalData) {
t.Error("Encrypted data should be longer than original data (due to nonce)")
}
// Test decryption
decryptedData, err := DecryptData(encryptedData, key)
if err != nil {
t.Fatalf("Decryption failed: %v", err)
}
if string(decryptedData) != string(originalData) {
t.Errorf("Decrypted data doesn't match original. Expected: %s, Got: %s",
string(originalData), string(decryptedData))
}
}
func TestEncryptDecryptEmptyData(t *testing.T) {
key := DeriveKey("test-key")
originalData := []byte("")
encryptedData, err := EncryptData(originalData, key)
if err != nil {
t.Fatalf("Encryption of empty data failed: %v", err)
}
decryptedData, err := DecryptData(encryptedData, key)
if err != nil {
t.Fatalf("Decryption of empty data failed: %v", err)
}
if len(decryptedData) != 0 {
t.Error("Decrypted empty data should be empty")
}
}
func TestEncryptDecryptLargeData(t *testing.T) {
key := DeriveKey("test-key")
// Create a large data block (1MB)
originalData := make([]byte, 1024*1024)
for i := range originalData {
originalData[i] = byte(i % 256)
}
encryptedData, err := EncryptData(originalData, key)
if err != nil {
t.Fatalf("Encryption of large data failed: %v", err)
}
decryptedData, err := DecryptData(encryptedData, key)
if err != nil {
t.Fatalf("Decryption of large data failed: %v", err)
}
if len(decryptedData) != len(originalData) {
t.Errorf("Decrypted data length mismatch. Expected: %d, Got: %d",
len(originalData), len(decryptedData))
}
for i := range originalData {
if decryptedData[i] != originalData[i] {
t.Errorf("Data mismatch at position %d", i)
break
}
}
}
func TestWrongKeyDecryption(t *testing.T) {
key1 := DeriveKey("key1")
key2 := DeriveKey("key2")
originalData := []byte("test data")
encryptedData, err := EncryptData(originalData, key1)
if err != nil {
t.Fatalf("Encryption failed: %v", err)
}
// Try to decrypt with wrong key
_, err = DecryptData(encryptedData, key2)
if err == nil {
t.Error("Decryption with wrong key should fail")
}
}
func TestCorruptedDataDecryption(t *testing.T) {
key := DeriveKey("test-key")
originalData := []byte("test data")
encryptedData, err := EncryptData(originalData, key)
if err != nil {
t.Fatalf("Encryption failed: %v", err)
}
// Corrupt the data
encryptedData[0] ^= 0xFF
// Try to decrypt corrupted data
_, err = DecryptData(encryptedData, key)
if err == nil {
t.Error("Decryption of corrupted data should fail")
}
}
func TestValidateEncryptionKey(t *testing.T) {
// Test valid key
validKey := "a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03"
if err := ValidateEncryptionKey(validKey); err != nil {
t.Errorf("Valid key should pass validation: %v", err)
}
// Test short key
shortKey := "short"
if err := ValidateEncryptionKey(shortKey); err == nil {
t.Error("Short key should fail validation")
}
// Test weak keys
weakKeys := []string{
"password", "123456", "admin", "test", "default",
"your-secure-encryption-key-change-this-to-something-random",
"test-encryption-key-12345", "teleport-key", "secret",
}
for _, weakKey := range weakKeys {
if err := ValidateEncryptionKey(weakKey); err == nil {
t.Errorf("Weak key '%s' should fail validation", weakKey)
}
}
// Test low entropy key
lowEntropyKey := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
if err := ValidateEncryptionKey(lowEntropyKey); err == nil {
t.Error("Low entropy key should fail validation")
}
}
func TestReplayProtection(t *testing.T) {
rp := NewReplayProtection()
nonce := uint64(12345)
timestamp := time.Now().Unix()
// First use should be valid
if !rp.IsValidNonce(nonce, timestamp) {
t.Error("First use of nonce should be valid")
}
// Replay should be invalid
if rp.IsValidNonce(nonce, timestamp) {
t.Error("Replay of nonce should be invalid")
}
// Different nonce should be valid
if !rp.IsValidNonce(nonce+1, timestamp) {
t.Error("Different nonce should be valid")
}
}
func TestValidatePacketTimestamp(t *testing.T) {
now := time.Now().Unix()
// Current timestamp should be valid
if !ValidatePacketTimestamp(now) {
t.Error("Current timestamp should be valid")
}
// Recent timestamp should be valid
recent := now - 60 // 1 minute ago
if !ValidatePacketTimestamp(recent) {
t.Error("Recent timestamp should be valid")
}
// Old timestamp should be invalid
old := now - int64(MaxPacketAge.Seconds()) - 1
if ValidatePacketTimestamp(old) {
t.Error("Old timestamp should be invalid")
}
// Future timestamp should be invalid
future := now + 3600 // 1 hour in future
if ValidatePacketTimestamp(future) {
t.Error("Future timestamp should be invalid")
}
}
func TestConstantTimeCompare(t *testing.T) {
a := []byte("test")
b := []byte("test")
c := []byte("different")
if !ConstantTimeCompare(a, b) {
t.Error("Identical byte slices should compare equal")
}
if ConstantTimeCompare(a, c) {
t.Error("Different byte slices should not compare equal")
}
// Test with empty slices
empty1 := []byte{}
empty2 := []byte{}
if !ConstantTimeCompare(empty1, empty2) {
t.Error("Empty slices should compare equal")
}
}

269
pkg/logger/logger.go Normal file
View File

@@ -0,0 +1,269 @@
package logger
import (
"io"
"os"
"path/filepath"
"regexp"
"sync"
"github.com/sirupsen/logrus"
)
var (
Log *logrus.Logger
once sync.Once
mu sync.RWMutex
)
// Config holds logging configuration
type Config struct {
Level string `yaml:"level"` // debug, info, warn, error
Format string `yaml:"format"` // json, text
File string `yaml:"file"` // log file path (empty for stdout)
MaxSize int `yaml:"max_size"` // max log file size in MB
MaxBackups int `yaml:"max_backups"` // max number of backup files
MaxAge int `yaml:"max_age"` // max age of backup files in days
Compress bool `yaml:"compress"` // compress backup files
}
// Init initializes the global logger with the given configuration
func Init(config Config) error {
mu.Lock()
defer mu.Unlock()
Log = logrus.New()
// Set log level
level, err := logrus.ParseLevel(config.Level)
if err != nil {
level = logrus.InfoLevel
}
Log.SetLevel(level)
// Set log format with sanitization
switch config.Format {
case "json":
Log.SetFormatter(&SanitizedJSONFormatter{
TimestampFormat: "2006-01-02 15:04:05",
})
default:
Log.SetFormatter(&SanitizedTextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05",
})
}
// Set output
if config.File != "" {
// Ensure directory exists
dir := filepath.Dir(config.File)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
// Open log file with secure permissions (owner read/write only)
file, err := os.OpenFile(config.File, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
return err
}
// Set output to both file and stdout
Log.SetOutput(io.MultiWriter(file, os.Stdout))
} else {
Log.SetOutput(os.Stdout)
}
return nil
}
// GetLogger returns the global logger instance
func GetLogger() *logrus.Logger {
mu.RLock()
if Log != nil {
mu.RUnlock()
return Log
}
mu.RUnlock()
// Initialize with default config if not already initialized
once.Do(func() {
Init(Config{
Level: "info",
Format: "text",
File: "",
})
})
mu.RLock()
defer mu.RUnlock()
return Log
}
// Helper functions for common logging operations
func Debug(args ...interface{}) {
GetLogger().Debug(args...)
}
func Debugf(format string, args ...interface{}) {
GetLogger().Debugf(format, args...)
}
func Info(args ...interface{}) {
GetLogger().Info(args...)
}
func Infof(format string, args ...interface{}) {
GetLogger().Infof(format, args...)
}
func Warn(args ...interface{}) {
GetLogger().Warn(args...)
}
func Warnf(format string, args ...interface{}) {
GetLogger().Warnf(format, args...)
}
func Error(args ...interface{}) {
GetLogger().Error(args...)
}
func Errorf(format string, args ...interface{}) {
GetLogger().Errorf(format, args...)
}
func Fatal(args ...interface{}) {
GetLogger().Fatal(args...)
}
func Fatalf(format string, args ...interface{}) {
GetLogger().Fatalf(format, args...)
}
// WithField creates a new logger entry with a field
func WithField(key string, value interface{}) *logrus.Entry {
return GetLogger().WithField(key, value)
}
// WithFields creates a new logger entry with multiple fields
func WithFields(fields logrus.Fields) *logrus.Entry {
return GetLogger().WithFields(fields)
}
// SanitizedTextFormatter sanitizes log output
type SanitizedTextFormatter struct {
FullTimestamp bool
TimestampFormat string
}
// Format formats the log entry with sanitization
func (f *SanitizedTextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
// Sanitize sensitive data
sanitizedEntry := *entry
sanitizedEntry.Message = sanitizeString(entry.Message)
// Sanitize fields
sanitizedFields := make(logrus.Fields)
for k, v := range entry.Data {
sanitizedFields[k] = sanitizeValue(v)
}
sanitizedEntry.Data = sanitizedFields
// Use default text formatter
formatter := &logrus.TextFormatter{
FullTimestamp: f.FullTimestamp,
TimestampFormat: f.TimestampFormat,
}
return formatter.Format(&sanitizedEntry)
}
// SanitizedJSONFormatter sanitizes JSON log output
type SanitizedJSONFormatter struct {
TimestampFormat string
}
// Format formats the log entry with sanitization
func (f *SanitizedJSONFormatter) Format(entry *logrus.Entry) ([]byte, error) {
// Sanitize sensitive data
sanitizedEntry := *entry
sanitizedEntry.Message = sanitizeString(entry.Message)
// Sanitize fields
sanitizedFields := make(logrus.Fields)
for k, v := range entry.Data {
sanitizedFields[k] = sanitizeValue(v)
}
sanitizedEntry.Data = sanitizedFields
// Use default JSON formatter
formatter := &logrus.JSONFormatter{
TimestampFormat: f.TimestampFormat,
}
return formatter.Format(&sanitizedEntry)
}
// sanitizeString removes or masks sensitive information
func sanitizeString(s string) string {
// Remove potential encryption keys (hex strings longer than 32 chars)
keyPattern := regexp.MustCompile(`[a-fA-F0-9]{32,}`)
s = keyPattern.ReplaceAllString(s, "[REDACTED_KEY]")
// Remove potential passwords and tokens
passwordPattern := regexp.MustCompile(`(?i)(password|pass|pwd|secret|key|token|auth|credential)\s*[:=]\s*[^\s]+`)
s = passwordPattern.ReplaceAllString(s, "$1=[REDACTED]")
// Remove potential API keys and tokens
apiKeyPattern := regexp.MustCompile(`(?i)(api[_-]?key|access[_-]?token|bearer[_-]?token)\s*[:=]\s*[^\s]+`)
s = apiKeyPattern.ReplaceAllString(s, "$1=[REDACTED]")
// Remove potential database connection strings
dbPattern := regexp.MustCompile(`(?i)(mysql|postgres|mongodb|redis)://[^@]+@[^\s]+`)
s = dbPattern.ReplaceAllString(s, "[REDACTED_DB_CONNECTION]")
// Remove potential JWT tokens
jwtPattern := regexp.MustCompile(`eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*`)
s = jwtPattern.ReplaceAllString(s, "[REDACTED_JWT]")
// Remove potential credit card numbers
ccPattern := regexp.MustCompile(`\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b`)
s = ccPattern.ReplaceAllString(s, "[REDACTED_CC]")
// Remove potential SSNs
ssnPattern := regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`)
s = ssnPattern.ReplaceAllString(s, "[REDACTED_SSN]")
// Remove potential IP addresses in sensitive contexts
ipPattern := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`)
s = ipPattern.ReplaceAllString(s, "[REDACTED_IP]")
// Remove potential email addresses
emailPattern := regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`)
s = emailPattern.ReplaceAllString(s, "[REDACTED_EMAIL]")
return s
}
// sanitizeValue sanitizes various value types
func sanitizeValue(v interface{}) interface{} {
switch val := v.(type) {
case string:
return sanitizeString(val)
case []byte:
return "[BINARY_DATA]"
case map[string]interface{}:
sanitized := make(map[string]interface{})
for k, v := range val {
sanitized[k] = sanitizeValue(v)
}
return sanitized
case []interface{}:
sanitized := make([]interface{}, len(val))
for i, v := range val {
sanitized[i] = sanitizeValue(v)
}
return sanitized
default:
return v
}
}

200
pkg/metrics/metrics.go Normal file
View File

@@ -0,0 +1,200 @@
package metrics
import (
"sync"
"sync/atomic"
"time"
)
// NOTE: These metrics are for internal use only and are not exposed externally.
// They are primarily used for testing, debugging, and internal monitoring.
// No HTTP endpoints or external APIs are provided to access these metrics.
// Metrics collects application metrics
type Metrics struct {
// Connection metrics
TotalConnections int64
ActiveConnections int64
RejectedConnections int64
FailedConnections int64
// Request metrics
TotalRequests int64
SuccessfulRequests int64
FailedRequests int64
RateLimitedRequests int64
// DNS metrics
DNSQueries int64
DNSCustomRecords int64
DNSForwardedQueries int64
DNSFailedQueries int64
// UDP metrics
UDPPacketsReceived int64
UDPPacketsSent int64
UDPReplayDetected int64
// Performance metrics
AverageResponseTime time.Duration
MaxResponseTime time.Duration
MinResponseTime time.Duration
// System metrics
StartTime time.Time
mutex sync.RWMutex
}
// Global metrics instance
var globalMetrics = &Metrics{
StartTime: time.Now(),
}
// GetMetrics returns the global metrics instance
func GetMetrics() *Metrics {
return globalMetrics
}
// IncrementTotalConnections increments the total connections counter
func (m *Metrics) IncrementTotalConnections() {
atomic.AddInt64(&m.TotalConnections, 1)
}
// IncrementActiveConnections increments the active connections counter
func (m *Metrics) IncrementActiveConnections() {
atomic.AddInt64(&m.ActiveConnections, 1)
}
// DecrementActiveConnections decrements the active connections counter
func (m *Metrics) DecrementActiveConnections() {
atomic.AddInt64(&m.ActiveConnections, -1)
}
// IncrementRejectedConnections increments the rejected connections counter
func (m *Metrics) IncrementRejectedConnections() {
atomic.AddInt64(&m.RejectedConnections, 1)
}
// IncrementFailedConnections increments the failed connections counter
func (m *Metrics) IncrementFailedConnections() {
atomic.AddInt64(&m.FailedConnections, 1)
}
// IncrementTotalRequests increments the total requests counter
func (m *Metrics) IncrementTotalRequests() {
atomic.AddInt64(&m.TotalRequests, 1)
}
// IncrementSuccessfulRequests increments the successful requests counter
func (m *Metrics) IncrementSuccessfulRequests() {
atomic.AddInt64(&m.SuccessfulRequests, 1)
}
// IncrementFailedRequests increments the failed requests counter
func (m *Metrics) IncrementFailedRequests() {
atomic.AddInt64(&m.FailedRequests, 1)
}
// IncrementRateLimitedRequests increments the rate limited requests counter
func (m *Metrics) IncrementRateLimitedRequests() {
atomic.AddInt64(&m.RateLimitedRequests, 1)
}
// IncrementDNSQueries increments the DNS queries counter
func (m *Metrics) IncrementDNSQueries() {
atomic.AddInt64(&m.DNSQueries, 1)
}
// IncrementDNSCustomRecords increments the DNS custom records counter
func (m *Metrics) IncrementDNSCustomRecords() {
atomic.AddInt64(&m.DNSCustomRecords, 1)
}
// IncrementDNSForwardedQueries increments the DNS forwarded queries counter
func (m *Metrics) IncrementDNSForwardedQueries() {
atomic.AddInt64(&m.DNSForwardedQueries, 1)
}
// IncrementDNSFailedQueries increments the DNS failed queries counter
func (m *Metrics) IncrementDNSFailedQueries() {
atomic.AddInt64(&m.DNSFailedQueries, 1)
}
// IncrementUDPPacketsReceived increments the UDP packets received counter
func (m *Metrics) IncrementUDPPacketsReceived() {
atomic.AddInt64(&m.UDPPacketsReceived, 1)
}
// IncrementUDPPacketsSent increments the UDP packets sent counter
func (m *Metrics) IncrementUDPPacketsSent() {
atomic.AddInt64(&m.UDPPacketsSent, 1)
}
// IncrementUDPReplayDetected increments the UDP replay detected counter
func (m *Metrics) IncrementUDPReplayDetected() {
atomic.AddInt64(&m.UDPReplayDetected, 1)
}
// RecordResponseTime records a response time
func (m *Metrics) RecordResponseTime(duration time.Duration) {
m.mutex.Lock()
defer m.mutex.Unlock()
// Update average response time (simple moving average)
if m.AverageResponseTime == 0 {
m.AverageResponseTime = duration
} else {
m.AverageResponseTime = (m.AverageResponseTime + duration) / 2
}
// Update min/max response times
if m.MaxResponseTime == 0 || duration > m.MaxResponseTime {
m.MaxResponseTime = duration
}
if m.MinResponseTime == 0 || duration < m.MinResponseTime {
m.MinResponseTime = duration
}
}
// GetUptime returns the application uptime
func (m *Metrics) GetUptime() time.Duration {
return time.Since(m.StartTime)
}
// GetStats returns a snapshot of all metrics
func (m *Metrics) GetStats() map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
return map[string]interface{}{
"uptime": m.GetUptime().String(),
"connections": map[string]int64{
"total": atomic.LoadInt64(&m.TotalConnections),
"active": atomic.LoadInt64(&m.ActiveConnections),
"rejected": atomic.LoadInt64(&m.RejectedConnections),
"failed": atomic.LoadInt64(&m.FailedConnections),
},
"requests": map[string]int64{
"total": atomic.LoadInt64(&m.TotalRequests),
"successful": atomic.LoadInt64(&m.SuccessfulRequests),
"failed": atomic.LoadInt64(&m.FailedRequests),
"rate_limited": atomic.LoadInt64(&m.RateLimitedRequests),
},
"dns": map[string]int64{
"queries": atomic.LoadInt64(&m.DNSQueries),
"custom_records": atomic.LoadInt64(&m.DNSCustomRecords),
"forwarded": atomic.LoadInt64(&m.DNSForwardedQueries),
"failed": atomic.LoadInt64(&m.DNSFailedQueries),
},
"udp": map[string]int64{
"packets_received": atomic.LoadInt64(&m.UDPPacketsReceived),
"packets_sent": atomic.LoadInt64(&m.UDPPacketsSent),
"replay_detected": atomic.LoadInt64(&m.UDPReplayDetected),
},
"performance": map[string]interface{}{
"avg_response_time": m.AverageResponseTime.String(),
"max_response_time": m.MaxResponseTime.String(),
"min_response_time": m.MinResponseTime.String(),
},
}
}

View File

@@ -0,0 +1,89 @@
package ratelimit
import (
"context"
"sync"
"time"
)
// RateLimiter implements a token bucket rate limiter
type RateLimiter struct {
requestsPerSecond int
burstSize int
windowSize time.Duration
tokens int
lastRefill time.Time
mutex sync.Mutex
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(requestsPerSecond, burstSize int, windowSize time.Duration) *RateLimiter {
return &RateLimiter{
requestsPerSecond: requestsPerSecond,
burstSize: burstSize,
tokens: burstSize,
lastRefill: time.Now(),
windowSize: windowSize,
}
}
// Allow checks if a request is allowed under the rate limit
func (rl *RateLimiter) Allow() bool {
rl.mutex.Lock()
defer rl.mutex.Unlock()
now := time.Now()
// Calculate tokens to add based on time elapsed
elapsed := now.Sub(rl.lastRefill)
tokensToAdd := int(elapsed.Seconds() * float64(rl.requestsPerSecond))
if tokensToAdd > 0 {
rl.tokens += tokensToAdd
if rl.tokens > rl.burstSize {
rl.tokens = rl.burstSize
}
rl.lastRefill = now
}
// Check if we have tokens available
if rl.tokens > 0 {
rl.tokens--
return true
}
return false
}
// AllowWithContext checks if a request is allowed with context cancellation
func (rl *RateLimiter) AllowWithContext(ctx context.Context) bool {
select {
case <-ctx.Done():
return false
default:
return rl.Allow()
}
}
// Wait blocks until a request is allowed
func (rl *RateLimiter) Wait(ctx context.Context) error {
for {
if rl.Allow() {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(rl.windowSize / time.Duration(rl.requestsPerSecond)):
// Wait for next token
}
}
}
// GetStats returns current rate limiter statistics
func (rl *RateLimiter) GetStats() (tokens int, lastRefill time.Time) {
rl.mutex.Lock()
defer rl.mutex.Unlock()
return rl.tokens, rl.lastRefill
}

33
pkg/types/types.go Normal file
View File

@@ -0,0 +1,33 @@
package types
import "net"
// PortForwardRequest represents a port forwarding request
type PortForwardRequest struct {
LocalPort int
RemotePort int
Protocol string
TargetHost string
}
// UDPPacketHeader represents the header for UDP packets
type UDPPacketHeader struct {
ClientID string
PacketID uint64
Timestamp int64
}
// TaggedUDPPacket represents a UDP packet with tagging information
type TaggedUDPPacket struct {
Header UDPPacketHeader
Data []byte
}
// UDPRequest represents a UDP forwarding request
type UDPRequest struct {
LocalPort int
RemotePort int
Protocol string
Data []byte
ClientAddr *net.UDPAddr
}

112
pkg/types/types_test.go Normal file
View File

@@ -0,0 +1,112 @@
package types
import (
"net"
"testing"
)
func TestPortForwardRequest(t *testing.T) {
req := PortForwardRequest{
LocalPort: 8080,
RemotePort: 80,
Protocol: "tcp",
}
if req.LocalPort != 8080 {
t.Errorf("Expected LocalPort 8080, got %d", req.LocalPort)
}
if req.RemotePort != 80 {
t.Errorf("Expected RemotePort 80, got %d", req.RemotePort)
}
if req.Protocol != "tcp" {
t.Errorf("Expected Protocol 'tcp', got '%s'", req.Protocol)
}
}
func TestUDPPacketHeader(t *testing.T) {
header := UDPPacketHeader{
ClientID: "test-client",
PacketID: 12345,
Timestamp: 1640995200,
}
if header.ClientID != "test-client" {
t.Errorf("Expected ClientID 'test-client', got '%s'", header.ClientID)
}
if header.PacketID != 12345 {
t.Errorf("Expected PacketID 12345, got %d", header.PacketID)
}
if header.Timestamp != 1640995200 {
t.Errorf("Expected Timestamp 1640995200, got %d", header.Timestamp)
}
}
func TestTaggedUDPPacket(t *testing.T) {
header := UDPPacketHeader{
ClientID: "test-client",
PacketID: 12345,
Timestamp: 1640995200,
}
data := []byte("test data")
packet := TaggedUDPPacket{
Header: header,
Data: data,
}
if packet.Header.ClientID != "test-client" {
t.Errorf("Expected ClientID 'test-client', got '%s'", packet.Header.ClientID)
}
if len(packet.Data) != len(data) {
t.Errorf("Expected data length %d, got %d", len(data), len(packet.Data))
}
for i, b := range data {
if packet.Data[i] != b {
t.Errorf("Data mismatch at position %d", i)
}
}
}
func TestUDPRequest(t *testing.T) {
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:8080")
if err != nil {
t.Fatalf("Failed to resolve UDP address: %v", err)
}
data := []byte("test data")
req := UDPRequest{
LocalPort: 8080,
RemotePort: 80,
Protocol: "udp",
Data: data,
ClientAddr: addr,
}
if req.LocalPort != 8080 {
t.Errorf("Expected LocalPort 8080, got %d", req.LocalPort)
}
if req.RemotePort != 80 {
t.Errorf("Expected RemotePort 80, got %d", req.RemotePort)
}
if req.Protocol != "udp" {
t.Errorf("Expected Protocol 'udp', got '%s'", req.Protocol)
}
if len(req.Data) != len(data) {
t.Errorf("Expected data length %d, got %d", len(data), len(req.Data))
}
if req.ClientAddr.String() != "127.0.0.1:8080" {
t.Errorf("Expected ClientAddr '127.0.0.1:8080', got '%s'", req.ClientAddr.String())
}
}

47
pkg/version/version.go Normal file
View File

@@ -0,0 +1,47 @@
package version
import (
"fmt"
"runtime"
)
// These variables are set during build time via ldflags
var (
Version = "dev"
GitCommit = "unknown"
BuildDate = "unknown"
GoVersion = runtime.Version()
Platform = fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)
)
// Info holds version information
type Info struct {
Version string `json:"version"`
GitCommit string `json:"git_commit"`
BuildDate string `json:"build_date"`
GoVersion string `json:"go_version"`
Platform string `json:"platform"`
}
// Get returns version information
func Get() Info {
return Info{
Version: Version,
GitCommit: GitCommit,
BuildDate: BuildDate,
GoVersion: GoVersion,
Platform: Platform,
}
}
// String returns a formatted version string
func String() string {
info := Get()
return fmt.Sprintf("teleport version %s (commit: %s, built: %s, go: %s, platform: %s)",
info.Version, info.GitCommit, info.BuildDate, info.GoVersion, info.Platform)
}
// Short returns a short version string
func Short() string {
return fmt.Sprintf("teleport %s", Version)
}