package dns import ( "fmt" "net" "strings" "time" "teleport/pkg/config" "teleport/pkg/logger" "github.com/miekg/dns" ) // StartDNSServer starts the built-in DNS server using miekg/dns func StartDNSServer(cfg *config.Config) { if !cfg.DNSServer.Enabled { return } // Create DNS server server := &dns.Server{ Addr: fmt.Sprintf(":%d", cfg.DNSServer.ListenPort), Net: "udp", } // Set up handler dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { handleDNSQuery(w, r, cfg) }) logger.WithField("port", cfg.DNSServer.ListenPort).Info("DNS server started") // Start server if err := server.ListenAndServe(); err != nil { logger.WithField("error", err).Error("Failed to start DNS server") } } // handleDNSQuery handles DNS queries using miekg/dns func handleDNSQuery(w dns.ResponseWriter, r *dns.Msg, cfg *config.Config) { // Check if we have custom records for this query response := checkCustomRecords(r, cfg) if response == nil { // Forward to backup DNS server response = forwardToBackupDNS(r, cfg) if response == nil { // Send error response response = new(dns.Msg) response.SetRcode(r, dns.RcodeServerFailure) } } // Send response w.WriteMsg(response) } // checkCustomRecords checks if we have custom records for the query func checkCustomRecords(query *dns.Msg, cfg *config.Config) *dns.Msg { if len(query.Question) == 0 { return nil } question := query.Question[0] questionName := strings.ToLower(question.Name) // Validate question name length and format if len(questionName) > 253 { logger.WithField("name", questionName).Warn("DNS query name too long, ignoring") return nil } // Basic character validation for DNS name for _, c := range questionName { if c < 32 || c > 126 { logger.WithField("name", questionName).Warn("DNS query name contains invalid characters, ignoring") return nil } } // Look for matching custom records var answers []dns.RR for _, record := range cfg.DNSServer.CustomRecords { if strings.ToLower(record.Name) == questionName && getRecordType(record.Type) == question.Qtype { answer := createDNSRecord(record, questionName) if answer != nil { answers = append(answers, answer) } } } if len(answers) == 0 { return nil } // Create response response := new(dns.Msg) response.SetReply(query) response.Authoritative = true response.Answer = answers return response } // getRecordType converts string record type to DNS type code func getRecordType(recordType string) uint16 { switch strings.ToUpper(recordType) { case "A": return dns.TypeA case "AAAA": return dns.TypeAAAA case "CNAME": return dns.TypeCNAME case "MX": return dns.TypeMX case "TXT": return dns.TypeTXT case "NS": return dns.TypeNS case "SRV": return dns.TypeSRV default: return 0 } } // createDNSRecord creates a DNS resource record from a custom record func createDNSRecord(record config.DNSRecord, name string) dns.RR { // Validate record value length if len(record.Value) > 1024 { logger.WithFields(map[string]interface{}{ "type": record.Type, "name": name, "value": record.Value, }).Warn("DNS record value too long, ignoring") return nil } switch strings.ToUpper(record.Type) { case "A": // IPv4 address ip := net.ParseIP(record.Value) if ip == nil || ip.To4() == nil { logger.WithFields(map[string]interface{}{ "type": record.Type, "name": name, "value": record.Value, }).Warn("Invalid IPv4 address in DNS record, ignoring") return nil } return &dns.A{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: record.TTL, }, A: ip.To4(), } case "AAAA": // IPv6 address ip := net.ParseIP(record.Value) if ip == nil || ip.To16() == nil { logger.WithFields(map[string]interface{}{ "type": record.Type, "name": name, "value": record.Value, }).Warn("Invalid IPv6 address in DNS record, ignoring") return nil } return &dns.AAAA{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: record.TTL, }, AAAA: ip.To16(), } case "CNAME": // Canonical name // Validate CNAME target length if len(record.Value) > 253 { logger.WithFields(map[string]interface{}{ "type": record.Type, "name": name, "value": record.Value, }).Warn("CNAME target too long, ignoring") return nil } return &dns.CNAME{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: record.TTL, }, Target: dns.Fqdn(record.Value), } case "MX": // Mail exchange // Validate MX target length if len(record.Value) > 253 { logger.WithFields(map[string]interface{}{ "type": record.Type, "name": name, "value": record.Value, }).Warn("MX target too long, ignoring") return nil } return &dns.MX{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeMX, Class: dns.ClassINET, Ttl: record.TTL, }, Preference: record.Priority, Mx: dns.Fqdn(record.Value), } case "TXT": // Text record // Validate TXT record length (RFC 1035 limit) if len(record.Value) > 255 { logger.WithFields(map[string]interface{}{ "type": record.Type, "name": name, "value": record.Value, }).Warn("TXT record too long, ignoring") return nil } return &dns.TXT{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: record.TTL, }, Txt: []string{record.Value}, } case "NS": // Name server // Validate NS target length if len(record.Value) > 253 { logger.WithFields(map[string]interface{}{ "type": record.Type, "name": name, "value": record.Value, }).Warn("NS target too long, ignoring") return nil } return &dns.NS{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: record.TTL, }, Ns: dns.Fqdn(record.Value), } case "SRV": // Service record // Validate SRV target length if len(record.Value) > 253 { logger.WithFields(map[string]interface{}{ "type": record.Type, "name": name, "value": record.Value, }).Warn("SRV target too long, ignoring") return nil } return &dns.SRV{ Hdr: dns.RR_Header{ Name: name, Rrtype: dns.TypeSRV, Class: dns.ClassINET, Ttl: record.TTL, }, Priority: record.Priority, Weight: record.Weight, Port: record.Port, Target: dns.Fqdn(record.Value), } default: return nil } } // forwardToBackupDNS forwards the query to the backup DNS server func forwardToBackupDNS(query *dns.Msg, cfg *config.Config) *dns.Msg { // Create DNS client with timeout client := new(dns.Client) client.Net = "udp" client.Timeout = 5 * time.Second // 5 second timeout // Forward query to backup server response, _, err := client.Exchange(query, cfg.DNSServer.BackupServer) if err != nil { logger.WithField("error", err).Error("Failed to forward DNS query to backup server") return nil } // Validate response if response == nil { logger.Warn("Received nil response from backup DNS server") return nil } // Basic response validation if response.Id != query.Id { logger.Warn("DNS response ID mismatch, potential spoofing attempt") return nil } // Limit response size to prevent amplification attacks if len(response.Answer) > 10 { logger.WithField("answer_count", len(response.Answer)).Warn("DNS response has too many answers, truncating") response.Answer = response.Answer[:10] } return response }