package config import ( "crypto/rand" "encoding/hex" "fmt" "os" "strconv" "strings" "time" "teleport/pkg/encryption" "gopkg.in/yaml.v3" ) // Config represents the teleport configuration type Config struct { InstanceID string `yaml:"instance_id"` ListenAddress string `yaml:"listen_address"` RemoteAddress string `yaml:"remote_address"` Ports []PortRule `yaml:"ports"` EncryptionKey string `yaml:"encryption_key"` KeepAlive bool `yaml:"keep_alive"` ReadTimeout time.Duration `yaml:"read_timeout"` WriteTimeout time.Duration `yaml:"write_timeout"` MaxConnections int `yaml:"max_connections"` RateLimit RateLimitConfig `yaml:"rate_limit"` DNSServer DNSServerConfig `yaml:"dns_server"` } // PortRule defines a port forwarding rule type PortRule struct { LocalPort int `yaml:"local_port"` RemotePort int `yaml:"remote_port"` Protocol string `yaml:"protocol"` // "tcp" or "udp" TargetHost string `yaml:"target_host,omitempty"` // Target host for server-side forwarding (defaults to localhost) } // UnmarshalYAML implements custom YAML unmarshaling for PortRule func (p *PortRule) UnmarshalYAML(value *yaml.Node) error { // Only support URL-style format: "protocol://target:targetport" (server) or "protocol://targetport:localport" (client) if value.Kind != yaml.ScalarNode { return fmt.Errorf("port must be a string in format 'protocol://target:port' or 'protocol://targetport:localport'") } portStr := value.Value // Validate input length if len(portStr) > 512 { return fmt.Errorf("port string too long") } // Check if it starts with protocol:// if !strings.Contains(portStr, "://") { return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port' or 'protocol://targetport:localport')", portStr) } parts := strings.Split(portStr, "://") if len(parts) != 2 { return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port')", portStr) } protocol := parts[0] if protocol != "tcp" && protocol != "udp" { return fmt.Errorf("invalid protocol: %s (expected 'tcp' or 'udp')", protocol) } addressPart := parts[1] // Validate address part length if len(addressPart) > 256 { return fmt.Errorf("address part too long") } addressParts := strings.Split(addressPart, ":") if len(addressParts) == 2 { // Check if first part is a hostname/IP (contains dots or is not a number) - server format if _, err := strconv.Atoi(addressParts[0]); err != nil || strings.Contains(addressParts[0], ".") { // Server format: protocol://target:targetport targetHost := addressParts[0] // Validate target host if len(targetHost) > 253 { // RFC 1123 limit return fmt.Errorf("target host too long") } // Basic character validation for hostname for _, c := range targetHost { if c < 32 || c > 126 { return fmt.Errorf("invalid character in target host") } } targetPort, err := strconv.Atoi(addressParts[1]) if err != nil { return fmt.Errorf("invalid target port: %s", addressParts[1]) } // Validate port range if targetPort < 1 || targetPort > 65535 { return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort) } p.LocalPort = targetPort // Server listens on this port p.RemotePort = targetPort // Server forwards to this port on target p.Protocol = protocol p.TargetHost = targetHost return nil } else { // Client format: protocol://targetport:localport (first part is a number) targetPort, err := strconv.Atoi(addressParts[0]) if err != nil { return fmt.Errorf("invalid target port: %s", addressParts[0]) } localPort, err := strconv.Atoi(addressParts[1]) if err != nil { return fmt.Errorf("invalid local port: %s", addressParts[1]) } // Validate port ranges if targetPort < 1 || targetPort > 65535 { return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort) } if localPort < 1 || localPort > 65535 { return fmt.Errorf("invalid local port: %d (must be 1-65535)", localPort) } p.LocalPort = localPort // Client listens on this port p.RemotePort = targetPort // Client forwards to this port on teleport server p.Protocol = protocol p.TargetHost = "" // Client doesn't specify target host return nil } } else { return fmt.Errorf("invalid address format: %s (expected 'target:port' for server or 'targetport:localport' for client)", addressPart) } } // MarshalYAML implements custom YAML marshaling for PortRule func (p PortRule) MarshalYAML() (interface{}, error) { // Use new format: protocol://target:targetport (server) or protocol://targetport:localport (client) if p.TargetHost == "" { // Client format: protocol://targetport:localport return fmt.Sprintf("%s://%d:%d", p.Protocol, p.RemotePort, p.LocalPort), nil } else { // Server format: protocol://target:targetport return fmt.Sprintf("%s://%s:%d", p.Protocol, p.TargetHost, p.RemotePort), nil } } // RateLimitConfig defines rate limiting configuration type RateLimitConfig struct { Enabled bool `yaml:"enabled"` RequestsPerSecond int `yaml:"requests_per_second"` BurstSize int `yaml:"burst_size"` WindowSize time.Duration `yaml:"window_size"` } // DNSServerConfig defines DNS server configuration type DNSServerConfig struct { Enabled bool `yaml:"enabled"` ListenPort int `yaml:"listen_port"` BackupServer string `yaml:"backup_server"` CustomRecords []DNSRecord `yaml:"custom_records"` } // DNSRecord defines a custom DNS record type DNSRecord struct { Name string `yaml:"name"` Type string `yaml:"type"` // A, AAAA, CNAME, MX, TXT, NS, SRV Value string `yaml:"value"` TTL uint32 `yaml:"ttl"` Priority uint16 `yaml:"priority,omitempty"` // For MX and SRV records Weight uint16 `yaml:"weight,omitempty"` // For SRV records Port uint16 `yaml:"port,omitempty"` // For SRV records } // LoadConfig loads and parses the configuration file func LoadConfig(filename string) (*Config, error) { // Check file size before reading (max 1MB) fileInfo, err := os.Stat(filename) if err != nil { return nil, fmt.Errorf("failed to stat config file: %v", err) } if fileInfo.Size() > 1024*1024 { // 1MB limit return nil, fmt.Errorf("config file too large: %d bytes (max 1MB)", fileInfo.Size()) } data, err := os.ReadFile(filename) if err != nil { return nil, fmt.Errorf("failed to read config file: %v", err) } var config Config if err := yaml.Unmarshal(data, &config); err != nil { return nil, fmt.Errorf("failed to parse config file: %v", err) } // Set default values if config.ReadTimeout == 0 { config.ReadTimeout = 30 * time.Second } if config.WriteTimeout == 0 { config.WriteTimeout = 30 * time.Second } if config.MaxConnections == 0 { config.MaxConnections = 1000 } if config.RateLimit.RequestsPerSecond == 0 { config.RateLimit.RequestsPerSecond = 100 } if config.RateLimit.BurstSize == 0 { config.RateLimit.BurstSize = 200 } if config.RateLimit.WindowSize == 0 { config.RateLimit.WindowSize = 1 * time.Second } if config.DNSServer.ListenPort == 0 { config.DNSServer.ListenPort = 5353 } if config.DNSServer.BackupServer == "" { config.DNSServer.BackupServer = "8.8.8.8:53" } // Validate configuration if err := config.Validate(); err != nil { return nil, fmt.Errorf("configuration validation failed: %v", err) } return &config, nil } // DetectMode determines if this should run as server or client based on config func DetectMode(config *Config) (string, error) { hasListen := config.ListenAddress != "" hasRemote := config.RemoteAddress != "" if hasListen && hasRemote { return "", fmt.Errorf("configuration error: cannot have both 'listen_address' and 'remote_address' set. Choose one:\n" + "- Server: set 'listen_address' and leave 'remote_address' empty\n" + "- Client: set 'remote_address' and leave 'listen_address' empty") } if !hasListen && !hasRemote { return "", fmt.Errorf("configuration error: must specify either 'listen_address' (for server) or 'remote_address' (for client)") } if hasListen && !hasRemote { return "server", nil } if hasRemote && !hasListen { return "client", nil } return "", fmt.Errorf("internal error: could not determine mode") } // Validate validates the configuration func (c *Config) Validate() error { // Validate instance ID if c.InstanceID == "" { return fmt.Errorf("instance_id is required") } // Validate encryption key if err := encryption.ValidateEncryptionKey(c.EncryptionKey); err != nil { return fmt.Errorf("invalid encryption key: %v", err) } // Validate ports if len(c.Ports) == 0 { return fmt.Errorf("at least one port rule is required") } for i, port := range c.Ports { if err := port.Validate(); err != nil { return fmt.Errorf("port rule %d is invalid: %v", i, err) } } // Validate DNS server configuration if c.DNSServer.Enabled { if err := c.DNSServer.Validate(); err != nil { return fmt.Errorf("DNS server configuration is invalid: %v", err) } } // Validate rate limiting configuration if c.RateLimit.Enabled { if c.RateLimit.RequestsPerSecond <= 0 { return fmt.Errorf("rate limit requests_per_second must be positive") } if c.RateLimit.BurstSize <= 0 { return fmt.Errorf("rate limit burst_size must be positive") } if c.RateLimit.WindowSize <= 0 { return fmt.Errorf("rate limit window_size must be positive") } if c.RateLimit.BurstSize < c.RateLimit.RequestsPerSecond { return fmt.Errorf("rate limit burst_size should be at least requests_per_second") } if c.RateLimit.RequestsPerSecond > 10000 { return fmt.Errorf("rate limit requests_per_second too high: %d (maximum 10000)", c.RateLimit.RequestsPerSecond) } if c.RateLimit.BurstSize > 50000 { return fmt.Errorf("rate limit burst_size too high: %d (maximum 50000)", c.RateLimit.BurstSize) } } // Validate connection limits if c.MaxConnections < 0 { return fmt.Errorf("max_connections cannot be negative") } if c.MaxConnections > 100000 { return fmt.Errorf("max_connections too high: %d (maximum 100000)", c.MaxConnections) } // Validate timeout values if c.ReadTimeout < 0 { return fmt.Errorf("read_timeout cannot be negative") } if c.WriteTimeout < 0 { return fmt.Errorf("write_timeout cannot be negative") } if c.ReadTimeout > 24*time.Hour { return fmt.Errorf("read_timeout too high: %v (maximum 24 hours)", c.ReadTimeout) } if c.WriteTimeout > 24*time.Hour { return fmt.Errorf("write_timeout too high: %v (maximum 24 hours)", c.WriteTimeout) } return nil } // Validate validates a port rule func (p *PortRule) Validate() error { if p.LocalPort < 1 || p.LocalPort > 65535 { return fmt.Errorf("local_port must be between 1 and 65535, got %d", p.LocalPort) } if p.RemotePort < 1 || p.RemotePort > 65535 { return fmt.Errorf("remote_port must be between 1 and 65535, got %d", p.RemotePort) } if p.Protocol != "tcp" && p.Protocol != "udp" { return fmt.Errorf("protocol must be 'tcp' or 'udp', got '%s'", p.Protocol) } return nil } // Validate validates DNS server configuration func (d *DNSServerConfig) Validate() error { if d.ListenPort < 1 || d.ListenPort > 65535 { return fmt.Errorf("DNS listen_port must be between 1 and 65535, got %d", d.ListenPort) } if d.BackupServer == "" { return fmt.Errorf("DNS backup_server is required when DNS server is enabled") } // Validate backup server format parts := strings.Split(d.BackupServer, ":") if len(parts) != 2 { return fmt.Errorf("DNS backup_server must be in format 'host:port', got '%s'", d.BackupServer) } port, err := strconv.Atoi(parts[1]) if err != nil || port < 1 || port > 65535 { return fmt.Errorf("DNS backup_server port must be between 1 and 65535, got '%s'", parts[1]) } // Validate custom records for i, record := range d.CustomRecords { if err := record.Validate(); err != nil { return fmt.Errorf("DNS record %d is invalid: %v", i, err) } } return nil } // Validate validates a DNS record func (r *DNSRecord) Validate() error { if r.Name == "" { return fmt.Errorf("DNS record name is required") } // Validate DNS name length and format if len(r.Name) > 253 { return fmt.Errorf("DNS record name too long") } // Basic character validation for DNS name for _, c := range r.Name { if c < 32 || c > 126 { return fmt.Errorf("invalid character in DNS record name") } } validTypes := []string{"A", "AAAA", "CNAME", "MX", "TXT", "SRV", "NS"} valid := false for _, t := range validTypes { if strings.ToUpper(r.Type) == t { valid = true break } } if !valid { return fmt.Errorf("DNS record type must be one of %v, got '%s'", validTypes, r.Type) } if r.Value == "" { return fmt.Errorf("DNS record value is required") } // Validate DNS record value length if len(r.Value) > 1024 { return fmt.Errorf("DNS record value too long") } // Basic character validation for DNS value for _, c := range r.Value { if c < 32 || c > 126 { return fmt.Errorf("invalid character in DNS record value") } } if r.TTL == 0 { return fmt.Errorf("DNS record TTL must be greater than 0") } // Validate TTL range (reasonable limits) if r.TTL > 86400 { // 24 hours return fmt.Errorf("DNS record TTL too large (max 86400 seconds)") } // Note: Priority, Weight, and Port are uint16 types, so they cannot exceed 65535 // No additional validation needed for these fields as they are already constrained by their type return nil } // GenerateExampleConfig generates an example configuration file func GenerateExampleConfig(filename string) error { // Generate a strong encryption key strongKey, err := generateStrongEncryptionKey() if err != nil { return fmt.Errorf("failed to generate encryption key: %v", err) } var config Config if strings.Contains(filename, "server") { // Generate server configuration config = Config{ InstanceID: "teleport-server-01", ListenAddress: ":8080", RemoteAddress: "", Ports: []PortRule{ {LocalPort: 80, RemotePort: 80, Protocol: "tcp", TargetHost: "localhost"}, }, EncryptionKey: strongKey, KeepAlive: true, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, MaxConnections: 1000, RateLimit: RateLimitConfig{ Enabled: true, RequestsPerSecond: 100, BurstSize: 200, WindowSize: 1 * time.Second, }, DNSServer: DNSServerConfig{ ListenPort: 5353, BackupServer: "8.8.8.8:53", CustomRecords: []DNSRecord{}, }, } } else { // Generate client configuration config = Config{ InstanceID: "teleport-client-01", ListenAddress: "", RemoteAddress: "localhost:8080", Ports: []PortRule{ {LocalPort: 8080, RemotePort: 80, Protocol: "tcp", TargetHost: ""}, }, EncryptionKey: strongKey, KeepAlive: true, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, MaxConnections: 100, RateLimit: RateLimitConfig{ Enabled: true, RequestsPerSecond: 50, BurstSize: 100, WindowSize: 1 * time.Second, }, DNSServer: DNSServerConfig{ ListenPort: 5353, BackupServer: "8.8.8.8:53", CustomRecords: []DNSRecord{ {Name: "app.local", Type: "A", Value: "127.0.0.1", TTL: 300}, }, }, } } data, err := yaml.Marshal(&config) if err != nil { return fmt.Errorf("failed to marshal config: %v", err) } err = os.WriteFile(filename, data, 0644) if err != nil { return fmt.Errorf("failed to write config file: %v", err) } fmt.Printf("Generated example configuration: %s\n", filename) fmt.Printf("Edit the configuration file and run: ./teleport -config %s\n", filename) return nil } // generateStrongEncryptionKey generates a cryptographically secure encryption key func generateStrongEncryptionKey() (string, error) { // Generate 32 random bytes (256 bits) for a strong encryption key bytes := make([]byte, 32) if _, err := rand.Read(bytes); err != nil { return "", fmt.Errorf("failed to generate random key: %v", err) } // Convert to hexadecimal string for easy copying return hex.EncodeToString(bytes), nil }