Add initial project structure with core functionality

- Created a new Go module named 'teleport' for secure port forwarding.
- Added essential files including .gitignore, LICENSE, and README.md with project details.
- Implemented configuration management with YAML support in config package.
- Developed core client and server functionalities for handling port forwarding.
- Introduced DNS server capabilities and integrated logging with sanitization.
- Established rate limiting and metrics tracking for performance monitoring.
- Included comprehensive tests for core components and functionalities.
- Set up CI workflows for automated testing and release management using Gitea actions.
This commit is contained in:
2025-09-20 18:07:08 -05:00
commit d24d1dc5ae
26 changed files with 6065 additions and 0 deletions

785
internal/client/client.go Normal file
View File

@@ -0,0 +1,785 @@
package client
import (
"context"
"encoding/binary"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"teleport/pkg/config"
"teleport/pkg/dns"
"teleport/pkg/encryption"
"teleport/pkg/logger"
"teleport/pkg/types"
)
// TeleportClient represents the client instance
type TeleportClient struct {
config *config.Config
serverConn net.Conn
udpListeners map[int]*net.UDPConn
udpMutex sync.RWMutex
packetCounter uint64
replayProtection *encryption.ReplayProtection
ctx context.Context
cancel context.CancelFunc
connectionPool chan net.Conn
maxPoolSize int
}
// NewTeleportClient creates a new teleport client
func NewTeleportClient(config *config.Config) *TeleportClient {
ctx, cancel := context.WithCancel(context.Background())
maxPoolSize := 10 // Default connection pool size
if config.MaxConnections > 0 {
maxPoolSize = config.MaxConnections / 10 // Use 10% of max connections for pool
if maxPoolSize < 5 {
maxPoolSize = 5
}
}
return &TeleportClient{
config: config,
udpListeners: make(map[int]*net.UDPConn),
replayProtection: encryption.NewReplayProtection(),
ctx: ctx,
cancel: cancel,
connectionPool: make(chan net.Conn, maxPoolSize),
maxPoolSize: maxPoolSize,
}
}
// Start starts the teleport client
func (tc *TeleportClient) Start() error {
logger.WithField("remote_address", tc.config.RemoteAddress).Info("Starting teleport client")
// Connect to server
conn, err := net.Dial("tcp", tc.config.RemoteAddress)
if err != nil {
return fmt.Errorf("failed to connect to server: %v", err)
}
tc.serverConn = conn
// Start DNS server if enabled
if tc.config.DNSServer.Enabled {
go dns.StartDNSServer(tc.config)
}
// Start port forwarding for each port rule
for _, rule := range tc.config.Ports {
go tc.startPortForwarding(rule)
}
// Keep the client running with proper error handling
<-tc.ctx.Done()
logger.Info("Client shutting down...")
return tc.ctx.Err()
}
// startPortForwarding starts port forwarding for a specific rule
func (tc *TeleportClient) startPortForwarding(rule config.PortRule) {
switch rule.Protocol {
case "tcp":
tc.startTCPForwarding(rule)
case "udp":
tc.startUDPForwarding(rule)
}
}
// startTCPForwarding starts TCP port forwarding
func (tc *TeleportClient) startTCPForwarding(rule config.PortRule) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", rule.LocalPort))
if err != nil {
logger.WithFields(map[string]interface{}{
"port": rule.LocalPort,
"error": err,
}).Error("Failed to start TCP listener")
return
}
defer listener.Close()
logger.WithFields(map[string]interface{}{
"local_port": rule.LocalPort,
"remote_addr": tc.config.RemoteAddress,
"remote_port": rule.RemotePort,
"protocol": rule.Protocol,
}).Info("TCP forwarding started")
for {
select {
case <-tc.ctx.Done():
logger.Debug("Client shutting down...")
return
default:
clientConn, err := listener.Accept()
if err != nil {
select {
case <-tc.ctx.Done():
return
default:
logger.WithField("error", err).Error("Failed to accept TCP connection")
continue
}
}
go tc.handleTCPConnection(clientConn, rule)
}
}
}
// Stop stops the teleport client
func (tc *TeleportClient) Stop() {
logger.Info("Stopping teleport client...")
tc.cancel()
if tc.serverConn != nil {
tc.serverConn.Close()
}
// Close all connections in the pool
for {
select {
case conn := <-tc.connectionPool:
conn.Close()
default:
goto poolClosed
}
}
poolClosed:
// Close all UDP listeners
tc.udpMutex.Lock()
for _, conn := range tc.udpListeners {
conn.Close()
}
tc.udpMutex.Unlock()
}
// getConnection gets a connection from the pool or creates a new one
func (tc *TeleportClient) getConnection() (net.Conn, error) {
// Try to get a healthy connection from the pool
for {
select {
case conn := <-tc.connectionPool:
if conn != nil && tc.isConnectionHealthy(conn) {
return conn, nil
}
// Connection is nil or unhealthy, close it and try again
if conn != nil {
conn.Close()
}
default:
// No connection in pool, create new one
break
}
break
}
// Create new connection
conn, err := net.Dial("tcp", tc.config.RemoteAddress)
if err != nil {
return nil, fmt.Errorf("failed to create connection: %v", err)
}
// Set connection timeouts
if tc.config.ReadTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(tc.config.ReadTimeout))
}
if tc.config.WriteTimeout > 0 {
conn.SetWriteDeadline(time.Now().Add(tc.config.WriteTimeout))
}
return conn, nil
}
// isConnectionHealthy checks if a connection is still alive and usable
func (tc *TeleportClient) isConnectionHealthy(conn net.Conn) bool {
if conn == nil {
return false
}
// Try to set a very short read deadline to test the connection
// This is a non-blocking way to check if the connection is still alive
conn.SetReadDeadline(time.Now().Add(1 * time.Millisecond))
// Try to read one byte (this will fail immediately if connection is dead)
one := make([]byte, 1)
_, err := conn.Read(one)
// Clear the deadline
conn.SetReadDeadline(time.Time{})
// If we get an error, the connection is likely dead
// We expect to get a timeout error for a healthy connection (since we set a 1ms deadline)
if err != nil {
// Check if it's a timeout error (which is expected for a healthy connection)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return true // Connection is healthy (timeout is expected)
}
// Any other error means the connection is dead
return false
}
// If we actually read data, that's unexpected but the connection is alive
return true
}
// returnConnection returns a connection to the pool
func (tc *TeleportClient) returnConnection(conn net.Conn) {
if conn == nil {
return
}
// Check if connection is still healthy before returning to pool
if !tc.isConnectionHealthy(conn) {
conn.Close()
return
}
select {
case tc.connectionPool <- conn:
// Connection returned to pool
case <-tc.ctx.Done():
// Context cancelled, close the connection
conn.Close()
default:
// Pool is full, close the connection
conn.Close()
}
}
// startUDPForwarding starts UDP port forwarding
func (tc *TeleportClient) startUDPForwarding(rule config.PortRule) {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", rule.LocalPort))
if err != nil {
logger.WithFields(map[string]interface{}{
"port": rule.LocalPort,
"error": err,
}).Error("Failed to resolve UDP address")
return
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
logger.WithFields(map[string]interface{}{
"port": rule.LocalPort,
"error": err,
}).Error("Failed to start UDP listener")
return
}
tc.udpMutex.Lock()
tc.udpListeners[rule.LocalPort] = conn
tc.udpMutex.Unlock()
logger.WithFields(map[string]interface{}{
"local_port": rule.LocalPort,
"remote_addr": tc.config.RemoteAddress,
"remote_port": rule.RemotePort,
"protocol": rule.Protocol,
}).Info("UDP forwarding started")
buffer := make([]byte, 4096)
for {
select {
case <-tc.ctx.Done():
logger.Debug("Client context cancelled, stopping UDP forwarding")
return
default:
n, clientAddr, err := conn.ReadFromUDP(buffer)
if err != nil {
select {
case <-tc.ctx.Done():
return
default:
logger.WithField("error", err).Error("Failed to read UDP packet")
continue
}
}
// Create tagged packet with atomic counter
packetID := atomic.AddUint64(&tc.packetCounter, 1)
taggedPacket := types.TaggedUDPPacket{
Header: types.UDPPacketHeader{
ClientID: tc.config.InstanceID,
PacketID: packetID,
Timestamp: time.Now().Unix(),
},
Data: make([]byte, n),
}
copy(taggedPacket.Data, buffer[:n])
logger.WithFields(map[string]interface{}{
"client": clientAddr,
"data_length": len(taggedPacket.Data),
"packetID": taggedPacket.Header.PacketID,
}).Debug("UDP packet received")
// Send to server and wait for response
go tc.sendTaggedUDPPacketWithResponse(taggedPacket, conn, clientAddr)
}
}
}
// sendTaggedUDPPacketWithResponse sends a tagged UDP packet to the server and forwards response back
func (tc *TeleportClient) sendTaggedUDPPacketWithResponse(packet types.TaggedUDPPacket, udpConn *net.UDPConn, clientAddr *net.UDPAddr) {
logger.WithField("packetID", packet.Header.PacketID).Debug("Starting to send UDP packet to server")
// Establish a new connection to the server for this UDP packet
serverConn, err := net.Dial("tcp", tc.config.RemoteAddress)
if err != nil {
logger.WithField("error", err).Debug("UDP CLIENT: Failed to connect to server for UDP packet")
return
}
defer serverConn.Close()
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Connected to server, sending port forward request")
// Find the UDP port rule to get the correct remote port
var remotePort int
for _, rule := range tc.config.Ports {
if rule.Protocol == "udp" {
remotePort = rule.RemotePort
break
}
}
// Send port forward request first (like TCP does)
request := types.PortForwardRequest{
LocalPort: int(packet.Header.PacketID), // Use packet ID as local port for identification
RemotePort: remotePort, // UDP port we want to forward to
Protocol: "udp",
TargetHost: "",
}
if err := tc.sendRequestToConnection(serverConn, request); err != nil {
logger.WithFields(map[string]interface{}{
"packetID": packet.Header.PacketID,
"error": err,
}).Debug("UDP CLIENT: Failed to send port forward request")
return
}
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Port forward request sent, sending packet")
// Send UDP packet through the new connection
tc.sendTaggedUDPPacketToConnection(serverConn, packet)
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Packet sent, waiting for response")
// Wait for response from server and forward it back
tc.waitForUDPResponseAndForward(serverConn, packet.Header.PacketID, udpConn, clientAddr)
}
// sendTaggedUDPPacketToConnection sends a tagged UDP packet to a specific connection
func (tc *TeleportClient) sendTaggedUDPPacketToConnection(conn net.Conn, packet types.TaggedUDPPacket) {
// Serialize the tagged packet
data, err := tc.serializeTaggedUDPPacket(packet)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": packet.Header.PacketID,
"error": err,
}).Debug("UDP CLIENT: Failed to serialize tagged UDP packet")
return
}
logger.WithFields(map[string]interface{}{
"packetID": packet.Header.PacketID,
"data_length": len(data),
}).Debug("UDP CLIENT: Serialized packet")
// Encrypt the data
key := encryption.DeriveKey(tc.config.EncryptionKey)
encryptedData, err := encryption.EncryptData(data, key)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": packet.Header.PacketID,
"error": err,
}).Debug("UDP CLIENT: Failed to encrypt UDP packet")
return
}
logger.WithFields(map[string]interface{}{
"packetID": packet.Header.PacketID,
"encrypted_length": len(encryptedData),
}).Debug("UDP CLIENT: Encrypted packet")
// Send to server
_, err = conn.Write(encryptedData)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": packet.Header.PacketID,
"error": err,
}).Debug("UDP CLIENT: Failed to send UDP packet to server")
} else {
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP CLIENT: Successfully sent packet to server")
}
}
// waitForUDPResponseAndForward waits for a UDP response from the server and forwards it back
func (tc *TeleportClient) waitForUDPResponseAndForward(conn net.Conn, expectedPacketID uint64, udpConn *net.UDPConn, clientAddr *net.UDPAddr) {
logger.WithField("packetID", expectedPacketID).Debug("UDP CLIENT: Waiting for response to packet")
buffer := make([]byte, 4096)
for {
select {
case <-tc.ctx.Done():
logger.WithField("packetID", expectedPacketID).Debug("UDP CLIENT: Context cancelled while waiting for response")
return
default:
// Set a timeout for reading the response
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
n, err := conn.Read(buffer)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": expectedPacketID,
"error": err,
}).Debug("UDP CLIENT: Failed to read UDP response")
return
}
logger.WithFields(map[string]interface{}{
"packetID": expectedPacketID,
"bytes_received": n,
}).Debug("UDP CLIENT: Received response bytes")
// Decrypt the response
key := encryption.DeriveKey(tc.config.EncryptionKey)
decryptedData, err := encryption.DecryptData(buffer[:n], key)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": expectedPacketID,
"error": err,
}).Debug("UDP CLIENT: Failed to decrypt UDP response")
continue
}
// Deserialize the response packet
responsePacket, err := tc.deserializeTaggedUDPPacket(decryptedData)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": expectedPacketID,
"error": err,
}).Debug("UDP CLIENT: Failed to deserialize UDP response")
continue
}
// Validate packet timestamp
if !encryption.ValidatePacketTimestamp(responsePacket.Header.Timestamp) {
logger.WithFields(map[string]interface{}{
"packetID": responsePacket.Header.PacketID,
"timestamp": responsePacket.Header.Timestamp,
}).Debug("UDP CLIENT: Response packet timestamp validation failed")
continue
}
// Check for replay attacks
if !tc.replayProtection.IsValidNonce(responsePacket.Header.PacketID, responsePacket.Header.Timestamp) {
logger.WithFields(map[string]interface{}{
"packetID": responsePacket.Header.PacketID,
"timestamp": responsePacket.Header.Timestamp,
}).Debug("UDP CLIENT: Replay attack detected in response")
continue
}
logger.WithFields(map[string]interface{}{
"packetID": responsePacket.Header.PacketID,
"data_length": len(responsePacket.Data),
}).Debug("UDP CLIENT: Deserialized response packet")
// Check if this is the response we're waiting for
if responsePacket.Header.PacketID == expectedPacketID {
// Forward the response back to the original UDP client
_, err := udpConn.WriteToUDP(responsePacket.Data, clientAddr)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": responsePacket.Header.PacketID,
"error": err,
}).Debug("UDP CLIENT: Failed to forward UDP response to client")
} else {
logger.WithFields(map[string]interface{}{
"packetID": responsePacket.Header.PacketID,
"client": clientAddr,
"data_length": len(responsePacket.Data),
}).Debug("UDP CLIENT: Successfully forwarded UDP response")
}
return
} else {
logger.WithFields(map[string]interface{}{
"received_packetID": responsePacket.Header.PacketID,
"expected_packetID": expectedPacketID,
}).Debug("UDP CLIENT: Received response for different packet")
}
}
}
}
// serializeTaggedUDPPacket serializes a tagged UDP packet
func (tc *TeleportClient) serializeTaggedUDPPacket(packet types.TaggedUDPPacket) ([]byte, error) {
// Simple serialization: header length + header + data
headerBytes := []byte(packet.Header.ClientID)
headerLen := len(headerBytes)
data := make([]byte, 4+8+8+headerLen+len(packet.Data))
offset := 0
// Header length
binary.BigEndian.PutUint32(data[offset:], uint32(headerLen))
offset += 4
// Packet ID
binary.BigEndian.PutUint64(data[offset:], packet.Header.PacketID)
offset += 8
// Timestamp
binary.BigEndian.PutUint64(data[offset:], uint64(packet.Header.Timestamp))
offset += 8
// Client ID
copy(data[offset:], headerBytes)
offset += headerLen
// Data
copy(data[offset:], packet.Data)
return data, nil
}
// deserializeTaggedUDPPacket deserializes a tagged UDP packet
func (tc *TeleportClient) deserializeTaggedUDPPacket(data []byte) (types.TaggedUDPPacket, error) {
// Minimum size: 4 (headerLen) + 8 (packetID) + 8 (timestamp) = 20 bytes
if len(data) < 20 {
return types.TaggedUDPPacket{}, fmt.Errorf("packet data too short")
}
// Maximum reasonable packet size (1MB)
if len(data) > 1024*1024 {
return types.TaggedUDPPacket{}, fmt.Errorf("packet data too large")
}
offset := 0
// Header length
if offset+4 > len(data) {
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for header length")
}
headerLen := binary.BigEndian.Uint32(data[offset:])
offset += 4
// Validate header length (reasonable limits)
if headerLen > 1024 || headerLen == 0 {
return types.TaggedUDPPacket{}, fmt.Errorf("invalid header length: %d", headerLen)
}
// Packet ID
if offset+8 > len(data) {
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for packet ID")
}
packetID := binary.BigEndian.Uint64(data[offset:])
offset += 8
// Timestamp
if offset+8 > len(data) {
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for timestamp")
}
timestamp := binary.BigEndian.Uint64(data[offset:])
offset += 8
// Validate timestamp is not too old or in the future
now := time.Now().Unix()
if timestamp > uint64(now+300) || timestamp < uint64(now-300) { // 5 minute window
return types.TaggedUDPPacket{}, fmt.Errorf("timestamp out of range: %d (current: %d)", timestamp, now)
}
// Client ID
if offset+int(headerLen) > len(data) {
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for client ID")
}
clientID := string(data[offset : offset+int(headerLen)])
offset += int(headerLen)
// Validate client ID (basic sanitization)
if len(clientID) == 0 {
return types.TaggedUDPPacket{}, fmt.Errorf("empty client ID")
}
// Check for reasonable client ID format
if len(clientID) > 256 {
return types.TaggedUDPPacket{}, fmt.Errorf("client ID too long")
}
// Data
if offset > len(data) {
return types.TaggedUDPPacket{}, fmt.Errorf("data offset exceeds packet length")
}
dataLen := len(data) - offset
packetData := make([]byte, dataLen)
copy(packetData, data[offset:])
return types.TaggedUDPPacket{
Header: types.UDPPacketHeader{
ClientID: clientID,
PacketID: packetID,
Timestamp: int64(timestamp),
},
Data: packetData,
}, nil
}
// handleTCPConnection handles a TCP connection from a local client
func (tc *TeleportClient) handleTCPConnection(clientConn net.Conn, rule config.PortRule) {
defer clientConn.Close()
// Get a connection from the pool or create a new one
serverConn, err := tc.getConnection()
if err != nil {
logger.WithField("error", err).Error("Failed to get server connection")
return
}
defer tc.returnConnection(serverConn)
// Send port forward request to server
request := types.PortForwardRequest(rule)
if err := tc.sendRequestToConnection(serverConn, request); err != nil {
logger.WithField("error", err).Error("Failed to send port forward request")
return
}
// Now forward the actual data bidirectionally
var wg sync.WaitGroup
wg.Add(2)
// Forward data from client to server
go func() {
defer wg.Done()
tc.forwardData(clientConn, serverConn)
}()
// Forward data from server to client
go func() {
defer wg.Done()
tc.forwardData(serverConn, clientConn)
}()
wg.Wait()
}
// forwardData forwards data from src to dst
func (tc *TeleportClient) forwardData(src, dst net.Conn) {
buffer := make([]byte, 4096)
for {
select {
case <-tc.ctx.Done():
return
default:
n, err := src.Read(buffer)
if err != nil {
// Close the destination connection when source closes
dst.Close()
return
}
_, err = dst.Write(buffer[:n])
if err != nil {
// Close the source connection when destination closes
src.Close()
return
}
}
}
}
// sendRequest sends a port forward request to the server
func (tc *TeleportClient) sendRequest(request types.PortForwardRequest) error {
return tc.sendRequestToConnection(tc.serverConn, request)
}
// sendRequestToConnection sends a port forward request to a specific connection
func (tc *TeleportClient) sendRequestToConnection(conn net.Conn, request types.PortForwardRequest) error {
// Serialize the request
data, err := tc.serializeRequest(request)
if err != nil {
return err
}
// Encrypt the data
key := encryption.DeriveKey(tc.config.EncryptionKey)
encryptedData, err := encryption.EncryptData(data, key)
if err != nil {
logger.WithFields(map[string]interface{}{
"error": err,
"data_length": len(data),
"request": fmt.Sprintf("%s:%d->%d", request.Protocol, request.LocalPort, request.RemotePort),
}).Error("Encryption failed")
return err
}
logger.WithFields(map[string]interface{}{
"original_length": len(data),
"encrypted_length": len(encryptedData),
"compression_ratio": fmt.Sprintf("%.2f", float64(len(encryptedData))/float64(len(data))),
"request": fmt.Sprintf("%s:%d->%d", request.Protocol, request.LocalPort, request.RemotePort),
}).Debug("Data encrypted successfully")
// Validate request size before sending
if len(encryptedData) == 0 {
return fmt.Errorf("encrypted data cannot be empty")
}
if len(encryptedData) > 64*1024 { // 64KB limit
return fmt.Errorf("request too large: %d bytes", len(encryptedData))
}
// Send request length
length := uint32(len(encryptedData))
if err := binary.Write(conn, binary.BigEndian, length); err != nil {
return err
}
// Send encrypted request data
bytesWritten := 0
for bytesWritten < len(encryptedData) {
n, err := conn.Write(encryptedData[bytesWritten:])
if err != nil {
return fmt.Errorf("failed to write request data: %v", err)
}
bytesWritten += n
}
return nil
}
// serializeRequest serializes a port forward request
func (tc *TeleportClient) serializeRequest(request types.PortForwardRequest) ([]byte, error) {
protocolBytes := []byte(request.Protocol)
protocolLen := len(protocolBytes)
targetHostBytes := []byte(request.TargetHost)
targetHostLen := len(targetHostBytes)
data := make([]byte, 4+4+4+4+protocolLen+targetHostLen)
offset := 0
// Local port
binary.BigEndian.PutUint32(data[offset:], uint32(request.LocalPort))
offset += 4
// Remote port
binary.BigEndian.PutUint32(data[offset:], uint32(request.RemotePort))
offset += 4
// Protocol length
binary.BigEndian.PutUint32(data[offset:], uint32(protocolLen))
offset += 4
// Protocol
copy(data[offset:], protocolBytes)
offset += protocolLen
// Target host length
binary.BigEndian.PutUint32(data[offset:], uint32(targetHostLen))
offset += 4
// Target host
copy(data[offset:], targetHostBytes)
return data, nil
}

869
internal/server/server.go Normal file
View File

@@ -0,0 +1,869 @@
package server
import (
"context"
"encoding/binary"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"teleport/pkg/config"
"teleport/pkg/encryption"
"teleport/pkg/logger"
"teleport/pkg/metrics"
"teleport/pkg/ratelimit"
"teleport/pkg/types"
)
// TeleportServer represents the server instance
type TeleportServer struct {
config *config.Config
listener net.Listener
udpListeners map[int]*net.UDPConn
udpClients map[string]*net.UDPConn
udpMutex sync.RWMutex
packetCounter uint64
replayProtection *encryption.ReplayProtection
ctx context.Context
cancel context.CancelFunc
connectionSem chan struct{} // Semaphore for limiting concurrent connections
activeConnections int64 // Atomic counter for active connections
maxConnections int // Maximum concurrent connections
rateLimiter *ratelimit.RateLimiter
goroutineSem chan struct{} // Semaphore for limiting concurrent goroutines
maxGoroutines int // Maximum concurrent goroutines
metrics *metrics.Metrics
}
// NewTeleportServer creates a new teleport server
func NewTeleportServer(config *config.Config) *TeleportServer {
ctx, cancel := context.WithCancel(context.Background())
maxConnections := 1000 // Default maximum concurrent connections
if config.MaxConnections > 0 {
maxConnections = config.MaxConnections
}
maxGoroutines := maxConnections * 2 // Allow 2x connections for goroutines
if maxGoroutines > 10000 {
maxGoroutines = 10000 // Cap at 10k goroutines
}
// Initialize rate limiter if enabled
var rateLimiter *ratelimit.RateLimiter
if config.RateLimit.Enabled {
rateLimiter = ratelimit.NewRateLimiter(
config.RateLimit.RequestsPerSecond,
config.RateLimit.BurstSize,
config.RateLimit.WindowSize,
)
}
// Initialize metrics
metricsInstance := metrics.GetMetrics()
return &TeleportServer{
config: config,
udpListeners: make(map[int]*net.UDPConn),
udpClients: make(map[string]*net.UDPConn),
replayProtection: encryption.NewReplayProtection(),
ctx: ctx,
cancel: cancel,
connectionSem: make(chan struct{}, maxConnections),
maxConnections: maxConnections,
rateLimiter: rateLimiter,
goroutineSem: make(chan struct{}, maxGoroutines),
maxGoroutines: maxGoroutines,
metrics: metricsInstance,
}
}
// Start starts the teleport server
func (ts *TeleportServer) Start() error {
logger.WithField("listen_address", ts.config.ListenAddress).Info("Starting teleport server")
// Start TCP listener
listener, err := net.Listen("tcp", ts.config.ListenAddress)
if err != nil {
return fmt.Errorf("failed to start TCP listener: %v", err)
}
ts.listener = listener
// Start UDP listeners for each port
for _, rule := range ts.config.Ports {
if rule.Protocol == "udp" {
go ts.startUDPListener(rule)
}
}
// Accept TCP connections
for {
select {
case <-ts.ctx.Done():
logger.Debug("Server shutting down...")
return ts.ctx.Err()
default:
conn, err := ts.listener.Accept()
if err != nil {
select {
case <-ts.ctx.Done():
return ts.ctx.Err()
default:
logger.WithField("error", err).Error("Failed to accept connection")
continue
}
}
// Check rate limiting first
if ts.rateLimiter != nil && !ts.rateLimiter.Allow() {
logger.WithField("client", conn.RemoteAddr()).Warn("Rate limit exceeded, rejecting connection")
ts.metrics.IncrementRateLimitedRequests()
conn.Close()
continue
}
// Check if we can accept more connections
select {
case ts.connectionSem <- struct{}{}:
// Connection accepted
atomic.AddInt64(&ts.activeConnections, 1)
ts.metrics.IncrementTotalConnections()
ts.metrics.IncrementActiveConnections()
// Check goroutine limit
select {
case ts.goroutineSem <- struct{}{}:
go ts.handleConnectionWithLimit(conn)
default:
// Goroutine limit reached
<-ts.connectionSem // Release connection semaphore
atomic.AddInt64(&ts.activeConnections, -1)
ts.metrics.DecrementActiveConnections()
ts.metrics.IncrementRejectedConnections()
logger.WithField("max_goroutines", ts.maxGoroutines).Warn("Goroutine limit reached, rejecting connection")
conn.Close()
}
default:
// Connection limit reached
ts.metrics.IncrementRejectedConnections()
logger.WithField("max_connections", ts.maxConnections).Warn("Connection limit reached, rejecting connection")
conn.Close()
}
}
}
}
// Stop stops the teleport server
func (ts *TeleportServer) Stop() {
logger.Info("Stopping teleport server...")
ts.cancel()
if ts.listener != nil {
ts.listener.Close()
}
// Close all UDP listeners
ts.udpMutex.Lock()
for _, conn := range ts.udpListeners {
conn.Close()
}
ts.udpMutex.Unlock()
}
// startUDPListener starts a UDP listener for a specific port
func (ts *TeleportServer) startUDPListener(rule config.PortRule) {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", rule.LocalPort))
if err != nil {
logger.WithFields(map[string]interface{}{
"port": rule.LocalPort,
"error": err,
}).Error("Failed to resolve UDP address")
return
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
logger.WithFields(map[string]interface{}{
"port": rule.LocalPort,
"error": err,
}).Error("Failed to start UDP listener")
return
}
ts.udpMutex.Lock()
ts.udpListeners[rule.LocalPort] = conn
ts.udpMutex.Unlock()
logger.WithField("port", rule.LocalPort).Info("UDP listener started")
buffer := make([]byte, 4096)
for {
n, clientAddr, err := conn.ReadFromUDP(buffer)
if err != nil {
logger.WithField("error", err).Error("Failed to read UDP packet")
continue
}
// Create UDP request
request := types.UDPRequest{
LocalPort: rule.LocalPort,
RemotePort: rule.RemotePort,
Protocol: rule.Protocol,
Data: make([]byte, n),
ClientAddr: clientAddr,
}
copy(request.Data, buffer[:n])
// Forward to all connected clients
go ts.forwardUDPRequest(request)
}
}
// forwardUDPRequest forwards a UDP request to all connected clients
func (ts *TeleportServer) forwardUDPRequest(request types.UDPRequest) {
ts.udpMutex.RLock()
defer ts.udpMutex.RUnlock()
// Create tagged packet with atomic counter
packetID := atomic.AddUint64(&ts.packetCounter, 1)
taggedPacket := types.TaggedUDPPacket{
Header: types.UDPPacketHeader{
ClientID: fmt.Sprintf("udp-%s", request.ClientAddr.String()),
PacketID: packetID,
Timestamp: time.Now().Unix(),
},
Data: request.Data,
}
// Forward to all connected UDP clients
for clientID, clientConn := range ts.udpClients {
if clientID != taggedPacket.Header.ClientID {
go ts.sendTaggedUDPPacket(clientConn, taggedPacket)
}
}
}
// sendTaggedUDPPacket sends a tagged UDP packet to a client
func (ts *TeleportServer) sendTaggedUDPPacket(clientConn *net.UDPConn, packet types.TaggedUDPPacket) {
// Serialize the tagged packet
data, err := ts.serializeTaggedUDPPacket(packet)
if err != nil {
logger.WithField("error", err).Debug("Failed to serialize tagged UDP packet")
return
}
// Encrypt the data
key := encryption.DeriveKey(ts.config.EncryptionKey)
encryptedData, err := encryption.EncryptData(data, key)
if err != nil {
logger.WithField("error", err).Debug("Failed to encrypt UDP packet")
return
}
// Send to client
_, err = clientConn.Write(encryptedData)
if err != nil {
logger.WithField("error", err).Debug("Failed to send UDP packet to client")
}
}
// serializeTaggedUDPPacket serializes a tagged UDP packet
func (ts *TeleportServer) serializeTaggedUDPPacket(packet types.TaggedUDPPacket) ([]byte, error) {
// Simple serialization: header length + header + data
headerBytes := []byte(packet.Header.ClientID)
headerLen := len(headerBytes)
data := make([]byte, 4+8+8+headerLen+len(packet.Data))
offset := 0
// Header length
binary.BigEndian.PutUint32(data[offset:], uint32(headerLen))
offset += 4
// Packet ID
binary.BigEndian.PutUint64(data[offset:], packet.Header.PacketID)
offset += 8
// Timestamp
binary.BigEndian.PutUint64(data[offset:], uint64(packet.Header.Timestamp))
offset += 8
// Client ID
copy(data[offset:], headerBytes)
offset += headerLen
// Data
copy(data[offset:], packet.Data)
return data, nil
}
// handleConnectionWithLimit handles a TCP connection with connection limit management
func (ts *TeleportServer) handleConnectionWithLimit(conn net.Conn) {
defer func() {
conn.Close()
<-ts.connectionSem // Release connection semaphore
<-ts.goroutineSem // Release goroutine semaphore
atomic.AddInt64(&ts.activeConnections, -1)
ts.metrics.DecrementActiveConnections()
}()
// Set connection timeouts
if ts.config.ReadTimeout > 0 {
conn.SetReadDeadline(time.Now().Add(ts.config.ReadTimeout))
}
if ts.config.WriteTimeout > 0 {
conn.SetWriteDeadline(time.Now().Add(ts.config.WriteTimeout))
}
logger.WithFields(map[string]interface{}{
"client": conn.RemoteAddr(),
"active_connections": atomic.LoadInt64(&ts.activeConnections),
}).Info("New connection")
ts.handleConnection(conn)
}
// handleConnection handles a TCP connection from a client
func (ts *TeleportServer) handleConnection(conn net.Conn) {
// Don't set timeouts on the initial connection - let the individual handlers set them
// Read the port forward request (only one per connection)
var request types.PortForwardRequest
if err := ts.readRequest(conn, &request); err != nil {
logger.WithField("error", err).Error("Failed to read request")
return
}
// Find matching port rule
// The client sends LocalPort (client's local port) and RemotePort (server's port to forward to)
// We need to find a server rule that matches the RemotePort the client wants
var portRule *config.PortRule
for _, rule := range ts.config.Ports {
if rule.LocalPort == request.RemotePort && rule.Protocol == request.Protocol {
portRule = &rule
break
}
}
if portRule == nil {
logger.WithFields(map[string]interface{}{
"local_port": request.LocalPort,
"remote_port": request.RemotePort,
"protocol": request.Protocol,
}).Warn("No matching port rule found")
return
}
// Handle based on protocol
logger.WithFields(map[string]interface{}{
"protocol": request.Protocol,
"port": request.RemotePort,
}).Info("Handling connection")
switch request.Protocol {
case "tcp":
logger.Debug("Routing to TCP handler")
ts.handleTCPForward(conn, portRule)
case "udp":
logger.Debug("Routing to UDP handler")
ts.handleUDPForward(conn, portRule)
}
}
// handleTCPForward handles TCP port forwarding
func (ts *TeleportServer) handleTCPForward(clientConn net.Conn, rule *config.PortRule) {
// Determine target host (default to localhost if not specified)
targetHost := rule.TargetHost
if targetHost == "" {
targetHost = "localhost"
}
// Connect to the target service
targetAddr := net.JoinHostPort(targetHost, fmt.Sprintf("%d", rule.RemotePort))
targetConn, err := net.Dial("tcp", targetAddr)
if err != nil {
logger.WithFields(map[string]interface{}{
"target": targetAddr,
"error": err,
}).Error("Failed to connect to target")
return
}
defer targetConn.Close()
logger.WithFields(map[string]interface{}{
"client": clientConn.RemoteAddr(),
"target": targetAddr,
}).Info("TCP forwarding")
// Start bidirectional forwarding
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
ts.forwardData(clientConn, targetConn)
}()
go func() {
defer wg.Done()
ts.forwardData(targetConn, clientConn)
}()
wg.Wait()
}
// handleUDPForward handles UDP port forwarding
func (ts *TeleportServer) handleUDPForward(clientConn net.Conn, rule *config.PortRule) {
logger.WithField("client", clientConn.RemoteAddr()).Debug("UDP SERVER: Starting UDP forwarding")
// Determine target host (default to localhost if not specified)
targetHost := rule.TargetHost
if targetHost == "" {
targetHost = "localhost"
}
// Create UDP connection to target
targetAddr := net.JoinHostPort(targetHost, fmt.Sprintf("%d", rule.RemotePort))
targetConn, err := net.Dial("udp", targetAddr)
if err != nil {
logger.WithFields(map[string]interface{}{
"target": targetAddr,
"error": err,
}).Debug("UDP SERVER: Failed to connect to UDP target")
return
}
defer targetConn.Close()
logger.WithField("target", targetAddr).Debug("UDP SERVER: Connected to UDP target")
// Register this client for UDP forwarding with cleanup
clientID := fmt.Sprintf("tcp-%s", clientConn.RemoteAddr().String())
ts.udpMutex.Lock()
// Clean up any existing connection for this client
if existingConn, exists := ts.udpClients[clientID]; exists {
existingConn.Close()
}
ts.udpClients[clientID] = targetConn.(*net.UDPConn)
ts.udpMutex.Unlock()
logger.WithField("client", clientID).Debug("UDP SERVER: UDP forwarding registered")
// Handle UDP packets bidirectionally
var wg sync.WaitGroup
wg.Add(2)
// Channel to pass packet IDs from requests to responses
// Use a larger buffer to prevent blocking under high load
packetIDChan := make(chan uint64, 1000)
// Forward packets from client to target
go func() {
defer wg.Done()
buffer := make([]byte, 4096)
for {
select {
case <-ts.ctx.Done():
logger.Debug("UDP SERVER: Context cancelled in client->target goroutine")
return
default:
n, err := clientConn.Read(buffer)
if err != nil {
logger.WithField("error", err).Debug("UDP SERVER: Failed to read from client")
return
}
logger.WithField("bytes_received", n).Debug("UDP SERVER: Received bytes from client")
// Decrypt the data
key := encryption.DeriveKey(ts.config.EncryptionKey)
decryptedData, err := encryption.DecryptData(buffer[:n], key)
if err != nil {
logger.WithField("error", err).Debug("UDP SERVER: Failed to decrypt UDP packet")
continue
}
logger.WithField("decrypted_bytes", len(decryptedData)).Debug("UDP SERVER: Decrypted bytes")
// Deserialize tagged packet
packet, err := ts.deserializeTaggedUDPPacket(decryptedData)
if err != nil {
logger.WithField("error", err).Debug("UDP SERVER: Failed to deserialize tagged UDP packet")
continue
}
// Validate packet timestamp
if !encryption.ValidatePacketTimestamp(packet.Header.Timestamp) {
logger.WithFields(map[string]interface{}{
"packetID": packet.Header.PacketID,
"timestamp": packet.Header.Timestamp,
}).Debug("UDP SERVER: Packet timestamp validation failed")
continue
}
// Check for replay attacks
if !ts.replayProtection.IsValidNonce(packet.Header.PacketID, packet.Header.Timestamp) {
logger.WithFields(map[string]interface{}{
"packetID": packet.Header.PacketID,
"timestamp": packet.Header.Timestamp,
}).Debug("UDP SERVER: Replay attack detected")
continue
}
logger.WithFields(map[string]interface{}{
"packetID": packet.Header.PacketID,
"data_length": len(packet.Data),
}).Debug("UDP SERVER: Deserialized packet")
// Forward to target
_, err = targetConn.Write(packet.Data)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": packet.Header.PacketID,
"error": err,
}).Debug("UDP SERVER: Failed to forward UDP packet to target")
return
}
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP SERVER: Forwarded packet to target")
// Send the packet ID to the response handler
select {
case packetIDChan <- packet.Header.PacketID:
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP SERVER: Sent packet ID to response handler")
default:
logger.WithField("packetID", packet.Header.PacketID).Debug("UDP SERVER: Packet ID channel full, dropping packet ID")
}
}
}
}()
// Forward responses from target to client
go func() {
defer wg.Done()
buffer := make([]byte, 4096)
for {
select {
case <-ts.ctx.Done():
logger.Debug("UDP SERVER: Context cancelled in target->client goroutine")
return
default:
n, err := targetConn.Read(buffer)
if err != nil {
logger.WithField("error", err).Debug("UDP SERVER: Failed to read from target")
return
}
logger.WithField("bytes_received", n).Debug("UDP SERVER: Received bytes from target")
// Get the packet ID from the request
var originalPacketID uint64
select {
case originalPacketID = <-packetIDChan:
logger.WithField("packetID", originalPacketID).Debug("UDP SERVER: Got packet ID for response")
case <-time.After(1 * time.Second):
logger.Debug("UDP SERVER: Timeout waiting for packet ID, using default")
originalPacketID = 0
}
// Create tagged packet for response using the original packet ID
responsePacket := types.TaggedUDPPacket{
Header: types.UDPPacketHeader{
ClientID: clientID,
PacketID: originalPacketID,
Timestamp: time.Now().Unix(),
},
Data: make([]byte, n),
}
copy(responsePacket.Data, buffer[:n])
logger.WithFields(map[string]interface{}{
"packetID": originalPacketID,
"data": string(responsePacket.Data),
}).Debug("UDP SERVER: Created response packet")
// Serialize and encrypt response
data, err := ts.serializeTaggedUDPPacket(responsePacket)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": originalPacketID,
"error": err,
}).Debug("UDP SERVER: Failed to serialize UDP response")
continue
}
key := encryption.DeriveKey(ts.config.EncryptionKey)
encryptedData, err := encryption.EncryptData(data, key)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": originalPacketID,
"error": err,
}).Debug("UDP SERVER: Failed to encrypt UDP response")
continue
}
logger.WithFields(map[string]interface{}{
"packetID": originalPacketID,
"encrypted_length": len(encryptedData),
}).Debug("UDP SERVER: Encrypted response packet")
// Send response to client
_, err = clientConn.Write(encryptedData)
if err != nil {
logger.WithFields(map[string]interface{}{
"packetID": originalPacketID,
"error": err,
}).Debug("UDP SERVER: Failed to send UDP response to client")
return
}
logger.WithField("packetID", originalPacketID).Debug("UDP SERVER: Successfully sent response packet to client")
}
}
}()
wg.Wait()
// Unregister client
ts.udpMutex.Lock()
delete(ts.udpClients, clientID)
ts.udpMutex.Unlock()
}
// deserializeTaggedUDPPacket deserializes a tagged UDP packet
func (ts *TeleportServer) deserializeTaggedUDPPacket(data []byte) (types.TaggedUDPPacket, error) {
// Minimum size: 4 (headerLen) + 8 (packetID) + 8 (timestamp) = 20 bytes
if len(data) < 20 {
return types.TaggedUDPPacket{}, fmt.Errorf("packet data too short")
}
// Maximum reasonable packet size (1MB)
if len(data) > 1024*1024 {
return types.TaggedUDPPacket{}, fmt.Errorf("packet data too large")
}
offset := 0
// Header length
if offset+4 > len(data) {
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for header length")
}
headerLen := binary.BigEndian.Uint32(data[offset:])
offset += 4
// Validate header length (reasonable limits)
if headerLen > 1024 || headerLen == 0 {
return types.TaggedUDPPacket{}, fmt.Errorf("invalid header length: %d", headerLen)
}
// Packet ID
if offset+8 > len(data) {
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for packet ID")
}
packetID := binary.BigEndian.Uint64(data[offset:])
offset += 8
// Timestamp
if offset+8 > len(data) {
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for timestamp")
}
timestamp := binary.BigEndian.Uint64(data[offset:])
offset += 8
// Client ID
if offset+int(headerLen) > len(data) {
return types.TaggedUDPPacket{}, fmt.Errorf("insufficient data for client ID")
}
clientID := string(data[offset : offset+int(headerLen)])
offset += int(headerLen)
// Validate client ID (basic sanitization)
if len(clientID) == 0 {
return types.TaggedUDPPacket{}, fmt.Errorf("empty client ID")
}
// Check for reasonable client ID format
if len(clientID) > 256 {
return types.TaggedUDPPacket{}, fmt.Errorf("client ID too long")
}
// Data
if offset > len(data) {
return types.TaggedUDPPacket{}, fmt.Errorf("data offset exceeds packet length")
}
packetData := make([]byte, len(data)-offset)
copy(packetData, data[offset:])
return types.TaggedUDPPacket{
Header: types.UDPPacketHeader{
ClientID: clientID,
PacketID: packetID,
Timestamp: int64(timestamp),
},
Data: packetData,
}, nil
}
// forwardData forwards data between two connections
func (ts *TeleportServer) forwardData(src, dst net.Conn) {
buffer := make([]byte, 4096)
for {
select {
case <-ts.ctx.Done():
return
default:
n, err := src.Read(buffer)
if err != nil {
// Close the destination connection when source closes
dst.Close()
return
}
_, err = dst.Write(buffer[:n])
if err != nil {
// Close the source connection when destination closes
src.Close()
return
}
}
}
}
// readRequest reads a port forward request from the connection
func (ts *TeleportServer) readRequest(conn net.Conn, request *types.PortForwardRequest) error {
// Read request length
var length uint32
if err := binary.Read(conn, binary.BigEndian, &length); err != nil {
return err
}
// Validate request length (prevent buffer overflow attacks)
if length == 0 {
return fmt.Errorf("request length cannot be zero")
}
if length > 32*1024 { // Reduced to 32KB limit for better security
return fmt.Errorf("request too large: %d bytes (max 32KB)", length)
}
// Read encrypted request data with timeout
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
encryptedData := make([]byte, length)
bytesRead := 0
for bytesRead < int(length) {
n, err := conn.Read(encryptedData[bytesRead:])
if err != nil {
return fmt.Errorf("failed to read request data: %v", err)
}
bytesRead += n
}
// Decrypt the data
key := encryption.DeriveKey(ts.config.EncryptionKey)
decryptedData, err := encryption.DecryptData(encryptedData, key)
if err != nil {
logger.WithFields(map[string]interface{}{
"error": err,
"encrypted_data_length": len(encryptedData),
}).Error("Decryption failed")
return err
}
// Deserialize the request
return ts.deserializeRequest(decryptedData, request)
}
// deserializeRequest deserializes a port forward request
func (ts *TeleportServer) deserializeRequest(data []byte, request *types.PortForwardRequest) error {
// Minimum size: 4 (localPort) + 4 (remotePort) + 4 (protocolLen) + 4 (targetHostLen) = 16 bytes
if len(data) < 16 {
return fmt.Errorf("request data too short")
}
// Maximum reasonable request size (64KB)
if len(data) > 64*1024 {
return fmt.Errorf("request data too large")
}
offset := 0
// Local port
if offset+4 > len(data) {
return fmt.Errorf("insufficient data for local port")
}
request.LocalPort = int(binary.BigEndian.Uint32(data[offset:]))
offset += 4
// Validate port range
if request.LocalPort < 1 || request.LocalPort > 65535 {
return fmt.Errorf("invalid local port: %d", request.LocalPort)
}
// Remote port
if offset+4 > len(data) {
return fmt.Errorf("insufficient data for remote port")
}
request.RemotePort = int(binary.BigEndian.Uint32(data[offset:]))
offset += 4
// Validate port range
if request.RemotePort < 1 || request.RemotePort > 65535 {
return fmt.Errorf("invalid remote port: %d", request.RemotePort)
}
// Protocol length
if offset+4 > len(data) {
return fmt.Errorf("insufficient data for protocol length")
}
protocolLen := int(binary.BigEndian.Uint32(data[offset:]))
offset += 4
// Validate protocol length
if protocolLen < 1 || protocolLen > 64 {
return fmt.Errorf("invalid protocol length: %d", protocolLen)
}
if offset+protocolLen > len(data) {
return fmt.Errorf("insufficient data for protocol")
}
request.Protocol = string(data[offset : offset+protocolLen])
offset += protocolLen
// Validate protocol
if request.Protocol != "tcp" && request.Protocol != "udp" {
return fmt.Errorf("invalid protocol: %s", request.Protocol)
}
// Target host length
if offset+4 > len(data) {
return fmt.Errorf("insufficient data for target host length")
}
targetHostLen := int(binary.BigEndian.Uint32(data[offset:]))
offset += 4
// Validate target host length
if targetHostLen > 256 {
return fmt.Errorf("target host too long: %d", targetHostLen)
}
if offset+targetHostLen > len(data) {
return fmt.Errorf("insufficient data for target host")
}
request.TargetHost = string(data[offset : offset+targetHostLen])
// Basic validation of target host (if not empty)
if request.TargetHost != "" {
if len(request.TargetHost) > 253 { // RFC 1123 limit
return fmt.Errorf("target host exceeds maximum length")
}
// Basic character validation (no control characters)
for _, c := range request.TargetHost {
if c < 32 || c > 126 {
return fmt.Errorf("invalid character in target host")
}
}
}
return nil
}