Files
teleport/pkg/config/config.go
Justin Harms d24d1dc5ae Add initial project structure with core functionality
- Created a new Go module named 'teleport' for secure port forwarding.
- Added essential files including .gitignore, LICENSE, and README.md with project details.
- Implemented configuration management with YAML support in config package.
- Developed core client and server functionalities for handling port forwarding.
- Introduced DNS server capabilities and integrated logging with sanitization.
- Established rate limiting and metrics tracking for performance monitoring.
- Included comprehensive tests for core components and functionalities.
- Set up CI workflows for automated testing and release management using Gitea actions.
2025-09-20 18:07:08 -05:00

538 lines
16 KiB
Go

package config
import (
"crypto/rand"
"encoding/hex"
"fmt"
"os"
"strconv"
"strings"
"time"
"teleport/pkg/encryption"
"gopkg.in/yaml.v3"
)
// Config represents the teleport configuration
type Config struct {
InstanceID string `yaml:"instance_id"`
ListenAddress string `yaml:"listen_address"`
RemoteAddress string `yaml:"remote_address"`
Ports []PortRule `yaml:"ports"`
EncryptionKey string `yaml:"encryption_key"`
KeepAlive bool `yaml:"keep_alive"`
ReadTimeout time.Duration `yaml:"read_timeout"`
WriteTimeout time.Duration `yaml:"write_timeout"`
MaxConnections int `yaml:"max_connections"`
RateLimit RateLimitConfig `yaml:"rate_limit"`
DNSServer DNSServerConfig `yaml:"dns_server"`
}
// PortRule defines a port forwarding rule
type PortRule struct {
LocalPort int `yaml:"local_port"`
RemotePort int `yaml:"remote_port"`
Protocol string `yaml:"protocol"` // "tcp" or "udp"
TargetHost string `yaml:"target_host,omitempty"` // Target host for server-side forwarding (defaults to localhost)
}
// UnmarshalYAML implements custom YAML unmarshaling for PortRule
func (p *PortRule) UnmarshalYAML(value *yaml.Node) error {
// Only support URL-style format: "protocol://target:targetport" (server) or "protocol://targetport:localport" (client)
if value.Kind != yaml.ScalarNode {
return fmt.Errorf("port must be a string in format 'protocol://target:port' or 'protocol://targetport:localport'")
}
portStr := value.Value
// Validate input length
if len(portStr) > 512 {
return fmt.Errorf("port string too long")
}
// Check if it starts with protocol://
if !strings.Contains(portStr, "://") {
return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port' or 'protocol://targetport:localport')", portStr)
}
parts := strings.Split(portStr, "://")
if len(parts) != 2 {
return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port')", portStr)
}
protocol := parts[0]
if protocol != "tcp" && protocol != "udp" {
return fmt.Errorf("invalid protocol: %s (expected 'tcp' or 'udp')", protocol)
}
addressPart := parts[1]
// Validate address part length
if len(addressPart) > 256 {
return fmt.Errorf("address part too long")
}
addressParts := strings.Split(addressPart, ":")
if len(addressParts) == 2 {
// Check if first part is a hostname/IP (contains dots or is not a number) - server format
if _, err := strconv.Atoi(addressParts[0]); err != nil || strings.Contains(addressParts[0], ".") {
// Server format: protocol://target:targetport
targetHost := addressParts[0]
// Validate target host
if len(targetHost) > 253 { // RFC 1123 limit
return fmt.Errorf("target host too long")
}
// Basic character validation for hostname
for _, c := range targetHost {
if c < 32 || c > 126 {
return fmt.Errorf("invalid character in target host")
}
}
targetPort, err := strconv.Atoi(addressParts[1])
if err != nil {
return fmt.Errorf("invalid target port: %s", addressParts[1])
}
// Validate port range
if targetPort < 1 || targetPort > 65535 {
return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort)
}
p.LocalPort = targetPort // Server listens on this port
p.RemotePort = targetPort // Server forwards to this port on target
p.Protocol = protocol
p.TargetHost = targetHost
return nil
} else {
// Client format: protocol://targetport:localport (first part is a number)
targetPort, err := strconv.Atoi(addressParts[0])
if err != nil {
return fmt.Errorf("invalid target port: %s", addressParts[0])
}
localPort, err := strconv.Atoi(addressParts[1])
if err != nil {
return fmt.Errorf("invalid local port: %s", addressParts[1])
}
// Validate port ranges
if targetPort < 1 || targetPort > 65535 {
return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort)
}
if localPort < 1 || localPort > 65535 {
return fmt.Errorf("invalid local port: %d (must be 1-65535)", localPort)
}
p.LocalPort = localPort // Client listens on this port
p.RemotePort = targetPort // Client forwards to this port on teleport server
p.Protocol = protocol
p.TargetHost = "" // Client doesn't specify target host
return nil
}
} else {
return fmt.Errorf("invalid address format: %s (expected 'target:port' for server or 'targetport:localport' for client)", addressPart)
}
}
// MarshalYAML implements custom YAML marshaling for PortRule
func (p PortRule) MarshalYAML() (interface{}, error) {
// Use new format: protocol://target:targetport (server) or protocol://targetport:localport (client)
if p.TargetHost == "" {
// Client format: protocol://targetport:localport
return fmt.Sprintf("%s://%d:%d", p.Protocol, p.RemotePort, p.LocalPort), nil
} else {
// Server format: protocol://target:targetport
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.TargetHost, p.RemotePort), nil
}
}
// RateLimitConfig defines rate limiting configuration
type RateLimitConfig struct {
Enabled bool `yaml:"enabled"`
RequestsPerSecond int `yaml:"requests_per_second"`
BurstSize int `yaml:"burst_size"`
WindowSize time.Duration `yaml:"window_size"`
}
// DNSServerConfig defines DNS server configuration
type DNSServerConfig struct {
Enabled bool `yaml:"enabled"`
ListenPort int `yaml:"listen_port"`
BackupServer string `yaml:"backup_server"`
CustomRecords []DNSRecord `yaml:"custom_records"`
}
// DNSRecord defines a custom DNS record
type DNSRecord struct {
Name string `yaml:"name"`
Type string `yaml:"type"` // A, AAAA, CNAME, MX, TXT, NS, SRV
Value string `yaml:"value"`
TTL uint32 `yaml:"ttl"`
Priority uint16 `yaml:"priority,omitempty"` // For MX and SRV records
Weight uint16 `yaml:"weight,omitempty"` // For SRV records
Port uint16 `yaml:"port,omitempty"` // For SRV records
}
// LoadConfig loads and parses the configuration file
func LoadConfig(filename string) (*Config, error) {
// Check file size before reading (max 1MB)
fileInfo, err := os.Stat(filename)
if err != nil {
return nil, fmt.Errorf("failed to stat config file: %v", err)
}
if fileInfo.Size() > 1024*1024 { // 1MB limit
return nil, fmt.Errorf("config file too large: %d bytes (max 1MB)", fileInfo.Size())
}
data, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %v", err)
}
var config Config
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %v", err)
}
// Set default values
if config.ReadTimeout == 0 {
config.ReadTimeout = 30 * time.Second
}
if config.WriteTimeout == 0 {
config.WriteTimeout = 30 * time.Second
}
if config.MaxConnections == 0 {
config.MaxConnections = 1000
}
if config.RateLimit.RequestsPerSecond == 0 {
config.RateLimit.RequestsPerSecond = 100
}
if config.RateLimit.BurstSize == 0 {
config.RateLimit.BurstSize = 200
}
if config.RateLimit.WindowSize == 0 {
config.RateLimit.WindowSize = 1 * time.Second
}
if config.DNSServer.ListenPort == 0 {
config.DNSServer.ListenPort = 5353
}
if config.DNSServer.BackupServer == "" {
config.DNSServer.BackupServer = "8.8.8.8:53"
}
// Validate configuration
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("configuration validation failed: %v", err)
}
return &config, nil
}
// DetectMode determines if this should run as server or client based on config
func DetectMode(config *Config) (string, error) {
hasListen := config.ListenAddress != ""
hasRemote := config.RemoteAddress != ""
if hasListen && hasRemote {
return "", fmt.Errorf("configuration error: cannot have both 'listen_address' and 'remote_address' set. Choose one:\n" +
"- Server: set 'listen_address' and leave 'remote_address' empty\n" +
"- Client: set 'remote_address' and leave 'listen_address' empty")
}
if !hasListen && !hasRemote {
return "", fmt.Errorf("configuration error: must specify either 'listen_address' (for server) or 'remote_address' (for client)")
}
if hasListen && !hasRemote {
return "server", nil
}
if hasRemote && !hasListen {
return "client", nil
}
return "", fmt.Errorf("internal error: could not determine mode")
}
// Validate validates the configuration
func (c *Config) Validate() error {
// Validate instance ID
if c.InstanceID == "" {
return fmt.Errorf("instance_id is required")
}
// Validate encryption key
if err := encryption.ValidateEncryptionKey(c.EncryptionKey); err != nil {
return fmt.Errorf("invalid encryption key: %v", err)
}
// Validate ports
if len(c.Ports) == 0 {
return fmt.Errorf("at least one port rule is required")
}
for i, port := range c.Ports {
if err := port.Validate(); err != nil {
return fmt.Errorf("port rule %d is invalid: %v", i, err)
}
}
// Validate DNS server configuration
if c.DNSServer.Enabled {
if err := c.DNSServer.Validate(); err != nil {
return fmt.Errorf("DNS server configuration is invalid: %v", err)
}
}
// Validate rate limiting configuration
if c.RateLimit.Enabled {
if c.RateLimit.RequestsPerSecond <= 0 {
return fmt.Errorf("rate limit requests_per_second must be positive")
}
if c.RateLimit.BurstSize <= 0 {
return fmt.Errorf("rate limit burst_size must be positive")
}
if c.RateLimit.WindowSize <= 0 {
return fmt.Errorf("rate limit window_size must be positive")
}
if c.RateLimit.BurstSize < c.RateLimit.RequestsPerSecond {
return fmt.Errorf("rate limit burst_size should be at least requests_per_second")
}
if c.RateLimit.RequestsPerSecond > 10000 {
return fmt.Errorf("rate limit requests_per_second too high: %d (maximum 10000)", c.RateLimit.RequestsPerSecond)
}
if c.RateLimit.BurstSize > 50000 {
return fmt.Errorf("rate limit burst_size too high: %d (maximum 50000)", c.RateLimit.BurstSize)
}
}
// Validate connection limits
if c.MaxConnections < 0 {
return fmt.Errorf("max_connections cannot be negative")
}
if c.MaxConnections > 100000 {
return fmt.Errorf("max_connections too high: %d (maximum 100000)", c.MaxConnections)
}
// Validate timeout values
if c.ReadTimeout < 0 {
return fmt.Errorf("read_timeout cannot be negative")
}
if c.WriteTimeout < 0 {
return fmt.Errorf("write_timeout cannot be negative")
}
if c.ReadTimeout > 24*time.Hour {
return fmt.Errorf("read_timeout too high: %v (maximum 24 hours)", c.ReadTimeout)
}
if c.WriteTimeout > 24*time.Hour {
return fmt.Errorf("write_timeout too high: %v (maximum 24 hours)", c.WriteTimeout)
}
return nil
}
// Validate validates a port rule
func (p *PortRule) Validate() error {
if p.LocalPort < 1 || p.LocalPort > 65535 {
return fmt.Errorf("local_port must be between 1 and 65535, got %d", p.LocalPort)
}
if p.RemotePort < 1 || p.RemotePort > 65535 {
return fmt.Errorf("remote_port must be between 1 and 65535, got %d", p.RemotePort)
}
if p.Protocol != "tcp" && p.Protocol != "udp" {
return fmt.Errorf("protocol must be 'tcp' or 'udp', got '%s'", p.Protocol)
}
return nil
}
// Validate validates DNS server configuration
func (d *DNSServerConfig) Validate() error {
if d.ListenPort < 1 || d.ListenPort > 65535 {
return fmt.Errorf("DNS listen_port must be between 1 and 65535, got %d", d.ListenPort)
}
if d.BackupServer == "" {
return fmt.Errorf("DNS backup_server is required when DNS server is enabled")
}
// Validate backup server format
parts := strings.Split(d.BackupServer, ":")
if len(parts) != 2 {
return fmt.Errorf("DNS backup_server must be in format 'host:port', got '%s'", d.BackupServer)
}
port, err := strconv.Atoi(parts[1])
if err != nil || port < 1 || port > 65535 {
return fmt.Errorf("DNS backup_server port must be between 1 and 65535, got '%s'", parts[1])
}
// Validate custom records
for i, record := range d.CustomRecords {
if err := record.Validate(); err != nil {
return fmt.Errorf("DNS record %d is invalid: %v", i, err)
}
}
return nil
}
// Validate validates a DNS record
func (r *DNSRecord) Validate() error {
if r.Name == "" {
return fmt.Errorf("DNS record name is required")
}
// Validate DNS name length and format
if len(r.Name) > 253 {
return fmt.Errorf("DNS record name too long")
}
// Basic character validation for DNS name
for _, c := range r.Name {
if c < 32 || c > 126 {
return fmt.Errorf("invalid character in DNS record name")
}
}
validTypes := []string{"A", "AAAA", "CNAME", "MX", "TXT", "SRV", "NS"}
valid := false
for _, t := range validTypes {
if strings.ToUpper(r.Type) == t {
valid = true
break
}
}
if !valid {
return fmt.Errorf("DNS record type must be one of %v, got '%s'", validTypes, r.Type)
}
if r.Value == "" {
return fmt.Errorf("DNS record value is required")
}
// Validate DNS record value length
if len(r.Value) > 1024 {
return fmt.Errorf("DNS record value too long")
}
// Basic character validation for DNS value
for _, c := range r.Value {
if c < 32 || c > 126 {
return fmt.Errorf("invalid character in DNS record value")
}
}
if r.TTL == 0 {
return fmt.Errorf("DNS record TTL must be greater than 0")
}
// Validate TTL range (reasonable limits)
if r.TTL > 86400 { // 24 hours
return fmt.Errorf("DNS record TTL too large (max 86400 seconds)")
}
// Note: Priority, Weight, and Port are uint16 types, so they cannot exceed 65535
// No additional validation needed for these fields as they are already constrained by their type
return nil
}
// GenerateExampleConfig generates an example configuration file
func GenerateExampleConfig(filename string) error {
// Generate a strong encryption key
strongKey, err := generateStrongEncryptionKey()
if err != nil {
return fmt.Errorf("failed to generate encryption key: %v", err)
}
var config Config
if strings.Contains(filename, "server") {
// Generate server configuration
config = Config{
InstanceID: "teleport-server-01",
ListenAddress: ":8080",
RemoteAddress: "",
Ports: []PortRule{
{LocalPort: 80, RemotePort: 80, Protocol: "tcp", TargetHost: "localhost"},
},
EncryptionKey: strongKey,
KeepAlive: true,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
MaxConnections: 1000,
RateLimit: RateLimitConfig{
Enabled: true,
RequestsPerSecond: 100,
BurstSize: 200,
WindowSize: 1 * time.Second,
},
DNSServer: DNSServerConfig{
ListenPort: 5353,
BackupServer: "8.8.8.8:53",
CustomRecords: []DNSRecord{},
},
}
} else {
// Generate client configuration
config = Config{
InstanceID: "teleport-client-01",
ListenAddress: "",
RemoteAddress: "localhost:8080",
Ports: []PortRule{
{LocalPort: 8080, RemotePort: 80, Protocol: "tcp", TargetHost: ""},
},
EncryptionKey: strongKey,
KeepAlive: true,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
MaxConnections: 100,
RateLimit: RateLimitConfig{
Enabled: true,
RequestsPerSecond: 50,
BurstSize: 100,
WindowSize: 1 * time.Second,
},
DNSServer: DNSServerConfig{
ListenPort: 5353,
BackupServer: "8.8.8.8:53",
CustomRecords: []DNSRecord{
{Name: "app.local", Type: "A", Value: "127.0.0.1", TTL: 300},
},
},
}
}
data, err := yaml.Marshal(&config)
if err != nil {
return fmt.Errorf("failed to marshal config: %v", err)
}
err = os.WriteFile(filename, data, 0644)
if err != nil {
return fmt.Errorf("failed to write config file: %v", err)
}
fmt.Printf("Generated example configuration: %s\n", filename)
fmt.Printf("Edit the configuration file and run: ./teleport -config %s\n", filename)
return nil
}
// generateStrongEncryptionKey generates a cryptographically secure encryption key
func generateStrongEncryptionKey() (string, error) {
// Generate 32 random bytes (256 bits) for a strong encryption key
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random key: %v", err)
}
// Convert to hexadecimal string for easy copying
return hex.EncodeToString(bytes), nil
}