package server import ( "context" "encoding/binary" "fmt" "net" "sync" "sync/atomic" "time" "teleport/pkg/config" "teleport/pkg/encryption" "teleport/pkg/logger" "teleport/pkg/metrics" "teleport/pkg/ratelimit" "teleport/pkg/types" ) // TeleportServer represents the server instance type TeleportServer struct { config *config.Config listener net.Listener udpListeners map[int]*net.UDPConn udpClients map[string]*net.UDPConn udpMutex sync.RWMutex packetCounter uint64 replayProtection *encryption.ReplayProtection ctx context.Context cancel context.CancelFunc connectionSem chan struct{} // Semaphore for limiting concurrent connections activeConnections int64 // Atomic counter for active connections maxConnections int // Maximum concurrent connections rateLimiter *ratelimit.RateLimiter goroutineSem chan struct{} // Semaphore for limiting concurrent goroutines maxGoroutines int // Maximum concurrent goroutines metrics *metrics.Metrics } // NewTeleportServer creates a new teleport server func NewTeleportServer(config *config.Config) *TeleportServer { ctx, cancel := context.WithCancel(context.Background()) maxConnections := 1000 // Default maximum concurrent connections if config.MaxConnections > 0 { maxConnections = config.MaxConnections } maxGoroutines := maxConnections * 2 // Allow 2x connections for goroutines if maxGoroutines > 10000 { maxGoroutines = 10000 // Cap at 10k goroutines } // Initialize rate limiter if enabled var rateLimiter *ratelimit.RateLimiter if config.RateLimit.Enabled { rateLimiter = ratelimit.NewRateLimiter( config.RateLimit.RequestsPerSecond, config.RateLimit.BurstSize, config.RateLimit.WindowSize, ) } // Initialize metrics metricsInstance := metrics.GetMetrics() return &TeleportServer{ config: config, udpListeners: make(map[int]*net.UDPConn), udpClients: make(map[string]*net.UDPConn), replayProtection: encryption.NewReplayProtection(), ctx: ctx, cancel: cancel, connectionSem: make(chan struct{}, maxConnections), maxConnections: maxConnections, rateLimiter: rateLimiter, goroutineSem: make(chan struct{}, maxGoroutines), maxGoroutines: maxGoroutines, metrics: metricsInstance, } } // Start starts the teleport server func (ts *TeleportServer) Start() error { logger.WithField("listen_address", ts.config.ListenAddress).Info("Starting teleport server") // Start TCP listener listener, err := net.Listen("tcp", ts.config.ListenAddress) if err != nil { return fmt.Errorf("failed to start TCP listener: %v", err) } ts.listener = listener // Start UDP listeners for each port for _, rule := range ts.config.Ports { if rule.Protocol == "udp" { go ts.startUDPListener(rule) } } // Accept TCP connections for { select { case <-ts.ctx.Done(): logger.Debug("Server shutting down...") return ts.ctx.Err() default: conn, err := ts.listener.Accept() if err != nil { select { case <-ts.ctx.Done(): return ts.ctx.Err() default: logger.WithField("error", err).Error("Failed to accept connection") continue } } // Check rate limiting first if ts.rateLimiter != nil && !ts.rateLimiter.Allow() { logger.WithField("client", conn.RemoteAddr()).Warn("Rate limit exceeded, rejecting connection") ts.metrics.IncrementRateLimitedRequests() conn.Close() continue } // Check if we can accept more connections select { case ts.connectionSem <- struct{}{}: // Connection accepted atomic.AddInt64(&ts.activeConnections, 1) ts.metrics.IncrementTotalConnections() ts.metrics.IncrementActiveConnections() // Check goroutine limit select { case ts.goroutineSem <- struct{}{}: go ts.handleConnectionWithLimit(conn) default: // Goroutine limit reached <-ts.connectionSem // Release connection semaphore atomic.AddInt64(&ts.activeConnections, -1) ts.metrics.DecrementActiveConnections() ts.metrics.IncrementRejectedConnections() logger.WithField("max_goroutines", ts.maxGoroutines).Warn("Goroutine limit reached, rejecting connection") conn.Close() } default: // Connection limit reached ts.metrics.IncrementRejectedConnections() logger.WithField("max_connections", ts.maxConnections).Warn("Connection limit reached, rejecting connection") conn.Close() } } } } // Stop stops the teleport server func (ts *TeleportServer) Stop() { logger.Info("Stopping teleport server...") ts.cancel() if ts.listener != nil { ts.listener.Close() } // Close all UDP listeners ts.udpMutex.Lock() for _, conn := range ts.udpListeners { conn.Close() } ts.udpMutex.Unlock() } // startUDPListener starts a UDP listener for a specific port func (ts *TeleportServer) startUDPListener(rule config.PortRule) { addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", rule.LocalPort)) if err != nil { logger.WithFields(map[string]interface{}{ "port": rule.LocalPort, "error": err, }).Error("Failed to resolve UDP address") return } conn, err := net.ListenUDP("udp", addr) if err != nil { logger.WithFields(map[string]interface{}{ "port": rule.LocalPort, "error": err, }).Error("Failed to start UDP listener") return } ts.udpMutex.Lock() ts.udpListeners[rule.LocalPort] = conn ts.udpMutex.Unlock() logger.WithField("port", rule.LocalPort).Info("UDP listener started") buffer := make([]byte, 4096) for { n, clientAddr, err := conn.ReadFromUDP(buffer) if err != nil { logger.WithField("error", err).Error("Failed to read UDP packet") continue } // Create UDP request request := types.UDPRequest{ LocalPort: rule.LocalPort, RemotePort: rule.RemotePort, Protocol: rule.Protocol, Data: make([]byte, n), ClientAddr: clientAddr, } copy(request.Data, buffer[:n]) // Forward to all connected clients go ts.forwardUDPRequest(request) } } // forwardUDPRequest forwards a UDP request to all connected clients func (ts *TeleportServer) forwardUDPRequest(request types.UDPRequest) { ts.udpMutex.RLock() defer ts.udpMutex.RUnlock() // Create tagged packet with atomic counter packetID := atomic.AddUint64(&ts.packetCounter, 1) taggedPacket := types.TaggedUDPPacket{ Header: types.UDPPacketHeader{ ClientID: fmt.Sprintf("udp-%s", request.ClientAddr.String()), PacketID: packetID, Timestamp: time.Now().Unix(), }, Data: request.Data, } // Forward to all connected UDP clients for clientID, clientConn := range ts.udpClients { if clientID != taggedPacket.Header.ClientID { go ts.sendTaggedUDPPacket(clientConn, taggedPacket) } } } // sendTaggedUDPPacket sends a tagged UDP packet to a client func (ts *TeleportServer) sendTaggedUDPPacket(clientConn *net.UDPConn, packet types.TaggedUDPPacket) { // Serialize the tagged packet data, err := ts.serializeTaggedUDPPacket(packet) if err != nil { logger.WithField("error", err).Debug("Failed to serialize tagged UDP packet") return } // Encrypt the data key := encryption.DeriveKey(ts.config.EncryptionKey) encryptedData, err := encryption.EncryptData(data, key) if err != nil { logger.WithField("error", err).Debug("Failed to encrypt UDP packet") return } // Send to client _, err = clientConn.Write(encryptedData) if err != nil { logger.WithField("error", err).Debug("Failed to send UDP packet to client") } } // serializeTaggedUDPPacket serializes a tagged UDP packet func (ts *TeleportServer) serializeTaggedUDPPacket(packet types.TaggedUDPPacket) ([]byte, error) { // Simple serialization: header length + header + data headerBytes := []byte(packet.Header.ClientID) headerLen := len(headerBytes) data := make([]byte, 4+8+8+headerLen+len(packet.Data)) offset := 0 // Header length binary.BigEndian.PutUint32(data[offset:], uint32(headerLen)) offset += 4 // Packet ID binary.BigEndian.PutUint64(data[offset:], packet.Header.PacketID) offset += 8 // Timestamp binary.BigEndian.PutUint64(data[offset:], uint64(packet.Header.Timestamp)) offset += 8 // Client ID copy(data[offset:], headerBytes) offset += headerLen // Data copy(data[offset:], packet.Data) return data, nil } // handleConnectionWithLimit handles a TCP connection with connection limit management func (ts *TeleportServer) handleConnectionWithLimit(conn net.Conn) { defer func() { conn.Close() <-ts.connectionSem // Release connection semaphore <-ts.goroutineSem // Release goroutine semaphore atomic.AddInt64(&ts.activeConnections, -1) ts.metrics.DecrementActiveConnections() }() // Set connection timeouts if ts.config.ReadTimeout > 0 { conn.SetReadDeadline(time.Now().Add(ts.config.ReadTimeout)) } if ts.config.WriteTimeout > 0 { conn.SetWriteDeadline(time.Now().Add(ts.config.WriteTimeout)) } logger.WithFields(map[string]interface{}{ "client": conn.RemoteAddr(), "active_connections": atomic.LoadInt64(&ts.activeConnections), }).Info("New connection") ts.handleConnection(conn) } // handleConnection handles a TCP connection from a client func (ts *TeleportServer) handleConnection(conn net.Conn) { // Don't set timeouts on the initial connection - let the individual handlers set them // Read the port forward request (only one per connection) var request types.PortForwardRequest if err := ts.readRequest(conn, &request); err != nil { logger.WithField("error", err).Error("Failed to read request") return } // Find matching port rule // The client sends LocalPort (client's local port) and RemotePort (server's port to forward to) // We need to find a server rule that matches the RemotePort the client wants var portRule *config.PortRule for _, rule := range ts.config.Ports { if rule.LocalPort == request.RemotePort && rule.Protocol == request.Protocol { portRule = &rule break } } if portRule == nil { logger.WithFields(map[string]interface{}{ "local_port": request.LocalPort, "remote_port": request.RemotePort, "protocol": request.Protocol, }).Warn("No matching port rule found") return } // Handle based on protocol logger.WithFields(map[string]interface{}{ "protocol": request.Protocol, "port": request.RemotePort, }).Info("Handling connection") switch request.Protocol { case "tcp": logger.Debug("Routing to TCP handler") ts.handleTCPForward(conn, portRule) case "udp": logger.Debug("Routing to UDP handler") ts.handleUDPForward(conn, portRule) } } // handleTCPForward handles TCP port forwarding func (ts *TeleportServer) handleTCPForward(clientConn net.Conn, rule *config.PortRule) { // Determine target host (default to localhost if not specified) targetHost := rule.TargetHost if targetHost == "" { targetHost = "localhost" } // Connect to the target service targetAddr := net.JoinHostPort(targetHost, fmt.Sprintf("%d", rule.RemotePort)) targetConn, err := net.Dial("tcp", targetAddr) if err != nil { logger.WithFields(map[string]interface{}{ "target": targetAddr, "error": err, }).Error("Failed to connect to target") return } defer targetConn.Close() logger.WithFields(map[string]interface{}{ "client": clientConn.RemoteAddr(), "target": targetAddr, }).Info("TCP forwarding") // Start bidirectional forwarding var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() ts.forwardData(clientConn, targetConn) }() go func() { defer wg.Done() ts.forwardData(targetConn, clientConn) }() wg.Wait() } // handleUDPForward handles UDP port forwarding func (ts *TeleportServer) handleUDPForward(clientConn net.Conn, rule *config.PortRule) { logger.WithField("client", clientConn.RemoteAddr()).Debug("UDP SERVER: Starting UDP forwarding") // Determine target host (default to localhost if not specified) targetHost := rule.TargetHost if targetHost == "" { targetHost = "localhost" } // Create UDP connection to target targetAddr := net.JoinHostPort(targetHost, fmt.Sprintf("%d", rule.RemotePort)) targetConn, err := net.Dial("udp", targetAddr) if err != nil { logger.WithFields(map[string]interface{}{ "target": targetAddr, "error": err, }).Debug("UDP SERVER: Failed to connect to UDP target") return } defer targetConn.Close() logger.WithField("target", targetAddr).Debug("UDP SERVER: Connected to UDP target") // Register this client for UDP forwarding with cleanup clientID := fmt.Sprintf("tcp-%s", clientConn.RemoteAddr().String()) ts.udpMutex.Lock() // Clean up any existing connection for this client if existingConn, exists := ts.udpClients[clientID]; exists { existingConn.Close() } ts.udpClients[clientID] = targetConn.(*net.UDPConn) ts.udpMutex.Unlock() logger.WithField("client", clientID).Debug("UDP SERVER: UDP forwarding registered") // Handle UDP packets bidirectionally var wg sync.WaitGroup wg.Add(2) // Channel to pass packet IDs from requests to responses // Use a larger buffer to prevent blocking under high load packetIDChan := make(chan uint64, 1000) // Forward packets from client to target go func() { defer wg.Done() buffer := make([]byte, 4096) for { select { case <-ts.ctx.Done(): logger.Debug("UDP SERVER: Context cancelled in client->target goroutine") return default: n, err := clientConn.Read(buffer) if err != nil { logger.WithField("error", err).Debug("UDP SERVER: Failed to read from client") return } logger.WithField("bytes_received", n).Debug("UDP SERVER: Received bytes from client") // Decrypt the data key := encryption.DeriveKey(ts.config.EncryptionKey) decryptedData, err := encryption.DecryptData(buffer[:n], key) if err != nil { logger.WithField("error", err).Debug("UDP SERVER: Failed to decrypt UDP packet") continue } logger.WithField("decrypted_bytes", len(decryptedData)).Debug("UDP SERVER: Decrypted bytes") // Deserialize tagged packet packet, err := ts.deserializeTaggedUDPPacket(decryptedData) if err != nil { logger.WithField("error", err).Debug("UDP SERVER: Failed to deserialize tagged UDP packet") continue } // Validate packet timestamp if !encryption.ValidatePacketTimestamp(packet.Header.Timestamp) { logger.WithFields(map[string]interface{}{ "packetID": packet.Header.PacketID, "timestamp": packet.Header.Timestamp, }).Debug("UDP SERVER: Packet timestamp validation failed") continue } // Check for replay attacks if !ts.replayProtection.IsValidNonce(packet.Header.PacketID, packet.Header.Timestamp) { logger.WithFields(map[string]interface{}{ "packetID": packet.Header.PacketID, "timestamp": packet.Header.Timestamp, }).Debug("UDP SERVER: Replay attack detected") continue } logger.WithFields(map[string]interface{}{ "packetID": packet.Header.PacketID, "data_length": len(packet.Data), }).Debug("UDP SERVER: Deserialized packet") // Forward to target _, err = targetConn.Write(packet.Data) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": packet.Header.PacketID, "error": err, }).Debug("UDP SERVER: Failed to forward UDP packet to target") return } logger.WithField("packetID", packet.Header.PacketID).Debug("UDP SERVER: Forwarded packet to target") // Send the packet ID to the response handler select { case packetIDChan <- packet.Header.PacketID: logger.WithField("packetID", packet.Header.PacketID).Debug("UDP SERVER: Sent packet ID to response handler") default: logger.WithField("packetID", packet.Header.PacketID).Debug("UDP SERVER: Packet ID channel full, dropping packet ID") } } } }() // Forward responses from target to client go func() { defer wg.Done() buffer := make([]byte, 4096) for { select { case <-ts.ctx.Done(): logger.Debug("UDP SERVER: Context cancelled in target->client goroutine") return default: n, err := targetConn.Read(buffer) if err != nil { logger.WithField("error", err).Debug("UDP SERVER: Failed to read from target") return } logger.WithField("bytes_received", n).Debug("UDP SERVER: Received bytes from target") // Get the packet ID from the request var originalPacketID uint64 select { case originalPacketID = <-packetIDChan: logger.WithField("packetID", originalPacketID).Debug("UDP SERVER: Got packet ID for response") case <-time.After(1 * time.Second): logger.Debug("UDP SERVER: Timeout waiting for packet ID, using default") originalPacketID = 0 } // Create tagged packet for response using the original packet ID responsePacket := types.TaggedUDPPacket{ Header: types.UDPPacketHeader{ ClientID: clientID, PacketID: originalPacketID, Timestamp: time.Now().Unix(), }, Data: make([]byte, n), } copy(responsePacket.Data, buffer[:n]) logger.WithFields(map[string]interface{}{ "packetID": originalPacketID, "data": string(responsePacket.Data), }).Debug("UDP SERVER: Created response packet") // Serialize and encrypt response data, err := ts.serializeTaggedUDPPacket(responsePacket) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": originalPacketID, "error": err, }).Debug("UDP SERVER: Failed to serialize UDP response") continue } key := encryption.DeriveKey(ts.config.EncryptionKey) encryptedData, err := encryption.EncryptData(data, key) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": originalPacketID, "error": err, }).Debug("UDP SERVER: Failed to encrypt UDP response") continue } logger.WithFields(map[string]interface{}{ "packetID": originalPacketID, "encrypted_length": len(encryptedData), }).Debug("UDP SERVER: Encrypted response packet") // Send response to client _, err = clientConn.Write(encryptedData) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": originalPacketID, "error": err, }).Debug("UDP SERVER: Failed to send UDP response to client") return } logger.WithField("packetID", originalPacketID).Debug("UDP SERVER: Successfully sent response packet to client") } } }() wg.Wait() // Unregister client ts.udpMutex.Lock() delete(ts.udpClients, clientID) ts.udpMutex.Unlock() } // deserializeTaggedUDPPacket deserializes a tagged UDP packet func (ts *TeleportServer) deserializeTaggedUDPPacket(data []byte) (types.TaggedUDPPacket, error) { // Minimum size: 4 (headerLen) + 8 (packetID) + 8 (timestamp) = 20 bytes if len(data) < 20 { return types.TaggedUDPPacket{}, fmt.Errorf("packet data too short") } // Maximum reasonable packet size (1MB) if len(data) > 1024*1024 { return types.TaggedUDPPacket{}, fmt.Errorf("packet data too large") } offset := 0 // Header length if offset+4 > len(data) { return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for header length") } headerLen := binary.BigEndian.Uint32(data[offset:]) offset += 4 // Validate header length (reasonable limits) if headerLen > 1024 || headerLen == 0 { return types.TaggedUDPPacket{}, fmt.Errorf("invalid header length: %d", headerLen) } // Packet ID if offset+8 > len(data) { return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for packet ID") } packetID := binary.BigEndian.Uint64(data[offset:]) offset += 8 // Timestamp if offset+8 > len(data) { return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for timestamp") } timestamp := binary.BigEndian.Uint64(data[offset:]) offset += 8 // Client ID if offset+int(headerLen) > len(data) { return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for client ID") } clientID := string(data[offset : offset+int(headerLen)]) offset += int(headerLen) // Validate client ID (basic sanitization) if len(clientID) == 0 { return types.TaggedUDPPacket{}, fmt.Errorf("empty client ID") } // Check for reasonable client ID format if len(clientID) > 256 { return types.TaggedUDPPacket{}, fmt.Errorf("client ID too long") } // Data if offset > len(data) { return types.TaggedUDPPacket{}, fmt.Errorf("data offset exceeds packet length") } packetData := make([]byte, len(data)-offset) copy(packetData, data[offset:]) return types.TaggedUDPPacket{ Header: types.UDPPacketHeader{ ClientID: clientID, PacketID: packetID, Timestamp: int64(timestamp), }, Data: packetData, }, nil } // forwardData forwards data between two connections func (ts *TeleportServer) forwardData(src, dst net.Conn) { buffer := make([]byte, 4096) for { select { case <-ts.ctx.Done(): return default: n, err := src.Read(buffer) if err != nil { // Close the destination connection when source closes dst.Close() return } _, err = dst.Write(buffer[:n]) if err != nil { // Close the source connection when destination closes src.Close() return } } } } // readRequest reads a port forward request from the connection func (ts *TeleportServer) readRequest(conn net.Conn, request *types.PortForwardRequest) error { // Read request length var length uint32 if err := binary.Read(conn, binary.BigEndian, &length); err != nil { return err } // Validate request length (prevent buffer overflow attacks) if length == 0 { return fmt.Errorf("request length cannot be zero") } if length > 32*1024 { // Reduced to 32KB limit for better security return fmt.Errorf("request too large: %d bytes (max 32KB)", length) } // Read encrypted request data with timeout conn.SetReadDeadline(time.Now().Add(10 * time.Second)) encryptedData := make([]byte, length) bytesRead := 0 for bytesRead < int(length) { n, err := conn.Read(encryptedData[bytesRead:]) if err != nil { return fmt.Errorf("failed to read request data: %v", err) } bytesRead += n } // Decrypt the data key := encryption.DeriveKey(ts.config.EncryptionKey) decryptedData, err := encryption.DecryptData(encryptedData, key) if err != nil { logger.WithFields(map[string]interface{}{ "error": err, "encrypted_data_length": len(encryptedData), }).Error("Decryption failed") return err } // Deserialize the request return ts.deserializeRequest(decryptedData, request) } // deserializeRequest deserializes a port forward request func (ts *TeleportServer) deserializeRequest(data []byte, request *types.PortForwardRequest) error { // Minimum size: 4 (localPort) + 4 (remotePort) + 4 (protocolLen) + 4 (targetHostLen) = 16 bytes if len(data) < 16 { return fmt.Errorf("request data too short") } // Maximum reasonable request size (64KB) if len(data) > 64*1024 { return fmt.Errorf("request data too large") } offset := 0 // Local port if offset+4 > len(data) { return fmt.Errorf("insufficient data for local port") } request.LocalPort = int(binary.BigEndian.Uint32(data[offset:])) offset += 4 // Validate port range if request.LocalPort < 1 || request.LocalPort > 65535 { return fmt.Errorf("invalid local port: %d", request.LocalPort) } // Remote port if offset+4 > len(data) { return fmt.Errorf("insufficient data for remote port") } request.RemotePort = int(binary.BigEndian.Uint32(data[offset:])) offset += 4 // Validate port range if request.RemotePort < 1 || request.RemotePort > 65535 { return fmt.Errorf("invalid remote port: %d", request.RemotePort) } // Protocol length if offset+4 > len(data) { return fmt.Errorf("insufficient data for protocol length") } protocolLen := int(binary.BigEndian.Uint32(data[offset:])) offset += 4 // Validate protocol length if protocolLen < 1 || protocolLen > 64 { return fmt.Errorf("invalid protocol length: %d", protocolLen) } if offset+protocolLen > len(data) { return fmt.Errorf("insufficient data for protocol") } request.Protocol = string(data[offset : offset+protocolLen]) offset += protocolLen // Validate protocol if request.Protocol != "tcp" && request.Protocol != "udp" { return fmt.Errorf("invalid protocol: %s", request.Protocol) } // Target host length if offset+4 > len(data) { return fmt.Errorf("insufficient data for target host length") } targetHostLen := int(binary.BigEndian.Uint32(data[offset:])) offset += 4 // Validate target host length if targetHostLen > 256 { return fmt.Errorf("target host too long: %d", targetHostLen) } if offset+targetHostLen > len(data) { return fmt.Errorf("insufficient data for target host") } request.TargetHost = string(data[offset : offset+targetHostLen]) // Basic validation of target host (if not empty) if request.TargetHost != "" { if len(request.TargetHost) > 253 { // RFC 1123 limit return fmt.Errorf("target host exceeds maximum length") } // Basic character validation (no control characters) for _, c := range request.TargetHost { if c < 32 || c > 126 { return fmt.Errorf("invalid character in target host") } } } return nil }