Commit 6613df99 authored by rui.zheng's avatar rui.zheng

add gost

parent 97d2de15
......@@ -78,12 +78,12 @@ func (c *Chain) Conn() (net.Conn, error) {
}
nodes := c.nodes
conn, err := nodes[0].Client.Dial(nodes[0].Addr, TimeoutDialOption(DialTimeout))
conn, err := nodes[0].Client.Dial(nodes[0].Addr, nodes[0].DialOptions...)
if err != nil {
return nil, err
}
conn, err = nodes[0].Client.Handshake(conn, AddrHandshakeOption(nodes[0].Addr))
conn, err = nodes[0].Client.Handshake(conn, nodes[0].HandshakeOptions...)
if err != nil {
return nil, err
}
......@@ -99,7 +99,7 @@ func (c *Chain) Conn() (net.Conn, error) {
conn.Close()
return nil, err
}
cc, err = next.Client.Handshake(cc, AddrHandshakeOption(next.Addr))
cc, err = next.Client.Handshake(cc, next.HandshakeOptions...)
if err != nil {
conn.Close()
return nil, err
......
This diff is collapsed.
......@@ -367,12 +367,16 @@ func (c *udpServerConn) writeLoop() {
}
func (c *udpServerConn) ttlWait() {
timer := time.NewTimer(c.ttl)
ttl := c.ttl
if ttl == 0 {
ttl = defaultTTL
}
timer := time.NewTimer(ttl)
for {
select {
case <-c.nopChan:
timer.Reset(c.ttl)
timer.Reset(ttl)
case <-timer.C:
close(c.brokenChan)
return
......@@ -452,7 +456,7 @@ func (l *tcpRemoteForwardListener) Accept() (net.Conn, error) {
func (l *tcpRemoteForwardListener) accept() (conn net.Conn, err error) {
lastNode := l.chain.LastNode()
if lastNode.Protocol == "forward" && lastNode.Transport == "ssh" {
if lastNode.Protocol == "remote" && lastNode.Transport == "ssh" {
conn, err = l.chain.Dial(l.addr.String())
} else if lastNode.Protocol == "socks5" {
cc, er := l.chain.Conn()
......
......@@ -13,7 +13,7 @@ import (
)
// Version is the gost version.
const Version = "2.4-dev20170722"
const Version = "2.4-dev20170803"
// Debug is a flag that enables the debug log.
var Debug bool
......@@ -39,7 +39,7 @@ var (
// PingRetries is the reties of ping.
PingRetries = 3
// default udp node TTL in second for udp port forwarding.
defaultTTL = 60
defaultTTL = 60 * time.Second
)
var (
......
......@@ -391,6 +391,15 @@ func HTTP2Listener(addr string, config *tls.Config) (Listener, error) {
connChan: make(chan *http2ServerConn, 1024),
errChan: make(chan error, 1),
}
if config == nil {
cert, err := tls.X509KeyPair(defaultRawCert, defaultRawKey)
if err != nil {
return nil, err
}
config = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
server := &http.Server{
Addr: addr,
Handler: http.HandlerFunc(l.handleFunc),
......@@ -400,6 +409,7 @@ func HTTP2Listener(addr string, config *tls.Config) (Listener, error) {
return nil, err
}
l.server = server
go server.ListenAndServeTLS("", "")
return l, nil
......@@ -462,6 +472,16 @@ func H2Listener(addr string, config *tls.Config) (Listener, error) {
if err != nil {
return nil, err
}
if config == nil {
cert, err := tls.X509KeyPair(defaultRawCert, defaultRawKey)
if err != nil {
return nil, err
}
config = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
l := &h2Listener{
Listener: ln,
server: &http2.Server{
......
package gost
import (
"bufio"
"net"
"net/url"
"os"
"strconv"
"strings"
......@@ -18,13 +16,11 @@ type Node struct {
Transport string
Remote string // remote address, used by tcp/udp port forwarding
User *url.Userinfo
users []*url.Userinfo // authentication or cipher for proxy
Whitelist *Permissions
Blacklist *Permissions
values url.Values
serverName string
Chain *Chain
Values url.Values
Client *Client
Server *Server
DialOptions []DialOption
HandshakeOptions []HandshakeOption
}
func ParseNode(s string) (node Node, err error) {
......@@ -36,49 +32,10 @@ func ParseNode(s string) (node Node, err error) {
return
}
query := u.Query()
node = Node{
Addr: u.Host,
values: query,
serverName: u.Host,
}
if query.Get("whitelist") != "" {
if node.Whitelist, err = ParsePermissions(query.Get("whitelist")); err != nil {
return
}
} else {
// By default allow for everyting
node.Whitelist, _ = ParsePermissions("*:*:*")
}
if query.Get("blacklist") != "" {
if node.Blacklist, err = ParsePermissions(query.Get("blacklist")); err != nil {
return
}
} else {
// By default block nothing
node.Blacklist, _ = ParsePermissions("")
}
if u.User != nil {
node.User = u.User
node.users = append(node.users, u.User)
}
users, er := parseUsers(node.values.Get("secrets"))
if users != nil {
node.users = append(node.users, users...)
}
if er != nil {
log.Log("load secrets:", er)
}
if strings.Contains(u.Host, ":") {
node.serverName, _, _ = net.SplitHostPort(u.Host)
if node.serverName == "" {
node.serverName = "localhost" // default server name
}
Values: u.Query(),
User: u.User,
}
schemes := strings.Split(u.Scheme, "+")
......@@ -105,9 +62,9 @@ func ParseNode(s string) (node Node, err error) {
}
switch node.Protocol {
case "http", "http2", "socks4", "socks4a", "socks", "socks5", "ss":
case "http", "http2", "socks4", "socks4a", "socks", "socks5", "ss", "ssu":
case "tcp", "udp", "rtcp", "rudp": // port forwarding
case "direct", "remote": // SSH port forwarding
case "direct", "remote", "forward": // SSH port forwarding
default:
node.Protocol = ""
}
......@@ -115,34 +72,6 @@ func ParseNode(s string) (node Node, err error) {
return
}
func parseUsers(authFile string) (users []*url.Userinfo, err error) {
if authFile == "" {
return
}
file, err := os.Open(authFile)
if err != nil {
return
}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
s := strings.SplitN(line, " ", 2)
if len(s) == 1 {
users = append(users, url.User(strings.TrimSpace(s[0])))
} else if len(s) == 2 {
users = append(users, url.UserPassword(strings.TrimSpace(s[0]), strings.TrimSpace(s[1])))
}
}
err = scanner.Err()
return
}
func Can(action string, addr string, whitelist, blacklist *Permissions) bool {
if !strings.Contains(addr, ":") {
addr = addr + ":80"
......@@ -159,7 +88,8 @@ func Can(action string, addr string, whitelist, blacklist *Permissions) bool {
return false
}
if Debug {
log.Logf("Can action: %s, host: %s, port %d", action, host, port)
}
return whitelist.Can(action, host, port) && !blacklist.Can(action, host, port)
}
......@@ -318,7 +318,6 @@ type sshSession struct {
}
func (s *sshSession) Ping(interval time.Duration, retries int) {
interval = 30 * time.Second
if interval <= 0 {
return
}
......@@ -620,13 +619,24 @@ func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) {
if len(config.Users) == 0 {
sshConfig.NoClientAuth = true
}
if config.TLSConfig != nil && len(config.TLSConfig.Certificates) > 0 {
if config.TLSConfig == nil {
cert, err := tls.X509KeyPair(defaultRawCert, defaultRawKey)
if err != nil {
ln.Close()
return nil, err
}
config.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}
signer, err := ssh.NewSignerFromKey(config.TLSConfig.Certificates[0].PrivateKey)
if err != nil {
log.Log("[sshf]", err)
ln.Close()
return nil, err
}
sshConfig.AddHostKey(signer)
}
l := &sshTunnelListener{
Listener: ln,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment