Commit cbc9c1f7 authored by ginuerzh's avatar ginuerzh

dns: add edns0 subnet option support

parent 8121e20c
...@@ -557,6 +557,8 @@ func (r *route) GenRouters() ([]router, error) { ...@@ -557,6 +557,8 @@ func (r *route) GenRouters() ([]router, error) {
gost.ChainResolverOption(chain), gost.ChainResolverOption(chain),
gost.TimeoutResolverOption(timeout), gost.TimeoutResolverOption(timeout),
gost.TTLResolverOption(ttl), gost.TTLResolverOption(ttl),
gost.PreferResolverOption(node.Get("prefer")),
gost.SrcIPResolverOption(net.ParseIP(node.Get("ip"))),
) )
} }
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
...@@ -122,6 +123,8 @@ type resolverOptions struct { ...@@ -122,6 +123,8 @@ type resolverOptions struct {
chain *Chain chain *Chain
timeout time.Duration timeout time.Duration
ttl time.Duration ttl time.Duration
prefer string
srcIP net.IP
} }
// ResolverOption allows a common way to set Resolver options. // ResolverOption allows a common way to set Resolver options.
...@@ -148,6 +151,20 @@ func TTLResolverOption(ttl time.Duration) ResolverOption { ...@@ -148,6 +151,20 @@ func TTLResolverOption(ttl time.Duration) ResolverOption {
} }
} }
// PreferResolverOption sets the prefer for Resolver.
func PreferResolverOption(prefer string) ResolverOption {
return func(opts *resolverOptions) {
opts.prefer = prefer
}
}
// SrcIPResolverOption sets the source IP for Resolver.
func SrcIPResolverOption(ip net.IP) ResolverOption {
return func(opts *resolverOptions) {
opts.srcIP = ip
}
}
// Resolver is a name resolver for domain name. // Resolver is a name resolver for domain name.
// It contains a list of name servers. // It contains a list of name servers.
type Resolver interface { type Resolver interface {
...@@ -177,6 +194,7 @@ type resolver struct { ...@@ -177,6 +194,7 @@ type resolver struct {
stopped chan struct{} stopped chan struct{}
mux sync.RWMutex mux sync.RWMutex
prefer string // ipv4 or ipv6 prefer string // ipv4 or ipv6
srcIP net.IP // for edns0 subnet option
options resolverOptions options resolverOptions
} }
...@@ -217,6 +235,12 @@ func (r *resolver) Init(opts ...ResolverOption) error { ...@@ -217,6 +235,12 @@ func (r *resolver) Init(opts ...ResolverOption) error {
if r.options.ttl != 0 { if r.options.ttl != 0 {
r.ttl = r.options.ttl r.ttl = r.options.ttl
} }
if r.options.prefer != "" {
r.prefer = r.options.prefer
}
if r.options.srcIP != nil {
r.srcIP = r.options.srcIP
}
var nss []NameServer var nss []NameServer
for _, ns := range r.servers { for _, ns := range r.servers {
...@@ -259,8 +283,9 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { ...@@ -259,8 +283,9 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
host = host + "." + domain host = host + "." + domain
} }
ctx := context.Background()
for _, ns := range r.copyServers() { for _, ns := range r.copyServers() {
ips, err = r.resolve(ns.exchanger, host) ips, err = r.resolve(ctx, ns.exchanger, host)
if err != nil { if err != nil {
log.Logf("[resolver] %s via %s : %s", host, ns.String(), err) log.Logf("[resolver] %s via %s : %s", host, ns.String(), err)
continue continue
...@@ -277,7 +302,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) { ...@@ -277,7 +302,7 @@ func (r *resolver) Resolve(host string) (ips []net.IP, err error) {
return return
} }
func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) { func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) {
if ex == nil { if ex == nil {
return return
} }
...@@ -286,7 +311,6 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) ...@@ -286,7 +311,6 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error)
prefer := r.prefer prefer := r.prefer
r.mux.RUnlock() r.mux.RUnlock()
ctx := context.Background()
if prefer == "ipv6" { // prefer ipv6 if prefer == "ipv6" { // prefer ipv6
mq := &dns.Msg{} mq := &dns.Msg{}
mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA)
...@@ -302,9 +326,15 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error) ...@@ -302,9 +326,15 @@ func (r *resolver) resolve(ex Exchanger, host string) (ips []net.IP, err error)
} }
func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) { func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) {
mr, _, err := r.exchangeMsg(ctx, ex, mq) key := newResolverCacheKey(&mq.Question[0])
if err != nil { mr := r.cache.loadCache(key)
return if mr == nil {
r.addSubnetOpt(mq)
mr, err = r.exchangeMsg(ctx, ex, mq)
if err != nil {
return
}
r.cache.storeCache(key, mr, r.TTL())
} }
for _, ans := range mr.Answer { for _, ans := range mr.Answer {
...@@ -319,49 +349,73 @@ func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (i ...@@ -319,49 +349,73 @@ func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (i
return return
} }
func (r *resolver) addSubnetOpt(m *dns.Msg) {
if m == nil || r.srcIP == nil {
return
}
opt := new(dns.OPT)
opt.Hdr.Name = "."
opt.Hdr.Rrtype = dns.TypeOPT
e := new(dns.EDNS0_SUBNET)
e.Code = dns.EDNS0SUBNET
if ip := r.srcIP.To4(); ip != nil {
e.Family = 1
e.SourceNetmask = 32
e.Address = ip.To4()
} else {
e.Family = 2
e.SourceNetmask = 128
e.Address = r.srcIP
}
opt.Option = append(opt.Option, e)
m.Extra = append(m.Extra, opt)
}
func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) { func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) {
mq := &dns.Msg{} mq := &dns.Msg{}
if err = mq.Unpack(query); err != nil { if err = mq.Unpack(query); err != nil {
return return
} }
var qs string if len(mq.Question) == 0 {
if len(mq.Question) > 0 { return nil, errors.New("empty question")
qs = mq.Question[0].String()
} }
var mr *dns.Msg var mr *dns.Msg
for _, ns := range r.copyServers() {
var cache bool
mr, cache, err = r.exchangeMsg(ctx, ns.exchanger, mq)
log.Logf("[dns] exchange message %d via %s (cache hit: %v): %s", mq.Id, ns.String(), cache, qs)
if err == nil {
break
}
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), err)
}
if err != nil {
return
}
return mr.Pack()
}
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, cache bool, err error) {
// Only cache for single question. // Only cache for single question.
if len(mq.Question) == 1 { if len(mq.Question) == 1 {
key := newResolverCacheKey(&mq.Question[0]) key := newResolverCacheKey(&mq.Question[0])
mr = r.cache.loadCache(key) mr = r.cache.loadCache(key)
if mr != nil { if mr != nil {
cache = true log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String())
mr.Id = mq.Id mr.Id = mq.Id
return return mr.Pack()
} }
defer func() { defer func() {
r.cache.storeCache(key, mr, r.TTL()) if mr != nil {
r.cache.storeCache(key, mr, r.TTL())
}
}() }()
} }
r.addSubnetOpt(mq)
for _, ns := range r.copyServers() {
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String())
mr, err = r.exchangeMsg(ctx, ns.exchanger, mq)
if err == nil {
break
}
log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), err)
}
if err != nil {
return
}
return mr.Pack()
}
func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) {
query, err := mq.Pack() query, err := mq.Pack()
if err != nil { if err != nil {
return return
...@@ -386,6 +440,7 @@ func (r *resolver) TTL() time.Duration { ...@@ -386,6 +440,7 @@ func (r *resolver) TTL() time.Duration {
func (r *resolver) Reload(rd io.Reader) error { func (r *resolver) Reload(rd io.Reader) error {
var ttl, timeout, period time.Duration var ttl, timeout, period time.Duration
var domain, prefer string var domain, prefer string
var srcIP net.IP
var nss []NameServer var nss []NameServer
if rd == nil || r.Stopped() { if rd == nil || r.Stopped() {
...@@ -422,6 +477,10 @@ func (r *resolver) Reload(rd io.Reader) error { ...@@ -422,6 +477,10 @@ func (r *resolver) Reload(rd io.Reader) error {
if len(ss) > 1 { if len(ss) > 1 {
prefer = strings.ToLower(ss[1]) prefer = strings.ToLower(ss[1])
} }
case "ip":
if len(ss) > 1 {
srcIP = net.ParseIP(ss[1])
}
case "nameserver": // nameserver option, compatible with /etc/resolv.conf case "nameserver": // nameserver option, compatible with /etc/resolv.conf
if len(ss) <= 1 { if len(ss) <= 1 {
break break
...@@ -461,6 +520,7 @@ func (r *resolver) Reload(rd io.Reader) error { ...@@ -461,6 +520,7 @@ func (r *resolver) Reload(rd io.Reader) error {
r.domain = domain r.domain = domain
r.period = period r.period = period
r.prefer = prefer r.prefer = prefer
r.srcIP = srcIP
r.servers = nss r.servers = nss
r.mux.Unlock() r.mux.Unlock()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment