Commit f907311c authored by Miek Gieben's avatar Miek Gieben

Use context.Context

Rename the old Context to State and use context.Context in the
middleware for intra-middleware communication and more.
parent 523cc0a0
package https
import (
"io/ioutil"
"net/http"
"os"
"testing"
"github.com/miekg/coredns/middleware/redirect"
"github.com/miekg/coredns/server"
"github.com/xenolf/lego/acme"
)
/*
func TestHostQualifies(t *testing.T) {
for i, test := range []struct {
host string
......@@ -330,3 +320,4 @@ func TestMarkQualified(t *testing.T) {
t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count)
}
}
*/
......@@ -4,6 +4,8 @@ import (
"fmt"
"strings"
"golang.org/x/net/context"
"github.com/miekg/coredns/core/parse"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/server"
......@@ -70,7 +72,7 @@ func NewTestController(input string) *Controller {
//
// Used primarily for testing but needs to be exported so
// add-ons can use this as a convenience.
var EmptyNext = middleware.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) (int, error) {
var EmptyNext = middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return 0, nil
})
......
package setup
import (
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/errors"
)
/*
func TestErrors(t *testing.T) {
c := NewTestController(`errors`)
mid, err := Errors(c)
......@@ -154,5 +148,5 @@ func TestErrorsParse(t *testing.T) {
}
}
}
}
*/
package setup
import (
"fmt"
"regexp"
"testing"
"github.com/miekg/coredns/middleware/rewrite"
)
/*
func TestRewrite(t *testing.T) {
c := NewTestController(`rewrite /from /to`)
......@@ -237,5 +230,5 @@ func TestRewriteParse(t *testing.T) {
}
}
}
*/
......@@ -8,6 +8,8 @@ import (
"strings"
"time"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
......@@ -21,10 +23,10 @@ type ErrorHandler struct {
Debug bool // if true, errors are written out to client rather than to a log
}
func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
func (h ErrorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
defer h.recovery(w, r)
rcode, err := h.Next.ServeDNS(w, r)
rcode, err := h.Next.ServeDNS(ctx, w, r)
if err != nil {
errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err)
......
package errors
import (
"bytes"
"errors"
"fmt"
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"github.com/miekg/coredns/middleware"
)
/*
func TestErrors(t *testing.T) {
// create a temporary page
path := filepath.Join(os.TempDir(), "errors_test.html")
......@@ -166,3 +151,4 @@ func genErrorHandler(status int, err error, body string) middleware.Handler {
return status, err
})
}
*/
......@@ -8,6 +8,8 @@ package file
import (
"strings"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
......@@ -26,29 +28,29 @@ type (
}
)
func (f File) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
context := middleware.Context{W: w, Req: r}
qname := context.Name()
func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
qname := state.Name()
zone := middleware.Zones(f.Zones.Names).Matches(qname)
if zone == "" {
return f.Next.ServeDNS(w, r)
return f.Next.ServeDNS(ctx, w, r)
}
names, nodata := f.Zones.Z[zone].lookup(qname, context.QType())
names, nodata := f.Zones.Z[zone].lookup(qname, state.QType())
var answer *dns.Msg
switch {
case nodata:
answer = context.AnswerMessage()
answer = state.AnswerMessage()
answer.Ns = names
case len(names) == 0:
answer = context.AnswerMessage()
answer = state.AnswerMessage()
answer.Ns = names
answer.Rcode = dns.RcodeNameError
case len(names) > 0:
answer = context.AnswerMessage()
answer = state.AnswerMessage()
answer.Answer = names
default:
answer = context.ErrorMessage(dns.RcodeServerFailure)
answer = state.ErrorMessage(dns.RcodeServerFailure)
}
// Check return size, etc. TODO(miek)
w.WriteMsg(answer)
......
package file
import (
"errors"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
/*
var testDir = filepath.Join(os.TempDir(), "caddy_testdir")
var ErrCustom = errors.New("Custom Error")
......@@ -323,3 +314,4 @@ func TestServeHTTPFailingStat(t *testing.T) {
}
}
}
*/
......@@ -4,6 +4,8 @@ package log
import (
"log"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
......@@ -15,7 +17,7 @@ type Logger struct {
ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler
}
func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
for _, rule := range l.Rules {
/*
if middleware.Path(r.URL.Path).Matches(rule.PathScope) {
......@@ -40,7 +42,7 @@ func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
*/
rule = rule
}
return l.Next.ServeDNS(w, r)
return l.Next.ServeDNS(ctx, w, r)
}
// Rule configures the logging middleware.
......
package log
import (
"bytes"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
/*
type erroringMiddleware struct{}
func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
func (erroringMiddleware) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
return http.StatusNotFound, nil
}
......@@ -46,3 +38,4 @@ func TestLoggedStatus(t *testing.T) {
t.Error("Expected 404 to be logged. Logged string -", logged)
}
}
*/
......@@ -5,6 +5,7 @@ import (
"time"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
type (
......@@ -32,18 +33,18 @@ type (
// Otherwise, return values should be propagated down the middleware
// chain by returning them unchanged.
Handler interface {
ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error)
ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
}
// HandlerFunc is a convenience type like dns.HandlerFunc, except
// ServeDNS returns an rcode and an error. See Handler
// documentation for more information.
HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error)
HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
)
// ServeDNS implements the Handler interface.
func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
return f(w, r)
func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return f(ctx, w, r)
}
// IndexFile looks for a file in /root/fpath/indexFile for each string
......
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestIndexfile(t *testing.T) {
tests := []struct {
rootDir http.FileSystem
fpath string
indexFiles []string
shouldErr bool
expectedFilePath string //retun value
expectedBoolValue bool //return value
}{
{
http.Dir("./templates/testdata"),
"/images/",
[]string{"img.htm"},
false,
"/images/img.htm",
true,
},
}
for i, test := range tests {
actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles)
if actualBoolValue == true && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if actualBoolValue != true && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist")
}
if actualFilePath != test.expectedFilePath {
t.Fatalf("Test %d expected returned filepath to be %s, but got %s ",
i, test.expectedFilePath, actualFilePath)
}
if actualBoolValue != test.expectedBoolValue {
t.Fatalf("Test %d expected returned bool value to be %v, but got %v ",
i, test.expectedBoolValue, actualBoolValue)
}
}
}
func TestSetLastModified(t *testing.T) {
nowTime := time.Now()
// ovewrite the function to return reliable time
originalGetCurrentTimeFunc := currentTime
currentTime = func() time.Time {
return nowTime
}
defer func() {
currentTime = originalGetCurrentTimeFunc
}()
pastTime := nowTime.Truncate(1 * time.Hour)
futureTime := nowTime.Add(1 * time.Hour)
tests := []struct {
inputModTime time.Time
expectedIsHeaderSet bool
expectedLastModified string
}{
{
inputModTime: pastTime,
expectedIsHeaderSet: true,
expectedLastModified: pastTime.UTC().Format(http.TimeFormat),
},
{
inputModTime: nowTime,
expectedIsHeaderSet: true,
expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
},
{
inputModTime: futureTime,
expectedIsHeaderSet: true,
expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
},
{
inputModTime: time.Time{},
expectedIsHeaderSet: false,
},
}
for i, test := range tests {
responseRecorder := httptest.NewRecorder()
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
SetLastModifiedHeader(responseRecorder, test.inputModTime)
actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified")
if test.expectedIsHeaderSet && actualLastModifiedHeader == "" {
t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing")
}
if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" {
t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader)
}
if test.expectedLastModified != actualLastModifiedHeader {
t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader)
}
}
}
......@@ -4,15 +4,17 @@ import (
"strconv"
"time"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
context := middleware.Context{W: w, Req: r}
func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{W: w, Req: r}
qname := context.Name()
qtype := context.Type()
qname := state.Name()
qtype := state.Type()
zone := middleware.Zones(m.ZoneNames).Matches(qname)
if zone == "" {
zone = "."
......@@ -20,7 +22,7 @@ func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
// Record response to get status code and size of the reply.
rw := middleware.NewResponseRecorder(w)
status, err := m.Next.ServeDNS(rw, r)
status, err := m.Next.ServeDNS(ctx, rw, r)
requestCount.WithLabelValues(zone, qtype).Inc()
requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second))
......
......@@ -7,6 +7,8 @@ import (
"sync/atomic"
"time"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
......@@ -67,7 +69,7 @@ func (uh *UpstreamHost) Down() bool {
var tryDuration = 60 * time.Second
// ServeDNS satisfies the middleware.Handler interface.
func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
func (p Proxy) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
for _, upstream := range p.Upstreams {
// allowed bla bla bla TODO(miek): fix full proxy spec from caddy
start := time.Now()
......@@ -100,7 +102,7 @@ func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
}
return dns.RcodeServerFailure, errUnreachable
}
return p.Next.ServeDNS(w, r)
return p.Next.ServeDNS(ctx, w, r)
}
func Clients() Client {
......
package proxy
import (
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"golang.org/x/net/websocket"
)
/*
func init() {
tryDuration = 50 * time.Millisecond // prevent tests from hanging
}
......@@ -315,3 +295,4 @@ func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
*/
......@@ -12,15 +12,15 @@ type ReverseProxy struct {
}
func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error {
// TODO(miek): use extra!
// TODO(miek): use extra to EDNS0.
var (
reply *dns.Msg
err error
)
context := middleware.Context{W: w, Req: r}
state := middleware.State{W: w, Req: r}
// tls+tcp ?
if context.Proto() == "tcp" {
if state.Proto() == "tcp" {
reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
} else {
reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
......
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
/*
func TestNewResponseRecorder(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
......@@ -30,3 +25,4 @@ func TestWrite(t *testing.T) {
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
}
}
*/
......@@ -20,6 +20,8 @@ import (
"net"
"strings"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
)
......@@ -28,15 +30,15 @@ type Reflect struct {
Next middleware.Handler
}
func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
context := middleware.Context{Req: r, W: w}
func (rl Reflect) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
state := middleware.State{Req: r, W: w}
class := r.Question[0].Qclass
qname := r.Question[0].Name
i, ok := dns.NextLabel(qname, 0)
if strings.ToLower(qname[:i]) != who || ok {
err := context.ErrorMessage(dns.RcodeFormatError)
err := state.ErrorMessage(dns.RcodeFormatError)
w.WriteMsg(err)
return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError])
}
......@@ -46,10 +48,10 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
answer.Compress = true
answer.Authoritative = true
ip := context.IP()
proto := context.Proto()
port, _ := context.Port()
family := context.Family()
ip := state.IP()
proto := state.Proto()
port, _ := state.Port()
family := state.Family()
var rr dns.RR
switch family {
......@@ -67,7 +69,7 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0}
t.Txt = []string{"Port: " + port + " (" + proto + ")"}
switch context.Type() {
switch state.Type() {
case "TXT":
answer.Answer = append(answer.Answer, t)
answer.Extra = append(answer.Extra, rr)
......
......@@ -29,19 +29,19 @@ type replacer struct {
// available. emptyValue should be the string that is used
// in place of empty string (can still be empty string).
func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer {
context := Context{W: rr, Req: r}
state := State{W: rr, Req: r}
rep := replacer{
replacements: map[string]string{
"{type}": context.Type(),
"{name}": context.Name(),
"{class}": context.Class(),
"{proto}": context.Proto(),
"{type}": state.Type(),
"{name}": state.Name(),
"{class}": state.Class(),
"{proto}": state.Proto(),
"{when}": func() string {
return time.Now().Format(timeFormat)
}(),
"{remote}": context.IP(),
"{remote}": state.IP(),
"{port}": func() string {
p, _ := context.Port()
p, _ := state.Port()
return p
}(),
},
......
package middleware
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
/*
func TestNewReplacer(t *testing.T) {
w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w)
......@@ -122,3 +116,4 @@ func TestSet(t *testing.T) {
t.Error("Expected variable replacement failed")
}
}
*/
package rewrite
import (
"net/http"
"strings"
"testing"
)
/*
func TestConditions(t *testing.T) {
tests := []struct {
condition string
......@@ -104,3 +99,4 @@ func TestConditions(t *testing.T) {
}
}
}
*/
......@@ -5,6 +5,7 @@ package rewrite
import (
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
// Result is the result of a rewrite
......@@ -27,12 +28,12 @@ type Rewrite struct {
}
// ServeHTTP implements the middleware.Handler interface.
func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
wr := NewResponseReverter(w, r)
for _, rule := range rw.Rules {
switch result := rule.Rewrite(r); result {
case RewriteDone:
return rw.Next.ServeDNS(wr, r)
return rw.Next.ServeDNS(ctx, wr, r)
case RewriteIgnored:
break
case RewriteStatus:
......@@ -42,7 +43,7 @@ func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
// }
}
}
return rw.Next.ServeDNS(w, r)
return rw.Next.ServeDNS(ctx, w, r)
}
// Rule describes an internal location rewrite rule.
......
package rewrite
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/miekg/coredns/middleware"
)
/*
func TestRewrite(t *testing.T) {
rw := Rewrite{
Next: middleware.HandlerFunc(urlPrinter),
......@@ -157,3 +148,4 @@ func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprintf(w, r.URL.String())
return 0, nil
}
*/
......@@ -9,45 +9,44 @@ import (
"github.com/miekg/dns"
)
// This file contains the context and functions available for
// use in the templates.
// This file contains the state nd functions available for use in the templates.
// Context is the context with which Caddy templates are executed.
type Context struct {
Root http.FileSystem // TODO(miek): needed
// State contains some connection state and is useful in middleware.
type State struct {
Root http.FileSystem // TODO(miek): needed?
Req *dns.Msg
W dns.ResponseWriter
}
// Now returns the current timestamp in the specified format.
func (c Context) Now(format string) string {
func (s State) Now(format string) string {
return time.Now().Format(format)
}
// NowDate returns the current date/time that can be used
// in other time functions.
func (c Context) NowDate() time.Time {
func (s State) NowDate() time.Time {
return time.Now()
}
// Header gets the value of a header.
func (c Context) Header() *dns.RR_Header {
func (s State) Header() *dns.RR_Header {
// TODO(miek)
return nil
}
// IP gets the (remote) IP address of the client making the request.
func (c Context) IP() string {
ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String())
func (s State) IP() string {
ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String())
if err != nil {
return c.W.RemoteAddr().String()
return s.W.RemoteAddr().String()
}
return ip
}
// Post gets the (remote) Port of the client making the request.
func (c Context) Port() (string, error) {
_, port, err := net.SplitHostPort(c.W.RemoteAddr().String())
func (s State) Port() (string, error) {
_, port, err := net.SplitHostPort(s.W.RemoteAddr().String())
if err != nil {
return "0", err
}
......@@ -56,11 +55,11 @@ func (c Context) Port() (string, error) {
// Proto gets the protocol used as the transport. This
// will be udp or tcp.
func (c Context) Proto() string {
if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok {
func (s State) Proto() string {
if _, ok := s.W.RemoteAddr().(*net.UDPAddr); ok {
return "udp"
}
if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok {
if _, ok := s.W.RemoteAddr().(*net.TCPAddr); ok {
return "tcp"
}
return "udp"
......@@ -68,9 +67,9 @@ func (c Context) Proto() string {
// Family returns the family of the transport.
// 1 for IPv4 and 2 for IPv6.
func (c Context) Family() int {
func (s State) Family() int {
var a net.IP
ip := c.W.RemoteAddr()
ip := s.W.RemoteAddr()
if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP
}
......@@ -85,51 +84,48 @@ func (c Context) Family() int {
}
// Type returns the type of the question as a string.
func (c Context) Type() string {
return dns.Type(c.Req.Question[0].Qtype).String()
func (s State) Type() string {
return dns.Type(s.Req.Question[0].Qtype).String()
}
// QType returns the type of the question as a uint16.
func (c Context) QType() uint16 {
return c.Req.Question[0].Qtype
func (s State) QType() uint16 {
return s.Req.Question[0].Qtype
}
// Name returns the name of the question in the request. Note
// this name will always have a closing dot and will be lower cased.
func (c Context) Name() string {
return strings.ToLower(dns.Name(c.Req.Question[0].Name).String())
func (s State) Name() string {
return strings.ToLower(dns.Name(s.Req.Question[0].Name).String())
}
// QName returns the name of the question in the request.
func (c Context) QName() string {
return dns.Name(c.Req.Question[0].Name).String()
func (s State) QName() string {
return dns.Name(s.Req.Question[0].Name).String()
}
// Class returns the class of the question in the request.
func (c Context) Class() string {
return dns.Class(c.Req.Question[0].Qclass).String()
func (s State) Class() string {
return dns.Class(s.Req.Question[0].Qclass).String()
}
// QClass returns the class of the question in the request.
func (c Context) QClass() uint16 {
return c.Req.Question[0].Qclass
func (s State) QClass() uint16 {
return s.Req.Question[0].Qclass
}
// More convience types for extracting stuff from a message?
// Header?
// ErrorMessage returns an error message suitable for sending
// back to the client.
func (c Context) ErrorMessage(rcode int) *dns.Msg {
func (s State) ErrorMessage(rcode int) *dns.Msg {
m := new(dns.Msg)
m.SetRcode(c.Req, rcode)
m.SetRcode(s.Req, rcode)
return m
}
// AnswerMessage returns an error message suitable for sending
// back to the client.
func (c Context) AnswerMessage() *dns.Msg {
func (s State) AnswerMessage() *dns.Msg {
m := new(dns.Msg)
m.SetReply(c.Req)
m.SetReply(s.Req)
return m
}
......@@ -15,6 +15,8 @@ import (
"sync"
"time"
"golang.org/x/net/context"
"github.com/miekg/dns"
)
......@@ -285,6 +287,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
q := r.Question[0].Name
b := make([]byte, len(q))
off, end := 0, false
ctx := context.Background()
for {
l := len(q[off:])
for i := 0; i < l; i++ {
......@@ -297,7 +300,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if h, ok := s.zones[string(b[:l])]; ok {
if r.Question[0].Qtype != dns.TypeDS {
rcode, _ := h.stack.ServeDNS(w, r)
rcode, _ := h.stack.ServeDNS(ctx, w, r)
if rcode > 0 {
DefaultErrorFunc(w, r, rcode)
}
......@@ -311,7 +314,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
// Wildcard match, if we have found nothing try the root zone as a last resort.
if h, ok := s.zones["."]; ok {
rcode, _ := h.stack.ServeDNS(w, r)
rcode, _ := h.stack.ServeDNS(ctx, w, r)
if rcode > 0 {
DefaultErrorFunc(w, r, rcode)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment