Commit 898b1ef3 authored by Miek Gieben's avatar Miek Gieben Committed by John Belamaric

server: actually scrub response (#2225)

* server: actually scrub response

Did all the worked, hooked it up wrongly :(

This also needs test, but those are hard(er) because we only receive
packets after they have been decoded; i.e. we never see the wirefmt.
Signed-off-by: default avatarMiek Gieben <miek@miek.nl>

* Add tests

Add a test for checking is compression pointers are set in the packet.
This also adds an undocumented 'large' feature to the erratic plugin to
send large responses that should be compressed.

Commenting the Scrub out in server results in:

=== RUN   TestCompressScrub
--- FAIL: TestCompressScrub (0.00s)
    compression_scrub_test.go:41: Expected returned packet to be < 512, got 839
FAIL
exit status 1
FAIL    github.com/coredns/coredns/test 0.036s

Actually checking the size might be easier, but lets be thorough here
and check the pointers them selves.
Signed-off-by: default avatarMiek Gieben <miek@miek.nl>

* Fix tests
Signed-off-by: default avatarMiek Gieben <miek@miek.nl>

* plugin erratic: fix e.large

always put an rr in the reply, fix e.large in erractic and add test to
check for it.
Signed-off-by: default avatarMiek Gieben <miek@miek.nl>
parent 96529b2c
...@@ -19,6 +19,7 @@ type Erratic struct { ...@@ -19,6 +19,7 @@ type Erratic struct {
duration time.Duration duration time.Duration
truncate uint64 truncate uint64
large bool // undocumented feature; return large responses for A request (>512B, to test compression).
q uint64 // counter of queries q uint64 // counter of queries
} }
...@@ -57,6 +58,11 @@ func (e *Erratic) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg ...@@ -57,6 +58,11 @@ func (e *Erratic) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg
rr := *(rrA.(*dns.A)) rr := *(rrA.(*dns.A))
rr.Header().Name = state.QName() rr.Header().Name = state.QName()
m.Answer = append(m.Answer, &rr) m.Answer = append(m.Answer, &rr)
if e.large {
for i := 0; i < 29; i++ {
m.Answer = append(m.Answer, &rr)
}
}
case dns.TypeAAAA: case dns.TypeAAAA:
rr := *(rrAAAA.(*dns.AAAA)) rr := *(rrAAAA.(*dns.AAAA))
rr.Header().Name = state.QName() rr.Header().Name = state.QName()
......
...@@ -98,3 +98,19 @@ func TestAxfr(t *testing.T) { ...@@ -98,3 +98,19 @@ func TestAxfr(t *testing.T) {
t.Errorf("Expected for record to be %d, got %d", dns.TypeSOA, x) t.Errorf("Expected for record to be %d, got %d", dns.TypeSOA, x)
} }
} }
func TestErratic(t *testing.T) {
e := &Erratic{drop: 0, delay: 0}
ctx := context.TODO()
req := new(dns.Msg)
req.SetQuestion("example.org.", dns.TypeA)
rec := dnstest.NewRecorder(&test.ResponseWriter{})
e.ServeDNS(ctx, rec, req)
if rec.Msg.Answer[0].Header().Rrtype != dns.TypeA {
t.Errorf("Expected A response, got %d type", rec.Msg.Answer[0].Header().Rrtype)
}
}
...@@ -104,6 +104,8 @@ func parseErratic(c *caddy.Controller) (*Erratic, error) { ...@@ -104,6 +104,8 @@ func parseErratic(c *caddy.Controller) (*Erratic, error) {
return nil, fmt.Errorf("illegal amount value given %q", args[0]) return nil, fmt.Errorf("illegal amount value given %q", args[0])
} }
e.truncate = uint64(amount) e.truncate = uint64(amount)
case "large":
e.large = true
default: default:
return nil, c.Errf("unknown property '%s'", c.Val()) return nil, c.Errf("unknown property '%s'", c.Val())
} }
......
...@@ -226,11 +226,7 @@ func (r *Request) SizeAndDo(m *dns.Msg) bool { ...@@ -226,11 +226,7 @@ func (r *Request) SizeAndDo(m *dns.Msg) bool {
return true return true
} }
// Scrub is a noop function, added for backwards compatibility reasons. The original Scrub is now called // Scrub scrubs the reply message so that it will fit the client's buffer. It will first
// automatically by the server on writing the reply. See ScrubWriter.
func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, int) { return reply, 0 }
// scrub scrubs the reply message so that it will fit the client's buffer. It will first
// check if the reply fits without compression and then *with* compression. // check if the reply fits without compression and then *with* compression.
// Scrub will then use binary search to find a save cut off point in the additional section. // Scrub will then use binary search to find a save cut off point in the additional section.
// If even *without* the additional section the reply still doesn't fit we // If even *without* the additional section the reply still doesn't fit we
...@@ -238,7 +234,7 @@ func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, int) { return reply, 0 } ...@@ -238,7 +234,7 @@ func (r *Request) Scrub(reply *dns.Msg) (*dns.Msg, int) { return reply, 0 }
// we set the TC bit on the reply; indicating the client should retry over TCP. // we set the TC bit on the reply; indicating the client should retry over TCP.
// Note, the TC bit will be set regardless of protocol, even TCP message will // Note, the TC bit will be set regardless of protocol, even TCP message will
// get the bit, the client should then retry with pigeons. // get the bit, the client should then retry with pigeons.
func (r *Request) scrub(reply *dns.Msg) *dns.Msg { func (r *Request) Scrub(reply *dns.Msg) *dns.Msg {
size := r.Size() size := r.Size()
reply.Compress = false reply.Compress = false
......
...@@ -73,7 +73,7 @@ func TestRequestScrubAnswer(t *testing.T) { ...@@ -73,7 +73,7 @@ func TestRequestScrubAnswer(t *testing.T) {
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
} }
req.scrub(reply) req.Scrub(reply)
if want, got := req.Size(), reply.Len(); want < got { if want, got := req.Size(), reply.Len(); want < got {
t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got)
} }
...@@ -94,7 +94,7 @@ func TestRequestScrubExtra(t *testing.T) { ...@@ -94,7 +94,7 @@ func TestRequestScrubExtra(t *testing.T) {
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
} }
req.scrub(reply) req.Scrub(reply)
if want, got := req.Size(), reply.Len(); want < got { if want, got := req.Size(), reply.Len(); want < got {
t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got)
} }
...@@ -116,7 +116,7 @@ func TestRequestScrubExtraEdns0(t *testing.T) { ...@@ -116,7 +116,7 @@ func TestRequestScrubExtraEdns0(t *testing.T) {
fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i))) fmt.Sprintf("large.example.com. 10 IN SRV 0 0 80 10-0-0-%d.default.pod.k8s.example.com.", i)))
} }
req.scrub(reply) req.Scrub(reply)
if want, got := req.Size(), reply.Len(); want < got { if want, got := req.Size(), reply.Len(); want < got {
t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got)
} }
...@@ -146,7 +146,7 @@ func TestRequestScrubExtraRegression(t *testing.T) { ...@@ -146,7 +146,7 @@ func TestRequestScrubExtraRegression(t *testing.T) {
fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i))) fmt.Sprintf("10-0-0-%d.default.pod.k8s.example.com. 10 IN A 10.0.0.%d", i, i)))
} }
reply = req.scrub(reply) reply = req.Scrub(reply)
if want, got := req.Size(), reply.Len(); want < got { if want, got := req.Size(), reply.Len(); want < got {
t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got)
} }
...@@ -171,7 +171,7 @@ func TestRequestScrubAnswerExact(t *testing.T) { ...@@ -171,7 +171,7 @@ func TestRequestScrubAnswerExact(t *testing.T) {
reply.Answer = append(reply.Answer, test.A(fmt.Sprintf("large.example.com. 10 IN A 127.0.0.%d", i))) reply.Answer = append(reply.Answer, test.A(fmt.Sprintf("large.example.com. 10 IN A 127.0.0.%d", i)))
} }
req.scrub(reply) req.Scrub(reply)
if want, got := req.Size(), reply.Len(); want < got { if want, got := req.Size(), reply.Len(); want < got {
t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got) t.Errorf("Want scrub to reduce message length below %d bytes, got %d bytes", want, got)
} }
......
...@@ -15,6 +15,6 @@ func NewScrubWriter(req *dns.Msg, w dns.ResponseWriter) *ScrubWriter { return &S ...@@ -15,6 +15,6 @@ func NewScrubWriter(req *dns.Msg, w dns.ResponseWriter) *ScrubWriter { return &S
// scrub on the message m and will then write it to the client. // scrub on the message m and will then write it to the client.
func (s *ScrubWriter) WriteMsg(m *dns.Msg) error { func (s *ScrubWriter) WriteMsg(m *dns.Msg) error {
state := Request{Req: s.req, W: s.ResponseWriter} state := Request{Req: s.req, W: s.ResponseWriter}
new, _ := state.Scrub(m) n := state.Scrub(m)
return s.ResponseWriter.WriteMsg(new) return s.ResponseWriter.WriteMsg(n)
} }
package test
import (
"net"
"testing"
"github.com/miekg/dns"
)
func TestCompressScrub(t *testing.T) {
corefile := `example.org:0 {
erratic {
drop 0
delay 0
large
}
}`
i, udp, _, err := CoreDNSServerAndPorts(corefile)
if err != nil {
t.Fatalf("Could not get CoreDNS serving instance: %s", err)
}
defer i.Stop()
c, err := net.Dial("udp", udp)
if err != nil {
t.Fatalf("Could not dial %s", err)
}
m := new(dns.Msg)
m.SetQuestion("example.org.", dns.TypeA)
q, _ := m.Pack()
c.Write(q)
buf := make([]byte, 1024)
n, err := c.Read(buf)
if err != nil || n == 0 {
t.Errorf("Expected reply, got: %s", err)
return
}
if n >= 512 {
t.Fatalf("Expected returned packet to be < 512, got %d", n)
}
buf = buf[:n]
// If there is compression in the returned packet we should look for compression pointers, if found
// the pointers should return to the domain name in the query (the first domain name that's avaiable for
// compression. This means we're looking for a combo where the pointers is detected and the offset is 12
// the position of the first name after the header. The erratic plugin adds 30 RRs that should all be compressed.
found := 0
for i := 0; i < len(buf)-1; i++ {
if buf[i]&0xC0 == 0xC0 {
off := (int(buf[i])^0xC0)<<8 | int(buf[i+1])
if off == 12 {
found++
}
}
}
if found != 30 {
t.Errorf("Failed to find all compression pointers in the packet, wanted 30, got %d", found)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment