Add initial project structure with core functionality
- Created a new Go module named 'teleport' for secure port forwarding. - Added essential files including .gitignore, LICENSE, and README.md with project details. - Implemented configuration management with YAML support in config package. - Developed core client and server functionalities for handling port forwarding. - Introduced DNS server capabilities and integrated logging with sanitization. - Established rate limiting and metrics tracking for performance monitoring. - Included comprehensive tests for core components and functionalities. - Set up CI workflows for automated testing and release management using Gitea actions.
This commit is contained in:
537
pkg/config/config.go
Normal file
537
pkg/config/config.go
Normal file
@@ -0,0 +1,537 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"teleport/pkg/encryption"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config represents the teleport configuration
|
||||
type Config struct {
|
||||
InstanceID string `yaml:"instance_id"`
|
||||
ListenAddress string `yaml:"listen_address"`
|
||||
RemoteAddress string `yaml:"remote_address"`
|
||||
Ports []PortRule `yaml:"ports"`
|
||||
EncryptionKey string `yaml:"encryption_key"`
|
||||
KeepAlive bool `yaml:"keep_alive"`
|
||||
ReadTimeout time.Duration `yaml:"read_timeout"`
|
||||
WriteTimeout time.Duration `yaml:"write_timeout"`
|
||||
MaxConnections int `yaml:"max_connections"`
|
||||
RateLimit RateLimitConfig `yaml:"rate_limit"`
|
||||
DNSServer DNSServerConfig `yaml:"dns_server"`
|
||||
}
|
||||
|
||||
// PortRule defines a port forwarding rule
|
||||
type PortRule struct {
|
||||
LocalPort int `yaml:"local_port"`
|
||||
RemotePort int `yaml:"remote_port"`
|
||||
Protocol string `yaml:"protocol"` // "tcp" or "udp"
|
||||
TargetHost string `yaml:"target_host,omitempty"` // Target host for server-side forwarding (defaults to localhost)
|
||||
}
|
||||
|
||||
// UnmarshalYAML implements custom YAML unmarshaling for PortRule
|
||||
func (p *PortRule) UnmarshalYAML(value *yaml.Node) error {
|
||||
// Only support URL-style format: "protocol://target:targetport" (server) or "protocol://targetport:localport" (client)
|
||||
if value.Kind != yaml.ScalarNode {
|
||||
return fmt.Errorf("port must be a string in format 'protocol://target:port' or 'protocol://targetport:localport'")
|
||||
}
|
||||
|
||||
portStr := value.Value
|
||||
|
||||
// Validate input length
|
||||
if len(portStr) > 512 {
|
||||
return fmt.Errorf("port string too long")
|
||||
}
|
||||
|
||||
// Check if it starts with protocol://
|
||||
if !strings.Contains(portStr, "://") {
|
||||
return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port' or 'protocol://targetport:localport')", portStr)
|
||||
}
|
||||
|
||||
parts := strings.Split(portStr, "://")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("invalid port format: %s (expected 'protocol://target:port')", portStr)
|
||||
}
|
||||
|
||||
protocol := parts[0]
|
||||
if protocol != "tcp" && protocol != "udp" {
|
||||
return fmt.Errorf("invalid protocol: %s (expected 'tcp' or 'udp')", protocol)
|
||||
}
|
||||
|
||||
addressPart := parts[1]
|
||||
|
||||
// Validate address part length
|
||||
if len(addressPart) > 256 {
|
||||
return fmt.Errorf("address part too long")
|
||||
}
|
||||
|
||||
addressParts := strings.Split(addressPart, ":")
|
||||
|
||||
if len(addressParts) == 2 {
|
||||
// Check if first part is a hostname/IP (contains dots or is not a number) - server format
|
||||
if _, err := strconv.Atoi(addressParts[0]); err != nil || strings.Contains(addressParts[0], ".") {
|
||||
// Server format: protocol://target:targetport
|
||||
targetHost := addressParts[0]
|
||||
|
||||
// Validate target host
|
||||
if len(targetHost) > 253 { // RFC 1123 limit
|
||||
return fmt.Errorf("target host too long")
|
||||
}
|
||||
|
||||
// Basic character validation for hostname
|
||||
for _, c := range targetHost {
|
||||
if c < 32 || c > 126 {
|
||||
return fmt.Errorf("invalid character in target host")
|
||||
}
|
||||
}
|
||||
|
||||
targetPort, err := strconv.Atoi(addressParts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid target port: %s", addressParts[1])
|
||||
}
|
||||
|
||||
// Validate port range
|
||||
if targetPort < 1 || targetPort > 65535 {
|
||||
return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort)
|
||||
}
|
||||
|
||||
p.LocalPort = targetPort // Server listens on this port
|
||||
p.RemotePort = targetPort // Server forwards to this port on target
|
||||
p.Protocol = protocol
|
||||
p.TargetHost = targetHost
|
||||
return nil
|
||||
} else {
|
||||
// Client format: protocol://targetport:localport (first part is a number)
|
||||
targetPort, err := strconv.Atoi(addressParts[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid target port: %s", addressParts[0])
|
||||
}
|
||||
localPort, err := strconv.Atoi(addressParts[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid local port: %s", addressParts[1])
|
||||
}
|
||||
|
||||
// Validate port ranges
|
||||
if targetPort < 1 || targetPort > 65535 {
|
||||
return fmt.Errorf("invalid target port: %d (must be 1-65535)", targetPort)
|
||||
}
|
||||
if localPort < 1 || localPort > 65535 {
|
||||
return fmt.Errorf("invalid local port: %d (must be 1-65535)", localPort)
|
||||
}
|
||||
|
||||
p.LocalPort = localPort // Client listens on this port
|
||||
p.RemotePort = targetPort // Client forwards to this port on teleport server
|
||||
p.Protocol = protocol
|
||||
p.TargetHost = "" // Client doesn't specify target host
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("invalid address format: %s (expected 'target:port' for server or 'targetport:localport' for client)", addressPart)
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshaling for PortRule
|
||||
func (p PortRule) MarshalYAML() (interface{}, error) {
|
||||
// Use new format: protocol://target:targetport (server) or protocol://targetport:localport (client)
|
||||
if p.TargetHost == "" {
|
||||
// Client format: protocol://targetport:localport
|
||||
return fmt.Sprintf("%s://%d:%d", p.Protocol, p.RemotePort, p.LocalPort), nil
|
||||
} else {
|
||||
// Server format: protocol://target:targetport
|
||||
return fmt.Sprintf("%s://%s:%d", p.Protocol, p.TargetHost, p.RemotePort), nil
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitConfig defines rate limiting configuration
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
RequestsPerSecond int `yaml:"requests_per_second"`
|
||||
BurstSize int `yaml:"burst_size"`
|
||||
WindowSize time.Duration `yaml:"window_size"`
|
||||
}
|
||||
|
||||
// DNSServerConfig defines DNS server configuration
|
||||
type DNSServerConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
ListenPort int `yaml:"listen_port"`
|
||||
BackupServer string `yaml:"backup_server"`
|
||||
CustomRecords []DNSRecord `yaml:"custom_records"`
|
||||
}
|
||||
|
||||
// DNSRecord defines a custom DNS record
|
||||
type DNSRecord struct {
|
||||
Name string `yaml:"name"`
|
||||
Type string `yaml:"type"` // A, AAAA, CNAME, MX, TXT, NS, SRV
|
||||
Value string `yaml:"value"`
|
||||
TTL uint32 `yaml:"ttl"`
|
||||
Priority uint16 `yaml:"priority,omitempty"` // For MX and SRV records
|
||||
Weight uint16 `yaml:"weight,omitempty"` // For SRV records
|
||||
Port uint16 `yaml:"port,omitempty"` // For SRV records
|
||||
}
|
||||
|
||||
// LoadConfig loads and parses the configuration file
|
||||
func LoadConfig(filename string) (*Config, error) {
|
||||
// Check file size before reading (max 1MB)
|
||||
fileInfo, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to stat config file: %v", err)
|
||||
}
|
||||
if fileInfo.Size() > 1024*1024 { // 1MB limit
|
||||
return nil, fmt.Errorf("config file too large: %d bytes (max 1MB)", fileInfo.Size())
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %v", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config file: %v", err)
|
||||
}
|
||||
|
||||
// Set default values
|
||||
if config.ReadTimeout == 0 {
|
||||
config.ReadTimeout = 30 * time.Second
|
||||
}
|
||||
if config.WriteTimeout == 0 {
|
||||
config.WriteTimeout = 30 * time.Second
|
||||
}
|
||||
if config.MaxConnections == 0 {
|
||||
config.MaxConnections = 1000
|
||||
}
|
||||
if config.RateLimit.RequestsPerSecond == 0 {
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
}
|
||||
if config.RateLimit.BurstSize == 0 {
|
||||
config.RateLimit.BurstSize = 200
|
||||
}
|
||||
if config.RateLimit.WindowSize == 0 {
|
||||
config.RateLimit.WindowSize = 1 * time.Second
|
||||
}
|
||||
if config.DNSServer.ListenPort == 0 {
|
||||
config.DNSServer.ListenPort = 5353
|
||||
}
|
||||
if config.DNSServer.BackupServer == "" {
|
||||
config.DNSServer.BackupServer = "8.8.8.8:53"
|
||||
}
|
||||
|
||||
// Validate configuration
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("configuration validation failed: %v", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// DetectMode determines if this should run as server or client based on config
|
||||
func DetectMode(config *Config) (string, error) {
|
||||
hasListen := config.ListenAddress != ""
|
||||
hasRemote := config.RemoteAddress != ""
|
||||
|
||||
if hasListen && hasRemote {
|
||||
return "", fmt.Errorf("configuration error: cannot have both 'listen_address' and 'remote_address' set. Choose one:\n" +
|
||||
"- Server: set 'listen_address' and leave 'remote_address' empty\n" +
|
||||
"- Client: set 'remote_address' and leave 'listen_address' empty")
|
||||
}
|
||||
|
||||
if !hasListen && !hasRemote {
|
||||
return "", fmt.Errorf("configuration error: must specify either 'listen_address' (for server) or 'remote_address' (for client)")
|
||||
}
|
||||
|
||||
if hasListen && !hasRemote {
|
||||
return "server", nil
|
||||
}
|
||||
|
||||
if hasRemote && !hasListen {
|
||||
return "client", nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("internal error: could not determine mode")
|
||||
}
|
||||
|
||||
// Validate validates the configuration
|
||||
func (c *Config) Validate() error {
|
||||
// Validate instance ID
|
||||
if c.InstanceID == "" {
|
||||
return fmt.Errorf("instance_id is required")
|
||||
}
|
||||
|
||||
// Validate encryption key
|
||||
if err := encryption.ValidateEncryptionKey(c.EncryptionKey); err != nil {
|
||||
return fmt.Errorf("invalid encryption key: %v", err)
|
||||
}
|
||||
|
||||
// Validate ports
|
||||
if len(c.Ports) == 0 {
|
||||
return fmt.Errorf("at least one port rule is required")
|
||||
}
|
||||
|
||||
for i, port := range c.Ports {
|
||||
if err := port.Validate(); err != nil {
|
||||
return fmt.Errorf("port rule %d is invalid: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate DNS server configuration
|
||||
if c.DNSServer.Enabled {
|
||||
if err := c.DNSServer.Validate(); err != nil {
|
||||
return fmt.Errorf("DNS server configuration is invalid: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate rate limiting configuration
|
||||
if c.RateLimit.Enabled {
|
||||
if c.RateLimit.RequestsPerSecond <= 0 {
|
||||
return fmt.Errorf("rate limit requests_per_second must be positive")
|
||||
}
|
||||
if c.RateLimit.BurstSize <= 0 {
|
||||
return fmt.Errorf("rate limit burst_size must be positive")
|
||||
}
|
||||
if c.RateLimit.WindowSize <= 0 {
|
||||
return fmt.Errorf("rate limit window_size must be positive")
|
||||
}
|
||||
if c.RateLimit.BurstSize < c.RateLimit.RequestsPerSecond {
|
||||
return fmt.Errorf("rate limit burst_size should be at least requests_per_second")
|
||||
}
|
||||
if c.RateLimit.RequestsPerSecond > 10000 {
|
||||
return fmt.Errorf("rate limit requests_per_second too high: %d (maximum 10000)", c.RateLimit.RequestsPerSecond)
|
||||
}
|
||||
if c.RateLimit.BurstSize > 50000 {
|
||||
return fmt.Errorf("rate limit burst_size too high: %d (maximum 50000)", c.RateLimit.BurstSize)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate connection limits
|
||||
if c.MaxConnections < 0 {
|
||||
return fmt.Errorf("max_connections cannot be negative")
|
||||
}
|
||||
if c.MaxConnections > 100000 {
|
||||
return fmt.Errorf("max_connections too high: %d (maximum 100000)", c.MaxConnections)
|
||||
}
|
||||
|
||||
// Validate timeout values
|
||||
if c.ReadTimeout < 0 {
|
||||
return fmt.Errorf("read_timeout cannot be negative")
|
||||
}
|
||||
if c.WriteTimeout < 0 {
|
||||
return fmt.Errorf("write_timeout cannot be negative")
|
||||
}
|
||||
if c.ReadTimeout > 24*time.Hour {
|
||||
return fmt.Errorf("read_timeout too high: %v (maximum 24 hours)", c.ReadTimeout)
|
||||
}
|
||||
if c.WriteTimeout > 24*time.Hour {
|
||||
return fmt.Errorf("write_timeout too high: %v (maximum 24 hours)", c.WriteTimeout)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates a port rule
|
||||
func (p *PortRule) Validate() error {
|
||||
if p.LocalPort < 1 || p.LocalPort > 65535 {
|
||||
return fmt.Errorf("local_port must be between 1 and 65535, got %d", p.LocalPort)
|
||||
}
|
||||
|
||||
if p.RemotePort < 1 || p.RemotePort > 65535 {
|
||||
return fmt.Errorf("remote_port must be between 1 and 65535, got %d", p.RemotePort)
|
||||
}
|
||||
|
||||
if p.Protocol != "tcp" && p.Protocol != "udp" {
|
||||
return fmt.Errorf("protocol must be 'tcp' or 'udp', got '%s'", p.Protocol)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates DNS server configuration
|
||||
func (d *DNSServerConfig) Validate() error {
|
||||
if d.ListenPort < 1 || d.ListenPort > 65535 {
|
||||
return fmt.Errorf("DNS listen_port must be between 1 and 65535, got %d", d.ListenPort)
|
||||
}
|
||||
|
||||
if d.BackupServer == "" {
|
||||
return fmt.Errorf("DNS backup_server is required when DNS server is enabled")
|
||||
}
|
||||
|
||||
// Validate backup server format
|
||||
parts := strings.Split(d.BackupServer, ":")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("DNS backup_server must be in format 'host:port', got '%s'", d.BackupServer)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(parts[1])
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
return fmt.Errorf("DNS backup_server port must be between 1 and 65535, got '%s'", parts[1])
|
||||
}
|
||||
|
||||
// Validate custom records
|
||||
for i, record := range d.CustomRecords {
|
||||
if err := record.Validate(); err != nil {
|
||||
return fmt.Errorf("DNS record %d is invalid: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates a DNS record
|
||||
func (r *DNSRecord) Validate() error {
|
||||
if r.Name == "" {
|
||||
return fmt.Errorf("DNS record name is required")
|
||||
}
|
||||
|
||||
// Validate DNS name length and format
|
||||
if len(r.Name) > 253 {
|
||||
return fmt.Errorf("DNS record name too long")
|
||||
}
|
||||
|
||||
// Basic character validation for DNS name
|
||||
for _, c := range r.Name {
|
||||
if c < 32 || c > 126 {
|
||||
return fmt.Errorf("invalid character in DNS record name")
|
||||
}
|
||||
}
|
||||
|
||||
validTypes := []string{"A", "AAAA", "CNAME", "MX", "TXT", "SRV", "NS"}
|
||||
valid := false
|
||||
for _, t := range validTypes {
|
||||
if strings.ToUpper(r.Type) == t {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valid {
|
||||
return fmt.Errorf("DNS record type must be one of %v, got '%s'", validTypes, r.Type)
|
||||
}
|
||||
|
||||
if r.Value == "" {
|
||||
return fmt.Errorf("DNS record value is required")
|
||||
}
|
||||
|
||||
// Validate DNS record value length
|
||||
if len(r.Value) > 1024 {
|
||||
return fmt.Errorf("DNS record value too long")
|
||||
}
|
||||
|
||||
// Basic character validation for DNS value
|
||||
for _, c := range r.Value {
|
||||
if c < 32 || c > 126 {
|
||||
return fmt.Errorf("invalid character in DNS record value")
|
||||
}
|
||||
}
|
||||
|
||||
if r.TTL == 0 {
|
||||
return fmt.Errorf("DNS record TTL must be greater than 0")
|
||||
}
|
||||
|
||||
// Validate TTL range (reasonable limits)
|
||||
if r.TTL > 86400 { // 24 hours
|
||||
return fmt.Errorf("DNS record TTL too large (max 86400 seconds)")
|
||||
}
|
||||
|
||||
// Note: Priority, Weight, and Port are uint16 types, so they cannot exceed 65535
|
||||
// No additional validation needed for these fields as they are already constrained by their type
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateExampleConfig generates an example configuration file
|
||||
func GenerateExampleConfig(filename string) error {
|
||||
// Generate a strong encryption key
|
||||
strongKey, err := generateStrongEncryptionKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate encryption key: %v", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if strings.Contains(filename, "server") {
|
||||
// Generate server configuration
|
||||
config = Config{
|
||||
InstanceID: "teleport-server-01",
|
||||
ListenAddress: ":8080",
|
||||
RemoteAddress: "",
|
||||
Ports: []PortRule{
|
||||
{LocalPort: 80, RemotePort: 80, Protocol: "tcp", TargetHost: "localhost"},
|
||||
},
|
||||
EncryptionKey: strongKey,
|
||||
KeepAlive: true,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
MaxConnections: 1000,
|
||||
RateLimit: RateLimitConfig{
|
||||
Enabled: true,
|
||||
RequestsPerSecond: 100,
|
||||
BurstSize: 200,
|
||||
WindowSize: 1 * time.Second,
|
||||
},
|
||||
DNSServer: DNSServerConfig{
|
||||
ListenPort: 5353,
|
||||
BackupServer: "8.8.8.8:53",
|
||||
CustomRecords: []DNSRecord{},
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Generate client configuration
|
||||
config = Config{
|
||||
InstanceID: "teleport-client-01",
|
||||
ListenAddress: "",
|
||||
RemoteAddress: "localhost:8080",
|
||||
Ports: []PortRule{
|
||||
{LocalPort: 8080, RemotePort: 80, Protocol: "tcp", TargetHost: ""},
|
||||
},
|
||||
EncryptionKey: strongKey,
|
||||
KeepAlive: true,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
MaxConnections: 100,
|
||||
RateLimit: RateLimitConfig{
|
||||
Enabled: true,
|
||||
RequestsPerSecond: 50,
|
||||
BurstSize: 100,
|
||||
WindowSize: 1 * time.Second,
|
||||
},
|
||||
DNSServer: DNSServerConfig{
|
||||
ListenPort: 5353,
|
||||
BackupServer: "8.8.8.8:53",
|
||||
CustomRecords: []DNSRecord{
|
||||
{Name: "app.local", Type: "A", Value: "127.0.0.1", TTL: 300},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(&config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(filename, data, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write config file: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Generated example configuration: %s\n", filename)
|
||||
fmt.Printf("Edit the configuration file and run: ./teleport -config %s\n", filename)
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateStrongEncryptionKey generates a cryptographically secure encryption key
|
||||
func generateStrongEncryptionKey() (string, error) {
|
||||
// Generate 32 random bytes (256 bits) for a strong encryption key
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate random key: %v", err)
|
||||
}
|
||||
|
||||
// Convert to hexadecimal string for easy copying
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
363
pkg/config/config_test.go
Normal file
363
pkg/config/config_test.go
Normal file
@@ -0,0 +1,363 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLoadConfig(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "test-config.yaml")
|
||||
|
||||
configContent := `
|
||||
instance_id: test-instance
|
||||
listen_address: 127.0.0.1:8080
|
||||
remote_address: ""
|
||||
ports:
|
||||
- tcp://localhost:22
|
||||
- tcp://localhost:80
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
keep_alive: true
|
||||
read_timeout: 30s
|
||||
write_timeout: 30s
|
||||
dns_server:
|
||||
enabled: true
|
||||
listen_port: 5353
|
||||
backup_server: 8.8.8.8:53
|
||||
custom_records:
|
||||
- name: test.local
|
||||
type: A
|
||||
value: 192.168.1.100
|
||||
ttl: 300
|
||||
`
|
||||
|
||||
err := os.WriteFile(configFile, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
// Test loading config
|
||||
config, err := LoadConfig(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Verify config values
|
||||
if config.InstanceID != "test-instance" {
|
||||
t.Errorf("Expected InstanceID 'test-instance', got '%s'", config.InstanceID)
|
||||
}
|
||||
|
||||
if config.ListenAddress != "127.0.0.1:8080" {
|
||||
t.Errorf("Expected ListenAddress '127.0.0.1:8080', got '%s'", config.ListenAddress)
|
||||
}
|
||||
|
||||
if config.RemoteAddress != "" {
|
||||
t.Errorf("Expected empty RemoteAddress, got '%s'", config.RemoteAddress)
|
||||
}
|
||||
|
||||
if len(config.Ports) != 2 {
|
||||
t.Errorf("Expected 2 ports, got %d", len(config.Ports))
|
||||
}
|
||||
|
||||
if config.Ports[0].LocalPort != 22 || config.Ports[0].RemotePort != 22 || config.Ports[0].Protocol != "tcp" || config.Ports[0].TargetHost != "localhost" {
|
||||
t.Error("First port rule is incorrect")
|
||||
}
|
||||
|
||||
if config.EncryptionKey != "a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03" {
|
||||
t.Errorf("Expected EncryptionKey 'a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03', got '%s'", config.EncryptionKey)
|
||||
}
|
||||
|
||||
if !config.KeepAlive {
|
||||
t.Error("Expected KeepAlive to be true")
|
||||
}
|
||||
|
||||
if config.ReadTimeout != 30*time.Second {
|
||||
t.Errorf("Expected ReadTimeout 30s, got %v", config.ReadTimeout)
|
||||
}
|
||||
|
||||
if config.WriteTimeout != 30*time.Second {
|
||||
t.Errorf("Expected WriteTimeout 30s, got %v", config.WriteTimeout)
|
||||
}
|
||||
|
||||
if !config.DNSServer.Enabled {
|
||||
t.Error("Expected DNS server to be enabled")
|
||||
}
|
||||
|
||||
if config.DNSServer.ListenPort != 5353 {
|
||||
t.Errorf("Expected DNS listen port 5353, got %d", config.DNSServer.ListenPort)
|
||||
}
|
||||
|
||||
if len(config.DNSServer.CustomRecords) != 1 {
|
||||
t.Errorf("Expected 1 custom DNS record, got %d", len(config.DNSServer.CustomRecords))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigDefaults(t *testing.T) {
|
||||
// Create a minimal config file
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "minimal-config.yaml")
|
||||
|
||||
configContent := `
|
||||
instance_id: test-instance
|
||||
listen_address: 127.0.0.1:8080
|
||||
ports:
|
||||
- tcp://localhost:22
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
`
|
||||
|
||||
err := os.WriteFile(configFile, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
config, err := LoadConfig(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Check default values
|
||||
if config.ReadTimeout != 30*time.Second {
|
||||
t.Errorf("Expected default ReadTimeout 30s, got %v", config.ReadTimeout)
|
||||
}
|
||||
|
||||
if config.WriteTimeout != 30*time.Second {
|
||||
t.Errorf("Expected default WriteTimeout 30s, got %v", config.WriteTimeout)
|
||||
}
|
||||
|
||||
if config.DNSServer.ListenPort != 5353 {
|
||||
t.Errorf("Expected default DNS listen port 5353, got %d", config.DNSServer.ListenPort)
|
||||
}
|
||||
|
||||
if config.DNSServer.BackupServer != "8.8.8.8:53" {
|
||||
t.Errorf("Expected default backup server '8.8.8.8:53', got '%s'", config.DNSServer.BackupServer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
listenAddress string
|
||||
remoteAddress string
|
||||
expectedMode string
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Server mode",
|
||||
listenAddress: "127.0.0.1:8080",
|
||||
remoteAddress: "",
|
||||
expectedMode: "server",
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Client mode",
|
||||
listenAddress: "",
|
||||
remoteAddress: "127.0.0.1:8080",
|
||||
expectedMode: "client",
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Both addresses set - error",
|
||||
listenAddress: "127.0.0.1:8080",
|
||||
remoteAddress: "127.0.0.1:8080",
|
||||
expectedMode: "",
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Neither address set - error",
|
||||
listenAddress: "",
|
||||
remoteAddress: "",
|
||||
expectedMode: "",
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := &Config{
|
||||
ListenAddress: tt.listenAddress,
|
||||
RemoteAddress: tt.remoteAddress,
|
||||
}
|
||||
|
||||
mode, err := DetectMode(config)
|
||||
|
||||
if tt.expectedError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if mode != tt.expectedMode {
|
||||
t.Errorf("Expected mode '%s', got '%s'", tt.expectedMode, mode)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigFileNotFound(t *testing.T) {
|
||||
_, err := LoadConfig("nonexistent-config.yaml")
|
||||
if err == nil {
|
||||
t.Error("Expected error for nonexistent config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigInvalidYAML(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "invalid-config.yaml")
|
||||
|
||||
// Write invalid YAML
|
||||
err := os.WriteFile(configFile, []byte("invalid: yaml: content: ["), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
_, err = LoadConfig(configFile)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid YAML")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortRuleURLFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config string
|
||||
expected PortRule
|
||||
}{
|
||||
{
|
||||
name: "Server format with localhost",
|
||||
config: `
|
||||
instance_id: test
|
||||
listen_address: :8080
|
||||
ports:
|
||||
- tcp://localhost:80
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
`,
|
||||
expected: PortRule{
|
||||
LocalPort: 80,
|
||||
RemotePort: 80,
|
||||
Protocol: "tcp",
|
||||
TargetHost: "localhost",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Server format with remote host",
|
||||
config: `
|
||||
instance_id: test
|
||||
listen_address: :8080
|
||||
ports:
|
||||
- tcp://server-a:22
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
`,
|
||||
expected: PortRule{
|
||||
LocalPort: 22,
|
||||
RemotePort: 22,
|
||||
Protocol: "tcp",
|
||||
TargetHost: "server-a",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Client format",
|
||||
config: `
|
||||
instance_id: test
|
||||
remote_address: localhost:8080
|
||||
ports:
|
||||
- tcp://80:8080
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
`,
|
||||
expected: PortRule{
|
||||
LocalPort: 8080,
|
||||
RemotePort: 80,
|
||||
Protocol: "tcp",
|
||||
TargetHost: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UDP format",
|
||||
config: `
|
||||
instance_id: test
|
||||
listen_address: :8080
|
||||
ports:
|
||||
- udp://dns-server:53
|
||||
encryption_key: a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03
|
||||
`,
|
||||
expected: PortRule{
|
||||
LocalPort: 53,
|
||||
RemotePort: 53,
|
||||
Protocol: "udp",
|
||||
TargetHost: "dns-server",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "test-config.yaml")
|
||||
|
||||
err := os.WriteFile(configFile, []byte(tt.config), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
config, err := LoadConfig(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
if len(config.Ports) != 1 {
|
||||
t.Fatalf("Expected 1 port, got %d", len(config.Ports))
|
||||
}
|
||||
|
||||
port := config.Ports[0]
|
||||
if port.LocalPort != tt.expected.LocalPort {
|
||||
t.Errorf("Expected LocalPort %d, got %d", tt.expected.LocalPort, port.LocalPort)
|
||||
}
|
||||
if port.RemotePort != tt.expected.RemotePort {
|
||||
t.Errorf("Expected RemotePort %d, got %d", tt.expected.RemotePort, port.RemotePort)
|
||||
}
|
||||
if port.Protocol != tt.expected.Protocol {
|
||||
t.Errorf("Expected Protocol %s, got %s", tt.expected.Protocol, port.Protocol)
|
||||
}
|
||||
if port.TargetHost != tt.expected.TargetHost {
|
||||
t.Errorf("Expected TargetHost %s, got %s", tt.expected.TargetHost, port.TargetHost)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortRuleOldFormatFails(t *testing.T) {
|
||||
// Test that old object format now fails
|
||||
tempDir := t.TempDir()
|
||||
configFile := filepath.Join(tempDir, "old-format-config.yaml")
|
||||
|
||||
configContent := `
|
||||
instance_id: test
|
||||
listen_address: :8080
|
||||
ports:
|
||||
- local_port: 22
|
||||
remote_port: 22
|
||||
protocol: tcp
|
||||
encryption_key: test-key
|
||||
`
|
||||
|
||||
err := os.WriteFile(configFile, []byte(configContent), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
_, err = LoadConfig(configFile)
|
||||
if err == nil {
|
||||
t.Error("Expected error for old format, but got none")
|
||||
}
|
||||
|
||||
// Check that the error message mentions the expected format
|
||||
if !strings.Contains(err.Error(), "protocol://") {
|
||||
t.Errorf("Expected error message to mention 'protocol://', got: %v", err)
|
||||
}
|
||||
}
|
||||
329
pkg/dns/dns.go
Normal file
329
pkg/dns/dns.go
Normal file
@@ -0,0 +1,329 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"teleport/pkg/config"
|
||||
"teleport/pkg/logger"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// StartDNSServer starts the built-in DNS server using miekg/dns
|
||||
func StartDNSServer(cfg *config.Config) {
|
||||
if !cfg.DNSServer.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
// Create DNS server
|
||||
server := &dns.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.DNSServer.ListenPort),
|
||||
Net: "udp",
|
||||
}
|
||||
|
||||
// Set up handler
|
||||
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
handleDNSQuery(w, r, cfg)
|
||||
})
|
||||
|
||||
logger.WithField("port", cfg.DNSServer.ListenPort).Info("DNS server started")
|
||||
|
||||
// Start server
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
logger.WithField("error", err).Error("Failed to start DNS server")
|
||||
}
|
||||
}
|
||||
|
||||
// handleDNSQuery handles DNS queries using miekg/dns
|
||||
func handleDNSQuery(w dns.ResponseWriter, r *dns.Msg, cfg *config.Config) {
|
||||
// Check if we have custom records for this query
|
||||
response := checkCustomRecords(r, cfg)
|
||||
if response == nil {
|
||||
// Forward to backup DNS server
|
||||
response = forwardToBackupDNS(r, cfg)
|
||||
if response == nil {
|
||||
// Send error response
|
||||
response = new(dns.Msg)
|
||||
response.SetRcode(r, dns.RcodeServerFailure)
|
||||
}
|
||||
}
|
||||
|
||||
// Send response
|
||||
w.WriteMsg(response)
|
||||
}
|
||||
|
||||
// checkCustomRecords checks if we have custom records for the query
|
||||
func checkCustomRecords(query *dns.Msg, cfg *config.Config) *dns.Msg {
|
||||
if len(query.Question) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
question := query.Question[0]
|
||||
questionName := strings.ToLower(question.Name)
|
||||
|
||||
// Validate question name length and format
|
||||
if len(questionName) > 253 {
|
||||
logger.WithField("name", questionName).Warn("DNS query name too long, ignoring")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Basic character validation for DNS name
|
||||
for _, c := range questionName {
|
||||
if c < 32 || c > 126 {
|
||||
logger.WithField("name", questionName).Warn("DNS query name contains invalid characters, ignoring")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Look for matching custom records
|
||||
var answers []dns.RR
|
||||
for _, record := range cfg.DNSServer.CustomRecords {
|
||||
if strings.ToLower(record.Name) == questionName && getRecordType(record.Type) == question.Qtype {
|
||||
answer := createDNSRecord(record, questionName)
|
||||
if answer != nil {
|
||||
answers = append(answers, answer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(answers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create response
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(query)
|
||||
response.Authoritative = true
|
||||
response.Answer = answers
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// getRecordType converts string record type to DNS type code
|
||||
func getRecordType(recordType string) uint16 {
|
||||
switch strings.ToUpper(recordType) {
|
||||
case "A":
|
||||
return dns.TypeA
|
||||
case "AAAA":
|
||||
return dns.TypeAAAA
|
||||
case "CNAME":
|
||||
return dns.TypeCNAME
|
||||
case "MX":
|
||||
return dns.TypeMX
|
||||
case "TXT":
|
||||
return dns.TypeTXT
|
||||
case "NS":
|
||||
return dns.TypeNS
|
||||
case "SRV":
|
||||
return dns.TypeSRV
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// createDNSRecord creates a DNS resource record from a custom record
|
||||
func createDNSRecord(record config.DNSRecord, name string) dns.RR {
|
||||
// Validate record value length
|
||||
if len(record.Value) > 1024 {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"type": record.Type,
|
||||
"name": name,
|
||||
"value": record.Value,
|
||||
}).Warn("DNS record value too long, ignoring")
|
||||
return nil
|
||||
}
|
||||
|
||||
switch strings.ToUpper(record.Type) {
|
||||
case "A":
|
||||
// IPv4 address
|
||||
ip := net.ParseIP(record.Value)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"type": record.Type,
|
||||
"name": name,
|
||||
"value": record.Value,
|
||||
}).Warn("Invalid IPv4 address in DNS record, ignoring")
|
||||
return nil
|
||||
}
|
||||
return &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: record.TTL,
|
||||
},
|
||||
A: ip.To4(),
|
||||
}
|
||||
|
||||
case "AAAA":
|
||||
// IPv6 address
|
||||
ip := net.ParseIP(record.Value)
|
||||
if ip == nil || ip.To16() == nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"type": record.Type,
|
||||
"name": name,
|
||||
"value": record.Value,
|
||||
}).Warn("Invalid IPv6 address in DNS record, ignoring")
|
||||
return nil
|
||||
}
|
||||
return &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: record.TTL,
|
||||
},
|
||||
AAAA: ip.To16(),
|
||||
}
|
||||
|
||||
case "CNAME":
|
||||
// Canonical name
|
||||
// Validate CNAME target length
|
||||
if len(record.Value) > 253 {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"type": record.Type,
|
||||
"name": name,
|
||||
"value": record.Value,
|
||||
}).Warn("CNAME target too long, ignoring")
|
||||
return nil
|
||||
}
|
||||
return &dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeCNAME,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: record.TTL,
|
||||
},
|
||||
Target: dns.Fqdn(record.Value),
|
||||
}
|
||||
|
||||
case "MX":
|
||||
// Mail exchange
|
||||
// Validate MX target length
|
||||
if len(record.Value) > 253 {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"type": record.Type,
|
||||
"name": name,
|
||||
"value": record.Value,
|
||||
}).Warn("MX target too long, ignoring")
|
||||
return nil
|
||||
}
|
||||
return &dns.MX{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeMX,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: record.TTL,
|
||||
},
|
||||
Preference: record.Priority,
|
||||
Mx: dns.Fqdn(record.Value),
|
||||
}
|
||||
|
||||
case "TXT":
|
||||
// Text record
|
||||
// Validate TXT record length (RFC 1035 limit)
|
||||
if len(record.Value) > 255 {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"type": record.Type,
|
||||
"name": name,
|
||||
"value": record.Value,
|
||||
}).Warn("TXT record too long, ignoring")
|
||||
return nil
|
||||
}
|
||||
return &dns.TXT{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: record.TTL,
|
||||
},
|
||||
Txt: []string{record.Value},
|
||||
}
|
||||
|
||||
case "NS":
|
||||
// Name server
|
||||
// Validate NS target length
|
||||
if len(record.Value) > 253 {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"type": record.Type,
|
||||
"name": name,
|
||||
"value": record.Value,
|
||||
}).Warn("NS target too long, ignoring")
|
||||
return nil
|
||||
}
|
||||
return &dns.NS{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeNS,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: record.TTL,
|
||||
},
|
||||
Ns: dns.Fqdn(record.Value),
|
||||
}
|
||||
|
||||
case "SRV":
|
||||
// Service record
|
||||
// Validate SRV target length
|
||||
if len(record.Value) > 253 {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"type": record.Type,
|
||||
"name": name,
|
||||
"value": record.Value,
|
||||
}).Warn("SRV target too long, ignoring")
|
||||
return nil
|
||||
}
|
||||
return &dns.SRV{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: name,
|
||||
Rrtype: dns.TypeSRV,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: record.TTL,
|
||||
},
|
||||
Priority: record.Priority,
|
||||
Weight: record.Weight,
|
||||
Port: record.Port,
|
||||
Target: dns.Fqdn(record.Value),
|
||||
}
|
||||
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// forwardToBackupDNS forwards the query to the backup DNS server
|
||||
func forwardToBackupDNS(query *dns.Msg, cfg *config.Config) *dns.Msg {
|
||||
// Create DNS client with timeout
|
||||
client := new(dns.Client)
|
||||
client.Net = "udp"
|
||||
client.Timeout = 5 * time.Second // 5 second timeout
|
||||
|
||||
// Forward query to backup server
|
||||
response, _, err := client.Exchange(query, cfg.DNSServer.BackupServer)
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Error("Failed to forward DNS query to backup server")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if response == nil {
|
||||
logger.Warn("Received nil response from backup DNS server")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Basic response validation
|
||||
if response.Id != query.Id {
|
||||
logger.Warn("DNS response ID mismatch, potential spoofing attempt")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Limit response size to prevent amplification attacks
|
||||
if len(response.Answer) > 10 {
|
||||
logger.WithField("answer_count", len(response.Answer)).Warn("DNS response has too many answers, truncating")
|
||||
response.Answer = response.Answer[:10]
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
226
pkg/encryption/encryption.go
Normal file
226
pkg/encryption/encryption.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
const (
|
||||
// PBKDF2 parameters for key derivation
|
||||
PBKDF2Iterations = 100000 // OWASP recommended minimum
|
||||
PBKDF2KeyLength = 32 // 256 bits
|
||||
PBKDF2SaltLength = 16 // 128 bits
|
||||
|
||||
// Replay protection parameters
|
||||
MaxPacketAge = 5 * time.Minute // Maximum age for UDP packets
|
||||
NonceWindow = 1000 // Number of nonces to track for replay protection
|
||||
)
|
||||
|
||||
// DeriveKey derives an encryption key from a password using PBKDF2
|
||||
func DeriveKey(password string) []byte {
|
||||
// Use a deterministic salt derived from the password hash for consistent key derivation
|
||||
// This ensures the same password always produces the same key while avoiding rainbow tables
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(password))
|
||||
passwordHash := hasher.Sum(nil)
|
||||
|
||||
// Create a deterministic salt from the password hash
|
||||
salt := make([]byte, PBKDF2SaltLength)
|
||||
copy(salt, passwordHash[:PBKDF2SaltLength])
|
||||
|
||||
key := pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, PBKDF2KeyLength, sha256.New)
|
||||
return key
|
||||
}
|
||||
|
||||
// DeriveKeyWithSalt derives an encryption key from a password using PBKDF2 with a custom salt
|
||||
func DeriveKeyWithSalt(password string, salt []byte) ([]byte, error) {
|
||||
if len(salt) != PBKDF2SaltLength {
|
||||
return nil, fmt.Errorf("salt length must be %d bytes, got %d", PBKDF2SaltLength, len(salt))
|
||||
}
|
||||
|
||||
key := pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, PBKDF2KeyLength, sha256.New)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GenerateSalt generates a cryptographically secure random salt
|
||||
func GenerateSalt() ([]byte, error) {
|
||||
salt := make([]byte, PBKDF2SaltLength)
|
||||
_, err := io.ReadFull(rand.Reader, salt)
|
||||
return salt, err
|
||||
}
|
||||
|
||||
// EncryptData encrypts data using AES-GCM
|
||||
func EncryptData(data []byte, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, data, nil)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// DecryptData decrypts data using AES-GCM
|
||||
func DecryptData(data []byte, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return nil, fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// ReplayProtection tracks nonces to prevent replay attacks
|
||||
type ReplayProtection struct {
|
||||
nonces map[uint64]time.Time
|
||||
mutex sync.RWMutex
|
||||
maxNonces int // Maximum number of nonces to track
|
||||
}
|
||||
|
||||
// NewReplayProtection creates a new replay protection instance
|
||||
func NewReplayProtection() *ReplayProtection {
|
||||
return &ReplayProtection{
|
||||
nonces: make(map[uint64]time.Time),
|
||||
maxNonces: NonceWindow * 2, // Allow 2x the window size for safety
|
||||
}
|
||||
}
|
||||
|
||||
// IsValidNonce checks if a nonce is valid (not replayed and not too old)
|
||||
func (rp *ReplayProtection) IsValidNonce(nonce uint64, timestamp int64) bool {
|
||||
rp.mutex.Lock()
|
||||
defer rp.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
packetTime := time.Unix(timestamp, 0)
|
||||
|
||||
// Check if packet is too old
|
||||
if now.Sub(packetTime) > MaxPacketAge {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if nonce was already used
|
||||
if _, exists := rp.nonces[nonce]; exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Add nonce to tracking
|
||||
rp.nonces[nonce] = now
|
||||
|
||||
// Clean up old nonces to prevent memory leaks
|
||||
if len(rp.nonces) > rp.maxNonces {
|
||||
rp.cleanupOldNonces(now)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupOldNonces removes nonces older than MaxPacketAge
|
||||
func (rp *ReplayProtection) cleanupOldNonces(now time.Time) {
|
||||
for nonce, timestamp := range rp.nonces {
|
||||
if now.Sub(timestamp) > MaxPacketAge {
|
||||
delete(rp.nonces, nonce)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePacketTimestamp validates that a packet timestamp is within acceptable range
|
||||
func ValidatePacketTimestamp(timestamp int64) bool {
|
||||
now := time.Now().Unix()
|
||||
packetTime := timestamp
|
||||
|
||||
// Check if packet is too old or from the future
|
||||
age := now - packetTime
|
||||
return age >= 0 && age <= int64(MaxPacketAge.Seconds())
|
||||
}
|
||||
|
||||
// ConstantTimeCompare performs a constant-time comparison of two byte slices
|
||||
func ConstantTimeCompare(a, b []byte) bool {
|
||||
return subtle.ConstantTimeCompare(a, b) == 1
|
||||
}
|
||||
|
||||
// ValidateEncryptionKey validates that an encryption key meets security requirements
|
||||
func ValidateEncryptionKey(key string) error {
|
||||
if len(key) < 32 {
|
||||
return fmt.Errorf("encryption key must be at least 32 characters long")
|
||||
}
|
||||
|
||||
// Check for common weak keys
|
||||
weakKeys := []string{
|
||||
"password", "123456", "admin", "test", "default",
|
||||
"your-secure-encryption-key-change-this-to-something-random",
|
||||
"test-encryption-key-12345", "teleport-key", "secret",
|
||||
}
|
||||
|
||||
for _, weak := range weakKeys {
|
||||
if key == weak {
|
||||
return fmt.Errorf("encryption key is too weak, please use a strong random key")
|
||||
}
|
||||
}
|
||||
|
||||
// Check for sufficient entropy (basic check)
|
||||
entropy := calculateEntropy(key)
|
||||
if entropy < 3.5 { // Minimum entropy threshold
|
||||
return fmt.Errorf("encryption key has insufficient entropy, please use a more random key")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateEntropy calculates the Shannon entropy of a string
|
||||
func calculateEntropy(s string) float64 {
|
||||
freq := make(map[rune]int)
|
||||
for _, r := range s {
|
||||
freq[r]++
|
||||
}
|
||||
|
||||
entropy := 0.0
|
||||
length := float64(len(s))
|
||||
|
||||
for _, count := range freq {
|
||||
p := float64(count) / length
|
||||
entropy -= p * log2(p)
|
||||
}
|
||||
|
||||
return entropy
|
||||
}
|
||||
|
||||
// log2 calculates log base 2
|
||||
func log2(x float64) float64 {
|
||||
return math.Log2(x)
|
||||
}
|
||||
242
pkg/encryption/encryption_test.go
Normal file
242
pkg/encryption/encryption_test.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDeriveKey(t *testing.T) {
|
||||
password := "test-password"
|
||||
key := DeriveKey(password)
|
||||
|
||||
if len(key) != 32 {
|
||||
t.Errorf("Expected key length 32, got %d", len(key))
|
||||
}
|
||||
|
||||
// Test that same password produces same key
|
||||
key2 := DeriveKey(password)
|
||||
if string(key) != string(key2) {
|
||||
t.Error("Same password should produce same key")
|
||||
}
|
||||
|
||||
// Test that different passwords produce different keys
|
||||
key3 := DeriveKey("different-password")
|
||||
if string(key) == string(key3) {
|
||||
t.Error("Different passwords should produce different keys")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
key := DeriveKey("test-key")
|
||||
originalData := []byte("Hello, World! This is a test message.")
|
||||
|
||||
// Test encryption
|
||||
encryptedData, err := EncryptData(originalData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption failed: %v", err)
|
||||
}
|
||||
|
||||
if len(encryptedData) <= len(originalData) {
|
||||
t.Error("Encrypted data should be longer than original data (due to nonce)")
|
||||
}
|
||||
|
||||
// Test decryption
|
||||
decryptedData, err := DecryptData(encryptedData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Decryption failed: %v", err)
|
||||
}
|
||||
|
||||
if string(decryptedData) != string(originalData) {
|
||||
t.Errorf("Decrypted data doesn't match original. Expected: %s, Got: %s",
|
||||
string(originalData), string(decryptedData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptEmptyData(t *testing.T) {
|
||||
key := DeriveKey("test-key")
|
||||
originalData := []byte("")
|
||||
|
||||
encryptedData, err := EncryptData(originalData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption of empty data failed: %v", err)
|
||||
}
|
||||
|
||||
decryptedData, err := DecryptData(encryptedData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Decryption of empty data failed: %v", err)
|
||||
}
|
||||
|
||||
if len(decryptedData) != 0 {
|
||||
t.Error("Decrypted empty data should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptLargeData(t *testing.T) {
|
||||
key := DeriveKey("test-key")
|
||||
|
||||
// Create a large data block (1MB)
|
||||
originalData := make([]byte, 1024*1024)
|
||||
for i := range originalData {
|
||||
originalData[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
encryptedData, err := EncryptData(originalData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption of large data failed: %v", err)
|
||||
}
|
||||
|
||||
decryptedData, err := DecryptData(encryptedData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Decryption of large data failed: %v", err)
|
||||
}
|
||||
|
||||
if len(decryptedData) != len(originalData) {
|
||||
t.Errorf("Decrypted data length mismatch. Expected: %d, Got: %d",
|
||||
len(originalData), len(decryptedData))
|
||||
}
|
||||
|
||||
for i := range originalData {
|
||||
if decryptedData[i] != originalData[i] {
|
||||
t.Errorf("Data mismatch at position %d", i)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrongKeyDecryption(t *testing.T) {
|
||||
key1 := DeriveKey("key1")
|
||||
key2 := DeriveKey("key2")
|
||||
originalData := []byte("test data")
|
||||
|
||||
encryptedData, err := EncryptData(originalData, key1)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to decrypt with wrong key
|
||||
_, err = DecryptData(encryptedData, key2)
|
||||
if err == nil {
|
||||
t.Error("Decryption with wrong key should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorruptedDataDecryption(t *testing.T) {
|
||||
key := DeriveKey("test-key")
|
||||
originalData := []byte("test data")
|
||||
|
||||
encryptedData, err := EncryptData(originalData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Encryption failed: %v", err)
|
||||
}
|
||||
|
||||
// Corrupt the data
|
||||
encryptedData[0] ^= 0xFF
|
||||
|
||||
// Try to decrypt corrupted data
|
||||
_, err = DecryptData(encryptedData, key)
|
||||
if err == nil {
|
||||
t.Error("Decryption of corrupted data should fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEncryptionKey(t *testing.T) {
|
||||
// Test valid key
|
||||
validKey := "a0e3dd20a761b118ca234160dd8b87230a001e332a97c9cfe3b8b9c99efaae03"
|
||||
if err := ValidateEncryptionKey(validKey); err != nil {
|
||||
t.Errorf("Valid key should pass validation: %v", err)
|
||||
}
|
||||
|
||||
// Test short key
|
||||
shortKey := "short"
|
||||
if err := ValidateEncryptionKey(shortKey); err == nil {
|
||||
t.Error("Short key should fail validation")
|
||||
}
|
||||
|
||||
// Test weak keys
|
||||
weakKeys := []string{
|
||||
"password", "123456", "admin", "test", "default",
|
||||
"your-secure-encryption-key-change-this-to-something-random",
|
||||
"test-encryption-key-12345", "teleport-key", "secret",
|
||||
}
|
||||
|
||||
for _, weakKey := range weakKeys {
|
||||
if err := ValidateEncryptionKey(weakKey); err == nil {
|
||||
t.Errorf("Weak key '%s' should fail validation", weakKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Test low entropy key
|
||||
lowEntropyKey := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
if err := ValidateEncryptionKey(lowEntropyKey); err == nil {
|
||||
t.Error("Low entropy key should fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplayProtection(t *testing.T) {
|
||||
rp := NewReplayProtection()
|
||||
nonce := uint64(12345)
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
// First use should be valid
|
||||
if !rp.IsValidNonce(nonce, timestamp) {
|
||||
t.Error("First use of nonce should be valid")
|
||||
}
|
||||
|
||||
// Replay should be invalid
|
||||
if rp.IsValidNonce(nonce, timestamp) {
|
||||
t.Error("Replay of nonce should be invalid")
|
||||
}
|
||||
|
||||
// Different nonce should be valid
|
||||
if !rp.IsValidNonce(nonce+1, timestamp) {
|
||||
t.Error("Different nonce should be valid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePacketTimestamp(t *testing.T) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
// Current timestamp should be valid
|
||||
if !ValidatePacketTimestamp(now) {
|
||||
t.Error("Current timestamp should be valid")
|
||||
}
|
||||
|
||||
// Recent timestamp should be valid
|
||||
recent := now - 60 // 1 minute ago
|
||||
if !ValidatePacketTimestamp(recent) {
|
||||
t.Error("Recent timestamp should be valid")
|
||||
}
|
||||
|
||||
// Old timestamp should be invalid
|
||||
old := now - int64(MaxPacketAge.Seconds()) - 1
|
||||
if ValidatePacketTimestamp(old) {
|
||||
t.Error("Old timestamp should be invalid")
|
||||
}
|
||||
|
||||
// Future timestamp should be invalid
|
||||
future := now + 3600 // 1 hour in future
|
||||
if ValidatePacketTimestamp(future) {
|
||||
t.Error("Future timestamp should be invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstantTimeCompare(t *testing.T) {
|
||||
a := []byte("test")
|
||||
b := []byte("test")
|
||||
c := []byte("different")
|
||||
|
||||
if !ConstantTimeCompare(a, b) {
|
||||
t.Error("Identical byte slices should compare equal")
|
||||
}
|
||||
|
||||
if ConstantTimeCompare(a, c) {
|
||||
t.Error("Different byte slices should not compare equal")
|
||||
}
|
||||
|
||||
// Test with empty slices
|
||||
empty1 := []byte{}
|
||||
empty2 := []byte{}
|
||||
if !ConstantTimeCompare(empty1, empty2) {
|
||||
t.Error("Empty slices should compare equal")
|
||||
}
|
||||
}
|
||||
269
pkg/logger/logger.go
Normal file
269
pkg/logger/logger.go
Normal file
@@ -0,0 +1,269 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
Log *logrus.Logger
|
||||
once sync.Once
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
// Config holds logging configuration
|
||||
type Config struct {
|
||||
Level string `yaml:"level"` // debug, info, warn, error
|
||||
Format string `yaml:"format"` // json, text
|
||||
File string `yaml:"file"` // log file path (empty for stdout)
|
||||
MaxSize int `yaml:"max_size"` // max log file size in MB
|
||||
MaxBackups int `yaml:"max_backups"` // max number of backup files
|
||||
MaxAge int `yaml:"max_age"` // max age of backup files in days
|
||||
Compress bool `yaml:"compress"` // compress backup files
|
||||
}
|
||||
|
||||
// Init initializes the global logger with the given configuration
|
||||
func Init(config Config) error {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
Log = logrus.New()
|
||||
|
||||
// Set log level
|
||||
level, err := logrus.ParseLevel(config.Level)
|
||||
if err != nil {
|
||||
level = logrus.InfoLevel
|
||||
}
|
||||
Log.SetLevel(level)
|
||||
|
||||
// Set log format with sanitization
|
||||
switch config.Format {
|
||||
case "json":
|
||||
Log.SetFormatter(&SanitizedJSONFormatter{
|
||||
TimestampFormat: "2006-01-02 15:04:05",
|
||||
})
|
||||
default:
|
||||
Log.SetFormatter(&SanitizedTextFormatter{
|
||||
FullTimestamp: true,
|
||||
TimestampFormat: "2006-01-02 15:04:05",
|
||||
})
|
||||
}
|
||||
|
||||
// Set output
|
||||
if config.File != "" {
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(config.File)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open log file with secure permissions (owner read/write only)
|
||||
file, err := os.OpenFile(config.File, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set output to both file and stdout
|
||||
Log.SetOutput(io.MultiWriter(file, os.Stdout))
|
||||
} else {
|
||||
Log.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLogger returns the global logger instance
|
||||
func GetLogger() *logrus.Logger {
|
||||
mu.RLock()
|
||||
if Log != nil {
|
||||
mu.RUnlock()
|
||||
return Log
|
||||
}
|
||||
mu.RUnlock()
|
||||
|
||||
// Initialize with default config if not already initialized
|
||||
once.Do(func() {
|
||||
Init(Config{
|
||||
Level: "info",
|
||||
Format: "text",
|
||||
File: "",
|
||||
})
|
||||
})
|
||||
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return Log
|
||||
}
|
||||
|
||||
// Helper functions for common logging operations
|
||||
func Debug(args ...interface{}) {
|
||||
GetLogger().Debug(args...)
|
||||
}
|
||||
|
||||
func Debugf(format string, args ...interface{}) {
|
||||
GetLogger().Debugf(format, args...)
|
||||
}
|
||||
|
||||
func Info(args ...interface{}) {
|
||||
GetLogger().Info(args...)
|
||||
}
|
||||
|
||||
func Infof(format string, args ...interface{}) {
|
||||
GetLogger().Infof(format, args...)
|
||||
}
|
||||
|
||||
func Warn(args ...interface{}) {
|
||||
GetLogger().Warn(args...)
|
||||
}
|
||||
|
||||
func Warnf(format string, args ...interface{}) {
|
||||
GetLogger().Warnf(format, args...)
|
||||
}
|
||||
|
||||
func Error(args ...interface{}) {
|
||||
GetLogger().Error(args...)
|
||||
}
|
||||
|
||||
func Errorf(format string, args ...interface{}) {
|
||||
GetLogger().Errorf(format, args...)
|
||||
}
|
||||
|
||||
func Fatal(args ...interface{}) {
|
||||
GetLogger().Fatal(args...)
|
||||
}
|
||||
|
||||
func Fatalf(format string, args ...interface{}) {
|
||||
GetLogger().Fatalf(format, args...)
|
||||
}
|
||||
|
||||
// WithField creates a new logger entry with a field
|
||||
func WithField(key string, value interface{}) *logrus.Entry {
|
||||
return GetLogger().WithField(key, value)
|
||||
}
|
||||
|
||||
// WithFields creates a new logger entry with multiple fields
|
||||
func WithFields(fields logrus.Fields) *logrus.Entry {
|
||||
return GetLogger().WithFields(fields)
|
||||
}
|
||||
|
||||
// SanitizedTextFormatter sanitizes log output
|
||||
type SanitizedTextFormatter struct {
|
||||
FullTimestamp bool
|
||||
TimestampFormat string
|
||||
}
|
||||
|
||||
// Format formats the log entry with sanitization
|
||||
func (f *SanitizedTextFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
// Sanitize sensitive data
|
||||
sanitizedEntry := *entry
|
||||
sanitizedEntry.Message = sanitizeString(entry.Message)
|
||||
|
||||
// Sanitize fields
|
||||
sanitizedFields := make(logrus.Fields)
|
||||
for k, v := range entry.Data {
|
||||
sanitizedFields[k] = sanitizeValue(v)
|
||||
}
|
||||
sanitizedEntry.Data = sanitizedFields
|
||||
|
||||
// Use default text formatter
|
||||
formatter := &logrus.TextFormatter{
|
||||
FullTimestamp: f.FullTimestamp,
|
||||
TimestampFormat: f.TimestampFormat,
|
||||
}
|
||||
return formatter.Format(&sanitizedEntry)
|
||||
}
|
||||
|
||||
// SanitizedJSONFormatter sanitizes JSON log output
|
||||
type SanitizedJSONFormatter struct {
|
||||
TimestampFormat string
|
||||
}
|
||||
|
||||
// Format formats the log entry with sanitization
|
||||
func (f *SanitizedJSONFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
// Sanitize sensitive data
|
||||
sanitizedEntry := *entry
|
||||
sanitizedEntry.Message = sanitizeString(entry.Message)
|
||||
|
||||
// Sanitize fields
|
||||
sanitizedFields := make(logrus.Fields)
|
||||
for k, v := range entry.Data {
|
||||
sanitizedFields[k] = sanitizeValue(v)
|
||||
}
|
||||
sanitizedEntry.Data = sanitizedFields
|
||||
|
||||
// Use default JSON formatter
|
||||
formatter := &logrus.JSONFormatter{
|
||||
TimestampFormat: f.TimestampFormat,
|
||||
}
|
||||
return formatter.Format(&sanitizedEntry)
|
||||
}
|
||||
|
||||
// sanitizeString removes or masks sensitive information
|
||||
func sanitizeString(s string) string {
|
||||
// Remove potential encryption keys (hex strings longer than 32 chars)
|
||||
keyPattern := regexp.MustCompile(`[a-fA-F0-9]{32,}`)
|
||||
s = keyPattern.ReplaceAllString(s, "[REDACTED_KEY]")
|
||||
|
||||
// Remove potential passwords and tokens
|
||||
passwordPattern := regexp.MustCompile(`(?i)(password|pass|pwd|secret|key|token|auth|credential)\s*[:=]\s*[^\s]+`)
|
||||
s = passwordPattern.ReplaceAllString(s, "$1=[REDACTED]")
|
||||
|
||||
// Remove potential API keys and tokens
|
||||
apiKeyPattern := regexp.MustCompile(`(?i)(api[_-]?key|access[_-]?token|bearer[_-]?token)\s*[:=]\s*[^\s]+`)
|
||||
s = apiKeyPattern.ReplaceAllString(s, "$1=[REDACTED]")
|
||||
|
||||
// Remove potential database connection strings
|
||||
dbPattern := regexp.MustCompile(`(?i)(mysql|postgres|mongodb|redis)://[^@]+@[^\s]+`)
|
||||
s = dbPattern.ReplaceAllString(s, "[REDACTED_DB_CONNECTION]")
|
||||
|
||||
// Remove potential JWT tokens
|
||||
jwtPattern := regexp.MustCompile(`eyJ[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*\.[A-Za-z0-9_-]*`)
|
||||
s = jwtPattern.ReplaceAllString(s, "[REDACTED_JWT]")
|
||||
|
||||
// Remove potential credit card numbers
|
||||
ccPattern := regexp.MustCompile(`\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b`)
|
||||
s = ccPattern.ReplaceAllString(s, "[REDACTED_CC]")
|
||||
|
||||
// Remove potential SSNs
|
||||
ssnPattern := regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`)
|
||||
s = ssnPattern.ReplaceAllString(s, "[REDACTED_SSN]")
|
||||
|
||||
// Remove potential IP addresses in sensitive contexts
|
||||
ipPattern := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`)
|
||||
s = ipPattern.ReplaceAllString(s, "[REDACTED_IP]")
|
||||
|
||||
// Remove potential email addresses
|
||||
emailPattern := regexp.MustCompile(`\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b`)
|
||||
s = emailPattern.ReplaceAllString(s, "[REDACTED_EMAIL]")
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// sanitizeValue sanitizes various value types
|
||||
func sanitizeValue(v interface{}) interface{} {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
return sanitizeString(val)
|
||||
case []byte:
|
||||
return "[BINARY_DATA]"
|
||||
case map[string]interface{}:
|
||||
sanitized := make(map[string]interface{})
|
||||
for k, v := range val {
|
||||
sanitized[k] = sanitizeValue(v)
|
||||
}
|
||||
return sanitized
|
||||
case []interface{}:
|
||||
sanitized := make([]interface{}, len(val))
|
||||
for i, v := range val {
|
||||
sanitized[i] = sanitizeValue(v)
|
||||
}
|
||||
return sanitized
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
200
pkg/metrics/metrics.go
Normal file
200
pkg/metrics/metrics.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NOTE: These metrics are for internal use only and are not exposed externally.
|
||||
// They are primarily used for testing, debugging, and internal monitoring.
|
||||
// No HTTP endpoints or external APIs are provided to access these metrics.
|
||||
|
||||
// Metrics collects application metrics
|
||||
type Metrics struct {
|
||||
// Connection metrics
|
||||
TotalConnections int64
|
||||
ActiveConnections int64
|
||||
RejectedConnections int64
|
||||
FailedConnections int64
|
||||
|
||||
// Request metrics
|
||||
TotalRequests int64
|
||||
SuccessfulRequests int64
|
||||
FailedRequests int64
|
||||
RateLimitedRequests int64
|
||||
|
||||
// DNS metrics
|
||||
DNSQueries int64
|
||||
DNSCustomRecords int64
|
||||
DNSForwardedQueries int64
|
||||
DNSFailedQueries int64
|
||||
|
||||
// UDP metrics
|
||||
UDPPacketsReceived int64
|
||||
UDPPacketsSent int64
|
||||
UDPReplayDetected int64
|
||||
|
||||
// Performance metrics
|
||||
AverageResponseTime time.Duration
|
||||
MaxResponseTime time.Duration
|
||||
MinResponseTime time.Duration
|
||||
|
||||
// System metrics
|
||||
StartTime time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// Global metrics instance
|
||||
var globalMetrics = &Metrics{
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
// GetMetrics returns the global metrics instance
|
||||
func GetMetrics() *Metrics {
|
||||
return globalMetrics
|
||||
}
|
||||
|
||||
// IncrementTotalConnections increments the total connections counter
|
||||
func (m *Metrics) IncrementTotalConnections() {
|
||||
atomic.AddInt64(&m.TotalConnections, 1)
|
||||
}
|
||||
|
||||
// IncrementActiveConnections increments the active connections counter
|
||||
func (m *Metrics) IncrementActiveConnections() {
|
||||
atomic.AddInt64(&m.ActiveConnections, 1)
|
||||
}
|
||||
|
||||
// DecrementActiveConnections decrements the active connections counter
|
||||
func (m *Metrics) DecrementActiveConnections() {
|
||||
atomic.AddInt64(&m.ActiveConnections, -1)
|
||||
}
|
||||
|
||||
// IncrementRejectedConnections increments the rejected connections counter
|
||||
func (m *Metrics) IncrementRejectedConnections() {
|
||||
atomic.AddInt64(&m.RejectedConnections, 1)
|
||||
}
|
||||
|
||||
// IncrementFailedConnections increments the failed connections counter
|
||||
func (m *Metrics) IncrementFailedConnections() {
|
||||
atomic.AddInt64(&m.FailedConnections, 1)
|
||||
}
|
||||
|
||||
// IncrementTotalRequests increments the total requests counter
|
||||
func (m *Metrics) IncrementTotalRequests() {
|
||||
atomic.AddInt64(&m.TotalRequests, 1)
|
||||
}
|
||||
|
||||
// IncrementSuccessfulRequests increments the successful requests counter
|
||||
func (m *Metrics) IncrementSuccessfulRequests() {
|
||||
atomic.AddInt64(&m.SuccessfulRequests, 1)
|
||||
}
|
||||
|
||||
// IncrementFailedRequests increments the failed requests counter
|
||||
func (m *Metrics) IncrementFailedRequests() {
|
||||
atomic.AddInt64(&m.FailedRequests, 1)
|
||||
}
|
||||
|
||||
// IncrementRateLimitedRequests increments the rate limited requests counter
|
||||
func (m *Metrics) IncrementRateLimitedRequests() {
|
||||
atomic.AddInt64(&m.RateLimitedRequests, 1)
|
||||
}
|
||||
|
||||
// IncrementDNSQueries increments the DNS queries counter
|
||||
func (m *Metrics) IncrementDNSQueries() {
|
||||
atomic.AddInt64(&m.DNSQueries, 1)
|
||||
}
|
||||
|
||||
// IncrementDNSCustomRecords increments the DNS custom records counter
|
||||
func (m *Metrics) IncrementDNSCustomRecords() {
|
||||
atomic.AddInt64(&m.DNSCustomRecords, 1)
|
||||
}
|
||||
|
||||
// IncrementDNSForwardedQueries increments the DNS forwarded queries counter
|
||||
func (m *Metrics) IncrementDNSForwardedQueries() {
|
||||
atomic.AddInt64(&m.DNSForwardedQueries, 1)
|
||||
}
|
||||
|
||||
// IncrementDNSFailedQueries increments the DNS failed queries counter
|
||||
func (m *Metrics) IncrementDNSFailedQueries() {
|
||||
atomic.AddInt64(&m.DNSFailedQueries, 1)
|
||||
}
|
||||
|
||||
// IncrementUDPPacketsReceived increments the UDP packets received counter
|
||||
func (m *Metrics) IncrementUDPPacketsReceived() {
|
||||
atomic.AddInt64(&m.UDPPacketsReceived, 1)
|
||||
}
|
||||
|
||||
// IncrementUDPPacketsSent increments the UDP packets sent counter
|
||||
func (m *Metrics) IncrementUDPPacketsSent() {
|
||||
atomic.AddInt64(&m.UDPPacketsSent, 1)
|
||||
}
|
||||
|
||||
// IncrementUDPReplayDetected increments the UDP replay detected counter
|
||||
func (m *Metrics) IncrementUDPReplayDetected() {
|
||||
atomic.AddInt64(&m.UDPReplayDetected, 1)
|
||||
}
|
||||
|
||||
// RecordResponseTime records a response time
|
||||
func (m *Metrics) RecordResponseTime(duration time.Duration) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// Update average response time (simple moving average)
|
||||
if m.AverageResponseTime == 0 {
|
||||
m.AverageResponseTime = duration
|
||||
} else {
|
||||
m.AverageResponseTime = (m.AverageResponseTime + duration) / 2
|
||||
}
|
||||
|
||||
// Update min/max response times
|
||||
if m.MaxResponseTime == 0 || duration > m.MaxResponseTime {
|
||||
m.MaxResponseTime = duration
|
||||
}
|
||||
if m.MinResponseTime == 0 || duration < m.MinResponseTime {
|
||||
m.MinResponseTime = duration
|
||||
}
|
||||
}
|
||||
|
||||
// GetUptime returns the application uptime
|
||||
func (m *Metrics) GetUptime() time.Duration {
|
||||
return time.Since(m.StartTime)
|
||||
}
|
||||
|
||||
// GetStats returns a snapshot of all metrics
|
||||
func (m *Metrics) GetStats() map[string]interface{} {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"uptime": m.GetUptime().String(),
|
||||
"connections": map[string]int64{
|
||||
"total": atomic.LoadInt64(&m.TotalConnections),
|
||||
"active": atomic.LoadInt64(&m.ActiveConnections),
|
||||
"rejected": atomic.LoadInt64(&m.RejectedConnections),
|
||||
"failed": atomic.LoadInt64(&m.FailedConnections),
|
||||
},
|
||||
"requests": map[string]int64{
|
||||
"total": atomic.LoadInt64(&m.TotalRequests),
|
||||
"successful": atomic.LoadInt64(&m.SuccessfulRequests),
|
||||
"failed": atomic.LoadInt64(&m.FailedRequests),
|
||||
"rate_limited": atomic.LoadInt64(&m.RateLimitedRequests),
|
||||
},
|
||||
"dns": map[string]int64{
|
||||
"queries": atomic.LoadInt64(&m.DNSQueries),
|
||||
"custom_records": atomic.LoadInt64(&m.DNSCustomRecords),
|
||||
"forwarded": atomic.LoadInt64(&m.DNSForwardedQueries),
|
||||
"failed": atomic.LoadInt64(&m.DNSFailedQueries),
|
||||
},
|
||||
"udp": map[string]int64{
|
||||
"packets_received": atomic.LoadInt64(&m.UDPPacketsReceived),
|
||||
"packets_sent": atomic.LoadInt64(&m.UDPPacketsSent),
|
||||
"replay_detected": atomic.LoadInt64(&m.UDPReplayDetected),
|
||||
},
|
||||
"performance": map[string]interface{}{
|
||||
"avg_response_time": m.AverageResponseTime.String(),
|
||||
"max_response_time": m.MaxResponseTime.String(),
|
||||
"min_response_time": m.MinResponseTime.String(),
|
||||
},
|
||||
}
|
||||
}
|
||||
89
pkg/ratelimit/ratelimit.go
Normal file
89
pkg/ratelimit/ratelimit.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimiter implements a token bucket rate limiter
|
||||
type RateLimiter struct {
|
||||
requestsPerSecond int
|
||||
burstSize int
|
||||
windowSize time.Duration
|
||||
tokens int
|
||||
lastRefill time.Time
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
func NewRateLimiter(requestsPerSecond, burstSize int, windowSize time.Duration) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
requestsPerSecond: requestsPerSecond,
|
||||
burstSize: burstSize,
|
||||
tokens: burstSize,
|
||||
lastRefill: time.Now(),
|
||||
windowSize: windowSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request is allowed under the rate limit
|
||||
func (rl *RateLimiter) Allow() bool {
|
||||
rl.mutex.Lock()
|
||||
defer rl.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Calculate tokens to add based on time elapsed
|
||||
elapsed := now.Sub(rl.lastRefill)
|
||||
tokensToAdd := int(elapsed.Seconds() * float64(rl.requestsPerSecond))
|
||||
|
||||
if tokensToAdd > 0 {
|
||||
rl.tokens += tokensToAdd
|
||||
if rl.tokens > rl.burstSize {
|
||||
rl.tokens = rl.burstSize
|
||||
}
|
||||
rl.lastRefill = now
|
||||
}
|
||||
|
||||
// Check if we have tokens available
|
||||
if rl.tokens > 0 {
|
||||
rl.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AllowWithContext checks if a request is allowed with context cancellation
|
||||
func (rl *RateLimiter) AllowWithContext(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
default:
|
||||
return rl.Allow()
|
||||
}
|
||||
}
|
||||
|
||||
// Wait blocks until a request is allowed
|
||||
func (rl *RateLimiter) Wait(ctx context.Context) error {
|
||||
for {
|
||||
if rl.Allow() {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(rl.windowSize / time.Duration(rl.requestsPerSecond)):
|
||||
// Wait for next token
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats returns current rate limiter statistics
|
||||
func (rl *RateLimiter) GetStats() (tokens int, lastRefill time.Time) {
|
||||
rl.mutex.Lock()
|
||||
defer rl.mutex.Unlock()
|
||||
return rl.tokens, rl.lastRefill
|
||||
}
|
||||
33
pkg/types/types.go
Normal file
33
pkg/types/types.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package types
|
||||
|
||||
import "net"
|
||||
|
||||
// PortForwardRequest represents a port forwarding request
|
||||
type PortForwardRequest struct {
|
||||
LocalPort int
|
||||
RemotePort int
|
||||
Protocol string
|
||||
TargetHost string
|
||||
}
|
||||
|
||||
// UDPPacketHeader represents the header for UDP packets
|
||||
type UDPPacketHeader struct {
|
||||
ClientID string
|
||||
PacketID uint64
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
// TaggedUDPPacket represents a UDP packet with tagging information
|
||||
type TaggedUDPPacket struct {
|
||||
Header UDPPacketHeader
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// UDPRequest represents a UDP forwarding request
|
||||
type UDPRequest struct {
|
||||
LocalPort int
|
||||
RemotePort int
|
||||
Protocol string
|
||||
Data []byte
|
||||
ClientAddr *net.UDPAddr
|
||||
}
|
||||
112
pkg/types/types_test.go
Normal file
112
pkg/types/types_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPortForwardRequest(t *testing.T) {
|
||||
req := PortForwardRequest{
|
||||
LocalPort: 8080,
|
||||
RemotePort: 80,
|
||||
Protocol: "tcp",
|
||||
}
|
||||
|
||||
if req.LocalPort != 8080 {
|
||||
t.Errorf("Expected LocalPort 8080, got %d", req.LocalPort)
|
||||
}
|
||||
|
||||
if req.RemotePort != 80 {
|
||||
t.Errorf("Expected RemotePort 80, got %d", req.RemotePort)
|
||||
}
|
||||
|
||||
if req.Protocol != "tcp" {
|
||||
t.Errorf("Expected Protocol 'tcp', got '%s'", req.Protocol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPPacketHeader(t *testing.T) {
|
||||
header := UDPPacketHeader{
|
||||
ClientID: "test-client",
|
||||
PacketID: 12345,
|
||||
Timestamp: 1640995200,
|
||||
}
|
||||
|
||||
if header.ClientID != "test-client" {
|
||||
t.Errorf("Expected ClientID 'test-client', got '%s'", header.ClientID)
|
||||
}
|
||||
|
||||
if header.PacketID != 12345 {
|
||||
t.Errorf("Expected PacketID 12345, got %d", header.PacketID)
|
||||
}
|
||||
|
||||
if header.Timestamp != 1640995200 {
|
||||
t.Errorf("Expected Timestamp 1640995200, got %d", header.Timestamp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaggedUDPPacket(t *testing.T) {
|
||||
header := UDPPacketHeader{
|
||||
ClientID: "test-client",
|
||||
PacketID: 12345,
|
||||
Timestamp: 1640995200,
|
||||
}
|
||||
|
||||
data := []byte("test data")
|
||||
|
||||
packet := TaggedUDPPacket{
|
||||
Header: header,
|
||||
Data: data,
|
||||
}
|
||||
|
||||
if packet.Header.ClientID != "test-client" {
|
||||
t.Errorf("Expected ClientID 'test-client', got '%s'", packet.Header.ClientID)
|
||||
}
|
||||
|
||||
if len(packet.Data) != len(data) {
|
||||
t.Errorf("Expected data length %d, got %d", len(data), len(packet.Data))
|
||||
}
|
||||
|
||||
for i, b := range data {
|
||||
if packet.Data[i] != b {
|
||||
t.Errorf("Data mismatch at position %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPRequest(t *testing.T) {
|
||||
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:8080")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to resolve UDP address: %v", err)
|
||||
}
|
||||
|
||||
data := []byte("test data")
|
||||
|
||||
req := UDPRequest{
|
||||
LocalPort: 8080,
|
||||
RemotePort: 80,
|
||||
Protocol: "udp",
|
||||
Data: data,
|
||||
ClientAddr: addr,
|
||||
}
|
||||
|
||||
if req.LocalPort != 8080 {
|
||||
t.Errorf("Expected LocalPort 8080, got %d", req.LocalPort)
|
||||
}
|
||||
|
||||
if req.RemotePort != 80 {
|
||||
t.Errorf("Expected RemotePort 80, got %d", req.RemotePort)
|
||||
}
|
||||
|
||||
if req.Protocol != "udp" {
|
||||
t.Errorf("Expected Protocol 'udp', got '%s'", req.Protocol)
|
||||
}
|
||||
|
||||
if len(req.Data) != len(data) {
|
||||
t.Errorf("Expected data length %d, got %d", len(data), len(req.Data))
|
||||
}
|
||||
|
||||
if req.ClientAddr.String() != "127.0.0.1:8080" {
|
||||
t.Errorf("Expected ClientAddr '127.0.0.1:8080', got '%s'", req.ClientAddr.String())
|
||||
}
|
||||
}
|
||||
47
pkg/version/version.go
Normal file
47
pkg/version/version.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// These variables are set during build time via ldflags
|
||||
var (
|
||||
Version = "dev"
|
||||
GitCommit = "unknown"
|
||||
BuildDate = "unknown"
|
||||
GoVersion = runtime.Version()
|
||||
Platform = fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)
|
||||
)
|
||||
|
||||
// Info holds version information
|
||||
type Info struct {
|
||||
Version string `json:"version"`
|
||||
GitCommit string `json:"git_commit"`
|
||||
BuildDate string `json:"build_date"`
|
||||
GoVersion string `json:"go_version"`
|
||||
Platform string `json:"platform"`
|
||||
}
|
||||
|
||||
// Get returns version information
|
||||
func Get() Info {
|
||||
return Info{
|
||||
Version: Version,
|
||||
GitCommit: GitCommit,
|
||||
BuildDate: BuildDate,
|
||||
GoVersion: GoVersion,
|
||||
Platform: Platform,
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a formatted version string
|
||||
func String() string {
|
||||
info := Get()
|
||||
return fmt.Sprintf("teleport version %s (commit: %s, built: %s, go: %s, platform: %s)",
|
||||
info.Version, info.GitCommit, info.BuildDate, info.GoVersion, info.Platform)
|
||||
}
|
||||
|
||||
// Short returns a short version string
|
||||
func Short() string {
|
||||
return fmt.Sprintf("teleport %s", Version)
|
||||
}
|
||||
Reference in New Issue
Block a user