Files
teleport/pkg/config/config_test.go
Justin Harms d24d1dc5ae Add initial project structure with core functionality
- Created a new Go module named 'teleport' for secure port forwarding.
- Added essential files including .gitignore, LICENSE, and README.md with project details.
- Implemented configuration management with YAML support in config package.
- Developed core client and server functionalities for handling port forwarding.
- Introduced DNS server capabilities and integrated logging with sanitization.
- Established rate limiting and metrics tracking for performance monitoring.
- Included comprehensive tests for core components and functionalities.
- Set up CI workflows for automated testing and release management using Gitea actions.
2025-09-20 18:07:08 -05:00

364 lines
8.7 KiB
Go

package config
import (
"os"
"path/filepath"
"strings"
"testing"
"time"
)
func TestLoadConfig(t *testing.T) {
// Create a temporary config file
tempDir := t.TempDir()
configFile := filepath.Join(tempDir, "test-config.yaml")
configContent := `
instance_id: test-instance
listen_address: 127.0.0.1:8080
remote_address: ""
ports:
- tcp://localhost:22
- tcp://localhost:80
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
keep_alive: true
read_timeout: 30s
write_timeout: 30s
dns_server:
enabled: true
listen_port: 5353
backup_server: 8.8.8.8:53
custom_records:
- name: test.local
type: A
value: 192.168.1.100
ttl: 300
`
err := os.WriteFile(configFile, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
// Test loading config
config, err := LoadConfig(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Verify config values
if config.InstanceID != "test-instance" {
t.Errorf("Expected InstanceID 'test-instance', got '%s'", config.InstanceID)
}
if config.ListenAddress != "127.0.0.1:8080" {
t.Errorf("Expected ListenAddress '127.0.0.1:8080', got '%s'", config.ListenAddress)
}
if config.RemoteAddress != "" {
t.Errorf("Expected empty RemoteAddress, got '%s'", config.RemoteAddress)
}
if len(config.Ports) != 2 {
t.Errorf("Expected 2 ports, got %d", len(config.Ports))
}
if config.Ports[0].LocalPort != 22 || config.Ports[0].RemotePort != 22 || config.Ports[0].Protocol != "tcp" || config.Ports[0].TargetHost != "localhost" {
t.Error("First port rule is incorrect")
}
if config.EncryptionKey != "a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03" {
t.Errorf("Expected EncryptionKey 'a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03', got '%s'", config.EncryptionKey)
}
if !config.KeepAlive {
t.Error("Expected KeepAlive to be true")
}
if config.ReadTimeout != 30*time.Second {
t.Errorf("Expected ReadTimeout 30s, got %v", config.ReadTimeout)
}
if config.WriteTimeout != 30*time.Second {
t.Errorf("Expected WriteTimeout 30s, got %v", config.WriteTimeout)
}
if !config.DNSServer.Enabled {
t.Error("Expected DNS server to be enabled")
}
if config.DNSServer.ListenPort != 5353 {
t.Errorf("Expected DNS listen port 5353, got %d", config.DNSServer.ListenPort)
}
if len(config.DNSServer.CustomRecords) != 1 {
t.Errorf("Expected 1 custom DNS record, got %d", len(config.DNSServer.CustomRecords))
}
}
func TestLoadConfigDefaults(t *testing.T) {
// Create a minimal config file
tempDir := t.TempDir()
configFile := filepath.Join(tempDir, "minimal-config.yaml")
configContent := `
instance_id: test-instance
listen_address: 127.0.0.1:8080
ports:
- tcp://localhost:22
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
`
err := os.WriteFile(configFile, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
config, err := LoadConfig(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Check default values
if config.ReadTimeout != 30*time.Second {
t.Errorf("Expected default ReadTimeout 30s, got %v", config.ReadTimeout)
}
if config.WriteTimeout != 30*time.Second {
t.Errorf("Expected default WriteTimeout 30s, got %v", config.WriteTimeout)
}
if config.DNSServer.ListenPort != 5353 {
t.Errorf("Expected default DNS listen port 5353, got %d", config.DNSServer.ListenPort)
}
if config.DNSServer.BackupServer != "8.8.8.8:53" {
t.Errorf("Expected default backup server '8.8.8.8:53', got '%s'", config.DNSServer.BackupServer)
}
}
func TestDetectMode(t *testing.T) {
tests := []struct {
name string
listenAddress string
remoteAddress string
expectedMode string
expectedError bool
}{
{
name: "Server mode",
listenAddress: "127.0.0.1:8080",
remoteAddress: "",
expectedMode: "server",
expectedError: false,
},
{
name: "Client mode",
listenAddress: "",
remoteAddress: "127.0.0.1:8080",
expectedMode: "client",
expectedError: false,
},
{
name: "Both addresses set - error",
listenAddress: "127.0.0.1:8080",
remoteAddress: "127.0.0.1:8080",
expectedMode: "",
expectedError: true,
},
{
name: "Neither address set - error",
listenAddress: "",
remoteAddress: "",
expectedMode: "",
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &Config{
ListenAddress: tt.listenAddress,
RemoteAddress: tt.remoteAddress,
}
mode, err := DetectMode(config)
if tt.expectedError {
if err == nil {
t.Error("Expected error but got none")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if mode != tt.expectedMode {
t.Errorf("Expected mode '%s', got '%s'", tt.expectedMode, mode)
}
}
})
}
}
func TestLoadConfigFileNotFound(t *testing.T) {
_, err := LoadConfig("nonexistent-config.yaml")
if err == nil {
t.Error("Expected error for nonexistent config file")
}
}
func TestLoadConfigInvalidYAML(t *testing.T) {
tempDir := t.TempDir()
configFile := filepath.Join(tempDir, "invalid-config.yaml")
// Write invalid YAML
err := os.WriteFile(configFile, []byte("invalid: yaml: content: ["), 0644)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
_, err = LoadConfig(configFile)
if err == nil {
t.Error("Expected error for invalid YAML")
}
}
func TestPortRuleURLFormat(t *testing.T) {
tests := []struct {
name string
config string
expected PortRule
}{
{
name: "Server format with localhost",
config: `
instance_id: test
listen_address: :8080
ports:
- tcp://localhost:80
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
`,
expected: PortRule{
LocalPort: 80,
RemotePort: 80,
Protocol: "tcp",
TargetHost: "localhost",
},
},
{
name: "Server format with remote host",
config: `
instance_id: test
listen_address: :8080
ports:
- tcp://server-a:22
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
`,
expected: PortRule{
LocalPort: 22,
RemotePort: 22,
Protocol: "tcp",
TargetHost: "server-a",
},
},
{
name: "Client format",
config: `
instance_id: test
remote_address: localhost:8080
ports:
- tcp://80:8080
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
`,
expected: PortRule{
LocalPort: 8080,
RemotePort: 80,
Protocol: "tcp",
TargetHost: "",
},
},
{
name: "UDP format",
config: `
instance_id: test
listen_address: :8080
ports:
- udp://dns-server:53
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
`,
expected: PortRule{
LocalPort: 53,
RemotePort: 53,
Protocol: "udp",
TargetHost: "dns-server",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
configFile := filepath.Join(tempDir, "test-config.yaml")
err := os.WriteFile(configFile, []byte(tt.config), 0644)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
config, err := LoadConfig(configFile)
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
if len(config.Ports) != 1 {
t.Fatalf("Expected 1 port, got %d", len(config.Ports))
}
port := config.Ports[0]
if port.LocalPort != tt.expected.LocalPort {
t.Errorf("Expected LocalPort %d, got %d", tt.expected.LocalPort, port.LocalPort)
}
if port.RemotePort != tt.expected.RemotePort {
t.Errorf("Expected RemotePort %d, got %d", tt.expected.RemotePort, port.RemotePort)
}
if port.Protocol != tt.expected.Protocol {
t.Errorf("Expected Protocol %s, got %s", tt.expected.Protocol, port.Protocol)
}
if port.TargetHost != tt.expected.TargetHost {
t.Errorf("Expected TargetHost %s, got %s", tt.expected.TargetHost, port.TargetHost)
}
})
}
}
func TestPortRuleOldFormatFails(t *testing.T) {
// Test that old object format now fails
tempDir := t.TempDir()
configFile := filepath.Join(tempDir, "old-format-config.yaml")
configContent := `
instance_id: test
listen_address: :8080
ports:
- local_port: 22
remote_port: 22
protocol: tcp
encryption_key: test-key
`
err := os.WriteFile(configFile, []byte(configContent), 0644)
if err != nil {
t.Fatalf("Failed to write test config: %v", err)
}
_, err = LoadConfig(configFile)
if err == nil {
t.Error("Expected error for old format, but got none")
}
// Check that the error message mentions the expected format
if !strings.Contains(err.Error(), "protocol://") {
t.Errorf("Expected error message to mention 'protocol://', got: %v", err)
}
}