- Created a new Go module named 'teleport' for secure port forwarding. - Added essential files including .gitignore, LICENSE, and README.md with project details. - Implemented configuration management with YAML support in config package. - Developed core client and server functionalities for handling port forwarding. - Introduced DNS server capabilities and integrated logging with sanitization. - Established rate limiting and metrics tracking for performance monitoring. - Included comprehensive tests for core components and functionalities. - Set up CI workflows for automated testing and release management using Gitea actions.
538 lines
16 KiB
Go
538 lines
16 KiB
Go
package config
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"teleport/pkg/encryption"
|
|
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
// Config represents the teleport configuration
|
|
type Config struct {
|
|
InstanceID string `yaml:"instance_id"`
|
|
ListenAddress string `yaml:"listen_address"`
|
|
RemoteAddress string `yaml:"remote_address"`
|
|
Ports []PortRule `yaml:"ports"`
|
|
EncryptionKey string `yaml:"encryption_key"`
|
|
KeepAlive bool `yaml:"keep_alive"`
|
|
ReadTimeout time.Duration `yaml:"read_timeout"`
|
|
WriteTimeout time.Duration `yaml:"write_timeout"`
|
|
MaxConnections int `yaml:"max_connections"`
|
|
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
|
DNSServer DNSServerConfig `yaml:"dns_server"`
|
|
}
|
|
|
|
// PortRule defines a port forwarding rule
|
|
type PortRule struct {
|
|
LocalPort int `yaml:"local_port"`
|
|
RemotePort int `yaml:"remote_port"`
|
|
Protocol string `yaml:"protocol"` // "tcp" or "udp"
|
|
TargetHost string `yaml:"target_host,omitempty"` // Target host for server-side forwarding (defaults to localhost)
|
|
}
|
|
|
|
// UnmarshalYAML implements custom YAML unmarshaling for PortRule
|
|
func (p *PortRule) UnmarshalYAML(value *yaml.Node) error {
|
|
// Only support URL-style format: "protocol://target:targetport" (server) or "protocol://targetport:localport" (client)
|
|
if value.Kind != yaml.ScalarNode {
|
|
return fmt.Errorf("port must be a string in format 'protocol://target:port' or 'protocol://targetport:localport'")
|
|
}
|
|
|
|
portStr := value.Value
|
|
|
|
// Validate input length
|
|
if len(portStr) > 512 {
|
|
return fmt.Errorf("port string too long")
|
|
}
|
|
|
|
// Check if it starts with protocol://
|
|
if !strings.Contains(portStr, "://") {
|
|
return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port' or 'protocol://targetport:localport')", portStr)
|
|
}
|
|
|
|
parts := strings.Split(portStr, "://")
|
|
if len(parts) != 2 {
|
|
return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port')", portStr)
|
|
}
|
|
|
|
protocol := parts[0]
|
|
if protocol != "tcp" && protocol != "udp" {
|
|
return fmt.Errorf("invalid protocol: %s (expected 'tcp' or 'udp')", protocol)
|
|
}
|
|
|
|
addressPart := parts[1]
|
|
|
|
// Validate address part length
|
|
if len(addressPart) > 256 {
|
|
return fmt.Errorf("address part too long")
|
|
}
|
|
|
|
addressParts := strings.Split(addressPart, ":")
|
|
|
|
if len(addressParts) == 2 {
|
|
// Check if first part is a hostname/IP (contains dots or is not a number) - server format
|
|
if _, err := strconv.Atoi(addressParts[0]); err != nil || strings.Contains(addressParts[0], ".") {
|
|
// Server format: protocol://target:targetport
|
|
targetHost := addressParts[0]
|
|
|
|
// Validate target host
|
|
if len(targetHost) > 253 { // RFC 1123 limit
|
|
return fmt.Errorf("target host too long")
|
|
}
|
|
|
|
// Basic character validation for hostname
|
|
for _, c := range targetHost {
|
|
if c < 32 || c > 126 {
|
|
return fmt.Errorf("invalid character in target host")
|
|
}
|
|
}
|
|
|
|
targetPort, err := strconv.Atoi(addressParts[1])
|
|
if err != nil {
|
|
return fmt.Errorf("invalid target port: %s", addressParts[1])
|
|
}
|
|
|
|
// Validate port range
|
|
if targetPort < 1 || targetPort > 65535 {
|
|
return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort)
|
|
}
|
|
|
|
p.LocalPort = targetPort // Server listens on this port
|
|
p.RemotePort = targetPort // Server forwards to this port on target
|
|
p.Protocol = protocol
|
|
p.TargetHost = targetHost
|
|
return nil
|
|
} else {
|
|
// Client format: protocol://targetport:localport (first part is a number)
|
|
targetPort, err := strconv.Atoi(addressParts[0])
|
|
if err != nil {
|
|
return fmt.Errorf("invalid target port: %s", addressParts[0])
|
|
}
|
|
localPort, err := strconv.Atoi(addressParts[1])
|
|
if err != nil {
|
|
return fmt.Errorf("invalid local port: %s", addressParts[1])
|
|
}
|
|
|
|
// Validate port ranges
|
|
if targetPort < 1 || targetPort > 65535 {
|
|
return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort)
|
|
}
|
|
if localPort < 1 || localPort > 65535 {
|
|
return fmt.Errorf("invalid local port: %d (must be 1-65535)", localPort)
|
|
}
|
|
|
|
p.LocalPort = localPort // Client listens on this port
|
|
p.RemotePort = targetPort // Client forwards to this port on teleport server
|
|
p.Protocol = protocol
|
|
p.TargetHost = "" // Client doesn't specify target host
|
|
return nil
|
|
}
|
|
} else {
|
|
return fmt.Errorf("invalid address format: %s (expected 'target:port' for server or 'targetport:localport' for client)", addressPart)
|
|
}
|
|
}
|
|
|
|
// MarshalYAML implements custom YAML marshaling for PortRule
|
|
func (p PortRule) MarshalYAML() (interface{}, error) {
|
|
// Use new format: protocol://target:targetport (server) or protocol://targetport:localport (client)
|
|
if p.TargetHost == "" {
|
|
// Client format: protocol://targetport:localport
|
|
return fmt.Sprintf("%s://%d:%d", p.Protocol, p.RemotePort, p.LocalPort), nil
|
|
} else {
|
|
// Server format: protocol://target:targetport
|
|
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.TargetHost, p.RemotePort), nil
|
|
}
|
|
}
|
|
|
|
// RateLimitConfig defines rate limiting configuration
|
|
type RateLimitConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
RequestsPerSecond int `yaml:"requests_per_second"`
|
|
BurstSize int `yaml:"burst_size"`
|
|
WindowSize time.Duration `yaml:"window_size"`
|
|
}
|
|
|
|
// DNSServerConfig defines DNS server configuration
|
|
type DNSServerConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
ListenPort int `yaml:"listen_port"`
|
|
BackupServer string `yaml:"backup_server"`
|
|
CustomRecords []DNSRecord `yaml:"custom_records"`
|
|
}
|
|
|
|
// DNSRecord defines a custom DNS record
|
|
type DNSRecord struct {
|
|
Name string `yaml:"name"`
|
|
Type string `yaml:"type"` // A, AAAA, CNAME, MX, TXT, NS, SRV
|
|
Value string `yaml:"value"`
|
|
TTL uint32 `yaml:"ttl"`
|
|
Priority uint16 `yaml:"priority,omitempty"` // For MX and SRV records
|
|
Weight uint16 `yaml:"weight,omitempty"` // For SRV records
|
|
Port uint16 `yaml:"port,omitempty"` // For SRV records
|
|
}
|
|
|
|
// LoadConfig loads and parses the configuration file
|
|
func LoadConfig(filename string) (*Config, error) {
|
|
// Check file size before reading (max 1MB)
|
|
fileInfo, err := os.Stat(filename)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to stat config file: %v", err)
|
|
}
|
|
if fileInfo.Size() > 1024*1024 { // 1MB limit
|
|
return nil, fmt.Errorf("config file too large: %d bytes (max 1MB)", fileInfo.Size())
|
|
}
|
|
|
|
data, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read config file: %v", err)
|
|
}
|
|
|
|
var config Config
|
|
if err := yaml.Unmarshal(data, &config); err != nil {
|
|
return nil, fmt.Errorf("failed to parse config file: %v", err)
|
|
}
|
|
|
|
// Set default values
|
|
if config.ReadTimeout == 0 {
|
|
config.ReadTimeout = 30 * time.Second
|
|
}
|
|
if config.WriteTimeout == 0 {
|
|
config.WriteTimeout = 30 * time.Second
|
|
}
|
|
if config.MaxConnections == 0 {
|
|
config.MaxConnections = 1000
|
|
}
|
|
if config.RateLimit.RequestsPerSecond == 0 {
|
|
config.RateLimit.RequestsPerSecond = 100
|
|
}
|
|
if config.RateLimit.BurstSize == 0 {
|
|
config.RateLimit.BurstSize = 200
|
|
}
|
|
if config.RateLimit.WindowSize == 0 {
|
|
config.RateLimit.WindowSize = 1 * time.Second
|
|
}
|
|
if config.DNSServer.ListenPort == 0 {
|
|
config.DNSServer.ListenPort = 5353
|
|
}
|
|
if config.DNSServer.BackupServer == "" {
|
|
config.DNSServer.BackupServer = "8.8.8.8:53"
|
|
}
|
|
|
|
// Validate configuration
|
|
if err := config.Validate(); err != nil {
|
|
return nil, fmt.Errorf("configuration validation failed: %v", err)
|
|
}
|
|
|
|
return &config, nil
|
|
}
|
|
|
|
// DetectMode determines if this should run as server or client based on config
|
|
func DetectMode(config *Config) (string, error) {
|
|
hasListen := config.ListenAddress != ""
|
|
hasRemote := config.RemoteAddress != ""
|
|
|
|
if hasListen && hasRemote {
|
|
return "", fmt.Errorf("configuration error: cannot have both 'listen_address' and 'remote_address' set. Choose one:\n" +
|
|
"- Server: set 'listen_address' and leave 'remote_address' empty\n" +
|
|
"- Client: set 'remote_address' and leave 'listen_address' empty")
|
|
}
|
|
|
|
if !hasListen && !hasRemote {
|
|
return "", fmt.Errorf("configuration error: must specify either 'listen_address' (for server) or 'remote_address' (for client)")
|
|
}
|
|
|
|
if hasListen && !hasRemote {
|
|
return "server", nil
|
|
}
|
|
|
|
if hasRemote && !hasListen {
|
|
return "client", nil
|
|
}
|
|
|
|
return "", fmt.Errorf("internal error: could not determine mode")
|
|
}
|
|
|
|
// Validate validates the configuration
|
|
func (c *Config) Validate() error {
|
|
// Validate instance ID
|
|
if c.InstanceID == "" {
|
|
return fmt.Errorf("instance_id is required")
|
|
}
|
|
|
|
// Validate encryption key
|
|
if err := encryption.ValidateEncryptionKey(c.EncryptionKey); err != nil {
|
|
return fmt.Errorf("invalid encryption key: %v", err)
|
|
}
|
|
|
|
// Validate ports
|
|
if len(c.Ports) == 0 {
|
|
return fmt.Errorf("at least one port rule is required")
|
|
}
|
|
|
|
for i, port := range c.Ports {
|
|
if err := port.Validate(); err != nil {
|
|
return fmt.Errorf("port rule %d is invalid: %v", i, err)
|
|
}
|
|
}
|
|
|
|
// Validate DNS server configuration
|
|
if c.DNSServer.Enabled {
|
|
if err := c.DNSServer.Validate(); err != nil {
|
|
return fmt.Errorf("DNS server configuration is invalid: %v", err)
|
|
}
|
|
}
|
|
|
|
// Validate rate limiting configuration
|
|
if c.RateLimit.Enabled {
|
|
if c.RateLimit.RequestsPerSecond <= 0 {
|
|
return fmt.Errorf("rate limit requests_per_second must be positive")
|
|
}
|
|
if c.RateLimit.BurstSize <= 0 {
|
|
return fmt.Errorf("rate limit burst_size must be positive")
|
|
}
|
|
if c.RateLimit.WindowSize <= 0 {
|
|
return fmt.Errorf("rate limit window_size must be positive")
|
|
}
|
|
if c.RateLimit.BurstSize < c.RateLimit.RequestsPerSecond {
|
|
return fmt.Errorf("rate limit burst_size should be at least requests_per_second")
|
|
}
|
|
if c.RateLimit.RequestsPerSecond > 10000 {
|
|
return fmt.Errorf("rate limit requests_per_second too high: %d (maximum 10000)", c.RateLimit.RequestsPerSecond)
|
|
}
|
|
if c.RateLimit.BurstSize > 50000 {
|
|
return fmt.Errorf("rate limit burst_size too high: %d (maximum 50000)", c.RateLimit.BurstSize)
|
|
}
|
|
}
|
|
|
|
// Validate connection limits
|
|
if c.MaxConnections < 0 {
|
|
return fmt.Errorf("max_connections cannot be negative")
|
|
}
|
|
if c.MaxConnections > 100000 {
|
|
return fmt.Errorf("max_connections too high: %d (maximum 100000)", c.MaxConnections)
|
|
}
|
|
|
|
// Validate timeout values
|
|
if c.ReadTimeout < 0 {
|
|
return fmt.Errorf("read_timeout cannot be negative")
|
|
}
|
|
if c.WriteTimeout < 0 {
|
|
return fmt.Errorf("write_timeout cannot be negative")
|
|
}
|
|
if c.ReadTimeout > 24*time.Hour {
|
|
return fmt.Errorf("read_timeout too high: %v (maximum 24 hours)", c.ReadTimeout)
|
|
}
|
|
if c.WriteTimeout > 24*time.Hour {
|
|
return fmt.Errorf("write_timeout too high: %v (maximum 24 hours)", c.WriteTimeout)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Validate validates a port rule
|
|
func (p *PortRule) Validate() error {
|
|
if p.LocalPort < 1 || p.LocalPort > 65535 {
|
|
return fmt.Errorf("local_port must be between 1 and 65535, got %d", p.LocalPort)
|
|
}
|
|
|
|
if p.RemotePort < 1 || p.RemotePort > 65535 {
|
|
return fmt.Errorf("remote_port must be between 1 and 65535, got %d", p.RemotePort)
|
|
}
|
|
|
|
if p.Protocol != "tcp" && p.Protocol != "udp" {
|
|
return fmt.Errorf("protocol must be 'tcp' or 'udp', got '%s'", p.Protocol)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Validate validates DNS server configuration
|
|
func (d *DNSServerConfig) Validate() error {
|
|
if d.ListenPort < 1 || d.ListenPort > 65535 {
|
|
return fmt.Errorf("DNS listen_port must be between 1 and 65535, got %d", d.ListenPort)
|
|
}
|
|
|
|
if d.BackupServer == "" {
|
|
return fmt.Errorf("DNS backup_server is required when DNS server is enabled")
|
|
}
|
|
|
|
// Validate backup server format
|
|
parts := strings.Split(d.BackupServer, ":")
|
|
if len(parts) != 2 {
|
|
return fmt.Errorf("DNS backup_server must be in format 'host:port', got '%s'", d.BackupServer)
|
|
}
|
|
|
|
port, err := strconv.Atoi(parts[1])
|
|
if err != nil || port < 1 || port > 65535 {
|
|
return fmt.Errorf("DNS backup_server port must be between 1 and 65535, got '%s'", parts[1])
|
|
}
|
|
|
|
// Validate custom records
|
|
for i, record := range d.CustomRecords {
|
|
if err := record.Validate(); err != nil {
|
|
return fmt.Errorf("DNS record %d is invalid: %v", i, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Validate validates a DNS record
|
|
func (r *DNSRecord) Validate() error {
|
|
if r.Name == "" {
|
|
return fmt.Errorf("DNS record name is required")
|
|
}
|
|
|
|
// Validate DNS name length and format
|
|
if len(r.Name) > 253 {
|
|
return fmt.Errorf("DNS record name too long")
|
|
}
|
|
|
|
// Basic character validation for DNS name
|
|
for _, c := range r.Name {
|
|
if c < 32 || c > 126 {
|
|
return fmt.Errorf("invalid character in DNS record name")
|
|
}
|
|
}
|
|
|
|
validTypes := []string{"A", "AAAA", "CNAME", "MX", "TXT", "SRV", "NS"}
|
|
valid := false
|
|
for _, t := range validTypes {
|
|
if strings.ToUpper(r.Type) == t {
|
|
valid = true
|
|
break
|
|
}
|
|
}
|
|
if !valid {
|
|
return fmt.Errorf("DNS record type must be one of %v, got '%s'", validTypes, r.Type)
|
|
}
|
|
|
|
if r.Value == "" {
|
|
return fmt.Errorf("DNS record value is required")
|
|
}
|
|
|
|
// Validate DNS record value length
|
|
if len(r.Value) > 1024 {
|
|
return fmt.Errorf("DNS record value too long")
|
|
}
|
|
|
|
// Basic character validation for DNS value
|
|
for _, c := range r.Value {
|
|
if c < 32 || c > 126 {
|
|
return fmt.Errorf("invalid character in DNS record value")
|
|
}
|
|
}
|
|
|
|
if r.TTL == 0 {
|
|
return fmt.Errorf("DNS record TTL must be greater than 0")
|
|
}
|
|
|
|
// Validate TTL range (reasonable limits)
|
|
if r.TTL > 86400 { // 24 hours
|
|
return fmt.Errorf("DNS record TTL too large (max 86400 seconds)")
|
|
}
|
|
|
|
// Note: Priority, Weight, and Port are uint16 types, so they cannot exceed 65535
|
|
// No additional validation needed for these fields as they are already constrained by their type
|
|
|
|
return nil
|
|
}
|
|
|
|
// GenerateExampleConfig generates an example configuration file
|
|
func GenerateExampleConfig(filename string) error {
|
|
// Generate a strong encryption key
|
|
strongKey, err := generateStrongEncryptionKey()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate encryption key: %v", err)
|
|
}
|
|
|
|
var config Config
|
|
if strings.Contains(filename, "server") {
|
|
// Generate server configuration
|
|
config = Config{
|
|
InstanceID: "teleport-server-01",
|
|
ListenAddress: ":8080",
|
|
RemoteAddress: "",
|
|
Ports: []PortRule{
|
|
{LocalPort: 80, RemotePort: 80, Protocol: "tcp", TargetHost: "localhost"},
|
|
},
|
|
EncryptionKey: strongKey,
|
|
KeepAlive: true,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
MaxConnections: 1000,
|
|
RateLimit: RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerSecond: 100,
|
|
BurstSize: 200,
|
|
WindowSize: 1 * time.Second,
|
|
},
|
|
DNSServer: DNSServerConfig{
|
|
ListenPort: 5353,
|
|
BackupServer: "8.8.8.8:53",
|
|
CustomRecords: []DNSRecord{},
|
|
},
|
|
}
|
|
} else {
|
|
// Generate client configuration
|
|
config = Config{
|
|
InstanceID: "teleport-client-01",
|
|
ListenAddress: "",
|
|
RemoteAddress: "localhost:8080",
|
|
Ports: []PortRule{
|
|
{LocalPort: 8080, RemotePort: 80, Protocol: "tcp", TargetHost: ""},
|
|
},
|
|
EncryptionKey: strongKey,
|
|
KeepAlive: true,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
MaxConnections: 100,
|
|
RateLimit: RateLimitConfig{
|
|
Enabled: true,
|
|
RequestsPerSecond: 50,
|
|
BurstSize: 100,
|
|
WindowSize: 1 * time.Second,
|
|
},
|
|
DNSServer: DNSServerConfig{
|
|
ListenPort: 5353,
|
|
BackupServer: "8.8.8.8:53",
|
|
CustomRecords: []DNSRecord{
|
|
{Name: "app.local", Type: "A", Value: "127.0.0.1", TTL: 300},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
data, err := yaml.Marshal(&config)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal config: %v", err)
|
|
}
|
|
|
|
err = os.WriteFile(filename, data, 0644)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to write config file: %v", err)
|
|
}
|
|
|
|
fmt.Printf("Generated example configuration: %s\n", filename)
|
|
fmt.Printf("Edit the configuration file and run: ./teleport -config %s\n", filename)
|
|
return nil
|
|
}
|
|
|
|
// generateStrongEncryptionKey generates a cryptographically secure encryption key
|
|
func generateStrongEncryptionKey() (string, error) {
|
|
// Generate 32 random bytes (256 bits) for a strong encryption key
|
|
bytes := make([]byte, 32)
|
|
if _, err := rand.Read(bytes); err != nil {
|
|
return "", fmt.Errorf("failed to generate random key: %v", err)
|
|
}
|
|
|
|
// Convert to hexadecimal string for easy copying
|
|
return hex.EncodeToString(bytes), nil
|
|
}
|