Commit f907311c authored by Miek Gieben's avatar Miek Gieben

Use context.Context

Rename the old Context to State and use context.Context in the
middleware for intra-middleware communication and more.
parent 523cc0a0
package https package https
import ( /*
"io/ioutil"
"net/http"
"os"
"testing"
"github.com/miekg/coredns/middleware/redirect"
"github.com/miekg/coredns/server"
"github.com/xenolf/lego/acme"
)
func TestHostQualifies(t *testing.T) { func TestHostQualifies(t *testing.T) {
for i, test := range []struct { for i, test := range []struct {
host string host string
...@@ -330,3 +320,4 @@ func TestMarkQualified(t *testing.T) { ...@@ -330,3 +320,4 @@ func TestMarkQualified(t *testing.T) {
t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count) t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count)
} }
} }
*/
...@@ -4,6 +4,8 @@ import ( ...@@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"strings" "strings"
"golang.org/x/net/context"
"github.com/miekg/coredns/core/parse" "github.com/miekg/coredns/core/parse"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/server" "github.com/miekg/coredns/server"
...@@ -70,7 +72,7 @@ func NewTestController(input string) *Controller { ...@@ -70,7 +72,7 @@ func NewTestController(input string) *Controller {
// //
// Used primarily for testing but needs to be exported so // Used primarily for testing but needs to be exported so
// add-ons can use this as a convenience. // add-ons can use this as a convenience.
var EmptyNext = middleware.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) (int, error) { var EmptyNext = middleware.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return 0, nil return 0, nil
}) })
......
package setup package setup
import ( /*
"testing"
"github.com/miekg/coredns/middleware"
"github.com/miekg/coredns/middleware/errors"
)
func TestErrors(t *testing.T) { func TestErrors(t *testing.T) {
c := NewTestController(`errors`) c := NewTestController(`errors`)
mid, err := Errors(c) mid, err := Errors(c)
...@@ -154,5 +148,5 @@ func TestErrorsParse(t *testing.T) { ...@@ -154,5 +148,5 @@ func TestErrorsParse(t *testing.T) {
} }
} }
} }
} }
*/
package setup package setup
import ( /*
"fmt"
"regexp"
"testing"
"github.com/miekg/coredns/middleware/rewrite"
)
func TestRewrite(t *testing.T) { func TestRewrite(t *testing.T) {
c := NewTestController(`rewrite /from /to`) c := NewTestController(`rewrite /from /to`)
...@@ -237,5 +230,5 @@ func TestRewriteParse(t *testing.T) { ...@@ -237,5 +230,5 @@ func TestRewriteParse(t *testing.T) {
} }
} }
} }
*/
...@@ -8,6 +8,8 @@ import ( ...@@ -8,6 +8,8 @@ import (
"strings" "strings"
"time" "time"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
...@@ -21,10 +23,10 @@ type ErrorHandler struct { ...@@ -21,10 +23,10 @@ type ErrorHandler struct {
Debug bool // if true, errors are written out to client rather than to a log Debug bool // if true, errors are written out to client rather than to a log
} }
func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { func (h ErrorHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
defer h.recovery(w, r) defer h.recovery(w, r)
rcode, err := h.Next.ServeDNS(w, r) rcode, err := h.Next.ServeDNS(ctx, w, r)
if err != nil { if err != nil {
errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err) errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err)
......
package errors package errors
import ( /*
"bytes"
"errors"
"fmt"
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"github.com/miekg/coredns/middleware"
)
func TestErrors(t *testing.T) { func TestErrors(t *testing.T) {
// create a temporary page // create a temporary page
path := filepath.Join(os.TempDir(), "errors_test.html") path := filepath.Join(os.TempDir(), "errors_test.html")
...@@ -166,3 +151,4 @@ func genErrorHandler(status int, err error, body string) middleware.Handler { ...@@ -166,3 +151,4 @@ func genErrorHandler(status int, err error, body string) middleware.Handler {
return status, err return status, err
}) })
} }
*/
...@@ -8,6 +8,8 @@ package file ...@@ -8,6 +8,8 @@ package file
import ( import (
"strings" "strings"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
...@@ -26,29 +28,29 @@ type ( ...@@ -26,29 +28,29 @@ type (
} }
) )
func (f File) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { func (f File) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
context := middleware.Context{W: w, Req: r} state := middleware.State{W: w, Req: r}
qname := context.Name() qname := state.Name()
zone := middleware.Zones(f.Zones.Names).Matches(qname) zone := middleware.Zones(f.Zones.Names).Matches(qname)
if zone == "" { if zone == "" {
return f.Next.ServeDNS(w, r) return f.Next.ServeDNS(ctx, w, r)
} }
names, nodata := f.Zones.Z[zone].lookup(qname, context.QType()) names, nodata := f.Zones.Z[zone].lookup(qname, state.QType())
var answer *dns.Msg var answer *dns.Msg
switch { switch {
case nodata: case nodata:
answer = context.AnswerMessage() answer = state.AnswerMessage()
answer.Ns = names answer.Ns = names
case len(names) == 0: case len(names) == 0:
answer = context.AnswerMessage() answer = state.AnswerMessage()
answer.Ns = names answer.Ns = names
answer.Rcode = dns.RcodeNameError answer.Rcode = dns.RcodeNameError
case len(names) > 0: case len(names) > 0:
answer = context.AnswerMessage() answer = state.AnswerMessage()
answer.Answer = names answer.Answer = names
default: default:
answer = context.ErrorMessage(dns.RcodeServerFailure) answer = state.ErrorMessage(dns.RcodeServerFailure)
} }
// Check return size, etc. TODO(miek) // Check return size, etc. TODO(miek)
w.WriteMsg(answer) w.WriteMsg(answer)
......
package file package file
import ( /*
"errors"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)
var testDir = filepath.Join(os.TempDir(), "caddy_testdir") var testDir = filepath.Join(os.TempDir(), "caddy_testdir")
var ErrCustom = errors.New("Custom Error") var ErrCustom = errors.New("Custom Error")
...@@ -323,3 +314,4 @@ func TestServeHTTPFailingStat(t *testing.T) { ...@@ -323,3 +314,4 @@ func TestServeHTTPFailingStat(t *testing.T) {
} }
} }
} }
*/
...@@ -4,6 +4,8 @@ package log ...@@ -4,6 +4,8 @@ package log
import ( import (
"log" "log"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
...@@ -15,7 +17,7 @@ type Logger struct { ...@@ -15,7 +17,7 @@ type Logger struct {
ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler
} }
func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { func (l Logger) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
for _, rule := range l.Rules { for _, rule := range l.Rules {
/* /*
if middleware.Path(r.URL.Path).Matches(rule.PathScope) { if middleware.Path(r.URL.Path).Matches(rule.PathScope) {
...@@ -40,7 +42,7 @@ func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { ...@@ -40,7 +42,7 @@ func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
*/ */
rule = rule rule = rule
} }
return l.Next.ServeDNS(w, r) return l.Next.ServeDNS(ctx, w, r)
} }
// Rule configures the logging middleware. // Rule configures the logging middleware.
......
package log package log
import ( /*
"bytes"
"log"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
type erroringMiddleware struct{} type erroringMiddleware struct{}
func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { func (erroringMiddleware) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
return http.StatusNotFound, nil return http.StatusNotFound, nil
} }
...@@ -46,3 +38,4 @@ func TestLoggedStatus(t *testing.T) { ...@@ -46,3 +38,4 @@ func TestLoggedStatus(t *testing.T) {
t.Error("Expected 404 to be logged. Logged string -", logged) t.Error("Expected 404 to be logged. Logged string -", logged)
} }
} }
*/
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context"
) )
type ( type (
...@@ -32,18 +33,18 @@ type ( ...@@ -32,18 +33,18 @@ type (
// Otherwise, return values should be propagated down the middleware // Otherwise, return values should be propagated down the middleware
// chain by returning them unchanged. // chain by returning them unchanged.
Handler interface { Handler interface {
ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error) ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
} }
// HandlerFunc is a convenience type like dns.HandlerFunc, except // HandlerFunc is a convenience type like dns.HandlerFunc, except
// ServeDNS returns an rcode and an error. See Handler // ServeDNS returns an rcode and an error. See Handler
// documentation for more information. // documentation for more information.
HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error) HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
) )
// ServeDNS implements the Handler interface. // ServeDNS implements the Handler interface.
func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return f(w, r) return f(ctx, w, r)
} }
// IndexFile looks for a file in /root/fpath/indexFile for each string // IndexFile looks for a file in /root/fpath/indexFile for each string
......
package middleware package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestIndexfile(t *testing.T) {
tests := []struct {
rootDir http.FileSystem
fpath string
indexFiles []string
shouldErr bool
expectedFilePath string //retun value
expectedBoolValue bool //return value
}{
{
http.Dir("./templates/testdata"),
"/images/",
[]string{"img.htm"},
false,
"/images/img.htm",
true,
},
}
for i, test := range tests {
actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles)
if actualBoolValue == true && test.shouldErr {
t.Errorf("Test %d didn't error, but it should have", i)
} else if actualBoolValue != true && !test.shouldErr {
t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist")
}
if actualFilePath != test.expectedFilePath {
t.Fatalf("Test %d expected returned filepath to be %s, but got %s ",
i, test.expectedFilePath, actualFilePath)
}
if actualBoolValue != test.expectedBoolValue {
t.Fatalf("Test %d expected returned bool value to be %v, but got %v ",
i, test.expectedBoolValue, actualBoolValue)
}
}
}
func TestSetLastModified(t *testing.T) {
nowTime := time.Now()
// ovewrite the function to return reliable time
originalGetCurrentTimeFunc := currentTime
currentTime = func() time.Time {
return nowTime
}
defer func() {
currentTime = originalGetCurrentTimeFunc
}()
pastTime := nowTime.Truncate(1 * time.Hour)
futureTime := nowTime.Add(1 * time.Hour)
tests := []struct {
inputModTime time.Time
expectedIsHeaderSet bool
expectedLastModified string
}{
{
inputModTime: pastTime,
expectedIsHeaderSet: true,
expectedLastModified: pastTime.UTC().Format(http.TimeFormat),
},
{
inputModTime: nowTime,
expectedIsHeaderSet: true,
expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
},
{
inputModTime: futureTime,
expectedIsHeaderSet: true,
expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
},
{
inputModTime: time.Time{},
expectedIsHeaderSet: false,
},
}
for i, test := range tests {
responseRecorder := httptest.NewRecorder()
errorPrefix := fmt.Sprintf("Test [%d]: ", i)
SetLastModifiedHeader(responseRecorder, test.inputModTime)
actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified")
if test.expectedIsHeaderSet && actualLastModifiedHeader == "" {
t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing")
}
if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" {
t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader)
}
if test.expectedLastModified != actualLastModifiedHeader {
t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader)
}
}
}
...@@ -4,15 +4,17 @@ import ( ...@@ -4,15 +4,17 @@ import (
"strconv" "strconv"
"time" "time"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { func (m *Metrics) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
context := middleware.Context{W: w, Req: r} state := middleware.State{W: w, Req: r}
qname := context.Name() qname := state.Name()
qtype := context.Type() qtype := state.Type()
zone := middleware.Zones(m.ZoneNames).Matches(qname) zone := middleware.Zones(m.ZoneNames).Matches(qname)
if zone == "" { if zone == "" {
zone = "." zone = "."
...@@ -20,7 +22,7 @@ func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { ...@@ -20,7 +22,7 @@ func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
// Record response to get status code and size of the reply. // Record response to get status code and size of the reply.
rw := middleware.NewResponseRecorder(w) rw := middleware.NewResponseRecorder(w)
status, err := m.Next.ServeDNS(rw, r) status, err := m.Next.ServeDNS(ctx, rw, r)
requestCount.WithLabelValues(zone, qtype).Inc() requestCount.WithLabelValues(zone, qtype).Inc()
requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second)) requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second))
......
...@@ -7,6 +7,8 @@ import ( ...@@ -7,6 +7,8 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
...@@ -67,7 +69,7 @@ func (uh *UpstreamHost) Down() bool { ...@@ -67,7 +69,7 @@ func (uh *UpstreamHost) Down() bool {
var tryDuration = 60 * time.Second var tryDuration = 60 * time.Second
// ServeDNS satisfies the middleware.Handler interface. // ServeDNS satisfies the middleware.Handler interface.
func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { func (p Proxy) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
for _, upstream := range p.Upstreams { for _, upstream := range p.Upstreams {
// allowed bla bla bla TODO(miek): fix full proxy spec from caddy // allowed bla bla bla TODO(miek): fix full proxy spec from caddy
start := time.Now() start := time.Now()
...@@ -100,7 +102,7 @@ func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { ...@@ -100,7 +102,7 @@ func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
} }
return dns.RcodeServerFailure, errUnreachable return dns.RcodeServerFailure, errUnreachable
} }
return p.Next.ServeDNS(w, r) return p.Next.ServeDNS(ctx, w, r)
} }
func Clients() Client { func Clients() Client {
......
package proxy package proxy
import ( /*
"bufio"
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"golang.org/x/net/websocket"
)
func init() { func init() {
tryDuration = 50 * time.Millisecond // prevent tests from hanging tryDuration = 50 * time.Millisecond // prevent tests from hanging
} }
...@@ -315,3 +295,4 @@ func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil } ...@@ -315,3 +295,4 @@ func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
func (c *fakeConn) Close() error { return nil } func (c *fakeConn) Close() error { return nil }
func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) } func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) } func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
*/
...@@ -12,15 +12,15 @@ type ReverseProxy struct { ...@@ -12,15 +12,15 @@ type ReverseProxy struct {
} }
func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error { func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error {
// TODO(miek): use extra! // TODO(miek): use extra to EDNS0.
var ( var (
reply *dns.Msg reply *dns.Msg
err error err error
) )
context := middleware.Context{W: w, Req: r} state := middleware.State{W: w, Req: r}
// tls+tcp ? // tls+tcp ?
if context.Proto() == "tcp" { if state.Proto() == "tcp" {
reply, err = middleware.Exchange(p.Client.TCP, r, p.Host) reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
} else { } else {
reply, err = middleware.Exchange(p.Client.UDP, r, p.Host) reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
......
package middleware package middleware
import ( /*
"net/http"
"net/http/httptest"
"testing"
)
func TestNewResponseRecorder(t *testing.T) { func TestNewResponseRecorder(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w) recordRequest := NewResponseRecorder(w)
...@@ -30,3 +25,4 @@ func TestWrite(t *testing.T) { ...@@ -30,3 +25,4 @@ func TestWrite(t *testing.T) {
t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String()) t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
} }
} }
*/
...@@ -20,6 +20,8 @@ import ( ...@@ -20,6 +20,8 @@ import (
"net" "net"
"strings" "strings"
"golang.org/x/net/context"
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
...@@ -28,15 +30,15 @@ type Reflect struct { ...@@ -28,15 +30,15 @@ type Reflect struct {
Next middleware.Handler Next middleware.Handler
} }
func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { func (rl Reflect) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
context := middleware.Context{Req: r, W: w} state := middleware.State{Req: r, W: w}
class := r.Question[0].Qclass class := r.Question[0].Qclass
qname := r.Question[0].Name qname := r.Question[0].Name
i, ok := dns.NextLabel(qname, 0) i, ok := dns.NextLabel(qname, 0)
if strings.ToLower(qname[:i]) != who || ok { if strings.ToLower(qname[:i]) != who || ok {
err := context.ErrorMessage(dns.RcodeFormatError) err := state.ErrorMessage(dns.RcodeFormatError)
w.WriteMsg(err) w.WriteMsg(err)
return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError]) return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError])
} }
...@@ -46,10 +48,10 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { ...@@ -46,10 +48,10 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
answer.Compress = true answer.Compress = true
answer.Authoritative = true answer.Authoritative = true
ip := context.IP() ip := state.IP()
proto := context.Proto() proto := state.Proto()
port, _ := context.Port() port, _ := state.Port()
family := context.Family() family := state.Family()
var rr dns.RR var rr dns.RR
switch family { switch family {
...@@ -67,7 +69,7 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { ...@@ -67,7 +69,7 @@ func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0} t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0}
t.Txt = []string{"Port: " + port + " (" + proto + ")"} t.Txt = []string{"Port: " + port + " (" + proto + ")"}
switch context.Type() { switch state.Type() {
case "TXT": case "TXT":
answer.Answer = append(answer.Answer, t) answer.Answer = append(answer.Answer, t)
answer.Extra = append(answer.Extra, rr) answer.Extra = append(answer.Extra, rr)
......
...@@ -29,19 +29,19 @@ type replacer struct { ...@@ -29,19 +29,19 @@ type replacer struct {
// available. emptyValue should be the string that is used // available. emptyValue should be the string that is used
// in place of empty string (can still be empty string). // in place of empty string (can still be empty string).
func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer { func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer {
context := Context{W: rr, Req: r} state := State{W: rr, Req: r}
rep := replacer{ rep := replacer{
replacements: map[string]string{ replacements: map[string]string{
"{type}": context.Type(), "{type}": state.Type(),
"{name}": context.Name(), "{name}": state.Name(),
"{class}": context.Class(), "{class}": state.Class(),
"{proto}": context.Proto(), "{proto}": state.Proto(),
"{when}": func() string { "{when}": func() string {
return time.Now().Format(timeFormat) return time.Now().Format(timeFormat)
}(), }(),
"{remote}": context.IP(), "{remote}": state.IP(),
"{port}": func() string { "{port}": func() string {
p, _ := context.Port() p, _ := state.Port()
return p return p
}(), }(),
}, },
......
package middleware package middleware
import ( /*
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestNewReplacer(t *testing.T) { func TestNewReplacer(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
recordRequest := NewResponseRecorder(w) recordRequest := NewResponseRecorder(w)
...@@ -122,3 +116,4 @@ func TestSet(t *testing.T) { ...@@ -122,3 +116,4 @@ func TestSet(t *testing.T) {
t.Error("Expected variable replacement failed") t.Error("Expected variable replacement failed")
} }
} }
*/
package rewrite package rewrite
import ( /*
"net/http"
"strings"
"testing"
)
func TestConditions(t *testing.T) { func TestConditions(t *testing.T) {
tests := []struct { tests := []struct {
condition string condition string
...@@ -104,3 +99,4 @@ func TestConditions(t *testing.T) { ...@@ -104,3 +99,4 @@ func TestConditions(t *testing.T) {
} }
} }
} }
*/
...@@ -5,6 +5,7 @@ package rewrite ...@@ -5,6 +5,7 @@ package rewrite
import ( import (
"github.com/miekg/coredns/middleware" "github.com/miekg/coredns/middleware"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/context"
) )
// Result is the result of a rewrite // Result is the result of a rewrite
...@@ -27,12 +28,12 @@ type Rewrite struct { ...@@ -27,12 +28,12 @@ type Rewrite struct {
} }
// ServeHTTP implements the middleware.Handler interface. // ServeHTTP implements the middleware.Handler interface.
func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
wr := NewResponseReverter(w, r) wr := NewResponseReverter(w, r)
for _, rule := range rw.Rules { for _, rule := range rw.Rules {
switch result := rule.Rewrite(r); result { switch result := rule.Rewrite(r); result {
case RewriteDone: case RewriteDone:
return rw.Next.ServeDNS(wr, r) return rw.Next.ServeDNS(ctx, wr, r)
case RewriteIgnored: case RewriteIgnored:
break break
case RewriteStatus: case RewriteStatus:
...@@ -42,7 +43,7 @@ func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { ...@@ -42,7 +43,7 @@ func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
// } // }
} }
} }
return rw.Next.ServeDNS(w, r) return rw.Next.ServeDNS(ctx, w, r)
} }
// Rule describes an internal location rewrite rule. // Rule describes an internal location rewrite rule.
......
package rewrite package rewrite
import ( /*
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/miekg/coredns/middleware"
)
func TestRewrite(t *testing.T) { func TestRewrite(t *testing.T) {
rw := Rewrite{ rw := Rewrite{
Next: middleware.HandlerFunc(urlPrinter), Next: middleware.HandlerFunc(urlPrinter),
...@@ -157,3 +148,4 @@ func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) { ...@@ -157,3 +148,4 @@ func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
fmt.Fprintf(w, r.URL.String()) fmt.Fprintf(w, r.URL.String())
return 0, nil return 0, nil
} }
*/
...@@ -9,45 +9,44 @@ import ( ...@@ -9,45 +9,44 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// This file contains the context and functions available for // This file contains the state nd functions available for use in the templates.
// use in the templates.
// Context is the context with which Caddy templates are executed. // State contains some connection state and is useful in middleware.
type Context struct { type State struct {
Root http.FileSystem // TODO(miek): needed Root http.FileSystem // TODO(miek): needed?
Req *dns.Msg Req *dns.Msg
W dns.ResponseWriter W dns.ResponseWriter
} }
// Now returns the current timestamp in the specified format. // Now returns the current timestamp in the specified format.
func (c Context) Now(format string) string { func (s State) Now(format string) string {
return time.Now().Format(format) return time.Now().Format(format)
} }
// NowDate returns the current date/time that can be used // NowDate returns the current date/time that can be used
// in other time functions. // in other time functions.
func (c Context) NowDate() time.Time { func (s State) NowDate() time.Time {
return time.Now() return time.Now()
} }
// Header gets the value of a header. // Header gets the value of a header.
func (c Context) Header() *dns.RR_Header { func (s State) Header() *dns.RR_Header {
// TODO(miek) // TODO(miek)
return nil return nil
} }
// IP gets the (remote) IP address of the client making the request. // IP gets the (remote) IP address of the client making the request.
func (c Context) IP() string { func (s State) IP() string {
ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String()) ip, _, err := net.SplitHostPort(s.W.RemoteAddr().String())
if err != nil { if err != nil {
return c.W.RemoteAddr().String() return s.W.RemoteAddr().String()
} }
return ip return ip
} }
// Post gets the (remote) Port of the client making the request. // Post gets the (remote) Port of the client making the request.
func (c Context) Port() (string, error) { func (s State) Port() (string, error) {
_, port, err := net.SplitHostPort(c.W.RemoteAddr().String()) _, port, err := net.SplitHostPort(s.W.RemoteAddr().String())
if err != nil { if err != nil {
return "0", err return "0", err
} }
...@@ -56,11 +55,11 @@ func (c Context) Port() (string, error) { ...@@ -56,11 +55,11 @@ func (c Context) Port() (string, error) {
// Proto gets the protocol used as the transport. This // Proto gets the protocol used as the transport. This
// will be udp or tcp. // will be udp or tcp.
func (c Context) Proto() string { func (s State) Proto() string {
if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok { if _, ok := s.W.RemoteAddr().(*net.UDPAddr); ok {
return "udp" return "udp"
} }
if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok { if _, ok := s.W.RemoteAddr().(*net.TCPAddr); ok {
return "tcp" return "tcp"
} }
return "udp" return "udp"
...@@ -68,9 +67,9 @@ func (c Context) Proto() string { ...@@ -68,9 +67,9 @@ func (c Context) Proto() string {
// Family returns the family of the transport. // Family returns the family of the transport.
// 1 for IPv4 and 2 for IPv6. // 1 for IPv4 and 2 for IPv6.
func (c Context) Family() int { func (s State) Family() int {
var a net.IP var a net.IP
ip := c.W.RemoteAddr() ip := s.W.RemoteAddr()
if i, ok := ip.(*net.UDPAddr); ok { if i, ok := ip.(*net.UDPAddr); ok {
a = i.IP a = i.IP
} }
...@@ -85,51 +84,48 @@ func (c Context) Family() int { ...@@ -85,51 +84,48 @@ func (c Context) Family() int {
} }
// Type returns the type of the question as a string. // Type returns the type of the question as a string.
func (c Context) Type() string { func (s State) Type() string {
return dns.Type(c.Req.Question[0].Qtype).String() return dns.Type(s.Req.Question[0].Qtype).String()
} }
// QType returns the type of the question as a uint16. // QType returns the type of the question as a uint16.
func (c Context) QType() uint16 { func (s State) QType() uint16 {
return c.Req.Question[0].Qtype return s.Req.Question[0].Qtype
} }
// Name returns the name of the question in the request. Note // Name returns the name of the question in the request. Note
// this name will always have a closing dot and will be lower cased. // this name will always have a closing dot and will be lower cased.
func (c Context) Name() string { func (s State) Name() string {
return strings.ToLower(dns.Name(c.Req.Question[0].Name).String()) return strings.ToLower(dns.Name(s.Req.Question[0].Name).String())
} }
// QName returns the name of the question in the request. // QName returns the name of the question in the request.
func (c Context) QName() string { func (s State) QName() string {
return dns.Name(c.Req.Question[0].Name).String() return dns.Name(s.Req.Question[0].Name).String()
} }
// Class returns the class of the question in the request. // Class returns the class of the question in the request.
func (c Context) Class() string { func (s State) Class() string {
return dns.Class(c.Req.Question[0].Qclass).String() return dns.Class(s.Req.Question[0].Qclass).String()
} }
// QClass returns the class of the question in the request. // QClass returns the class of the question in the request.
func (c Context) QClass() uint16 { func (s State) QClass() uint16 {
return c.Req.Question[0].Qclass return s.Req.Question[0].Qclass
} }
// More convience types for extracting stuff from a message?
// Header?
// ErrorMessage returns an error message suitable for sending // ErrorMessage returns an error message suitable for sending
// back to the client. // back to the client.
func (c Context) ErrorMessage(rcode int) *dns.Msg { func (s State) ErrorMessage(rcode int) *dns.Msg {
m := new(dns.Msg) m := new(dns.Msg)
m.SetRcode(c.Req, rcode) m.SetRcode(s.Req, rcode)
return m return m
} }
// AnswerMessage returns an error message suitable for sending // AnswerMessage returns an error message suitable for sending
// back to the client. // back to the client.
func (c Context) AnswerMessage() *dns.Msg { func (s State) AnswerMessage() *dns.Msg {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(c.Req) m.SetReply(s.Req)
return m return m
} }
...@@ -15,6 +15,8 @@ import ( ...@@ -15,6 +15,8 @@ import (
"sync" "sync"
"time" "time"
"golang.org/x/net/context"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
...@@ -285,6 +287,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { ...@@ -285,6 +287,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
q := r.Question[0].Name q := r.Question[0].Name
b := make([]byte, len(q)) b := make([]byte, len(q))
off, end := 0, false off, end := 0, false
ctx := context.Background()
for { for {
l := len(q[off:]) l := len(q[off:])
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
...@@ -297,7 +300,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { ...@@ -297,7 +300,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if h, ok := s.zones[string(b[:l])]; ok { if h, ok := s.zones[string(b[:l])]; ok {
if r.Question[0].Qtype != dns.TypeDS { if r.Question[0].Qtype != dns.TypeDS {
rcode, _ := h.stack.ServeDNS(w, r) rcode, _ := h.stack.ServeDNS(ctx, w, r)
if rcode > 0 { if rcode > 0 {
DefaultErrorFunc(w, r, rcode) DefaultErrorFunc(w, r, rcode)
} }
...@@ -311,7 +314,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { ...@@ -311,7 +314,7 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
// Wildcard match, if we have found nothing try the root zone as a last resort. // Wildcard match, if we have found nothing try the root zone as a last resort.
if h, ok := s.zones["."]; ok { if h, ok := s.zones["."]; ok {
rcode, _ := h.stack.ServeDNS(w, r) rcode, _ := h.stack.ServeDNS(ctx, w, r)
if rcode > 0 { if rcode > 0 {
DefaultErrorFunc(w, r, rcode) DefaultErrorFunc(w, r, rcode)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment