Add initial project structure with core functionality
- Created a new Go module named 'teleport' for secure port forwarding. - Added essential files including .gitignore, LICENSE, and README.md with project details. - Implemented configuration management with YAML support in config package. - Developed core client and server functionalities for handling port forwarding. - Introduced DNS server capabilities and integrated logging with sanitization. - Established rate limiting and metrics tracking for performance monitoring. - Included comprehensive tests for core components and functionalities. - Set up CI workflows for automated testing and release management using Gitea actions.
This commit is contained in:
537
pkg/config/config.go
Normal file
537
pkg/config/config.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"teleport/pkg/encryption"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config represents the teleport configuration
|
||||
type Config struct {
|
||||
InstanceID string `yaml:"instance_id"`
|
||||
ListenAddress string `yaml:"listen_address"`
|
||||
RemoteAddress string `yaml:"remote_address"`
|
||||
Ports []PortRule `yaml:"ports"`
|
||||
EncryptionKey string `yaml:"encryption_key"`
|
||||
KeepAlive bool `yaml:"keep_alive"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout"`
|
||||
MaxConnections int `yaml:"max_connections"`
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
DNSServer DNSServerConfig `yaml:"dns_server"`
|
||||
}
|
||||
|
||||
// PortRule defines a port forwarding rule
|
||||
type PortRule struct {
|
||||
LocalPort int `yaml:"local_port"`
|
||||
RemotePort int `yaml:"remote_port"`
|
||||
Protocol string `yaml:"protocol"` // "tcp" or "udp"
|
||||
TargetHost string `yaml:"target_host,omitempty"` // Target host for server-side forwarding (defaults to localhost)
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements custom YAML unmarshaling for PortRule
|
||||
func (p *PortRule) UnmarshalYAML(value *yaml.Node) error {
|
||||
// Only support URL-style format: "protocol://target:targetport" (server) or "protocol://targetport:localport" (client)
|
||||
if value.Kind != yaml.ScalarNode {
|
||||
return fmt.Errorf("port must be a string in format 'protocol://target:port' or 'protocol://targetport:localport'")
|
||||
}
|
||||
|
||||
portStr := value.Value
|
||||
|
||||
// Validate input length
|
||||
if len(portStr) > 512 {
|
||||
return fmt.Errorf("port string too long")
|
||||
}
|
||||
|
||||
// Check if it starts with protocol://
|
||||
if !strings.Contains(portStr, "://") {
|
||||
return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port' or 'protocol://targetport:localport')", portStr)
|
||||
}
|
||||
|
||||
parts := strings.Split(portStr, "://")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port')", portStr)
|
||||
}
|
||||
|
||||
protocol := parts[0]
|
||||
if protocol != "tcp" && protocol != "udp" {
|
||||
return fmt.Errorf("invalid protocol: %s (expected 'tcp' or 'udp')", protocol)
|
||||
}
|
||||
|
||||
addressPart := parts[1]
|
||||
|
||||
// Validate address part length
|
||||
if len(addressPart) > 256 {
|
||||
return fmt.Errorf("address part too long")
|
||||
}
|
||||
|
||||
addressParts := strings.Split(addressPart, ":")
|
||||
|
||||
if len(addressParts) == 2 {
|
||||
// Check if first part is a hostname/IP (contains dots or is not a number) - server format
|
||||
if _, err := strconv.Atoi(addressParts[0]); err != nil || strings.Contains(addressParts[0], ".") {
|
||||
// Server format: protocol://target:targetport
|
||||
targetHost := addressParts[0]
|
||||
|
||||
// Validate target host
|
||||
if len(targetHost) > 253 { // RFC 1123 limit
|
||||
return fmt.Errorf("target host too long")
|
||||
}
|
||||
|
||||
// Basic character validation for hostname
|
||||
for _, c := range targetHost {
|
||||
if c < 32 || c > 126 {
|
||||
return fmt.Errorf("invalid character in target host")
|
||||
}
|
||||
}
|
||||
|
||||
targetPort, err := strconv.Atoi(addressParts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid target port: %s", addressParts[1])
|
||||
}
|
||||
|
||||
// Validate port range
|
||||
if targetPort < 1 || targetPort > 65535 {
|
||||
return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort)
|
||||
}
|
||||
|
||||
p.LocalPort = targetPort // Server listens on this port
|
||||
p.RemotePort = targetPort // Server forwards to this port on target
|
||||
p.Protocol = protocol
|
||||
p.TargetHost = targetHost
|
||||
return nil
|
||||
} else {
|
||||
// Client format: protocol://targetport:localport (first part is a number)
|
||||
targetPort, err := strconv.Atoi(addressParts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid target port: %s", addressParts[0])
|
||||
}
|
||||
localPort, err := strconv.Atoi(addressParts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid local port: %s", addressParts[1])
|
||||
}
|
||||
|
||||
// Validate port ranges
|
||||
if targetPort < 1 || targetPort > 65535 {
|
||||
return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort)
|
||||
}
|
||||
if localPort < 1 || localPort > 65535 {
|
||||
return fmt.Errorf("invalid local port: %d (must be 1-65535)", localPort)
|
||||
}
|
||||
|
||||
p.LocalPort = localPort // Client listens on this port
|
||||
p.RemotePort = targetPort // Client forwards to this port on teleport server
|
||||
p.Protocol = protocol
|
||||
p.TargetHost = "" // Client doesn't specify target host
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("invalid address format: %s (expected 'target:port' for server or 'targetport:localport' for client)", addressPart)
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshaling for PortRule
|
||||
func (p PortRule) MarshalYAML() (interface{}, error) {
|
||||
// Use new format: protocol://target:targetport (server) or protocol://targetport:localport (client)
|
||||
if p.TargetHost == "" {
|
||||
// Client format: protocol://targetport:localport
|
||||
return fmt.Sprintf("%s://%d:%d", p.Protocol, p.RemotePort, p.LocalPort), nil
|
||||
} else {
|
||||
// Server format: protocol://target:targetport
|
||||
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.TargetHost, p.RemotePort), nil
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitConfig defines rate limiting configuration
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
RequestsPerSecond int `yaml:"requests_per_second"`
|
||||
BurstSize int `yaml:"burst_size"`
|
||||
WindowSize time.Duration `yaml:"window_size"`
|
||||
}
|
||||
|
||||
// DNSServerConfig defines DNS server configuration
|
||||
type DNSServerConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
ListenPort int `yaml:"listen_port"`
|
||||
BackupServer string `yaml:"backup_server"`
|
||||
CustomRecords []DNSRecord `yaml:"custom_records"`
|
||||
}
|
||||
|
||||
// DNSRecord defines a custom DNS record
|
||||
type DNSRecord struct {
|
||||
Name string `yaml:"name"`
|
||||
Type string `yaml:"type"` // A, AAAA, CNAME, MX, TXT, NS, SRV
|
||||
Value string `yaml:"value"`
|
||||
TTL uint32 `yaml:"ttl"`
|
||||
Priority uint16 `yaml:"priority,omitempty"` // For MX and SRV records
|
||||
Weight uint16 `yaml:"weight,omitempty"` // For SRV records
|
||||
Port uint16 `yaml:"port,omitempty"` // For SRV records
|
||||
}
|
||||
|
||||
// LoadConfig loads and parses the configuration file
|
||||
func LoadConfig(filename string) (*Config, error) {
|
||||
// Check file size before reading (max 1MB)
|
||||
fileInfo, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to stat config file: %v", err)
|
||||
}
|
||||
if fileInfo.Size() > 1024*1024 { // 1MB limit
|
||||
return nil, fmt.Errorf("config file too large: %d bytes (max 1MB)", fileInfo.Size())
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %v", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %v", err)
|
||||
}
|
||||
|
||||
// Set default values
|
||||
if config.ReadTimeout == 0 {
|
||||
config.ReadTimeout = 30 * time.Second
|
||||
}
|
||||
if config.WriteTimeout == 0 {
|
||||
config.WriteTimeout = 30 * time.Second
|
||||
}
|
||||
if config.MaxConnections == 0 {
|
||||
config.MaxConnections = 1000
|
||||
}
|
||||
if config.RateLimit.RequestsPerSecond == 0 {
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
}
|
||||
if config.RateLimit.BurstSize == 0 {
|
||||
config.RateLimit.BurstSize = 200
|
||||
}
|
||||
if config.RateLimit.WindowSize == 0 {
|
||||
config.RateLimit.WindowSize = 1 * time.Second
|
||||
}
|
||||
if config.DNSServer.ListenPort == 0 {
|
||||
config.DNSServer.ListenPort = 5353
|
||||
}
|
||||
if config.DNSServer.BackupServer == "" {
|
||||
config.DNSServer.BackupServer = "8.8.8.8:53"
|
||||
}
|
||||
|
||||
// Validate configuration
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("configuration validation failed: %v", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// DetectMode determines if this should run as server or client based on config
|
||||
func DetectMode(config *Config) (string, error) {
|
||||
hasListen := config.ListenAddress != ""
|
||||
hasRemote := config.RemoteAddress != ""
|
||||
|
||||
if hasListen && hasRemote {
|
||||
return "", fmt.Errorf("configuration error: cannot have both 'listen_address' and 'remote_address' set. Choose one:\n" +
|
||||
"- Server: set 'listen_address' and leave 'remote_address' empty\n" +
|
||||
"- Client: set 'remote_address' and leave 'listen_address' empty")
|
||||
}
|
||||
|
||||
if !hasListen && !hasRemote {
|
||||
return "", fmt.Errorf("configuration error: must specify either 'listen_address' (for server) or 'remote_address' (for client)")
|
||||
}
|
||||
|
||||
if hasListen && !hasRemote {
|
||||
return "server", nil
|
||||
}
|
||||
|
||||
if hasRemote && !hasListen {
|
||||
return "client", nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("internal error: could not determine mode")
|
||||
}
|
||||
|
||||
// Validate validates the configuration
|
||||
func (c *Config) Validate() error {
|
||||
// Validate instance ID
|
||||
if c.InstanceID == "" {
|
||||
return fmt.Errorf("instance_id is required")
|
||||
}
|
||||
|
||||
// Validate encryption key
|
||||
if err := encryption.ValidateEncryptionKey(c.EncryptionKey); err != nil {
|
||||
return fmt.Errorf("invalid encryption key: %v", err)
|
||||
}
|
||||
|
||||
// Validate ports
|
||||
if len(c.Ports) == 0 {
|
||||
return fmt.Errorf("at least one port rule is required")
|
||||
}
|
||||
|
||||
for i, port := range c.Ports {
|
||||
if err := port.Validate(); err != nil {
|
||||
return fmt.Errorf("port rule %d is invalid: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate DNS server configuration
|
||||
if c.DNSServer.Enabled {
|
||||
if err := c.DNSServer.Validate(); err != nil {
|
||||
return fmt.Errorf("DNS server configuration is invalid: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate rate limiting configuration
|
||||
if c.RateLimit.Enabled {
|
||||
if c.RateLimit.RequestsPerSecond <= 0 {
|
||||
return fmt.Errorf("rate limit requests_per_second must be positive")
|
||||
}
|
||||
if c.RateLimit.BurstSize <= 0 {
|
||||
return fmt.Errorf("rate limit burst_size must be positive")
|
||||
}
|
||||
if c.RateLimit.WindowSize <= 0 {
|
||||
return fmt.Errorf("rate limit window_size must be positive")
|
||||
}
|
||||
if c.RateLimit.BurstSize < c.RateLimit.RequestsPerSecond {
|
||||
return fmt.Errorf("rate limit burst_size should be at least requests_per_second")
|
||||
}
|
||||
if c.RateLimit.RequestsPerSecond > 10000 {
|
||||
return fmt.Errorf("rate limit requests_per_second too high: %d (maximum 10000)", c.RateLimit.RequestsPerSecond)
|
||||
}
|
||||
if c.RateLimit.BurstSize > 50000 {
|
||||
return fmt.Errorf("rate limit burst_size too high: %d (maximum 50000)", c.RateLimit.BurstSize)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate connection limits
|
||||
if c.MaxConnections < 0 {
|
||||
return fmt.Errorf("max_connections cannot be negative")
|
||||
}
|
||||
if c.MaxConnections > 100000 {
|
||||
return fmt.Errorf("max_connections too high: %d (maximum 100000)", c.MaxConnections)
|
||||
}
|
||||
|
||||
// Validate timeout values
|
||||
if c.ReadTimeout < 0 {
|
||||
return fmt.Errorf("read_timeout cannot be negative")
|
||||
}
|
||||
if c.WriteTimeout < 0 {
|
||||
return fmt.Errorf("write_timeout cannot be negative")
|
||||
}
|
||||
if c.ReadTimeout > 24*time.Hour {
|
||||
return fmt.Errorf("read_timeout too high: %v (maximum 24 hours)", c.ReadTimeout)
|
||||
}
|
||||
if c.WriteTimeout > 24*time.Hour {
|
||||
return fmt.Errorf("write_timeout too high: %v (maximum 24 hours)", c.WriteTimeout)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates a port rule
|
||||
func (p *PortRule) Validate() error {
|
||||
if p.LocalPort < 1 || p.LocalPort > 65535 {
|
||||
return fmt.Errorf("local_port must be between 1 and 65535, got %d", p.LocalPort)
|
||||
}
|
||||
|
||||
if p.RemotePort < 1 || p.RemotePort > 65535 {
|
||||
return fmt.Errorf("remote_port must be between 1 and 65535, got %d", p.RemotePort)
|
||||
}
|
||||
|
||||
if p.Protocol != "tcp" && p.Protocol != "udp" {
|
||||
return fmt.Errorf("protocol must be 'tcp' or 'udp', got '%s'", p.Protocol)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates DNS server configuration
|
||||
func (d *DNSServerConfig) Validate() error {
|
||||
if d.ListenPort < 1 || d.ListenPort > 65535 {
|
||||
return fmt.Errorf("DNS listen_port must be between 1 and 65535, got %d", d.ListenPort)
|
||||
}
|
||||
|
||||
if d.BackupServer == "" {
|
||||
return fmt.Errorf("DNS backup_server is required when DNS server is enabled")
|
||||
}
|
||||
|
||||
// Validate backup server format
|
||||
parts := strings.Split(d.BackupServer, ":")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("DNS backup_server must be in format 'host:port', got '%s'", d.BackupServer)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("DNS backup_server port must be between 1 and 65535, got '%s'", parts[1])
|
||||
}
|
||||
|
||||
// Validate custom records
|
||||
for i, record := range d.CustomRecords {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("DNS record %d is invalid: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates a DNS record
|
||||
func (r *DNSRecord) Validate() error {
|
||||
if r.Name == "" {
|
||||
return fmt.Errorf("DNS record name is required")
|
||||
}
|
||||
|
||||
// Validate DNS name length and format
|
||||
if len(r.Name) > 253 {
|
||||
return fmt.Errorf("DNS record name too long")
|
||||
}
|
||||
|
||||
// Basic character validation for DNS name
|
||||
for _, c := range r.Name {
|
||||
if c < 32 || c > 126 {
|
||||
return fmt.Errorf("invalid character in DNS record name")
|
||||
}
|
||||
}
|
||||
|
||||
validTypes := []string{"A", "AAAA", "CNAME", "MX", "TXT", "SRV", "NS"}
|
||||
valid := false
|
||||
for _, t := range validTypes {
|
||||
if strings.ToUpper(r.Type) == t {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("DNS record type must be one of %v, got '%s'", validTypes, r.Type)
|
||||
}
|
||||
|
||||
if r.Value == "" {
|
||||
return fmt.Errorf("DNS record value is required")
|
||||
}
|
||||
|
||||
// Validate DNS record value length
|
||||
if len(r.Value) > 1024 {
|
||||
return fmt.Errorf("DNS record value too long")
|
||||
}
|
||||
|
||||
// Basic character validation for DNS value
|
||||
for _, c := range r.Value {
|
||||
if c < 32 || c > 126 {
|
||||
return fmt.Errorf("invalid character in DNS record value")
|
||||
}
|
||||
}
|
||||
|
||||
if r.TTL == 0 {
|
||||
return fmt.Errorf("DNS record TTL must be greater than 0")
|
||||
}
|
||||
|
||||
// Validate TTL range (reasonable limits)
|
||||
if r.TTL > 86400 { // 24 hours
|
||||
return fmt.Errorf("DNS record TTL too large (max 86400 seconds)")
|
||||
}
|
||||
|
||||
// Note: Priority, Weight, and Port are uint16 types, so they cannot exceed 65535
|
||||
// No additional validation needed for these fields as they are already constrained by their type
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateExampleConfig generates an example configuration file
|
||||
func GenerateExampleConfig(filename string) error {
|
||||
// Generate a strong encryption key
|
||||
strongKey, err := generateStrongEncryptionKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate encryption key: %v", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if strings.Contains(filename, "server") {
|
||||
// Generate server configuration
|
||||
config = Config{
|
||||
InstanceID: "teleport-server-01",
|
||||
ListenAddress: ":8080",
|
||||
RemoteAddress: "",
|
||||
Ports: []PortRule{
|
||||
{LocalPort: 80, RemotePort: 80, Protocol: "tcp", TargetHost: "localhost"},
|
||||
},
|
||||
EncryptionKey: strongKey,
|
||||
KeepAlive: true,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
MaxConnections: 1000,
|
||||
RateLimit: RateLimitConfig{
|
||||
Enabled: true,
|
||||
RequestsPerSecond: 100,
|
||||
BurstSize: 200,
|
||||
WindowSize: 1 * time.Second,
|
||||
},
|
||||
DNSServer: DNSServerConfig{
|
||||
ListenPort: 5353,
|
||||
BackupServer: "8.8.8.8:53",
|
||||
CustomRecords: []DNSRecord{},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Generate client configuration
|
||||
config = Config{
|
||||
InstanceID: "teleport-client-01",
|
||||
ListenAddress: "",
|
||||
RemoteAddress: "localhost:8080",
|
||||
Ports: []PortRule{
|
||||
{LocalPort: 8080, RemotePort: 80, Protocol: "tcp", TargetHost: ""},
|
||||
},
|
||||
EncryptionKey: strongKey,
|
||||
KeepAlive: true,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
MaxConnections: 100,
|
||||
RateLimit: RateLimitConfig{
|
||||
Enabled: true,
|
||||
RequestsPerSecond: 50,
|
||||
BurstSize: 100,
|
||||
WindowSize: 1 * time.Second,
|
||||
},
|
||||
DNSServer: DNSServerConfig{
|
||||
ListenPort: 5353,
|
||||
BackupServer: "8.8.8.8:53",
|
||||
CustomRecords: []DNSRecord{
|
||||
{Name: "app.local", Type: "A", Value: "127.0.0.1", TTL: 300},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(&config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(filename, data, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Generated example configuration: %s\n", filename)
|
||||
fmt.Printf("Edit the configuration file and run: ./teleport -config %s\n", filename)
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateStrongEncryptionKey generates a cryptographically secure encryption key
|
||||
func generateStrongEncryptionKey() (string, error) {
|
||||
// Generate 32 random bytes (256 bits) for a strong encryption key
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random key: %v", err)
|
||||
}
|
||||
|
||||
// Convert to hexadecimal string for easy copying
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
363
pkg/config/config_test.go
Normal file
363
pkg/config/config_test.go
Normal file
@@ -0,0 +1,363 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "test-config.yaml")
|
||||
|
||||
configContent := `
|
||||
instance_id: test-instance
|
||||
listen_address: 127.0.0.1:8080
|
||||
remote_address: ""
|
||||
ports:
|
||||
- tcp://localhost:22
|
||||
- tcp://localhost:80
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
keep_alive: true
|
||||
read_timeout: 30s
|
||||
write_timeout: 30s
|
||||
dns_server:
|
||||
enabled: true
|
||||
listen_port: 5353
|
||||
backup_server: 8.8.8.8:53
|
||||
custom_records:
|
||||
- name: test.local
|
||||
type: A
|
||||
value: 192.168.1.100
|
||||
ttl: 300
|
||||
`
|
||||
|
||||
err := os.WriteFile(configFile, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
// Test loading config
|
||||
config, err := LoadConfig(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Verify config values
|
||||
if config.InstanceID != "test-instance" {
|
||||
t.Errorf("Expected InstanceID 'test-instance', got '%s'", config.InstanceID)
|
||||
}
|
||||
|
||||
if config.ListenAddress != "127.0.0.1:8080" {
|
||||
t.Errorf("Expected ListenAddress '127.0.0.1:8080', got '%s'", config.ListenAddress)
|
||||
}
|
||||
|
||||
if config.RemoteAddress != "" {
|
||||
t.Errorf("Expected empty RemoteAddress, got '%s'", config.RemoteAddress)
|
||||
}
|
||||
|
||||
if len(config.Ports) != 2 {
|
||||
t.Errorf("Expected 2 ports, got %d", len(config.Ports))
|
||||
}
|
||||
|
||||
if config.Ports[0].LocalPort != 22 || config.Ports[0].RemotePort != 22 || config.Ports[0].Protocol != "tcp" || config.Ports[0].TargetHost != "localhost" {
|
||||
t.Error("First port rule is incorrect")
|
||||
}
|
||||
|
||||
if config.EncryptionKey != "a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03" {
|
||||
t.Errorf("Expected EncryptionKey 'a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03', got '%s'", config.EncryptionKey)
|
||||
}
|
||||
|
||||
if !config.KeepAlive {
|
||||
t.Error("Expected KeepAlive to be true")
|
||||
}
|
||||
|
||||
if config.ReadTimeout != 30*time.Second {
|
||||
t.Errorf("Expected ReadTimeout 30s, got %v", config.ReadTimeout)
|
||||
}
|
||||
|
||||
if config.WriteTimeout != 30*time.Second {
|
||||
t.Errorf("Expected WriteTimeout 30s, got %v", config.WriteTimeout)
|
||||
}
|
||||
|
||||
if !config.DNSServer.Enabled {
|
||||
t.Error("Expected DNS server to be enabled")
|
||||
}
|
||||
|
||||
if config.DNSServer.ListenPort != 5353 {
|
||||
t.Errorf("Expected DNS listen port 5353, got %d", config.DNSServer.ListenPort)
|
||||
}
|
||||
|
||||
if len(config.DNSServer.CustomRecords) != 1 {
|
||||
t.Errorf("Expected 1 custom DNS record, got %d", len(config.DNSServer.CustomRecords))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigDefaults(t *testing.T) {
|
||||
// Create a minimal config file
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "minimal-config.yaml")
|
||||
|
||||
configContent := `
|
||||
instance_id: test-instance
|
||||
listen_address: 127.0.0.1:8080
|
||||
ports:
|
||||
- tcp://localhost:22
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
`
|
||||
|
||||
err := os.WriteFile(configFile, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
config, err := LoadConfig(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Check default values
|
||||
if config.ReadTimeout != 30*time.Second {
|
||||
t.Errorf("Expected default ReadTimeout 30s, got %v", config.ReadTimeout)
|
||||
}
|
||||
|
||||
if config.WriteTimeout != 30*time.Second {
|
||||
t.Errorf("Expected default WriteTimeout 30s, got %v", config.WriteTimeout)
|
||||
}
|
||||
|
||||
if config.DNSServer.ListenPort != 5353 {
|
||||
t.Errorf("Expected default DNS listen port 5353, got %d", config.DNSServer.ListenPort)
|
||||
}
|
||||
|
||||
if config.DNSServer.BackupServer != "8.8.8.8:53" {
|
||||
t.Errorf("Expected default backup server '8.8.8.8:53', got '%s'", config.DNSServer.BackupServer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
listenAddress string
|
||||
remoteAddress string
|
||||
expectedMode string
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Server mode",
|
||||
listenAddress: "127.0.0.1:8080",
|
||||
remoteAddress: "",
|
||||
expectedMode: "server",
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Client mode",
|
||||
listenAddress: "",
|
||||
remoteAddress: "127.0.0.1:8080",
|
||||
expectedMode: "client",
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Both addresses set - error",
|
||||
listenAddress: "127.0.0.1:8080",
|
||||
remoteAddress: "127.0.0.1:8080",
|
||||
expectedMode: "",
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Neither address set - error",
|
||||
listenAddress: "",
|
||||
remoteAddress: "",
|
||||
expectedMode: "",
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &Config{
|
||||
ListenAddress: tt.listenAddress,
|
||||
RemoteAddress: tt.remoteAddress,
|
||||
}
|
||||
|
||||
mode, err := DetectMode(config)
|
||||
|
||||
if tt.expectedError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if mode != tt.expectedMode {
|
||||
t.Errorf("Expected mode '%s', got '%s'", tt.expectedMode, mode)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFileNotFound(t *testing.T) {
|
||||
_, err := LoadConfig("nonexistent-config.yaml")
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigInvalidYAML(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "invalid-config.yaml")
|
||||
|
||||
// Write invalid YAML
|
||||
err := os.WriteFile(configFile, []byte("invalid: yaml: content: ["), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
_, err = LoadConfig(configFile)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid YAML")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortRuleURLFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
expected PortRule
|
||||
}{
|
||||
{
|
||||
name: "Server format with localhost",
|
||||
config: `
|
||||
instance_id: test
|
||||
listen_address: :8080
|
||||
ports:
|
||||
- tcp://localhost:80
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
`,
|
||||
expected: PortRule{
|
||||
LocalPort: 80,
|
||||
RemotePort: 80,
|
||||
Protocol: "tcp",
|
||||
TargetHost: "localhost",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Server format with remote host",
|
||||
config: `
|
||||
instance_id: test
|
||||
listen_address: :8080
|
||||
ports:
|
||||
- tcp://server-a:22
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
`,
|
||||
expected: PortRule{
|
||||
LocalPort: 22,
|
||||
RemotePort: 22,
|
||||
Protocol: "tcp",
|
||||
TargetHost: "server-a",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Client format",
|
||||
config: `
|
||||
instance_id: test
|
||||
remote_address: localhost:8080
|
||||
ports:
|
||||
- tcp://80:8080
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
`,
|
||||
expected: PortRule{
|
||||
LocalPort: 8080,
|
||||
RemotePort: 80,
|
||||
Protocol: "tcp",
|
||||
TargetHost: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UDP format",
|
||||
config: `
|
||||
instance_id: test
|
||||
listen_address: :8080
|
||||
ports:
|
||||
- udp://dns-server:53
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
`,
|
||||
expected: PortRule{
|
||||
LocalPort: 53,
|
||||
RemotePort: 53,
|
||||
Protocol: "udp",
|
||||
TargetHost: "dns-server",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "test-config.yaml")
|
||||
|
||||
err := os.WriteFile(configFile, []byte(tt.config), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
config, err := LoadConfig(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
if len(config.Ports) != 1 {
|
||||
t.Fatalf("Expected 1 port, got %d", len(config.Ports))
|
||||
}
|
||||
|
||||
port := config.Ports[0]
|
||||
if port.LocalPort != tt.expected.LocalPort {
|
||||
t.Errorf("Expected LocalPort %d, got %d", tt.expected.LocalPort, port.LocalPort)
|
||||
}
|
||||
if port.RemotePort != tt.expected.RemotePort {
|
||||
t.Errorf("Expected RemotePort %d, got %d", tt.expected.RemotePort, port.RemotePort)
|
||||
}
|
||||
if port.Protocol != tt.expected.Protocol {
|
||||
t.Errorf("Expected Protocol %s, got %s", tt.expected.Protocol, port.Protocol)
|
||||
}
|
||||
if port.TargetHost != tt.expected.TargetHost {
|
||||
t.Errorf("Expected TargetHost %s, got %s", tt.expected.TargetHost, port.TargetHost)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortRuleOldFormatFails(t *testing.T) {
|
||||
// Test that old object format now fails
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "old-format-config.yaml")
|
||||
|
||||
configContent := `
|
||||
instance_id: test
|
||||
listen_address: :8080
|
||||
ports:
|
||||
- local_port: 22
|
||||
remote_port: 22
|
||||
protocol: tcp
|
||||
encryption_key: test-key
|
||||
`
|
||||
|
||||
err := os.WriteFile(configFile, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
_, err = LoadConfig(configFile)
|
||||
if err == nil {
|
||||
t.Error("Expected error for old format, but got none")
|
||||
}
|
||||
|
||||
// Check that the error message mentions the expected format
|
||||
if !strings.Contains(err.Error(), "protocol://") {
|
||||
t.Errorf("Expected error message to mention 'protocol://', got: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user