Commit b285bcd6 authored by ginuerzh's avatar ginuerzh

support mutual TLS authentication

parent 60d7e011
...@@ -44,8 +44,9 @@ var ( ...@@ -44,8 +44,9 @@ var (
defaultKeyFile = "key.pem" defaultKeyFile = "key.pem"
) )
// Load the certificate from cert and key files, will use the default certificate if the provided info are invalid. // Load the certificate from cert & key files and optional client CA file,
func tlsConfig(certFile, keyFile string) (*tls.Config, error) { // will use the default certificate if the provided info are invalid.
func tlsConfig(certFile, keyFile, caFile string) (*tls.Config, error) {
if certFile == "" || keyFile == "" { if certFile == "" || keyFile == "" {
certFile, keyFile = defaultCertFile, defaultKeyFile certFile, keyFile = defaultCertFile, defaultKeyFile
} }
...@@ -54,7 +55,15 @@ func tlsConfig(certFile, keyFile string) (*tls.Config, error) { ...@@ -54,7 +55,15 @@ func tlsConfig(certFile, keyFile string) (*tls.Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &tls.Config{Certificates: []tls.Certificate{cert}}, nil
cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
if pool, _ := loadCA(caFile); pool != nil {
cfg.ClientCAs = pool
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
return cfg, nil
} }
func loadCA(caFile string) (cp *x509.CertPool, err error) { func loadCA(caFile string) (cp *x509.CertPool, err error) {
......
...@@ -67,7 +67,7 @@ func main() { ...@@ -67,7 +67,7 @@ func main() {
} }
// NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate. // NOTE: as of 2.6, you can use custom cert/key files to initialize the default certificate.
tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile) tlsConfig, err := tlsConfig(defaultCertFile, defaultKeyFile, "")
if err != nil { if err != nil {
// generate random self-signed certificate. // generate random self-signed certificate.
cert, err := gost.GenCertificate() cert, err := gost.GenCertificate()
......
...@@ -128,6 +128,10 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { ...@@ -128,6 +128,10 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
InsecureSkipVerify: !node.GetBool("secure"), InsecureSkipVerify: !node.GetBool("secure"),
RootCAs: rootCAs, RootCAs: rootCAs,
} }
if cert, err := tls.LoadX509KeyPair(node.Get("cert"), node.Get("key")); err == nil {
tlsCfg.Certificates = []tls.Certificate{cert}
}
wsOpts := &gost.WSOptions{} wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = node.GetBool("compression") wsOpts.EnableCompression = node.GetBool("compression")
wsOpts.ReadBufferSize = node.GetInt("rbuf") wsOpts.ReadBufferSize = node.GetInt("rbuf")
...@@ -343,7 +347,7 @@ func (r *route) GenRouters() ([]router, error) { ...@@ -343,7 +347,7 @@ func (r *route) GenRouters() ([]router, error) {
} }
} }
certFile, keyFile := node.Get("cert"), node.Get("key") certFile, keyFile := node.Get("cert"), node.Get("key")
tlsCfg, err := tlsConfig(certFile, keyFile) tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca"))
if err != nil && certFile != "" && keyFile != "" { if err != nil && certFile != "" && keyFile != "" {
return nil, err return nil, err
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment