bettercap/modules/dns_proxy/dns_proxy_base.go
2024-10-12 22:05:29 +02:00

242 lines
5.4 KiB
Go

package dns_proxy
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"strings"
"time"
"github.com/bettercap/bettercap/v2/firewall"
"github.com/bettercap/bettercap/v2/session"
"github.com/evilsocket/islazy/log"
"github.com/miekg/dns"
)
const (
dialTimeout = 2 * time.Second
readTimeout = 2 * time.Second
writeTimeout = 2 * time.Second
)
type DNSProxy struct {
Name string
Address string
Server *dns.Server
Redirection *firewall.Redirection
Nameserver string
NetProtocol string
Script *DnsProxyScript
CertFile string
KeyFile string
Blacklist []string
Whitelist []string
Sess *session.Session
doRedirect bool
isRunning bool
tag string
}
func (p *DNSProxy) shouldProxy(clientIP string) bool {
// check if this client is in the whitelist
for _, ip := range p.Whitelist {
if clientIP == ip {
return true
}
}
// check if this client is in the blacklist
for _, ip := range p.Blacklist {
if ip == "*" || clientIP == ip {
return false
}
}
return true
}
func (p *DNSProxy) Configure(address string, dnsPort int, doRedirect bool, nameserver string, netProtocol string, proxyPort int, scriptPath string, certFile string, keyFile string) error {
var err error
p.Address = address
p.doRedirect = doRedirect
p.CertFile = certFile
p.KeyFile = keyFile
if scriptPath != "" {
if err, p.Script = LoadDnsProxyScript(scriptPath, p.Sess); err != nil {
return err
} else {
p.Debug("proxy script %s loaded.", scriptPath)
}
}
dnsClient := dns.Client{
DialTimeout: dialTimeout,
Net: netProtocol,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
}
resolverAddr := fmt.Sprintf("%s:%d", nameserver, dnsPort)
handler := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg)
m.SetReply(req)
clientIP := strings.Split(w.RemoteAddr().String(), ":")[0]
req, res := p.onRequestFilter(req, clientIP)
if res == nil {
// unused var is time til res
res, _, err := dnsClient.Exchange(req, resolverAddr)
if err != nil {
p.Debug("error while resolving DNS query: %s", err.Error())
m.SetRcode(req, dns.RcodeServerFailure)
w.WriteMsg(m)
return
}
res = p.onResponseFilter(req, res, clientIP)
if res == nil {
p.Debug("response is nil")
m.SetRcode(req, dns.RcodeServerFailure)
w.WriteMsg(m)
return
} else {
if err := w.WriteMsg(res); err != nil {
p.Error("Error writing response: %s", err)
}
}
} else {
if err := w.WriteMsg(res); err != nil {
p.Error("Error writing response: %s", err)
}
}
})
p.Server = &dns.Server{
Addr: fmt.Sprintf("%s:%d", address, proxyPort),
Net: netProtocol,
Handler: handler,
}
if netProtocol == "tcp-tls" && p.CertFile != "" && p.KeyFile != "" {
rawCert, _ := ioutil.ReadFile(p.CertFile)
rawKey, _ := ioutil.ReadFile(p.KeyFile)
ourCa, err := tls.X509KeyPair(rawCert, rawKey)
if err != nil {
return err
}
if ourCa.Leaf, err = x509.ParseCertificate(ourCa.Certificate[0]); err != nil {
return err
}
p.Server.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{ourCa},
}
}
if p.doRedirect {
if !p.Sess.Firewall.IsForwardingEnabled() {
p.Info("enabling forwarding.")
p.Sess.Firewall.EnableForwarding(true)
}
redirectProtocol := netProtocol
if redirectProtocol == "tcp-tls" {
redirectProtocol = "tcp"
}
p.Redirection = firewall.NewRedirection(p.Sess.Interface.Name(),
redirectProtocol,
dnsPort,
p.Address,
proxyPort)
if err := p.Sess.Firewall.EnableRedirection(p.Redirection, true); err != nil {
return err
}
p.Debug("applied redirection %s", p.Redirection.String())
} else {
p.Warning("port redirection disabled, the proxy must be set manually to work")
}
p.Sess.UnkCmdCallback = func(cmd string) bool {
if p.Script != nil {
return p.Script.OnCommand(cmd)
}
return false
}
return nil
}
func (p *DNSProxy) dnsWorker() error {
p.isRunning = true
return p.Server.ListenAndServe()
}
func (p *DNSProxy) Debug(format string, args ...interface{}) {
p.Sess.Events.Log(log.DEBUG, p.tag+format, args...)
}
func (p *DNSProxy) Info(format string, args ...interface{}) {
p.Sess.Events.Log(log.INFO, p.tag+format, args...)
}
func (p *DNSProxy) Warning(format string, args ...interface{}) {
p.Sess.Events.Log(log.WARNING, p.tag+format, args...)
}
func (p *DNSProxy) Error(format string, args ...interface{}) {
p.Sess.Events.Log(log.ERROR, p.tag+format, args...)
}
func (p *DNSProxy) Fatal(format string, args ...interface{}) {
p.Sess.Events.Log(log.FATAL, p.tag+format, args...)
}
func NewDNSProxy(s *session.Session, tag string) *DNSProxy {
p := &DNSProxy{
Name: "dns.proxy",
Sess: s,
Server: nil,
doRedirect: true,
tag: session.AsTag(tag),
}
return p
}
func (p *DNSProxy) Start() {
go func() {
p.Info("started on %s", p.Server.Addr)
err := p.dnsWorker()
// TODO: check the dns server closed error
if err != nil && err.Error() != "dns: Server closed" {
p.Fatal("%s", err)
}
}()
}
func (p *DNSProxy) Stop() error {
if p.doRedirect && p.Redirection != nil {
p.Debug("disabling redirection %s", p.Redirection.String())
if err := p.Sess.Firewall.EnableRedirection(p.Redirection, false); err != nil {
return err
}
p.Redirection = nil
}
p.Sess.UnkCmdCallback = nil
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return p.Server.ShutdownContext(ctx)
}