Commit b2c221c9 authored by Miek Gieben's avatar Miek Gieben

Add test for CNAME ordering

Add a test for SkyDNS':
https://github.com/skynetservices/skydns/issues/217

Put the CNAME in front for both answer and extra sections. Note that
the etcd middleware seems to already to the correct thing though.
parent 8cf1c897
// +build etcd
package etcd
// etcd needs to be running on http://127.0.0.1:2379
import (
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/etcd/msg"
"github.com/miekg/dns"
)
// Check the ordering of returned cname.
func TestCnameLookup(t *testing.T) {
for _, serv := range servicesCname {
set(t, etc, serv.Key, 0, serv)
defer delete(t, etc, serv.Key)
}
for _, tc := range dnsTestCasesCname {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(tc.Qname), tc.Qtype)
rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{})
_, err := etc.ServeDNS(ctx, rec, m)
if err != nil {
t.Errorf("expected no error, got %v\n", err)
return
}
resp := rec.Msg()
if resp.Rcode != tc.Rcode {
t.Errorf("rcode is %q, expected %q", dns.RcodeToString[resp.Rcode], dns.RcodeToString[tc.Rcode])
t.Logf("%v\n", resp)
continue
}
if len(resp.Answer) != len(tc.Answer) {
t.Errorf("answer for %q contained %d results, %d expected", tc.Qname, len(resp.Answer), len(tc.Answer))
t.Logf("%v\n", resp)
continue
}
if len(resp.Ns) != len(tc.Ns) {
t.Errorf("authority for %q contained %d results, %d expected", tc.Qname, len(resp.Ns), len(tc.Ns))
t.Logf("%v\n", resp)
continue
}
if len(resp.Extra) != len(tc.Extra) {
t.Errorf("additional for %q contained %d results, %d expected", tc.Qname, len(resp.Extra), len(tc.Extra))
t.Logf("%v\n", resp)
continue
}
if !checkSection(t, tc, Answer, resp.Answer) {
t.Logf("%v\n", resp)
}
if !checkSection(t, tc, Ns, resp.Ns) {
t.Logf("%v\n", resp)
}
if !checkSection(t, tc, Extra, resp.Extra) {
t.Logf("%v\n", resp)
}
}
}
var servicesCname = []*msg.Service{
{Host: "cname1.region2.skydns.test", Key: "a.server1.dev.region1.skydns.test."},
{Host: "cname2.region2.skydns.test", Key: "cname1.region2.skydns.test."},
{Host: "cname3.region2.skydns.test", Key: "cname2.region2.skydns.test."},
{Host: "cname4.region2.skydns.test", Key: "cname3.region2.skydns.test."},
{Host: "cname5.region2.skydns.test", Key: "cname4.region2.skydns.test."},
{Host: "cname6.region2.skydns.test", Key: "cname5.region2.skydns.test."},
{Host: "endpoint.region2.skydns.test", Key: "cname6.region2.skydns.test."},
{Host: "10.240.0.1", Key: "endpoint.region2.skydns.test."},
}
var dnsTestCasesCname = []dnsTestCase{
{
Qname: "a.server1.dev.region1.skydns.test.", Qtype: dns.TypeSRV,
Answer: []dns.RR{
newSRV("a.server1.dev.region1.skydns.test. 300 IN SRV 10 100 0 cname1.region2.skydns.test."),
},
Extra: []dns.RR{
newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
newCNAME("cname3.region2.skydns.test. 300 IN CNAME cname4.region2.skydns.test."),
newCNAME("cname4.region2.skydns.test. 300 IN CNAME cname5.region2.skydns.test."),
newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."),
newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
},
},
}
...@@ -14,18 +14,21 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { ...@@ -14,18 +14,21 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
if res.Rcode != dns.RcodeSuccess { if res.Rcode != dns.RcodeSuccess {
return r.ResponseWriter.WriteMsg(res) return r.ResponseWriter.WriteMsg(res)
} }
if len(res.Answer) < 2 { // don't even bother
return r.ResponseWriter.WriteMsg(res)
}
// put CNAMEs first, randomize a/aaaa's and put packet back together. res.Answer = roundRobin(res.Answer)
// TODO(miek): check family and give v6 more prio? res.Extra = roundRobin(res.Extra)
return r.ResponseWriter.WriteMsg(res)
}
func roundRobin(in []dns.RR) []dns.RR {
cname := []dns.RR{} cname := []dns.RR{}
address := []dns.RR{} address := []dns.RR{}
rest := []dns.RR{} rest := []dns.RR{}
for _, r := range res.Answer { for _, r := range in {
switch r.Header().Rrtype { switch r.Header().Rrtype {
case dns.TypeCNAME: case dns.TypeCNAME:
// d d d d DNAME and friends here as well?
cname = append(cname, r) cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA: case dns.TypeA, dns.TypeAAAA:
address = append(address, r) address = append(address, r)
...@@ -36,7 +39,7 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { ...@@ -36,7 +39,7 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
switch l := len(address); l { switch l := len(address); l {
case 0, 1: case 0, 1:
return r.ResponseWriter.WriteMsg(res) break
case 2: case 2:
if dns.Id()%2 == 0 { if dns.Id()%2 == 0 {
address[0], address[1] = address[1], address[0] address[0], address[1] = address[1], address[0]
...@@ -51,9 +54,9 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { ...@@ -51,9 +54,9 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
address[q], address[p] = address[p], address[q] address[q], address[p] = address[p], address[q]
} }
} }
res.Answer = append(cname, rest...) out := append(cname, rest...)
res.Answer = append(res.Answer, address...) out = append(out, address...)
return r.ResponseWriter.WriteMsg(res) return out
} }
// Should we pack and unpack here to fiddle with the packet... Not likely. // Should we pack and unpack here to fiddle with the packet... Not likely.
......
...@@ -4,13 +4,15 @@ ...@@ -4,13 +4,15 @@
message. See [Wikipedia](https://en.wikipedia.org/wiki/Round-robin_DNS) about the pros and cons message. See [Wikipedia](https://en.wikipedia.org/wiki/Round-robin_DNS) about the pros and cons
on this setup. on this setup.
It will take care to sort any CNAMEs before any address records.
## Syntax ## Syntax
~~~ ~~~
loadbalance [policy] loadbalance [policy]
~~~ ~~~
* policy is how to balance, the default is "round_robin" * `policy` is how to balance, the default is "round_robin"
## Examples ## Examples
......
package loadbalance
import (
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
func TestLoadBalance(t *testing.T) {
rm := RoundRobin{Next: handler()}
// the first X records must be cnames after this test
tests := []struct {
answer []dns.RR
extra []dns.RR
cnameAnswer int
cnameExtra int
}{
{
answer: []dns.RR{
newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."),
newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
},
cnameAnswer: 4,
},
{
answer: []dns.RR{
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
},
cnameAnswer: 1,
},
{
answer: []dns.RR{
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
},
extra: []dns.RR{
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"),
},
cnameAnswer: 1,
cnameExtra: 1,
},
}
rec := middleware.NewResponseRecorder(&middleware.TestResponseWriter{})
for i, test := range tests {
req := new(dns.Msg)
req.SetQuestion("region2.skydns.test.", dns.TypeSRV)
req.Answer = test.answer
req.Extra = test.extra
_, err := rm.ServeDNS(context.TODO(), rec, req)
if err != nil {
t.Errorf("Test %d: Expected no error, but got %s", i, err)
continue
}
cname := 0
for _, r := range rec.Msg().Answer {
if r.Header().Rrtype != dns.TypeCNAME {
break
}
cname++
}
if cname != test.cnameAnswer {
t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname)
}
cname = 0
for _, r := range rec.Msg().Extra {
if r.Header().Rrtype != dns.TypeCNAME {
break
}
cname++
}
if cname != test.cnameExtra {
t.Errorf("Test %d: Expected %d cname in Extra, but got %d", i, test.cnameExtra, cname)
}
}
}
func handler() middleware.Handler {
return middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
w.WriteMsg(r)
return dns.RcodeSuccess, nil
})
}
func newA(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) }
func newAAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) }
func newCNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment