Commit dedd0853 authored by rui.zheng's avatar rui.zheng

#141: Add load balancing support for proxy chain

parent 18bb8ab2
...@@ -95,12 +95,17 @@ func (c *Chain) Dial(addr string) (net.Conn, error) { ...@@ -95,12 +95,17 @@ func (c *Chain) Dial(addr string) (net.Conn, error) {
return net.Dial("tcp", addr) return net.Dial("tcp", addr)
} }
conn, nodes, err := c.getConn() route, err := c.selectRoute()
if err != nil { if err != nil {
return nil, err return nil, err
} }
cc, err := nodes[len(nodes)-1].Client.Connect(conn, addr) conn, err := c.getConn(route)
if err != nil {
return nil, err
}
cc, err := route.LastNode().Client.Connect(conn, addr)
if err != nil { if err != nil {
conn.Close() conn.Close()
return nil, err return nil, err
...@@ -111,26 +116,44 @@ func (c *Chain) Dial(addr string) (net.Conn, error) { ...@@ -111,26 +116,44 @@ func (c *Chain) Dial(addr string) (net.Conn, error) {
// Conn obtains a handshaked connection to the last node of the chain. // Conn obtains a handshaked connection to the last node of the chain.
// If the chain is empty, it returns an ErrEmptyChain error. // If the chain is empty, it returns an ErrEmptyChain error.
func (c *Chain) Conn() (conn net.Conn, err error) { func (c *Chain) Conn() (conn net.Conn, err error) {
conn, _, err = c.getConn() route, err := c.selectRoute()
if err != nil {
return nil, err
}
conn, err = c.getConn(route)
return return
} }
func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) { func (c *Chain) selectRoute() (route *Chain, err error) {
if c.IsEmpty() { route = NewChain()
err = ErrEmptyChain for _, group := range c.nodeGroups {
return selector := group.Selector
} if selector == nil {
groups := c.nodeGroups selector = &defaultSelector{}
selector := groups[0].Selector }
if selector == nil { // select node from node group
selector = &defaultSelector{} node, err := selector.Select(group.Nodes(), group.Options...)
if err != nil {
return nil, err
}
if node.Client.Transporter.Multiplex() {
node.DialOptions = append(node.DialOptions,
ChainDialOption(route),
)
route = NewChain() // cutoff the chain for multiplex
}
route.AddNode(node)
} }
// select node from node group return
node, err := selector.Select(groups[0].Nodes(), groups[0].Options...) }
if err != nil {
func (c *Chain) getConn(route *Chain) (conn net.Conn, err error) {
if route.IsEmpty() {
err = ErrEmptyChain
return return
} }
nodes = append(nodes, node) nodes := route.Nodes()
node := nodes[0]
addr, err := selectIP(&node) addr, err := selectIP(&node)
if err != nil { if err != nil {
...@@ -147,21 +170,7 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) { ...@@ -147,21 +170,7 @@ func (c *Chain) getConn() (conn net.Conn, nodes []Node, err error) {
} }
preNode := node preNode := node
for i := range groups { for _, node := range nodes[1:] {
if i == len(groups)-1 {
break
}
selector = groups[i+1].Selector
if selector == nil {
selector = &defaultSelector{}
}
node, err = selector.Select(groups[i+1].Nodes(), groups[i+1].Options...)
if err != nil {
cn.Close()
return
}
nodes = append(nodes, node)
addr, err = selectIP(&node) addr, err = selectIP(&node)
if err != nil { if err != nil {
return return
...@@ -206,6 +215,7 @@ func selectIP(node *Node) (string, error) { ...@@ -206,6 +215,7 @@ func selectIP(node *Node) (string, error) {
ip = ip + ":" + sport ip = ip + ":" + sport
} }
addr = ip addr = ip
// override the original address
node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr)) node.HandshakeOptions = append(node.HandshakeOptions, AddrHandshakeOption(addr))
} }
log.Log("select IP:", node.Addr, node.IPs, addr) log.Log("select IP:", node.Addr, node.IPs, addr)
......
...@@ -94,7 +94,6 @@ func (tr *tcpTransporter) Multiplex() bool { ...@@ -94,7 +94,6 @@ func (tr *tcpTransporter) Multiplex() bool {
type DialOptions struct { type DialOptions struct {
Timeout time.Duration Timeout time.Duration
Chain *Chain Chain *Chain
// IPs []string
} }
// DialOption allows a common way to set dial options. // DialOption allows a common way to set dial options.
......
...@@ -62,6 +62,16 @@ func init() { ...@@ -62,6 +62,16 @@ func init() {
} }
func main() { func main() {
// generate random self-signed certificate.
cert, err := gost.GenCertificate()
if err != nil {
log.Log(err)
os.Exit(1)
}
gost.DefaultTLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
chain, err := initChain() chain, err := initChain()
if err != nil { if err != nil {
log.Log(err) log.Log(err)
...@@ -71,160 +81,188 @@ func main() { ...@@ -71,160 +81,188 @@ func main() {
log.Log(err) log.Log(err)
os.Exit(1) os.Exit(1)
} }
select {} select {}
} }
func initChain() (*gost.Chain, error) { func initChain() (*gost.Chain, error) {
chain := gost.NewChain() chain := gost.NewChain()
for _, ns := range options.ChainNodes { for _, ns := range options.ChainNodes {
node, err := gost.ParseNode(ns) // parse the base node
node, err := parseChainNode(ns)
if err != nil { if err != nil {
return nil, err return nil, err
} }
node.IPs = parseIP(node.Values.Get("ip")) ngroup := gost.NewNodeGroup(node)
node.IPSelector = &gost.RoundRobinIPSelector{}
users, err := parseUsers(node.Values.Get("secrets")) // parse node peers if exists
peerCfg, err := loadPeerConfig(node.Values.Get("peer"))
if err != nil { if err != nil {
return nil, err log.Log(err)
}
if node.User == nil && len(users) > 0 {
node.User = users[0]
} }
serverName, _, _ := net.SplitHostPort(node.Addr) ngroup.Options = append(ngroup.Options,
if serverName == "" { // gost.WithFilter(),
serverName = "localhost" // default server name gost.WithStrategy(parseStrategy(peerCfg.Strategy)),
)
for _, s := range peerCfg.Nodes {
node, err = parseChainNode(s)
if err != nil {
return nil, err
}
ngroup.AddNode(node)
} }
rootCAs, err := loadCA(node.Values.Get("ca")) chain.AddNodeGroup(ngroup)
if err != nil { }
return nil, err
}
tlsCfg := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: !toBool(node.Values.Get("secure")),
RootCAs: rootCAs,
}
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
wsOpts.UserAgent = node.Values.Get("agent")
var tr gost.Transporter return chain, nil
switch node.Transport { }
case "tls":
tr = gost.TLSTransporter() func parseChainNode(ns string) (node gost.Node, err error) {
case "mtls": node, err = gost.ParseNode(ns)
tr = gost.MTLSTransporter() if err != nil {
case "ws": return
tr = gost.WSTransporter(wsOpts) }
case "mws":
tr = gost.MWSTransporter(wsOpts) node.IPs = parseIP(node.Values.Get("ip"))
case "wss": node.IPSelector = &gost.RoundRobinIPSelector{}
tr = gost.WSSTransporter(wsOpts)
case "mwss": users, err := parseUsers(node.Values.Get("secrets"))
tr = gost.MWSSTransporter(wsOpts) if err != nil {
case "kcp": return
}
if node.User == nil && len(users) > 0 {
node.User = users[0]
}
serverName, _, _ := net.SplitHostPort(node.Addr)
if serverName == "" {
serverName = "localhost" // default server name
}
rootCAs, err := loadCA(node.Values.Get("ca"))
if err != nil {
return
}
tlsCfg := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: !toBool(node.Values.Get("secure")),
RootCAs: rootCAs,
}
wsOpts := &gost.WSOptions{}
wsOpts.EnableCompression = toBool(node.Values.Get("compression"))
wsOpts.ReadBufferSize, _ = strconv.Atoi(node.Values.Get("rbuf"))
wsOpts.WriteBufferSize, _ = strconv.Atoi(node.Values.Get("wbuf"))
wsOpts.UserAgent = node.Values.Get("agent")
var tr gost.Transporter
switch node.Transport {
case "tls":
tr = gost.TLSTransporter()
case "mtls":
tr = gost.MTLSTransporter()
case "ws":
tr = gost.WSTransporter(wsOpts)
case "mws":
tr = gost.MWSTransporter(wsOpts)
case "wss":
tr = gost.WSSTransporter(wsOpts)
case "mwss":
tr = gost.MWSSTransporter(wsOpts)
case "kcp":
/*
if !chain.IsEmpty() { if !chain.IsEmpty() {
return nil, errors.New("KCP must be the first node in the proxy chain") return nil, errors.New("KCP must be the first node in the proxy chain")
} }
config, err := parseKCPConfig(node.Values.Get("c")) */
if err != nil { config, err := parseKCPConfig(node.Values.Get("c"))
return nil, err if err != nil {
} return node, err
tr = gost.KCPTransporter(config) }
case "ssh": tr = gost.KCPTransporter(config)
if node.Protocol == "direct" || node.Protocol == "remote" { case "ssh":
tr = gost.SSHForwardTransporter() if node.Protocol == "direct" || node.Protocol == "remote" {
} else { tr = gost.SSHForwardTransporter()
tr = gost.SSHTunnelTransporter() } else {
} tr = gost.SSHTunnelTransporter()
case "quic": }
case "quic":
/*
if !chain.IsEmpty() { if !chain.IsEmpty() {
return nil, errors.New("QUIC must be the first node in the proxy chain") return nil, errors.New("QUIC must be the first node in the proxy chain")
} }
config := &gost.QUICConfig{ */
TLSConfig: tlsCfg, config := &gost.QUICConfig{
KeepAlive: toBool(node.Values.Get("keepalive")), TLSConfig: tlsCfg,
} KeepAlive: toBool(node.Values.Get("keepalive")),
tr = gost.QUICTransporter(config)
case "http2":
tr = gost.HTTP2Transporter(tlsCfg)
case "h2":
tr = gost.H2Transporter(tlsCfg)
case "h2c":
tr = gost.H2CTransporter()
case "obfs4":
if err := gost.Obfs4Init(node, false); err != nil {
return nil, err
}
tr = gost.Obfs4Transporter()
case "ohttp":
tr = gost.ObfsHTTPTransporter()
default:
tr = gost.TCPTransporter()
} }
tr = gost.QUICTransporter(config)
if tr.Multiplex() { case "http2":
node.DialOptions = append(node.DialOptions, tr = gost.HTTP2Transporter(tlsCfg)
gost.ChainDialOption(chain), case "h2":
) tr = gost.H2Transporter(tlsCfg)
chain = gost.NewChain() // cutoff the chain for multiplex case "h2c":
tr = gost.H2CTransporter()
case "obfs4":
if err := gost.Obfs4Init(node, false); err != nil {
return node, err
} }
tr = gost.Obfs4Transporter()
case "ohttp":
tr = gost.ObfsHTTPTransporter()
default:
tr = gost.TCPTransporter()
}
var connector gost.Connector var connector gost.Connector
switch node.Protocol { switch node.Protocol {
case "http2": case "http2":
connector = gost.HTTP2Connector(node.User) connector = gost.HTTP2Connector(node.User)
case "socks", "socks5": case "socks", "socks5":
connector = gost.SOCKS5Connector(node.User) connector = gost.SOCKS5Connector(node.User)
case "socks4": case "socks4":
connector = gost.SOCKS4Connector() connector = gost.SOCKS4Connector()
case "socks4a": case "socks4a":
connector = gost.SOCKS4AConnector() connector = gost.SOCKS4AConnector()
case "ss": case "ss":
connector = gost.ShadowConnector(node.User) connector = gost.ShadowConnector(node.User)
case "direct": case "direct":
connector = gost.SSHDirectForwardConnector() connector = gost.SSHDirectForwardConnector()
case "remote": case "remote":
connector = gost.SSHRemoteForwardConnector() connector = gost.SSHRemoteForwardConnector()
case "forward": case "forward":
connector = gost.ForwardConnector() connector = gost.ForwardConnector()
case "sni": case "sni":
connector = gost.SNIConnector(node.Values.Get("host")) connector = gost.SNIConnector(node.Values.Get("host"))
case "http": case "http":
fallthrough fallthrough
default: default:
node.Protocol = "http" // default protocol is HTTP node.Protocol = "http" // default protocol is HTTP
connector = gost.HTTPConnector(node.User) connector = gost.HTTPConnector(node.User)
} }
timeout, _ := strconv.Atoi(node.Values.Get("timeout")) timeout, _ := strconv.Atoi(node.Values.Get("timeout"))
node.DialOptions = append(node.DialOptions, node.DialOptions = append(node.DialOptions,
gost.TimeoutDialOption(time.Duration(timeout)*time.Second), gost.TimeoutDialOption(time.Duration(timeout)*time.Second),
) )
interval, _ := strconv.Atoi(node.Values.Get("ping")) interval, _ := strconv.Atoi(node.Values.Get("ping"))
retry, _ := strconv.Atoi(node.Values.Get("retry")) retry, _ := strconv.Atoi(node.Values.Get("retry"))
node.HandshakeOptions = append(node.HandshakeOptions, node.HandshakeOptions = append(node.HandshakeOptions,
gost.AddrHandshakeOption(node.Addr), gost.AddrHandshakeOption(node.Addr),
gost.UserHandshakeOption(node.User), gost.UserHandshakeOption(node.User),
gost.TLSConfigHandshakeOption(tlsCfg), gost.TLSConfigHandshakeOption(tlsCfg),
gost.IntervalHandshakeOption(time.Duration(interval)*time.Second), gost.IntervalHandshakeOption(time.Duration(interval)*time.Second),
gost.TimeoutHandshakeOption(time.Duration(timeout)*time.Second), gost.TimeoutHandshakeOption(time.Duration(timeout)*time.Second),
gost.RetryHandshakeOption(retry), gost.RetryHandshakeOption(retry),
) )
node.Client = &gost.Client{ node.Client = &gost.Client{
Connector: connector, Connector: connector,
Transporter: tr, Transporter: tr,
}
chain.AddNode(node)
} }
return chain, nil return
} }
func serve(chain *gost.Chain) error { func serve(chain *gost.Chain) error {
...@@ -533,3 +571,32 @@ func parseIP(s string) (ips []string) { ...@@ -533,3 +571,32 @@ func parseIP(s string) (ips []string) {
} }
return return
} }
type peerConfig struct {
Strategy string `json:"strategy"`
Filters []string `json:"filters"`
Nodes []string `json:"nodes"`
}
func loadPeerConfig(peer string) (config peerConfig, err error) {
if peer == "" {
return
}
content, err := ioutil.ReadFile(peer)
if err != nil {
return
}
err = json.Unmarshal(content, &config)
return
}
func parseStrategy(s string) gost.Strategy {
switch s {
case "round":
return &gost.RoundStrategy{}
case "random":
fallthrough
default:
return &gost.RandomStrategy{}
}
}
...@@ -38,7 +38,7 @@ var ( ...@@ -38,7 +38,7 @@ var (
// PingTimeout is the timeout for pinging. // PingTimeout is the timeout for pinging.
PingTimeout = 30 * time.Second PingTimeout = 30 * time.Second
// PingRetries is the reties of ping. // PingRetries is the reties of ping.
PingRetries = 3 PingRetries = 1
// default udp node TTL in second for udp port forwarding. // default udp node TTL in second for udp port forwarding.
defaultTTL = 60 * time.Second defaultTTL = 60 * time.Second
) )
...@@ -51,27 +51,19 @@ var ( ...@@ -51,27 +51,19 @@ var (
DefaultUserAgent = "Chrome/60.0.3112.90" DefaultUserAgent = "Chrome/60.0.3112.90"
) )
func init() {
rawCert, rawKey, err := generateKeyPair()
if err != nil {
panic(err)
}
cert, err := tls.X509KeyPair(rawCert, rawKey)
if err != nil {
panic(err)
}
DefaultTLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
// log.DefaultLogger = &LogLogger{}
}
// SetLogger sets a new logger for internal log system // SetLogger sets a new logger for internal log system
func SetLogger(logger log.Logger) { func SetLogger(logger log.Logger) {
log.DefaultLogger = logger log.DefaultLogger = logger
} }
func GenCertificate() (cert tls.Certificate, err error) {
rawCert, rawKey, err := generateKeyPair()
if err != nil {
return
}
return tls.X509KeyPair(rawCert, rawKey)
}
func generateKeyPair() (rawCert, rawKey []byte, err error) { func generateKeyPair() (rawCert, rawKey []byte, err error) {
// Create private key and self-signed certificate // Create private key and self-signed certificate
// Adapted from https://golang.org/src/crypto/tls/generate_cert.go // Adapted from https://golang.org/src/crypto/tls/generate_cert.go
......
...@@ -194,6 +194,7 @@ func (l *quicListener) sessionLoop(session quic.Session) { ...@@ -194,6 +194,7 @@ func (l *quicListener) sessionLoop(session quic.Session) {
stream, err := session.AcceptStream() stream, err := session.AcceptStream()
if err != nil { if err != nil {
log.Log("[quic] accept stream:", err) log.Log("[quic] accept stream:", err)
session.Close(err)
return return
} }
......
...@@ -11,14 +11,10 @@ var ( ...@@ -11,14 +11,10 @@ var (
ErrNoneAvailable = errors.New("none available") ErrNoneAvailable = errors.New("none available")
) )
// SelectOption used when making a select call
type SelectOption func(*SelectOptions)
// NodeSelector as a mechanism to pick nodes and mark their status. // NodeSelector as a mechanism to pick nodes and mark their status.
type NodeSelector interface { type NodeSelector interface {
Select(nodes []Node, opts ...SelectOption) (Node, error) Select(nodes []Node, opts ...SelectOption) (Node, error)
// Mark(node Node) // Mark(node Node)
String() string
} }
type defaultSelector struct { type defaultSelector struct {
...@@ -26,35 +22,70 @@ type defaultSelector struct { ...@@ -26,35 +22,70 @@ type defaultSelector struct {
func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, error) { func (s *defaultSelector) Select(nodes []Node, opts ...SelectOption) (Node, error) {
sopts := SelectOptions{ sopts := SelectOptions{
Strategy: defaultStrategy, Strategy: &RoundStrategy{},
} }
for _, opt := range opts { for _, opt := range opts {
opt(&sopts) opt(&sopts)
} }
for _, filter := range sopts.Filters { for _, filter := range sopts.Filters {
nodes = filter(nodes) nodes = filter.Filter(nodes)
} }
if len(nodes) == 0 { if len(nodes) == 0 {
return Node{}, ErrNoneAvailable return Node{}, ErrNoneAvailable
} }
return sopts.Strategy(nodes), nil return sopts.Strategy.Apply(nodes), nil
}
func (s *defaultSelector) String() string {
return "default"
} }
// Filter is used to filter a node during the selection process // Filter is used to filter a node during the selection process
type Filter func([]Node) []Node type Filter interface {
Filter([]Node) []Node
}
// Strategy is a selection strategy e.g random, round robin // Strategy is a selection strategy e.g random, round robin
type Strategy func([]Node) Node type Strategy interface {
Apply([]Node) Node
String() string
}
// RoundStrategy is a strategy for node selector
type RoundStrategy struct {
count uint64
}
// Apply applies the round robin strategy for the nodes
func (s *RoundStrategy) Apply(nodes []Node) Node {
if len(nodes) == 0 {
return Node{}
}
old := s.count
atomic.AddUint64(&s.count, 1)
return nodes[int(old%uint64(len(nodes)))]
}
func (s *RoundStrategy) String() string {
return "round"
}
// RandomStrategy is a strategy for node selector
type RandomStrategy struct{}
// Apply applies the random strategy for the nodes
func (s *RandomStrategy) Apply(nodes []Node) Node {
if len(nodes) == 0 {
return Node{}
}
func defaultStrategy(nodes []Node) Node { return nodes[time.Now().Nanosecond()%len(nodes)]
return nodes[0]
} }
func (s *RandomStrategy) String() string {
return "random"
}
// SelectOption used when making a select call
type SelectOption func(*SelectOptions)
// SelectOptions is the options for node selection // SelectOptions is the options for node selection
type SelectOptions struct { type SelectOptions struct {
Filters []Filter Filters []Filter
...@@ -108,9 +139,9 @@ func (s *RoundRobinIPSelector) Select(ips []string) (string, error) { ...@@ -108,9 +139,9 @@ func (s *RoundRobinIPSelector) Select(ips []string) (string, error) {
if len(ips) == 0 { if len(ips) == 0 {
return "", nil return "", nil
} }
old := s.count
count := atomic.AddUint64(&s.count, 1) atomic.AddUint64(&s.count, 1)
return ips[int(count%uint64(len(ips)))], nil return ips[int(old%uint64(len(ips)))], nil
} }
func (s *RoundRobinIPSelector) String() string { func (s *RoundRobinIPSelector) String() string {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment