Commit 9ac3cab1 authored by Miek Gieben's avatar Miek Gieben Committed by GitHub

Make CoreDNS a server type plugin for Caddy (#220)

* Make CoreDNS a server type plugin for Caddy

Remove code we don't need and port all middleware over. Fix all tests
and rework the documentation.

Also make `go generate` build a caddy binary which we then copy into
our directory. This means `go build`-builds remain working as-is.

And new etc instances in each etcd test for better isolation.
Fix more tests and rework test.Server with the newer support Caddy offers.

Fix Makefile to support new mode of operation.
parent a1989c35
......@@ -12,7 +12,7 @@ go:
# In the Travis VM-based build environment, IPv6 networking is not
# enabled by default. The sysctl operations below enable IPv6.
# IPv6 is needed by some of the CoreDNS test cases. The VM environment
# is needed to have access to sudo in the test environment. Sudo is
# is needed to have access to sudo in the test environment. Sudo is
# needed to have docker in the test environment. Docker is needed to
# launch a kubernetes instance in the test environment.
# (Dependencies are fun! :) )
......@@ -38,9 +38,9 @@ before_script:
- if which docker &>/dev/null ; then docker pull gcr.io/google_containers/hyperkube-amd64:v1.2.4 ; docker ps -a ; fi
- if which docker &>/dev/null ; then ./contrib/kubernetes/testscripts/start_k8s_with_services.sh ; docker ps -a ; fi
# Get golang dependencies, and build coredns binary
- go get -v -d
- go get -v -d ./...
- go get github.com/coreos/go-etcd/etcd
- go build -v -ldflags="-s -w"
#- go build -v -ldflags="-s -w"
script:
- go test -tags etcd -race -bench=. ./...
......
#BUILD_VERBOSE :=
BUILD_VERBOSE := -v
TEST_VERBOSE :=
#TEST_VERBOSE :=
TEST_VERBOSE := -v
DOCKER_IMAGE_NAME := $$USER/coredns
all: coredns
all:
# Phony this to ensure we always build the binary.
# TODO: Add .go file dependencies.
.PHONY: coredns
coredns: generate deps
go build $(BUILD_VERBOSE) -ldflags="-s -w"
.PHONY: docker
docker: all
GOOS=linux go build -a -tags netgo -installsuffix netgo -ldflags="-s -w"
docker: deps
CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w"
docker build -t $(DOCKER_IMAGE_NAME) .
../../mholt/caddy:
# Get caddy so we can generate into that codebase
# before getting all other dependencies.
go get ${BUILD_VERBOSE} github.com/mholt/caddy
.PHONY: generate
generate: ../../mholt/caddy
go generate $(BUILD_VERSOSE)
.PHONY: deps
deps:
deps: generate
go get ${BUILD_VERBOSE}
.PHONY: test
test:
test: deps
go test $(TEST_VERBOSE) ./...
.PHONY: testk8s
testk8s:
testk8s: deps
# With -args --v=100 the k8s API response data will be printed in the log:
#go test $(TEST_VERBOSE) -tags=k8s -run 'TestK8sIntegration' ./test -args --v=100
# Without the k8s API response data:
go test $(TEST_VERBOSE) -tags=k8s -run 'TestK8sIntegration' ./test
.PHONY: testk8s-setup
testk8s-setup:
go test -v ./core/setup -run TestKubernetes
testk8s-setup: deps
go test -v ./middleware/kubernetes/... -run TestKubernetes
.PHONY: clean
clean:
go clean
rm -f coredns
.PHONY: distclean
distclean: clean
# Clean all dependencies and build artifacts
find $(GOPATH)/pkg -maxdepth 1 -mindepth 1 | xargs rm -rf
find $(GOPATH)/bin -maxdepth 1 -mindepth 1 | xargs rm -rf
find $(GOPATH)/src -maxdepth 1 -mindepth 1 | grep -v github | xargs rm -rf
find $(GOPATH)/src -maxdepth 2 -mindepth 2 | grep -v miekg | xargs rm -rf
find $(GOPATH)/src/github.com/miekg -maxdepth 1 -mindepth 1 \! -name \*coredns\* | xargs rm -rf
# CoreDNS
CoreDNS is DNS server that started as a fork of [Caddy](https://github.com/mholt/caddy/). It has the
same model: it chains middleware.
same model: it chains middleware. In fact to similar that CoreDNS is now a server type plugin for
CAddy, i.e. you'll need Caddy to compile CoreDNS.
CoreDNS is the successor of [SkyDNS](https://github.com/skynetservices/skydns). SkyDNS is a thin
layer that exposes services in etcd in the DNS. CoreDNS builds on this idea and is a generic DNS
......@@ -38,7 +39,7 @@ There are still few [issues](https://github.com/miekg/coredns/issues), and work
things fast and reduce the memory usage.
All in all, CoreDNS should be able to provide you with enough functionality to replace parts of
BIND9, Knot, NSD or PowerDNS.
BIND9, Knot, NSD or PowerDNS and SkyDNS.
Most documentation is in the source and some blog articles can be [found
here](https://miek.nl/tags/coredns/). If you do want to use CoreDNS in production, please let us
know and how we can help.
......@@ -46,6 +47,25 @@ know and how we can help.
<https://caddyserver.com/> is also full of examples on how to structure a Corefile (renamed from
Caddyfile when I forked it).
## Compilation
CoreDNS (as a servertype plugin for Caddy) has a hard dependency on Caddy - this is *almost* like
the normal Go dependencies, but with a small twist, caddy (the source) need to know that CoreDNS
exists and for this we need to add 1 line `_ "github.com/miekg/coredns/core"` to file in caddy.
You have the source of CoreDNS, this should preferably be downloaded under your `$GOPATH`. Get all
dependencies:
go get ./...
Then, execute `go generate`, this will patch Caddy to add CoreDNS, and then `go build` as you would
normally do:
go generate
go build
Should yield a `coredns` binary.
## Examples
Start a simple proxy:
......@@ -95,15 +115,17 @@ All the above examples are possible with the *current* CoreDNS.
## What remains to be done
* Website?
* Logo?
* Optimizations.
* Load testing.
* The [issues](https://github.com/miekg/coredns/issues).
## Blog
## Blog and Contact
Website: <https://coredns.io>
Twitter: `@coredns.io`
Docs: <https://miek.nl/tags/coredns/>
Github: <https://github.com/miekg/coredns>
<https://miek.nl/tags/coredns/>
## Systemd service file
......@@ -123,7 +145,7 @@ LimitNOFILE=8192
User=coredns
WorkingDirectory=/home/coredns
ExecStartPre=/sbin/setcap cap_net_bind_service=+ep /opt/bin/coredns
ExecStart=/opt/bin/coredns -pidfile /home/coredns/coredns.pid -conf=/etc/coredns
ExecStart=/opt/bin/coredns -pidfile /home/coredns/coredns.pid -conf=/etc/coredns/Corefile
ExecReload=/bin/kill -SIGUSR1 $MAINPID
Restart=on-failure
......
package assets
import (
"os"
"path/filepath"
"runtime"
)
// Path returns the path to the folder
// where the application may store data. This
// currently resolves to ~/.coredns
func Path() string {
return filepath.Join(userHomeDir(), ".coredns")
}
// userHomeDir returns the user's home directory according to
// environment variables.
//
// Credit: http://stackoverflow.com/a/7922977/1048862
func userHomeDir() string {
if runtime.GOOS == "windows" {
home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
if home == "" {
home = os.Getenv("USERPROFILE")
}
return home
}
return os.Getenv("HOME")
}
package assets
import (
"strings"
"testing"
)
func TestPath(t *testing.T) {
if actual := Path(); !strings.HasSuffix(actual, ".coredns") {
t.Errorf("Expected path to be a .coredns folder, got: %v", actual)
}
}
package core
/*
func TestCaddyStartStop(t *testing.T) {
corefile := "localhost:1984"
for i := 0; i < 2; i++ {
err := Start(CorefileInput{Contents: []byte(corefile)})
if err != nil {
t.Fatalf("Error starting, iteration %d: %v", i, err)
}
client := http.Client{
Timeout: time.Duration(2 * time.Second),
}
resp, err := client.Get("http://localhost:1984")
if err != nil {
t.Fatalf("Expected GET request to succeed (iteration %d), but it failed: %v", i, err)
}
resp.Body.Close()
err = Stop()
if err != nil {
t.Fatalf("Error stopping, iteration %d: %v", i, err)
}
}
}
*/
This diff is collapsed.
package core
import (
"reflect"
"sync"
"testing"
"github.com/miekg/coredns/server"
)
func TestDefaultInput(t *testing.T) {
if actual, expected := string(DefaultInput().Body()), ":53\nroot ."; actual != expected {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
}
// next few tests simulate user providing -host and/or -port flags
Host = "not-localhost.com"
if actual, expected := string(DefaultInput().Body()), "not-localhost.com:53\nroot ."; actual != expected {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
}
Host = "[::1]"
if actual, expected := string(DefaultInput().Body()), "[::1]:53\nroot ."; actual != expected {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
}
Host = "127.0.1.1"
if actual, expected := string(DefaultInput().Body()), "127.0.1.1:53\nroot ."; actual != expected {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
}
Host = "not-localhost.com"
Port = "1234"
if actual, expected := string(DefaultInput().Body()), "not-localhost.com:1234\nroot ."; actual != expected {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
}
Host = DefaultHost
Port = "1234"
if actual, expected := string(DefaultInput().Body()), ":1234\nroot ."; actual != expected {
t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
}
}
func TestResolveAddr(t *testing.T) {
// NOTE: If tests fail due to comparing to string "127.0.0.1",
// it's possible that system env resolves with IPv6, or ::1.
// If that happens, maybe we should use actualAddr.IP.IsLoopback()
// for the assertion, rather than a direct string comparison.
// NOTE: Tests with {Host: "", Port: ""} and {Host: "localhost", Port: ""}
// will not behave the same cross-platform, so they have been omitted.
for i, test := range []struct {
config server.Config
shouldWarnErr bool
shouldFatalErr bool
expectedIP string
expectedPort int
}{
{server.Config{Host: "127.0.0.1", Port: "1234"}, false, false, "<nil>", 1234},
{server.Config{Host: "localhost", Port: "80"}, false, false, "<nil>", 80},
{server.Config{BindHost: "localhost", Port: "1234"}, false, false, "127.0.0.1", 1234},
{server.Config{BindHost: "127.0.0.1", Port: "1234"}, false, false, "127.0.0.1", 1234},
{server.Config{BindHost: "should-not-resolve", Port: "1234"}, true, false, "<nil>", 1234},
{server.Config{BindHost: "localhost", Port: "http"}, false, false, "127.0.0.1", 80},
{server.Config{BindHost: "localhost", Port: "https"}, false, false, "127.0.0.1", 443},
{server.Config{BindHost: "", Port: "1234"}, false, false, "<nil>", 1234},
{server.Config{BindHost: "localhost", Port: "abcd"}, false, true, "", 0},
{server.Config{BindHost: "127.0.0.1", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234},
{server.Config{BindHost: "localhost", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234},
{server.Config{BindHost: "should-not-resolve", Host: "localhost", Port: "1234"}, true, false, "<nil>", 1234},
} {
actualAddr, warnErr, fatalErr := resolveAddr(test.config)
if test.shouldFatalErr && fatalErr == nil {
t.Errorf("Test %d: Expected error, but there wasn't any", i)
}
if !test.shouldFatalErr && fatalErr != nil {
t.Errorf("Test %d: Expected no error, but there was one: %v", i, fatalErr)
}
if fatalErr != nil {
continue
}
if test.shouldWarnErr && warnErr == nil {
t.Errorf("Test %d: Expected warning, but there wasn't any", i)
}
if !test.shouldWarnErr && warnErr != nil {
t.Errorf("Test %d: Expected no warning, but there was one: %v", i, warnErr)
}
if actual, expected := actualAddr.IP.String(), test.expectedIP; actual != expected {
t.Errorf("Test %d: IP was %s but expected %s", i, actual, expected)
}
if actual, expected := actualAddr.Port, test.expectedPort; actual != expected {
t.Errorf("Test %d: Port was %d but expected %d", i, actual, expected)
}
}
}
func TestMakeOnces(t *testing.T) {
directives := []directive{
{"dummy", nil},
{"dummy2", nil},
}
directiveOrder = directives
onces := makeOnces()
if len(onces) != len(directives) {
t.Errorf("onces had len %d , expected %d", len(onces), len(directives))
}
expected := map[string]*sync.Once{
"dummy": new(sync.Once),
"dummy2": new(sync.Once),
}
if !reflect.DeepEqual(onces, expected) {
t.Errorf("onces was %v, expected %v", onces, expected)
}
}
func TestMakeStorages(t *testing.T) {
directives := []directive{
{"dummy", nil},
{"dummy2", nil},
}
directiveOrder = directives
storages := makeStorages()
if len(storages) != len(directives) {
t.Errorf("storages had len %d , expected %d", len(storages), len(directives))
}
expected := map[string]interface{}{
"dummy": nil,
"dummy2": nil,
}
if !reflect.DeepEqual(storages, expected) {
t.Errorf("storages was %v, expected %v", storages, expected)
}
}
func TestValidDirective(t *testing.T) {
directives := []directive{
{"dummy", nil},
{"dummy2", nil},
}
directiveOrder = directives
for i, test := range []struct {
directive string
valid bool
}{
{"dummy", true},
{"dummy2", true},
{"dummy3", false},
} {
if actual, expected := validDirective(test.directive), test.valid; actual != expected {
t.Errorf("Test %d: valid was %t, expected %t", i, actual, expected)
}
}
}
This diff is collapsed.
package core
import (
// plug in the server
_ "github.com/miekg/coredns/core/dnsserver"
// plug in the standard directives
_ "github.com/miekg/coredns/middleware/bind"
_ "github.com/miekg/coredns/middleware/health"
_ "github.com/miekg/coredns/middleware/pprof"
_ "github.com/miekg/coredns/middleware/errors"
_ "github.com/miekg/coredns/middleware/loadbalance"
_ "github.com/miekg/coredns/middleware/log"
_ "github.com/miekg/coredns/middleware/metrics"
_ "github.com/miekg/coredns/middleware/rewrite"
_ "github.com/miekg/coredns/middleware/cache"
_ "github.com/miekg/coredns/middleware/chaos"
_ "github.com/miekg/coredns/middleware/dnssec"
_ "github.com/miekg/coredns/middleware/etcd"
_ "github.com/miekg/coredns/middleware/file"
_ "github.com/miekg/coredns/middleware/kubernetes"
_ "github.com/miekg/coredns/middleware/proxy"
_ "github.com/miekg/coredns/middleware/secondary"
)
package coremain
import (
"errors"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"runtime"
"strconv"
"strings"
"time"
"github.com/miekg/coredns/core"
"github.com/miekg/coredns/core/https"
"github.com/xenolf/lego/acme"
"gopkg.in/natefinch/lumberjack.v2"
)
func init() {
core.TrapSignals()
setVersion()
flag.BoolVar(&https.Agreed, "agree", false, "Agree to Let's Encrypt Subscriber Agreement")
flag.StringVar(&https.CAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "Certificate authority ACME server")
flag.StringVar(&conf, "conf", "", "Configuration file to use (default="+core.DefaultConfigFile+")")
flag.StringVar(&cpu, "cpu", "100%", "CPU cap")
flag.StringVar(&https.DefaultEmail, "email", "", "Default Let's Encrypt account email address")
flag.DurationVar(&core.GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown")
flag.StringVar(&core.Host, "host", core.DefaultHost, "Default host")
flag.StringVar(&logfile, "log", "", "Process log file")
flag.StringVar(&core.PidFile, "pidfile", "", "Path to write pid file")
flag.StringVar(&core.Port, "port", core.DefaultPort, "Default port")
flag.BoolVar(&core.Quiet, "quiet", false, "Quiet mode (no initialization output)")
flag.StringVar(&revoke, "revoke", "", "Hostname for which to revoke the certificate")
flag.StringVar(&core.Root, "root", core.DefaultRoot, "Root path to default zone files")
flag.BoolVar(&version, "version", false, "Show version")
}
func Run() {
flag.Parse() // called here in Run() to allow other packages to set flags in their inits
core.AppName = appName
core.AppVersion = appVersion
acme.UserAgent = appName + "/" + appVersion
// set up process log before anything bad happens
switch logfile {
case "stdout":
log.SetOutput(os.Stdout)
case "stderr":
log.SetOutput(os.Stderr)
case "":
log.SetOutput(ioutil.Discard)
default:
log.SetOutput(&lumberjack.Logger{
Filename: logfile,
MaxSize: 100,
MaxAge: 14,
MaxBackups: 10,
})
}
if revoke != "" {
err := https.Revoke(revoke)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Revoked certificate for %s\n", revoke)
os.Exit(0)
}
if version {
fmt.Printf("%s %s\n", appName, appVersion)
if devBuild && gitShortStat != "" {
fmt.Printf("%s\n%s\n", gitShortStat, gitFilesModified)
}
os.Exit(0)
}
// Set CPU cap
err := setCPU(cpu)
if err != nil {
mustLogFatal(err)
}
// Get Corefile input
corefile, err := core.LoadCorefile(loadCorefile)
if err != nil {
mustLogFatal(err)
}
// Start your engines
err = core.Start(corefile)
if err != nil {
mustLogFatal(err)
}
// Twiddle your thumbs
core.Wait()
}
// mustLogFatal just wraps log.Fatal() in a way that ensures the
// output is always printed to stderr so the user can see it
// if the user is still there, even if the process log was not
// enabled. If this process is a restart, however, and the user
// might not be there anymore, this just logs to the process log
// and exits.
func mustLogFatal(args ...interface{}) {
if !core.IsRestart() {
log.SetOutput(os.Stderr)
}
log.Fatal(args...)
}
func loadCorefile() (core.Input, error) {
// Try -conf flag
if conf != "" {
if conf == "stdin" {
return core.CorefileFromPipe(os.Stdin)
}
contents, err := ioutil.ReadFile(conf)
if err != nil {
return nil, err
}
return core.CorefileInput{
Contents: contents,
Filepath: conf,
RealFile: true,
}, nil
}
// command line args
if flag.NArg() > 0 {
confBody := core.Host + ":" + core.Port + "\n" + strings.Join(flag.Args(), "\n")
return core.CorefileInput{
Contents: []byte(confBody),
Filepath: "args",
}, nil
}
// Corefile in cwd
contents, err := ioutil.ReadFile(core.DefaultConfigFile)
if err != nil {
if os.IsNotExist(err) {
return core.DefaultInput(), nil
}
return nil, err
}
return core.CorefileInput{
Contents: contents,
Filepath: core.DefaultConfigFile,
RealFile: true,
}, nil
}
// setCPU parses string cpu and sets GOMAXPROCS
// according to its value. It accepts either
// a number (e.g. 3) or a percent (e.g. 50%).
func setCPU(cpu string) error {
var numCPU int
availCPU := runtime.NumCPU()
if strings.HasSuffix(cpu, "%") {
// Percent
var percent float32
pctStr := cpu[:len(cpu)-1]
pctInt, err := strconv.Atoi(pctStr)
if err != nil || pctInt < 1 || pctInt > 100 {
return errors.New("invalid CPU value: percentage must be between 1-100")
}
percent = float32(pctInt) / 100
numCPU = int(float32(availCPU) * percent)
} else {
// Number
num, err := strconv.Atoi(cpu)
if err != nil || num < 1 {
return errors.New("invalid CPU value: provide a number or percent greater than 0")
}
numCPU = num
}
if numCPU > availCPU {
numCPU = availCPU
}
runtime.GOMAXPROCS(numCPU)
return nil
}
// setVersion figures out the version information based on
// variables set by -ldflags.
func setVersion() {
// A development build is one that's not at a tag or has uncommitted changes
devBuild = gitTag == "" || gitShortStat != ""
// Only set the appVersion if -ldflags was used
if gitNearestTag != "" || gitTag != "" {
if devBuild && gitNearestTag != "" {
appVersion = fmt.Sprintf("%s (+%s %s)",
strings.TrimPrefix(gitNearestTag, "v"), gitCommit, buildDate)
} else if gitTag != "" {
appVersion = strings.TrimPrefix(gitTag, "v")
}
}
}
const appName = "CoreDNS"
// Flags that control program flow or startup
var (
conf string
cpu string
logfile string
revoke string
version bool
)
// Build information obtained with the help of -ldflags
var (
appVersion = "(untracked dev build)" // inferred at startup
devBuild = true // inferred at startup
buildDate string // date -u
gitTag string // git describe --exact-match HEAD 2> /dev/null
gitNearestTag string // git describe --abbrev=0 --tags HEAD
gitCommit string // git rev-parse HEAD
gitShortStat string // git diff-index --shortstat
gitFilesModified string // git diff-index --name-only HEAD
)
package coremain
import (
"runtime"
"testing"
)
func TestSetCPU(t *testing.T) {
currentCPU := runtime.GOMAXPROCS(-1)
maxCPU := runtime.NumCPU()
halfCPU := int(0.5 * float32(maxCPU))
if halfCPU < 1 {
halfCPU = 1
}
for i, test := range []struct {
input string
output int
shouldErr bool
}{
{"1", 1, false},
{"-1", currentCPU, true},
{"0", currentCPU, true},
{"100%", maxCPU, false},
{"50%", halfCPU, false},
{"110%", currentCPU, true},
{"-10%", currentCPU, true},
{"invalid input", currentCPU, true},
{"invalid input%", currentCPU, true},
{"9999", maxCPU, false}, // over available CPU
} {
err := setCPU(test.input)
if test.shouldErr && err == nil {
t.Errorf("Test %d: Expected error, but there wasn't any", i)
}
if !test.shouldErr && err != nil {
t.Errorf("Test %d: Expected no error, but there was one: %v", i, err)
}
if actual, expected := runtime.GOMAXPROCS(-1), test.output; actual != expected {
t.Errorf("Test %d: GOMAXPROCS was %d but expected %d", i, actual, expected)
}
// teardown
runtime.GOMAXPROCS(currentCPU)
}
}
func TestSetVersion(t *testing.T) {
setVersion()
if !devBuild {
t.Error("Expected default to assume development build, but it didn't")
}
if got, want := appVersion, "(untracked dev build)"; got != want {
t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
}
gitTag = "v1.1"
setVersion()
if devBuild {
t.Error("Expected a stable build if gitTag is set with no changes")
}
if got, want := appVersion, "1.1"; got != want {
t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
}
gitTag = ""
gitNearestTag = "v1.0"
gitCommit = "deadbeef"
buildDate = "Fri Feb 26 06:53:17 UTC 2016"
setVersion()
if !devBuild {
t.Error("Expected inferring a dev build when gitTag is empty")
}
if got, want := appVersion, "1.0 (+deadbeef Fri Feb 26 06:53:17 UTC 2016)"; got != want {
t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
}
}
package core
import (
"github.com/miekg/coredns/core/https"
"github.com/miekg/coredns/core/parse"
"github.com/miekg/coredns/core/setup"
"github.com/miekg/coredns/middleware"
)
func init() {
// The parse package must know which directives
// are valid, but it must not import the setup
// or config package. To solve this problem, we
// fill up this map in our init function here.
// The parse package does not need to know the
// ordering of the directives.
for _, dir := range directiveOrder {
parse.ValidDirectives[dir.name] = struct{}{}
}
}
// Directives are registered in the order they should be
// executed. Middleware (directives that inject a handler)
// are executed in the order A-B-C-*-C-B-A, assuming
// they all call the Next handler in the chain.
//
// Ordering is VERY important. Every middleware will
// feel the effects of all other middleware below
// (after) them during a request, but they must not
// care what middleware above them are doing.
//
// For example, log needs to know the status code and
// exactly how many bytes were written to the client,
// which every other middleware can affect, so it gets
// registered first. The errors middleware does not
// care if gzip or log modifies its response, so it
// gets registered below them. Gzip, on the other hand,
// DOES care what errors does to the response since it
// must compress every output to the client, even error
// pages, so it must be registered before the errors
// middleware and any others that would write to the
// response.
var directiveOrder = []directive{
// Essential directives that initialize vital configuration settings
{"root", setup.Root},
{"bind", setup.BindHost},
{"tls", https.Setup},
{"health", setup.Health},
{"pprof", setup.PProf},
// Other directives that don't create HTTP handlers
{"startup", setup.Startup},
{"shutdown", setup.Shutdown},
// Directives that inject handlers (middleware)
{"prometheus", setup.Prometheus},
{"errors", setup.Errors},
{"log", setup.Log},
{"chaos", setup.Chaos},
{"rewrite", setup.Rewrite},
{"loadbalance", setup.Loadbalance},
{"cache", setup.Cache},
{"dnssec", setup.Dnssec},
{"file", setup.File},
{"secondary", setup.Secondary},
{"etcd", setup.Etcd},
{"kubernetes", setup.Kubernetes},
{"proxy", setup.Proxy},
}
// RegisterDirective adds the given directive to CoreDNS's list of directives.
// Pass the name of a directive you want it to be placed after,
// otherwise it will be placed at the bottom of the stack.
func RegisterDirective(name string, setup SetupFunc, after string) {
dir := directive{name: name, setup: setup}
idx := len(directiveOrder)
for i := range directiveOrder {
if directiveOrder[i].name == after {
idx = i + 1
break
}
}
newDirectives := append(directiveOrder[:idx], append([]directive{dir}, directiveOrder[idx:]...)...)
directiveOrder = newDirectives
parse.ValidDirectives[name] = struct{}{}
}
// directive ties together a directive name with its setup function.
type directive struct {
name string
setup SetupFunc
}
// SetupFunc takes a controller and may optionally return a middleware.
// If the resulting middleware is not nil, it will be chained into
// the DNS handlers in the order specified in this package.
type SetupFunc func(c *setup.Controller) (middleware.Middleware, error)
package core
import (
"reflect"
"testing"
)
func TestRegister(t *testing.T) {
directives := []directive{
{"dummy", nil},
{"dummy2", nil},
}
directiveOrder = directives
RegisterDirective("foo", nil, "dummy")
if len(directiveOrder) != 3 {
t.Fatal("Should have 3 directives now")
}
getNames := func() (s []string) {
for _, d := range directiveOrder {
s = append(s, d.name)
}
return s
}
if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2"}) {
t.Fatalf("directive order doesn't match: %s", getNames())
}
RegisterDirective("bar", nil, "ASDASD")
if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2", "bar"}) {
t.Fatalf("directive order doesn't match: %s", getNames())
}
}
package dns
import (
"path/filepath"
"github.com/miekg/coredns/core/assets"
)
var storage = Storage(assets.Path())
// Storage is a root directory and facilitates
// forming file paths derived from it.
type Storage string
// Zones gets the directory that stores zones data.
func (s Storage) Zones() string {
return filepath.Join(string(s), "zones")
}
// Zone returns the path to the folder containing assets for domain.
func (s Storage) Zone(domain string) string {
return filepath.Join(s.Zones(), domain)
}
// SecondaryZoneFile returns the path to domain's secondary zone file (when fetched).
func (s Storage) SecondaryZoneFile(domain string) string {
return filepath.Join(s.Zone(domain), "db."+domain)
}
package dns
import (
"path/filepath"
"testing"
)
func TestStorage(t *testing.T) {
storage = Storage("./le_test")
if expected, actual := filepath.Join("le_test", "zones"), storage.Zones(); actual != expected {
t.Errorf("Expected Zones() to return '%s' but got '%s'", expected, actual)
}
if expected, actual := filepath.Join("le_test", "zones", "test.com"), storage.Zone("test.com"); actual != expected {
t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual)
}
if expected, actual := filepath.Join("le_test", "zones", "test.com", "db.test.com"), storage.SecondaryZoneFile("test.com"); actual != expected {
t.Errorf("Expected SecondaryZoneFile() to return '%s' but got '%s'", expected, actual)
}
}
package dnsserver
import (
"fmt"
"net"
"strings"
"github.com/miekg/dns"
)
type zoneAddr struct {
Zone string
Port string
}
// String return z.Zone + ":" + z.Port as a string.
func (z zoneAddr) String() string { return z.Zone + ":" + z.Port }
// normalizeZone parses an zone string into a structured format with separate
// host, and port portions, as well as the original input string.
func normalizeZone(str string) (zoneAddr, error) {
var err error
// separate host and port
host, port, err := net.SplitHostPort(str)
if err != nil {
host, port, err = net.SplitHostPort(str + ":")
// no error check here; return err at end of function
}
if len(host) > 255 {
return zoneAddr{}, fmt.Errorf("specified zone is too long: %d > 255", len(host))
}
_, d := dns.IsDomainName(host)
if !d {
return zoneAddr{}, fmt.Errorf("zone is not a valid domain name: %s", host)
}
if port == "" {
port = "53"
}
return zoneAddr{Zone: strings.ToLower(dns.Fqdn(host)), Port: port}, err
}
package dnsserver
import "github.com/mholt/caddy"
// Config configuration for a single server.
type Config struct {
// The zone of the site.
Zone string
// The hostname to bind listener to, defaults to the wildcard address
ListenHost string
// The port to listen on.
Port string
// The directory from which to parse db files, and store keys.
Root string
// Middleware stack.
Middleware []Middleware
// Compiled middleware stack.
middlewareChain Handler
}
// GetConfig gets the Config that corresponds to c.
// If none exist nil is returned.
func GetConfig(c *caddy.Controller) *Config {
ctx := c.Context().(*dnsContext)
if cfg, ok := ctx.keysToConfigs[c.Key]; ok {
return cfg
}
// we should only get here during tests because directive
// actions typically skip the server blocks where we make
// the configs.
ctx.saveConfig(c.Key, &Config{Root: Root})
return GetConfig(c)
}
package dnsserver
// Add here, and in core/coredns.go to use them.
// Directives are registered in the order they should be
// executed.
//
// Ordering is VERY important. Every middleware will
// feel the effects of all other middleware below
// (after) them during a request, but they must not
// care what middleware above them are doing.
var Directives = []string{
"bind",
"health",
"pprof",
"prometheus",
"errors",
"log",
"chaos",
"cache",
"rewrite",
"loadbalance",
"dnssec",
"file",
"secondary",
"etcd",
"kubernetes",
"proxy",
}
package dnsserver
import (
"github.com/miekg/dns"
"golang.org/x/net/context"
)
type (
// Middleware is the middle layer which represents the traditional
// idea of middleware: it chains one Handler to the next by being
// passed the next Handler in the chain.
Middleware func(Handler) Handler
// Handler is like dns.Handler except ServeDNS may return an rcode
// and/or error.
//
// If ServeDNS writes to the response body, it should return a status
// code. If the status code is not one of the following:
// * SERVFAIL (dns.RcodeServerFailure)
// * REFUSED (dns.RecodeRefused)
// * FORMERR (dns.RcodeFormatError)
// * NOTIMP (dns.RcodeNotImplemented)
//
// CoreDNS assumes *no* reply has yet been written. All other response
// codes signal other handlers above it that the response message is
// already written, and that they should not write to it also.
//
// If ServeDNS encounters an error, it should return the error value
// so it can be logged by designated error-handling middleware.
//
// If writing a response after calling another ServeDNS method, the
// returned rcode SHOULD be used when writing the response.
//
// If handling errors after calling another ServeDNS method, the
// returned error value SHOULD be logged or handled accordingly.
//
// Otherwise, return values should be propagated down the middleware
// chain by returning them unchanged.
Handler interface {
ServeDNS(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
}
// HandlerFunc is a convenience type like dns.HandlerFunc, except
// ServeDNS returns an rcode and an error. See Handler
// documentation for more information.
HandlerFunc func(context.Context, dns.ResponseWriter, *dns.Msg) (int, error)
)
// ServeDNS implements the Handler interface.
func (f HandlerFunc) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
return f(ctx, w, r)
}
package dnsserver
import (
"fmt"
"net"
"time"
"github.com/mholt/caddy"
"github.com/mholt/caddy/caddyfile"
)
const serverType = "dns"
func init() {
caddy.RegisterServerType(serverType, caddy.ServerType{
Directives: Directives,
DefaultInput: func() caddy.Input {
if Port == DefaultPort && Zone != "" {
return caddy.CaddyfileInput{
Filepath: "Corefile",
Contents: nil,
ServerTypeName: serverType,
}
}
return caddy.CaddyfileInput{
Filepath: "Corefile",
Contents: nil,
ServerTypeName: serverType,
}
},
NewContext: newContext,
})
}
var TestNewContext = newContext
func newContext() caddy.Context {
return &dnsContext{keysToConfigs: make(map[string]*Config)}
}
type dnsContext struct {
keysToConfigs map[string]*Config
// configs is the master list of all site configs.
configs []*Config
}
func (h *dnsContext) saveConfig(key string, cfg *Config) {
h.configs = append(h.configs, cfg)
h.keysToConfigs[key] = cfg
}
// InspectServerBlocks make sure that everything checks out before
// executing directives and otherwise prepares the directives to
// be parsed and executed.
func (h *dnsContext) InspectServerBlocks(sourceFile string, serverBlocks []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
// Normalize and check all the zone names and check for duplicates
dups := map[string]string{}
for _, s := range serverBlocks {
for i, k := range s.Keys {
za, err := normalizeZone(k)
if err != nil {
return nil, err
}
s.Keys[i] = za.String()
if v, ok := dups[za.Zone]; ok {
return nil, fmt.Errorf("cannot serve %s - zone already defined for %v", za, v)
}
dups[za.Zone] = za.String()
// Save the config to our master list, and key it for lookups
cfg := &Config{
Zone: za.Zone,
Port: za.Port,
// TODO(miek): more?
}
h.saveConfig(za.String(), cfg)
}
}
return serverBlocks, nil
}
// MakeServers uses the newly-created siteConfigs to create and return a list of server instances.
func (h *dnsContext) MakeServers() ([]caddy.Server, error) {
// we must map (group) each config to a bind address
groups, err := groupConfigsByListenAddr(h.configs)
if err != nil {
return nil, err
}
// then we create a server for each group
var servers []caddy.Server
for addr, group := range groups {
s, err := NewServer(addr, group)
if err != nil {
return nil, err
}
servers = append(servers, s)
}
return servers, nil
}
// AddMiddleware adds a middleware to a site's middleware stack.
func (sc *Config) AddMiddleware(m Middleware) {
sc.Middleware = append(sc.Middleware, m)
}
// groupSiteConfigsByListenAddr groups site configs by their listen
// (bind) address, so sites that use the same listener can be served
// on the same server instance. The return value maps the listen
// address (what you pass into net.Listen) to the list of site configs.
// This function does NOT vet the configs to ensure they are compatible.
func groupConfigsByListenAddr(configs []*Config) (map[string][]*Config, error) {
groups := make(map[string][]*Config)
for _, conf := range configs {
if conf.Port == "" {
conf.Port = Port
}
addr, err := net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.ListenHost, conf.Port))
if err != nil {
return nil, err
}
addrstr := addr.String()
groups[addrstr] = append(groups[addrstr], conf)
}
return groups, nil
}
const (
// DefaultZone is the default zone.
DefaultZone = "."
// DefaultPort is the default port.
DefaultPort = "2053"
// DefaultRoot is the default root folder.
DefaultRoot = "."
)
// These "soft defaults" are configurable by
// command line flags, etc.
var (
// Root is the site root
Root = DefaultRoot
// Host is the site host
Zone = DefaultZone
// Port is the site port
Port = DefaultPort
// GracefulTimeout is the maximum duration of a graceful shutdown.
GracefulTimeout time.Duration
)
package dnsserver
import (
"log"
"net"
"runtime"
"sync"
"time"
"github.com/miekg/coredns/middleware"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
// Server represents an instance of a server, which serves
// DNS requests at a particular address (host and port). A
// server is capable of serving numerous zones on
// the same address and the listener may be stopped for
// graceful termination (POSIX only).
type Server struct {
Addr string // Address we listen on
mux *dns.ServeMux
server [2]*dns.Server // 0 is a net.Listener, 1 is a net.PacketConn (a *UDPConn) in our case.
l net.Listener
p net.PacketConn
m sync.Mutex // protects listener and packetconn
zones map[string]*Config // zones keyed by their address
dnsWg sync.WaitGroup // used to wait on outstanding connections
connTimeout time.Duration // the maximum duration of a graceful shutdown
}
func NewServer(addr string, group []*Config) (*Server, error) {
s := &Server{
Addr: addr,
zones: make(map[string]*Config),
connTimeout: 5 * time.Second, // TODO(miek): was configurable
}
mux := dns.NewServeMux()
mux.Handle(".", s) // wildcard handler, everything will go through here
s.mux = mux
// We have to bound our wg with one increment
// to prevent a "race condition" that is hard-coded
// into sync.WaitGroup.Wait() - basically, an add
// with a positive delta must be guaranteed to
// occur before Wait() is called on the wg.
// In a way, this kind of acts as a safety barrier.
s.dnsWg.Add(1)
for _, site := range group {
// set the config per zone
s.zones[site.Zone] = site
// compile custom middleware for everything
var stack Handler
for i := len(site.Middleware) - 1; i >= 0; i-- {
stack = site.Middleware[i](stack)
}
site.middlewareChain = stack
}
return s, nil
}
// LocalAddr return the addresses where the server is bound to.
func (s *Server) LocalAddr() net.Addr {
s.m.Lock()
defer s.m.Unlock()
return s.l.Addr()
}
// LocalAddrPacket return the net.PacketConn address where the server is bound to.
func (s *Server) LocalAddrPacket() net.Addr {
s.m.Lock()
defer s.m.Unlock()
return s.p.LocalAddr()
}
// Serve starts the server with an existing listener. It blocks until the server stops.
func (s *Server) Serve(l net.Listener) error {
s.m.Lock()
s.server[tcp] = &dns.Server{Listener: l, Net: "tcp", Handler: s.mux}
s.m.Unlock()
return s.server[tcp].ActivateAndServe()
}
// ServePacket starts the server with an existing packetconn. It blocks until the server stops.
func (s *Server) ServePacket(p net.PacketConn) error {
s.m.Lock()
s.server[udp] = &dns.Server{PacketConn: p, Net: "udp", Handler: s.mux}
s.m.Unlock()
return s.server[udp].ActivateAndServe()
}
func (s *Server) Listen() (net.Listener, error) {
l, err := net.Listen("tcp", s.Addr)
if err != nil {
return nil, err
}
s.m.Lock()
s.l = l
s.m.Unlock()
return l, nil
}
func (s *Server) ListenPacket() (net.PacketConn, error) {
p, err := net.ListenPacket("udp", s.Addr)
if err != nil {
return nil, err
}
s.m.Lock()
s.p = p
s.m.Unlock()
return p, nil
}
// Stop stops the server. It blocks until the server is
// totally stopped. On POSIX systems, it will wait for
// connections to close (up to a max timeout of a few
// seconds); on Windows it will close the listener
// immediately.
func (s *Server) Stop() (err error) {
if runtime.GOOS != "windows" {
// force connections to close after timeout
done := make(chan struct{})
go func() {
s.dnsWg.Done() // decrement our initial increment used as a barrier
s.dnsWg.Wait()
close(done)
}()
// Wait for remaining connections to finish or
// force them all to close after timeout
select {
case <-time.After(s.connTimeout):
case <-done:
}
}
// Close the listener now; this stops the server without delay
s.m.Lock()
if s.l != nil {
err = s.l.Close()
}
if s.p != nil {
err = s.p.Close()
}
for _, s1 := range s.server {
err = s1.Shutdown()
}
s.m.Unlock()
return
}
// ServeDNS is the entry point for every request to the address that s
// is bound to. It acts as a multiplexer for the requests zonename as
// defined in the request so that the correct zone
// (configuration and middleware stack) will handle the request.
func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// TODO(miek): expensive to use defer
defer func() {
// In case the user doesn't enable error middleware, we still
// need to make sure that we stay alive up here
if rec := recover(); rec != nil {
DefaultErrorFunc(w, r, dns.RcodeServerFailure)
}
}()
if m, err := middleware.Edns0Version(r); err != nil { // Wrong EDNS version, return at once.
w.WriteMsg(m)
return
}
q := r.Question[0].Name
b := make([]byte, len(q))
off, end := 0, false
ctx := context.Background()
for {
l := len(q[off:])
for i := 0; i < l; i++ {
b[i] = q[off+i]
// normalize the name for the lookup
if b[i] >= 'A' && b[i] <= 'Z' {
b[i] |= ('a' - 'A')
}
}
if h, ok := s.zones[string(b[:l])]; ok {
if r.Question[0].Qtype != dns.TypeDS {
rcode, _ := h.middlewareChain.ServeDNS(ctx, w, r)
if RcodeNoClientWrite(rcode) {
DefaultErrorFunc(w, r, rcode)
}
return
}
}
off, end = dns.NextLabel(q, off)
if end {
break
}
}
// Wildcard match, if we have found nothing try the root zone as a last resort.
if h, ok := s.zones["."]; ok {
rcode, _ := h.middlewareChain.ServeDNS(ctx, w, r)
if RcodeNoClientWrite(rcode) {
DefaultErrorFunc(w, r, rcode)
}
return
}
// Still here? Error out with REFUSED and some logging
remoteHost := w.RemoteAddr().String()
DefaultErrorFunc(w, r, dns.RcodeRefused)
log.Printf("[INFO] \"%s %s %s\" - No such zone at %s (Remote: %s)", dns.Type(r.Question[0].Qtype), dns.Class(r.Question[0].Qclass), q, s.Addr, remoteHost)
}
// DefaultErrorFunc responds to an DNS request with an error.
func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
state := middleware.State{W: w, Req: r}
answer := new(dns.Msg)
answer.SetRcode(r, rcode)
state.SizeAndDo(answer)
w.WriteMsg(answer)
}
func RcodeNoClientWrite(rcode int) bool {
switch rcode {
case dns.RcodeServerFailure:
fallthrough
case dns.RcodeRefused:
fallthrough
case dns.RcodeFormatError:
fallthrough
case dns.RcodeNotImplemented:
return true
}
return false
}
const (
tcp = 0
udp = 1
)
package core
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"sync"
)
// isLocalhost returns true if host looks explicitly like a localhost address.
func isLocalhost(host string) bool {
return host == "localhost" || host == "::1" || strings.HasPrefix(host, "127.")
}
// checkFdlimit issues a warning if the OS max file descriptors is below a recommended minimum.
func checkFdlimit() {
const min = 4096
// Warn if ulimit is too low for production sites
if runtime.GOOS == "linux" || runtime.GOOS == "darwin" {
out, err := exec.Command("sh", "-c", "ulimit -n").Output() // use sh because ulimit isn't in Linux $PATH
if err == nil {
// Note that an error here need not be reported
lim, err := strconv.Atoi(string(bytes.TrimSpace(out)))
if err == nil && lim < min {
fmt.Printf("Warning: File descriptor limit %d is too low for production sites. At least %d is recommended. Set with \"ulimit -n %d\".\n", lim, min, min)
}
}
}
}
// signalSuccessToParent tells the parent our status using pipe at index 3.
// If this process is not a restart, this function does nothing.
// Calling this function once this process has successfully initialized
// is vital so that the parent process can unblock and kill itself.
// This function is idempotent; it executes at most once per process.
func signalSuccessToParent() {
signalParentOnce.Do(func() {
if IsRestart() {
ppipe := os.NewFile(3, "") // parent is reading from pipe at index 3
_, err := ppipe.Write([]byte("success")) // we must send some bytes to the parent
if err != nil {
log.Printf("[ERROR] Communicating successful init to parent: %v", err)
}
ppipe.Close()
}
})
}
// signalParentOnce is used to make sure that the parent is only
// signaled once; doing so more than once breaks whatever socket is
// at fd 4 (the reason for this is still unclear - to reproduce,
// call Stop() and Start() in succession at least once after a
// restart, then try loading first host of Corefile in the browser).
// Do not use this directly - call signalSuccessToParent instead.
var signalParentOnce sync.Once
// corefileGob maps bind address to index of the file descriptor
// in the Files array passed to the child process. It also contains
// the corefile contents and other state needed by the new process.
// Used only during graceful restarts where a new process is spawned.
type corefileGob struct {
ListenerFds map[string]uintptr
Corefile Input
OnDemandTLSCertsIssued int32
}
// IsRestart returns whether this process is, according
// to env variables, a fork as part of a graceful restart.
func IsRestart() bool {
return os.Getenv("COREDNS_RESTART") == "true"
}
// writePidFile writes the process ID to the file at PidFile, if specified.
func writePidFile() error {
pid := []byte(strconv.Itoa(os.Getpid()) + "\n")
return ioutil.WriteFile(PidFile, pid, 0644)
}
// CorefileInput represents a Corefile as input
// and is simply a convenient way to implement
// the Input interface.
type CorefileInput struct {
Filepath string
Contents []byte
RealFile bool
}
// Body returns c.Contents.
func (c CorefileInput) Body() []byte { return c.Contents }
// Path returns c.Filepath.
func (c CorefileInput) Path() string { return c.Filepath }
// IsFile returns true if the original input was a real file on the file system.
func (c CorefileInput) IsFile() bool { return c.RealFile }
package https
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"log"
"strings"
"sync"
"time"
"github.com/xenolf/lego/acme"
"golang.org/x/crypto/ocsp"
)
// certCache stores certificates in memory,
// keying certificates by name.
var certCache = make(map[string]Certificate)
var certCacheMu sync.RWMutex
// Certificate is a tls.Certificate with associated metadata tacked on.
// Even if the metadata can be obtained by parsing the certificate,
// we can be more efficient by extracting the metadata once so it's
// just there, ready to use.
type Certificate struct {
tls.Certificate
// Names is the list of names this certificate is written for.
// The first is the CommonName (if any), the rest are SAN.
Names []string
// NotAfter is when the certificate expires.
NotAfter time.Time
// Managed certificates are certificates that CoreDNS is managing,
// as opposed to the user specifying a certificate and key file
// or directory and managing the certificate resources themselves.
Managed bool
// OnDemand certificates are obtained or loaded on-demand during TLS
// handshakes (as opposed to preloaded certificates, which are loaded
// at startup). If OnDemand is true, Managed must necessarily be true.
// OnDemand certificates are maintained in the background just like
// preloaded ones, however, if an OnDemand certificate fails to renew,
// it is removed from the in-memory cache.
OnDemand bool
// OCSP contains the certificate's parsed OCSP response.
OCSP *ocsp.Response
}
// getCertificate gets a certificate that matches name (a server name)
// from the in-memory cache. If there is no exact match for name, it
// will be checked against names of the form '*.example.com' (wildcard
// certificates) according to RFC 6125. If a match is found, matched will
// be true. If no matches are found, matched will be false and a default
// certificate will be returned with defaulted set to true. If no default
// certificate is set, defaulted will be set to false.
//
// The logic in this function is adapted from the Go standard library,
// which is by the Go Authors.
//
// This function is safe for concurrent use.
func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
var ok bool
// Not going to trim trailing dots here since RFC 3546 says,
// "The hostname is represented ... without a trailing dot."
// Just normalize to lowercase.
name = strings.ToLower(name)
certCacheMu.RLock()
defer certCacheMu.RUnlock()
// exact match? great, let's use it
if cert, ok = certCache[name]; ok {
matched = true
return
}
// try replacing labels in the name with wildcards until we get a match
labels := strings.Split(name, ".")
for i := range labels {
labels[i] = "*"
candidate := strings.Join(labels, ".")
if cert, ok = certCache[candidate]; ok {
matched = true
return
}
}
// if nothing matches, use the default certificate or bust
cert, defaulted = certCache[""]
return
}
// cacheManagedCertificate loads the certificate for domain into the
// cache, flagging it as Managed and, if onDemand is true, as OnDemand
// (meaning that it was obtained or loaded during a TLS handshake).
//
// This function is safe for concurrent use.
func cacheManagedCertificate(domain string, onDemand bool) (Certificate, error) {
cert, err := makeCertificateFromDisk(storage.SiteCertFile(domain), storage.SiteKeyFile(domain))
if err != nil {
return cert, err
}
cert.Managed = true
cert.OnDemand = onDemand
cacheCertificate(cert)
return cert, nil
}
// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile
// and keyFile, which must be in PEM format. It stores the certificate in
// memory. The Managed and OnDemand flags of the certificate will be set to
// false.
//
// This function is safe for concurrent use.
func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
cert, err := makeCertificateFromDisk(certFile, keyFile)
if err != nil {
return err
}
cacheCertificate(cert)
return nil
}
// cacheUnmanagedCertificatePEMBytes makes a certificate out of the PEM bytes
// of the certificate and key, then caches it in memory.
//
// This function is safe for concurrent use.
func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
cert, err := makeCertificate(certBytes, keyBytes)
if err != nil {
return err
}
cacheCertificate(cert)
return nil
}
// makeCertificateFromDisk makes a Certificate by loading the
// certificate and key files. It fills out all the fields in
// the certificate except for the Managed and OnDemand flags.
// (It is up to the caller to set those.)
func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
certPEMBlock, err := ioutil.ReadFile(certFile)
if err != nil {
return Certificate{}, err
}
keyPEMBlock, err := ioutil.ReadFile(keyFile)
if err != nil {
return Certificate{}, err
}
return makeCertificate(certPEMBlock, keyPEMBlock)
}
// makeCertificate turns a certificate PEM bundle and a key PEM block into
// a Certificate, with OCSP and other relevant metadata tagged with it,
// except for the OnDemand and Managed flags. It is up to the caller to
// set those properties.
func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
var cert Certificate
// Convert to a tls.Certificate
tlsCert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil {
return cert, err
}
if len(tlsCert.Certificate) == 0 {
return cert, errors.New("certificate is empty")
}
// Parse leaf certificate and extract relevant metadata
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return cert, err
}
if leaf.Subject.CommonName != "" {
cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)}
}
for _, name := range leaf.DNSNames {
if name != leaf.Subject.CommonName {
cert.Names = append(cert.Names, strings.ToLower(name))
}
}
cert.NotAfter = leaf.NotAfter
// Staple OCSP
ocspBytes, ocspResp, err := acme.GetOCSPForCert(certPEMBlock)
if err != nil {
// An error here is not a problem because a certificate may simply
// not contain a link to an OCSP server. But we should log it anyway.
log.Printf("[WARNING] No OCSP stapling for %v: %v", cert.Names, err)
} else if ocspResp.Status == ocsp.Good {
tlsCert.OCSPStaple = ocspBytes
cert.OCSP = ocspResp
}
cert.Certificate = tlsCert
return cert, nil
}
// cacheCertificate adds cert to the in-memory cache. If the cache is
// empty, cert will be used as the default certificate. If the cache is
// full, random entries are deleted until there is room to map all the
// names on the certificate.
//
// This certificate will be keyed to the names in cert.Names. Any name
// that is already a key in the cache will be replaced with this cert.
//
// This function is safe for concurrent use.
func cacheCertificate(cert Certificate) {
certCacheMu.Lock()
if _, ok := certCache[""]; !ok {
// use as default
cert.Names = append(cert.Names, "")
certCache[""] = cert
}
for len(certCache)+len(cert.Names) > 10000 {
// for simplicity, just remove random elements
for key := range certCache {
if key == "" { // ... but not the default cert
continue
}
delete(certCache, key)
break
}
}
for _, name := range cert.Names {
certCache[name] = cert
}
certCacheMu.Unlock()
}
package https
import "testing"
func TestUnexportedGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
// When cache is empty
if _, matched, defaulted := getCertificate("example.com"); matched || defaulted {
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
}
// When cache has one certificate in it (also is default)
defaultCert := Certificate{Names: []string{"example.com", ""}}
certCache[""] = defaultCert
certCache["example.com"] = defaultCert
if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" {
t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}}
if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
}
// When no certificate matches, the default is returned
if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted {
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
} else if cert.Names[0] != "example.com" {
t.Errorf("Expected default cert, got: %v", cert)
}
}
func TestCacheCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}})
if _, ok := certCache["example.com"]; !ok {
t.Error("Expected first cert to be cached by key 'example.com', but it wasn't")
}
if _, ok := certCache["sub.example.com"]; !ok {
t.Error("Expected first cert to be cached by key 'sub.exmaple.com', but it wasn't")
}
if cert, ok := certCache[""]; !ok || cert.Names[2] != "" {
t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't")
}
cacheCertificate(Certificate{Names: []string{"example2.com"}})
if _, ok := certCache["example2.com"]; !ok {
t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't")
}
if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" {
t.Error("Expected second cert to NOT be cached as default, but it was")
}
}
package https
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net"
"sync"
"time"
"github.com/miekg/coredns/server"
"github.com/xenolf/lego/acme"
)
// acmeMu ensures that only one ACME challenge occurs at a time.
var acmeMu sync.Mutex
// ACMEClient is an acme.Client with custom state attached.
type ACMEClient struct {
*acme.Client
AllowPrompts bool // if false, we assume AlternatePort must be used
}
// NewACMEClient creates a new ACMEClient given an email and whether
// prompting the user is allowed. Clients should not be kept and
// re-used over long periods of time, but immediate re-use is more
// efficient than re-creating on every iteration.
var NewACMEClient = func(email string, allowPrompts bool) (*ACMEClient, error) {
// Look up or create the LE user account
leUser, err := getUser(email)
if err != nil {
return nil, err
}
// The client facilitates our communication with the CA server.
client, err := acme.NewClient(CAUrl, &leUser, KeyType)
if err != nil {
return nil, err
}
// If not registered, the user must register an account with the CA
// and agree to terms
if leUser.Registration == nil {
reg, err := client.Register()
if err != nil {
return nil, errors.New("registration error: " + err.Error())
}
leUser.Registration = reg
if allowPrompts { // can't prompt a user who isn't there
if !Agreed && reg.TosURL == "" {
Agreed = promptUserAgreement(saURL, false) // TODO - latest URL
}
if !Agreed && reg.TosURL == "" {
return nil, errors.New("user must agree to terms")
}
}
err = client.AgreeToTOS()
if err != nil {
saveUser(leUser) // Might as well try, right?
return nil, errors.New("error agreeing to terms: " + err.Error())
}
// save user to the file system
err = saveUser(leUser)
if err != nil {
return nil, errors.New("could not save user: " + err.Error())
}
}
return &ACMEClient{
Client: client,
AllowPrompts: allowPrompts,
}, nil
}
// NewACMEClientGetEmail creates a new ACMEClient and gets an email
// address at the same time (a server config is required, since it
// may contain an email address in it).
func NewACMEClientGetEmail(config server.Config, allowPrompts bool) (*ACMEClient, error) {
return NewACMEClient(getEmail(config, allowPrompts), allowPrompts)
}
// Configure configures c according to bindHost, which is the host (not
// whole address) to bind the listener to in solving the http and tls-sni
// challenges.
func (c *ACMEClient) Configure(bindHost string) {
// If we allow prompts, operator must be present. In our case,
// that is synonymous with saying the server is not already
// started. So if the user is still there, we don't use
// AlternatePort because we don't need to proxy the challenges.
// Conversely, if the operator is not there, the server has
// already started and we need to proxy the challenge.
if c.AllowPrompts {
// Operator is present; server is not already listening
c.SetHTTPAddress(net.JoinHostPort(bindHost, ""))
c.SetTLSAddress(net.JoinHostPort(bindHost, ""))
//c.ExcludeChallenges([]acme.Challenge{acme.DNS01})
} else {
// Operator is not present; server is started, so proxy challenges
c.SetHTTPAddress(net.JoinHostPort(bindHost, AlternatePort))
c.SetTLSAddress(net.JoinHostPort(bindHost, AlternatePort))
//c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01})
}
c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) // TODO: can we proxy TLS challenges? and we should support DNS...
}
// Obtain obtains a single certificate for names. It stores the certificate
// on the disk if successful.
func (c *ACMEClient) Obtain(names []string) error {
Attempts:
for attempts := 0; attempts < 2; attempts++ {
acmeMu.Lock()
certificate, failures := c.ObtainCertificate(names, true, nil)
acmeMu.Unlock()
if len(failures) > 0 {
// Error - try to fix it or report it to the user and abort
var errMsg string // we'll combine all the failures into a single error message
var promptedForAgreement bool // only prompt user for agreement at most once
for errDomain, obtainErr := range failures {
// TODO: Double-check, will obtainErr ever be nil?
if tosErr, ok := obtainErr.(acme.TOSError); ok {
// Terms of Service agreement error; we can probably deal with this
if !Agreed && !promptedForAgreement && c.AllowPrompts {
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
promptedForAgreement = true
}
if Agreed || !c.AllowPrompts {
err := c.AgreeToTOS()
if err != nil {
return errors.New("error agreeing to updated terms: " + err.Error())
}
continue Attempts
}
}
// If user did not agree or it was any other kind of error, just append to the list of errors
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
}
return errors.New(errMsg)
}
// Success - immediately save the certificate resource
err := saveCertResource(certificate)
if err != nil {
return fmt.Errorf("error saving assets for %v: %v", names, err)
}
break
}
return nil
}
// Renew renews the managed certificate for name. Right now our storage
// mechanism only supports one name per certificate, so this function only
// accepts one domain as input. It can be easily modified to support SAN
// certificates if, one day, they become desperately needed enough that our
// storage mechanism is upgraded to be more complex to support SAN certs.
//
// Anyway, this function is safe for concurrent use.
func (c *ACMEClient) Renew(name string) error {
// Prepare for renewal (load PEM cert, key, and meta)
certBytes, err := ioutil.ReadFile(storage.SiteCertFile(name))
if err != nil {
return err
}
keyBytes, err := ioutil.ReadFile(storage.SiteKeyFile(name))
if err != nil {
return err
}
metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(name))
if err != nil {
return err
}
var certMeta acme.CertificateResource
err = json.Unmarshal(metaBytes, &certMeta)
certMeta.Certificate = certBytes
certMeta.PrivateKey = keyBytes
// Perform renewal and retry if necessary, but not too many times.
var newCertMeta acme.CertificateResource
var success bool
for attempts := 0; attempts < 2; attempts++ {
acmeMu.Lock()
newCertMeta, err = c.RenewCertificate(certMeta, true)
acmeMu.Unlock()
if err == nil {
success = true
break
}
// If the legal terms changed and need to be agreed to again,
// we can handle that.
if _, ok := err.(acme.TOSError); ok {
err := c.AgreeToTOS()
if err != nil {
return err
}
continue
}
// For any other kind of error, wait 10s and try again.
time.Sleep(10 * time.Second)
}
if !success {
return errors.New("too many renewal attempts; last error: " + err.Error())
}
return saveCertResource(newCertMeta)
}
package https
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io/ioutil"
"os"
)
// loadPrivateKey loads a PEM-encoded ECC/RSA private key from file.
func loadPrivateKey(file string) (crypto.PrivateKey, error) {
keyBytes, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
keyBlock, _ := pem.Decode(keyBytes)
switch keyBlock.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(keyBlock.Bytes)
}
return nil, errors.New("unknown private key type")
}
// savePrivateKey saves a PEM-encoded ECC/RSA private key to file.
func savePrivateKey(key crypto.PrivateKey, file string) error {
var pemType string
var keyBytes []byte
switch key := key.(type) {
case *ecdsa.PrivateKey:
var err error
pemType = "EC"
keyBytes, err = x509.MarshalECPrivateKey(key)
if err != nil {
return err
}
case *rsa.PrivateKey:
pemType = "RSA"
keyBytes = x509.MarshalPKCS1PrivateKey(key)
}
pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes}
keyOut, err := os.Create(file)
if err != nil {
return err
}
keyOut.Chmod(0600)
defer keyOut.Close()
return pem.Encode(keyOut, &pemKey)
}
package https
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"os"
"runtime"
"testing"
)
func TestSaveAndLoadRSAPrivateKey(t *testing.T) {
keyFile := "test.key"
defer os.Remove(keyFile)
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
// test save
err = savePrivateKey(privateKey, keyFile)
if err != nil {
t.Fatal("error saving private key:", err)
}
// it doesn't make sense to test file permission on windows
if runtime.GOOS != "windows" {
// get info of the key file
info, err := os.Stat(keyFile)
if err != nil {
t.Fatal("error stating private key:", err)
}
// verify permission of key file is correct
if info.Mode().Perm() != 0600 {
t.Error("Expected key file to have permission 0600, but it wasn't")
}
}
// test load
loadedKey, err := loadPrivateKey(keyFile)
if err != nil {
t.Error("error loading private key:", err)
}
// verify loaded key is correct
if !PrivateKeysSame(privateKey, loadedKey) {
t.Error("Expected key bytes to be the same, but they weren't")
}
}
func TestSaveAndLoadECCPrivateKey(t *testing.T) {
keyFile := "test.key"
defer os.Remove(keyFile)
privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
t.Fatal(err)
}
// test save
err = savePrivateKey(privateKey, keyFile)
if err != nil {
t.Fatal("error saving private key:", err)
}
// it doesn't make sense to test file permission on windows
if runtime.GOOS != "windows" {
// get info of the key file
info, err := os.Stat(keyFile)
if err != nil {
t.Fatal("error stating private key:", err)
}
// verify permission of key file is correct
if info.Mode().Perm() != 0600 {
t.Error("Expected key file to have permission 0600, but it wasn't")
}
}
// test load
loadedKey, err := loadPrivateKey(keyFile)
if err != nil {
t.Error("error loading private key:", err)
}
// verify loaded key is correct
if !PrivateKeysSame(privateKey, loadedKey) {
t.Error("Expected key bytes to be the same, but they weren't")
}
}
// PrivateKeysSame compares the bytes of a and b and returns true if they are the same.
func PrivateKeysSame(a, b crypto.PrivateKey) bool {
return bytes.Equal(PrivateKeyBytes(a), PrivateKeyBytes(b))
}
// PrivateKeyBytes returns the bytes of DER-encoded key.
func PrivateKeyBytes(key crypto.PrivateKey) []byte {
var keyBytes []byte
switch key := key.(type) {
case *rsa.PrivateKey:
keyBytes = x509.MarshalPKCS1PrivateKey(key)
case *ecdsa.PrivateKey:
keyBytes, _ = x509.MarshalECPrivateKey(key)
}
return keyBytes
}
package https
import (
"crypto/tls"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strings"
)
const challengeBasePath = "/.well-known/acme-challenge"
// RequestCallback proxies challenge requests to ACME client if the
// request path starts with challengeBasePath. It returns true if it
// handled the request and no more needs to be done; it returns false
// if this call was a no-op and the request still needs handling.
func RequestCallback(w http.ResponseWriter, r *http.Request) bool {
if strings.HasPrefix(r.URL.Path, challengeBasePath) {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
upstream, err := url.Parse(scheme + "://localhost:" + AlternatePort)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Printf("[ERROR] ACME proxy handler: %v", err)
return true
}
proxy := httputil.NewSingleHostReverseProxy(upstream)
proxy.Transport = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs
}
proxy.ServeHTTP(w, r)
return true
}
return false
}
package https
import (
"net"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestCallbackNoOp(t *testing.T) {
// try base paths that aren't handled by this handler
for _, url := range []string{
"http://localhost/",
"http://localhost/foo.html",
"http://localhost/.git",
"http://localhost/.well-known/",
"http://localhost/.well-known/acme-challenging",
} {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Fatalf("Could not craft request, got error: %v", err)
}
rw := httptest.NewRecorder()
if RequestCallback(rw, req) {
t.Errorf("Got true with this URL, but shouldn't have: %s", url)
}
}
}
func TestRequestCallbackSuccess(t *testing.T) {
expectedPath := challengeBasePath + "/asdf"
// Set up fake acme handler backend to make sure proxying succeeds
var proxySuccess bool
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxySuccess = true
if r.URL.Path != expectedPath {
t.Errorf("Expected path '%s' but got '%s' instead", expectedPath, r.URL.Path)
}
}))
// Custom listener that uses the port we expect
ln, err := net.Listen("tcp", "127.0.0.1:"+AlternatePort)
if err != nil {
t.Fatalf("Unable to start test server listener: %v", err)
}
ts.Listener = ln
// Start our engines and run the test
ts.Start()
defer ts.Close()
req, err := http.NewRequest("GET", "http://127.0.0.1:"+AlternatePort+expectedPath, nil)
if err != nil {
t.Fatalf("Could not craft request, got error: %v", err)
}
rw := httptest.NewRecorder()
RequestCallback(rw, req)
if !proxySuccess {
t.Fatal("Expected request to be proxied, but it wasn't")
}
}
This diff is collapsed.
package https
import (
"crypto/tls"
"crypto/x509"
"testing"
)
func TestGetCertificate(t *testing.T) {
defer func() { certCache = make(map[string]Certificate) }()
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
helloNoSNI := &tls.ClientHelloInfo{}
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
// When cache is empty
if cert, err := GetCertificate(hello); err == nil {
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
}
if cert, err := GetCertificate(helloNoSNI); err == nil {
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
}
// When cache has one certificate in it (also is default)
defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
certCache[""] = defaultCert
certCache["example.com"] = defaultCert
if cert, err := GetCertificate(hello); err != nil {
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
}
if cert, err := GetCertificate(helloNoSNI); err != nil {
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
}
// When retrieving wildcard certificate
certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
if cert, err := GetCertificate(helloSub); err != nil {
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
}
// When no certificate matches, the default is returned
if cert, err := GetCertificate(helloNoMatch); err != nil {
t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
} else if cert.Leaf.DNSNames[0] != "example.com" {
t.Errorf("Expected default cert with no matches, got: %v", cert)
}
}
This diff is collapsed.
This diff is collapsed.
package https
import (
"log"
"time"
"github.com/miekg/coredns/server"
"golang.org/x/crypto/ocsp"
)
const (
// RenewInterval is how often to check certificates for renewal.
RenewInterval = 12 * time.Hour
// OCSPInterval is how often to check if OCSP stapling needs updating.
OCSPInterval = 1 * time.Hour
)
// maintainAssets is a permanently-blocking function
// that loops indefinitely and, on a regular schedule, checks
// certificates for expiration and initiates a renewal of certs
// that are expiring soon. It also updates OCSP stapling and
// performs other maintenance of assets.
//
// You must pass in the channel which you'll close when
// maintenance should stop, to allow this goroutine to clean up
// after itself and unblock.
func maintainAssets(stopChan chan struct{}) {
renewalTicker := time.NewTicker(RenewInterval)
ocspTicker := time.NewTicker(OCSPInterval)
for {
select {
case <-renewalTicker.C:
log.Println("[INFO] Scanning for expiring certificates")
renewManagedCertificates(false)
log.Println("[INFO] Done checking certificates")
case <-ocspTicker.C:
log.Println("[INFO] Scanning for stale OCSP staples")
updateOCSPStaples()
log.Println("[INFO] Done checking OCSP staples")
case <-stopChan:
renewalTicker.Stop()
ocspTicker.Stop()
log.Println("[INFO] Stopped background maintenance routine")
return
}
}
}
func renewManagedCertificates(allowPrompts bool) (err error) {
var renewed, deleted []Certificate
var client *ACMEClient
visitedNames := make(map[string]struct{})
certCacheMu.RLock()
for name, cert := range certCache {
if !cert.Managed {
continue
}
// the list of names on this cert should never be empty...
if cert.Names == nil || len(cert.Names) == 0 {
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v", name, cert.Names)
deleted = append(deleted, cert)
continue
}
// skip names whose certificate we've already renewed
if _, ok := visitedNames[name]; ok {
continue
}
for _, name := range cert.Names {
visitedNames[name] = struct{}{}
}
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
if timeLeft < renewDurationBefore {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
if client == nil {
client, err = NewACMEClientGetEmail(server.Config{}, allowPrompts)
if err != nil {
return err
}
client.Configure("") // TODO: Bind address of relevant listener, yuck
}
err := client.Renew(cert.Names[0]) // managed certs better have only one name
if err != nil {
if client.AllowPrompts && timeLeft < 0 {
// Certificate renewal failed, the operator is present, and the certificate
// is already expired; we should stop immediately and return the error. Note
// that we used to do this any time a renewal failed at startup. However,
// after discussion in https://github.com/miekg/coredns/issues/642 we decided to
// only stop startup if the certificate is expired. We still log the error
// otherwise.
certCacheMu.RUnlock()
return err
}
log.Printf("[ERROR] %v", err)
if cert.OnDemand {
deleted = append(deleted, cert)
}
} else {
renewed = append(renewed, cert)
}
}
}
certCacheMu.RUnlock()
// Apply changes to the cache
for _, cert := range renewed {
_, err := cacheManagedCertificate(cert.Names[0], cert.OnDemand)
if err != nil {
if client.AllowPrompts {
return err // operator is present, so report error immediately
}
log.Printf("[ERROR] %v", err)
}
}
for _, cert := range deleted {
certCacheMu.Lock()
for _, name := range cert.Names {
delete(certCache, name)
}
certCacheMu.Unlock()
}
return nil
}
func updateOCSPStaples() {
// Create a temporary place to store updates
// until we release the potentially long-lived
// read lock and use a short-lived write lock.
type ocspUpdate struct {
rawBytes []byte
parsed *ocsp.Response
}
updated := make(map[string]ocspUpdate)
// A single SAN certificate maps to multiple names, so we use this
// set to make sure we don't waste cycles checking OCSP for the same
// certificate multiple times.
visited := make(map[string]struct{})
certCacheMu.RLock()
for name, cert := range certCache {
// skip this certificate if we've already visited it,
// and if not, mark all the names as visited
if _, ok := visited[name]; ok {
continue
}
for _, n := range cert.Names {
visited[n] = struct{}{}
}
// no point in updating OCSP for expired certificates
if time.Now().After(cert.NotAfter) {
continue
}
var lastNextUpdate time.Time
if cert.OCSP != nil {
// start checking OCSP staple about halfway through validity period for good measure
lastNextUpdate = cert.OCSP.NextUpdate
refreshTime := cert.OCSP.ThisUpdate.Add(lastNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
// since OCSP is already stapled, we need only check if we're in that "refresh window"
if time.Now().Before(refreshTime) {
continue
}
}
err := stapleOCSP(&cert, nil)
if err != nil {
if cert.OCSP != nil {
// if it was no staple before, that's fine, otherwise we should log the error
log.Printf("[ERROR] Checking OCSP for %s: %v", name, err)
}
continue
}
// By this point, we've obtained the latest OCSP response.
// If there was no staple before, or if the response is updated, make
// sure we apply the update to all names on the certificate.
if lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate {
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
for _, n := range cert.Names {
updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
}
}
}
certCacheMu.RUnlock()
// This write lock should be brief since we have all the info we need now.
certCacheMu.Lock()
for name, update := range updated {
cert := certCache[name]
cert.OCSP = update.parsed
cert.Certificate.OCSPStaple = update.rawBytes
certCache[name] = cert
}
certCacheMu.Unlock()
}
// renewDurationBefore is how long before expiration to renew certificates.
const renewDurationBefore = (24 * time.Hour) * 30
This diff is collapsed.
package https
// TODO(miek): all fail
/*
func TestMain(m *testing.M) {
// Write test certificates to disk before tests, and clean up
// when we're done.
err := ioutil.WriteFile(certFile, testCert, 0644)
if err != nil {
log.Fatal(err)
}
err = ioutil.WriteFile(keyFile, testKey, 0644)
if err != nil {
os.Remove(certFile)
log.Fatal(err)
}
result := m.Run()
os.Remove(certFile)
os.Remove(keyFile)
os.Exit(result)
}
func TestSetupParseBasic(t *testing.T) {
c := setup.NewTestController(`tls ` + certFile + ` ` + keyFile + ``)
_, err := Setup(c)
if err != nil {
t.Errorf("Expected no errors, got: %v", err)
}
// Basic checks
if !c.TLS.Manual {
t.Error("Expected TLS Manual=true, but was false")
}
if !c.TLS.Enabled {
t.Error("Expected TLS Enabled=true, but was false")
}
// Security defaults
if c.TLS.ProtocolMinVersion != tls.VersionTLS10 {
t.Errorf("Expected 'tls1.0 (0x0301)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion)
}
if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 {
t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", c.TLS.ProtocolMaxVersion)
}
// Cipher checks
expectedCiphers := []uint16{
tls.TLS_FALLBACK_SCSV,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
}
// Ensure count is correct (plus one for TLS_FALLBACK_SCSV)
if len(c.TLS.Ciphers) != len(expectedCiphers) {
t.Errorf("Expected %v Ciphers (including TLS_FALLBACK_SCSV), got %v",
len(expectedCiphers), len(c.TLS.Ciphers))
}
// Ensure ordering is correct
for i, actual := range c.TLS.Ciphers {
if actual != expectedCiphers[i] {
t.Errorf("Expected cipher in position %d to be %0x, got %0x", i, expectedCiphers[i], actual)
}
}
if !c.TLS.PreferServerCipherSuites {
t.Error("Expected PreferServerCipherSuites = true, but was false")
}
}
func TestSetupParseIncompleteParams(t *testing.T) {
// Using tls without args is an error because it's unnecessary.
c := setup.NewTestController(`tls`)
_, err := Setup(c)
if err == nil {
t.Error("Expected an error, but didn't get one")
}
}
func TestSetupParseWithOptionalParams(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` {
protocols ssl3.0 tls1.2
ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
}`
c := setup.NewTestController(params)
_, err := Setup(c)
if err != nil {
t.Errorf("Expected no errors, got: %v", err)
}
if c.TLS.ProtocolMinVersion != tls.VersionSSL30 {
t.Errorf("Expected 'ssl3.0 (0x0300)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion)
}
if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 {
t.Errorf("Expected 'tls1.2 (0x0302)' as ProtocolMaxVersion, got %#v", c.TLS.ProtocolMaxVersion)
}
if len(c.TLS.Ciphers)-1 != 3 {
t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1)
}
}
func TestSetupDefaultWithOptionalParams(t *testing.T) {
params := `tls {
ciphers RSA-3DES-EDE-CBC-SHA
}`
c := setup.NewTestController(params)
_, err := Setup(c)
if err != nil {
t.Errorf("Expected no errors, got: %v", err)
}
if len(c.TLS.Ciphers)-1 != 1 {
t.Errorf("Expected 1 ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1)
}
}
// TODO: If we allow this... but probably not a good idea.
// func TestSetupDisableHTTPRedirect(t *testing.T) {
// c := NewTestController(`tls {
// allow_http
// }`)
// _, err := TLS(c)
// if err != nil {
// t.Errorf("Expected no error, but got %v", err)
// }
// if !c.TLS.DisableHTTPRedir {
// t.Error("Expected HTTP redirect to be disabled, but it wasn't")
// }
// }
func TestSetupParseWithWrongOptionalParams(t *testing.T) {
// Test protocols wrong params
params := `tls ` + certFile + ` ` + keyFile + ` {
protocols ssl tls
}`
c := setup.NewTestController(params)
_, err := Setup(c)
if err == nil {
t.Errorf("Expected errors, but no error returned")
}
// Test ciphers wrong params
params = `tls ` + certFile + ` ` + keyFile + ` {
ciphers not-valid-cipher
}`
c = setup.NewTestController(params)
_, err = Setup(c)
if err == nil {
t.Errorf("Expected errors, but no error returned")
}
}
func TestSetupParseWithClientAuth(t *testing.T) {
params := `tls ` + certFile + ` ` + keyFile + ` {
clients client_ca.crt client2_ca.crt
}`
c := setup.NewTestController(params)
_, err := Setup(c)
if err != nil {
t.Errorf("Expected no errors, got: %v", err)
}
if count := len(c.TLS.ClientCerts); count != 2 {
t.Fatalf("Expected two client certs, had %d", count)
}
if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" {
t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual)
}
if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" {
t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual)
}
// Test missing client cert file
params = `tls ` + certFile + ` ` + keyFile + ` {
clients
}`
c = setup.NewTestController(params)
_, err = Setup(c)
if err == nil {
t.Errorf("Expected an error, but no error returned")
}
}
const (
certFile = "test_cert.pem"
keyFile = "test_key.pem"
)
var testCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBkjCCATmgAwIBAgIJANfFCBcABL6LMAkGByqGSM49BAEwFDESMBAGA1UEAxMJ
bG9jYWxob3N0MB4XDTE2MDIxMDIyMjAyNFoXDTE4MDIwOTIyMjAyNFowFDESMBAG
A1UEAxMJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEs22MtnG7
9K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDLSiVQvFZ6lUszTlczNxVk
pEfqrM6xAupB7qN1MHMwHQYDVR0OBBYEFHxYDvAxUwL4XrjPev6qZ/BiLDs5MEQG
A1UdIwQ9MDuAFHxYDvAxUwL4XrjPev6qZ/BiLDs5oRikFjAUMRIwEAYDVQQDEwls
b2NhbGhvc3SCCQDXxQgXAAS+izAMBgNVHRMEBTADAQH/MAkGByqGSM49BAEDSAAw
RQIgRvBqbyJM2JCJqhA1FmcoZjeMocmhxQHTt1c+1N2wFUgCIQDtvrivbBPA688N
Qh3sMeAKNKPsx5NxYdoWuu9KWcKz9A==
-----END CERTIFICATE-----
`)
var testKey = []byte(`-----BEGIN EC PARAMETERS-----
BggqhkjOPQMBBw==
-----END EC PARAMETERS-----
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIGLtRmwzYVcrH3J0BnzYbGPdWVF10i9p6mxkA4+b2fURoAoGCCqGSM49
AwEHoUQDQgAEs22MtnG79K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDL
SiVQvFZ6lUszTlczNxVkpEfqrM6xAupB7g==
-----END EC PRIVATE KEY-----
`)
*/
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
glob0.host0 {
dir2 arg1
}
glob0.host1 {
}
glob1.host0 {
dir1
dir2 arg1
}
glob2.host0 {
dir2 arg1
}
dir2 arg1 arg2
dir3
\ No newline at end of file
host1 {
dir1
dir2 arg1
}
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
// Package parse provides facilities for parsing configuration files.
package parse
import "io"
// ServerBlocks parses the input just enough to organize tokens,
// in order, by server block. No further parsing is performed.
// If checkDirectives is true, only valid directives will be allowed
// otherwise we consider it a parse error. Server blocks are returned
// in the order in which they appear.
func ServerBlocks(filename string, input io.Reader, checkDirectives bool) ([]ServerBlock, error) {
p := parser{Dispenser: NewDispenser(filename, input)}
p.checkDirectives = checkDirectives
blocks, err := p.parseAll()
return blocks, err
}
// allTokens lexes the entire input, but does not parse it.
// It returns all the tokens from the input, unstructured
// and in order.
func allTokens(input io.Reader) (tokens []token) {
l := new(lexer)
l.load(input)
for l.next() {
tokens = append(tokens, l.token)
}
return
}
// ValidDirectives is a set of directives that are valid (unordered). Populated
// by config package's init function.
var ValidDirectives = make(map[string]struct{})
package parse
import (
"strings"
"testing"
)
func TestAllTokens(t *testing.T) {
input := strings.NewReader("a b c\nd e")
expected := []string{"a", "b", "c", "d", "e"}
tokens := allTokens(input)
if len(tokens) != len(expected) {
t.Fatalf("Expected %d tokens, got %d", len(expected), len(tokens))
}
for i, val := range expected {
if tokens[i].text != val {
t.Errorf("Token %d should be '%s' but was '%s'", i, val, tokens[i].text)
}
}
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
package core
func trapSignalsPosix() {}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment