Add initial project structure with core functionality
- Created a new Go module named 'teleport' for secure port forwarding. - Added essential files including .gitignore, LICENSE, and README.md with project details. - Implemented configuration management with YAML support in config package. - Developed core client and server functionalities for handling port forwarding. - Introduced DNS server capabilities and integrated logging with sanitization. - Established rate limiting and metrics tracking for performance monitoring. - Included comprehensive tests for core components and functionalities. - Set up CI workflows for automated testing and release management using Gitea actions.
This commit is contained in:
869
internal/server/server.go
Normal file
869
internal/server/server.go
Normal file
@@ -0,0 +1,869 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"teleport/pkg/config"
|
||||
"teleport/pkg/encryption"
|
||||
"teleport/pkg/logger"
|
||||
"teleport/pkg/metrics"
|
||||
"teleport/pkg/ratelimit"
|
||||
"teleport/pkg/types"
|
||||
)
|
||||
|
||||
// TeleportServer represents the server instance
|
||||
type TeleportServer struct {
|
||||
config *config.Config
|
||||
listener net.Listener
|
||||
udpListeners map[int]*net.UDPConn
|
||||
udpClients map[string]*net.UDPConn
|
||||
udpMutex sync.RWMutex
|
||||
packetCounter uint64
|
||||
replayProtection *encryption.ReplayProtection
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
connectionSem chan struct{} // Semaphore for limiting concurrent connections
|
||||
activeConnections int64 // Atomic counter for active connections
|
||||
maxConnections int // Maximum concurrent connections
|
||||
rateLimiter *ratelimit.RateLimiter
|
||||
goroutineSem chan struct{} // Semaphore for limiting concurrent goroutines
|
||||
maxGoroutines int // Maximum concurrent goroutines
|
||||
metrics *metrics.Metrics
|
||||
}
|
||||
|
||||
// NewTeleportServer creates a new teleport server
|
||||
func NewTeleportServer(config *config.Config) *TeleportServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
maxConnections := 1000 // Default maximum concurrent connections
|
||||
if config.MaxConnections > 0 {
|
||||
maxConnections = config.MaxConnections
|
||||
}
|
||||
|
||||
maxGoroutines := maxConnections * 2 // Allow 2x connections for goroutines
|
||||
if maxGoroutines > 10000 {
|
||||
maxGoroutines = 10000 // Cap at 10k goroutines
|
||||
}
|
||||
|
||||
// Initialize rate limiter if enabled
|
||||
var rateLimiter *ratelimit.RateLimiter
|
||||
if config.RateLimit.Enabled {
|
||||
rateLimiter = ratelimit.NewRateLimiter(
|
||||
config.RateLimit.RequestsPerSecond,
|
||||
config.RateLimit.BurstSize,
|
||||
config.RateLimit.WindowSize,
|
||||
)
|
||||
}
|
||||
|
||||
// Initialize metrics
|
||||
metricsInstance := metrics.GetMetrics()
|
||||
|
||||
return &TeleportServer{
|
||||
config: config,
|
||||
udpListeners: make(map[int]*net.UDPConn),
|
||||
udpClients: make(map[string]*net.UDPConn),
|
||||
replayProtection: encryption.NewReplayProtection(),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
connectionSem: make(chan struct{}, maxConnections),
|
||||
maxConnections: maxConnections,
|
||||
rateLimiter: rateLimiter,
|
||||
goroutineSem: make(chan struct{}, maxGoroutines),
|
||||
maxGoroutines: maxGoroutines,
|
||||
metrics: metricsInstance,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the teleport server
|
||||
func (ts *TeleportServer) Start() error {
|
||||
logger.WithField("listen_address", ts.config.ListenAddress).Info("Starting teleport server")
|
||||
|
||||
// Start TCP listener
|
||||
listener, err := net.Listen("tcp", ts.config.ListenAddress)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start TCP listener: %v", err)
|
||||
}
|
||||
ts.listener = listener
|
||||
|
||||
// Start UDP listeners for each port
|
||||
for _, rule := range ts.config.Ports {
|
||||
if rule.Protocol == "udp" {
|
||||
go ts.startUDPListener(rule)
|
||||
}
|
||||
}
|
||||
|
||||
// Accept TCP connections
|
||||
for {
|
||||
select {
|
||||
case <-ts.ctx.Done():
|
||||
logger.Debug("Server shutting down...")
|
||||
return ts.ctx.Err()
|
||||
default:
|
||||
conn, err := ts.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ts.ctx.Done():
|
||||
return ts.ctx.Err()
|
||||
default:
|
||||
logger.WithField("error", err).Error("Failed to accept connection")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Check rate limiting first
|
||||
if ts.rateLimiter != nil && !ts.rateLimiter.Allow() {
|
||||
logger.WithField("client", conn.RemoteAddr()).Warn("Rate limit exceeded, rejecting connection")
|
||||
ts.metrics.IncrementRateLimitedRequests()
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if we can accept more connections
|
||||
select {
|
||||
case ts.connectionSem <- struct{}{}:
|
||||
// Connection accepted
|
||||
atomic.AddInt64(&ts.activeConnections, 1)
|
||||
ts.metrics.IncrementTotalConnections()
|
||||
ts.metrics.IncrementActiveConnections()
|
||||
|
||||
// Check goroutine limit
|
||||
select {
|
||||
case ts.goroutineSem <- struct{}{}:
|
||||
go ts.handleConnectionWithLimit(conn)
|
||||
default:
|
||||
// Goroutine limit reached
|
||||
<-ts.connectionSem // Release connection semaphore
|
||||
atomic.AddInt64(&ts.activeConnections, -1)
|
||||
ts.metrics.DecrementActiveConnections()
|
||||
ts.metrics.IncrementRejectedConnections()
|
||||
logger.WithField("max_goroutines", ts.maxGoroutines).Warn("Goroutine limit reached, rejecting connection")
|
||||
conn.Close()
|
||||
}
|
||||
default:
|
||||
// Connection limit reached
|
||||
ts.metrics.IncrementRejectedConnections()
|
||||
logger.WithField("max_connections", ts.maxConnections).Warn("Connection limit reached, rejecting connection")
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the teleport server
|
||||
func (ts *TeleportServer) Stop() {
|
||||
logger.Info("Stopping teleport server...")
|
||||
ts.cancel()
|
||||
if ts.listener != nil {
|
||||
ts.listener.Close()
|
||||
}
|
||||
|
||||
// Close all UDP listeners
|
||||
ts.udpMutex.Lock()
|
||||
for _, conn := range ts.udpListeners {
|
||||
conn.Close()
|
||||
}
|
||||
ts.udpMutex.Unlock()
|
||||
}
|
||||
|
||||
// startUDPListener starts a UDP listener for a specific port
|
||||
func (ts *TeleportServer) startUDPListener(rule config.PortRule) {
|
||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", rule.LocalPort))
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"port": rule.LocalPort,
|
||||
"error": err,
|
||||
}).Error("Failed to resolve UDP address")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"port": rule.LocalPort,
|
||||
"error": err,
|
||||
}).Error("Failed to start UDP listener")
|
||||
return
|
||||
}
|
||||
|
||||
ts.udpMutex.Lock()
|
||||
ts.udpListeners[rule.LocalPort] = conn
|
||||
ts.udpMutex.Unlock()
|
||||
|
||||
logger.WithField("port", rule.LocalPort).Info("UDP listener started")
|
||||
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
n, clientAddr, err := conn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Error("Failed to read UDP packet")
|
||||
continue
|
||||
}
|
||||
|
||||
// Create UDP request
|
||||
request := types.UDPRequest{
|
||||
LocalPort: rule.LocalPort,
|
||||
RemotePort: rule.RemotePort,
|
||||
Protocol: rule.Protocol,
|
||||
Data: make([]byte, n),
|
||||
ClientAddr: clientAddr,
|
||||
}
|
||||
copy(request.Data, buffer[:n])
|
||||
|
||||
// Forward to all connected clients
|
||||
go ts.forwardUDPRequest(request)
|
||||
}
|
||||
}
|
||||
|
||||
// forwardUDPRequest forwards a UDP request to all connected clients
|
||||
func (ts *TeleportServer) forwardUDPRequest(request types.UDPRequest) {
|
||||
ts.udpMutex.RLock()
|
||||
defer ts.udpMutex.RUnlock()
|
||||
|
||||
// Create tagged packet with atomic counter
|
||||
packetID := atomic.AddUint64(&ts.packetCounter, 1)
|
||||
taggedPacket := types.TaggedUDPPacket{
|
||||
Header: types.UDPPacketHeader{
|
||||
ClientID: fmt.Sprintf("udp-%s", request.ClientAddr.String()),
|
||||
PacketID: packetID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
},
|
||||
Data: request.Data,
|
||||
}
|
||||
|
||||
// Forward to all connected UDP clients
|
||||
for clientID, clientConn := range ts.udpClients {
|
||||
if clientID != taggedPacket.Header.ClientID {
|
||||
go ts.sendTaggedUDPPacket(clientConn, taggedPacket)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendTaggedUDPPacket sends a tagged UDP packet to a client
|
||||
func (ts *TeleportServer) sendTaggedUDPPacket(clientConn *net.UDPConn, packet types.TaggedUDPPacket) {
|
||||
// Serialize the tagged packet
|
||||
data, err := ts.serializeTaggedUDPPacket(packet)
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Debug("Failed to serialize tagged UDP packet")
|
||||
return
|
||||
}
|
||||
|
||||
// Encrypt the data
|
||||
key := encryption.DeriveKey(ts.config.EncryptionKey)
|
||||
encryptedData, err := encryption.EncryptData(data, key)
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Debug("Failed to encrypt UDP packet")
|
||||
return
|
||||
}
|
||||
|
||||
// Send to client
|
||||
_, err = clientConn.Write(encryptedData)
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Debug("Failed to send UDP packet to client")
|
||||
}
|
||||
}
|
||||
|
||||
// serializeTaggedUDPPacket serializes a tagged UDP packet
|
||||
func (ts *TeleportServer) serializeTaggedUDPPacket(packet types.TaggedUDPPacket) ([]byte, error) {
|
||||
// Simple serialization: header length + header + data
|
||||
headerBytes := []byte(packet.Header.ClientID)
|
||||
headerLen := len(headerBytes)
|
||||
|
||||
data := make([]byte, 4+8+8+headerLen+len(packet.Data))
|
||||
offset := 0
|
||||
|
||||
// Header length
|
||||
binary.BigEndian.PutUint32(data[offset:], uint32(headerLen))
|
||||
offset += 4
|
||||
|
||||
// Packet ID
|
||||
binary.BigEndian.PutUint64(data[offset:], packet.Header.PacketID)
|
||||
offset += 8
|
||||
|
||||
// Timestamp
|
||||
binary.BigEndian.PutUint64(data[offset:], uint64(packet.Header.Timestamp))
|
||||
offset += 8
|
||||
|
||||
// Client ID
|
||||
copy(data[offset:], headerBytes)
|
||||
offset += headerLen
|
||||
|
||||
// Data
|
||||
copy(data[offset:], packet.Data)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// handleConnectionWithLimit handles a TCP connection with connection limit management
|
||||
func (ts *TeleportServer) handleConnectionWithLimit(conn net.Conn) {
|
||||
defer func() {
|
||||
conn.Close()
|
||||
<-ts.connectionSem // Release connection semaphore
|
||||
<-ts.goroutineSem // Release goroutine semaphore
|
||||
atomic.AddInt64(&ts.activeConnections, -1)
|
||||
ts.metrics.DecrementActiveConnections()
|
||||
}()
|
||||
|
||||
// Set connection timeouts
|
||||
if ts.config.ReadTimeout > 0 {
|
||||
conn.SetReadDeadline(time.Now().Add(ts.config.ReadTimeout))
|
||||
}
|
||||
if ts.config.WriteTimeout > 0 {
|
||||
conn.SetWriteDeadline(time.Now().Add(ts.config.WriteTimeout))
|
||||
}
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"client": conn.RemoteAddr(),
|
||||
"active_connections": atomic.LoadInt64(&ts.activeConnections),
|
||||
}).Info("New connection")
|
||||
|
||||
ts.handleConnection(conn)
|
||||
}
|
||||
|
||||
// handleConnection handles a TCP connection from a client
|
||||
func (ts *TeleportServer) handleConnection(conn net.Conn) {
|
||||
|
||||
// Don't set timeouts on the initial connection - let the individual handlers set them
|
||||
|
||||
// Read the port forward request (only one per connection)
|
||||
var request types.PortForwardRequest
|
||||
if err := ts.readRequest(conn, &request); err != nil {
|
||||
logger.WithField("error", err).Error("Failed to read request")
|
||||
return
|
||||
}
|
||||
|
||||
// Find matching port rule
|
||||
// The client sends LocalPort (client's local port) and RemotePort (server's port to forward to)
|
||||
// We need to find a server rule that matches the RemotePort the client wants
|
||||
var portRule *config.PortRule
|
||||
for _, rule := range ts.config.Ports {
|
||||
if rule.LocalPort == request.RemotePort && rule.Protocol == request.Protocol {
|
||||
portRule = &rule
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if portRule == nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"local_port": request.LocalPort,
|
||||
"remote_port": request.RemotePort,
|
||||
"protocol": request.Protocol,
|
||||
}).Warn("No matching port rule found")
|
||||
return
|
||||
}
|
||||
|
||||
// Handle based on protocol
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"protocol": request.Protocol,
|
||||
"port": request.RemotePort,
|
||||
}).Info("Handling connection")
|
||||
|
||||
switch request.Protocol {
|
||||
case "tcp":
|
||||
logger.Debug("Routing to TCP handler")
|
||||
ts.handleTCPForward(conn, portRule)
|
||||
case "udp":
|
||||
logger.Debug("Routing to UDP handler")
|
||||
ts.handleUDPForward(conn, portRule)
|
||||
}
|
||||
}
|
||||
|
||||
// handleTCPForward handles TCP port forwarding
|
||||
func (ts *TeleportServer) handleTCPForward(clientConn net.Conn, rule *config.PortRule) {
|
||||
// Determine target host (default to localhost if not specified)
|
||||
targetHost := rule.TargetHost
|
||||
if targetHost == "" {
|
||||
targetHost = "localhost"
|
||||
}
|
||||
|
||||
// Connect to the target service
|
||||
targetAddr := net.JoinHostPort(targetHost, fmt.Sprintf("%d", rule.RemotePort))
|
||||
targetConn, err := net.Dial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"target": targetAddr,
|
||||
"error": err,
|
||||
}).Error("Failed to connect to target")
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"client": clientConn.RemoteAddr(),
|
||||
"target": targetAddr,
|
||||
}).Info("TCP forwarding")
|
||||
|
||||
// Start bidirectional forwarding
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ts.forwardData(clientConn, targetConn)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ts.forwardData(targetConn, clientConn)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// handleUDPForward handles UDP port forwarding
|
||||
func (ts *TeleportServer) handleUDPForward(clientConn net.Conn, rule *config.PortRule) {
|
||||
logger.WithField("client", clientConn.RemoteAddr()).Debug("UDP SERVER: Starting UDP forwarding")
|
||||
|
||||
// Determine target host (default to localhost if not specified)
|
||||
targetHost := rule.TargetHost
|
||||
if targetHost == "" {
|
||||
targetHost = "localhost"
|
||||
}
|
||||
|
||||
// Create UDP connection to target
|
||||
targetAddr := net.JoinHostPort(targetHost, fmt.Sprintf("%d", rule.RemotePort))
|
||||
targetConn, err := net.Dial("udp", targetAddr)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"target": targetAddr,
|
||||
"error": err,
|
||||
}).Debug("UDP SERVER: Failed to connect to UDP target")
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
logger.WithField("target", targetAddr).Debug("UDP SERVER: Connected to UDP target")
|
||||
|
||||
// Register this client for UDP forwarding with cleanup
|
||||
clientID := fmt.Sprintf("tcp-%s", clientConn.RemoteAddr().String())
|
||||
ts.udpMutex.Lock()
|
||||
// Clean up any existing connection for this client
|
||||
if existingConn, exists := ts.udpClients[clientID]; exists {
|
||||
existingConn.Close()
|
||||
}
|
||||
ts.udpClients[clientID] = targetConn.(*net.UDPConn)
|
||||
ts.udpMutex.Unlock()
|
||||
|
||||
logger.WithField("client", clientID).Debug("UDP SERVER: UDP forwarding registered")
|
||||
|
||||
// Handle UDP packets bidirectionally
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
// Channel to pass packet IDs from requests to responses
|
||||
// Use a larger buffer to prevent blocking under high load
|
||||
packetIDChan := make(chan uint64, 1000)
|
||||
|
||||
// Forward packets from client to target
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-ts.ctx.Done():
|
||||
logger.Debug("UDP SERVER: Context cancelled in client->target goroutine")
|
||||
return
|
||||
default:
|
||||
n, err := clientConn.Read(buffer)
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Debug("UDP SERVER: Failed to read from client")
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithField("bytes_received", n).Debug("UDP SERVER: Received bytes from client")
|
||||
|
||||
// Decrypt the data
|
||||
key := encryption.DeriveKey(ts.config.EncryptionKey)
|
||||
decryptedData, err := encryption.DecryptData(buffer[:n], key)
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Debug("UDP SERVER: Failed to decrypt UDP packet")
|
||||
continue
|
||||
}
|
||||
|
||||
logger.WithField("decrypted_bytes", len(decryptedData)).Debug("UDP SERVER: Decrypted bytes")
|
||||
|
||||
// Deserialize tagged packet
|
||||
packet, err := ts.deserializeTaggedUDPPacket(decryptedData)
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Debug("UDP SERVER: Failed to deserialize tagged UDP packet")
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate packet timestamp
|
||||
if !encryption.ValidatePacketTimestamp(packet.Header.Timestamp) {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": packet.Header.PacketID,
|
||||
"timestamp": packet.Header.Timestamp,
|
||||
}).Debug("UDP SERVER: Packet timestamp validation failed")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for replay attacks
|
||||
if !ts.replayProtection.IsValidNonce(packet.Header.PacketID, packet.Header.Timestamp) {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": packet.Header.PacketID,
|
||||
"timestamp": packet.Header.Timestamp,
|
||||
}).Debug("UDP SERVER: Replay attack detected")
|
||||
continue
|
||||
}
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": packet.Header.PacketID,
|
||||
"data_length": len(packet.Data),
|
||||
}).Debug("UDP SERVER: Deserialized packet")
|
||||
|
||||
// Forward to target
|
||||
_, err = targetConn.Write(packet.Data)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": packet.Header.PacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP SERVER: Failed to forward UDP packet to target")
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP SERVER: Forwarded packet to target")
|
||||
|
||||
// Send the packet ID to the response handler
|
||||
select {
|
||||
case packetIDChan <- packet.Header.PacketID:
|
||||
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP SERVER: Sent packet ID to response handler")
|
||||
default:
|
||||
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP SERVER: Packet ID channel full, dropping packet ID")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Forward responses from target to client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-ts.ctx.Done():
|
||||
logger.Debug("UDP SERVER: Context cancelled in target->client goroutine")
|
||||
return
|
||||
default:
|
||||
n, err := targetConn.Read(buffer)
|
||||
if err != nil {
|
||||
logger.WithField("error", err).Debug("UDP SERVER: Failed to read from target")
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithField("bytes_received", n).Debug("UDP SERVER: Received bytes from target")
|
||||
|
||||
// Get the packet ID from the request
|
||||
var originalPacketID uint64
|
||||
select {
|
||||
case originalPacketID = <-packetIDChan:
|
||||
logger.WithField("packetID", originalPacketID).Debug("UDP SERVER: Got packet ID for response")
|
||||
case <-time.After(1 * time.Second):
|
||||
logger.Debug("UDP SERVER: Timeout waiting for packet ID, using default")
|
||||
originalPacketID = 0
|
||||
}
|
||||
|
||||
// Create tagged packet for response using the original packet ID
|
||||
responsePacket := types.TaggedUDPPacket{
|
||||
Header: types.UDPPacketHeader{
|
||||
ClientID: clientID,
|
||||
PacketID: originalPacketID,
|
||||
Timestamp: time.Now().Unix(),
|
||||
},
|
||||
Data: make([]byte, n),
|
||||
}
|
||||
copy(responsePacket.Data, buffer[:n])
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": originalPacketID,
|
||||
"data": string(responsePacket.Data),
|
||||
}).Debug("UDP SERVER: Created response packet")
|
||||
|
||||
// Serialize and encrypt response
|
||||
data, err := ts.serializeTaggedUDPPacket(responsePacket)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": originalPacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP SERVER: Failed to serialize UDP response")
|
||||
continue
|
||||
}
|
||||
|
||||
key := encryption.DeriveKey(ts.config.EncryptionKey)
|
||||
encryptedData, err := encryption.EncryptData(data, key)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": originalPacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP SERVER: Failed to encrypt UDP response")
|
||||
continue
|
||||
}
|
||||
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": originalPacketID,
|
||||
"encrypted_length": len(encryptedData),
|
||||
}).Debug("UDP SERVER: Encrypted response packet")
|
||||
|
||||
// Send response to client
|
||||
_, err = clientConn.Write(encryptedData)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"packetID": originalPacketID,
|
||||
"error": err,
|
||||
}).Debug("UDP SERVER: Failed to send UDP response to client")
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithField("packetID", originalPacketID).Debug("UDP SERVER: Successfully sent response packet to client")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Unregister client
|
||||
ts.udpMutex.Lock()
|
||||
delete(ts.udpClients, clientID)
|
||||
ts.udpMutex.Unlock()
|
||||
}
|
||||
|
||||
// deserializeTaggedUDPPacket deserializes a tagged UDP packet
|
||||
func (ts *TeleportServer) deserializeTaggedUDPPacket(data []byte) (types.TaggedUDPPacket, error) {
|
||||
// Minimum size: 4 (headerLen) + 8 (packetID) + 8 (timestamp) = 20 bytes
|
||||
if len(data) < 20 {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("packet data too short")
|
||||
}
|
||||
|
||||
// Maximum reasonable packet size (1MB)
|
||||
if len(data) > 1024*1024 {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("packet data too large")
|
||||
}
|
||||
|
||||
offset := 0
|
||||
|
||||
// Header length
|
||||
if offset+4 > len(data) {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for header length")
|
||||
}
|
||||
headerLen := binary.BigEndian.Uint32(data[offset:])
|
||||
offset += 4
|
||||
|
||||
// Validate header length (reasonable limits)
|
||||
if headerLen > 1024 || headerLen == 0 {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("invalid header length: %d", headerLen)
|
||||
}
|
||||
|
||||
// Packet ID
|
||||
if offset+8 > len(data) {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for packet ID")
|
||||
}
|
||||
packetID := binary.BigEndian.Uint64(data[offset:])
|
||||
offset += 8
|
||||
|
||||
// Timestamp
|
||||
if offset+8 > len(data) {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for timestamp")
|
||||
}
|
||||
timestamp := binary.BigEndian.Uint64(data[offset:])
|
||||
offset += 8
|
||||
|
||||
// Client ID
|
||||
if offset+int(headerLen) > len(data) {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for client ID")
|
||||
}
|
||||
clientID := string(data[offset : offset+int(headerLen)])
|
||||
offset += int(headerLen)
|
||||
|
||||
// Validate client ID (basic sanitization)
|
||||
if len(clientID) == 0 {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("empty client ID")
|
||||
}
|
||||
// Check for reasonable client ID format
|
||||
if len(clientID) > 256 {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("client ID too long")
|
||||
}
|
||||
|
||||
// Data
|
||||
if offset > len(data) {
|
||||
return types.TaggedUDPPacket{}, fmt.Errorf("data offset exceeds packet length")
|
||||
}
|
||||
packetData := make([]byte, len(data)-offset)
|
||||
copy(packetData, data[offset:])
|
||||
|
||||
return types.TaggedUDPPacket{
|
||||
Header: types.UDPPacketHeader{
|
||||
ClientID: clientID,
|
||||
PacketID: packetID,
|
||||
Timestamp: int64(timestamp),
|
||||
},
|
||||
Data: packetData,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// forwardData forwards data between two connections
|
||||
func (ts *TeleportServer) forwardData(src, dst net.Conn) {
|
||||
buffer := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-ts.ctx.Done():
|
||||
return
|
||||
default:
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
// Close the destination connection when source closes
|
||||
dst.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_, err = dst.Write(buffer[:n])
|
||||
if err != nil {
|
||||
// Close the source connection when destination closes
|
||||
src.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readRequest reads a port forward request from the connection
|
||||
func (ts *TeleportServer) readRequest(conn net.Conn, request *types.PortForwardRequest) error {
|
||||
// Read request length
|
||||
var length uint32
|
||||
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate request length (prevent buffer overflow attacks)
|
||||
if length == 0 {
|
||||
return fmt.Errorf("request length cannot be zero")
|
||||
}
|
||||
if length > 32*1024 { // Reduced to 32KB limit for better security
|
||||
return fmt.Errorf("request too large: %d bytes (max 32KB)", length)
|
||||
}
|
||||
|
||||
// Read encrypted request data with timeout
|
||||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
encryptedData := make([]byte, length)
|
||||
bytesRead := 0
|
||||
for bytesRead < int(length) {
|
||||
n, err := conn.Read(encryptedData[bytesRead:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read request data: %v", err)
|
||||
}
|
||||
bytesRead += n
|
||||
}
|
||||
|
||||
// Decrypt the data
|
||||
key := encryption.DeriveKey(ts.config.EncryptionKey)
|
||||
decryptedData, err := encryption.DecryptData(encryptedData, key)
|
||||
if err != nil {
|
||||
logger.WithFields(map[string]interface{}{
|
||||
"error": err,
|
||||
"encrypted_data_length": len(encryptedData),
|
||||
}).Error("Decryption failed")
|
||||
return err
|
||||
}
|
||||
|
||||
// Deserialize the request
|
||||
return ts.deserializeRequest(decryptedData, request)
|
||||
}
|
||||
|
||||
// deserializeRequest deserializes a port forward request
|
||||
func (ts *TeleportServer) deserializeRequest(data []byte, request *types.PortForwardRequest) error {
|
||||
// Minimum size: 4 (localPort) + 4 (remotePort) + 4 (protocolLen) + 4 (targetHostLen) = 16 bytes
|
||||
if len(data) < 16 {
|
||||
return fmt.Errorf("request data too short")
|
||||
}
|
||||
|
||||
// Maximum reasonable request size (64KB)
|
||||
if len(data) > 64*1024 {
|
||||
return fmt.Errorf("request data too large")
|
||||
}
|
||||
|
||||
offset := 0
|
||||
|
||||
// Local port
|
||||
if offset+4 > len(data) {
|
||||
return fmt.Errorf("insufficient data for local port")
|
||||
}
|
||||
request.LocalPort = int(binary.BigEndian.Uint32(data[offset:]))
|
||||
offset += 4
|
||||
|
||||
// Validate port range
|
||||
if request.LocalPort < 1 || request.LocalPort > 65535 {
|
||||
return fmt.Errorf("invalid local port: %d", request.LocalPort)
|
||||
}
|
||||
|
||||
// Remote port
|
||||
if offset+4 > len(data) {
|
||||
return fmt.Errorf("insufficient data for remote port")
|
||||
}
|
||||
request.RemotePort = int(binary.BigEndian.Uint32(data[offset:]))
|
||||
offset += 4
|
||||
|
||||
// Validate port range
|
||||
if request.RemotePort < 1 || request.RemotePort > 65535 {
|
||||
return fmt.Errorf("invalid remote port: %d", request.RemotePort)
|
||||
}
|
||||
|
||||
// Protocol length
|
||||
if offset+4 > len(data) {
|
||||
return fmt.Errorf("insufficient data for protocol length")
|
||||
}
|
||||
protocolLen := int(binary.BigEndian.Uint32(data[offset:]))
|
||||
offset += 4
|
||||
|
||||
// Validate protocol length
|
||||
if protocolLen < 1 || protocolLen > 64 {
|
||||
return fmt.Errorf("invalid protocol length: %d", protocolLen)
|
||||
}
|
||||
|
||||
if offset+protocolLen > len(data) {
|
||||
return fmt.Errorf("insufficient data for protocol")
|
||||
}
|
||||
|
||||
request.Protocol = string(data[offset : offset+protocolLen])
|
||||
offset += protocolLen
|
||||
|
||||
// Validate protocol
|
||||
if request.Protocol != "tcp" && request.Protocol != "udp" {
|
||||
return fmt.Errorf("invalid protocol: %s", request.Protocol)
|
||||
}
|
||||
|
||||
// Target host length
|
||||
if offset+4 > len(data) {
|
||||
return fmt.Errorf("insufficient data for target host length")
|
||||
}
|
||||
targetHostLen := int(binary.BigEndian.Uint32(data[offset:]))
|
||||
offset += 4
|
||||
|
||||
// Validate target host length
|
||||
if targetHostLen > 256 {
|
||||
return fmt.Errorf("target host too long: %d", targetHostLen)
|
||||
}
|
||||
|
||||
if offset+targetHostLen > len(data) {
|
||||
return fmt.Errorf("insufficient data for target host")
|
||||
}
|
||||
|
||||
request.TargetHost = string(data[offset : offset+targetHostLen])
|
||||
|
||||
// Basic validation of target host (if not empty)
|
||||
if request.TargetHost != "" {
|
||||
if len(request.TargetHost) > 253 { // RFC 1123 limit
|
||||
return fmt.Errorf("target host exceeds maximum length")
|
||||
}
|
||||
// Basic character validation (no control characters)
|
||||
for _, c := range request.TargetHost {
|
||||
if c < 32 || c > 126 {
|
||||
return fmt.Errorf("invalid character in target host")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user