Commit 7ce71001 authored by Jonathan Dickinson's avatar Jonathan Dickinson Committed by Miek Gieben

- Adding tests for MX round-robin (#358)

- Implementing MX round-robin
- Slight tidy
parent 219bfd04
...@@ -28,6 +28,7 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error { ...@@ -28,6 +28,7 @@ func (r *RoundRobinResponseWriter) WriteMsg(res *dns.Msg) error {
func roundRobin(in []dns.RR) []dns.RR { func roundRobin(in []dns.RR) []dns.RR {
cname := []dns.RR{} cname := []dns.RR{}
address := []dns.RR{} address := []dns.RR{}
mx := []dns.RR{}
rest := []dns.RR{} rest := []dns.RR{}
for _, r := range in { for _, r := range in {
switch r.Header().Rrtype { switch r.Header().Rrtype {
...@@ -35,17 +36,29 @@ func roundRobin(in []dns.RR) []dns.RR { ...@@ -35,17 +36,29 @@ func roundRobin(in []dns.RR) []dns.RR {
cname = append(cname, r) cname = append(cname, r)
case dns.TypeA, dns.TypeAAAA: case dns.TypeA, dns.TypeAAAA:
address = append(address, r) address = append(address, r)
case dns.TypeMX:
mx = append(mx, r)
default: default:
rest = append(rest, r) rest = append(rest, r)
} }
} }
switch l := len(address); l { roundRobinShuffle(address)
roundRobinShuffle(mx)
out := append(cname, rest...)
out = append(out, address...)
out = append(out, mx...)
return out
}
func roundRobinShuffle(records []dns.RR) {
switch l := len(records); l {
case 0, 1: case 0, 1:
break break
case 2: case 2:
if dns.Id()%2 == 0 { if dns.Id()%2 == 0 {
address[0], address[1] = address[1], address[0] records[0], records[1] = records[1], records[0]
} }
default: default:
for j := 0; j < l*(int(dns.Id())%4+1); j++ { for j := 0; j < l*(int(dns.Id())%4+1); j++ {
...@@ -54,12 +67,9 @@ func roundRobin(in []dns.RR) []dns.RR { ...@@ -54,12 +67,9 @@ func roundRobin(in []dns.RR) []dns.RR {
if q == p { if q == p {
p = (p + 1) % l p = (p + 1) % l
} }
address[q], address[p] = address[p], address[q] records[q], records[p] = records[p], records[q]
} }
} }
out := append(cname, rest...)
out = append(out, address...)
return out
} }
// Write implements the dns.ResponseWriter interface. // Write implements the dns.ResponseWriter interface.
......
...@@ -16,44 +16,66 @@ func TestLoadBalance(t *testing.T) { ...@@ -16,44 +16,66 @@ func TestLoadBalance(t *testing.T) {
// the first X records must be cnames after this test // the first X records must be cnames after this test
tests := []struct { tests := []struct {
answer []dns.RR answer []dns.RR
extra []dns.RR extra []dns.RR
cnameAnswer int cnameAnswer int
cnameExtra int cnameExtra int
addressAnswer int
addressExtra int
mxAnswer int
mxExtra int
}{ }{
{ {
answer: []dns.RR{ answer: []dns.RR{
newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."), newCNAME("cname1.region2.skydns.test. 300 IN CNAME cname2.region2.skydns.test."),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."), newCNAME("cname5.region2.skydns.test. 300 IN CNAME cname6.region2.skydns.test."),
newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), newCNAME("cname6.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
newMX("mx.region2.skydns.test. 300 IN MX 2 mx2.region2.skydns.test."),
newMX("mx.region2.skydns.test. 300 IN MX 3 mx3.region2.skydns.test."),
}, },
cnameAnswer: 4, cnameAnswer: 4,
addressAnswer: 1,
mxAnswer: 3,
}, },
{ {
answer: []dns.RR{ answer: []dns.RR{
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."), newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
newCNAME("cname.region2.skydns.test. 300 IN CNAME endpoint.region2.skydns.test."),
}, },
cnameAnswer: 1, cnameAnswer: 1,
addressAnswer: 1,
mxAnswer: 1,
}, },
{ {
answer: []dns.RR{ answer: []dns.RR{
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"), newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.2"),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), newMX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
newMX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."),
}, },
extra: []dns.RR{ extra: []dns.RR{
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"), newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.1"),
newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"), newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::1"),
newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."), newMX("mx.region2.skydns.test. 300 IN MX 1 mx1.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"), newCNAME("cname2.region2.skydns.test. 300 IN CNAME cname3.region2.skydns.test."),
newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"), newMX("mx.region2.skydns.test. 300 IN MX 1 mx2.region2.skydns.test."),
newA("endpoint.region2.skydns.test. 300 IN A 10.240.0.3"),
newAAAA("endpoint.region2.skydns.test. 300 IN AAAA ::2"),
newMX("mx.region2.skydns.test. 300 IN MX 1 mx3.region2.skydns.test."),
}, },
cnameAnswer: 1, cnameAnswer: 1,
cnameExtra: 1, cnameExtra: 1,
addressAnswer: 3,
addressExtra: 4,
mxAnswer: 3,
mxExtra: 3,
}, },
} }
...@@ -71,27 +93,71 @@ func TestLoadBalance(t *testing.T) { ...@@ -71,27 +93,71 @@ func TestLoadBalance(t *testing.T) {
continue continue
} }
cname := 0
for _, r := range rec.Msg.Answer { cname, address, mx, sorted := countRecords(rec.Msg.Answer)
if r.Header().Rrtype != dns.TypeCNAME { if !sorted {
break t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Answer, but got mixed", i)
}
cname++
} }
if cname != test.cnameAnswer { if cname != test.cnameAnswer {
t.Errorf("Test %d: Expected %d cnames in Answer, but got %d", i, test.cnameAnswer, cname) t.Errorf("Test %d: Expected %d CNAMEs in Answer, but got %d", i, test.cnameAnswer, cname)
} }
cname = 0 if address != test.addressAnswer {
for _, r := range rec.Msg.Extra { t.Errorf("Test %d: Expected %d A/AAAAs in Answer, but got %d", i, test.addressAnswer, address)
if r.Header().Rrtype != dns.TypeCNAME { }
break if mx != test.mxAnswer {
} t.Errorf("Test %d: Expected %d MXs in Answer, but got %d", i, test.mxAnswer, mx)
cname++ }
cname, address, mx, sorted = countRecords(rec.Msg.Extra)
if !sorted {
t.Errorf("Test %d: Expected CNAMEs, then AAAAs, then MX in Extra, but got mixed", i)
} }
if cname != test.cnameExtra { if cname != test.cnameExtra {
t.Errorf("Test %d: Expected %d cname in Extra, but got %d", i, test.cnameExtra, cname) t.Errorf("Test %d: Expected %d CNAMEs in Extra, but got %d", i, test.cnameAnswer, cname)
}
if address != test.addressExtra {
t.Errorf("Test %d: Expected %d A/AAAAs in Extra, but got %d", i, test.addressAnswer, address)
}
if mx != test.mxExtra {
t.Errorf("Test %d: Expected %d MXs in Extra, but got %d", i, test.mxAnswer, mx)
}
}
}
func countRecords(result []dns.RR) (cname int, address int, mx int, sorted bool) {
const (
Start = iota
CNAMERecords
ARecords
MXRecords
Any
)
// The order of the records is used to determine if the round-robin actually did anything.
sorted = true
cname = 0
address = 0
mx = 0
state := Start
for _, r := range result {
switch r.Header().Rrtype {
case dns.TypeCNAME:
sorted = sorted && state <= CNAMERecords
state = CNAMERecords
cname++
case dns.TypeA, dns.TypeAAAA:
sorted = sorted && state <= ARecords
state = ARecords
address++
case dns.TypeMX:
sorted = sorted && state <= MXRecords
state = MXRecords
mx++
default:
state = Any
} }
} }
return
} }
func handler() middleware.Handler { func handler() middleware.Handler {
...@@ -104,3 +170,4 @@ func handler() middleware.Handler { ...@@ -104,3 +170,4 @@ func handler() middleware.Handler {
func newA(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) } func newA(rr string) *dns.A { r, _ := dns.NewRR(rr); return r.(*dns.A) }
func newAAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) } func newAAAA(rr string) *dns.AAAA { r, _ := dns.NewRR(rr); return r.(*dns.AAAA) }
func newCNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) } func newCNAME(rr string) *dns.CNAME { r, _ := dns.NewRR(rr); return r.(*dns.CNAME) }
func newMX(rr string) *dns.MX { r, _ := dns.NewRR(rr); return r.(*dns.MX) }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment