- Created a new Go module named 'teleport' for secure port forwarding. - Added essential files including .gitignore, LICENSE, and README.md with project details. - Implemented configuration management with YAML support in config package. - Developed core client and server functionalities for handling port forwarding. - Introduced DNS server capabilities and integrated logging with sanitization. - Established rate limiting and metrics tracking for performance monitoring. - Included comprehensive tests for core components and functionalities. - Set up CI workflows for automated testing and release management using Gitea actions.
330 lines
7.5 KiB
Go
330 lines
7.5 KiB
Go
package dns
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"teleport/pkg/config"
|
|
"teleport/pkg/logger"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// StartDNSServer starts the built-in DNS server using miekg/dns
|
|
func StartDNSServer(cfg *config.Config) {
|
|
if !cfg.DNSServer.Enabled {
|
|
return
|
|
}
|
|
|
|
// Create DNS server
|
|
server := &dns.Server{
|
|
Addr: fmt.Sprintf(":%d", cfg.DNSServer.ListenPort),
|
|
Net: "udp",
|
|
}
|
|
|
|
// Set up handler
|
|
dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
|
|
handleDNSQuery(w, r, cfg)
|
|
})
|
|
|
|
logger.WithField("port", cfg.DNSServer.ListenPort).Info("DNS server started")
|
|
|
|
// Start server
|
|
if err := server.ListenAndServe(); err != nil {
|
|
logger.WithField("error", err).Error("Failed to start DNS server")
|
|
}
|
|
}
|
|
|
|
// handleDNSQuery handles DNS queries using miekg/dns
|
|
func handleDNSQuery(w dns.ResponseWriter, r *dns.Msg, cfg *config.Config) {
|
|
// Check if we have custom records for this query
|
|
response := checkCustomRecords(r, cfg)
|
|
if response == nil {
|
|
// Forward to backup DNS server
|
|
response = forwardToBackupDNS(r, cfg)
|
|
if response == nil {
|
|
// Send error response
|
|
response = new(dns.Msg)
|
|
response.SetRcode(r, dns.RcodeServerFailure)
|
|
}
|
|
}
|
|
|
|
// Send response
|
|
w.WriteMsg(response)
|
|
}
|
|
|
|
// checkCustomRecords checks if we have custom records for the query
|
|
func checkCustomRecords(query *dns.Msg, cfg *config.Config) *dns.Msg {
|
|
if len(query.Question) == 0 {
|
|
return nil
|
|
}
|
|
|
|
question := query.Question[0]
|
|
questionName := strings.ToLower(question.Name)
|
|
|
|
// Validate question name length and format
|
|
if len(questionName) > 253 {
|
|
logger.WithField("name", questionName).Warn("DNS query name too long, ignoring")
|
|
return nil
|
|
}
|
|
|
|
// Basic character validation for DNS name
|
|
for _, c := range questionName {
|
|
if c < 32 || c > 126 {
|
|
logger.WithField("name", questionName).Warn("DNS query name contains invalid characters, ignoring")
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Look for matching custom records
|
|
var answers []dns.RR
|
|
for _, record := range cfg.DNSServer.CustomRecords {
|
|
if strings.ToLower(record.Name) == questionName && getRecordType(record.Type) == question.Qtype {
|
|
answer := createDNSRecord(record, questionName)
|
|
if answer != nil {
|
|
answers = append(answers, answer)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(answers) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Create response
|
|
response := new(dns.Msg)
|
|
response.SetReply(query)
|
|
response.Authoritative = true
|
|
response.Answer = answers
|
|
|
|
return response
|
|
}
|
|
|
|
// getRecordType converts string record type to DNS type code
|
|
func getRecordType(recordType string) uint16 {
|
|
switch strings.ToUpper(recordType) {
|
|
case "A":
|
|
return dns.TypeA
|
|
case "AAAA":
|
|
return dns.TypeAAAA
|
|
case "CNAME":
|
|
return dns.TypeCNAME
|
|
case "MX":
|
|
return dns.TypeMX
|
|
case "TXT":
|
|
return dns.TypeTXT
|
|
case "NS":
|
|
return dns.TypeNS
|
|
case "SRV":
|
|
return dns.TypeSRV
|
|
default:
|
|
return 0
|
|
}
|
|
}
|
|
|
|
// createDNSRecord creates a DNS resource record from a custom record
|
|
func createDNSRecord(record config.DNSRecord, name string) dns.RR {
|
|
// Validate record value length
|
|
if len(record.Value) > 1024 {
|
|
logger.WithFields(map[string]interface{}{
|
|
"type": record.Type,
|
|
"name": name,
|
|
"value": record.Value,
|
|
}).Warn("DNS record value too long, ignoring")
|
|
return nil
|
|
}
|
|
|
|
switch strings.ToUpper(record.Type) {
|
|
case "A":
|
|
// IPv4 address
|
|
ip := net.ParseIP(record.Value)
|
|
if ip == nil || ip.To4() == nil {
|
|
logger.WithFields(map[string]interface{}{
|
|
"type": record.Type,
|
|
"name": name,
|
|
"value": record.Value,
|
|
}).Warn("Invalid IPv4 address in DNS record, ignoring")
|
|
return nil
|
|
}
|
|
return &dns.A{
|
|
Hdr: dns.RR_Header{
|
|
Name: name,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
Ttl: record.TTL,
|
|
},
|
|
A: ip.To4(),
|
|
}
|
|
|
|
case "AAAA":
|
|
// IPv6 address
|
|
ip := net.ParseIP(record.Value)
|
|
if ip == nil || ip.To16() == nil {
|
|
logger.WithFields(map[string]interface{}{
|
|
"type": record.Type,
|
|
"name": name,
|
|
"value": record.Value,
|
|
}).Warn("Invalid IPv6 address in DNS record, ignoring")
|
|
return nil
|
|
}
|
|
return &dns.AAAA{
|
|
Hdr: dns.RR_Header{
|
|
Name: name,
|
|
Rrtype: dns.TypeAAAA,
|
|
Class: dns.ClassINET,
|
|
Ttl: record.TTL,
|
|
},
|
|
AAAA: ip.To16(),
|
|
}
|
|
|
|
case "CNAME":
|
|
// Canonical name
|
|
// Validate CNAME target length
|
|
if len(record.Value) > 253 {
|
|
logger.WithFields(map[string]interface{}{
|
|
"type": record.Type,
|
|
"name": name,
|
|
"value": record.Value,
|
|
}).Warn("CNAME target too long, ignoring")
|
|
return nil
|
|
}
|
|
return &dns.CNAME{
|
|
Hdr: dns.RR_Header{
|
|
Name: name,
|
|
Rrtype: dns.TypeCNAME,
|
|
Class: dns.ClassINET,
|
|
Ttl: record.TTL,
|
|
},
|
|
Target: dns.Fqdn(record.Value),
|
|
}
|
|
|
|
case "MX":
|
|
// Mail exchange
|
|
// Validate MX target length
|
|
if len(record.Value) > 253 {
|
|
logger.WithFields(map[string]interface{}{
|
|
"type": record.Type,
|
|
"name": name,
|
|
"value": record.Value,
|
|
}).Warn("MX target too long, ignoring")
|
|
return nil
|
|
}
|
|
return &dns.MX{
|
|
Hdr: dns.RR_Header{
|
|
Name: name,
|
|
Rrtype: dns.TypeMX,
|
|
Class: dns.ClassINET,
|
|
Ttl: record.TTL,
|
|
},
|
|
Preference: record.Priority,
|
|
Mx: dns.Fqdn(record.Value),
|
|
}
|
|
|
|
case "TXT":
|
|
// Text record
|
|
// Validate TXT record length (RFC 1035 limit)
|
|
if len(record.Value) > 255 {
|
|
logger.WithFields(map[string]interface{}{
|
|
"type": record.Type,
|
|
"name": name,
|
|
"value": record.Value,
|
|
}).Warn("TXT record too long, ignoring")
|
|
return nil
|
|
}
|
|
return &dns.TXT{
|
|
Hdr: dns.RR_Header{
|
|
Name: name,
|
|
Rrtype: dns.TypeTXT,
|
|
Class: dns.ClassINET,
|
|
Ttl: record.TTL,
|
|
},
|
|
Txt: []string{record.Value},
|
|
}
|
|
|
|
case "NS":
|
|
// Name server
|
|
// Validate NS target length
|
|
if len(record.Value) > 253 {
|
|
logger.WithFields(map[string]interface{}{
|
|
"type": record.Type,
|
|
"name": name,
|
|
"value": record.Value,
|
|
}).Warn("NS target too long, ignoring")
|
|
return nil
|
|
}
|
|
return &dns.NS{
|
|
Hdr: dns.RR_Header{
|
|
Name: name,
|
|
Rrtype: dns.TypeNS,
|
|
Class: dns.ClassINET,
|
|
Ttl: record.TTL,
|
|
},
|
|
Ns: dns.Fqdn(record.Value),
|
|
}
|
|
|
|
case "SRV":
|
|
// Service record
|
|
// Validate SRV target length
|
|
if len(record.Value) > 253 {
|
|
logger.WithFields(map[string]interface{}{
|
|
"type": record.Type,
|
|
"name": name,
|
|
"value": record.Value,
|
|
}).Warn("SRV target too long, ignoring")
|
|
return nil
|
|
}
|
|
return &dns.SRV{
|
|
Hdr: dns.RR_Header{
|
|
Name: name,
|
|
Rrtype: dns.TypeSRV,
|
|
Class: dns.ClassINET,
|
|
Ttl: record.TTL,
|
|
},
|
|
Priority: record.Priority,
|
|
Weight: record.Weight,
|
|
Port: record.Port,
|
|
Target: dns.Fqdn(record.Value),
|
|
}
|
|
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// forwardToBackupDNS forwards the query to the backup DNS server
|
|
func forwardToBackupDNS(query *dns.Msg, cfg *config.Config) *dns.Msg {
|
|
// Create DNS client with timeout
|
|
client := new(dns.Client)
|
|
client.Net = "udp"
|
|
client.Timeout = 5 * time.Second // 5 second timeout
|
|
|
|
// Forward query to backup server
|
|
response, _, err := client.Exchange(query, cfg.DNSServer.BackupServer)
|
|
if err != nil {
|
|
logger.WithField("error", err).Error("Failed to forward DNS query to backup server")
|
|
return nil
|
|
}
|
|
|
|
// Validate response
|
|
if response == nil {
|
|
logger.Warn("Received nil response from backup DNS server")
|
|
return nil
|
|
}
|
|
|
|
// Basic response validation
|
|
if response.Id != query.Id {
|
|
logger.Warn("DNS response ID mismatch, potential spoofing attempt")
|
|
return nil
|
|
}
|
|
|
|
// Limit response size to prevent amplification attacks
|
|
if len(response.Answer) > 10 {
|
|
logger.WithField("answer_count", len(response.Answer)).Warn("DNS response has too many answers, truncating")
|
|
response.Answer = response.Answer[:10]
|
|
}
|
|
|
|
return response
|
|
}
|