Files
teleport/pkg/dns/dns.go
Justin Harms d24d1dc5ae Add initial project structure with core functionality
- Created a new Go module named 'teleport' for secure port forwarding.
- Added essential files including .gitignore, LICENSE, and README.md with project details.
- Implemented configuration management with YAML support in config package.
- Developed core client and server functionalities for handling port forwarding.
- Introduced DNS server capabilities and integrated logging with sanitization.
- Established rate limiting and metrics tracking for performance monitoring.
- Included comprehensive tests for core components and functionalities.
- Set up CI workflows for automated testing and release management using Gitea actions.
2025-09-20 18:07:08 -05:00

330 lines
7.5 KiB
Go

package dns
import (
"fmt"
"net"
"strings"
"time"
"teleport/pkg/config"
"teleport/pkg/logger"
"github.com/miekg/dns"
)
// StartDNSServer starts the built-in DNS server using miekg/dns
func StartDNSServer(cfg *config.Config) {
if !cfg.DNSServer.Enabled {
return
}
// Create DNS server
server := &dns.Server{
Addr: fmt.Sprintf(":%d", cfg.DNSServer.ListenPort),
Net: "udp",
}
// Set up handler
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
handleDNSQuery(w, r, cfg)
})
logger.WithField("port", cfg.DNSServer.ListenPort).Info("DNS server started")
// Start server
if err := server.ListenAndServe(); err != nil {
logger.WithField("error", err).Error("Failed to start DNS server")
}
}
// handleDNSQuery handles DNS queries using miekg/dns
func handleDNSQuery(w dns.ResponseWriter, r *dns.Msg, cfg *config.Config) {
// Check if we have custom records for this query
response := checkCustomRecords(r, cfg)
if response == nil {
// Forward to backup DNS server
response = forwardToBackupDNS(r, cfg)
if response == nil {
// Send error response
response = new(dns.Msg)
response.SetRcode(r, dns.RcodeServerFailure)
}
}
// Send response
w.WriteMsg(response)
}
// checkCustomRecords checks if we have custom records for the query
func checkCustomRecords(query *dns.Msg, cfg *config.Config) *dns.Msg {
if len(query.Question) == 0 {
return nil
}
question := query.Question[0]
questionName := strings.ToLower(question.Name)
// Validate question name length and format
if len(questionName) > 253 {
logger.WithField("name", questionName).Warn("DNS query name too long, ignoring")
return nil
}
// Basic character validation for DNS name
for _, c := range questionName {
if c < 32 || c > 126 {
logger.WithField("name", questionName).Warn("DNS query name contains invalid characters, ignoring")
return nil
}
}
// Look for matching custom records
var answers []dns.RR
for _, record := range cfg.DNSServer.CustomRecords {
if strings.ToLower(record.Name) == questionName && getRecordType(record.Type) == question.Qtype {
answer := createDNSRecord(record, questionName)
if answer != nil {
answers = append(answers, answer)
}
}
}
if len(answers) == 0 {
return nil
}
// Create response
response := new(dns.Msg)
response.SetReply(query)
response.Authoritative = true
response.Answer = answers
return response
}
// getRecordType converts string record type to DNS type code
func getRecordType(recordType string) uint16 {
switch strings.ToUpper(recordType) {
case "A":
return dns.TypeA
case "AAAA":
return dns.TypeAAAA
case "CNAME":
return dns.TypeCNAME
case "MX":
return dns.TypeMX
case "TXT":
return dns.TypeTXT
case "NS":
return dns.TypeNS
case "SRV":
return dns.TypeSRV
default:
return 0
}
}
// createDNSRecord creates a DNS resource record from a custom record
func createDNSRecord(record config.DNSRecord, name string) dns.RR {
// Validate record value length
if len(record.Value) > 1024 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("DNS record value too long, ignoring")
return nil
}
switch strings.ToUpper(record.Type) {
case "A":
// IPv4 address
ip := net.ParseIP(record.Value)
if ip == nil || ip.To4() == nil {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("Invalid IPv4 address in DNS record, ignoring")
return nil
}
return &dns.A{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: record.TTL,
},
A: ip.To4(),
}
case "AAAA":
// IPv6 address
ip := net.ParseIP(record.Value)
if ip == nil || ip.To16() == nil {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("Invalid IPv6 address in DNS record, ignoring")
return nil
}
return &dns.AAAA{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: record.TTL,
},
AAAA: ip.To16(),
}
case "CNAME":
// Canonical name
// Validate CNAME target length
if len(record.Value) > 253 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("CNAME target too long, ignoring")
return nil
}
return &dns.CNAME{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: record.TTL,
},
Target: dns.Fqdn(record.Value),
}
case "MX":
// Mail exchange
// Validate MX target length
if len(record.Value) > 253 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("MX target too long, ignoring")
return nil
}
return &dns.MX{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeMX,
Class: dns.ClassINET,
Ttl: record.TTL,
},
Preference: record.Priority,
Mx: dns.Fqdn(record.Value),
}
case "TXT":
// Text record
// Validate TXT record length (RFC 1035 limit)
if len(record.Value) > 255 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("TXT record too long, ignoring")
return nil
}
return &dns.TXT{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: record.TTL,
},
Txt: []string{record.Value},
}
case "NS":
// Name server
// Validate NS target length
if len(record.Value) > 253 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("NS target too long, ignoring")
return nil
}
return &dns.NS{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeNS,
Class: dns.ClassINET,
Ttl: record.TTL,
},
Ns: dns.Fqdn(record.Value),
}
case "SRV":
// Service record
// Validate SRV target length
if len(record.Value) > 253 {
logger.WithFields(map[string]interface{}{
"type": record.Type,
"name": name,
"value": record.Value,
}).Warn("SRV target too long, ignoring")
return nil
}
return &dns.SRV{
Hdr: dns.RR_Header{
Name: name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: record.TTL,
},
Priority: record.Priority,
Weight: record.Weight,
Port: record.Port,
Target: dns.Fqdn(record.Value),
}
default:
return nil
}
}
// forwardToBackupDNS forwards the query to the backup DNS server
func forwardToBackupDNS(query *dns.Msg, cfg *config.Config) *dns.Msg {
// Create DNS client with timeout
client := new(dns.Client)
client.Net = "udp"
client.Timeout = 5 * time.Second // 5 second timeout
// Forward query to backup server
response, _, err := client.Exchange(query, cfg.DNSServer.BackupServer)
if err != nil {
logger.WithField("error", err).Error("Failed to forward DNS query to backup server")
return nil
}
// Validate response
if response == nil {
logger.Warn("Received nil response from backup DNS server")
return nil
}
// Basic response validation
if response.Id != query.Id {
logger.Warn("DNS response ID mismatch, potential spoofing attempt")
return nil
}
// Limit response size to prevent amplification attacks
if len(response.Answer) > 10 {
logger.WithField("answer_count", len(response.Answer)).Warn("DNS response has too many answers, truncating")
response.Answer = response.Answer[:10]
}
return response
}