Commit 2fd31cd3 authored by Miek Gieben's avatar Miek Gieben Committed by GitHub

plugin/metadata: some cleanups (#1906)

* plugin/metadata: some cleanups

Name to provider.go as that's what being defined right now in the file.
Use request.Request because that's done in variables.go anyway. Name the
main storage M, because there is no further meaning behind.

Remove superfluous methods
Signed-off-by: default avatarMiek Gieben <miek@miek.nl>

* Fix test
Signed-off-by: default avatarMiek Gieben <miek@miek.nl>
parent e6c00f39
...@@ -24,15 +24,16 @@ func (m *Metadata) Name() string { return "metadata" } ...@@ -24,15 +24,16 @@ func (m *Metadata) Name() string { return "metadata" }
// ServeDNS implements the plugin.Handler interface. // ServeDNS implements the plugin.Handler interface.
func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
md, ctx := newMD(ctx) ctx = context.WithValue(ctx, metadataKey{}, M{})
md, _ := FromContext(ctx)
state := request.Request{W: w, Req: r} state := request.Request{W: w, Req: r}
if plugin.Zones(m.Zones).Matches(state.Name()) != "" { if plugin.Zones(m.Zones).Matches(state.Name()) != "" {
// Go through all Providers and collect metadata // Go through all Providers and collect metadata.
for _, provider := range m.Providers { for _, provider := range m.Providers {
for _, varName := range provider.MetadataVarNames() { for _, varName := range provider.MetadataVarNames() {
if val, ok := provider.Metadata(ctx, w, r, varName); ok { if val, ok := provider.Metadata(ctx, state, varName); ok {
md.setValue(varName, val) md.SetValue(varName, val)
} }
} }
} }
...@@ -47,8 +48,8 @@ func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms ...@@ -47,8 +48,8 @@ func (m *Metadata) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Ms
func (m *Metadata) MetadataVarNames() []string { return variables.All } func (m *Metadata) MetadataVarNames() []string { return variables.All }
// Metadata implements the plugin.Provider interface. // Metadata implements the plugin.Provider interface.
func (m *Metadata) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, varName string) (interface{}, bool) { func (m *Metadata) Metadata(ctx context.Context, state request.Request, varName string) (interface{}, bool) {
if val, err := variables.GetValue(varName, w, r); err == nil { if val, err := variables.GetValue(state, varName); err == nil {
return val, true return val, true
} }
return nil, false return nil, false
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
...@@ -20,12 +21,12 @@ func (m testProvider) MetadataVarNames() []string { ...@@ -20,12 +21,12 @@ func (m testProvider) MetadataVarNames() []string {
return keys return keys
} }
func (m testProvider) Metadata(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, key string) (val interface{}, ok bool) { func (m testProvider) Metadata(ctx context.Context, state request.Request, key string) (val interface{}, ok bool) {
value, ok := m[key] value, ok := m[key]
return value, ok return value, ok
} }
// testHandler implements plugin.Handler // testHandler implements plugin.Handler.
type testHandler struct{ ctx context.Context } type testHandler struct{ ctx context.Context }
func (m *testHandler) Name() string { return "testHandler" } func (m *testHandler) Name() string { return "testHandler" }
...@@ -35,7 +36,7 @@ func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns ...@@ -35,7 +36,7 @@ func (m *testHandler) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns
return 0, nil return 0, nil
} }
func TestMetadataServDns(t *testing.T) { func TestMetadataServeDNS(t *testing.T) {
expectedMetadata := []testProvider{ expectedMetadata := []testProvider{
testProvider{"testkey1": "testvalue1"}, testProvider{"testkey1": "testvalue1"},
testProvider{"testkey2": 2, "testkey3": "testvalue3"}, testProvider{"testkey2": 2, "testkey3": "testvalue3"},
...@@ -45,9 +46,8 @@ func TestMetadataServDns(t *testing.T) { ...@@ -45,9 +46,8 @@ func TestMetadataServDns(t *testing.T) {
for _, e := range expectedMetadata { for _, e := range expectedMetadata {
providers = append(providers, e) providers = append(providers, e)
} }
// Fake handler which stores the resulting context
next := &testHandler{}
next := &testHandler{} // fake handler which stores the resulting context
metadata := Metadata{ metadata := Metadata{
Zones: []string{"."}, Zones: []string{"."},
Providers: providers, Providers: providers,
......
...@@ -3,7 +3,7 @@ package metadata ...@@ -3,7 +3,7 @@ package metadata
import ( import (
"context" "context"
"github.com/miekg/dns" "github.com/coredns/coredns/request"
) )
// Provider interface needs to be implemented by each plugin willing to provide // Provider interface needs to be implemented by each plugin willing to provide
...@@ -16,38 +16,32 @@ type Provider interface { ...@@ -16,38 +16,32 @@ type Provider interface {
// Metadata is expected to return a value with metadata information by the key // Metadata is expected to return a value with metadata information by the key
// from 4th argument. Value can be later retrieved from context by any other plugin. // from 4th argument. Value can be later retrieved from context by any other plugin.
// If value is not available by some reason returned boolean value should be false. // If value is not available by some reason returned boolean value should be false.
Metadata(context.Context, dns.ResponseWriter, *dns.Msg, string) (interface{}, bool) Metadata(ctx context.Context, state request.Request, variable string) (interface{}, bool)
} }
// MD is metadata information storage // M is metadata information storage.
type MD map[string]interface{} type M map[string]interface{}
// metadataKey defines the type of key that is used to save metadata into the context // FromContext retrieves the metadata from the context.
type metadataKey struct{} func FromContext(ctx context.Context) (M, bool) {
// newMD initializes MD and attaches it to context
func newMD(ctx context.Context) (MD, context.Context) {
m := MD{}
return m, context.WithValue(ctx, metadataKey{}, m)
}
// FromContext retrieves MD struct from context.
func FromContext(ctx context.Context) (md MD, ok bool) {
if metadata := ctx.Value(metadataKey{}); metadata != nil { if metadata := ctx.Value(metadataKey{}); metadata != nil {
if md, ok := metadata.(MD); ok { if m, ok := metadata.(M); ok {
return md, true return m, true
} }
} }
return MD{}, false return M{}, false
} }
// Value returns metadata value by key. // Value returns metadata value by key.
func (m MD) Value(key string) (value interface{}, ok bool) { func (m M) Value(key string) (value interface{}, ok bool) {
value, ok = m[key] value, ok = m[key]
return value, ok return value, ok
} }
// setValue adds metadata value. // SetValue sets the metadata value under key.
func (m MD) setValue(key string, val interface{}) { func (m M) SetValue(key string, val interface{}) {
m[key] = val m[key] = val
} }
// metadataKey defines the type of key that is used to save metadata into the context.
type metadataKey struct{}
...@@ -25,23 +25,24 @@ func TestMD(t *testing.T) { ...@@ -25,23 +25,24 @@ func TestMD(t *testing.T) {
// Using one same md and ctx for all test cases // Using one same md and ctx for all test cases
ctx := context.TODO() ctx := context.TODO()
md, ctx := newMD(ctx) ctx = context.WithValue(ctx, metadataKey{}, M{})
m, _ := FromContext(ctx)
for i, tc := range tests { for i, tc := range tests {
for k, v := range tc.addValues { for k, v := range tc.addValues {
md.setValue(k, v) m.SetValue(k, v)
} }
if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(md)) { if !reflect.DeepEqual(tc.expectedValues, map[string]interface{}(m)) {
t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, md) t.Errorf("Test %d: Expected %v but got %v", i, tc.expectedValues, m)
} }
// Make sure that MD is recieved from context successfullly // Make sure that md is recieved from context successfullly
mdFromContext, ok := FromContext(ctx) mFromContext, ok := FromContext(ctx)
if !ok { if !ok {
t.Errorf("Test %d: MD is not recieved from the context", i) t.Errorf("Test %d: md is not recieved from the context", i)
} }
if !reflect.DeepEqual(md, mdFromContext) { if !reflect.DeepEqual(m, mFromContext) {
t.Errorf("Test %d: MD recieved from context differs from initial. Initial: %v, from context: %v", i, md, mdFromContext) t.Errorf("Test %d: md recieved from context differs from initial. Initial: %v, from context: %v", i, m, mFromContext)
} }
} }
} }
...@@ -7,8 +7,6 @@ import ( ...@@ -7,8 +7,6 @@ import (
"strconv" "strconv"
"github.com/coredns/coredns/request" "github.com/coredns/coredns/request"
"github.com/miekg/dns"
) )
const ( const (
...@@ -26,35 +24,32 @@ var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverI ...@@ -26,35 +24,32 @@ var All = []string{queryName, queryType, clientIP, clientPort, protocol, serverI
// GetValue calculates and returns the data specified by the variable name. // GetValue calculates and returns the data specified by the variable name.
// Supported varNames are listed in allProvidedVars. // Supported varNames are listed in allProvidedVars.
func GetValue(varName string, w dns.ResponseWriter, r *dns.Msg) ([]byte, error) { func GetValue(state request.Request, varName string) ([]byte, error) {
req := request.Request{W: w, Req: r}
switch varName { switch varName {
case queryName: case queryName:
//Query name is written as ascii string return []byte(state.QName()), nil
return []byte(req.QName()), nil
case queryType: case queryType:
return uint16ToWire(req.QType()), nil return uint16ToWire(state.QType()), nil
case clientIP: case clientIP:
return ipToWire(req.Family(), req.IP()) return ipToWire(state.Family(), state.IP())
case clientPort: case clientPort:
return portToWire(req.Port()) return portToWire(state.Port())
case protocol: case protocol:
// Proto is written as ascii string return []byte(state.Proto()), nil
return []byte(req.Proto()), nil
case serverIP: case serverIP:
ip, _, err := net.SplitHostPort(w.LocalAddr().String()) ip, _, err := net.SplitHostPort(state.W.LocalAddr().String())
if err != nil { if err != nil {
ip = w.RemoteAddr().String() ip = state.W.RemoteAddr().String()
} }
return ipToWire(family(w.RemoteAddr()), ip) return ipToWire(state.Family(), ip)
case serverPort: case serverPort:
_, port, err := net.SplitHostPort(w.LocalAddr().String()) _, port, err := net.SplitHostPort(state.W.LocalAddr().String())
if err != nil { if err != nil {
port = "0" port = "0"
} }
......
...@@ -5,6 +5,8 @@ import ( ...@@ -5,6 +5,8 @@ import (
"testing" "testing"
"github.com/coredns/coredns/plugin/test" "github.com/coredns/coredns/plugin/test"
"github.com/coredns/coredns/request"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
...@@ -63,8 +65,9 @@ func TestGetValue(t *testing.T) { ...@@ -63,8 +65,9 @@ func TestGetValue(t *testing.T) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetQuestion("example.com.", dns.TypeA) m.SetQuestion("example.com.", dns.TypeA)
m.Question[0].Qclass = dns.ClassINET m.Question[0].Qclass = dns.ClassINET
state := request.Request{W: &test.ResponseWriter{}, Req: m}
value, err := GetValue(tc.varName, &test.ResponseWriter{}, m) value, err := GetValue(state, tc.varName)
if tc.shouldErr && err == nil { if tc.shouldErr && err == nil {
t.Errorf("Test %d: Expected error, but didn't recieve", i) t.Errorf("Test %d: Expected error, but didn't recieve", i)
......
...@@ -202,7 +202,8 @@ func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWrite ...@@ -202,7 +202,8 @@ func (rule *edns0VariableRule) ruleData(ctx context.Context, w dns.ResponseWrite
} }
} }
} else { // No metadata available means metadata plugin is disabled. Try to get the value directly. } else { // No metadata available means metadata plugin is disabled. Try to get the value directly.
return variables.GetValue(rule.variable, w, r) state := request.Request{W: w, Req: r} // TODO(miek): every rule needs to take a request.Request.
return variables.GetValue(state, rule.variable)
} }
return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable) return nil, fmt.Errorf("unable to extract data for variable %s", rule.variable)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment