Commit ab87ca05 authored by rui.zheng's avatar rui.zheng

update quic-go

parent 5efffa7d
MIT License
Copyright (c) 2016 the quic-go authors & Google, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
......
......@@ -14,6 +14,7 @@ quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) p
Done:
- Basic protocol with support for QUIC version 34-36
- QUIC client
- HTTP/2 support
- Crypto (RSA / ECDSA certificates, Curve25519 for key exchange, AES-GCM or Chacha20-Poly1305 as stream cipher)
- Loss detection and retransmission (currently fast retransmission & RTO)
......@@ -22,11 +23,10 @@ Done:
Major TODOs:
- Security, especially DOS protections
- Security, especially DoS protections
- Performance
- Better packet loss detection
- Connection migration
- QUIC client
## Guides
......@@ -38,20 +38,26 @@ Running tests:
go test ./...
Running the example server:
### Running the example server
go run example/main.go -www /var/www/
Using the `quic_client` from chromium:
quic_client --quic-version=32 --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
Using Chrome:
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
### Using the example client
go run example/client/main.go https://quic.clemente.io
## Usage
### As a server
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
```go
......@@ -59,6 +65,16 @@ http.Handle("/", http.FileServer(http.Dir(wwwDir)))
h2quic.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
```
### As a client
See the [example client](example/client/main.go). Use a `QuicRoundTripper` as a `Transport` in a `http.Client`.
```go
http.Client{
Transport: &h2quic.QuicRoundTripper{},
}
```
## Building on Windows
Due to the low Windows timer resolution (see [StackOverflow question](http://stackoverflow.com/questions/37706834/high-resolution-timers-millisecond-precision-in-go-on-windows)) available with Go 1.6.x, some optimizations might not work when compiled with this version of the compiler. Please use Go 1.7 on Windows.
......@@ -28,8 +28,8 @@ type SentPacketHandler interface {
// ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber) error
ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedStopWaiting(*frames.StopWaitingFrame) error
GetAckFrame(dequeue bool) (*frames.AckFrame, error)
GetAckFrame() *frames.AckFrame
}
......@@ -19,31 +19,17 @@ type Packet struct {
SendTime time.Time
}
// GetStreamFramesForRetransmission gets all the streamframes for retransmission
func (p *Packet) GetStreamFramesForRetransmission() []*frames.StreamFrame {
var streamFrames []*frames.StreamFrame
// GetFramesForRetransmission gets all the frames for retransmission
func (p *Packet) GetFramesForRetransmission() []frames.Frame {
var fs []frames.Frame
for _, frame := range p.Frames {
if streamFrame, isStreamFrame := frame.(*frames.StreamFrame); isStreamFrame {
streamFrames = append(streamFrames, streamFrame)
}
}
return streamFrames
}
// GetControlFramesForRetransmission gets all the control frames for retransmission
func (p *Packet) GetControlFramesForRetransmission() []frames.Frame {
var controlFrames []frames.Frame
for _, frame := range p.Frames {
// omit ACKs
if _, isStreamFrame := frame.(*frames.StreamFrame); isStreamFrame {
switch frame.(type) {
case *frames.AckFrame:
continue
case *frames.StopWaitingFrame:
continue
}
_, isAck := frame.(*frames.AckFrame)
_, isStopWaiting := frame.(*frames.StopWaitingFrame)
if !isAck && !isStopWaiting {
controlFrames = append(controlFrames, frame)
}
fs = append(fs, frame)
}
return controlFrames
return fs
}
......@@ -6,45 +6,48 @@ import (
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
)
var (
// ErrDuplicatePacket occurres when a duplicate packet is received
ErrDuplicatePacket = errors.New("ReceivedPacketHandler: Duplicate Packet")
// ErrMapAccess occurs when a NACK contains invalid NACK ranges
ErrMapAccess = qerr.Error(qerr.InvalidAckData, "Packet does not exist in PacketHistory")
// ErrPacketSmallerThanLastStopWaiting occurs when a packet arrives with a packet number smaller than the largest LeastUnacked of a StopWaitingFrame. If this error occurs, the packet should be ignored
ErrPacketSmallerThanLastStopWaiting = errors.New("ReceivedPacketHandler: Packet number smaller than highest StopWaiting")
)
var (
errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "")
)
var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
type receivedPacketHandler struct {
largestInOrderObserved protocol.PacketNumber
largestObserved protocol.PacketNumber
ignorePacketsBelow protocol.PacketNumber
currentAckFrame *frames.AckFrame
stateChanged bool // has an ACK for this state already been sent? Will be set to false every time a new packet arrives, and to false every time an ACK is sent
largestObserved protocol.PacketNumber
ignorePacketsBelow protocol.PacketNumber
largestObservedReceivedTime time.Time
packetHistory *receivedPacketHistory
receivedTimes map[protocol.PacketNumber]time.Time
lowestInReceivedTimes protocol.PacketNumber
ackSendDelay time.Duration
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
ackAlarmResetCallback func(time.Time)
lastAck *frames.AckFrame
}
// NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler() ReceivedPacketHandler {
func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler {
// create a stopped timer, see https://github.com/golang/go/issues/12721#issuecomment-143010182
timer := time.NewTimer(0)
<-timer.C
return &receivedPacketHandler{
receivedTimes: make(map[protocol.PacketNumber]time.Time),
packetHistory: newReceivedPacketHistory(),
packetHistory: newReceivedPacketHistory(),
ackAlarmResetCallback: ackAlarmResetCallback,
ackSendDelay: protocol.AckSendDelay,
}
}
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber) error {
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
if packetNumber == 0 {
return errInvalidPacketNumber
}
......@@ -55,30 +58,21 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
return ErrPacketSmallerThanLastStopWaiting
}
_, ok := h.receivedTimes[packetNumber]
if packetNumber <= h.largestInOrderObserved || ok {
if h.packetHistory.IsDuplicate(packetNumber) {
return ErrDuplicatePacket
}
h.packetHistory.ReceivedPacket(packetNumber)
h.stateChanged = true
h.currentAckFrame = nil
err := h.packetHistory.ReceivedPacket(packetNumber)
if err != nil {
return err
}
if packetNumber > h.largestObserved {
h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now()
}
if packetNumber == h.largestInOrderObserved+1 {
h.largestInOrderObserved = packetNumber
}
h.receivedTimes[packetNumber] = time.Now()
if protocol.PacketNumber(len(h.receivedTimes)) > protocol.MaxTrackedReceivedPackets {
return errTooManyOutstandingReceivedPackets
}
h.maybeQueueAck(packetNumber, shouldInstigateAck)
return nil
}
......@@ -89,55 +83,84 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame)
}
h.ignorePacketsBelow = f.LeastUnacked - 1
h.garbageCollectReceivedTimes()
// the LeastUnacked is the smallest packet number of any packet for which the sender is still awaiting an ack. So the largestInOrderObserved is one less than that
if f.LeastUnacked > h.largestInOrderObserved {
h.largestInOrderObserved = f.LeastUnacked - 1
}
h.packetHistory.DeleteBelow(f.LeastUnacked)
return nil
}
func (h *receivedPacketHandler) GetAckFrame(dequeue bool) (*frames.AckFrame, error) {
if !h.stateChanged {
return nil, nil
func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
var ackAlarmSet bool
h.packetsReceivedSinceLastAck++
if shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
}
if dequeue {
h.stateChanged = false
// always ack the first packet
if h.lastAck == nil {
h.ackQueued = true
}
if h.currentAckFrame != nil {
return h.currentAckFrame, nil
// Always send an ack every 20 packets in order to allow the peer to discard
// information from the SentPacketManager and provide an RTT measurement.
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
h.ackQueued = true
}
packetReceivedTime, ok := h.receivedTimes[h.largestObserved]
if !ok {
return nil, ErrMapAccess
// if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
// note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
h.ackQueued = true
}
// check if a new missing range above the previously was created
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked {
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck {
h.ackQueued = true
} else {
if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay)
ackAlarmSet = true
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
ackAlarmSet = false
}
if ackAlarmSet {
h.ackAlarmResetCallback(h.ackAlarm)
}
}
func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
return nil
}
ackRanges := h.packetHistory.GetAckRanges()
h.currentAckFrame = &frames.AckFrame{
ack := &frames.AckFrame{
LargestAcked: h.largestObserved,
LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber,
PacketReceivedTime: packetReceivedTime,
PacketReceivedTime: h.largestObservedReceivedTime,
}
if len(ackRanges) > 1 {
h.currentAckFrame.AckRanges = ackRanges
ack.AckRanges = ackRanges
}
return h.currentAckFrame, nil
}
h.lastAck = ack
h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
func (h *receivedPacketHandler) garbageCollectReceivedTimes() {
for i := h.lowestInReceivedTimes; i <= h.ignorePacketsBelow; i++ {
delete(h.receivedTimes, i)
}
if h.ignorePacketsBelow > h.lowestInReceivedTimes {
h.lowestInReceivedTimes = h.ignorePacketsBelow + 1
}
return ack
}
package ackhandler
import (
"sync"
"github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
)
type receivedPacketHistory struct {
ranges *utils.PacketIntervalList
mutex sync.RWMutex
// the map is used as a replacement for a set here. The bool is always supposed to be set to true
receivedPacketNumbers map[protocol.PacketNumber]bool
lowestInReceivedPacketNumbers protocol.PacketNumber
}
var (
errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received packets")
)
// newReceivedPacketHistory creates a new received packet history
func newReceivedPacketHistory() *receivedPacketHistory {
return &receivedPacketHistory{
ranges: utils.NewPacketIntervalList(),
ranges: utils.NewPacketIntervalList(),
receivedPacketNumbers: make(map[protocol.PacketNumber]bool),
}
}
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) {
h.mutex.Lock()
defer h.mutex.Unlock()
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
if h.ranges.Len() >= protocol.MaxTrackedReceivedAckRanges {
return errTooManyOutstandingReceivedAckRanges
}
if len(h.receivedPacketNumbers) >= protocol.MaxTrackedReceivedPackets {
return errTooManyOutstandingReceivedPackets
}
h.receivedPacketNumbers[p] = true
if h.ranges.Len() == 0 {
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
return
return nil
}
for el := h.ranges.Back(); el != nil; el = el.Prev() {
// p already included in an existing range. Nothing to do here
if p >= el.Value.Start && p <= el.Value.End {
return
return nil
}
var rangeExtended bool
......@@ -52,46 +66,61 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) {
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
prev.Value.End = el.Value.End
h.ranges.Remove(el)
return
return nil
}
return // if the two ranges were not merge, we're done here
return nil // if the two ranges were not merge, we're done here
}
// create a new range at the end
if p > el.Value.End {
h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el)
return
return nil
}
}
// create a new range at the beginning
h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front())
return nil
}
// DeleteBelow deletes all entries below the leastUnacked packet number
func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber) {
h.mutex.Lock()
defer h.mutex.Unlock()
h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, leastUnacked)
nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next()
if leastUnacked > el.Value.Start && leastUnacked <= el.Value.End {
for i := el.Value.Start; i < leastUnacked; i++ { // adjust start value of a range
delete(h.receivedPacketNumbers, i)
}
el.Value.Start = leastUnacked
}
if el.Value.End < leastUnacked { // delete a whole range
} else if el.Value.End < leastUnacked { // delete a whole range
for i := el.Value.Start; i <= el.Value.End; i++ {
delete(h.receivedPacketNumbers, i)
}
h.ranges.Remove(el)
} else {
} else { // no ranges affected. Nothing to do
return
}
}
}
// IsDuplicate determines if a packet should be regarded as a duplicate packet
// note that after receiving a StopWaitingFrame, all packets below the LeastUnacked should be regarded as duplicates, even if the packet was just delayed
func (h *receivedPacketHistory) IsDuplicate(p protocol.PacketNumber) bool {
if p < h.lowestInReceivedPacketNumbers {
return true
}
_, ok := h.receivedPacketNumbers[p]
return ok
}
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.ranges.Len() == 0 {
return nil
}
......@@ -104,3 +133,13 @@ func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange {
return ackRanges
}
func (h *receivedPacketHistory) GetHighestAckRange() frames.AckRange {
ackRange := frames.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.FirstPacketNumber = r.Start
ackRange.LastPacketNumber = r.End
}
return ackRange
}
......@@ -47,9 +47,7 @@ type sentPacketHandler struct {
}
// NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler() SentPacketHandler {
rttStats := &congestion.RTTStats{}
func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
congestion := congestion.NewCubicSender(
congestion.DefaultClock{},
rttStats,
......
......@@ -13,8 +13,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
install:
- rmdir c:\go /s /q
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.7.1.windows-amd64.zip
- 7z x go1.7.1.windows-amd64.zip -y -oC:\ > NUL
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.7.5.windows-amd64.zip
- 7z x go1.7.5.windows-amd64.zip -y -oC:\ > NUL
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
- echo %PATH%
- echo %GOPATH%
......
package quic
import (
"bytes"
"crypto/tls"
"errors"
"net"
"strings"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
)
// A Client of QUIC
type Client struct {
addr *net.UDPAddr
conn *net.UDPConn
hostname string
connectionID protocol.ConnectionID
version protocol.VersionNumber
versionNegotiated bool
closed uint32 // atomic bool
tlsConfig *tls.Config
cryptoChangeCallback CryptoChangeCallback
versionNegotiateCallback VersionNegotiateCallback
session packetHandler
}
// VersionNegotiateCallback is called once the client has a negotiated version
type VersionNegotiateCallback func() error
var errHostname = errors.New("Invalid hostname")
var (
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
)
// NewClient makes a new client
func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
udpAddr, err := net.ResolveUDPAddr("udp", host)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
connectionID, err := utils.GenerateConnectionID()
if err != nil {
return nil, err
}
hostname, _, err := net.SplitHostPort(host)
if err != nil {
return nil, err
}
client := &Client{
addr: udpAddr,
conn: conn,
hostname: hostname,
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
connectionID: connectionID,
tlsConfig: tlsConfig,
cryptoChangeCallback: cryptoChangeCallback,
versionNegotiateCallback: versionNegotiateCallback,
}
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version)
err = client.createNewSession(nil)
if err != nil {
return nil, err
}
return client, nil
}
// Listen listens
func (c *Client) Listen() error {
for {
data := getPacketBuffer()
data = data[:protocol.MaxPacketSize]
n, _, err := c.conn.ReadFromUDP(data)
if err != nil {
if strings.HasSuffix(err.Error(), "use of closed network connection") {
return nil
}
return err
}
data = data[:n]
err = c.handlePacket(data)
if err != nil {
utils.Errorf("error handling packet: %s", err.Error())
c.session.Close(err)
return err
}
}
}
// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs)
func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) {
return c.session.OpenStream(id)
}
// Close closes the connection
func (c *Client) Close(e error) error {
// Only close once
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return nil
}
_ = c.session.Close(e)
return c.conn.Close()
}
func (c *Client) handlePacket(packet []byte) error {
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize {
return qerr.PacketTooLarge
}
rcvTime := time.Now()
r := bytes.NewReader(packet)
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
}
hdr.Raw = packet[:len(packet)-r.Len()]
// ignore delayed / duplicated version negotiation packets
if c.versionNegotiated && hdr.VersionFlag {
return nil
}
// this is the first packet after the client sent a packet with the VersionFlag set
// if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && !c.versionNegotiated {
c.versionNegotiated = true
err = c.versionNegotiateCallback()
if err != nil {
return err
}
}
if hdr.VersionFlag {
var hasCommonVersion bool // check if we're supporting any of the offered versions
for _, v := range hdr.SupportedVersions {
// check if the server sent the offered version in supported versions
if v == c.version {
return qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection.")
}
if v != protocol.VersionUnsupported {
hasCommonVersion = true
}
}
if !hasCommonVersion {
utils.Infof("No common version found.")
return qerr.InvalidVersion
}
ok, highestSupportedVersion := protocol.HighestSupportedVersion(hdr.SupportedVersions)
if !ok {
return qerr.VersionNegotiationMismatch
}
utils.Infof("Switching to QUIC version %d", highestSupportedVersion)
c.version = highestSupportedVersion
c.versionNegotiated = true
c.session.Close(errCloseSessionForNewVersion)
err = c.createNewSession(hdr.SupportedVersions)
if err != nil {
return err
}
err = c.versionNegotiateCallback()
if err != nil {
return err
}
return nil // version negotiation packets have no payload
}
c.session.handlePacket(&receivedPacket{
remoteAddr: c.addr,
publicHeader: hdr,
data: packet[len(packet)-r.Len():],
rcvTime: rcvTime,
})
return nil
}
func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
var err error
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.tlsConfig, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions)
if err != nil {
return err
}
go c.session.run()
return nil
}
func (c *Client) streamCallback(session *Session, stream utils.Stream) {}
func (c *Client) closeCallback(id protocol.ConnectionID) {
utils.Infof("Connection %x closed.", id)
}
......@@ -2,6 +2,8 @@ coverage:
round: nearest
ignore:
- ackhandler/packet_linkedlist.go
- h2quic/gzipreader.go
- h2quic/response.go
- utils/byteinterval_linkedlist.go
- utils/packetinterval_linkedlist.go
status:
......
package crypto
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"errors"
"strings"
)
// A CertChain holds a certificate and a private key
type CertChain interface {
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
GetLeafCert(sni string) ([]byte, error)
}
// proofSource stores a key and a certificate for the server proof
type proofSource struct {
type certChain struct {
config *tls.Config
}
// NewProofSource loads the key and cert from files
func NewProofSource(tlsConfig *tls.Config) (Signer, error) {
return &proofSource{config: tlsConfig}, nil
var _ CertChain = &certChain{}
var errNoMatchingCertificate = errors.New("no matching certificate found")
// NewCertChain loads the key and cert from files
func NewCertChain(tlsConfig *tls.Config) CertChain {
return &certChain{config: tlsConfig}
}
// SignServerProof signs CHLO and server config for use in the server proof
func (ps *proofSource) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
cert, err := ps.getCertForSNI(sni)
func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
key, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, errors.New("expected PrivateKey to implement crypto.Signer")
}
opts := crypto.SignerOpts(crypto.SHA256)
if _, ok = key.(*rsa.PrivateKey); ok {
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
}
return key.Sign(rand.Reader, hash.Sum(nil), opts)
return signServerProof(cert, chlo, serverConfigData)
}
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
func (ps *proofSource) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
cert, err := ps.getCertForSNI(sni)
func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
......@@ -58,17 +47,17 @@ func (ps *proofSource) GetCertsCompressed(sni string, pCommonSetHashes, pCachedH
}
// GetLeafCert gets the leaf certificate
func (ps *proofSource) GetLeafCert(sni string) ([]byte, error) {
cert, err := ps.getCertForSNI(sni)
func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
cert, err := c.getCertForSNI(sni)
if err != nil {
return nil, err
}
return cert.Certificate[0], nil
}
func (ps *proofSource) getCertForSNI(sni string) (*tls.Certificate, error) {
if ps.config.GetCertificate != nil {
cert, err := ps.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
if c.config.GetCertificate != nil {
cert, err := c.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if err != nil {
return nil, err
}
......@@ -76,17 +65,20 @@ func (ps *proofSource) getCertForSNI(sni string) (*tls.Certificate, error) {
return cert, nil
}
}
if len(ps.config.NameToCertificate) != 0 {
if cert, ok := ps.config.NameToCertificate[sni]; ok {
if len(c.config.NameToCertificate) != 0 {
if cert, ok := c.config.NameToCertificate[sni]; ok {
return cert, nil
}
wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' })
if cert, ok := ps.config.NameToCertificate[wildcardSNI]; ok {
if cert, ok := c.config.NameToCertificate[wildcardSNI]; ok {
return cert, nil
}
}
if len(ps.config.Certificates) != 0 {
return &ps.config.Certificates[0], nil
if len(c.config.Certificates) != 0 {
return &c.config.Certificates[0], nil
}
return nil, errors.New("no matching certificate found")
return nil, errNoMatchingCertificate
}
......@@ -22,8 +22,8 @@ const (
type entry struct {
t entryType
h uint64
i uint32
h uint64 // set hash
i uint32 // index
}
func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
......@@ -41,7 +41,7 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by
chainHashes := make([]uint64, len(chain))
for i := range chain {
chainHashes[i] = hashCert(chain[i])
chainHashes[i] = HashCert(chain[i])
}
entries := buildEntries(chain, chainHashes, cachedHashes, setHashes)
......@@ -89,6 +89,111 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by
return res.Bytes(), nil
}
func decompressChain(data []byte) ([][]byte, error) {
var chain [][]byte
var entries []entry
r := bytes.NewReader(data)
var numCerts int
var hasCompressedCerts bool
for {
entryTypeByte, err := r.ReadByte()
if entryTypeByte == 0 {
break
}
et := entryType(entryTypeByte)
if err != nil {
return nil, err
}
numCerts++
switch et {
case entryCached:
// we're not sending any certificate hashes in the CHLO, so there shouldn't be any cached certificates in the chain
return nil, errors.New("unexpected cached certificate")
case entryCommon:
e := entry{t: entryCommon}
e.h, err = utils.ReadUint64(r)
if err != nil {
return nil, err
}
e.i, err = utils.ReadUint32(r)
if err != nil {
return nil, err
}
certSet, ok := certSets[e.h]
if !ok {
return nil, errors.New("unknown certSet")
}
if e.i >= uint32(len(certSet)) {
return nil, errors.New("certificate not found in certSet")
}
entries = append(entries, e)
chain = append(chain, certSet[e.i])
case entryCompressed:
hasCompressedCerts = true
entries = append(entries, entry{t: entryCompressed})
chain = append(chain, nil)
default:
return nil, errors.New("unknown entryType")
}
}
if numCerts == 0 {
return make([][]byte, 0, 0), nil
}
if hasCompressedCerts {
uncompressedLength, err := utils.ReadUint32(r)
if err != nil {
fmt.Println(4)
return nil, err
}
zlibDict := buildZlibDictForEntries(entries, chain)
gz, err := zlib.NewReaderDict(r, zlibDict)
if err != nil {
return nil, err
}
defer gz.Close()
var totalLength uint32
var certIndex int
for totalLength < uncompressedLength {
lenBytes := make([]byte, 4)
_, err := gz.Read(lenBytes)
if err != nil {
return nil, err
}
certLen := binary.LittleEndian.Uint32(lenBytes)
cert := make([]byte, certLen)
n, err := gz.Read(cert)
if uint32(n) != certLen && err != nil {
return nil, err
}
for {
if certIndex >= len(entries) {
return nil, errors.New("CertCompression BUG: no element to save uncompressed certificate")
}
if entries[certIndex].t == entryCompressed {
chain[certIndex] = cert
certIndex++
break
}
certIndex++
}
totalLength += 4 + certLen
}
}
return chain, nil
}
func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry {
res := make([]entry, len(chain))
chainLoop:
......@@ -149,8 +254,19 @@ func splitHashes(hashes []byte) ([]uint64, error) {
return res, nil
}
func hashCert(cert []byte) uint64 {
h := fnv.New64()
func getCommonCertificateHashes() []byte {
ccs := make([]byte, 8*len(certSets), 8*len(certSets))
i := 0
for certSetHash := range certSets {
binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash)
i++
}
return ccs
}
// HashCert calculates the FNV1a hash of a certificate
func HashCert(cert []byte) uint64 {
h := fnv.New64a()
h.Write(cert)
return h.Sum64()
}
package crypto
import (
"crypto/tls"
"crypto/x509"
"errors"
"hash/fnv"
"time"
"github.com/lucas-clemente/quic-go/qerr"
)
// CertManager manages the certificates sent by the server
type CertManager interface {
SetData([]byte) error
GetCommonCertificateHashes() []byte
GetLeafCert() []byte
GetLeafCertHash() (uint64, error)
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
Verify(hostname string) error
}
type certManager struct {
chain []*x509.Certificate
config *tls.Config
}
var _ CertManager = &certManager{}
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded")
// NewCertManager creates a new CertManager
func NewCertManager(tlsConfig *tls.Config) CertManager {
return &certManager{config: tlsConfig}
}
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
func (c *certManager) SetData(data []byte) error {
byteChain, err := decompressChain(data)
if err != nil {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
}
chain := make([]*x509.Certificate, len(byteChain), len(byteChain))
for i, data := range byteChain {
cert, err := x509.ParseCertificate(data)
if err != nil {
return err
}
chain[i] = cert
}
c.chain = chain
return nil
}
func (c *certManager) GetCommonCertificateHashes() []byte {
return getCommonCertificateHashes()
}
// GetLeafCert returns the leaf certificate of the certificate chain
// it returns nil if the certificate chain has not yet been set
func (c *certManager) GetLeafCert() []byte {
if len(c.chain) == 0 {
return nil
}
return c.chain[0].Raw
}
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
func (c *certManager) GetLeafCertHash() (uint64, error) {
leafCert := c.GetLeafCert()
if leafCert == nil {
return 0, errNoCertificateChain
}
h := fnv.New64a()
_, err := h.Write(leafCert)
if err != nil {
return 0, err
}
return h.Sum64(), nil
}
// VerifyServerProof verifies the signature of the server config
// it should only be called after the certificate chain has been set, otherwise it returns false
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool {
if len(c.chain) == 0 {
return false
}
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData)
}
// Verify verifies the certificate chain
func (c *certManager) Verify(hostname string) error {
if len(c.chain) == 0 {
return errNoCertificateChain
}
if c.config != nil && c.config.InsecureSkipVerify {
return nil
}
leafCert := c.chain[0]
var opts x509.VerifyOptions
if c.config != nil {
opts.Roots = c.config.RootCAs
opts.DNSName = c.config.ServerName
if c.config.Time == nil {
opts.CurrentTime = time.Now()
} else {
opts.CurrentTime = c.config.Time()
}
} else {
opts.DNSName = hostname
}
// the first certificate is the leaf certificate, all others are intermediates
if len(c.chain) > 1 {
intermediates := x509.NewCertPool()
for i := 1; i < len(c.chain); i++ {
intermediates.AddCert(c.chain[i])
}
opts.Intermediates = intermediates
}
_, err := leafCert.Verify(opts)
return err
}
// +build ignore
package crypto
import (
"crypto/rand"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Chacha20poly1305", func() {
var (
alice, bob AEAD
keyAlice, keyBob, ivAlice, ivBob []byte
)
BeforeEach(func() {
keyAlice = make([]byte, 32)
keyBob = make([]byte, 32)
ivAlice = make([]byte, 4)
ivBob = make([]byte, 4)
rand.Reader.Read(keyAlice)
rand.Reader.Read(keyBob)
rand.Reader.Read(ivAlice)
rand.Reader.Read(ivBob)
var err error
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice)
Expect(err).ToNot(HaveOccurred())
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob)
Expect(err).ToNot(HaveOccurred())
})
It("seals and opens", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := bob.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("seals and opens reverse", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := alice.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("has the proper length", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
Expect(b).To(HaveLen(6 + 12))
})
It("fails with wrong aad", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err := bob.Open(nil, b, 42, []byte("aad2"))
Expect(err).To(HaveOccurred())
})
It("rejects wrong key and iv sizes", func() {
var err error
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs"
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:])
Expect(err).To(MatchError(e))
})
})
......@@ -21,15 +21,21 @@ import (
// }
// DeriveKeysAESGCM derives the client and server keys and creates a matching AES-GCM AEAD instance
func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte) (AEAD, error) {
otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16)
func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
var swap bool
if pers == protocol.PerspectiveClient {
swap = true
}
otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16, swap)
if err != nil {
return nil, err
}
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
}
func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int) ([]byte, []byte, []byte, []byte, error) {
// deriveKeys derives the keys and the IVs
// swap should be set true if generating the values for the client, and false for the server
func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int, swap bool) ([]byte, []byte, []byte, []byte, error) {
var info bytes.Buffer
if forwardSecure {
info.Write([]byte("QUIC forward secure key expansion\x00"))
......@@ -47,17 +53,33 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol
if _, err := io.ReadFull(r, s); err != nil {
return nil, nil, nil, nil, err
}
otherKey := s[:keyLen]
myKey := s[keyLen : 2*keyLen]
otherIV := s[2*keyLen : 2*keyLen+4]
myIV := s[2*keyLen+4:]
key1 := s[:keyLen]
key2 := s[keyLen : 2*keyLen]
iv1 := s[2*keyLen : 2*keyLen+4]
iv2 := s[2*keyLen+4:]
var otherKey, myKey []byte
var otherIV, myIV []byte
if !forwardSecure {
if err := diversify(myKey, myIV, divNonce); err != nil {
if err := diversify(key2, iv2, divNonce); err != nil {
return nil, nil, nil, nil, err
}
}
if swap {
otherKey = key2
myKey = key1
otherIV = iv2
myIV = iv1
} else {
otherKey = key1
myKey = key2
otherIV = iv1
myIV = iv2
}
return otherKey, myKey, otherIV, myIV, nil
}
......
package crypto
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"errors"
"math/big"
)
type ecdsaSignature struct {
R, S *big.Int
}
// signServerProof signs CHLO and server config for use in the server proof
func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) {
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
key, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, errors.New("expected PrivateKey to implement crypto.Signer")
}
opts := crypto.SignerOpts(crypto.SHA256)
if _, ok = key.(*rsa.PrivateKey); ok {
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
}
return key.Sign(rand.Reader, hash.Sum(nil), opts)
}
// verifyServerProof verifies the server proof signature
func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool {
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
// RSA
if cert.PublicKeyAlgorithm == x509.RSA {
opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts)
return err == nil
}
// ECDSA
signature := &ecdsaSignature{}
rest, err := asn1.Unmarshal(proof, signature)
if err != nil || len(rest) != 0 {
return false
}
return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S)
}
package crypto
// A Signer holds a certificate and a private key
type Signer interface {
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
GetLeafCert(sni string) ([]byte, error)
}
......@@ -2,38 +2,39 @@ package flowcontrol
import (
"errors"
"fmt"
"sync"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
)
type flowControlManager struct {
connectionParametersManager *handshake.ConnectionParametersManager
connectionParameters handshake.ConnectionParametersManager
rttStats *congestion.RTTStats
streamFlowController map[protocol.StreamID]*flowController
contributesToConnectionFlowControl map[protocol.StreamID]bool
mutex sync.RWMutex
}
var (
// ErrStreamFlowControlViolation is a stream flow control violation
ErrStreamFlowControlViolation = errors.New("Stream level flow control violation")
// ErrConnectionFlowControlViolation is a connection level flow control violation
ErrConnectionFlowControlViolation = errors.New("Connection level flow control violation")
)
var _ FlowControlManager = &flowControlManager{}
var errMapAccess = errors.New("Error accessing the flowController map.")
// NewFlowControlManager creates a new flow control manager
func NewFlowControlManager(connectionParametersManager *handshake.ConnectionParametersManager) FlowControlManager {
func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager {
fcm := flowControlManager{
connectionParametersManager: connectionParametersManager,
connectionParameters: connectionParameters,
rttStats: rttStats,
streamFlowController: make(map[protocol.StreamID]*flowController),
contributesToConnectionFlowControl: make(map[protocol.StreamID]bool),
}
// initialize connection level flow controller
fcm.streamFlowController[0] = newFlowController(0, connectionParametersManager)
fcm.streamFlowController[0] = newFlowController(0, connectionParameters, rttStats)
fcm.contributesToConnectionFlowControl[0] = false
return &fcm
}
......@@ -47,7 +48,7 @@ func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesTo
return
}
f.streamFlowController[streamID] = newFlowController(streamID, f.connectionParametersManager)
f.streamFlowController[streamID] = newFlowController(streamID, f.connectionParameters, f.rttStats)
f.contributesToConnectionFlowControl[streamID] = contributesToConnectionFlow
}
......@@ -59,6 +60,48 @@ func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
f.mutex.Unlock()
}
// ResetStream should be called when receiving a RstStreamFrame
// it updates the byte offset to the value in the RstStreamFrame
// streamID must not be 0 here
func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
f.mutex.Lock()
defer f.mutex.Unlock()
streamFlowController, err := f.getFlowController(streamID)
if err != nil {
return err
}
increment, err := streamFlowController.UpdateHighestReceived(byteOffset)
if err != nil {
return qerr.StreamDataAfterTermination
}
if streamFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow))
}
if f.contributesToConnectionFlowControl[streamID] {
connectionFlowController := f.streamFlowController[0]
connectionFlowController.IncrementHighestReceived(increment)
if connectionFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow))
}
}
return nil
}
func (f *flowControlManager) GetBytesSent(streamID protocol.StreamID) (protocol.ByteCount, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
fc, err := f.getFlowController(streamID)
if err != nil {
return 0, err
}
return fc.GetBytesSent(), nil
}
// UpdateHighestReceived updates the highest received byte offset for a stream
// it adds the number of additional bytes to connection level flow control
// streamID must not be 0 here
......@@ -70,17 +113,19 @@ func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, b
if err != nil {
return err
}
increment := streamFlowController.UpdateHighestReceived(byteOffset)
// UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered
// this error can be ignored here
increment, _ := streamFlowController.UpdateHighestReceived(byteOffset)
if streamFlowController.CheckFlowControlViolation() {
return ErrStreamFlowControlViolation
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow))
}
if f.contributesToConnectionFlowControl[streamID] {
connectionFlowController := f.streamFlowController[0]
connectionFlowController.IncrementHighestReceived(increment)
if connectionFlowController.CheckFlowControlViolation() {
return ErrConnectionFlowControlViolation
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow))
}
}
......@@ -117,6 +162,16 @@ func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) {
return res
}
func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
flowController, err := f.getFlowController(streamID)
if err != nil {
return 0, err
}
return flowController.receiveFlowControlWindow, nil
}
// streamID must not be 0 here
func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
// Only lock the part reading from the map, since send-windows are only accessed from the session goroutine.
......
package flowcontrol
import (
"errors"
"time"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
)
type flowController struct {
streamID protocol.StreamID
connectionParametersManager *handshake.ConnectionParametersManager
connectionParameters handshake.ConnectionParametersManager
rttStats *congestion.RTTStats
bytesSent protocol.ByteCount
sendFlowControlWindow protocol.ByteCount
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveFlowControlWindow protocol.ByteCount
receiveFlowControlWindowIncrement protocol.ByteCount
lastWindowUpdateTime time.Time
bytesRead protocol.ByteCount
highestReceived protocol.ByteCount
receiveFlowControlWindow protocol.ByteCount
receiveFlowControlWindowIncrement protocol.ByteCount
maxReceiveFlowControlWindowIncrement protocol.ByteCount
}
// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously
var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
// newFlowController gets a new flow controller
func newFlowController(streamID protocol.StreamID, connectionParametersManager *handshake.ConnectionParametersManager) *flowController {
func newFlowController(streamID protocol.StreamID, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController {
fc := flowController{
streamID: streamID,
connectionParametersManager: connectionParametersManager,
streamID: streamID,
connectionParameters: connectionParameters,
rttStats: rttStats,
}
if streamID == 0 {
fc.receiveFlowControlWindow = connectionParametersManager.GetReceiveConnectionFlowControlWindow()
fc.receiveFlowControlWindow = connectionParameters.GetReceiveConnectionFlowControlWindow()
fc.receiveFlowControlWindowIncrement = fc.receiveFlowControlWindow
fc.maxReceiveFlowControlWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow()
} else {
fc.receiveFlowControlWindow = connectionParametersManager.GetReceiveStreamFlowControlWindow()
fc.receiveFlowControlWindow = connectionParameters.GetReceiveStreamFlowControlWindow()
fc.receiveFlowControlWindowIncrement = fc.receiveFlowControlWindow
fc.maxReceiveFlowControlWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow()
}
return &fc
......@@ -40,9 +55,9 @@ func newFlowController(streamID protocol.StreamID, connectionParametersManager *
func (c *flowController) getSendFlowControlWindow() protocol.ByteCount {
if c.sendFlowControlWindow == 0 {
if c.streamID == 0 {
return c.connectionParametersManager.GetSendConnectionFlowControlWindow()
return c.connectionParameters.GetSendConnectionFlowControlWindow()
}
return c.connectionParametersManager.GetSendStreamFlowControlWindow()
return c.connectionParameters.GetSendStreamFlowControlWindow()
}
return c.sendFlowControlWindow
}
......@@ -51,6 +66,10 @@ func (c *flowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n
}
func (c *flowController) GetBytesSent() protocol.ByteCount {
return c.bytesSent
}
// UpdateSendWindow should be called after receiving a WindowUpdateFrame
// it returns true if the window was actually updated
func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
......@@ -76,13 +95,19 @@ func (c *flowController) SendWindowOffset() protocol.ByteCount {
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
// Should **only** be used for the stream-level FlowController
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) protocol.ByteCount {
// it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
// This error occurs every time StreamFrames get reordered and has to be ignored in that case
// It should only be treated as an error when resetting a stream
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) {
if byteOffset == c.highestReceived {
return 0, nil
}
if byteOffset > c.highestReceived {
increment := byteOffset - c.highestReceived
c.highestReceived = byteOffset
return increment
return increment, nil
}
return 0
return 0, ErrReceivedSmallerByteOffset
}
// IncrementHighestReceived adds an increment to the highestReceived value
......@@ -99,14 +124,52 @@ func (c *flowController) AddBytesRead(n protocol.ByteCount) {
// if so, it returns true and the offset of the window
func (c *flowController) MaybeTriggerWindowUpdate() (bool, protocol.ByteCount) {
diff := c.receiveFlowControlWindow - c.bytesRead
// Chromium implements the same threshold
if diff < (c.receiveFlowControlWindowIncrement / 2) {
c.maybeAdjustWindowIncrement()
c.lastWindowUpdateTime = time.Now()
c.receiveFlowControlWindow = c.bytesRead + c.receiveFlowControlWindowIncrement
return true, c.receiveFlowControlWindow
}
return false, 0
}
// maybeAdjustWindowIncrement increases the receiveFlowControlWindowIncrement if we're sending WindowUpdates too often
func (c *flowController) maybeAdjustWindowIncrement() {
if c.lastWindowUpdateTime.IsZero() {
return
}
rtt := c.rttStats.SmoothedRTT()
if rtt == 0 {
return
}
timeSinceLastWindowUpdate := time.Now().Sub(c.lastWindowUpdateTime)
// interval between the window updates is sufficiently large, no need to increase the increment
if timeSinceLastWindowUpdate >= 2*rtt {
return
}
oldWindowSize := c.receiveFlowControlWindowIncrement
c.receiveFlowControlWindowIncrement = utils.MinByteCount(2*c.receiveFlowControlWindowIncrement, c.maxReceiveFlowControlWindowIncrement)
// debug log, if the window size was actually increased
if oldWindowSize < c.receiveFlowControlWindowIncrement {
newWindowSize := c.receiveFlowControlWindowIncrement / (1 << 10)
if c.streamID == 0 {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize)
} else {
utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize)
}
}
}
func (c *flowController) CheckFlowControlViolation() bool {
if c.highestReceived > c.receiveFlowControlWindow {
return true
......
......@@ -13,9 +13,11 @@ type FlowControlManager interface {
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
RemoveStream(streamID protocol.StreamID)
// methods needed for receiving data
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error
GetWindowUpdates() []WindowUpdate
GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error)
// methods needed for sending data
AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error
SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error)
......
......@@ -27,8 +27,10 @@ type AckFrame struct {
LowestAcked protocol.PacketNumber
AckRanges []AckRange // has to be ordered. The ACK range with the highest FirstPacketNumber goes first, the ACK range with the lowest FirstPacketNumber goes last
// time when the LargestAcked was receiveid
// this field Will not be set for received ACKs frames
PacketReceivedTime time.Time
DelayTime time.Duration
PacketReceivedTime time.Time // only for received packets. Will not be modified for received ACKs frames
}
// ParseAckFrame reads an ACK frame
......@@ -83,7 +85,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame,
if err != nil {
return nil, err
}
if ackBlockLength < 1 {
if frame.LargestAcked > 0 && ackBlockLength < 1 {
return nil, ErrInvalidFirstAckRange
}
......@@ -141,7 +143,11 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame,
frame.LowestAcked = frame.AckRanges[len(frame.AckRanges)-1].FirstPacketNumber
} else {
frame.LowestAcked = protocol.PacketNumber(largestAcked + 1 - ackBlockLength)
if frame.LargestAcked == 0 {
frame.LowestAcked = 0
} else {
frame.LowestAcked = protocol.PacketNumber(largestAcked + 1 - ackBlockLength)
}
}
if !frame.validateAckRanges() {
......
......@@ -11,9 +11,18 @@ func LogFrame(frame Frame, sent bool) {
if sent {
dir = "->"
}
if sf, ok := frame.(*StreamFrame); ok {
utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, sf.StreamID, sf.FinBit, sf.Offset, sf.DataLen(), sf.Offset+sf.DataLen())
return
switch f := frame.(type) {
case *StreamFrame:
utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
case *StopWaitingFrame:
if sent {
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
} else {
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
}
case *AckFrame:
utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String())
default:
utils.Debugf("\t%s %#v", dir, frame)
}
utils.Debugf("\t%s %#v", dir, frame)
}
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
)
type quicClient interface {
OpenStream(protocol.StreamID) (utils.Stream, error)
Close(error) error
Listen() error
}
// Client is a HTTP2 client doing QUIC requests
type Client struct {
mutex sync.RWMutex
cryptoChangedCond sync.Cond
t *QuicRoundTripper
hostname string
encryptionLevel protocol.EncryptionLevel
client quicClient
headerStream utils.Stream
headerErr *qerr.QuicError
highestOpenedStream protocol.StreamID
requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response
}
var _ h2quicClient = &Client{}
// NewClient creates a new client
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) {
c := &Client{
t: t,
hostname: authorityAddr("https", hostname),
highestOpenedStream: 3,
responses: make(map[protocol.StreamID]chan *http.Response),
}
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
var err error
c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback)
if err != nil {
return nil, err
}
go c.client.Listen()
return c, nil
}
func (c *Client) handleStreamCb(session *quic.Session, stream utils.Stream) {
utils.Debugf("Handling stream %d", stream.StreamID())
}
func (c *Client) cryptoChangeCallback(isForwardSecure bool) {
c.cryptoChangedCond.L.Lock()
defer c.cryptoChangedCond.L.Unlock()
if isForwardSecure {
c.encryptionLevel = protocol.EncryptionForwardSecure
utils.Debugf("is forward secure")
} else {
c.encryptionLevel = protocol.EncryptionSecure
utils.Debugf("is secure")
}
c.cryptoChangedCond.Broadcast()
}
func (c *Client) versionNegotiateCallback() error {
var err error
// once the version has been negotiated, open the header stream
c.headerStream, err = c.client.OpenStream(3)
if err != nil {
return err
}
c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream()
return nil
}
func (c *Client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream)
var lastStream protocol.StreamID
for {
frame, err := h2framer.ReadFrame()
if err != nil {
c.headerErr = qerr.Error(qerr.InvalidStreamData, "cannot read frame")
break
}
lastStream = protocol.StreamID(frame.Header().StreamID)
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
break
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
break
}
c.mutex.RLock()
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
break
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
}
headerChan <- rsp
}
// stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
c.mutex.Lock()
for _, responseChan := range c.responses {
responseChan <- nil
}
c.mutex.Unlock()
}
// Do executes a request and returns a response
func (c *Client) Do(req *http.Request) (*http.Response, error) {
// TODO: add port to address, if it doesn't have one
if req.URL.Scheme != "https" {
return nil, errors.New("quic http2: unsupported scheme")
}
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
utils.Debugf("%s vs %s", req.Host, c.hostname)
return nil, errors.New("h2quic Client BUG: Do called for the wrong client")
}
hasBody := (req.Body != nil)
c.mutex.Lock()
c.highestOpenedStream += 2
dataStreamID := c.highestOpenedStream
for c.encryptionLevel != protocol.EncryptionForwardSecure {
c.cryptoChangedCond.Wait()
}
hdrChan := make(chan *http.Response)
c.responses[dataStreamID] = hdrChan
c.mutex.Unlock()
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
dataStream, err := c.client.OpenStream(dataStreamID)
if err != nil {
c.Close(err)
return nil, err
}
var requestedGzip bool
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true
}
// TODO: add support for trailers
endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStreamID, endStream, requestedGzip)
if err != nil {
c.Close(err)
return nil, err
}
resc := make(chan error, 1)
if hasBody {
go func() {
resc <- c.writeRequestBody(dataStream, req.Body)
}()
}
var res *http.Response
var receivedResponse bool
var bodySent bool
if !hasBody {
bodySent = true
}
for !(bodySent && receivedResponse) {
select {
case res = <-hdrChan:
receivedResponse = true
c.mutex.Lock()
delete(c.responses, dataStreamID)
c.mutex.Unlock()
if res == nil { // an error occured on the header stream
c.Close(c.headerErr)
return nil, c.headerErr
}
case err := <-resc:
bodySent = true
if err != nil {
return nil, err
}
}
}
// TODO: correctly set this variable
var streamEnded bool
isHead := (req.Method == "HEAD")
res = setLength(res, isHead, streamEnded)
if streamEnded || isHead {
res.Body = noBody
} else {
res.Body = dataStream
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = &gzipReader{body: res.Body}
setUncompressed(res)
}
}
res.Request = req
return res, nil
}
func (c *Client) writeRequestBody(dataStream utils.Stream, body io.ReadCloser) (err error) {
defer func() {
cerr := body.Close()
if err == nil {
// TODO: what to do with dataStream here? Maybe reset it?
err = cerr
}
}()
_, err = io.Copy(dataStream, body)
if err != nil {
// TODO: what to do with dataStream here? Maybe reset it?
return err
}
return dataStream.Close()
}
// Close closes the client
func (c *Client) Close(e error) {
_ = c.client.Close(e)
}
// copied from net/transport.go
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}
package h2quic
// copied from net/transport.go
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
import (
"compress/gzip"
"io"
)
// call gzip.NewReader on the first call to Read
type gzipReader struct {
body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
}
func (gz *gzipReader) Close() error {
return gz.body.Close()
}
package h2quic
import (
"crypto/tls"
"errors"
"net/http"
"net/url"
"strconv"
"strings"
"golang.org/x/net/http2/hpack"
)
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
var path, authority, method string
var path, authority, method, contentLengthStr string
httpHeaders := http.Header{}
for _, h := range headers {
......@@ -20,6 +23,8 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
method = h.Value
case ":authority":
authority = h.Value
case "content-length":
contentLengthStr = h.Value
default:
if !h.IsPseudo() {
httpHeaders.Add(h.Name, h.Value)
......@@ -27,6 +32,11 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
}
}
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
if len(httpHeaders["Cookie"]) > 0 {
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
}
if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
return nil, errors.New(":path, :authority and :method must not be empty")
}
......@@ -36,16 +46,35 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
return nil, err
}
var contentLength int64
if len(contentLengthStr) > 0 {
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return nil, err
}
}
return &http.Request{
Method: method,
URL: u,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
Header: httpHeaders,
Body: nil,
// ContentLength: -1,
Host: authority,
RequestURI: path,
Method: method,
URL: u,
Proto: "HTTP/2.0",
ProtoMajor: 2,
ProtoMinor: 0,
Header: httpHeaders,
Body: nil,
ContentLength: contentLength,
Host: authority,
RequestURI: path,
TLS: &tls.ConnectionState{},
}, nil
}
func hostnameFromRequest(req *http.Request) string {
if len(req.Host) > 0 {
return req.Host
}
if req.URL != nil {
return req.URL.Host
}
return ""
}
package h2quic
import (
"io"
"github.com/lucas-clemente/quic-go/utils"
)
type requestBody struct {
requestRead bool
dataStream utils.Stream
}
// make sure the requestBody can be used as a http.Request.Body
var _ io.ReadCloser = &requestBody{}
func newRequestBody(stream utils.Stream) *requestBody {
return &requestBody{dataStream: stream}
}
func (b *requestBody) Read(p []byte) (int, error) {
b.requestRead = true
return b.dataStream.Read(p)
}
func (b *requestBody) Close() error {
// stream's Close() closes the write side, not the read side
return nil
}
package h2quic
import (
"bytes"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/lex/httplex"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
)
type requestWriter struct {
mutex sync.Mutex
headerStream utils.Stream
henc *hpack.Encoder
hbuf bytes.Buffer // HPACK encoder writes into this
}
const defaultUserAgent = "quic-go"
func newRequestWriter(headerStream utils.Stream) *requestWriter {
rw := &requestWriter{
headerStream: headerStream,
}
rw.henc = hpack.NewEncoder(&rw.hbuf)
return rw
}
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
// TODO: add support for trailers
// TODO: add support for gzip compression
// TODO: write continuation frames, if the header frame is too long
w.mutex.Lock()
defer w.mutex.Unlock()
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
h2framer := http2.NewFramer(w.headerStream, nil)
return h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: uint32(dataStreamID),
EndHeaders: true,
EndStream: endStream,
BlockFragment: w.hbuf.Bytes(),
Priority: http2.PriorityParam{Weight: 0xff},
})
}
// the rest of this files is copied from http2.Transport
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
w.hbuf.Reset()
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httplex.PunycodeHostPort(host)
if err != nil {
return nil, err
}
var path string
if req.Method != "CONNECT" {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return nil, fmt.Errorf("invalid request :path %q", orig)
}
}
}
}
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httplex.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httplex.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}
}
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
w.writeHeader(":authority", host)
w.writeHeader(":method", req.Method)
if req.Method != "CONNECT" {
w.writeHeader(":path", path)
w.writeHeader(":scheme", req.URL.Scheme)
}
if trailers != "" {
w.writeHeader("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
switch lowKey {
case "host", "content-length":
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
case "user-agent":
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
w.writeHeader(lowKey, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
w.writeHeader("accept-encoding", "gzip")
}
if !didUA {
w.writeHeader("user-agent", defaultUserAgent)
}
return w.hbuf.Bytes(), nil
}
func (w *requestWriter) writeHeader(name, value string) {
utils.Debugf("http2: Transport encoding header %q = %q", name, value)
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
}
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
func actualContentLength(req *http.Request) int64 {
if req.Body == nil {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}
package h2quic
import (
"bytes"
"errors"
"io"
"io/ioutil"
"net/http"
"net/textproto"
"strconv"
"strings"
"golang.org/x/net/http2"
)
// copied from net/http2/transport.go
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil))
// from the handleResponse function
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
if f.Truncated {
return nil, errResponseHeaderListSize
}
status := f.PseudoValue("status")
if status == "" {
return nil, errors.New("missing status pseudo header")
}
statusCode, err := strconv.Atoi(status)
if err != nil {
return nil, errors.New("malformed non-numeric status pseudo header")
}
if statusCode == 100 {
// TODO: handle this
// traceGot100Continue(cs.trace)
// if cs.on100 != nil {
// cs.on100() // forces any write delay timer to fire
// }
// cs.pastHeaders = false // do it all again
// return nil, nil
}
header := make(http.Header)
res := &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: header,
StatusCode: statusCode,
Status: status + " " + http.StatusText(statusCode),
}
for _, hf := range f.RegularFields() {
key := http.CanonicalHeaderKey(hf.Name)
if key == "Trailer" {
t := res.Trailer
if t == nil {
t = make(http.Header)
res.Trailer = t
}
foreachHeaderElement(hf.Value, func(v string) {
t[http.CanonicalHeaderKey(v)] = nil
})
} else {
header[key] = append(header[key], hf.Value)
}
}
return res, nil
}
// continuation of the handleResponse function
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
if !streamEnded || isHead {
res.ContentLength = -1
if clens := res.Header["Content-Length"]; len(clens) == 1 {
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
res.ContentLength = clen64
} else {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
}
} else if len(clens) > 1 {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
}
}
return res
}
// copied from net/http/server.go
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
return
}
if !strings.Contains(v, ",") {
fn(v)
return
}
for _, f := range strings.Split(v, ",") {
if f = textproto.TrimString(f); f != "" {
fn(f)
}
}
}
// +build go1.7
package h2quic
import "net/http"
func setUncompressed(res *http.Response) {
res.Uncompressed = true
}
// +build !go1.7
package h2quic
import "net/http"
func setUncompressed(res *http.Response) {
// http.Response.Uncompressed was introduced in go 1.7
}
......@@ -21,6 +21,7 @@ type responseWriter struct {
headerStreamMutex *sync.Mutex
header http.Header
status int // status code passed to WriteHeader
headerWritten bool
}
......@@ -43,6 +44,7 @@ func (w *responseWriter) WriteHeader(status int) {
return
}
w.headerWritten = true
w.status = status
var headers bytes.Buffer
enc := hpack.NewEncoder(&headers)
......@@ -72,6 +74,9 @@ func (w *responseWriter) Write(p []byte) (int, error) {
if !w.headerWritten {
w.WriteHeader(200)
}
if !bodyAllowedForStatus(w.status) {
return 0, http.ErrBodyNotAllowed
}
return w.dataStream.Write(p)
}
......@@ -79,3 +84,18 @@ func (w *responseWriter) Flush() {}
// test that we implement http.Flusher
var _ http.Flusher = &responseWriter{}
// copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"golang.org/x/net/lex/httplex"
)
type h2quicClient interface {
Do(*http.Request) (*http.Response, error)
}
// QuicRoundTripper implements the http.RoundTripper interface
type QuicRoundTripper struct {
mutex sync.Mutex
// DisableCompression, if true, prevents the Transport from
// requesting compression with an "Accept-Encoding: gzip"
// request header when the Request contains no existing
// Accept-Encoding value. If the Transport requests gzip on
// its own and gets a gzipped response, it's transparently
// decoded in the Response.Body. However, if the user
// explicitly requested gzip it is not automatically
// uncompressed.
DisableCompression bool
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
clients map[string]h2quicClient
}
var _ http.RoundTripper = &QuicRoundTripper{}
// RoundTrip does a round trip
func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL")
}
if req.URL.Host == "" {
closeRequestBody(req)
return nil, errors.New("quic: no Host in request URL")
}
if req.Header == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.Header")
}
if req.URL.Scheme == "https" {
for k, vv := range req.Header {
if !httplex.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
}
for _, v := range vv {
if !httplex.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
}
}
}
} else {
closeRequestBody(req)
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
}
if req.Method != "" && !validMethod(req.Method) {
closeRequestBody(req)
return nil, fmt.Errorf("quic: invalid method %q", req.Method)
}
hostname := authorityAddr("https", hostnameFromRequest(req))
client, err := r.getClient(hostname)
if err != nil {
return nil, err
}
return client.Do(req)
}
func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
r.clients = make(map[string]h2quicClient)
}
client, ok := r.clients[hostname]
if !ok {
var err error
client, err = NewClient(r, r.TLSClientConfig, hostname)
if err != nil {
return nil, err
}
r.clients[hostname] = client
}
return client, nil
}
func (r *QuicRoundTripper) disableCompression() bool {
return r.DisableCompression
}
func closeRequestBody(req *http.Request) {
if req.Body != nil {
req.Body.Close()
}
}
func validMethod(method string) bool {
/*
Method = "OPTIONS" ; Section 9.2
| "GET" ; Section 9.3
| "HEAD" ; Section 9.4
| "POST" ; Section 9.5
| "PUT" ; Section 9.6
| "DELETE" ; Section 9.7
| "TRACE" ; Section 9.8
| "CONNECT" ; Section 9.9
| extension-method
extension-method = token
token = 1*<any CHAR except CTLs or separators>
*/
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// copied from net/http/http.go
func isNotToken(r rune) bool {
return !httplex.IsTokenRune(r)
}
......@@ -4,7 +4,6 @@ import (
"crypto/tls"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"runtime"
......@@ -113,6 +112,7 @@ func (s *Server) handleStream(session streamCreator, stream utils.Stream) {
if _, ok := err.(*qerr.QuicError); !ok {
utils.Errorf("error handling h2 request: %s", err.Error())
}
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
return
}
}
......@@ -124,7 +124,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
if err != nil {
return err
}
h2headersFrame := h2frame.(*http2.HeadersFrame)
h2headersFrame, ok := h2frame.(*http2.HeadersFrame)
if !ok {
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
}
if !h2headersFrame.HeadersEnded() {
return errors.New("http2 header continuation not implemented")
}
......@@ -152,13 +155,15 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
return err
}
var streamEnded bool
if h2headersFrame.StreamEnded() {
dataStream.CloseRemote(0)
streamEnded = true
_, _ = dataStream.Read([]byte{0}) // read the eof
}
// stream's Close() closes the write side, not the read side
req.Body = ioutil.NopCloser(dataStream)
reqBody := newRequestBody(dataStream)
req.Body = reqBody
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
......@@ -187,6 +192,9 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
responseWriter.WriteHeader(200)
}
if responseWriter.dataStream != nil {
if !streamEnded && !reqBody.requestRead {
responseWriter.dataStream.Reset(nil)
}
responseWriter.dataStream.Close()
}
if s.CloseAfterFirstRequest {
......
package handshake
import "github.com/lucas-clemente/quic-go/protocol"
// CryptoSetup is a crypto setup
type CryptoSetup interface {
HandleCryptoStream() error
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
LockForSealing()
UnlockForSealing()
HandshakeComplete() bool
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient
}
......@@ -95,12 +95,19 @@ func WriteHandshakeMessage(b *bytes.Buffer, messageTag Tag, data map[Tag][]byte)
func printHandshakeMessage(data map[Tag][]byte) string {
var res string
var pad string
for k, v := range data {
if k == TagPAD {
continue
pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(k), len(v))
} else {
res += fmt.Sprintf("\t%s: %#v\n", tagToString(k), string(v))
}
res += fmt.Sprintf("\t%s: %#v\n", tagToString(k), string(v))
}
if len(pad) > 0 {
res += pad
}
return res
}
......
......@@ -10,13 +10,14 @@ import (
// ServerConfig is a server config
type ServerConfig struct {
kex crypto.KeyExchange
signer crypto.Signer
certChain crypto.CertChain
ID []byte
obit []byte
stkSource crypto.StkSource
}
// NewServerConfig creates a new server config
func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfig, error) {
func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*ServerConfig, error) {
id := make([]byte, 16)
_, err := rand.Read(id)
if err != nil {
......@@ -27,6 +28,12 @@ func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfi
if _, err = rand.Read(stkSecret); err != nil {
return nil, err
}
obit := make([]byte, 8)
if _, err = rand.Read(obit); err != nil {
return nil, err
}
stkSource, err := crypto.NewStkSource(stkSecret)
if err != nil {
return nil, err
......@@ -34,8 +41,9 @@ func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfi
return &ServerConfig{
kex: kex,
signer: signer,
certChain: certChain,
ID: id,
obit: obit,
stkSource: stkSource,
}, nil
}
......@@ -48,7 +56,7 @@ func (s *ServerConfig) Get() []byte {
TagKEXS: []byte("C255"),
TagAEAD: []byte("AESG"),
TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...),
TagOBIT: {0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7},
TagOBIT: s.obit,
TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
})
return serverConfig.Bytes()
......@@ -56,10 +64,10 @@ func (s *ServerConfig) Get() []byte {
// Sign the server config and CHLO with the server's keyData
func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) {
return s.signer.SignServerProof(sni, chlo, s.Get())
return s.certChain.SignServerProof(sni, chlo, s.Get())
}
// GetCertsCompressed returns the certificate data
func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) {
return s.signer.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
return s.certChain.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
}
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"math"
"time"
"github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
)
type serverConfigClient struct {
raw []byte
ID []byte
obit []byte
expiry time.Time
kex crypto.KeyExchange
sharedSecret []byte
}
var (
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG")
)
// parseServerConfig parses a server config
func parseServerConfig(data []byte) (*serverConfigClient, error) {
tag, tagMap, err := ParseHandshakeMessage(bytes.NewReader(data))
if err != nil {
return nil, err
}
if tag != TagSCFG {
return nil, errMessageNotServerConfig
}
scfg := &serverConfigClient{raw: data}
err = scfg.parseValues(tagMap)
if err != nil {
return nil, err
}
return scfg, nil
}
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
// SCID
scfgID, ok := tagMap[TagSCID]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID")
}
if len(scfgID) != 16 {
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID")
}
s.ID = scfgID
// KEXS
// TODO: allow for P256 in the list
// TODO: setup Key Exchange
kexs, ok := tagMap[TagKEXS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS")
}
if len(kexs)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS")
}
if !bytes.Equal(kexs, []byte("C255")) {
return qerr.Error(qerr.CryptoNoSupport, "KEXS")
}
// AEAD
aead, ok := tagMap[TagAEAD]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD")
}
if len(aead)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD")
}
var aesgFound bool
for i := 0; i < len(aead)/4; i++ {
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) {
aesgFound = true
break
}
}
if !aesgFound {
return qerr.Error(qerr.CryptoNoSupport, "AEAD")
}
// PUBS
// TODO: save this value
pubs, ok := tagMap[TagPUBS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
if len(pubs) != 35 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
var err error
s.kex, err = crypto.NewCurve25519KEX()
if err != nil {
return err
}
// the PUBS value is always prepended by []byte{0x20, 0x00, 0x00}
s.sharedSecret, err = s.kex.CalculateSharedKey(pubs[3:])
if err != nil {
return err
}
// OBIT
obit, ok := tagMap[TagOBIT]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT")
}
if len(obit) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT")
}
s.obit = obit
// EXPY
expy, ok := tagMap[TagEXPY]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY")
}
if len(expy) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY")
}
// make sure that the value doesn't overflow an int64
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2)
s.expiry = time.Unix(int64(expyTimestamp), 0)
// TODO: implement VER
return nil
}
func (s *serverConfigClient) IsExpired() bool {
return s.expiry.Before(time.Now())
}
func (s *serverConfigClient) Get() []byte {
return s.raw
}
......@@ -59,6 +59,8 @@ const (
// TagNONC is the client nonce
TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24
// TagXLCT is the expected leaf certificate
TagXLCT Tag = 'X' + 'L'<<8 + 'C'<<16 + 'T'<<24
// TagSCID is the server config ID
TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24
......
......@@ -18,58 +18,72 @@ type packedPacket struct {
type packetPacker struct {
connectionID protocol.ConnectionID
perspective protocol.Perspective
version protocol.VersionNumber
cryptoSetup *handshake.CryptoSetup
cryptoSetup handshake.CryptoSetup
packetNumberGenerator *packetNumberGenerator
connectionParametersManager *handshake.ConnectionParametersManager
connectionParameters handshake.ConnectionParametersManager
streamFramer *streamFramer
controlFrames []frames.Frame
}
func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup *handshake.CryptoSetup, connectionParametersHandler *handshake.ConnectionParametersManager, streamFramer *streamFramer, version protocol.VersionNumber) *packetPacker {
func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.CryptoSetup, connectionParameters handshake.ConnectionParametersManager, streamFramer *streamFramer, perspective protocol.Perspective, version protocol.VersionNumber) *packetPacker {
return &packetPacker{
cryptoSetup: cryptoSetup,
connectionID: connectionID,
connectionParametersManager: connectionParametersHandler,
version: version,
streamFramer: streamFramer,
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
cryptoSetup: cryptoSetup,
connectionID: connectionID,
connectionParameters: connectionParameters,
perspective: perspective,
version: version,
streamFramer: streamFramer,
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
}
}
func (p *packetPacker) PackConnectionClose(frame *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
return p.packPacket(nil, []frames.Frame{frame}, leastUnacked, true, false)
// PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
// in case the connection is closed, all queued control frames aren't of any use anymore
// discard them and queue the ConnectionCloseFrame
p.controlFrames = []frames.Frame{ccf}
return p.packPacket(nil, leastUnacked)
}
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, maySendOnlyAck bool) (*packedPacket, error) {
return p.packPacket(stopWaitingFrame, controlFrames, leastUnacked, false, maySendOnlyAck)
// PackPacket packs a new packet
// the stopWaitingFrame is *guaranteed* to be included in the next packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
p.controlFrames = append(p.controlFrames, controlFrames...)
return p.packPacket(stopWaitingFrame, leastUnacked)
}
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, onlySendOneControlFrame, maySendOnlyAck bool) (*packedPacket, error) {
if len(controlFrames) > 0 {
p.controlFrames = append(p.controlFrames, controlFrames...)
}
currentPacketNumber := p.packetNumberGenerator.Peek()
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
// cryptoSetup needs to be locked here, so that the AEADs are not changed between
// calling DiversificationNonce() and Seal().
p.cryptoSetup.LockForSealing()
defer p.cryptoSetup.UnlockForSealing()
currentPacketNumber := p.packetNumberGenerator.Peek()
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(currentPacketNumber, leastUnacked)
responsePublicHeader := &PublicHeader{
ConnectionID: p.connectionID,
PacketNumber: currentPacketNumber,
PacketNumberLen: packetNumberLen,
TruncateConnectionID: p.connectionParametersManager.TruncateConnectionID(),
DiversificationNonce: p.cryptoSetup.DiversificationNonce(),
TruncateConnectionID: p.connectionParameters.TruncateConnectionID(),
}
if p.perspective == protocol.PerspectiveServer {
responsePublicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
}
publicHeaderLength, err := responsePublicHeader.GetLength()
// TODO: stop sending version numbers once a version has been negotiated
if p.perspective == protocol.PerspectiveClient {
responsePublicHeader.VersionFlag = true
responsePublicHeader.VersionNumber = p.version
}
publicHeaderLength, err := responsePublicHeader.GetLength(p.perspective)
if err != nil {
return nil, err
}
......@@ -79,9 +93,15 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con
stopWaitingFrame.PacketNumberLen = packetNumberLen
}
// we're packing a ConnectionClose, don't add any StreamFrames
var isConnectionClose bool
if len(p.controlFrames) == 1 {
_, isConnectionClose = p.controlFrames[0].(*frames.ConnectionCloseFrame)
}
var payloadFrames []frames.Frame
if onlySendOneControlFrame {
payloadFrames = []frames.Frame{controlFrames[0]}
if isConnectionClose {
payloadFrames = []frames.Frame{p.controlFrames[0]}
} else {
payloadFrames, err = p.composeNextPacket(stopWaitingFrame, publicHeaderLength)
if err != nil {
......@@ -94,26 +114,14 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con
return nil, nil
}
// Don't send out packets that only contain a StopWaitingFrame
if !onlySendOneControlFrame && len(payloadFrames) == 1 && stopWaitingFrame != nil {
if len(payloadFrames) == 1 && stopWaitingFrame != nil {
return nil, nil
}
// Don't send out packets that only contain an ACK (plus optional STOP_WAITING), if requested
if !maySendOnlyAck {
if len(payloadFrames) == 1 {
if _, ok := payloadFrames[0].(*frames.AckFrame); ok {
return nil, nil
}
} else if len(payloadFrames) == 2 && stopWaitingFrame != nil {
if _, ok := payloadFrames[1].(*frames.AckFrame); ok {
return nil, nil
}
}
}
raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw)
if err = responsePublicHeader.WritePublicHeader(buffer, p.version); err != nil {
if err = responsePublicHeader.Write(buffer, p.version, p.perspective); err != nil {
return nil, err
}
......
......@@ -11,10 +11,6 @@ import (
"github.com/lucas-clemente/quic-go/qerr"
)
type unpackedPacket struct {
frames []frames.Frame
}
type packetUnpacker struct {
version protocol.VersionNumber
aead crypto.AEAD
......
package protocol
// EncryptionLevel is the encryption level
// Default value is Unencrypted
type EncryptionLevel int
const (
// Unencrypted is not encrypted
Unencrypted EncryptionLevel = iota
// EncryptionSecure is encrypted, but not forward secure
EncryptionSecure
// EncryptionForwardSecure is forward secure
EncryptionForwardSecure
)
package protocol
// Perspective determines if we're acting as a server or a client
type Perspective int
// the perspectives
const (
PerspectiveServer Perspective = 1
PerspectiveClient Perspective = 2
)
......@@ -64,3 +64,9 @@ const MaxRetransmissionTime = 60 * time.Second
// ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have.
const ClientHelloMinimumSize = 1024
// MaxClientHellos is the maximum number of times we'll send a client hello
// The value 3 accounts for:
// * one failure due to an incorrect or missing source-address token
// * one failure due the server's certificate chain being unavailible and the server being unwilling to send it without a valid source-address token
const MaxClientHellos = 3
......@@ -3,31 +3,48 @@ package protocol
import "time"
// DefaultMaxCongestionWindow is the default for the max congestion window
const DefaultMaxCongestionWindow PacketNumber = 1000
const DefaultMaxCongestionWindow = 1000
// InitialCongestionWindow is the initial congestion window in QUIC packets
const InitialCongestionWindow PacketNumber = 32
const InitialCongestionWindow = 32
// MaxUndecryptablePackets limits the number of undecryptable packets that a
// session queues for later until it sends a public reset.
const MaxUndecryptablePackets = 10
// AckSendDelay is the maximal time delay applied to packets containing only ACKs
const AckSendDelay = 5 * time.Millisecond
// AckSendDelay is the maximum delay that can be applied to an ACK for a retransmittable packet
// This is the value Chromium is using
const AckSendDelay = 25 * time.Millisecond
// ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data
// This is the value that Google servers are using
const ReceiveStreamFlowControlWindow ByteCount = (1 << 20) // 1 MB
const ReceiveStreamFlowControlWindow ByteCount = (1 << 10) * 32 // 32 kB
// ReceiveConnectionFlowControlWindow is the stream-level flow control window for receiving data
// ReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data
// This is the value that Google servers are using
const ReceiveConnectionFlowControlWindow ByteCount = (1 << 20) * 1.5 // 1.5 MB
const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB
// MaxReceiveStreamFlowControlWindowServer is the maximum stream-level flow control window for receiving data
// This is the value that Google servers are using
const MaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB
// MaxReceiveConnectionFlowControlWindowServer is the connection-level flow control window for receiving data
// This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB
// MaxReceiveStreamFlowControlWindowClient is the maximum stream-level flow control window for receiving data, for the client
// This is the value that Chromium is using
const MaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB
// MaxReceiveConnectionFlowControlWindowClient is the connection-level flow control window for receiving data, for the server
// This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB
// MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection
const MaxStreamsPerConnection = 100
// MaxIncomingDynamicStreams is the maximum value accepted for the incoming number of dynamic streams per connection
const MaxIncomingDynamicStreams = 100
// MaxIncomingDynamicStreamsPerConnection is the maximum value accepted for the incoming number of dynamic streams per connection
const MaxIncomingDynamicStreamsPerConnection = 100
// MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used.
const MaxStreamsMultiplier = 1.1
......@@ -60,8 +77,17 @@ const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow
// MaxTrackedReceivedPackets is the maximum number of received packets saved for doing the entropy calculations
const MaxTrackedReceivedPackets = 2 * DefaultMaxCongestionWindow
// MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked
const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow
// MaxPacketsReceivedBeforeAckSend is the number of packets that can be received before an ACK frame is sent
const MaxPacketsReceivedBeforeAckSend = 20
// RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for
const RetransmittablePacketsBeforeAck = 2
// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames
// prevents DOS attacks against the streamFrameSorter
// prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000
// CryptoMaxParams is the upper limit for the number of parameters in a crypto message.
......@@ -69,7 +95,7 @@ const MaxStreamFrameSorterGaps = 1000
const CryptoMaxParams = 128
// CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message.
const CryptoParameterMaxLength = 2000
const CryptoParameterMaxLength = 4000
// EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX.
const EphermalKeyLifetime = time.Minute
......@@ -77,14 +103,21 @@ const EphermalKeyLifetime = time.Minute
// InitialIdleTimeout is the timeout before the handshake succeeds.
const InitialIdleTimeout = 5 * time.Second
// DefaultIdleTimeout is the default idle timeout.
// DefaultIdleTimeout is the default idle timeout, for the server
const DefaultIdleTimeout = 30 * time.Second
// MaxIdleTimeout is the maximum idle timeout that can be negotiated.
const MaxIdleTimeout = 1 * time.Minute
// MaxIdleTimeoutServer is the maximum idle timeout that can be negotiated, for the server
const MaxIdleTimeoutServer = 1 * time.Minute
// MaxIdleTimeoutClient is the idle timeout that the client suggests to the server
const MaxIdleTimeoutClient = 2 * time.Minute
// MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds.
const MaxTimeForCryptoHandshake = 10 * time.Second
// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed
// after this time all information about the old connection will be deleted
const ClosedSessionDeleteTimeout = time.Minute
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
const NumCachedCertificates = 128
......@@ -14,10 +14,12 @@ const (
Version34 VersionNumber = 34 + iota
Version35
Version36
VersionWhatever = 0 // for when the version doesn't matter
VersionWhatever = 0 // for when the version doesn't matter
VersionUnsupported = -1
)
// SupportedVersions lists the versions that the server supports
// must be in sorted order
var SupportedVersions = []VersionNumber{
Version34, Version35, Version36,
}
......@@ -49,6 +51,28 @@ func IsSupportedVersion(v VersionNumber) bool {
return false
}
// HighestSupportedVersion finds the highest version number that is both present in other and in SupportedVersions
// the versions in other do not need to be ordered
// it returns true and the version number, if there is one, otherwise false
func HighestSupportedVersion(other []VersionNumber) (bool, VersionNumber) {
var otherSupported []VersionNumber
for _, ver := range other {
if ver != VersionUnsupported {
otherSupported = append(otherSupported, ver)
}
}
for i := len(SupportedVersions) - 1; i >= 0; i-- {
for _, ver := range otherSupported {
if ver == SupportedVersions[i] {
return true, ver
}
}
}
return false, 0
}
func init() {
var b bytes.Buffer
for _, v := range SupportedVersions {
......
......@@ -3,7 +3,6 @@ package quic
import (
"bytes"
"errors"
"io"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
......@@ -11,11 +10,11 @@ import (
)
var (
errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set")
errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time")
errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported")
errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0")
errGetLengthOnlyForRegularPackets = errors.New("PublicHeader: GetLength can only be called for regular packets")
errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set")
errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time")
errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported")
errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0")
errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets")
)
// The PublicHeader of a QUIC packet
......@@ -27,16 +26,19 @@ type PublicHeader struct {
TruncateConnectionID bool
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
VersionNumber protocol.VersionNumber
VersionNumber protocol.VersionNumber // VersionNumber sent by the client
SupportedVersions []protocol.VersionNumber // VersionNumbers sent by the server
DiversificationNonce []byte
}
// WritePublicHeader writes a public header
func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.VersionNumber) error {
// Write writes a public header
func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pers protocol.Perspective) error {
publicFlagByte := uint8(0x00)
if h.VersionFlag && h.ResetFlag {
return errResetAndVersionFlagSet
}
if h.VersionFlag {
publicFlagByte |= 0x01
}
......@@ -54,7 +56,8 @@ func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.Versi
publicFlagByte |= 0x04
}
if !h.ResetFlag && !h.VersionFlag {
// only set PacketNumberLen bits if a packet number will be written
if h.hasPacketNumber(pers) {
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
publicFlagByte |= 0x00
......@@ -73,30 +76,42 @@ func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.Versi
utils.WriteUint64(b, uint64(h.ConnectionID))
}
if h.VersionFlag && pers == protocol.PerspectiveClient {
utils.WriteUint32(b, protocol.VersionNumberToTag(h.VersionNumber))
}
if len(h.DiversificationNonce) > 0 {
b.Write(h.DiversificationNonce)
}
if !h.ResetFlag && !h.VersionFlag {
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.WriteUint32(b, uint32(h.PacketNumber))
case protocol.PacketNumberLen6:
utils.WriteUint48(b, uint64(h.PacketNumber))
default:
return errPacketNumberLenNotSet
}
// if we're a server, and the VersionFlag is set, we must not include anything else in the packet
if !h.hasPacketNumber(pers) {
return nil
}
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 {
return errPacketNumberLenNotSet
}
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.WriteUint32(b, uint32(h.PacketNumber))
case protocol.PacketNumberLen6:
utils.WriteUint48(b, uint64(h.PacketNumber))
default:
return errPacketNumberLenNotSet
}
return nil
}
// ParsePublicHeader parses a QUIC packet's public header
func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) {
// the packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient
func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*PublicHeader, error) {
header := &PublicHeader{}
// First byte
......@@ -117,15 +132,17 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) {
return nil, errReceivedTruncatedConnectionID
}
switch publicFlagByte & 0x30 {
case 0x30:
header.PacketNumberLen = protocol.PacketNumberLen6
case 0x20:
header.PacketNumberLen = protocol.PacketNumberLen4
case 0x10:
header.PacketNumberLen = protocol.PacketNumberLen2
case 0x00:
header.PacketNumberLen = protocol.PacketNumberLen1
if header.hasPacketNumber(packetSentBy) {
switch publicFlagByte & 0x30 {
case 0x30:
header.PacketNumberLen = protocol.PacketNumberLen6
case 0x20:
header.PacketNumberLen = protocol.PacketNumberLen4
case 0x10:
header.PacketNumberLen = protocol.PacketNumberLen2
case 0x00:
header.PacketNumberLen = protocol.PacketNumberLen1
}
}
// Connection ID
......@@ -133,46 +150,111 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) {
if err != nil {
return nil, err
}
header.ConnectionID = protocol.ConnectionID(connID)
if header.ConnectionID == 0 {
return nil, errInvalidConnectionID
}
if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 {
// TODO: remove the if once the Google servers send the correct value
// assume that a packet doesn't contain a diversification nonce if the version flag or the reset flag is set, no matter what the public flag says
// see https://github.com/lucas-clemente/quic-go/issues/232
if !header.VersionFlag && !header.ResetFlag {
header.DiversificationNonce = make([]byte, 32)
// this Read can never return an EOF for a valid packet, since the diversification nonce is followed by the packet number
_, err = b.Read(header.DiversificationNonce)
if err != nil {
return nil, err
}
}
}
// Version (optional)
if header.VersionFlag {
var versionTag uint32
versionTag, err = utils.ReadUint32(b)
if err != nil {
return nil, err
if !header.ResetFlag {
if header.VersionFlag {
if packetSentBy == protocol.PerspectiveClient {
var versionTag uint32
versionTag, err = utils.ReadUint32(b)
if err != nil {
return nil, err
}
header.VersionNumber = protocol.VersionTagToNumber(versionTag)
} else { // parse the version negotiaton packet
if b.Len()%4 != 0 {
return nil, qerr.InvalidVersionNegotiationPacket
}
header.SupportedVersions = make([]protocol.VersionNumber, 0)
for {
var versionTag uint32
versionTag, err = utils.ReadUint32(b)
if err != nil {
break
}
v := protocol.VersionTagToNumber(versionTag)
if !protocol.IsSupportedVersion(v) {
v = protocol.VersionUnsupported
}
header.SupportedVersions = append(header.SupportedVersions, v)
}
}
}
header.VersionNumber = protocol.VersionTagToNumber(versionTag)
}
// Packet number
packetNumber, err := utils.ReadUintN(b, uint8(header.PacketNumberLen))
if err != nil {
return nil, err
if header.hasPacketNumber(packetSentBy) {
packetNumber, err := utils.ReadUintN(b, uint8(header.PacketNumberLen))
if err != nil {
return nil, err
}
header.PacketNumber = protocol.PacketNumber(packetNumber)
}
header.PacketNumber = protocol.PacketNumber(packetNumber)
return header, nil
}
// GetLength gets the length of the publicHeader in bytes
// can only be called for regular packets
func (h *PublicHeader) GetLength() (protocol.ByteCount, error) {
if h.VersionFlag || h.ResetFlag {
return 0, errGetLengthOnlyForRegularPackets
func (h *PublicHeader) GetLength(pers protocol.Perspective) (protocol.ByteCount, error) {
if h.VersionFlag && h.ResetFlag {
return 0, errResetAndVersionFlagSet
}
if h.VersionFlag && pers == protocol.PerspectiveServer {
return 0, errGetLengthNotForVersionNegotiation
}
length := protocol.ByteCount(1) // 1 byte for public flags
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 {
return 0, errPacketNumberLenNotSet
if h.hasPacketNumber(pers) {
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 {
return 0, errPacketNumberLenNotSet
}
length += protocol.ByteCount(h.PacketNumberLen)
}
if !h.TruncateConnectionID {
length += 8 // 8 bytes for the connection ID
}
// Version Number in packets sent by the client
if h.VersionFlag {
length += 4
}
length += protocol.ByteCount(len(h.DiversificationNonce))
length += protocol.ByteCount(h.PacketNumberLen)
return length, nil
}
// hasPacketNumber determines if this PublicHeader will contain a packet number
// this depends on the ResetFlag, the VersionFlag and who sent the packet
func (h *PublicHeader) hasPacketNumber(packetSentBy protocol.Perspective) bool {
if h.ResetFlag {
return false
}
if h.VersionFlag && packetSentBy == protocol.PerspectiveServer {
return false
}
return true
}
......@@ -2,12 +2,19 @@ package quic
import (
"bytes"
"encoding/binary"
"errors"
"github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
)
type publicReset struct {
rejectedPacketNumber protocol.PacketNumber
nonce uint64
}
func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte {
b := &bytes.Buffer{}
b.WriteByte(0x0a)
......@@ -22,3 +29,34 @@ func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p
utils.WriteUint64(b, uint64(rejectedPacketNumber))
return b.Bytes()
}
func parsePublicReset(r *bytes.Reader) (*publicReset, error) {
pr := publicReset{}
tag, tagMap, err := handshake.ParseHandshakeMessage(r)
if err != nil {
return nil, err
}
if tag != handshake.TagPRST {
return nil, errors.New("wrong public reset tag")
}
rseq, ok := tagMap[handshake.TagRSEQ]
if !ok {
return nil, errors.New("RSEQ missing")
}
if len(rseq) != 8 {
return nil, errors.New("invalid RSEQ tag")
}
pr.rejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq))
rnon, ok := tagMap[handshake.TagRNON]
if !ok {
return nil, errors.New("RNON missing")
}
if len(rnon) != 8 {
return nil, errors.New("invalid RNON tag")
}
pr.nonce = binary.LittleEndian.Uint64(rnon)
return &pr, nil
}
......@@ -3,6 +3,7 @@ package quic
import (
"bytes"
"crypto/tls"
"errors"
"net"
"strings"
"sync"
......@@ -18,6 +19,7 @@ import (
// packetHandler handles packets
type packetHandler interface {
handlePacket(*receivedPacket)
OpenStream(protocol.StreamID) (utils.Stream, error)
run()
Close(error) error
}
......@@ -29,11 +31,12 @@ type Server struct {
conn *net.UDPConn
connMutex sync.Mutex
signer crypto.Signer
scfg *handshake.ServerConfig
certChain crypto.CertChain
scfg *handshake.ServerConfig
sessions map[protocol.ConnectionID]packetHandler
sessionsMutex sync.RWMutex
sessions map[protocol.ConnectionID]packetHandler
sessionsMutex sync.RWMutex
deleteClosedSessionsAfter time.Duration
streamCallback StreamCallback
......@@ -42,16 +45,13 @@ type Server struct {
// NewServer makes a new server
func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server, error) {
signer, err := crypto.NewProofSource(tlsConfig)
if err != nil {
return nil, err
}
certChain := crypto.NewCertChain(tlsConfig)
kex, err := crypto.NewCurve25519KEX()
if err != nil {
return nil, err
}
scfg, err := handshake.NewServerConfig(kex, signer)
scfg, err := handshake.NewServerConfig(kex, certChain)
if err != nil {
return nil, err
}
......@@ -62,12 +62,13 @@ func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server,
}
return &Server{
addr: udpAddr,
signer: signer,
scfg: scfg,
streamCallback: cb,
sessions: map[protocol.ConnectionID]packetHandler{},
newSession: newSession,
addr: udpAddr,
certChain: certChain,
scfg: scfg,
streamCallback: cb,
sessions: map[protocol.ConnectionID]packetHandler{},
newSession: newSession,
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
}, nil
}
......@@ -135,12 +136,39 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
r := bytes.NewReader(packet)
hdr, err := ParsePublicHeader(r)
hdr, err := ParsePublicHeader(r, protocol.PerspectiveClient)
if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
}
hdr.Raw = packet[:len(packet)-r.Len()]
s.sessionsMutex.RLock()
session, ok := s.sessions[hdr.ConnectionID]
s.sessionsMutex.RUnlock()
// ignore all Public Reset packets
if hdr.ResetFlag {
if ok {
var pr *publicReset
pr, err = parsePublicReset(r)
if err != nil {
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
} else {
utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.rejectedPacketNumber)
}
} else {
utils.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID)
}
return nil
}
// a session is only created once the client sent a supported version
// if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated
// it is safe to drop it
if ok && hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) {
return nil
}
// Send Version Negotiation Packet if the client is speaking a different protocol version
if hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) {
utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber)
......@@ -148,15 +176,20 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
return err
}
s.sessionsMutex.RLock()
session, ok := s.sessions[hdr.ConnectionID]
s.sessionsMutex.RUnlock()
if !ok {
utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, hdr.VersionNumber, remoteAddr)
if !hdr.VersionFlag {
_, err = conn.WriteToUDP(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr)
return err
}
version := hdr.VersionNumber
if !protocol.IsSupportedVersion(version) {
return errors.New("Server BUG: negotiated version not supported")
}
utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr)
session, err = s.newSession(
&udpConn{conn: conn, currentAddr: remoteAddr},
hdr.VersionNumber,
version,
hdr.ConnectionID,
s.scfg,
s.streamCallback,
......@@ -187,6 +220,12 @@ func (s *Server) closeCallback(id protocol.ConnectionID) {
s.sessionsMutex.Lock()
s.sessions[id] = nil
s.sessionsMutex.Unlock()
time.AfterFunc(s.deleteClosedSessionsAfter, func() {
s.sessionsMutex.Lock()
delete(s.sessions, id)
s.sessionsMutex.Unlock()
})
}
func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte {
......@@ -196,7 +235,7 @@ func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte {
PacketNumber: 1,
VersionFlag: true,
}
err := responsePublicHeader.WritePublicHeader(fullReply, protocol.Version35)
err := responsePublicHeader.Write(fullReply, protocol.Version35, protocol.PerspectiveServer)
if err != nil {
utils.Errorf("error composing version negotiation packet: %s", err.Error())
}
......
......@@ -119,7 +119,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
if f.flowControlManager.RemainingConnectionWindowSize() == 0 {
// We are now connection-level FC blocked
f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: 0})
} else if sendWindowSize-frame.DataLen() == 0 {
} else if !frame.FinBit && sendWindowSize-frame.DataLen() == 0 {
// We are now stream-level FC blocked
f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: s.StreamID()})
}
......
package quic
import "net"
import (
"net"
"sync"
)
type connection interface {
write([]byte) error
......@@ -9,6 +12,8 @@ type connection interface {
}
type udpConn struct {
mutex sync.RWMutex
conn *net.UDPConn
currentAddr *net.UDPAddr
}
......@@ -21,9 +26,14 @@ func (c *udpConn) write(p []byte) error {
}
func (c *udpConn) setCurrentRemoteAddr(addr interface{}) {
c.mutex.Lock()
c.currentAddr = addr.(*net.UDPAddr)
c.mutex.Unlock()
}
func (c *udpConn) RemoteAddr() *net.UDPAddr {
return c.currentAddr
c.mutex.RLock()
addr := c.currentAddr
c.mutex.RUnlock()
return addr
}
package quic
import "github.com/lucas-clemente/quic-go/frames"
type unpackedPacket struct {
frames []frames.Frame
}
func (u *unpackedPacket) IsRetransmittable() bool {
for _, f := range u.frames {
switch f.(type) {
case *frames.StreamFrame:
return true
case *frames.RstStreamFrame:
return true
case *frames.WindowUpdateFrame:
return true
case *frames.BlockedFrame:
return true
case *frames.PingFrame:
return true
case *frames.GoawayFrame:
return true
}
}
return false
}
......@@ -34,6 +34,14 @@ func MaxUint64(a, b uint64) uint64 {
return a
}
// MinUint64 returns the maximum of two uint64
func MinUint64(a, b uint64) uint64 {
if a < b {
return a
}
return b
}
// Min returns the minimum of two Ints
func Min(a, b int) int {
if a < b {
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment