package config import ( "os" "path/filepath" "strings" "testing" "time" ) func TestLoadConfig(t *testing.T) { // Create a temporary config file tempDir := t.TempDir() configFile := filepath.Join(tempDir, "test-config.yaml") configContent := ` instance_id: test-instance listen_address: 127.0.0.1:8080 remote_address: "" ports: - tcp://localhost:22 - tcp://localhost:80 encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03 keep_alive: true read_timeout: 30s write_timeout: 30s dns_server: enabled: true listen_port: 5353 backup_server: 8.8.8.8:53 custom_records: - name: test.local type: A value: 192.168.1.100 ttl: 300 ` err := os.WriteFile(configFile, []byte(configContent), 0644) if err != nil { t.Fatalf("Failed to write test config: %v", err) } // Test loading config config, err := LoadConfig(configFile) if err != nil { t.Fatalf("Failed to load config: %v", err) } // Verify config values if config.InstanceID != "test-instance" { t.Errorf("Expected InstanceID 'test-instance', got '%s'", config.InstanceID) } if config.ListenAddress != "127.0.0.1:8080" { t.Errorf("Expected ListenAddress '127.0.0.1:8080', got '%s'", config.ListenAddress) } if config.RemoteAddress != "" { t.Errorf("Expected empty RemoteAddress, got '%s'", config.RemoteAddress) } if len(config.Ports) != 2 { t.Errorf("Expected 2 ports, got %d", len(config.Ports)) } if config.Ports[0].LocalPort != 22 || config.Ports[0].RemotePort != 22 || config.Ports[0].Protocol != "tcp" || config.Ports[0].TargetHost != "localhost" { t.Error("First port rule is incorrect") } if config.EncryptionKey != "a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03" { t.Errorf("Expected EncryptionKey 'a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03', got '%s'", config.EncryptionKey) } if !config.KeepAlive { t.Error("Expected KeepAlive to be true") } if config.ReadTimeout != 30*time.Second { t.Errorf("Expected ReadTimeout 30s, got %v", config.ReadTimeout) } if config.WriteTimeout != 30*time.Second { t.Errorf("Expected WriteTimeout 30s, got %v", config.WriteTimeout) } if !config.DNSServer.Enabled { t.Error("Expected DNS server to be enabled") } if config.DNSServer.ListenPort != 5353 { t.Errorf("Expected DNS listen port 5353, got %d", config.DNSServer.ListenPort) } if len(config.DNSServer.CustomRecords) != 1 { t.Errorf("Expected 1 custom DNS record, got %d", len(config.DNSServer.CustomRecords)) } } func TestLoadConfigDefaults(t *testing.T) { // Create a minimal config file tempDir := t.TempDir() configFile := filepath.Join(tempDir, "minimal-config.yaml") configContent := ` instance_id: test-instance listen_address: 127.0.0.1:8080 ports: - tcp://localhost:22 encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03 ` err := os.WriteFile(configFile, []byte(configContent), 0644) if err != nil { t.Fatalf("Failed to write test config: %v", err) } config, err := LoadConfig(configFile) if err != nil { t.Fatalf("Failed to load config: %v", err) } // Check default values if config.ReadTimeout != 30*time.Second { t.Errorf("Expected default ReadTimeout 30s, got %v", config.ReadTimeout) } if config.WriteTimeout != 30*time.Second { t.Errorf("Expected default WriteTimeout 30s, got %v", config.WriteTimeout) } if config.DNSServer.ListenPort != 5353 { t.Errorf("Expected default DNS listen port 5353, got %d", config.DNSServer.ListenPort) } if config.DNSServer.BackupServer != "8.8.8.8:53" { t.Errorf("Expected default backup server '8.8.8.8:53', got '%s'", config.DNSServer.BackupServer) } } func TestDetectMode(t *testing.T) { tests := []struct { name string listenAddress string remoteAddress string expectedMode string expectedError bool }{ { name: "Server mode", listenAddress: "127.0.0.1:8080", remoteAddress: "", expectedMode: "server", expectedError: false, }, { name: "Client mode", listenAddress: "", remoteAddress: "127.0.0.1:8080", expectedMode: "client", expectedError: false, }, { name: "Both addresses set - error", listenAddress: "127.0.0.1:8080", remoteAddress: "127.0.0.1:8080", expectedMode: "", expectedError: true, }, { name: "Neither address set - error", listenAddress: "", remoteAddress: "", expectedMode: "", expectedError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { config := &Config{ ListenAddress: tt.listenAddress, RemoteAddress: tt.remoteAddress, } mode, err := DetectMode(config) if tt.expectedError { if err == nil { t.Error("Expected error but got none") } } else { if err != nil { t.Errorf("Unexpected error: %v", err) } if mode != tt.expectedMode { t.Errorf("Expected mode '%s', got '%s'", tt.expectedMode, mode) } } }) } } func TestLoadConfigFileNotFound(t *testing.T) { _, err := LoadConfig("nonexistent-config.yaml") if err == nil { t.Error("Expected error for nonexistent config file") } } func TestLoadConfigInvalidYAML(t *testing.T) { tempDir := t.TempDir() configFile := filepath.Join(tempDir, "invalid-config.yaml") // Write invalid YAML err := os.WriteFile(configFile, []byte("invalid: yaml: content: ["), 0644) if err != nil { t.Fatalf("Failed to write test config: %v", err) } _, err = LoadConfig(configFile) if err == nil { t.Error("Expected error for invalid YAML") } } func TestPortRuleURLFormat(t *testing.T) { tests := []struct { name string config string expected PortRule }{ { name: "Server format with localhost", config: ` instance_id: test listen_address: :8080 ports: - tcp://localhost:80 encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03 `, expected: PortRule{ LocalPort: 80, RemotePort: 80, Protocol: "tcp", TargetHost: "localhost", }, }, { name: "Server format with remote host", config: ` instance_id: test listen_address: :8080 ports: - tcp://server-a:22 encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03 `, expected: PortRule{ LocalPort: 22, RemotePort: 22, Protocol: "tcp", TargetHost: "server-a", }, }, { name: "Client format", config: ` instance_id: test remote_address: localhost:8080 ports: - tcp://80:8080 encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03 `, expected: PortRule{ LocalPort: 8080, RemotePort: 80, Protocol: "tcp", TargetHost: "", }, }, { name: "UDP format", config: ` instance_id: test listen_address: :8080 ports: - udp://dns-server:53 encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03 `, expected: PortRule{ LocalPort: 53, RemotePort: 53, Protocol: "udp", TargetHost: "dns-server", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tempDir := t.TempDir() configFile := filepath.Join(tempDir, "test-config.yaml") err := os.WriteFile(configFile, []byte(tt.config), 0644) if err != nil { t.Fatalf("Failed to write test config: %v", err) } config, err := LoadConfig(configFile) if err != nil { t.Fatalf("Failed to load config: %v", err) } if len(config.Ports) != 1 { t.Fatalf("Expected 1 port, got %d", len(config.Ports)) } port := config.Ports[0] if port.LocalPort != tt.expected.LocalPort { t.Errorf("Expected LocalPort %d, got %d", tt.expected.LocalPort, port.LocalPort) } if port.RemotePort != tt.expected.RemotePort { t.Errorf("Expected RemotePort %d, got %d", tt.expected.RemotePort, port.RemotePort) } if port.Protocol != tt.expected.Protocol { t.Errorf("Expected Protocol %s, got %s", tt.expected.Protocol, port.Protocol) } if port.TargetHost != tt.expected.TargetHost { t.Errorf("Expected TargetHost %s, got %s", tt.expected.TargetHost, port.TargetHost) } }) } } func TestPortRuleOldFormatFails(t *testing.T) { // Test that old object format now fails tempDir := t.TempDir() configFile := filepath.Join(tempDir, "old-format-config.yaml") configContent := ` instance_id: test listen_address: :8080 ports: - local_port: 22 remote_port: 22 protocol: tcp encryption_key: test-key ` err := os.WriteFile(configFile, []byte(configContent), 0644) if err != nil { t.Fatalf("Failed to write test config: %v", err) } _, err = LoadConfig(configFile) if err == nil { t.Error("Expected error for old format, but got none") } // Check that the error message mentions the expected format if !strings.Contains(err.Error(), "protocol://") { t.Errorf("Expected error message to mention 'protocol://', got: %v", err) } }