Add initial project structure with core functionality
- Created a new Go module named 'teleport' for secure port forwarding. - Added essential files including .gitignore, LICENSE, and README.md with project details. - Implemented configuration management with YAML support in config package. - Developed core client and server functionalities for handling port forwarding. - Introduced DNS server capabilities and integrated logging with sanitization. - Established rate limiting and metrics tracking for performance monitoring. - Included comprehensive tests for core components and functionalities. - Set up CI workflows for automated testing and release management using Gitea actions.
This commit is contained in:
785
internal/client/client.go
Normal file
785
internal/client/client.go
Normal file
@@ -0,0 +1,785 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"teleport/pkg/config"
|
||||
"teleport/pkg/dns"
|
||||
"teleport/pkg/encryption"
|
||||
"teleport/pkg/logger"
|
||||
"teleport/pkg/types"
|
||||
)
|
||||
|
||||
// TeleportClient represents the client instance
|
||||
type TeleportClient struct {
|
||||
config *config.Config
|
||||
serverConn net.Conn
|
||||
udpListeners map[int]*net.UDPConn
|
||||
udpMutex sync.RWMutex
|
||||
packetCounter uint64
|
||||
replayProtection *encryption.ReplayProtection
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
connectionPool chan net.Conn
|
||||
maxPoolSize int
|
||||
}
|
||||
|
||||
// NewTeleportClient creates a new teleport client
|
||||
func NewTeleportClient(config *config.Config) *TeleportClient {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
maxPoolSize := 10 // Default connection pool size
|
||||
if config.MaxConnections > 0 {
|
||||
maxPoolSize = config.MaxConnections / 10 // Use 10% of max connections for pool
|
||||
if maxPoolSize < 5 {
|
||||
maxPoolSize = 5
|
||||
}
|
||||
}
|
||||
|
||||
return &TeleportClient{
|
||||
config: config,
|
||||
udpListeners: make(map[int]*net.UDPConn),
|
||||
replayProtection: encryption.NewReplayProtection(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
connectionPool: make(chan net.Conn, maxPoolSize),
|
||||
maxPoolSize: maxPoolSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the teleport client
|
||||
func (tc *TeleportClient) Start() error {
|
||||
logger.WithField("remote_address", tc.config.RemoteAddress).Info("Starting teleport client")
|
||||
|
||||
// Connect to server
|
||||
conn, err := net.Dial("tcp", tc.config.RemoteAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to server: %v", err)
|
||||
}
|
||||
tc.serverConn = conn
|
||||
|
||||
// Start DNS server if enabled
|
||||
if tc.config.DNSServer.Enabled {
|
||||
go dns.StartDNSServer(tc.config)
|
||||
}
|
||||
|
||||
// Start port forwarding for each port rule
|
||||
for _, rule := range tc.config.Ports {
|
||||
go tc.startPortForwarding(rule)
|
||||
}
|
||||
|
||||
// Keep the client running with proper error handling
|
||||
<-tc.ctx.Done()
|
||||
logger.Info("Client shutting down...")
|
||||
return tc.ctx.Err()
|
||||
}
|
||||
|
||||
// startPortForwarding starts port forwarding for a specific rule
|
||||
func (tc *TeleportClient) startPortForwarding(rule config.PortRule) {
|
||||
switch rule.Protocol {
|
||||
case "tcp":
|
||||
tc.startTCPForwarding(rule)
|
||||
case "udp":
|
||||
tc.startUDPForwarding(rule)
|
||||
}
|
||||
}
|
||||
|
||||
// startTCPForwarding starts TCP port forwarding
|
||||
func (tc *TeleportClient) startTCPForwarding(rule config.PortRule) {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", rule.LocalPort))
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"port": rule.LocalPort,
|
||||
"error": err,
|
||||
}).Error("Failed to start TCP listener")
|
||||
return
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"local_port": rule.LocalPort,
|
||||
"remote_addr": tc.config.RemoteAddress,
|
||||
"remote_port": rule.RemotePort,
|
||||
"protocol": rule.Protocol,
|
||||
}).Info("TCP forwarding started")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tc.ctx.Done():
|
||||
logger.Debug("Client shutting down...")
|
||||
return
|
||||
default:
|
||||
clientConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-tc.ctx.Done():
|
||||
return
|
||||
default:
|
||||
logger.WithField("error", err).Error("Failed to accept TCP connection")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
go tc.handleTCPConnection(clientConn, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the teleport client
|
||||
func (tc *TeleportClient) Stop() {
|
||||
logger.Info("Stopping teleport client...")
|
||||
tc.cancel()
|
||||
if tc.serverConn != nil {
|
||||
tc.serverConn.Close()
|
||||
}
|
||||
|
||||
// Close all connections in the pool
|
||||
for {
|
||||
select {
|
||||
case conn := <-tc.connectionPool:
|
||||
conn.Close()
|
||||
default:
|
||||
goto poolClosed
|
||||
}
|
||||
}
|
||||
poolClosed:
|
||||
|
||||
// Close all UDP listeners
|
||||
tc.udpMutex.Lock()
|
||||
for _, conn := range tc.udpListeners {
|
||||
conn.Close()
|
||||
}
|
||||
tc.udpMutex.Unlock()
|
||||
}
|
||||
|
||||
// getConnection gets a connection from the pool or creates a new one
|
||||
func (tc *TeleportClient) getConnection() (net.Conn, error) {
|
||||
// Try to get a healthy connection from the pool
|
||||
for {
|
||||
select {
|
||||
case conn := <-tc.connectionPool:
|
||||
if conn != nil && tc.isConnectionHealthy(conn) {
|
||||
return conn, nil
|
||||
}
|
||||
// Connection is nil or unhealthy, close it and try again
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
default:
|
||||
// No connection in pool, create new one
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Create new connection
|
||||
conn, err := net.Dial("tcp", tc.config.RemoteAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection: %v", err)
|
||||
}
|
||||
|
||||
// Set connection timeouts
|
||||
if tc.config.ReadTimeout > 0 {
|
||||
conn.SetReadDeadline(time.Now().Add(tc.config.ReadTimeout))
|
||||
}
|
||||
if tc.config.WriteTimeout > 0 {
|
||||
conn.SetWriteDeadline(time.Now().Add(tc.config.WriteTimeout))
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// isConnectionHealthy checks if a connection is still alive and usable
|
||||
func (tc *TeleportClient) isConnectionHealthy(conn net.Conn) bool {
|
||||
if conn == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Try to set a very short read deadline to test the connection
|
||||
// This is a non-blocking way to check if the connection is still alive
|
||||
conn.SetReadDeadline(time.Now().Add(1 * time.Millisecond))
|
||||
|
||||
// Try to read one byte (this will fail immediately if connection is dead)
|
||||
one := make([]byte, 1)
|
||||
_, err := conn.Read(one)
|
||||
|
||||
// Clear the deadline
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// If we get an error, the connection is likely dead
|
||||
// We expect to get a timeout error for a healthy connection (since we set a 1ms deadline)
|
||||
if err != nil {
|
||||
// Check if it's a timeout error (which is expected for a healthy connection)
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
return true // Connection is healthy (timeout is expected)
|
||||
}
|
||||
// Any other error means the connection is dead
|
||||
return false
|
||||
}
|
||||
|
||||
// If we actually read data, that's unexpected but the connection is alive
|
||||
return true
|
||||
}
|
||||
|
||||
// returnConnection returns a connection to the pool
|
||||
func (tc *TeleportClient) returnConnection(conn net.Conn) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if connection is still healthy before returning to pool
|
||||
if !tc.isConnectionHealthy(conn) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case tc.connectionPool <- conn:
|
||||
// Connection returned to pool
|
||||
case <-tc.ctx.Done():
|
||||
// Context cancelled, close the connection
|
||||
conn.Close()
|
||||
default:
|
||||
// Pool is full, close the connection
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// startUDPForwarding starts UDP port forwarding
|
||||
func (tc *TeleportClient) startUDPForwarding(rule config.PortRule) {
|
||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", rule.LocalPort))
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"port": rule.LocalPort,
|
||||
"error": err,
|
||||
}).Error("Failed to resolve UDP address")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"port": rule.LocalPort,
|
||||
"error": err,
|
||||
}).Error("Failed to start UDP listener")
|
||||
return
|
||||
}
|
||||
|
||||
tc.udpMutex.Lock()
|
||||
tc.udpListeners[rule.LocalPort] = conn
|
||||
tc.udpMutex.Unlock()
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"local_port": rule.LocalPort,
|
||||
"remote_addr": tc.config.RemoteAddress,
|
||||
"remote_port": rule.RemotePort,
|
||||
"protocol": rule.Protocol,
|
||||
}).Info("UDP forwarding started")
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-tc.ctx.Done():
|
||||
logger.Debug("Client context cancelled, stopping UDP forwarding")
|
||||
return
|
||||
default:
|
||||
n, clientAddr, err := conn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-tc.ctx.Done():
|
||||
return
|
||||
default:
|
||||
logger.WithField("error", err).Error("Failed to read UDP packet")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Create tagged packet with atomic counter
|
||||
packetID := atomic.AddUint64(&tc.packetCounter, 1)
|
||||
taggedPacket := types.TaggedUDPPacket{
|
||||
Header: types.UDPPacketHeader{
|
||||
ClientID: tc.config.InstanceID,
|
||||
PacketID: packetID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
},
|
||||
Data: make([]byte, n),
|
||||
}
|
||||
copy(taggedPacket.Data, buffer[:n])
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"client": clientAddr,
|
||||
"data_length": len(taggedPacket.Data),
|
||||
"packetID": taggedPacket.Header.PacketID,
|
||||
}).Debug("UDP packet received")
|
||||
|
||||
// Send to server and wait for response
|
||||
go tc.sendTaggedUDPPacketWithResponse(taggedPacket, conn, clientAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendTaggedUDPPacketWithResponse sends a tagged UDP packet to the server and forwards response back
|
||||
func (tc *TeleportClient) sendTaggedUDPPacketWithResponse(packet types.TaggedUDPPacket, udpConn *net.UDPConn, clientAddr *net.UDPAddr) {
|
||||
logger.WithField("packetID", packet.Header.PacketID).Debug("Starting to send UDP packet to server")
|
||||
|
||||
// Establish a new connection to the server for this UDP packet
|
||||
serverConn, err := net.Dial("tcp", tc.config.RemoteAddress)
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Debug("UDP CLIENT: Failed to connect to server for UDP packet")
|
||||
return
|
||||
}
|
||||
defer serverConn.Close()
|
||||
|
||||
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Connected to server, sending port forward request")
|
||||
|
||||
// Find the UDP port rule to get the correct remote port
|
||||
var remotePort int
|
||||
for _, rule := range tc.config.Ports {
|
||||
if rule.Protocol == "udp" {
|
||||
remotePort = rule.RemotePort
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Send port forward request first (like TCP does)
|
||||
request := types.PortForwardRequest{
|
||||
LocalPort: int(packet.Header.PacketID), // Use packet ID as local port for identification
|
||||
RemotePort: remotePort, // UDP port we want to forward to
|
||||
Protocol: "udp",
|
||||
TargetHost: "",
|
||||
}
|
||||
|
||||
if err := tc.sendRequestToConnection(serverConn, request); err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": packet.Header.PacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP CLIENT: Failed to send port forward request")
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Port forward request sent, sending packet")
|
||||
|
||||
// Send UDP packet through the new connection
|
||||
tc.sendTaggedUDPPacketToConnection(serverConn, packet)
|
||||
|
||||
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Packet sent, waiting for response")
|
||||
|
||||
// Wait for response from server and forward it back
|
||||
tc.waitForUDPResponseAndForward(serverConn, packet.Header.PacketID, udpConn, clientAddr)
|
||||
}
|
||||
|
||||
// sendTaggedUDPPacketToConnection sends a tagged UDP packet to a specific connection
|
||||
func (tc *TeleportClient) sendTaggedUDPPacketToConnection(conn net.Conn, packet types.TaggedUDPPacket) {
|
||||
// Serialize the tagged packet
|
||||
data, err := tc.serializeTaggedUDPPacket(packet)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": packet.Header.PacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP CLIENT: Failed to serialize tagged UDP packet")
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": packet.Header.PacketID,
|
||||
"data_length": len(data),
|
||||
}).Debug("UDP CLIENT: Serialized packet")
|
||||
|
||||
// Encrypt the data
|
||||
key := encryption.DeriveKey(tc.config.EncryptionKey)
|
||||
encryptedData, err := encryption.EncryptData(data, key)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": packet.Header.PacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP CLIENT: Failed to encrypt UDP packet")
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": packet.Header.PacketID,
|
||||
"encrypted_length": len(encryptedData),
|
||||
}).Debug("UDP CLIENT: Encrypted packet")
|
||||
|
||||
// Send to server
|
||||
_, err = conn.Write(encryptedData)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": packet.Header.PacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP CLIENT: Failed to send UDP packet to server")
|
||||
} else {
|
||||
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Successfully sent packet to server")
|
||||
}
|
||||
}
|
||||
|
||||
// waitForUDPResponseAndForward waits for a UDP response from the server and forwards it back
|
||||
func (tc *TeleportClient) waitForUDPResponseAndForward(conn net.Conn, expectedPacketID uint64, udpConn *net.UDPConn, clientAddr *net.UDPAddr) {
|
||||
logger.WithField("packetID", expectedPacketID).Debug("UDP CLIENT: Waiting for response to packet")
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-tc.ctx.Done():
|
||||
logger.WithField("packetID", expectedPacketID).Debug("UDP CLIENT: Context cancelled while waiting for response")
|
||||
return
|
||||
default:
|
||||
// Set a timeout for reading the response
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
n, err := conn.Read(buffer)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": expectedPacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP CLIENT: Failed to read UDP response")
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": expectedPacketID,
|
||||
"bytes_received": n,
|
||||
}).Debug("UDP CLIENT: Received response bytes")
|
||||
|
||||
// Decrypt the response
|
||||
key := encryption.DeriveKey(tc.config.EncryptionKey)
|
||||
decryptedData, err := encryption.DecryptData(buffer[:n], key)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": expectedPacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP CLIENT: Failed to decrypt UDP response")
|
||||
continue
|
||||
}
|
||||
|
||||
// Deserialize the response packet
|
||||
responsePacket, err := tc.deserializeTaggedUDPPacket(decryptedData)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": expectedPacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP CLIENT: Failed to deserialize UDP response")
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate packet timestamp
|
||||
if !encryption.ValidatePacketTimestamp(responsePacket.Header.Timestamp) {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": responsePacket.Header.PacketID,
|
||||
"timestamp": responsePacket.Header.Timestamp,
|
||||
}).Debug("UDP CLIENT: Response packet timestamp validation failed")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for replay attacks
|
||||
if !tc.replayProtection.IsValidNonce(responsePacket.Header.PacketID, responsePacket.Header.Timestamp) {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": responsePacket.Header.PacketID,
|
||||
"timestamp": responsePacket.Header.Timestamp,
|
||||
}).Debug("UDP CLIENT: Replay attack detected in response")
|
||||
continue
|
||||
}
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": responsePacket.Header.PacketID,
|
||||
"data_length": len(responsePacket.Data),
|
||||
}).Debug("UDP CLIENT: Deserialized response packet")
|
||||
|
||||
// Check if this is the response we're waiting for
|
||||
if responsePacket.Header.PacketID == expectedPacketID {
|
||||
// Forward the response back to the original UDP client
|
||||
_, err := udpConn.WriteToUDP(responsePacket.Data, clientAddr)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": responsePacket.Header.PacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP CLIENT: Failed to forward UDP response to client")
|
||||
} else {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": responsePacket.Header.PacketID,
|
||||
"client": clientAddr,
|
||||
"data_length": len(responsePacket.Data),
|
||||
}).Debug("UDP CLIENT: Successfully forwarded UDP response")
|
||||
}
|
||||
return
|
||||
} else {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"received_packetID": responsePacket.Header.PacketID,
|
||||
"expected_packetID": expectedPacketID,
|
||||
}).Debug("UDP CLIENT: Received response for different packet")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// serializeTaggedUDPPacket serializes a tagged UDP packet
|
||||
func (tc *TeleportClient) serializeTaggedUDPPacket(packet types.TaggedUDPPacket) ([]byte, error) {
|
||||
// Simple serialization: header length + header + data
|
||||
headerBytes := []byte(packet.Header.ClientID)
|
||||
headerLen := len(headerBytes)
|
||||
|
||||
data := make([]byte, 4+8+8+headerLen+len(packet.Data))
|
||||
offset := 0
|
||||
|
||||
// Header length
|
||||
binary.BigEndian.PutUint32(data[offset:], uint32(headerLen))
|
||||
offset += 4
|
||||
|
||||
// Packet ID
|
||||
binary.BigEndian.PutUint64(data[offset:], packet.Header.PacketID)
|
||||
offset += 8
|
||||
|
||||
// Timestamp
|
||||
binary.BigEndian.PutUint64(data[offset:], uint64(packet.Header.Timestamp))
|
||||
offset += 8
|
||||
|
||||
// Client ID
|
||||
copy(data[offset:], headerBytes)
|
||||
offset += headerLen
|
||||
|
||||
// Data
|
||||
copy(data[offset:], packet.Data)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// deserializeTaggedUDPPacket deserializes a tagged UDP packet
|
||||
func (tc *TeleportClient) deserializeTaggedUDPPacket(data []byte) (types.TaggedUDPPacket, error) {
|
||||
// Minimum size: 4 (headerLen) + 8 (packetID) + 8 (timestamp) = 20 bytes
|
||||
if len(data) < 20 {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("packet data too short")
|
||||
}
|
||||
|
||||
// Maximum reasonable packet size (1MB)
|
||||
if len(data) > 1024*1024 {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("packet data too large")
|
||||
}
|
||||
|
||||
offset := 0
|
||||
|
||||
// Header length
|
||||
if offset+4 > len(data) {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for header length")
|
||||
}
|
||||
headerLen := binary.BigEndian.Uint32(data[offset:])
|
||||
offset += 4
|
||||
|
||||
// Validate header length (reasonable limits)
|
||||
if headerLen > 1024 || headerLen == 0 {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("invalid header length: %d", headerLen)
|
||||
}
|
||||
|
||||
// Packet ID
|
||||
if offset+8 > len(data) {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for packet ID")
|
||||
}
|
||||
packetID := binary.BigEndian.Uint64(data[offset:])
|
||||
offset += 8
|
||||
|
||||
// Timestamp
|
||||
if offset+8 > len(data) {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for timestamp")
|
||||
}
|
||||
timestamp := binary.BigEndian.Uint64(data[offset:])
|
||||
offset += 8
|
||||
|
||||
// Validate timestamp is not too old or in the future
|
||||
now := time.Now().Unix()
|
||||
if timestamp > uint64(now+300) || timestamp < uint64(now-300) { // 5 minute window
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("timestamp out of range: %d (current: %d)", timestamp, now)
|
||||
}
|
||||
|
||||
// Client ID
|
||||
if offset+int(headerLen) > len(data) {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for client ID")
|
||||
}
|
||||
clientID := string(data[offset : offset+int(headerLen)])
|
||||
offset += int(headerLen)
|
||||
|
||||
// Validate client ID (basic sanitization)
|
||||
if len(clientID) == 0 {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("empty client ID")
|
||||
}
|
||||
// Check for reasonable client ID format
|
||||
if len(clientID) > 256 {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("client ID too long")
|
||||
}
|
||||
|
||||
// Data
|
||||
if offset > len(data) {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("data offset exceeds packet length")
|
||||
}
|
||||
dataLen := len(data) - offset
|
||||
packetData := make([]byte, dataLen)
|
||||
copy(packetData, data[offset:])
|
||||
|
||||
return types.TaggedUDPPacket{
|
||||
Header: types.UDPPacketHeader{
|
||||
ClientID: clientID,
|
||||
PacketID: packetID,
|
||||
Timestamp: int64(timestamp),
|
||||
},
|
||||
Data: packetData,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleTCPConnection handles a TCP connection from a local client
|
||||
func (tc *TeleportClient) handleTCPConnection(clientConn net.Conn, rule config.PortRule) {
|
||||
defer clientConn.Close()
|
||||
|
||||
// Get a connection from the pool or create a new one
|
||||
serverConn, err := tc.getConnection()
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Error("Failed to get server connection")
|
||||
return
|
||||
}
|
||||
defer tc.returnConnection(serverConn)
|
||||
|
||||
// Send port forward request to server
|
||||
request := types.PortForwardRequest(rule)
|
||||
|
||||
if err := tc.sendRequestToConnection(serverConn, request); err != nil {
|
||||
logger.WithField("error", err).Error("Failed to send port forward request")
|
||||
return
|
||||
}
|
||||
|
||||
// Now forward the actual data bidirectionally
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Forward data from client to server
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tc.forwardData(clientConn, serverConn)
|
||||
}()
|
||||
|
||||
// Forward data from server to client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tc.forwardData(serverConn, clientConn)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// forwardData forwards data from src to dst
|
||||
func (tc *TeleportClient) forwardData(src, dst net.Conn) {
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-tc.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
// Close the destination connection when source closes
|
||||
dst.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_, err = dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
// Close the source connection when destination closes
|
||||
src.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendRequest sends a port forward request to the server
|
||||
func (tc *TeleportClient) sendRequest(request types.PortForwardRequest) error {
|
||||
return tc.sendRequestToConnection(tc.serverConn, request)
|
||||
}
|
||||
|
||||
// sendRequestToConnection sends a port forward request to a specific connection
|
||||
func (tc *TeleportClient) sendRequestToConnection(conn net.Conn, request types.PortForwardRequest) error {
|
||||
// Serialize the request
|
||||
data, err := tc.serializeRequest(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Encrypt the data
|
||||
key := encryption.DeriveKey(tc.config.EncryptionKey)
|
||||
encryptedData, err := encryption.EncryptData(data, key)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"error": err,
|
||||
"data_length": len(data),
|
||||
"request": fmt.Sprintf("%s:%d->%d", request.Protocol, request.LocalPort, request.RemotePort),
|
||||
}).Error("Encryption failed")
|
||||
return err
|
||||
}
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"original_length": len(data),
|
||||
"encrypted_length": len(encryptedData),
|
||||
"compression_ratio": fmt.Sprintf("%.2f", float64(len(encryptedData))/float64(len(data))),
|
||||
"request": fmt.Sprintf("%s:%d->%d", request.Protocol, request.LocalPort, request.RemotePort),
|
||||
}).Debug("Data encrypted successfully")
|
||||
|
||||
// Validate request size before sending
|
||||
if len(encryptedData) == 0 {
|
||||
return fmt.Errorf("encrypted data cannot be empty")
|
||||
}
|
||||
if len(encryptedData) > 64*1024 { // 64KB limit
|
||||
return fmt.Errorf("request too large: %d bytes", len(encryptedData))
|
||||
}
|
||||
|
||||
// Send request length
|
||||
length := uint32(len(encryptedData))
|
||||
if err := binary.Write(conn, binary.BigEndian, length); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Send encrypted request data
|
||||
bytesWritten := 0
|
||||
for bytesWritten < len(encryptedData) {
|
||||
n, err := conn.Write(encryptedData[bytesWritten:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write request data: %v", err)
|
||||
}
|
||||
bytesWritten += n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// serializeRequest serializes a port forward request
|
||||
func (tc *TeleportClient) serializeRequest(request types.PortForwardRequest) ([]byte, error) {
|
||||
protocolBytes := []byte(request.Protocol)
|
||||
protocolLen := len(protocolBytes)
|
||||
targetHostBytes := []byte(request.TargetHost)
|
||||
targetHostLen := len(targetHostBytes)
|
||||
|
||||
data := make([]byte, 4+4+4+4+protocolLen+targetHostLen)
|
||||
offset := 0
|
||||
|
||||
// Local port
|
||||
binary.BigEndian.PutUint32(data[offset:], uint32(request.LocalPort))
|
||||
offset += 4
|
||||
|
||||
// Remote port
|
||||
binary.BigEndian.PutUint32(data[offset:], uint32(request.RemotePort))
|
||||
offset += 4
|
||||
|
||||
// Protocol length
|
||||
binary.BigEndian.PutUint32(data[offset:], uint32(protocolLen))
|
||||
offset += 4
|
||||
|
||||
// Protocol
|
||||
copy(data[offset:], protocolBytes)
|
||||
offset += protocolLen
|
||||
|
||||
// Target host length
|
||||
binary.BigEndian.PutUint32(data[offset:], uint32(targetHostLen))
|
||||
offset += 4
|
||||
|
||||
// Target host
|
||||
copy(data[offset:], targetHostBytes)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
Reference in New Issue
Block a user