package client import ( "context" "encoding/binary" "fmt" "net" "sync" "sync/atomic" "time" "teleport/pkg/config" "teleport/pkg/dns" "teleport/pkg/encryption" "teleport/pkg/logger" "teleport/pkg/types" ) // TeleportClient represents the client instance type TeleportClient struct { config *config.Config serverConn net.Conn udpListeners map[int]*net.UDPConn udpMutex sync.RWMutex packetCounter uint64 replayProtection *encryption.ReplayProtection ctx context.Context cancel context.CancelFunc connectionPool chan net.Conn maxPoolSize int } // NewTeleportClient creates a new teleport client func NewTeleportClient(config *config.Config) *TeleportClient { ctx, cancel := context.WithCancel(context.Background()) maxPoolSize := 10 // Default connection pool size if config.MaxConnections > 0 { maxPoolSize = config.MaxConnections / 10 // Use 10% of max connections for pool if maxPoolSize < 5 { maxPoolSize = 5 } } return &TeleportClient{ config: config, udpListeners: make(map[int]*net.UDPConn), replayProtection: encryption.NewReplayProtection(), ctx: ctx, cancel: cancel, connectionPool: make(chan net.Conn, maxPoolSize), maxPoolSize: maxPoolSize, } } // Start starts the teleport client func (tc *TeleportClient) Start() error { logger.WithField("remote_address", tc.config.RemoteAddress).Info("Starting teleport client") // Connect to server conn, err := net.Dial("tcp", tc.config.RemoteAddress) if err != nil { return fmt.Errorf("failed to connect to server: %v", err) } tc.serverConn = conn // Start DNS server if enabled if tc.config.DNSServer.Enabled { go dns.StartDNSServer(tc.config) } // Start port forwarding for each port rule for _, rule := range tc.config.Ports { go tc.startPortForwarding(rule) } // Keep the client running with proper error handling <-tc.ctx.Done() logger.Info("Client shutting down...") return tc.ctx.Err() } // startPortForwarding starts port forwarding for a specific rule func (tc *TeleportClient) startPortForwarding(rule config.PortRule) { switch rule.Protocol { case "tcp": tc.startTCPForwarding(rule) case "udp": tc.startUDPForwarding(rule) } } // startTCPForwarding starts TCP port forwarding func (tc *TeleportClient) startTCPForwarding(rule config.PortRule) { listener, err := net.Listen("tcp", fmt.Sprintf(":%d", rule.LocalPort)) if err != nil { logger.WithFields(map[string]interface{}{ "port": rule.LocalPort, "error": err, }).Error("Failed to start TCP listener") return } defer listener.Close() logger.WithFields(map[string]interface{}{ "local_port": rule.LocalPort, "remote_addr": tc.config.RemoteAddress, "remote_port": rule.RemotePort, "protocol": rule.Protocol, }).Info("TCP forwarding started") for { select { case <-tc.ctx.Done(): logger.Debug("Client shutting down...") return default: clientConn, err := listener.Accept() if err != nil { select { case <-tc.ctx.Done(): return default: logger.WithField("error", err).Error("Failed to accept TCP connection") continue } } go tc.handleTCPConnection(clientConn, rule) } } } // Stop stops the teleport client func (tc *TeleportClient) Stop() { logger.Info("Stopping teleport client...") tc.cancel() if tc.serverConn != nil { tc.serverConn.Close() } // Close all connections in the pool for { select { case conn := <-tc.connectionPool: conn.Close() default: goto poolClosed } } poolClosed: // Close all UDP listeners tc.udpMutex.Lock() for _, conn := range tc.udpListeners { conn.Close() } tc.udpMutex.Unlock() } // getConnection gets a connection from the pool or creates a new one func (tc *TeleportClient) getConnection() (net.Conn, error) { // Try to get a healthy connection from the pool for { select { case conn := <-tc.connectionPool: if conn != nil && tc.isConnectionHealthy(conn) { return conn, nil } // Connection is nil or unhealthy, close it and try again if conn != nil { conn.Close() } default: // No connection in pool, create new one break } break } // Create new connection conn, err := net.Dial("tcp", tc.config.RemoteAddress) if err != nil { return nil, fmt.Errorf("failed to create connection: %v", err) } // Set connection timeouts if tc.config.ReadTimeout > 0 { conn.SetReadDeadline(time.Now().Add(tc.config.ReadTimeout)) } if tc.config.WriteTimeout > 0 { conn.SetWriteDeadline(time.Now().Add(tc.config.WriteTimeout)) } return conn, nil } // isConnectionHealthy checks if a connection is still alive and usable func (tc *TeleportClient) isConnectionHealthy(conn net.Conn) bool { if conn == nil { return false } // Try to set a very short read deadline to test the connection // This is a non-blocking way to check if the connection is still alive conn.SetReadDeadline(time.Now().Add(1 * time.Millisecond)) // Try to read one byte (this will fail immediately if connection is dead) one := make([]byte, 1) _, err := conn.Read(one) // Clear the deadline conn.SetReadDeadline(time.Time{}) // If we get an error, the connection is likely dead // We expect to get a timeout error for a healthy connection (since we set a 1ms deadline) if err != nil { // Check if it's a timeout error (which is expected for a healthy connection) if netErr, ok := err.(net.Error); ok && netErr.Timeout() { return true // Connection is healthy (timeout is expected) } // Any other error means the connection is dead return false } // If we actually read data, that's unexpected but the connection is alive return true } // returnConnection returns a connection to the pool func (tc *TeleportClient) returnConnection(conn net.Conn) { if conn == nil { return } // Check if connection is still healthy before returning to pool if !tc.isConnectionHealthy(conn) { conn.Close() return } select { case tc.connectionPool <- conn: // Connection returned to pool case <-tc.ctx.Done(): // Context cancelled, close the connection conn.Close() default: // Pool is full, close the connection conn.Close() } } // startUDPForwarding starts UDP port forwarding func (tc *TeleportClient) startUDPForwarding(rule config.PortRule) { addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", rule.LocalPort)) if err != nil { logger.WithFields(map[string]interface{}{ "port": rule.LocalPort, "error": err, }).Error("Failed to resolve UDP address") return } conn, err := net.ListenUDP("udp", addr) if err != nil { logger.WithFields(map[string]interface{}{ "port": rule.LocalPort, "error": err, }).Error("Failed to start UDP listener") return } tc.udpMutex.Lock() tc.udpListeners[rule.LocalPort] = conn tc.udpMutex.Unlock() logger.WithFields(map[string]interface{}{ "local_port": rule.LocalPort, "remote_addr": tc.config.RemoteAddress, "remote_port": rule.RemotePort, "protocol": rule.Protocol, }).Info("UDP forwarding started") buffer := make([]byte, 4096) for { select { case <-tc.ctx.Done(): logger.Debug("Client context cancelled, stopping UDP forwarding") return default: n, clientAddr, err := conn.ReadFromUDP(buffer) if err != nil { select { case <-tc.ctx.Done(): return default: logger.WithField("error", err).Error("Failed to read UDP packet") continue } } // Create tagged packet with atomic counter packetID := atomic.AddUint64(&tc.packetCounter, 1) taggedPacket := types.TaggedUDPPacket{ Header: types.UDPPacketHeader{ ClientID: tc.config.InstanceID, PacketID: packetID, Timestamp: time.Now().Unix(), }, Data: make([]byte, n), } copy(taggedPacket.Data, buffer[:n]) logger.WithFields(map[string]interface{}{ "client": clientAddr, "data_length": len(taggedPacket.Data), "packetID": taggedPacket.Header.PacketID, }).Debug("UDP packet received") // Send to server and wait for response go tc.sendTaggedUDPPacketWithResponse(taggedPacket, conn, clientAddr) } } } // sendTaggedUDPPacketWithResponse sends a tagged UDP packet to the server and forwards response back func (tc *TeleportClient) sendTaggedUDPPacketWithResponse(packet types.TaggedUDPPacket, udpConn *net.UDPConn, clientAddr *net.UDPAddr) { logger.WithField("packetID", packet.Header.PacketID).Debug("Starting to send UDP packet to server") // Establish a new connection to the server for this UDP packet serverConn, err := net.Dial("tcp", tc.config.RemoteAddress) if err != nil { logger.WithField("error", err).Debug("UDP CLIENT: Failed to connect to server for UDP packet") return } defer serverConn.Close() logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Connected to server, sending port forward request") // Find the UDP port rule to get the correct remote port var remotePort int for _, rule := range tc.config.Ports { if rule.Protocol == "udp" { remotePort = rule.RemotePort break } } // Send port forward request first (like TCP does) request := types.PortForwardRequest{ LocalPort: int(packet.Header.PacketID), // Use packet ID as local port for identification RemotePort: remotePort, // UDP port we want to forward to Protocol: "udp", TargetHost: "", } if err := tc.sendRequestToConnection(serverConn, request); err != nil { logger.WithFields(map[string]interface{}{ "packetID": packet.Header.PacketID, "error": err, }).Debug("UDP CLIENT: Failed to send port forward request") return } logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Port forward request sent, sending packet") // Send UDP packet through the new connection tc.sendTaggedUDPPacketToConnection(serverConn, packet) logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Packet sent, waiting for response") // Wait for response from server and forward it back tc.waitForUDPResponseAndForward(serverConn, packet.Header.PacketID, udpConn, clientAddr) } // sendTaggedUDPPacketToConnection sends a tagged UDP packet to a specific connection func (tc *TeleportClient) sendTaggedUDPPacketToConnection(conn net.Conn, packet types.TaggedUDPPacket) { // Serialize the tagged packet data, err := tc.serializeTaggedUDPPacket(packet) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": packet.Header.PacketID, "error": err, }).Debug("UDP CLIENT: Failed to serialize tagged UDP packet") return } logger.WithFields(map[string]interface{}{ "packetID": packet.Header.PacketID, "data_length": len(data), }).Debug("UDP CLIENT: Serialized packet") // Encrypt the data key := encryption.DeriveKey(tc.config.EncryptionKey) encryptedData, err := encryption.EncryptData(data, key) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": packet.Header.PacketID, "error": err, }).Debug("UDP CLIENT: Failed to encrypt UDP packet") return } logger.WithFields(map[string]interface{}{ "packetID": packet.Header.PacketID, "encrypted_length": len(encryptedData), }).Debug("UDP CLIENT: Encrypted packet") // Send to server _, err = conn.Write(encryptedData) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": packet.Header.PacketID, "error": err, }).Debug("UDP CLIENT: Failed to send UDP packet to server") } else { logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Successfully sent packet to server") } } // waitForUDPResponseAndForward waits for a UDP response from the server and forwards it back func (tc *TeleportClient) waitForUDPResponseAndForward(conn net.Conn, expectedPacketID uint64, udpConn *net.UDPConn, clientAddr *net.UDPAddr) { logger.WithField("packetID", expectedPacketID).Debug("UDP CLIENT: Waiting for response to packet") buffer := make([]byte, 4096) for { select { case <-tc.ctx.Done(): logger.WithField("packetID", expectedPacketID).Debug("UDP CLIENT: Context cancelled while waiting for response") return default: // Set a timeout for reading the response conn.SetReadDeadline(time.Now().Add(5 * time.Second)) n, err := conn.Read(buffer) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": expectedPacketID, "error": err, }).Debug("UDP CLIENT: Failed to read UDP response") return } logger.WithFields(map[string]interface{}{ "packetID": expectedPacketID, "bytes_received": n, }).Debug("UDP CLIENT: Received response bytes") // Decrypt the response key := encryption.DeriveKey(tc.config.EncryptionKey) decryptedData, err := encryption.DecryptData(buffer[:n], key) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": expectedPacketID, "error": err, }).Debug("UDP CLIENT: Failed to decrypt UDP response") continue } // Deserialize the response packet responsePacket, err := tc.deserializeTaggedUDPPacket(decryptedData) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": expectedPacketID, "error": err, }).Debug("UDP CLIENT: Failed to deserialize UDP response") continue } // Validate packet timestamp if !encryption.ValidatePacketTimestamp(responsePacket.Header.Timestamp) { logger.WithFields(map[string]interface{}{ "packetID": responsePacket.Header.PacketID, "timestamp": responsePacket.Header.Timestamp, }).Debug("UDP CLIENT: Response packet timestamp validation failed") continue } // Check for replay attacks if !tc.replayProtection.IsValidNonce(responsePacket.Header.PacketID, responsePacket.Header.Timestamp) { logger.WithFields(map[string]interface{}{ "packetID": responsePacket.Header.PacketID, "timestamp": responsePacket.Header.Timestamp, }).Debug("UDP CLIENT: Replay attack detected in response") continue } logger.WithFields(map[string]interface{}{ "packetID": responsePacket.Header.PacketID, "data_length": len(responsePacket.Data), }).Debug("UDP CLIENT: Deserialized response packet") // Check if this is the response we're waiting for if responsePacket.Header.PacketID == expectedPacketID { // Forward the response back to the original UDP client _, err := udpConn.WriteToUDP(responsePacket.Data, clientAddr) if err != nil { logger.WithFields(map[string]interface{}{ "packetID": responsePacket.Header.PacketID, "error": err, }).Debug("UDP CLIENT: Failed to forward UDP response to client") } else { logger.WithFields(map[string]interface{}{ "packetID": responsePacket.Header.PacketID, "client": clientAddr, "data_length": len(responsePacket.Data), }).Debug("UDP CLIENT: Successfully forwarded UDP response") } return } else { logger.WithFields(map[string]interface{}{ "received_packetID": responsePacket.Header.PacketID, "expected_packetID": expectedPacketID, }).Debug("UDP CLIENT: Received response for different packet") } } } } // serializeTaggedUDPPacket serializes a tagged UDP packet func (tc *TeleportClient) serializeTaggedUDPPacket(packet types.TaggedUDPPacket) ([]byte, error) { // Simple serialization: header length + header + data headerBytes := []byte(packet.Header.ClientID) headerLen := len(headerBytes) data := make([]byte, 4+8+8+headerLen+len(packet.Data)) offset := 0 // Header length binary.BigEndian.PutUint32(data[offset:], uint32(headerLen)) offset += 4 // Packet ID binary.BigEndian.PutUint64(data[offset:], packet.Header.PacketID) offset += 8 // Timestamp binary.BigEndian.PutUint64(data[offset:], uint64(packet.Header.Timestamp)) offset += 8 // Client ID copy(data[offset:], headerBytes) offset += headerLen // Data copy(data[offset:], packet.Data) return data, nil } // deserializeTaggedUDPPacket deserializes a tagged UDP packet func (tc *TeleportClient) deserializeTaggedUDPPacket(data []byte) (types.TaggedUDPPacket, error) { // Minimum size: 4 (headerLen) + 8 (packetID) + 8 (timestamp) = 20 bytes if len(data) < 20 { return types.TaggedUDPPacket{}, fmt.Errorf("packet data too short") } // Maximum reasonable packet size (1MB) if len(data) > 1024*1024 { return types.TaggedUDPPacket{}, fmt.Errorf("packet data too large") } offset := 0 // Header length if offset+4 > len(data) { return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for header length") } headerLen := binary.BigEndian.Uint32(data[offset:]) offset += 4 // Validate header length (reasonable limits) if headerLen > 1024 || headerLen == 0 { return types.TaggedUDPPacket{}, fmt.Errorf("invalid header length: %d", headerLen) } // Packet ID if offset+8 > len(data) { return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for packet ID") } packetID := binary.BigEndian.Uint64(data[offset:]) offset += 8 // Timestamp if offset+8 > len(data) { return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for timestamp") } timestamp := binary.BigEndian.Uint64(data[offset:]) offset += 8 // Validate timestamp is not too old or in the future now := time.Now().Unix() if timestamp > uint64(now+300) || timestamp < uint64(now-300) { // 5 minute window return types.TaggedUDPPacket{}, fmt.Errorf("timestamp out of range: %d (current: %d)", timestamp, now) } // Client ID if offset+int(headerLen) > len(data) { return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for client ID") } clientID := string(data[offset : offset+int(headerLen)]) offset += int(headerLen) // Validate client ID (basic sanitization) if len(clientID) == 0 { return types.TaggedUDPPacket{}, fmt.Errorf("empty client ID") } // Check for reasonable client ID format if len(clientID) > 256 { return types.TaggedUDPPacket{}, fmt.Errorf("client ID too long") } // Data if offset > len(data) { return types.TaggedUDPPacket{}, fmt.Errorf("data offset exceeds packet length") } dataLen := len(data) - offset packetData := make([]byte, dataLen) copy(packetData, data[offset:]) return types.TaggedUDPPacket{ Header: types.UDPPacketHeader{ ClientID: clientID, PacketID: packetID, Timestamp: int64(timestamp), }, Data: packetData, }, nil } // handleTCPConnection handles a TCP connection from a local client func (tc *TeleportClient) handleTCPConnection(clientConn net.Conn, rule config.PortRule) { defer clientConn.Close() // Get a connection from the pool or create a new one serverConn, err := tc.getConnection() if err != nil { logger.WithField("error", err).Error("Failed to get server connection") return } defer tc.returnConnection(serverConn) // Send port forward request to server request := types.PortForwardRequest(rule) if err := tc.sendRequestToConnection(serverConn, request); err != nil { logger.WithField("error", err).Error("Failed to send port forward request") return } // Now forward the actual data bidirectionally var wg sync.WaitGroup wg.Add(2) // Forward data from client to server go func() { defer wg.Done() tc.forwardData(clientConn, serverConn) }() // Forward data from server to client go func() { defer wg.Done() tc.forwardData(serverConn, clientConn) }() wg.Wait() } // forwardData forwards data from src to dst func (tc *TeleportClient) forwardData(src, dst net.Conn) { buffer := make([]byte, 4096) for { select { case <-tc.ctx.Done(): return default: n, err := src.Read(buffer) if err != nil { // Close the destination connection when source closes dst.Close() return } _, err = dst.Write(buffer[:n]) if err != nil { // Close the source connection when destination closes src.Close() return } } } } // sendRequest sends a port forward request to the server func (tc *TeleportClient) sendRequest(request types.PortForwardRequest) error { return tc.sendRequestToConnection(tc.serverConn, request) } // sendRequestToConnection sends a port forward request to a specific connection func (tc *TeleportClient) sendRequestToConnection(conn net.Conn, request types.PortForwardRequest) error { // Serialize the request data, err := tc.serializeRequest(request) if err != nil { return err } // Encrypt the data key := encryption.DeriveKey(tc.config.EncryptionKey) encryptedData, err := encryption.EncryptData(data, key) if err != nil { logger.WithFields(map[string]interface{}{ "error": err, "data_length": len(data), "request": fmt.Sprintf("%s:%d->%d", request.Protocol, request.LocalPort, request.RemotePort), }).Error("Encryption failed") return err } logger.WithFields(map[string]interface{}{ "original_length": len(data), "encrypted_length": len(encryptedData), "compression_ratio": fmt.Sprintf("%.2f", float64(len(encryptedData))/float64(len(data))), "request": fmt.Sprintf("%s:%d->%d", request.Protocol, request.LocalPort, request.RemotePort), }).Debug("Data encrypted successfully") // Validate request size before sending if len(encryptedData) == 0 { return fmt.Errorf("encrypted data cannot be empty") } if len(encryptedData) > 64*1024 { // 64KB limit return fmt.Errorf("request too large: %d bytes", len(encryptedData)) } // Send request length length := uint32(len(encryptedData)) if err := binary.Write(conn, binary.BigEndian, length); err != nil { return err } // Send encrypted request data bytesWritten := 0 for bytesWritten < len(encryptedData) { n, err := conn.Write(encryptedData[bytesWritten:]) if err != nil { return fmt.Errorf("failed to write request data: %v", err) } bytesWritten += n } return nil } // serializeRequest serializes a port forward request func (tc *TeleportClient) serializeRequest(request types.PortForwardRequest) ([]byte, error) { protocolBytes := []byte(request.Protocol) protocolLen := len(protocolBytes) targetHostBytes := []byte(request.TargetHost) targetHostLen := len(targetHostBytes) data := make([]byte, 4+4+4+4+protocolLen+targetHostLen) offset := 0 // Local port binary.BigEndian.PutUint32(data[offset:], uint32(request.LocalPort)) offset += 4 // Remote port binary.BigEndian.PutUint32(data[offset:], uint32(request.RemotePort)) offset += 4 // Protocol length binary.BigEndian.PutUint32(data[offset:], uint32(protocolLen)) offset += 4 // Protocol copy(data[offset:], protocolBytes) offset += protocolLen // Target host length binary.BigEndian.PutUint32(data[offset:], uint32(targetHostLen)) offset += 4 // Target host copy(data[offset:], targetHostBytes) return data, nil }