Commit abe40434 authored by ginuerzh's avatar ginuerzh

add chain.DialContext

parent 425099a7
package gost
import (
"context"
"errors"
"net"
"time"
......@@ -100,9 +101,14 @@ func (c *Chain) IsEmpty() bool {
return c == nil || len(c.nodeGroups) == 0
}
// Dial connects to the target address addr through the chain.
// If the chain is empty, it will use the net.Dial directly.
func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error) {
// Dial connects to the target TCP address addr through the chain.
// Deprecated: use DialContext instead.
func (c *Chain) Dial(address string, opts ...ChainOption) (conn net.Conn, err error) {
return c.DialContext(context.Background(), "tcp", address, opts...)
}
// DialContext connects to the address on the named network using the provided context.
func (c *Chain) DialContext(ctx context.Context, network, address string, opts ...ChainOption) (conn net.Conn, err error) {
options := &ChainOptions{}
for _, opt := range opts {
opt(options)
......@@ -117,7 +123,7 @@ func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error
}
for i := 0; i < retries; i++ {
conn, err = c.dialWithOptions(addr, options)
conn, err = c.dialWithOptions(ctx, network, address, options)
if err == nil {
break
}
......@@ -125,16 +131,19 @@ func (c *Chain) Dial(addr string, opts ...ChainOption) (conn net.Conn, err error
return
}
func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, error) {
func (c *Chain) dialWithOptions(ctx context.Context, network, address string, options *ChainOptions) (net.Conn, error) {
if options == nil {
options = &ChainOptions{}
}
route, err := c.selectRouteFor(addr)
route, err := c.selectRouteFor(address)
if err != nil {
return nil, err
}
ipAddr := c.resolve(addr, options.Resolver, options.Hosts)
ipAddr := address
if address != "" {
ipAddr = c.resolve(address, options.Resolver, options.Hosts)
}
timeout := options.Timeout
if timeout <= 0 {
......@@ -142,16 +151,27 @@ func (c *Chain) dialWithOptions(addr string, options *ChainOptions) (net.Conn, e
}
if route.IsEmpty() {
return net.DialTimeout("tcp", ipAddr, timeout)
switch network {
case "udp", "udp4", "udp6":
if address == "" {
return net.ListenUDP(network, nil)
}
default:
}
d := &net.Dialer{
Timeout: timeout,
// LocalAddr: laddr, // TODO: optional local address
}
return d.DialContext(ctx, network, ipAddr)
}
conn, err := route.getConn()
conn, err := route.getConn(ctx)
if err != nil {
return nil, err
}
cOpts := append([]ConnectOption{AddrConnectOption(addr)}, route.LastNode().ConnectOptions...)
cc, err := route.LastNode().Client.Connect(conn, ipAddr, cOpts...)
cOpts := append([]ConnectOption{AddrConnectOption(address)}, route.LastNode().ConnectOptions...)
cc, err := route.LastNode().Client.ConnectContext(ctx, conn, network, ipAddr, cOpts...)
if err != nil {
conn.Close()
return nil, err
......@@ -187,6 +207,8 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
opt(options)
}
ctx := context.Background()
retries := 1
if c != nil && c.Retries > 0 {
retries = c.Retries
......@@ -201,7 +223,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
if err != nil {
continue
}
conn, err = route.getConn()
conn, err = route.getConn(ctx)
if err == nil {
break
}
......@@ -210,7 +232,7 @@ func (c *Chain) Conn(opts ...ChainOption) (conn net.Conn, err error) {
}
// getConn obtains a connection to the last node of the chain.
func (c *Chain) getConn() (conn net.Conn, err error) {
func (c *Chain) getConn(ctx context.Context) (conn net.Conn, err error) {
if c.IsEmpty() {
err = ErrEmptyChain
return
......@@ -234,7 +256,7 @@ func (c *Chain) getConn() (conn net.Conn, err error) {
preNode := node
for _, node := range nodes[1:] {
var cc net.Conn
cc, err = preNode.Client.Connect(cn, node.Addr, preNode.ConnectOptions...)
cc, err = preNode.Client.ConnectContext(ctx, cn, "tcp", node.Addr, preNode.ConnectOptions...)
if err != nil {
cn.Close()
node.MarkDead()
......
package gost
import (
"context"
"crypto/tls"
"net"
"net/url"
......@@ -14,23 +15,8 @@ import (
// Connector is responsible for connecting to the destination address through this proxy.
// Transporter performs a handshake with this proxy.
type Client struct {
Connector Connector
Transporter Transporter
}
// Dial connects to the target address.
func (c *Client) Dial(addr string, options ...DialOption) (net.Conn, error) {
return c.Transporter.Dial(addr, options...)
}
// Handshake performs a handshake with the proxy over connection conn.
func (c *Client) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) {
return c.Transporter.Handshake(conn, options...)
}
// Connect connects to the address addr via the proxy over connection conn.
func (c *Client) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return c.Connector.Connect(conn, addr, options...)
Connector
Transporter
}
// DefaultClient is a standard HTTP proxy client.
......@@ -53,7 +39,36 @@ func Connect(conn net.Conn, addr string) (net.Conn, error) {
// Connector is responsible for connecting to the destination address.
type Connector interface {
Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error)
// Deprecated: use ConnectContext instead.
Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error)
ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error)
}
type autoConnector struct {
User *url.Userinfo
}
// AutoConnector is a Connector.
func AutoConnector(user *url.Userinfo) Connector {
return &autoConnector{
User: user,
}
}
func (c *autoConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *autoConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
var cnr Connector
switch network {
case "tcp", "tcp4", "tcp6":
cnr = &httpConnector{User: c.User}
default:
cnr = &socks5UDPTunConnector{User: c.User}
}
return cnr.ConnectContext(ctx, conn, network, address, options...)
}
// Transporter is responsible for handshaking with the proxy server.
......
......@@ -227,10 +227,9 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) {
case "sni":
connector = gost.SNIConnector(node.Get("host"))
case "http":
fallthrough
default:
node.Protocol = "http" // default protocol is HTTP
connector = gost.HTTPConnector(node.User)
default:
connector = gost.AutoConnector(node.User)
}
timeout := node.GetInt("timeout")
......
package gost
import (
"context"
"errors"
"net"
"strings"
......@@ -22,7 +23,11 @@ func ForwardConnector() Connector {
return &forwardConnector{}
}
func (c *forwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *forwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return conn, nil
}
func (c *forwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
return conn, nil
}
......@@ -186,42 +191,12 @@ func (h *udpDirectForwardHandler) Handle(conn net.Conn) {
return
}
raddr, err := net.ResolveUDPAddr("udp", node.Addr)
cc, err := h.options.Chain.DialContext(context.Background(), "udp", node.Addr)
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
var cc net.Conn
if h.options.Chain.IsEmpty() {
cc, err = net.DialUDP("udp", nil, raddr)
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
} else if h.options.Chain.LastNode().Protocol == "ssu" {
cc, err = h.options.Chain.Dial(node.Addr,
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil {
node.MarkDead()
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
} else {
var err error
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
if err != nil {
log.Logf("[udp] %s - %s : %s", conn.RemoteAddr(), node.Addr, err)
return
}
cc = &udpTunnelConn{Conn: cc, raddr: raddr}
}
defer cc.Close()
node.ResetDead()
......@@ -726,11 +701,11 @@ func (l *udpRemoteForwardListener) connect() (conn net.PacketConn, err error) {
lastNode := l.chain.LastNode()
if lastNode.Protocol == "socks5" {
var cc net.Conn
cc, err = getSOCKS5UDPTunnel(l.chain, l.addr)
cc, err = getSocks5UDPTunnel(l.chain, l.addr)
if err != nil {
log.Logf("[rudp] %s : %s", l.Addr(), err)
} else {
conn = &udpTunnelConn{Conn: cc}
conn = cc.(net.PacketConn)
}
} else {
var uc *net.UDPConn
......
......@@ -20,7 +20,7 @@ import (
)
// Version is the gost version.
const Version = "2.10.0"
const Version = "2.10.1"
// Debug is a flag that enables the debug log.
var Debug bool
......
......@@ -3,6 +3,7 @@ package gost
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"fmt"
"net"
......@@ -27,7 +28,16 @@ func HTTPConnector(user *url.Userinfo) Connector {
return &httpConnector{User: user}
}
func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *httpConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *httpConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
......@@ -47,8 +57,8 @@ func (c *httpConnector) Connect(conn net.Conn, addr string, options ...ConnectOp
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: addr},
Host: addr,
URL: &url.URL{Host: address},
Host: address,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
......
......@@ -3,6 +3,7 @@ package gost
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
......@@ -33,7 +34,16 @@ func HTTP2Connector(user *url.Userinfo) Connector {
return &http2Connector{User: user}
}
func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *http2Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *http2Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
......@@ -57,7 +67,7 @@ func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectO
ProtoMajor: 2,
ProtoMinor: 0,
Body: pr,
Host: addr,
Host: address,
ContentLength: -1,
}
req.Header.Set("User-Agent", ua)
......@@ -97,7 +107,7 @@ func (c *http2Connector) Connect(conn net.Conn, addr string, options ...ConnectO
closed: make(chan struct{}),
}
hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr)
hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", address)
hc.localAddr, _ = net.ResolveTCPAddr("tcp", cc.addr)
return hc, nil
......
......@@ -3,6 +3,7 @@
package gost
import (
"context"
"errors"
"fmt"
"net"
......@@ -132,32 +133,14 @@ func (h *udpRedirectHandler) Handle(conn net.Conn) {
return
}
var cc net.Conn
var err error
if h.options.Chain.IsEmpty() {
cc, err = net.DialUDP("udp", nil, raddr)
if err != nil {
log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err)
return
}
} else if h.options.Chain.LastNode().Protocol == "ssu" {
cc, err = h.options.Chain.Dial(raddr.String(),
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil {
log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err)
return
}
} else {
var err error
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
if err != nil {
log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err)
return
}
cc = &udpTunnelConn{Conn: cc, raddr: raddr}
cc, err := h.options.Chain.DialContext(context.Background(),
"udp", raddr.String(),
RetryChainOption(h.options.Retries),
TimeoutChainOption(h.options.Timeout),
)
if err != nil {
log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err)
return
}
defer cc.Close()
......
......@@ -606,31 +606,12 @@ func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger {
}
}
func (ex *dnsExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
if ex.options.chain.IsEmpty() {
d := &net.Dialer{
Timeout: ex.options.timeout,
}
return d.DialContext(ctx, network, address)
}
if ex.options.chain.LastNode().Protocol == "ssu" {
return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout))
}
raddr, err := net.ResolveUDPAddr(network, address)
if err != nil {
return
}
cc, err := getSOCKS5UDPTunnel(ex.options.chain, nil)
conn = &udpTunnelConn{Conn: cc, raddr: raddr}
return
}
func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
t := time.Now()
c, err := ex.dial(ctx, "udp", ex.addr)
c, err := ex.options.chain.DialContext(ctx,
"udp", ex.addr,
TimeoutChainOption(ex.options.timeout),
)
if err != nil {
return nil, err
}
......@@ -674,19 +655,12 @@ func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger {
}
}
func (ex *dnsTCPExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
if ex.options.chain.IsEmpty() {
d := &net.Dialer{
Timeout: ex.options.timeout,
}
return d.DialContext(ctx, network, address)
}
return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout))
}
func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
t := time.Now()
c, err := ex.dial(ctx, "tcp", ex.addr)
c, err := ex.options.chain.DialContext(ctx,
"tcp", ex.addr,
TimeoutChainOption(ex.options.timeout),
)
if err != nil {
return nil, err
}
......@@ -738,14 +712,10 @@ func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption
}
func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) {
if ex.options.chain.IsEmpty() {
d := &net.Dialer{
Timeout: ex.options.timeout,
}
conn, err = d.DialContext(ctx, network, address)
} else {
conn, err = ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout))
}
conn, err = ex.options.chain.DialContext(ctx,
network, address,
TimeoutChainOption(ex.options.timeout),
)
if err != nil {
return
}
......@@ -812,14 +782,11 @@ func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOp
return ex
}
func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
if ex.options.chain.IsEmpty() {
d := &net.Dialer{
Timeout: ex.options.timeout,
}
return d.DialContext(ctx, network, address)
}
return ex.options.chain.Dial(address, TimeoutChainOption(ex.options.timeout))
func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
return ex.options.chain.DialContext(ctx,
network, address,
TimeoutChainOption(ex.options.timeout),
)
}
func (ex *dohExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) {
......
name: gost
type: app
version: '2.10.0'
version: '2.10.1'
title: GO Simple Tunnel
summary: A simple security tunnel written in golang
description: |
......
......@@ -5,6 +5,7 @@ package gost
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/binary"
"errors"
......@@ -29,8 +30,17 @@ func SNIConnector(host string) Connector {
return &sniConnector{host: host}
}
func (c *sniConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
return &sniClientConn{addr: addr, host: c.host, Conn: conn}, nil
func (c *sniConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *sniConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
return &sniClientConn{addr: address, host: c.host, Conn: conn}, nil
}
type sniHandler struct {
......
This diff is collapsed.
......@@ -2,6 +2,7 @@ package gost
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
......@@ -15,6 +16,15 @@ import (
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
)
const (
maxSocksAddrLen = 259
)
var (
_ net.Conn = (*shadowConn)(nil)
_ net.PacketConn = (*shadowUDPPacketConn)(nil)
)
type shadowConnector struct {
cipher core.Cipher
}
......@@ -27,7 +37,16 @@ func ShadowConnector(info *url.Userinfo) Connector {
}
}
func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *shadowConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *shadowConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
......@@ -38,7 +57,7 @@ func (c *shadowConnector) Connect(conn net.Conn, addr string, options ...Connect
timeout = ConnectTimeout
}
socksAddr, err := gosocks5.NewAddr(addr)
socksAddr, err := gosocks5.NewAddr(address)
if err != nil {
return nil, err
}
......@@ -183,7 +202,16 @@ func ShadowUDPConnector(info *url.Userinfo) Connector {
}
}
func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *shadowUDPConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "udp", address, options...)
}
func (c *shadowUDPConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "tcp", "tcp4", "tcp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
......@@ -197,13 +225,13 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn
conn.SetDeadline(time.Now().Add(timeout))
defer conn.SetDeadline(time.Time{})
taddr, _ := net.ResolveUDPAddr(network, address)
if taddr == nil {
taddr = &net.UDPAddr{}
}
pc, ok := conn.(net.PacketConn)
if ok {
rawaddr, err := ss.RawAddr(addr)
if err != nil {
return nil, err
}
if c.cipher != nil {
pc = c.cipher.PacketConn(pc)
}
......@@ -211,22 +239,17 @@ func (c *shadowUDPConnector) Connect(conn net.Conn, addr string, options ...Conn
return &shadowUDPPacketConn{
PacketConn: pc,
raddr: conn.RemoteAddr(),
header: rawaddr,
taddr: taddr,
}, nil
}
taddr, err := gosocks5.NewAddr(addr)
if err != nil {
return nil, err
}
if c.cipher != nil {
conn = c.cipher.StreamConn(conn)
}
return &shadowUDPStreamConn{
Conn: conn,
addr: taddr,
return &socks5UDPTunnelConn{
Conn: conn,
taddr: taddr,
}, nil
}
......@@ -258,23 +281,13 @@ func (h *shadowUDPHandler) Init(options ...HandlerOption) {
func (h *shadowUDPHandler) Handle(conn net.Conn) {
defer conn.Close()
var err error
var cc net.PacketConn
if h.options.Chain.IsEmpty() {
cc, err = net.ListenUDP("udp", nil)
if err != nil {
log.Logf("[ssu] %s - : %s", conn.LocalAddr(), err)
return
}
} else {
var c net.Conn
c, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
if err != nil {
log.Logf("[ssu] %s - : %s", conn.LocalAddr(), err)
return
}
cc = &udpTunnelConn{Conn: c}
c, err := h.options.Chain.DialContext(context.Background(), "udp", "")
if err != nil {
log.Logf("[ssu] %s: %s", conn.LocalAddr(), err)
return
}
cc = c.(net.PacketConn)
defer cc.Close()
pc, ok := conn.(net.PacketConn)
......@@ -466,24 +479,11 @@ func (c *shadowConn) Write(b []byte) (n int, err error) {
type shadowUDPPacketConn struct {
net.PacketConn
raddr net.Addr
header []byte
raddr net.Addr
taddr net.Addr
}
func (c *shadowUDPPacketConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
buf := bytes.Buffer{}
if _, err = buf.Write(c.header); err != nil {
return
}
if _, err = buf.Write(b); err != nil {
return
}
_, err = c.PacketConn.WriteTo(buf.Bytes(), c.raddr)
return
}
func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) {
func (c *shadowUDPPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
buf := mPool.Get().([]byte)
defer mPool.Put(buf)
......@@ -501,45 +501,44 @@ func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) {
return
}
n = copy(b, dgram.Data)
addr, err = net.ResolveUDPAddr("udp", dgram.Header.Addr.String())
return
}
func (c *shadowUDPPacketConn) RemoteAddr() net.Addr {
return c.raddr
}
type shadowUDPStreamConn struct {
net.Conn
addr *gosocks5.Addr
func (c *shadowUDPPacketConn) Read(b []byte) (n int, err error) {
n, _, err = c.ReadFrom(b)
return
}
func (c *shadowUDPStreamConn) Read(b []byte) (n int, err error) {
dgram, err := gosocks5.ReadUDPDatagram(c.Conn)
func (c *shadowUDPPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
sa, err := gosocks5.NewAddr(addr.String())
if err != nil {
return
}
var rawaddr [maxSocksAddrLen]byte
nn, err := sa.Encode(rawaddr[:])
if err != nil {
return
}
n = copy(b, dgram.Data)
return
}
func (c *shadowUDPStreamConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.Conn.RemoteAddr()
buf := mPool.Get().([]byte)
defer mPool.Put(buf)
copy(buf, rawaddr[:nn])
n = copy(buf[nn:], b)
_, err = c.PacketConn.WriteTo(buf[:n+nn], c.raddr)
return
}
func (c *shadowUDPStreamConn) Write(b []byte) (n int, err error) {
n = len(b) // force byte length consistent
dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(uint16(len(b)), 0, c.addr), b)
buf := bytes.Buffer{}
dgram.Write(&buf)
_, err = c.Conn.Write(buf.Bytes())
return
func (c *shadowUDPPacketConn) Write(b []byte) (n int, err error) {
return c.WriteTo(b, c.taddr)
}
func (c *shadowUDPStreamConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.Write(b)
func (c *shadowUDPPacketConn) RemoteAddr() net.Addr {
return c.raddr
}
type shadowCipher struct {
......
......@@ -138,7 +138,7 @@ var ssProxyTests = []struct {
serverCipher *url.Userinfo
pass bool
}{
{nil, nil, false},
{nil, nil, true},
{&url.Userinfo{}, &url.Userinfo{}, true},
{url.User("abc"), url.User("abc"), true},
{url.UserPassword("abc", "def"), url.UserPassword("abc", "def"), true},
......
......@@ -39,6 +39,15 @@ func SSHDirectForwardConnector() Connector {
}
func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", raddr, options...)
}
func (c *sshDirectForwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, raddr string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
opts := &ConnectOptions{}
for _, option := range options {
option(opts)
......@@ -73,7 +82,16 @@ func SSHRemoteForwardConnector() Connector {
return &sshRemoteForwardConnector{}
}
func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options ...ConnectOption) (net.Conn, error) {
func (c *sshRemoteForwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) {
return c.ConnectContext(context.Background(), conn, "tcp", address, options...)
}
func (c *sshRemoteForwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) {
switch network {
case "udp", "udp4", "udp6":
return nil, fmt.Errorf("%s unsupported", network)
}
cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution.
if !ok {
return nil, errors.New("ssh: wrong connection type")
......@@ -87,10 +105,10 @@ func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options
if cc.session == nil || cc.session.client == nil {
return
}
if strings.HasPrefix(addr, ":") {
addr = "0.0.0.0" + addr
if strings.HasPrefix(address, ":") {
address = "0.0.0.0" + address
}
ln, err := cc.session.client.Listen("tcp", addr)
ln, err := cc.session.client.Listen("tcp", address)
if err != nil {
return
}
......@@ -99,7 +117,7 @@ func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options
for {
rc, err := ln.Accept()
if err != nil {
log.Logf("[ssh-rtcp] %s <-> %s accpet : %s", ln.Addr(), addr, err)
log.Logf("[ssh-rtcp] %s <-> %s accpet : %s", ln.Addr(), address, err)
return
}
// log.Log("[ssh-rtcp] accept", rc.LocalAddr(), rc.RemoteAddr())
......@@ -107,7 +125,7 @@ func (c *sshRemoteForwardConnector) Connect(conn net.Conn, addr string, options
case cc.session.connChan <- rc:
default:
rc.Close()
log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), addr)
log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), address)
}
}
}()
......
package gost
import (
"context"
"errors"
"fmt"
"io"
......@@ -167,9 +168,11 @@ func (h *tunHandler) Handle(conn net.Conn) {
var pc net.PacketConn
// fake tcp mode will be ignored when the client specifies a chain.
if raddr != nil && !h.options.Chain.IsEmpty() {
var cc net.Conn
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
pc = &udpTunnelConn{Conn: cc, raddr: raddr}
cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String())
if err != nil {
return err
}
pc = cc.(net.PacketConn)
} else {
if h.options.TCPMode {
if raddr != nil {
......@@ -549,9 +552,11 @@ func (h *tapHandler) Handle(conn net.Conn) {
var pc net.PacketConn
// fake tcp mode will be ignored when the client specifies a chain.
if raddr != nil && !h.options.Chain.IsEmpty() {
var cc net.Conn
cc, err = getSOCKS5UDPTunnel(h.options.Chain, nil)
pc = &udpTunnelConn{Conn: cc, raddr: raddr}
cc, err := h.options.Chain.DialContext(context.Background(), "udp", raddr.String())
if err != nil {
return err
}
pc = cc.(net.PacketConn)
} else {
if h.options.TCPMode {
if raddr != nil {
......
......@@ -19,19 +19,17 @@ func UDPTransporter() Transporter {
}
func (tr *udpTransporter) Dial(addr string, options ...DialOption) (net.Conn, error) {
raddr, err := net.ResolveUDPAddr("udp", addr)
taddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", nil)
conn, err := net.DialUDP("udp", nil, taddr)
if err != nil {
return nil, err
}
return &udpClientConn{
UDPConn: conn,
raddr: raddr,
}, nil
}
......@@ -340,19 +338,14 @@ func (c *udpServerConn) SetWriteDeadline(t time.Time) error {
type udpClientConn struct {
*net.UDPConn
raddr net.Addr
}
func (c *udpClientConn) Write(b []byte) (int, error) {
if c.raddr != nil {
return c.WriteTo(b, c.raddr)
}
func (c *udpClientConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.UDPConn.Write(b)
}
func (c *udpClientConn) RemoteAddr() net.Addr {
if c.raddr != nil {
return c.raddr
}
return c.UDPConn.RemoteAddr()
func (c *udpClientConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
n, err = c.Read(b)
addr = c.RemoteAddr()
return
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment