Commit 7c7a51ec authored by ginuerzh's avatar ginuerzh

ws supports user defined url path

parent 88efafca
...@@ -109,6 +109,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { ...@@ -109,6 +109,7 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
wsOpts.ReadBufferSize = node.GetInt("rbuf") wsOpts.ReadBufferSize = node.GetInt("rbuf")
wsOpts.WriteBufferSize = node.GetInt("wbuf") wsOpts.WriteBufferSize = node.GetInt("wbuf")
wsOpts.UserAgent = node.Get("agent") wsOpts.UserAgent = node.Get("agent")
wsOpts.Path = node.Get("path")
var host string var host string
...@@ -276,6 +277,7 @@ func (r *route) GenRouters() ([]router, error) { ...@@ -276,6 +277,7 @@ func (r *route) GenRouters() ([]router, error) {
wsOpts.EnableCompression = node.GetBool("compression") wsOpts.EnableCompression = node.GetBool("compression")
wsOpts.ReadBufferSize = node.GetInt("rbuf") wsOpts.ReadBufferSize = node.GetInt("rbuf")
wsOpts.WriteBufferSize = node.GetInt("wbuf") wsOpts.WriteBufferSize = node.GetInt("wbuf")
wsOpts.Path = node.Get("path")
var ln gost.Listener var ln gost.Listener
switch node.Transport { switch node.Transport {
...@@ -284,7 +286,6 @@ func (r *route) GenRouters() ([]router, error) { ...@@ -284,7 +286,6 @@ func (r *route) GenRouters() ([]router, error) {
case "mtls": case "mtls":
ln, err = gost.MTLSListener(node.Addr, tlsCfg) ln, err = gost.MTLSListener(node.Addr, tlsCfg)
case "ws": case "ws":
wsOpts.WriteBufferSize = node.GetInt("wbuf")
ln, err = gost.WSListener(node.Addr, wsOpts) ln, err = gost.WSListener(node.Addr, wsOpts)
case "mws": case "mws":
ln, err = gost.MWSListener(node.Addr, wsOpts) ln, err = gost.MWSListener(node.Addr, wsOpts)
......
...@@ -19,6 +19,10 @@ import ( ...@@ -19,6 +19,10 @@ import (
smux "gopkg.in/xtaci/smux.v1" smux "gopkg.in/xtaci/smux.v1"
) )
const (
defaultWSPath = "/ws"
)
// WSOptions describes the options for websocket. // WSOptions describes the options for websocket.
type WSOptions struct { type WSOptions struct {
ReadBufferSize int ReadBufferSize int
...@@ -26,6 +30,7 @@ type WSOptions struct { ...@@ -26,6 +30,7 @@ type WSOptions struct {
HandshakeTimeout time.Duration HandshakeTimeout time.Duration
EnableCompression bool EnableCompression bool
UserAgent string UserAgent string
Path string
} }
type wsTransporter struct { type wsTransporter struct {
...@@ -49,7 +54,15 @@ func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (n ...@@ -49,7 +54,15 @@ func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (n
if opts.WSOptions != nil { if opts.WSOptions != nil {
wsOptions = opts.WSOptions wsOptions = opts.WSOptions
} }
url := url.URL{Scheme: "ws", Host: opts.Host, Path: "/ws"} if wsOptions == nil {
wsOptions = &WSOptions{}
}
path := wsOptions.Path
if path == "" {
path = defaultWSPath
}
url := url.URL{Scheme: "ws", Host: opts.Host, Path: path}
return websocketClientConn(url.String(), conn, nil, wsOptions) return websocketClientConn(url.String(), conn, nil, wsOptions)
} }
...@@ -148,7 +161,15 @@ func (tr *mwsTransporter) initSession(addr string, conn net.Conn, opts *Handshak ...@@ -148,7 +161,15 @@ func (tr *mwsTransporter) initSession(addr string, conn net.Conn, opts *Handshak
if opts.WSOptions != nil { if opts.WSOptions != nil {
wsOptions = opts.WSOptions wsOptions = opts.WSOptions
} }
url := url.URL{Scheme: "ws", Host: opts.Host, Path: "/ws"} if wsOptions == nil {
wsOptions = &WSOptions{}
}
path := wsOptions.Path
if path == "" {
path = defaultWSPath
}
url := url.URL{Scheme: "ws", Host: opts.Host, Path: path}
conn, err := websocketClientConn(url.String(), conn, nil, wsOptions) conn, err := websocketClientConn(url.String(), conn, nil, wsOptions)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -187,10 +208,18 @@ func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) ( ...@@ -187,10 +208,18 @@ func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (
if opts.WSOptions != nil { if opts.WSOptions != nil {
wsOptions = opts.WSOptions wsOptions = opts.WSOptions
} }
if wsOptions == nil {
wsOptions = &WSOptions{}
}
if opts.TLSConfig == nil { if opts.TLSConfig == nil {
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
} }
url := url.URL{Scheme: "wss", Host: opts.Host, Path: "/ws"} path := wsOptions.Path
if path == "" {
path = defaultWSPath
}
url := url.URL{Scheme: "wss", Host: opts.Host, Path: path}
return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions)
} }
...@@ -288,11 +317,19 @@ func (tr *mwssTransporter) initSession(addr string, conn net.Conn, opts *Handsha ...@@ -288,11 +317,19 @@ func (tr *mwssTransporter) initSession(addr string, conn net.Conn, opts *Handsha
if opts.WSOptions != nil { if opts.WSOptions != nil {
wsOptions = opts.WSOptions wsOptions = opts.WSOptions
} }
if wsOptions == nil {
wsOptions = &WSOptions{}
}
tlsConfig := opts.TLSConfig tlsConfig := opts.TLSConfig
if tlsConfig == nil { if tlsConfig == nil {
tlsConfig = &tls.Config{InsecureSkipVerify: true} tlsConfig = &tls.Config{InsecureSkipVerify: true}
} }
url := url.URL{Scheme: "wss", Host: opts.Host, Path: "/ws"} path := wsOptions.Path
if path == "" {
path = defaultWSPath
}
url := url.URL{Scheme: "wss", Host: opts.Host, Path: path}
conn, err := websocketClientConn(url.String(), conn, tlsConfig, wsOptions) conn, err := websocketClientConn(url.String(), conn, tlsConfig, wsOptions)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -338,8 +375,12 @@ func WSListener(addr string, options *WSOptions) (Listener, error) { ...@@ -338,8 +375,12 @@ func WSListener(addr string, options *WSOptions) (Listener, error) {
errChan: make(chan error, 1), errChan: make(chan error, 1),
} }
path := options.Path
if path == "" {
path = defaultWSPath
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/ws", http.HandlerFunc(l.upgrade)) mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{ l.srv = &http.Server{
Addr: addr, Addr: addr,
Handler: mux, Handler: mux,
...@@ -431,8 +472,13 @@ func MWSListener(addr string, options *WSOptions) (Listener, error) { ...@@ -431,8 +472,13 @@ func MWSListener(addr string, options *WSOptions) (Listener, error) {
errChan: make(chan error, 1), errChan: make(chan error, 1),
} }
path := options.Path
if path == "" {
path = defaultWSPath
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/ws", http.HandlerFunc(l.upgrade)) mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{ l.srv = &http.Server{
Addr: addr, Addr: addr,
Handler: mux, Handler: mux,
...@@ -551,8 +597,13 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen ...@@ -551,8 +597,13 @@ func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listen
tlsConfig = DefaultTLSConfig tlsConfig = DefaultTLSConfig
} }
path := options.Path
if path == "" {
path = defaultWSPath
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/ws", http.HandlerFunc(l.upgrade)) mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{ l.srv = &http.Server{
Addr: addr, Addr: addr,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
...@@ -612,8 +663,13 @@ func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Liste ...@@ -612,8 +663,13 @@ func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Liste
tlsConfig = DefaultTLSConfig tlsConfig = DefaultTLSConfig
} }
path := options.Path
if path == "" {
path = defaultWSPath
}
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/ws", http.HandlerFunc(l.upgrade)) mux.Handle(path, http.HandlerFunc(l.upgrade))
l.srv = &http.Server{ l.srv = &http.Server{
Addr: addr, Addr: addr,
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment