Commit ab87ca05 authored by rui.zheng's avatar rui.zheng

update quic-go

parent 5efffa7d
MIT License MIT License
Copyright (c) 2016 the quic-go authors & Google, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights in the Software without restriction, including without limitation the rights
......
...@@ -14,6 +14,7 @@ quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) p ...@@ -14,6 +14,7 @@ quic-go is an implementation of the [QUIC](https://en.wikipedia.org/wiki/QUIC) p
Done: Done:
- Basic protocol with support for QUIC version 34-36 - Basic protocol with support for QUIC version 34-36
- QUIC client
- HTTP/2 support - HTTP/2 support
- Crypto (RSA / ECDSA certificates, Curve25519 for key exchange, AES-GCM or Chacha20-Poly1305 as stream cipher) - Crypto (RSA / ECDSA certificates, Curve25519 for key exchange, AES-GCM or Chacha20-Poly1305 as stream cipher)
- Loss detection and retransmission (currently fast retransmission & RTO) - Loss detection and retransmission (currently fast retransmission & RTO)
...@@ -22,11 +23,10 @@ Done: ...@@ -22,11 +23,10 @@ Done:
Major TODOs: Major TODOs:
- Security, especially DOS protections - Security, especially DoS protections
- Performance - Performance
- Better packet loss detection - Better packet loss detection
- Connection migration - Connection migration
- QUIC client
## Guides ## Guides
...@@ -38,20 +38,26 @@ Running tests: ...@@ -38,20 +38,26 @@ Running tests:
go test ./... go test ./...
Running the example server: ### Running the example server
go run example/main.go -www /var/www/ go run example/main.go -www /var/www/
Using the `quic_client` from chromium: Using the `quic_client` from chromium:
quic_client --quic-version=32 --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io quic_client --host=127.0.0.1 --port=6121 --v=1 https://quic.clemente.io
Using Chrome: Using Chrome:
/Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io /Applications/Google\ Chrome.app/Contents/MacOS/Google\ Chrome --user-data-dir=/tmp/chrome --no-proxy-server --enable-quic --origin-to-force-quic-on=quic.clemente.io:443 --host-resolver-rules='MAP quic.clemente.io:443 127.0.0.1:6121' https://quic.clemente.io
### Using the example client
go run example/client/main.go https://quic.clemente.io
## Usage ## Usage
### As a server
See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go: See the [example server](example/main.go) or try out [Caddy](https://github.com/mholt/caddy) (from version 0.9, [instructions here](https://github.com/mholt/caddy/wiki/QUIC)). Starting a QUIC server is very similar to the standard lib http in go:
```go ```go
...@@ -59,6 +65,16 @@ http.Handle("/", http.FileServer(http.Dir(wwwDir))) ...@@ -59,6 +65,16 @@ http.Handle("/", http.FileServer(http.Dir(wwwDir)))
h2quic.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil) h2quic.ListenAndServeQUIC("localhost:4242", "/path/to/cert/chain.pem", "/path/to/privkey.pem", nil)
``` ```
### As a client
See the [example client](example/client/main.go). Use a `QuicRoundTripper` as a `Transport` in a `http.Client`.
```go
http.Client{
Transport: &h2quic.QuicRoundTripper{},
}
```
## Building on Windows ## Building on Windows
Due to the low Windows timer resolution (see [StackOverflow question](http://stackoverflow.com/questions/37706834/high-resolution-timers-millisecond-precision-in-go-on-windows)) available with Go 1.6.x, some optimizations might not work when compiled with this version of the compiler. Please use Go 1.7 on Windows. Due to the low Windows timer resolution (see [StackOverflow question](http://stackoverflow.com/questions/37706834/high-resolution-timers-millisecond-precision-in-go-on-windows)) available with Go 1.6.x, some optimizations might not work when compiled with this version of the compiler. Please use Go 1.7 on Windows.
...@@ -28,8 +28,8 @@ type SentPacketHandler interface { ...@@ -28,8 +28,8 @@ type SentPacketHandler interface {
// ReceivedPacketHandler handles ACKs needed to send for incoming packets // ReceivedPacketHandler handles ACKs needed to send for incoming packets
type ReceivedPacketHandler interface { type ReceivedPacketHandler interface {
ReceivedPacket(packetNumber protocol.PacketNumber) error ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error
ReceivedStopWaiting(*frames.StopWaitingFrame) error ReceivedStopWaiting(*frames.StopWaitingFrame) error
GetAckFrame(dequeue bool) (*frames.AckFrame, error) GetAckFrame() *frames.AckFrame
} }
...@@ -19,31 +19,17 @@ type Packet struct { ...@@ -19,31 +19,17 @@ type Packet struct {
SendTime time.Time SendTime time.Time
} }
// GetStreamFramesForRetransmission gets all the streamframes for retransmission // GetFramesForRetransmission gets all the frames for retransmission
func (p *Packet) GetStreamFramesForRetransmission() []*frames.StreamFrame { func (p *Packet) GetFramesForRetransmission() []frames.Frame {
var streamFrames []*frames.StreamFrame var fs []frames.Frame
for _, frame := range p.Frames { for _, frame := range p.Frames {
if streamFrame, isStreamFrame := frame.(*frames.StreamFrame); isStreamFrame { switch frame.(type) {
streamFrames = append(streamFrames, streamFrame) case *frames.AckFrame:
} continue
} case *frames.StopWaitingFrame:
return streamFrames
}
// GetControlFramesForRetransmission gets all the control frames for retransmission
func (p *Packet) GetControlFramesForRetransmission() []frames.Frame {
var controlFrames []frames.Frame
for _, frame := range p.Frames {
// omit ACKs
if _, isStreamFrame := frame.(*frames.StreamFrame); isStreamFrame {
continue continue
} }
fs = append(fs, frame)
_, isAck := frame.(*frames.AckFrame)
_, isStopWaiting := frame.(*frames.StopWaitingFrame)
if !isAck && !isStopWaiting {
controlFrames = append(controlFrames, frame)
}
} }
return controlFrames return fs
} }
...@@ -6,45 +6,48 @@ import ( ...@@ -6,45 +6,48 @@ import (
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
) )
var ( var (
// ErrDuplicatePacket occurres when a duplicate packet is received // ErrDuplicatePacket occurres when a duplicate packet is received
ErrDuplicatePacket = errors.New("ReceivedPacketHandler: Duplicate Packet") ErrDuplicatePacket = errors.New("ReceivedPacketHandler: Duplicate Packet")
// ErrMapAccess occurs when a NACK contains invalid NACK ranges
ErrMapAccess = qerr.Error(qerr.InvalidAckData, "Packet does not exist in PacketHistory")
// ErrPacketSmallerThanLastStopWaiting occurs when a packet arrives with a packet number smaller than the largest LeastUnacked of a StopWaitingFrame. If this error occurs, the packet should be ignored // ErrPacketSmallerThanLastStopWaiting occurs when a packet arrives with a packet number smaller than the largest LeastUnacked of a StopWaitingFrame. If this error occurs, the packet should be ignored
ErrPacketSmallerThanLastStopWaiting = errors.New("ReceivedPacketHandler: Packet number smaller than highest StopWaiting") ErrPacketSmallerThanLastStopWaiting = errors.New("ReceivedPacketHandler: Packet number smaller than highest StopWaiting")
) )
var ( var errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
errInvalidPacketNumber = errors.New("ReceivedPacketHandler: Invalid packet number")
errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "")
)
type receivedPacketHandler struct { type receivedPacketHandler struct {
largestInOrderObserved protocol.PacketNumber largestObserved protocol.PacketNumber
largestObserved protocol.PacketNumber ignorePacketsBelow protocol.PacketNumber
ignorePacketsBelow protocol.PacketNumber largestObservedReceivedTime time.Time
currentAckFrame *frames.AckFrame
stateChanged bool // has an ACK for this state already been sent? Will be set to false every time a new packet arrives, and to false every time an ACK is sent
packetHistory *receivedPacketHistory packetHistory *receivedPacketHistory
receivedTimes map[protocol.PacketNumber]time.Time ackSendDelay time.Duration
lowestInReceivedTimes protocol.PacketNumber
packetsReceivedSinceLastAck int
retransmittablePacketsReceivedSinceLastAck int
ackQueued bool
ackAlarm time.Time
ackAlarmResetCallback func(time.Time)
lastAck *frames.AckFrame
} }
// NewReceivedPacketHandler creates a new receivedPacketHandler // NewReceivedPacketHandler creates a new receivedPacketHandler
func NewReceivedPacketHandler() ReceivedPacketHandler { func NewReceivedPacketHandler(ackAlarmResetCallback func(time.Time)) ReceivedPacketHandler {
// create a stopped timer, see https://github.com/golang/go/issues/12721#issuecomment-143010182
timer := time.NewTimer(0)
<-timer.C
return &receivedPacketHandler{ return &receivedPacketHandler{
receivedTimes: make(map[protocol.PacketNumber]time.Time), packetHistory: newReceivedPacketHistory(),
packetHistory: newReceivedPacketHistory(), ackAlarmResetCallback: ackAlarmResetCallback,
ackSendDelay: protocol.AckSendDelay,
} }
} }
func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber) error { func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, shouldInstigateAck bool) error {
if packetNumber == 0 { if packetNumber == 0 {
return errInvalidPacketNumber return errInvalidPacketNumber
} }
...@@ -55,30 +58,21 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe ...@@ -55,30 +58,21 @@ func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumbe
return ErrPacketSmallerThanLastStopWaiting return ErrPacketSmallerThanLastStopWaiting
} }
_, ok := h.receivedTimes[packetNumber] if h.packetHistory.IsDuplicate(packetNumber) {
if packetNumber <= h.largestInOrderObserved || ok {
return ErrDuplicatePacket return ErrDuplicatePacket
} }
h.packetHistory.ReceivedPacket(packetNumber) err := h.packetHistory.ReceivedPacket(packetNumber)
if err != nil {
h.stateChanged = true return err
h.currentAckFrame = nil }
if packetNumber > h.largestObserved { if packetNumber > h.largestObserved {
h.largestObserved = packetNumber h.largestObserved = packetNumber
h.largestObservedReceivedTime = time.Now()
} }
if packetNumber == h.largestInOrderObserved+1 { h.maybeQueueAck(packetNumber, shouldInstigateAck)
h.largestInOrderObserved = packetNumber
}
h.receivedTimes[packetNumber] = time.Now()
if protocol.PacketNumber(len(h.receivedTimes)) > protocol.MaxTrackedReceivedPackets {
return errTooManyOutstandingReceivedPackets
}
return nil return nil
} }
...@@ -89,55 +83,84 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame) ...@@ -89,55 +83,84 @@ func (h *receivedPacketHandler) ReceivedStopWaiting(f *frames.StopWaitingFrame)
} }
h.ignorePacketsBelow = f.LeastUnacked - 1 h.ignorePacketsBelow = f.LeastUnacked - 1
h.garbageCollectReceivedTimes()
// the LeastUnacked is the smallest packet number of any packet for which the sender is still awaiting an ack. So the largestInOrderObserved is one less than that
if f.LeastUnacked > h.largestInOrderObserved {
h.largestInOrderObserved = f.LeastUnacked - 1
}
h.packetHistory.DeleteBelow(f.LeastUnacked) h.packetHistory.DeleteBelow(f.LeastUnacked)
return nil return nil
} }
func (h *receivedPacketHandler) GetAckFrame(dequeue bool) (*frames.AckFrame, error) { func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, shouldInstigateAck bool) {
if !h.stateChanged { var ackAlarmSet bool
return nil, nil h.packetsReceivedSinceLastAck++
if shouldInstigateAck {
h.retransmittablePacketsReceivedSinceLastAck++
} }
if dequeue { // always ack the first packet
h.stateChanged = false if h.lastAck == nil {
h.ackQueued = true
} }
if h.currentAckFrame != nil { // Always send an ack every 20 packets in order to allow the peer to discard
return h.currentAckFrame, nil // information from the SentPacketManager and provide an RTT measurement.
if h.packetsReceivedSinceLastAck >= protocol.MaxPacketsReceivedBeforeAckSend {
h.ackQueued = true
} }
packetReceivedTime, ok := h.receivedTimes[h.largestObserved] // if the packet number is smaller than the largest acked packet, it must have been reported missing with the last ACK
if !ok { // note that it cannot be a duplicate because they're already filtered out by ReceivedPacket()
return nil, ErrMapAccess if h.lastAck != nil && packetNumber < h.lastAck.LargestAcked {
h.ackQueued = true
}
// check if a new missing range above the previously was created
if h.lastAck != nil && h.packetHistory.GetHighestAckRange().FirstPacketNumber > h.lastAck.LargestAcked {
h.ackQueued = true
}
if !h.ackQueued && shouldInstigateAck {
if h.retransmittablePacketsReceivedSinceLastAck >= protocol.RetransmittablePacketsBeforeAck {
h.ackQueued = true
} else {
if h.ackAlarm.IsZero() {
h.ackAlarm = time.Now().Add(h.ackSendDelay)
ackAlarmSet = true
}
}
}
if h.ackQueued {
// cancel the ack alarm
h.ackAlarm = time.Time{}
ackAlarmSet = false
}
if ackAlarmSet {
h.ackAlarmResetCallback(h.ackAlarm)
}
}
func (h *receivedPacketHandler) GetAckFrame() *frames.AckFrame {
if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(time.Now())) {
return nil
} }
ackRanges := h.packetHistory.GetAckRanges() ackRanges := h.packetHistory.GetAckRanges()
h.currentAckFrame = &frames.AckFrame{ ack := &frames.AckFrame{
LargestAcked: h.largestObserved, LargestAcked: h.largestObserved,
LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber, LowestAcked: ackRanges[len(ackRanges)-1].FirstPacketNumber,
PacketReceivedTime: packetReceivedTime, PacketReceivedTime: h.largestObservedReceivedTime,
} }
if len(ackRanges) > 1 { if len(ackRanges) > 1 {
h.currentAckFrame.AckRanges = ackRanges ack.AckRanges = ackRanges
} }
return h.currentAckFrame, nil h.lastAck = ack
} h.ackAlarm = time.Time{}
h.ackQueued = false
h.packetsReceivedSinceLastAck = 0
h.retransmittablePacketsReceivedSinceLastAck = 0
func (h *receivedPacketHandler) garbageCollectReceivedTimes() { return ack
for i := h.lowestInReceivedTimes; i <= h.ignorePacketsBelow; i++ {
delete(h.receivedTimes, i)
}
if h.ignorePacketsBelow > h.lowestInReceivedTimes {
h.lowestInReceivedTimes = h.ignorePacketsBelow + 1
}
} }
package ackhandler package ackhandler
import ( import (
"sync"
"github.com/lucas-clemente/quic-go/frames" "github.com/lucas-clemente/quic-go/frames"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/utils"
) )
type receivedPacketHistory struct { type receivedPacketHistory struct {
ranges *utils.PacketIntervalList ranges *utils.PacketIntervalList
mutex sync.RWMutex // the map is used as a replacement for a set here. The bool is always supposed to be set to true
receivedPacketNumbers map[protocol.PacketNumber]bool
lowestInReceivedPacketNumbers protocol.PacketNumber
} }
var (
errTooManyOutstandingReceivedAckRanges = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received ACK ranges")
errTooManyOutstandingReceivedPackets = qerr.Error(qerr.TooManyOutstandingReceivedPackets, "Too many outstanding received packets")
)
// newReceivedPacketHistory creates a new received packet history // newReceivedPacketHistory creates a new received packet history
func newReceivedPacketHistory() *receivedPacketHistory { func newReceivedPacketHistory() *receivedPacketHistory {
return &receivedPacketHistory{ return &receivedPacketHistory{
ranges: utils.NewPacketIntervalList(), ranges: utils.NewPacketIntervalList(),
receivedPacketNumbers: make(map[protocol.PacketNumber]bool),
} }
} }
// ReceivedPacket registers a packet with PacketNumber p and updates the ranges // ReceivedPacket registers a packet with PacketNumber p and updates the ranges
func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) { func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) error {
h.mutex.Lock() if h.ranges.Len() >= protocol.MaxTrackedReceivedAckRanges {
defer h.mutex.Unlock() return errTooManyOutstandingReceivedAckRanges
}
if len(h.receivedPacketNumbers) >= protocol.MaxTrackedReceivedPackets {
return errTooManyOutstandingReceivedPackets
}
h.receivedPacketNumbers[p] = true
if h.ranges.Len() == 0 { if h.ranges.Len() == 0 {
h.ranges.PushBack(utils.PacketInterval{Start: p, End: p}) h.ranges.PushBack(utils.PacketInterval{Start: p, End: p})
return return nil
} }
for el := h.ranges.Back(); el != nil; el = el.Prev() { for el := h.ranges.Back(); el != nil; el = el.Prev() {
// p already included in an existing range. Nothing to do here // p already included in an existing range. Nothing to do here
if p >= el.Value.Start && p <= el.Value.End { if p >= el.Value.Start && p <= el.Value.End {
return return nil
} }
var rangeExtended bool var rangeExtended bool
...@@ -52,46 +66,61 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) { ...@@ -52,46 +66,61 @@ func (h *receivedPacketHistory) ReceivedPacket(p protocol.PacketNumber) {
if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges if prev != nil && prev.Value.End+1 == el.Value.Start { // merge two ranges
prev.Value.End = el.Value.End prev.Value.End = el.Value.End
h.ranges.Remove(el) h.ranges.Remove(el)
return return nil
} }
return // if the two ranges were not merge, we're done here return nil // if the two ranges were not merge, we're done here
} }
// create a new range at the end // create a new range at the end
if p > el.Value.End { if p > el.Value.End {
h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el) h.ranges.InsertAfter(utils.PacketInterval{Start: p, End: p}, el)
return return nil
} }
} }
// create a new range at the beginning // create a new range at the beginning
h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front()) h.ranges.InsertBefore(utils.PacketInterval{Start: p, End: p}, h.ranges.Front())
return nil
} }
// DeleteBelow deletes all entries below the leastUnacked packet number
func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber) { func (h *receivedPacketHistory) DeleteBelow(leastUnacked protocol.PacketNumber) {
h.mutex.Lock() h.lowestInReceivedPacketNumbers = utils.MaxPacketNumber(h.lowestInReceivedPacketNumbers, leastUnacked)
defer h.mutex.Unlock()
nextEl := h.ranges.Front() nextEl := h.ranges.Front()
for el := h.ranges.Front(); nextEl != nil; el = nextEl { for el := h.ranges.Front(); nextEl != nil; el = nextEl {
nextEl = el.Next() nextEl = el.Next()
if leastUnacked > el.Value.Start && leastUnacked <= el.Value.End { if leastUnacked > el.Value.Start && leastUnacked <= el.Value.End {
for i := el.Value.Start; i < leastUnacked; i++ { // adjust start value of a range
delete(h.receivedPacketNumbers, i)
}
el.Value.Start = leastUnacked el.Value.Start = leastUnacked
} } else if el.Value.End < leastUnacked { // delete a whole range
if el.Value.End < leastUnacked { // delete a whole range for i := el.Value.Start; i <= el.Value.End; i++ {
delete(h.receivedPacketNumbers, i)
}
h.ranges.Remove(el) h.ranges.Remove(el)
} else { } else { // no ranges affected. Nothing to do
return return
} }
} }
} }
// IsDuplicate determines if a packet should be regarded as a duplicate packet
// note that after receiving a StopWaitingFrame, all packets below the LeastUnacked should be regarded as duplicates, even if the packet was just delayed
func (h *receivedPacketHistory) IsDuplicate(p protocol.PacketNumber) bool {
if p < h.lowestInReceivedPacketNumbers {
return true
}
_, ok := h.receivedPacketNumbers[p]
return ok
}
// GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame // GetAckRanges gets a slice of all AckRanges that can be used in an AckFrame
func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange { func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange {
h.mutex.RLock()
defer h.mutex.RUnlock()
if h.ranges.Len() == 0 { if h.ranges.Len() == 0 {
return nil return nil
} }
...@@ -104,3 +133,13 @@ func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange { ...@@ -104,3 +133,13 @@ func (h *receivedPacketHistory) GetAckRanges() []frames.AckRange {
return ackRanges return ackRanges
} }
func (h *receivedPacketHistory) GetHighestAckRange() frames.AckRange {
ackRange := frames.AckRange{}
if h.ranges.Len() > 0 {
r := h.ranges.Back().Value
ackRange.FirstPacketNumber = r.Start
ackRange.LastPacketNumber = r.End
}
return ackRange
}
...@@ -47,9 +47,7 @@ type sentPacketHandler struct { ...@@ -47,9 +47,7 @@ type sentPacketHandler struct {
} }
// NewSentPacketHandler creates a new sentPacketHandler // NewSentPacketHandler creates a new sentPacketHandler
func NewSentPacketHandler() SentPacketHandler { func NewSentPacketHandler(rttStats *congestion.RTTStats) SentPacketHandler {
rttStats := &congestion.RTTStats{}
congestion := congestion.NewCubicSender( congestion := congestion.NewCubicSender(
congestion.DefaultClock{}, congestion.DefaultClock{},
rttStats, rttStats,
......
...@@ -13,8 +13,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go ...@@ -13,8 +13,8 @@ clone_folder: c:\gopath\src\github.com\lucas-clemente\quic-go
install: install:
- rmdir c:\go /s /q - rmdir c:\go /s /q
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.7.1.windows-amd64.zip - appveyor DownloadFile https://storage.googleapis.com/golang/go1.7.5.windows-amd64.zip
- 7z x go1.7.1.windows-amd64.zip -y -oC:\ > NUL - 7z x go1.7.5.windows-amd64.zip -y -oC:\ > NUL
- set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin - set PATH=%PATH%;%GOPATH%\bin\windows_%GOARCH%;%GOPATH%\bin
- echo %PATH% - echo %PATH%
- echo %GOPATH% - echo %GOPATH%
......
package quic
import (
"bytes"
"crypto/tls"
"errors"
"net"
"strings"
"sync/atomic"
"time"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
)
// A Client of QUIC
type Client struct {
addr *net.UDPAddr
conn *net.UDPConn
hostname string
connectionID protocol.ConnectionID
version protocol.VersionNumber
versionNegotiated bool
closed uint32 // atomic bool
tlsConfig *tls.Config
cryptoChangeCallback CryptoChangeCallback
versionNegotiateCallback VersionNegotiateCallback
session packetHandler
}
// VersionNegotiateCallback is called once the client has a negotiated version
type VersionNegotiateCallback func() error
var errHostname = errors.New("Invalid hostname")
var (
errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
)
// NewClient makes a new client
func NewClient(host string, tlsConfig *tls.Config, cryptoChangeCallback CryptoChangeCallback, versionNegotiateCallback VersionNegotiateCallback) (*Client, error) {
udpAddr, err := net.ResolveUDPAddr("udp", host)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
connectionID, err := utils.GenerateConnectionID()
if err != nil {
return nil, err
}
hostname, _, err := net.SplitHostPort(host)
if err != nil {
return nil, err
}
client := &Client{
addr: udpAddr,
conn: conn,
hostname: hostname,
version: protocol.SupportedVersions[len(protocol.SupportedVersions)-1], // use the highest supported version by default
connectionID: connectionID,
tlsConfig: tlsConfig,
cryptoChangeCallback: cryptoChangeCallback,
versionNegotiateCallback: versionNegotiateCallback,
}
utils.Infof("Starting new connection to %s (%s), connectionID %x, version %d", host, udpAddr.String(), connectionID, client.version)
err = client.createNewSession(nil)
if err != nil {
return nil, err
}
return client, nil
}
// Listen listens
func (c *Client) Listen() error {
for {
data := getPacketBuffer()
data = data[:protocol.MaxPacketSize]
n, _, err := c.conn.ReadFromUDP(data)
if err != nil {
if strings.HasSuffix(err.Error(), "use of closed network connection") {
return nil
}
return err
}
data = data[:n]
err = c.handlePacket(data)
if err != nil {
utils.Errorf("error handling packet: %s", err.Error())
c.session.Close(err)
return err
}
}
}
// OpenStream opens a stream, for client-side created streams (i.e. odd streamIDs)
func (c *Client) OpenStream(id protocol.StreamID) (utils.Stream, error) {
return c.session.OpenStream(id)
}
// Close closes the connection
func (c *Client) Close(e error) error {
// Only close once
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
return nil
}
_ = c.session.Close(e)
return c.conn.Close()
}
func (c *Client) handlePacket(packet []byte) error {
if protocol.ByteCount(len(packet)) > protocol.MaxPacketSize {
return qerr.PacketTooLarge
}
rcvTime := time.Now()
r := bytes.NewReader(packet)
hdr, err := ParsePublicHeader(r, protocol.PerspectiveServer)
if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error())
}
hdr.Raw = packet[:len(packet)-r.Len()]
// ignore delayed / duplicated version negotiation packets
if c.versionNegotiated && hdr.VersionFlag {
return nil
}
// this is the first packet after the client sent a packet with the VersionFlag set
// if the server doesn't send a version negotiation packet, it supports the suggested version
if !hdr.VersionFlag && !c.versionNegotiated {
c.versionNegotiated = true
err = c.versionNegotiateCallback()
if err != nil {
return err
}
}
if hdr.VersionFlag {
var hasCommonVersion bool // check if we're supporting any of the offered versions
for _, v := range hdr.SupportedVersions {
// check if the server sent the offered version in supported versions
if v == c.version {
return qerr.Error(qerr.InvalidVersionNegotiationPacket, "Server already supports client's version and should have accepted the connection.")
}
if v != protocol.VersionUnsupported {
hasCommonVersion = true
}
}
if !hasCommonVersion {
utils.Infof("No common version found.")
return qerr.InvalidVersion
}
ok, highestSupportedVersion := protocol.HighestSupportedVersion(hdr.SupportedVersions)
if !ok {
return qerr.VersionNegotiationMismatch
}
utils.Infof("Switching to QUIC version %d", highestSupportedVersion)
c.version = highestSupportedVersion
c.versionNegotiated = true
c.session.Close(errCloseSessionForNewVersion)
err = c.createNewSession(hdr.SupportedVersions)
if err != nil {
return err
}
err = c.versionNegotiateCallback()
if err != nil {
return err
}
return nil // version negotiation packets have no payload
}
c.session.handlePacket(&receivedPacket{
remoteAddr: c.addr,
publicHeader: hdr,
data: packet[len(packet)-r.Len():],
rcvTime: rcvTime,
})
return nil
}
func (c *Client) createNewSession(negotiatedVersions []protocol.VersionNumber) error {
var err error
c.session, err = newClientSession(c.conn, c.addr, c.hostname, c.version, c.connectionID, c.tlsConfig, c.streamCallback, c.closeCallback, c.cryptoChangeCallback, negotiatedVersions)
if err != nil {
return err
}
go c.session.run()
return nil
}
func (c *Client) streamCallback(session *Session, stream utils.Stream) {}
func (c *Client) closeCallback(id protocol.ConnectionID) {
utils.Infof("Connection %x closed.", id)
}
...@@ -2,6 +2,8 @@ coverage: ...@@ -2,6 +2,8 @@ coverage:
round: nearest round: nearest
ignore: ignore:
- ackhandler/packet_linkedlist.go - ackhandler/packet_linkedlist.go
- h2quic/gzipreader.go
- h2quic/response.go
- utils/byteinterval_linkedlist.go - utils/byteinterval_linkedlist.go
- utils/packetinterval_linkedlist.go - utils/packetinterval_linkedlist.go
status: status:
......
package crypto package crypto
import ( import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls" "crypto/tls"
"errors" "errors"
"strings" "strings"
) )
// A CertChain holds a certificate and a private key
type CertChain interface {
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
GetLeafCert(sni string) ([]byte, error)
}
// proofSource stores a key and a certificate for the server proof // proofSource stores a key and a certificate for the server proof
type proofSource struct { type certChain struct {
config *tls.Config config *tls.Config
} }
// NewProofSource loads the key and cert from files var _ CertChain = &certChain{}
func NewProofSource(tlsConfig *tls.Config) (Signer, error) {
return &proofSource{config: tlsConfig}, nil var errNoMatchingCertificate = errors.New("no matching certificate found")
// NewCertChain loads the key and cert from files
func NewCertChain(tlsConfig *tls.Config) CertChain {
return &certChain{config: tlsConfig}
} }
// SignServerProof signs CHLO and server config for use in the server proof // SignServerProof signs CHLO and server config for use in the server proof
func (ps *proofSource) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) { func (c *certChain) SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error) {
cert, err := ps.getCertForSNI(sni) cert, err := c.getCertForSNI(sni)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hash := sha256.New() return signServerProof(cert, chlo, serverConfigData)
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
key, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, errors.New("expected PrivateKey to implement crypto.Signer")
}
opts := crypto.SignerOpts(crypto.SHA256)
if _, ok = key.(*rsa.PrivateKey); ok {
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
}
return key.Sign(rand.Reader, hash.Sum(nil), opts)
} }
// GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc // GetCertsCompressed gets the certificate in the format described by the QUIC crypto doc
func (ps *proofSource) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { func (c *certChain) GetCertsCompressed(sni string, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
cert, err := ps.getCertForSNI(sni) cert, err := c.getCertForSNI(sni)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -58,17 +47,17 @@ func (ps *proofSource) GetCertsCompressed(sni string, pCommonSetHashes, pCachedH ...@@ -58,17 +47,17 @@ func (ps *proofSource) GetCertsCompressed(sni string, pCommonSetHashes, pCachedH
} }
// GetLeafCert gets the leaf certificate // GetLeafCert gets the leaf certificate
func (ps *proofSource) GetLeafCert(sni string) ([]byte, error) { func (c *certChain) GetLeafCert(sni string) ([]byte, error) {
cert, err := ps.getCertForSNI(sni) cert, err := c.getCertForSNI(sni)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return cert.Certificate[0], nil return cert.Certificate[0], nil
} }
func (ps *proofSource) getCertForSNI(sni string) (*tls.Certificate, error) { func (c *certChain) getCertForSNI(sni string) (*tls.Certificate, error) {
if ps.config.GetCertificate != nil { if c.config.GetCertificate != nil {
cert, err := ps.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni}) cert, err := c.config.GetCertificate(&tls.ClientHelloInfo{ServerName: sni})
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -76,17 +65,20 @@ func (ps *proofSource) getCertForSNI(sni string) (*tls.Certificate, error) { ...@@ -76,17 +65,20 @@ func (ps *proofSource) getCertForSNI(sni string) (*tls.Certificate, error) {
return cert, nil return cert, nil
} }
} }
if len(ps.config.NameToCertificate) != 0 {
if cert, ok := ps.config.NameToCertificate[sni]; ok { if len(c.config.NameToCertificate) != 0 {
if cert, ok := c.config.NameToCertificate[sni]; ok {
return cert, nil return cert, nil
} }
wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' }) wildcardSNI := "*" + strings.TrimLeftFunc(sni, func(r rune) bool { return r != '.' })
if cert, ok := ps.config.NameToCertificate[wildcardSNI]; ok { if cert, ok := c.config.NameToCertificate[wildcardSNI]; ok {
return cert, nil return cert, nil
} }
} }
if len(ps.config.Certificates) != 0 {
return &ps.config.Certificates[0], nil if len(c.config.Certificates) != 0 {
return &c.config.Certificates[0], nil
} }
return nil, errors.New("no matching certificate found")
return nil, errNoMatchingCertificate
} }
...@@ -22,8 +22,8 @@ const ( ...@@ -22,8 +22,8 @@ const (
type entry struct { type entry struct {
t entryType t entryType
h uint64 h uint64 // set hash
i uint32 i uint32 // index
} }
func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) { func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]byte, error) {
...@@ -41,7 +41,7 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by ...@@ -41,7 +41,7 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by
chainHashes := make([]uint64, len(chain)) chainHashes := make([]uint64, len(chain))
for i := range chain { for i := range chain {
chainHashes[i] = hashCert(chain[i]) chainHashes[i] = HashCert(chain[i])
} }
entries := buildEntries(chain, chainHashes, cachedHashes, setHashes) entries := buildEntries(chain, chainHashes, cachedHashes, setHashes)
...@@ -89,6 +89,111 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by ...@@ -89,6 +89,111 @@ func compressChain(chain [][]byte, pCommonSetHashes, pCachedHashes []byte) ([]by
return res.Bytes(), nil return res.Bytes(), nil
} }
func decompressChain(data []byte) ([][]byte, error) {
var chain [][]byte
var entries []entry
r := bytes.NewReader(data)
var numCerts int
var hasCompressedCerts bool
for {
entryTypeByte, err := r.ReadByte()
if entryTypeByte == 0 {
break
}
et := entryType(entryTypeByte)
if err != nil {
return nil, err
}
numCerts++
switch et {
case entryCached:
// we're not sending any certificate hashes in the CHLO, so there shouldn't be any cached certificates in the chain
return nil, errors.New("unexpected cached certificate")
case entryCommon:
e := entry{t: entryCommon}
e.h, err = utils.ReadUint64(r)
if err != nil {
return nil, err
}
e.i, err = utils.ReadUint32(r)
if err != nil {
return nil, err
}
certSet, ok := certSets[e.h]
if !ok {
return nil, errors.New("unknown certSet")
}
if e.i >= uint32(len(certSet)) {
return nil, errors.New("certificate not found in certSet")
}
entries = append(entries, e)
chain = append(chain, certSet[e.i])
case entryCompressed:
hasCompressedCerts = true
entries = append(entries, entry{t: entryCompressed})
chain = append(chain, nil)
default:
return nil, errors.New("unknown entryType")
}
}
if numCerts == 0 {
return make([][]byte, 0, 0), nil
}
if hasCompressedCerts {
uncompressedLength, err := utils.ReadUint32(r)
if err != nil {
fmt.Println(4)
return nil, err
}
zlibDict := buildZlibDictForEntries(entries, chain)
gz, err := zlib.NewReaderDict(r, zlibDict)
if err != nil {
return nil, err
}
defer gz.Close()
var totalLength uint32
var certIndex int
for totalLength < uncompressedLength {
lenBytes := make([]byte, 4)
_, err := gz.Read(lenBytes)
if err != nil {
return nil, err
}
certLen := binary.LittleEndian.Uint32(lenBytes)
cert := make([]byte, certLen)
n, err := gz.Read(cert)
if uint32(n) != certLen && err != nil {
return nil, err
}
for {
if certIndex >= len(entries) {
return nil, errors.New("CertCompression BUG: no element to save uncompressed certificate")
}
if entries[certIndex].t == entryCompressed {
chain[certIndex] = cert
certIndex++
break
}
certIndex++
}
totalLength += 4 + certLen
}
}
return chain, nil
}
func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry { func buildEntries(chain [][]byte, chainHashes, cachedHashes, setHashes []uint64) []entry {
res := make([]entry, len(chain)) res := make([]entry, len(chain))
chainLoop: chainLoop:
...@@ -149,8 +254,19 @@ func splitHashes(hashes []byte) ([]uint64, error) { ...@@ -149,8 +254,19 @@ func splitHashes(hashes []byte) ([]uint64, error) {
return res, nil return res, nil
} }
func hashCert(cert []byte) uint64 { func getCommonCertificateHashes() []byte {
h := fnv.New64() ccs := make([]byte, 8*len(certSets), 8*len(certSets))
i := 0
for certSetHash := range certSets {
binary.LittleEndian.PutUint64(ccs[i*8:(i+1)*8], certSetHash)
i++
}
return ccs
}
// HashCert calculates the FNV1a hash of a certificate
func HashCert(cert []byte) uint64 {
h := fnv.New64a()
h.Write(cert) h.Write(cert)
return h.Sum64() return h.Sum64()
} }
package crypto
import (
"crypto/tls"
"crypto/x509"
"errors"
"hash/fnv"
"time"
"github.com/lucas-clemente/quic-go/qerr"
)
// CertManager manages the certificates sent by the server
type CertManager interface {
SetData([]byte) error
GetCommonCertificateHashes() []byte
GetLeafCert() []byte
GetLeafCertHash() (uint64, error)
VerifyServerProof(proof, chlo, serverConfigData []byte) bool
Verify(hostname string) error
}
type certManager struct {
chain []*x509.Certificate
config *tls.Config
}
var _ CertManager = &certManager{}
var errNoCertificateChain = errors.New("CertManager BUG: No certicifate chain loaded")
// NewCertManager creates a new CertManager
func NewCertManager(tlsConfig *tls.Config) CertManager {
return &certManager{config: tlsConfig}
}
// SetData takes the byte-slice sent in the SHLO and decompresses it into the certificate chain
func (c *certManager) SetData(data []byte) error {
byteChain, err := decompressChain(data)
if err != nil {
return qerr.Error(qerr.InvalidCryptoMessageParameter, "Certificate data invalid")
}
chain := make([]*x509.Certificate, len(byteChain), len(byteChain))
for i, data := range byteChain {
cert, err := x509.ParseCertificate(data)
if err != nil {
return err
}
chain[i] = cert
}
c.chain = chain
return nil
}
func (c *certManager) GetCommonCertificateHashes() []byte {
return getCommonCertificateHashes()
}
// GetLeafCert returns the leaf certificate of the certificate chain
// it returns nil if the certificate chain has not yet been set
func (c *certManager) GetLeafCert() []byte {
if len(c.chain) == 0 {
return nil
}
return c.chain[0].Raw
}
// GetLeafCertHash calculates the FNV1a_64 hash of the leaf certificate
func (c *certManager) GetLeafCertHash() (uint64, error) {
leafCert := c.GetLeafCert()
if leafCert == nil {
return 0, errNoCertificateChain
}
h := fnv.New64a()
_, err := h.Write(leafCert)
if err != nil {
return 0, err
}
return h.Sum64(), nil
}
// VerifyServerProof verifies the signature of the server config
// it should only be called after the certificate chain has been set, otherwise it returns false
func (c *certManager) VerifyServerProof(proof, chlo, serverConfigData []byte) bool {
if len(c.chain) == 0 {
return false
}
return verifyServerProof(proof, c.chain[0], chlo, serverConfigData)
}
// Verify verifies the certificate chain
func (c *certManager) Verify(hostname string) error {
if len(c.chain) == 0 {
return errNoCertificateChain
}
if c.config != nil && c.config.InsecureSkipVerify {
return nil
}
leafCert := c.chain[0]
var opts x509.VerifyOptions
if c.config != nil {
opts.Roots = c.config.RootCAs
opts.DNSName = c.config.ServerName
if c.config.Time == nil {
opts.CurrentTime = time.Now()
} else {
opts.CurrentTime = c.config.Time()
}
} else {
opts.DNSName = hostname
}
// the first certificate is the leaf certificate, all others are intermediates
if len(c.chain) > 1 {
intermediates := x509.NewCertPool()
for i := 1; i < len(c.chain); i++ {
intermediates.AddCert(c.chain[i])
}
opts.Intermediates = intermediates
}
_, err := leafCert.Verify(opts)
return err
}
// +build ignore
package crypto
import (
"crypto/rand"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
var _ = Describe("Chacha20poly1305", func() {
var (
alice, bob AEAD
keyAlice, keyBob, ivAlice, ivBob []byte
)
BeforeEach(func() {
keyAlice = make([]byte, 32)
keyBob = make([]byte, 32)
ivAlice = make([]byte, 4)
ivBob = make([]byte, 4)
rand.Reader.Read(keyAlice)
rand.Reader.Read(keyBob)
rand.Reader.Read(ivAlice)
rand.Reader.Read(ivBob)
var err error
alice, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice)
Expect(err).ToNot(HaveOccurred())
bob, err = NewAEADChacha20Poly1305(keyAlice, keyBob, ivAlice, ivBob)
Expect(err).ToNot(HaveOccurred())
})
It("seals and opens", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := bob.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("seals and opens reverse", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
text, err := alice.Open(nil, b, 42, []byte("aad"))
Expect(err).ToNot(HaveOccurred())
Expect(text).To(Equal([]byte("foobar")))
})
It("has the proper length", func() {
b := bob.Seal(nil, []byte("foobar"), 42, []byte("aad"))
Expect(b).To(HaveLen(6 + 12))
})
It("fails with wrong aad", func() {
b := alice.Seal(nil, []byte("foobar"), 42, []byte("aad"))
_, err := bob.Open(nil, b, 42, []byte("aad2"))
Expect(err).To(HaveOccurred())
})
It("rejects wrong key and iv sizes", func() {
var err error
e := "chacha20poly1305: expected 32-byte keys and 4-byte IVs"
_, err = NewAEADChacha20Poly1305(keyBob[1:], keyAlice, ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice[1:], ivBob, ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob[1:], ivAlice)
Expect(err).To(MatchError(e))
_, err = NewAEADChacha20Poly1305(keyBob, keyAlice, ivBob, ivAlice[1:])
Expect(err).To(MatchError(e))
})
})
...@@ -21,15 +21,21 @@ import ( ...@@ -21,15 +21,21 @@ import (
// } // }
// DeriveKeysAESGCM derives the client and server keys and creates a matching AES-GCM AEAD instance // DeriveKeysAESGCM derives the client and server keys and creates a matching AES-GCM AEAD instance
func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte) (AEAD, error) { func DeriveKeysAESGCM(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo []byte, scfg []byte, cert []byte, divNonce []byte, pers protocol.Perspective) (AEAD, error) {
otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16) var swap bool
if pers == protocol.PerspectiveClient {
swap = true
}
otherKey, myKey, otherIV, myIV, err := deriveKeys(forwardSecure, sharedSecret, nonces, connID, chlo, scfg, cert, divNonce, 16, swap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewAEADAESGCM(otherKey, myKey, otherIV, myIV) return NewAEADAESGCM(otherKey, myKey, otherIV, myIV)
} }
func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int) ([]byte, []byte, []byte, []byte, error) { // deriveKeys derives the keys and the IVs
// swap should be set true if generating the values for the client, and false for the server
func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol.ConnectionID, chlo, scfg, cert, divNonce []byte, keyLen int, swap bool) ([]byte, []byte, []byte, []byte, error) {
var info bytes.Buffer var info bytes.Buffer
if forwardSecure { if forwardSecure {
info.Write([]byte("QUIC forward secure key expansion\x00")) info.Write([]byte("QUIC forward secure key expansion\x00"))
...@@ -47,17 +53,33 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol ...@@ -47,17 +53,33 @@ func deriveKeys(forwardSecure bool, sharedSecret, nonces []byte, connID protocol
if _, err := io.ReadFull(r, s); err != nil { if _, err := io.ReadFull(r, s); err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }
otherKey := s[:keyLen]
myKey := s[keyLen : 2*keyLen] key1 := s[:keyLen]
otherIV := s[2*keyLen : 2*keyLen+4] key2 := s[keyLen : 2*keyLen]
myIV := s[2*keyLen+4:] iv1 := s[2*keyLen : 2*keyLen+4]
iv2 := s[2*keyLen+4:]
var otherKey, myKey []byte
var otherIV, myIV []byte
if !forwardSecure { if !forwardSecure {
if err := diversify(myKey, myIV, divNonce); err != nil { if err := diversify(key2, iv2, divNonce); err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }
} }
if swap {
otherKey = key2
myKey = key1
otherIV = iv2
myIV = iv1
} else {
otherKey = key1
myKey = key2
otherIV = iv1
myIV = iv2
}
return otherKey, myKey, otherIV, myIV, nil return otherKey, myKey, otherIV, myIV, nil
} }
......
package crypto
import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"errors"
"math/big"
)
type ecdsaSignature struct {
R, S *big.Int
}
// signServerProof signs CHLO and server config for use in the server proof
func signServerProof(cert *tls.Certificate, chlo []byte, serverConfigData []byte) ([]byte, error) {
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
key, ok := cert.PrivateKey.(crypto.Signer)
if !ok {
return nil, errors.New("expected PrivateKey to implement crypto.Signer")
}
opts := crypto.SignerOpts(crypto.SHA256)
if _, ok = key.(*rsa.PrivateKey); ok {
opts = &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
}
return key.Sign(rand.Reader, hash.Sum(nil), opts)
}
// verifyServerProof verifies the server proof signature
func verifyServerProof(proof []byte, cert *x509.Certificate, chlo []byte, serverConfigData []byte) bool {
hash := sha256.New()
hash.Write([]byte("QUIC CHLO and server config signature\x00"))
chloHash := sha256.Sum256(chlo)
hash.Write([]byte{32, 0, 0, 0})
hash.Write(chloHash[:])
hash.Write(serverConfigData)
// RSA
if cert.PublicKeyAlgorithm == x509.RSA {
opts := &rsa.PSSOptions{SaltLength: 32, Hash: crypto.SHA256}
err := rsa.VerifyPSS(cert.PublicKey.(*rsa.PublicKey), crypto.SHA256, hash.Sum(nil), proof, opts)
return err == nil
}
// ECDSA
signature := &ecdsaSignature{}
rest, err := asn1.Unmarshal(proof, signature)
if err != nil || len(rest) != 0 {
return false
}
return ecdsa.Verify(cert.PublicKey.(*ecdsa.PublicKey), hash.Sum(nil), signature.R, signature.S)
}
package crypto
// A Signer holds a certificate and a private key
type Signer interface {
SignServerProof(sni string, chlo []byte, serverConfigData []byte) ([]byte, error)
GetCertsCompressed(sni string, commonSetHashes, cachedHashes []byte) ([]byte, error)
GetLeafCert(sni string) ([]byte, error)
}
...@@ -2,38 +2,39 @@ package flowcontrol ...@@ -2,38 +2,39 @@ package flowcontrol
import ( import (
"errors" "errors"
"fmt"
"sync" "sync"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/utils"
) )
type flowControlManager struct { type flowControlManager struct {
connectionParametersManager *handshake.ConnectionParametersManager connectionParameters handshake.ConnectionParametersManager
rttStats *congestion.RTTStats
streamFlowController map[protocol.StreamID]*flowController streamFlowController map[protocol.StreamID]*flowController
contributesToConnectionFlowControl map[protocol.StreamID]bool contributesToConnectionFlowControl map[protocol.StreamID]bool
mutex sync.RWMutex mutex sync.RWMutex
} }
var ( var _ FlowControlManager = &flowControlManager{}
// ErrStreamFlowControlViolation is a stream flow control violation
ErrStreamFlowControlViolation = errors.New("Stream level flow control violation")
// ErrConnectionFlowControlViolation is a connection level flow control violation
ErrConnectionFlowControlViolation = errors.New("Connection level flow control violation")
)
var errMapAccess = errors.New("Error accessing the flowController map.") var errMapAccess = errors.New("Error accessing the flowController map.")
// NewFlowControlManager creates a new flow control manager // NewFlowControlManager creates a new flow control manager
func NewFlowControlManager(connectionParametersManager *handshake.ConnectionParametersManager) FlowControlManager { func NewFlowControlManager(connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) FlowControlManager {
fcm := flowControlManager{ fcm := flowControlManager{
connectionParametersManager: connectionParametersManager, connectionParameters: connectionParameters,
rttStats: rttStats,
streamFlowController: make(map[protocol.StreamID]*flowController), streamFlowController: make(map[protocol.StreamID]*flowController),
contributesToConnectionFlowControl: make(map[protocol.StreamID]bool), contributesToConnectionFlowControl: make(map[protocol.StreamID]bool),
} }
// initialize connection level flow controller // initialize connection level flow controller
fcm.streamFlowController[0] = newFlowController(0, connectionParametersManager) fcm.streamFlowController[0] = newFlowController(0, connectionParameters, rttStats)
fcm.contributesToConnectionFlowControl[0] = false fcm.contributesToConnectionFlowControl[0] = false
return &fcm return &fcm
} }
...@@ -47,7 +48,7 @@ func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesTo ...@@ -47,7 +48,7 @@ func (f *flowControlManager) NewStream(streamID protocol.StreamID, contributesTo
return return
} }
f.streamFlowController[streamID] = newFlowController(streamID, f.connectionParametersManager) f.streamFlowController[streamID] = newFlowController(streamID, f.connectionParameters, f.rttStats)
f.contributesToConnectionFlowControl[streamID] = contributesToConnectionFlow f.contributesToConnectionFlowControl[streamID] = contributesToConnectionFlow
} }
...@@ -59,6 +60,48 @@ func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) { ...@@ -59,6 +60,48 @@ func (f *flowControlManager) RemoveStream(streamID protocol.StreamID) {
f.mutex.Unlock() f.mutex.Unlock()
} }
// ResetStream should be called when receiving a RstStreamFrame
// it updates the byte offset to the value in the RstStreamFrame
// streamID must not be 0 here
func (f *flowControlManager) ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error {
f.mutex.Lock()
defer f.mutex.Unlock()
streamFlowController, err := f.getFlowController(streamID)
if err != nil {
return err
}
increment, err := streamFlowController.UpdateHighestReceived(byteOffset)
if err != nil {
return qerr.StreamDataAfterTermination
}
if streamFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow))
}
if f.contributesToConnectionFlowControl[streamID] {
connectionFlowController := f.streamFlowController[0]
connectionFlowController.IncrementHighestReceived(increment)
if connectionFlowController.CheckFlowControlViolation() {
return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow))
}
}
return nil
}
func (f *flowControlManager) GetBytesSent(streamID protocol.StreamID) (protocol.ByteCount, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
fc, err := f.getFlowController(streamID)
if err != nil {
return 0, err
}
return fc.GetBytesSent(), nil
}
// UpdateHighestReceived updates the highest received byte offset for a stream // UpdateHighestReceived updates the highest received byte offset for a stream
// it adds the number of additional bytes to connection level flow control // it adds the number of additional bytes to connection level flow control
// streamID must not be 0 here // streamID must not be 0 here
...@@ -70,17 +113,19 @@ func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, b ...@@ -70,17 +113,19 @@ func (f *flowControlManager) UpdateHighestReceived(streamID protocol.StreamID, b
if err != nil { if err != nil {
return err return err
} }
increment := streamFlowController.UpdateHighestReceived(byteOffset) // UpdateHighestReceived returns an ErrReceivedSmallerByteOffset when StreamFrames got reordered
// this error can be ignored here
increment, _ := streamFlowController.UpdateHighestReceived(byteOffset)
if streamFlowController.CheckFlowControlViolation() { if streamFlowController.CheckFlowControlViolation() {
return ErrStreamFlowControlViolation return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes on stream %d, allowed %d bytes", byteOffset, streamID, streamFlowController.receiveFlowControlWindow))
} }
if f.contributesToConnectionFlowControl[streamID] { if f.contributesToConnectionFlowControl[streamID] {
connectionFlowController := f.streamFlowController[0] connectionFlowController := f.streamFlowController[0]
connectionFlowController.IncrementHighestReceived(increment) connectionFlowController.IncrementHighestReceived(increment)
if connectionFlowController.CheckFlowControlViolation() { if connectionFlowController.CheckFlowControlViolation() {
return ErrConnectionFlowControlViolation return qerr.Error(qerr.FlowControlReceivedTooMuchData, fmt.Sprintf("Received %d bytes for the connection, allowed %d bytes", byteOffset, connectionFlowController.receiveFlowControlWindow))
} }
} }
...@@ -117,6 +162,16 @@ func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) { ...@@ -117,6 +162,16 @@ func (f *flowControlManager) GetWindowUpdates() (res []WindowUpdate) {
return res return res
} }
func (f *flowControlManager) GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error) {
f.mutex.Lock()
defer f.mutex.Unlock()
flowController, err := f.getFlowController(streamID)
if err != nil {
return 0, err
}
return flowController.receiveFlowControlWindow, nil
}
// streamID must not be 0 here // streamID must not be 0 here
func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error { func (f *flowControlManager) AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error {
// Only lock the part reading from the map, since send-windows are only accessed from the session goroutine. // Only lock the part reading from the map, since send-windows are only accessed from the session goroutine.
......
package flowcontrol package flowcontrol
import ( import (
"errors"
"time"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
) )
type flowController struct { type flowController struct {
streamID protocol.StreamID streamID protocol.StreamID
connectionParametersManager *handshake.ConnectionParametersManager connectionParameters handshake.ConnectionParametersManager
rttStats *congestion.RTTStats
bytesSent protocol.ByteCount bytesSent protocol.ByteCount
sendFlowControlWindow protocol.ByteCount sendFlowControlWindow protocol.ByteCount
bytesRead protocol.ByteCount lastWindowUpdateTime time.Time
highestReceived protocol.ByteCount
receiveFlowControlWindow protocol.ByteCount bytesRead protocol.ByteCount
receiveFlowControlWindowIncrement protocol.ByteCount highestReceived protocol.ByteCount
receiveFlowControlWindow protocol.ByteCount
receiveFlowControlWindowIncrement protocol.ByteCount
maxReceiveFlowControlWindowIncrement protocol.ByteCount
} }
// ErrReceivedSmallerByteOffset occurs if the ByteOffset received is smaller than a ByteOffset that was set previously
var ErrReceivedSmallerByteOffset = errors.New("Received a smaller byte offset")
// newFlowController gets a new flow controller // newFlowController gets a new flow controller
func newFlowController(streamID protocol.StreamID, connectionParametersManager *handshake.ConnectionParametersManager) *flowController { func newFlowController(streamID protocol.StreamID, connectionParameters handshake.ConnectionParametersManager, rttStats *congestion.RTTStats) *flowController {
fc := flowController{ fc := flowController{
streamID: streamID, streamID: streamID,
connectionParametersManager: connectionParametersManager, connectionParameters: connectionParameters,
rttStats: rttStats,
} }
if streamID == 0 { if streamID == 0 {
fc.receiveFlowControlWindow = connectionParametersManager.GetReceiveConnectionFlowControlWindow() fc.receiveFlowControlWindow = connectionParameters.GetReceiveConnectionFlowControlWindow()
fc.receiveFlowControlWindowIncrement = fc.receiveFlowControlWindow fc.receiveFlowControlWindowIncrement = fc.receiveFlowControlWindow
fc.maxReceiveFlowControlWindowIncrement = connectionParameters.GetMaxReceiveConnectionFlowControlWindow()
} else { } else {
fc.receiveFlowControlWindow = connectionParametersManager.GetReceiveStreamFlowControlWindow() fc.receiveFlowControlWindow = connectionParameters.GetReceiveStreamFlowControlWindow()
fc.receiveFlowControlWindowIncrement = fc.receiveFlowControlWindow fc.receiveFlowControlWindowIncrement = fc.receiveFlowControlWindow
fc.maxReceiveFlowControlWindowIncrement = connectionParameters.GetMaxReceiveStreamFlowControlWindow()
} }
return &fc return &fc
...@@ -40,9 +55,9 @@ func newFlowController(streamID protocol.StreamID, connectionParametersManager * ...@@ -40,9 +55,9 @@ func newFlowController(streamID protocol.StreamID, connectionParametersManager *
func (c *flowController) getSendFlowControlWindow() protocol.ByteCount { func (c *flowController) getSendFlowControlWindow() protocol.ByteCount {
if c.sendFlowControlWindow == 0 { if c.sendFlowControlWindow == 0 {
if c.streamID == 0 { if c.streamID == 0 {
return c.connectionParametersManager.GetSendConnectionFlowControlWindow() return c.connectionParameters.GetSendConnectionFlowControlWindow()
} }
return c.connectionParametersManager.GetSendStreamFlowControlWindow() return c.connectionParameters.GetSendStreamFlowControlWindow()
} }
return c.sendFlowControlWindow return c.sendFlowControlWindow
} }
...@@ -51,6 +66,10 @@ func (c *flowController) AddBytesSent(n protocol.ByteCount) { ...@@ -51,6 +66,10 @@ func (c *flowController) AddBytesSent(n protocol.ByteCount) {
c.bytesSent += n c.bytesSent += n
} }
func (c *flowController) GetBytesSent() protocol.ByteCount {
return c.bytesSent
}
// UpdateSendWindow should be called after receiving a WindowUpdateFrame // UpdateSendWindow should be called after receiving a WindowUpdateFrame
// it returns true if the window was actually updated // it returns true if the window was actually updated
func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool { func (c *flowController) UpdateSendWindow(newOffset protocol.ByteCount) bool {
...@@ -76,13 +95,19 @@ func (c *flowController) SendWindowOffset() protocol.ByteCount { ...@@ -76,13 +95,19 @@ func (c *flowController) SendWindowOffset() protocol.ByteCount {
// UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher // UpdateHighestReceived updates the highestReceived value, if the byteOffset is higher
// Should **only** be used for the stream-level FlowController // Should **only** be used for the stream-level FlowController
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) protocol.ByteCount { // it returns an ErrReceivedSmallerByteOffset if the received byteOffset is smaller than any byteOffset received before
// This error occurs every time StreamFrames get reordered and has to be ignored in that case
// It should only be treated as an error when resetting a stream
func (c *flowController) UpdateHighestReceived(byteOffset protocol.ByteCount) (protocol.ByteCount, error) {
if byteOffset == c.highestReceived {
return 0, nil
}
if byteOffset > c.highestReceived { if byteOffset > c.highestReceived {
increment := byteOffset - c.highestReceived increment := byteOffset - c.highestReceived
c.highestReceived = byteOffset c.highestReceived = byteOffset
return increment return increment, nil
} }
return 0 return 0, ErrReceivedSmallerByteOffset
} }
// IncrementHighestReceived adds an increment to the highestReceived value // IncrementHighestReceived adds an increment to the highestReceived value
...@@ -99,14 +124,52 @@ func (c *flowController) AddBytesRead(n protocol.ByteCount) { ...@@ -99,14 +124,52 @@ func (c *flowController) AddBytesRead(n protocol.ByteCount) {
// if so, it returns true and the offset of the window // if so, it returns true and the offset of the window
func (c *flowController) MaybeTriggerWindowUpdate() (bool, protocol.ByteCount) { func (c *flowController) MaybeTriggerWindowUpdate() (bool, protocol.ByteCount) {
diff := c.receiveFlowControlWindow - c.bytesRead diff := c.receiveFlowControlWindow - c.bytesRead
// Chromium implements the same threshold // Chromium implements the same threshold
if diff < (c.receiveFlowControlWindowIncrement / 2) { if diff < (c.receiveFlowControlWindowIncrement / 2) {
c.maybeAdjustWindowIncrement()
c.lastWindowUpdateTime = time.Now()
c.receiveFlowControlWindow = c.bytesRead + c.receiveFlowControlWindowIncrement c.receiveFlowControlWindow = c.bytesRead + c.receiveFlowControlWindowIncrement
return true, c.receiveFlowControlWindow return true, c.receiveFlowControlWindow
} }
return false, 0 return false, 0
} }
// maybeAdjustWindowIncrement increases the receiveFlowControlWindowIncrement if we're sending WindowUpdates too often
func (c *flowController) maybeAdjustWindowIncrement() {
if c.lastWindowUpdateTime.IsZero() {
return
}
rtt := c.rttStats.SmoothedRTT()
if rtt == 0 {
return
}
timeSinceLastWindowUpdate := time.Now().Sub(c.lastWindowUpdateTime)
// interval between the window updates is sufficiently large, no need to increase the increment
if timeSinceLastWindowUpdate >= 2*rtt {
return
}
oldWindowSize := c.receiveFlowControlWindowIncrement
c.receiveFlowControlWindowIncrement = utils.MinByteCount(2*c.receiveFlowControlWindowIncrement, c.maxReceiveFlowControlWindowIncrement)
// debug log, if the window size was actually increased
if oldWindowSize < c.receiveFlowControlWindowIncrement {
newWindowSize := c.receiveFlowControlWindowIncrement / (1 << 10)
if c.streamID == 0 {
utils.Debugf("Increasing receive flow control window for the connection to %d kB", newWindowSize)
} else {
utils.Debugf("Increasing receive flow control window increment for stream %d to %d kB", c.streamID, newWindowSize)
}
}
}
func (c *flowController) CheckFlowControlViolation() bool { func (c *flowController) CheckFlowControlViolation() bool {
if c.highestReceived > c.receiveFlowControlWindow { if c.highestReceived > c.receiveFlowControlWindow {
return true return true
......
...@@ -13,9 +13,11 @@ type FlowControlManager interface { ...@@ -13,9 +13,11 @@ type FlowControlManager interface {
NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool) NewStream(streamID protocol.StreamID, contributesToConnectionFlow bool)
RemoveStream(streamID protocol.StreamID) RemoveStream(streamID protocol.StreamID)
// methods needed for receiving data // methods needed for receiving data
ResetStream(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error UpdateHighestReceived(streamID protocol.StreamID, byteOffset protocol.ByteCount) error
AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error AddBytesRead(streamID protocol.StreamID, n protocol.ByteCount) error
GetWindowUpdates() []WindowUpdate GetWindowUpdates() []WindowUpdate
GetReceiveWindow(streamID protocol.StreamID) (protocol.ByteCount, error)
// methods needed for sending data // methods needed for sending data
AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error AddBytesSent(streamID protocol.StreamID, n protocol.ByteCount) error
SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error) SendWindowSize(streamID protocol.StreamID) (protocol.ByteCount, error)
......
...@@ -27,8 +27,10 @@ type AckFrame struct { ...@@ -27,8 +27,10 @@ type AckFrame struct {
LowestAcked protocol.PacketNumber LowestAcked protocol.PacketNumber
AckRanges []AckRange // has to be ordered. The ACK range with the highest FirstPacketNumber goes first, the ACK range with the lowest FirstPacketNumber goes last AckRanges []AckRange // has to be ordered. The ACK range with the highest FirstPacketNumber goes first, the ACK range with the lowest FirstPacketNumber goes last
// time when the LargestAcked was receiveid
// this field Will not be set for received ACKs frames
PacketReceivedTime time.Time
DelayTime time.Duration DelayTime time.Duration
PacketReceivedTime time.Time // only for received packets. Will not be modified for received ACKs frames
} }
// ParseAckFrame reads an ACK frame // ParseAckFrame reads an ACK frame
...@@ -83,7 +85,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, ...@@ -83,7 +85,7 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame,
if err != nil { if err != nil {
return nil, err return nil, err
} }
if ackBlockLength < 1 { if frame.LargestAcked > 0 && ackBlockLength < 1 {
return nil, ErrInvalidFirstAckRange return nil, ErrInvalidFirstAckRange
} }
...@@ -141,7 +143,11 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame, ...@@ -141,7 +143,11 @@ func ParseAckFrame(r *bytes.Reader, version protocol.VersionNumber) (*AckFrame,
frame.LowestAcked = frame.AckRanges[len(frame.AckRanges)-1].FirstPacketNumber frame.LowestAcked = frame.AckRanges[len(frame.AckRanges)-1].FirstPacketNumber
} else { } else {
frame.LowestAcked = protocol.PacketNumber(largestAcked + 1 - ackBlockLength) if frame.LargestAcked == 0 {
frame.LowestAcked = 0
} else {
frame.LowestAcked = protocol.PacketNumber(largestAcked + 1 - ackBlockLength)
}
} }
if !frame.validateAckRanges() { if !frame.validateAckRanges() {
......
...@@ -11,9 +11,18 @@ func LogFrame(frame Frame, sent bool) { ...@@ -11,9 +11,18 @@ func LogFrame(frame Frame, sent bool) {
if sent { if sent {
dir = "->" dir = "->"
} }
if sf, ok := frame.(*StreamFrame); ok { switch f := frame.(type) {
utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, sf.StreamID, sf.FinBit, sf.Offset, sf.DataLen(), sf.Offset+sf.DataLen()) case *StreamFrame:
return utils.Debugf("\t%s &frames.StreamFrame{StreamID: %d, FinBit: %t, Offset: 0x%x, Data length: 0x%x, Offset + Data length: 0x%x}", dir, f.StreamID, f.FinBit, f.Offset, f.DataLen(), f.Offset+f.DataLen())
case *StopWaitingFrame:
if sent {
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x, PacketNumberLen: 0x%x}", dir, f.LeastUnacked, f.PacketNumberLen)
} else {
utils.Debugf("\t%s &frames.StopWaitingFrame{LeastUnacked: 0x%x}", dir, f.LeastUnacked)
}
case *AckFrame:
utils.Debugf("\t%s &frames.AckFrame{LargestAcked: 0x%x, LowestAcked: 0x%x, AckRanges: %#v, DelayTime: %s}", dir, f.LargestAcked, f.LowestAcked, f.AckRanges, f.DelayTime.String())
default:
utils.Debugf("\t%s %#v", dir, frame)
} }
utils.Debugf("\t%s %#v", dir, frame)
} }
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/idna"
quic "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
)
type quicClient interface {
OpenStream(protocol.StreamID) (utils.Stream, error)
Close(error) error
Listen() error
}
// Client is a HTTP2 client doing QUIC requests
type Client struct {
mutex sync.RWMutex
cryptoChangedCond sync.Cond
t *QuicRoundTripper
hostname string
encryptionLevel protocol.EncryptionLevel
client quicClient
headerStream utils.Stream
headerErr *qerr.QuicError
highestOpenedStream protocol.StreamID
requestWriter *requestWriter
responses map[protocol.StreamID]chan *http.Response
}
var _ h2quicClient = &Client{}
// NewClient creates a new client
func NewClient(t *QuicRoundTripper, tlsConfig *tls.Config, hostname string) (*Client, error) {
c := &Client{
t: t,
hostname: authorityAddr("https", hostname),
highestOpenedStream: 3,
responses: make(map[protocol.StreamID]chan *http.Response),
}
c.cryptoChangedCond = sync.Cond{L: &c.mutex}
var err error
c.client, err = quic.NewClient(c.hostname, tlsConfig, c.cryptoChangeCallback, c.versionNegotiateCallback)
if err != nil {
return nil, err
}
go c.client.Listen()
return c, nil
}
func (c *Client) handleStreamCb(session *quic.Session, stream utils.Stream) {
utils.Debugf("Handling stream %d", stream.StreamID())
}
func (c *Client) cryptoChangeCallback(isForwardSecure bool) {
c.cryptoChangedCond.L.Lock()
defer c.cryptoChangedCond.L.Unlock()
if isForwardSecure {
c.encryptionLevel = protocol.EncryptionForwardSecure
utils.Debugf("is forward secure")
} else {
c.encryptionLevel = protocol.EncryptionSecure
utils.Debugf("is secure")
}
c.cryptoChangedCond.Broadcast()
}
func (c *Client) versionNegotiateCallback() error {
var err error
// once the version has been negotiated, open the header stream
c.headerStream, err = c.client.OpenStream(3)
if err != nil {
return err
}
c.requestWriter = newRequestWriter(c.headerStream)
go c.handleHeaderStream()
return nil
}
func (c *Client) handleHeaderStream() {
decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {})
h2framer := http2.NewFramer(nil, c.headerStream)
var lastStream protocol.StreamID
for {
frame, err := h2framer.ReadFrame()
if err != nil {
c.headerErr = qerr.Error(qerr.InvalidStreamData, "cannot read frame")
break
}
lastStream = protocol.StreamID(frame.Header().StreamID)
hframe, ok := frame.(*http2.HeadersFrame)
if !ok {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "not a headers frame")
break
}
mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe}
mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment())
if err != nil {
c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, "cannot read header fields")
break
}
c.mutex.RLock()
headerChan, ok := c.responses[protocol.StreamID(hframe.StreamID)]
c.mutex.RUnlock()
if !ok {
c.headerErr = qerr.Error(qerr.InternalError, fmt.Sprintf("h2client BUG: response channel for stream %d not found", lastStream))
break
}
rsp, err := responseFromHeaders(mhframe)
if err != nil {
c.headerErr = qerr.Error(qerr.InternalError, err.Error())
}
headerChan <- rsp
}
// stop all running request
utils.Debugf("Error handling header stream %d: %s", lastStream, c.headerErr.Error())
c.mutex.Lock()
for _, responseChan := range c.responses {
responseChan <- nil
}
c.mutex.Unlock()
}
// Do executes a request and returns a response
func (c *Client) Do(req *http.Request) (*http.Response, error) {
// TODO: add port to address, if it doesn't have one
if req.URL.Scheme != "https" {
return nil, errors.New("quic http2: unsupported scheme")
}
if authorityAddr("https", hostnameFromRequest(req)) != c.hostname {
utils.Debugf("%s vs %s", req.Host, c.hostname)
return nil, errors.New("h2quic Client BUG: Do called for the wrong client")
}
hasBody := (req.Body != nil)
c.mutex.Lock()
c.highestOpenedStream += 2
dataStreamID := c.highestOpenedStream
for c.encryptionLevel != protocol.EncryptionForwardSecure {
c.cryptoChangedCond.Wait()
}
hdrChan := make(chan *http.Response)
c.responses[dataStreamID] = hdrChan
c.mutex.Unlock()
// TODO: think about what to do with a TooManyOpenStreams error. Wait and retry?
dataStream, err := c.client.OpenStream(dataStreamID)
if err != nil {
c.Close(err)
return nil, err
}
var requestedGzip bool
if !c.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" {
requestedGzip = true
}
// TODO: add support for trailers
endStream := !hasBody
err = c.requestWriter.WriteRequest(req, dataStreamID, endStream, requestedGzip)
if err != nil {
c.Close(err)
return nil, err
}
resc := make(chan error, 1)
if hasBody {
go func() {
resc <- c.writeRequestBody(dataStream, req.Body)
}()
}
var res *http.Response
var receivedResponse bool
var bodySent bool
if !hasBody {
bodySent = true
}
for !(bodySent && receivedResponse) {
select {
case res = <-hdrChan:
receivedResponse = true
c.mutex.Lock()
delete(c.responses, dataStreamID)
c.mutex.Unlock()
if res == nil { // an error occured on the header stream
c.Close(c.headerErr)
return nil, c.headerErr
}
case err := <-resc:
bodySent = true
if err != nil {
return nil, err
}
}
}
// TODO: correctly set this variable
var streamEnded bool
isHead := (req.Method == "HEAD")
res = setLength(res, isHead, streamEnded)
if streamEnded || isHead {
res.Body = noBody
} else {
res.Body = dataStream
if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
res.Header.Del("Content-Encoding")
res.Header.Del("Content-Length")
res.ContentLength = -1
res.Body = &gzipReader{body: res.Body}
setUncompressed(res)
}
}
res.Request = req
return res, nil
}
func (c *Client) writeRequestBody(dataStream utils.Stream, body io.ReadCloser) (err error) {
defer func() {
cerr := body.Close()
if err == nil {
// TODO: what to do with dataStream here? Maybe reset it?
err = cerr
}
}()
_, err = io.Copy(dataStream, body)
if err != nil {
// TODO: what to do with dataStream here? Maybe reset it?
return err
}
return dataStream.Close()
}
// Close closes the client
func (c *Client) Close(e error) {
_ = c.client.Close(e)
}
// copied from net/transport.go
// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
// and returns a host:port. The port 443 is added if needed.
func authorityAddr(scheme string, authority string) (addr string) {
host, port, err := net.SplitHostPort(authority)
if err != nil { // authority didn't have a port
port = "443"
if scheme == "http" {
port = "80"
}
host = authority
}
if a, err := idna.ToASCII(host); err == nil {
host = a
}
// IPv6 address literal, without a port:
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
return host + ":" + port
}
return net.JoinHostPort(host, port)
}
package h2quic
// copied from net/transport.go
// gzipReader wraps a response body so it can lazily
// call gzip.NewReader on the first call to Read
import (
"compress/gzip"
"io"
)
// call gzip.NewReader on the first call to Read
type gzipReader struct {
body io.ReadCloser // underlying Response.Body
zr *gzip.Reader // lazily-initialized gzip reader
zerr error // sticky error
}
func (gz *gzipReader) Read(p []byte) (n int, err error) {
if gz.zerr != nil {
return 0, gz.zerr
}
if gz.zr == nil {
gz.zr, err = gzip.NewReader(gz.body)
if err != nil {
gz.zerr = err
return 0, err
}
}
return gz.zr.Read(p)
}
func (gz *gzipReader) Close() error {
return gz.body.Close()
}
package h2quic package h2quic
import ( import (
"crypto/tls"
"errors" "errors"
"net/http" "net/http"
"net/url" "net/url"
"strconv"
"strings"
"golang.org/x/net/http2/hpack" "golang.org/x/net/http2/hpack"
) )
func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
var path, authority, method string var path, authority, method, contentLengthStr string
httpHeaders := http.Header{} httpHeaders := http.Header{}
for _, h := range headers { for _, h := range headers {
...@@ -20,6 +23,8 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { ...@@ -20,6 +23,8 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
method = h.Value method = h.Value
case ":authority": case ":authority":
authority = h.Value authority = h.Value
case "content-length":
contentLengthStr = h.Value
default: default:
if !h.IsPseudo() { if !h.IsPseudo() {
httpHeaders.Add(h.Name, h.Value) httpHeaders.Add(h.Name, h.Value)
...@@ -27,6 +32,11 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { ...@@ -27,6 +32,11 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
} }
} }
// concatenate cookie headers, see https://tools.ietf.org/html/rfc6265#section-5.4
if len(httpHeaders["Cookie"]) > 0 {
httpHeaders.Set("Cookie", strings.Join(httpHeaders["Cookie"], "; "))
}
if len(path) == 0 || len(authority) == 0 || len(method) == 0 { if len(path) == 0 || len(authority) == 0 || len(method) == 0 {
return nil, errors.New(":path, :authority and :method must not be empty") return nil, errors.New(":path, :authority and :method must not be empty")
} }
...@@ -36,16 +46,35 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) { ...@@ -36,16 +46,35 @@ func requestFromHeaders(headers []hpack.HeaderField) (*http.Request, error) {
return nil, err return nil, err
} }
var contentLength int64
if len(contentLengthStr) > 0 {
contentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return nil, err
}
}
return &http.Request{ return &http.Request{
Method: method, Method: method,
URL: u, URL: u,
Proto: "HTTP/2.0", Proto: "HTTP/2.0",
ProtoMajor: 2, ProtoMajor: 2,
ProtoMinor: 0, ProtoMinor: 0,
Header: httpHeaders, Header: httpHeaders,
Body: nil, Body: nil,
// ContentLength: -1, ContentLength: contentLength,
Host: authority, Host: authority,
RequestURI: path, RequestURI: path,
TLS: &tls.ConnectionState{},
}, nil }, nil
} }
func hostnameFromRequest(req *http.Request) string {
if len(req.Host) > 0 {
return req.Host
}
if req.URL != nil {
return req.URL.Host
}
return ""
}
package h2quic
import (
"io"
"github.com/lucas-clemente/quic-go/utils"
)
type requestBody struct {
requestRead bool
dataStream utils.Stream
}
// make sure the requestBody can be used as a http.Request.Body
var _ io.ReadCloser = &requestBody{}
func newRequestBody(stream utils.Stream) *requestBody {
return &requestBody{dataStream: stream}
}
func (b *requestBody) Read(p []byte) (int, error) {
b.requestRead = true
return b.dataStream.Read(p)
}
func (b *requestBody) Close() error {
// stream's Close() closes the write side, not the read side
return nil
}
package h2quic
import (
"bytes"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"golang.org/x/net/lex/httplex"
"github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils"
)
type requestWriter struct {
mutex sync.Mutex
headerStream utils.Stream
henc *hpack.Encoder
hbuf bytes.Buffer // HPACK encoder writes into this
}
const defaultUserAgent = "quic-go"
func newRequestWriter(headerStream utils.Stream) *requestWriter {
rw := &requestWriter{
headerStream: headerStream,
}
rw.henc = hpack.NewEncoder(&rw.hbuf)
return rw
}
func (w *requestWriter) WriteRequest(req *http.Request, dataStreamID protocol.StreamID, endStream, requestGzip bool) error {
// TODO: add support for trailers
// TODO: add support for gzip compression
// TODO: write continuation frames, if the header frame is too long
w.mutex.Lock()
defer w.mutex.Unlock()
w.encodeHeaders(req, requestGzip, "", actualContentLength(req))
h2framer := http2.NewFramer(w.headerStream, nil)
return h2framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: uint32(dataStreamID),
EndHeaders: true,
EndStream: endStream,
BlockFragment: w.hbuf.Bytes(),
Priority: http2.PriorityParam{Weight: 0xff},
})
}
// the rest of this files is copied from http2.Transport
func (w *requestWriter) encodeHeaders(req *http.Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) {
w.hbuf.Reset()
host := req.Host
if host == "" {
host = req.URL.Host
}
host, err := httplex.PunycodeHostPort(host)
if err != nil {
return nil, err
}
var path string
if req.Method != "CONNECT" {
path = req.URL.RequestURI()
if !validPseudoPath(path) {
orig := path
path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host)
if !validPseudoPath(path) {
if req.URL.Opaque != "" {
return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque)
} else {
return nil, fmt.Errorf("invalid request :path %q", orig)
}
}
}
}
// Check for any invalid headers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
for k, vv := range req.Header {
if !httplex.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("invalid HTTP header name %q", k)
}
for _, v := range vv {
if !httplex.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k)
}
}
}
// 8.1.2.3 Request Pseudo-Header Fields
// The :path pseudo-header field includes the path and query parts of the
// target URI (the path-absolute production and optionally a '?' character
// followed by the query production (see Sections 3.3 and 3.4 of
// [RFC3986]).
w.writeHeader(":authority", host)
w.writeHeader(":method", req.Method)
if req.Method != "CONNECT" {
w.writeHeader(":path", path)
w.writeHeader(":scheme", req.URL.Scheme)
}
if trailers != "" {
w.writeHeader("trailer", trailers)
}
var didUA bool
for k, vv := range req.Header {
lowKey := strings.ToLower(k)
switch lowKey {
case "host", "content-length":
// Host is :authority, already sent.
// Content-Length is automatic, set below.
continue
case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive":
// Per 8.1.2.2 Connection-Specific Header
// Fields, don't send connection-specific
// fields. We have already checked if any
// are error-worthy so just ignore the rest.
continue
case "user-agent":
// Match Go's http1 behavior: at most one
// User-Agent. If set to nil or empty string,
// then omit it. Otherwise if not mentioned,
// include the default (below).
didUA = true
if len(vv) < 1 {
continue
}
vv = vv[:1]
if vv[0] == "" {
continue
}
}
for _, v := range vv {
w.writeHeader(lowKey, v)
}
}
if shouldSendReqContentLength(req.Method, contentLength) {
w.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
}
if addGzipHeader {
w.writeHeader("accept-encoding", "gzip")
}
if !didUA {
w.writeHeader("user-agent", defaultUserAgent)
}
return w.hbuf.Bytes(), nil
}
func (w *requestWriter) writeHeader(name, value string) {
utils.Debugf("http2: Transport encoding header %q = %q", name, value)
w.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
}
// shouldSendReqContentLength reports whether the http2.Transport should send
// a "content-length" request header. This logic is basically a copy of the net/http
// transferWriter.shouldSendContentLength.
// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
// -1 means unknown.
func shouldSendReqContentLength(method string, contentLength int64) bool {
if contentLength > 0 {
return true
}
if contentLength < 0 {
return false
}
// For zero bodies, whether we send a content-length depends on the method.
// It also kinda doesn't matter for http2 either way, with END_STREAM.
switch method {
case "POST", "PUT", "PATCH":
return true
default:
return false
}
}
func validPseudoPath(v string) bool {
return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*"
}
// actualContentLength returns a sanitized version of
// req.ContentLength, where 0 actually means zero (not unknown) and -1
// means unknown.
func actualContentLength(req *http.Request) int64 {
if req.Body == nil {
return 0
}
if req.ContentLength != 0 {
return req.ContentLength
}
return -1
}
package h2quic
import (
"bytes"
"errors"
"io"
"io/ioutil"
"net/http"
"net/textproto"
"strconv"
"strings"
"golang.org/x/net/http2"
)
// copied from net/http2/transport.go
var errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
var noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil))
// from the handleResponse function
func responseFromHeaders(f *http2.MetaHeadersFrame) (*http.Response, error) {
if f.Truncated {
return nil, errResponseHeaderListSize
}
status := f.PseudoValue("status")
if status == "" {
return nil, errors.New("missing status pseudo header")
}
statusCode, err := strconv.Atoi(status)
if err != nil {
return nil, errors.New("malformed non-numeric status pseudo header")
}
if statusCode == 100 {
// TODO: handle this
// traceGot100Continue(cs.trace)
// if cs.on100 != nil {
// cs.on100() // forces any write delay timer to fire
// }
// cs.pastHeaders = false // do it all again
// return nil, nil
}
header := make(http.Header)
res := &http.Response{
Proto: "HTTP/2.0",
ProtoMajor: 2,
Header: header,
StatusCode: statusCode,
Status: status + " " + http.StatusText(statusCode),
}
for _, hf := range f.RegularFields() {
key := http.CanonicalHeaderKey(hf.Name)
if key == "Trailer" {
t := res.Trailer
if t == nil {
t = make(http.Header)
res.Trailer = t
}
foreachHeaderElement(hf.Value, func(v string) {
t[http.CanonicalHeaderKey(v)] = nil
})
} else {
header[key] = append(header[key], hf.Value)
}
}
return res, nil
}
// continuation of the handleResponse function
func setLength(res *http.Response, isHead, streamEnded bool) *http.Response {
if !streamEnded || isHead {
res.ContentLength = -1
if clens := res.Header["Content-Length"]; len(clens) == 1 {
if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
res.ContentLength = clen64
} else {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
}
} else if len(clens) > 1 {
// TODO: care? unlike http/1, it won't mess up our framing, so it's
// more safe smuggling-wise to ignore.
}
}
return res
}
// copied from net/http/server.go
// foreachHeaderElement splits v according to the "#rule" construction
// in RFC 2616 section 2.1 and calls fn for each non-empty element.
func foreachHeaderElement(v string, fn func(string)) {
v = textproto.TrimString(v)
if v == "" {
return
}
if !strings.Contains(v, ",") {
fn(v)
return
}
for _, f := range strings.Split(v, ",") {
if f = textproto.TrimString(f); f != "" {
fn(f)
}
}
}
// +build go1.7
package h2quic
import "net/http"
func setUncompressed(res *http.Response) {
res.Uncompressed = true
}
// +build !go1.7
package h2quic
import "net/http"
func setUncompressed(res *http.Response) {
// http.Response.Uncompressed was introduced in go 1.7
}
...@@ -21,6 +21,7 @@ type responseWriter struct { ...@@ -21,6 +21,7 @@ type responseWriter struct {
headerStreamMutex *sync.Mutex headerStreamMutex *sync.Mutex
header http.Header header http.Header
status int // status code passed to WriteHeader
headerWritten bool headerWritten bool
} }
...@@ -43,6 +44,7 @@ func (w *responseWriter) WriteHeader(status int) { ...@@ -43,6 +44,7 @@ func (w *responseWriter) WriteHeader(status int) {
return return
} }
w.headerWritten = true w.headerWritten = true
w.status = status
var headers bytes.Buffer var headers bytes.Buffer
enc := hpack.NewEncoder(&headers) enc := hpack.NewEncoder(&headers)
...@@ -72,6 +74,9 @@ func (w *responseWriter) Write(p []byte) (int, error) { ...@@ -72,6 +74,9 @@ func (w *responseWriter) Write(p []byte) (int, error) {
if !w.headerWritten { if !w.headerWritten {
w.WriteHeader(200) w.WriteHeader(200)
} }
if !bodyAllowedForStatus(w.status) {
return 0, http.ErrBodyNotAllowed
}
return w.dataStream.Write(p) return w.dataStream.Write(p)
} }
...@@ -79,3 +84,18 @@ func (w *responseWriter) Flush() {} ...@@ -79,3 +84,18 @@ func (w *responseWriter) Flush() {}
// test that we implement http.Flusher // test that we implement http.Flusher
var _ http.Flusher = &responseWriter{} var _ http.Flusher = &responseWriter{}
// copied from http2/http2.go
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 2616, section 4.4.
func bodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}
package h2quic
import (
"crypto/tls"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"golang.org/x/net/lex/httplex"
)
type h2quicClient interface {
Do(*http.Request) (*http.Response, error)
}
// QuicRoundTripper implements the http.RoundTripper interface
type QuicRoundTripper struct {
mutex sync.Mutex
// DisableCompression, if true, prevents the Transport from
// requesting compression with an "Accept-Encoding: gzip"
// request header when the Request contains no existing
// Accept-Encoding value. If the Transport requests gzip on
// its own and gets a gzipped response, it's transparently
// decoded in the Response.Body. However, if the user
// explicitly requested gzip it is not automatically
// uncompressed.
DisableCompression bool
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
clients map[string]h2quicClient
}
var _ http.RoundTripper = &QuicRoundTripper{}
// RoundTrip does a round trip
func (r *QuicRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.URL")
}
if req.URL.Host == "" {
closeRequestBody(req)
return nil, errors.New("quic: no Host in request URL")
}
if req.Header == nil {
closeRequestBody(req)
return nil, errors.New("quic: nil Request.Header")
}
if req.URL.Scheme == "https" {
for k, vv := range req.Header {
if !httplex.ValidHeaderFieldName(k) {
return nil, fmt.Errorf("quic: invalid http header field name %q", k)
}
for _, v := range vv {
if !httplex.ValidHeaderFieldValue(v) {
return nil, fmt.Errorf("quic: invalid http header field value %q for key %v", v, k)
}
}
}
} else {
closeRequestBody(req)
return nil, fmt.Errorf("quic: unsupported protocol scheme: %s", req.URL.Scheme)
}
if req.Method != "" && !validMethod(req.Method) {
closeRequestBody(req)
return nil, fmt.Errorf("quic: invalid method %q", req.Method)
}
hostname := authorityAddr("https", hostnameFromRequest(req))
client, err := r.getClient(hostname)
if err != nil {
return nil, err
}
return client.Do(req)
}
func (r *QuicRoundTripper) getClient(hostname string) (h2quicClient, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
if r.clients == nil {
r.clients = make(map[string]h2quicClient)
}
client, ok := r.clients[hostname]
if !ok {
var err error
client, err = NewClient(r, r.TLSClientConfig, hostname)
if err != nil {
return nil, err
}
r.clients[hostname] = client
}
return client, nil
}
func (r *QuicRoundTripper) disableCompression() bool {
return r.DisableCompression
}
func closeRequestBody(req *http.Request) {
if req.Body != nil {
req.Body.Close()
}
}
func validMethod(method string) bool {
/*
Method = "OPTIONS" ; Section 9.2
| "GET" ; Section 9.3
| "HEAD" ; Section 9.4
| "POST" ; Section 9.5
| "PUT" ; Section 9.6
| "DELETE" ; Section 9.7
| "TRACE" ; Section 9.8
| "CONNECT" ; Section 9.9
| extension-method
extension-method = token
token = 1*<any CHAR except CTLs or separators>
*/
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// copied from net/http/http.go
func isNotToken(r rune) bool {
return !httplex.IsTokenRune(r)
}
...@@ -4,7 +4,6 @@ import ( ...@@ -4,7 +4,6 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"runtime" "runtime"
...@@ -113,6 +112,7 @@ func (s *Server) handleStream(session streamCreator, stream utils.Stream) { ...@@ -113,6 +112,7 @@ func (s *Server) handleStream(session streamCreator, stream utils.Stream) {
if _, ok := err.(*qerr.QuicError); !ok { if _, ok := err.(*qerr.QuicError); !ok {
utils.Errorf("error handling h2 request: %s", err.Error()) utils.Errorf("error handling h2 request: %s", err.Error())
} }
session.Close(qerr.Error(qerr.InvalidHeadersStreamData, err.Error()))
return return
} }
} }
...@@ -124,7 +124,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, ...@@ -124,7 +124,10 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
if err != nil { if err != nil {
return err return err
} }
h2headersFrame := h2frame.(*http2.HeadersFrame) h2headersFrame, ok := h2frame.(*http2.HeadersFrame)
if !ok {
return qerr.Error(qerr.InvalidHeadersStreamData, "expected a header frame")
}
if !h2headersFrame.HeadersEnded() { if !h2headersFrame.HeadersEnded() {
return errors.New("http2 header continuation not implemented") return errors.New("http2 header continuation not implemented")
} }
...@@ -152,13 +155,15 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, ...@@ -152,13 +155,15 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
return err return err
} }
var streamEnded bool
if h2headersFrame.StreamEnded() { if h2headersFrame.StreamEnded() {
dataStream.CloseRemote(0) dataStream.CloseRemote(0)
streamEnded = true
_, _ = dataStream.Read([]byte{0}) // read the eof _, _ = dataStream.Read([]byte{0}) // read the eof
} }
// stream's Close() closes the write side, not the read side reqBody := newRequestBody(dataStream)
req.Body = ioutil.NopCloser(dataStream) req.Body = reqBody
responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID)) responseWriter := newResponseWriter(headerStream, headerStreamMutex, dataStream, protocol.StreamID(h2headersFrame.StreamID))
...@@ -187,6 +192,9 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream, ...@@ -187,6 +192,9 @@ func (s *Server) handleRequest(session streamCreator, headerStream utils.Stream,
responseWriter.WriteHeader(200) responseWriter.WriteHeader(200)
} }
if responseWriter.dataStream != nil { if responseWriter.dataStream != nil {
if !streamEnded && !reqBody.requestRead {
responseWriter.dataStream.Reset(nil)
}
responseWriter.dataStream.Close() responseWriter.dataStream.Close()
} }
if s.CloseAfterFirstRequest { if s.CloseAfterFirstRequest {
......
package handshake
import "github.com/lucas-clemente/quic-go/protocol"
// CryptoSetup is a crypto setup
type CryptoSetup interface {
HandleCryptoStream() error
Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, error)
Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte
LockForSealing()
UnlockForSealing()
HandshakeComplete() bool
// TODO: clean up this interface
DiversificationNonce() []byte // only needed for cryptoSetupServer
SetDiversificationNonce([]byte) error // only needed for cryptoSetupClient
}
...@@ -95,12 +95,19 @@ func WriteHandshakeMessage(b *bytes.Buffer, messageTag Tag, data map[Tag][]byte) ...@@ -95,12 +95,19 @@ func WriteHandshakeMessage(b *bytes.Buffer, messageTag Tag, data map[Tag][]byte)
func printHandshakeMessage(data map[Tag][]byte) string { func printHandshakeMessage(data map[Tag][]byte) string {
var res string var res string
var pad string
for k, v := range data { for k, v := range data {
if k == TagPAD { if k == TagPAD {
continue pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(k), len(v))
} else {
res += fmt.Sprintf("\t%s: %#v\n", tagToString(k), string(v))
} }
res += fmt.Sprintf("\t%s: %#v\n", tagToString(k), string(v))
} }
if len(pad) > 0 {
res += pad
}
return res return res
} }
......
...@@ -10,13 +10,14 @@ import ( ...@@ -10,13 +10,14 @@ import (
// ServerConfig is a server config // ServerConfig is a server config
type ServerConfig struct { type ServerConfig struct {
kex crypto.KeyExchange kex crypto.KeyExchange
signer crypto.Signer certChain crypto.CertChain
ID []byte ID []byte
obit []byte
stkSource crypto.StkSource stkSource crypto.StkSource
} }
// NewServerConfig creates a new server config // NewServerConfig creates a new server config
func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfig, error) { func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*ServerConfig, error) {
id := make([]byte, 16) id := make([]byte, 16)
_, err := rand.Read(id) _, err := rand.Read(id)
if err != nil { if err != nil {
...@@ -27,6 +28,12 @@ func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfi ...@@ -27,6 +28,12 @@ func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfi
if _, err = rand.Read(stkSecret); err != nil { if _, err = rand.Read(stkSecret); err != nil {
return nil, err return nil, err
} }
obit := make([]byte, 8)
if _, err = rand.Read(obit); err != nil {
return nil, err
}
stkSource, err := crypto.NewStkSource(stkSecret) stkSource, err := crypto.NewStkSource(stkSecret)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -34,8 +41,9 @@ func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfi ...@@ -34,8 +41,9 @@ func NewServerConfig(kex crypto.KeyExchange, signer crypto.Signer) (*ServerConfi
return &ServerConfig{ return &ServerConfig{
kex: kex, kex: kex,
signer: signer, certChain: certChain,
ID: id, ID: id,
obit: obit,
stkSource: stkSource, stkSource: stkSource,
}, nil }, nil
} }
...@@ -48,7 +56,7 @@ func (s *ServerConfig) Get() []byte { ...@@ -48,7 +56,7 @@ func (s *ServerConfig) Get() []byte {
TagKEXS: []byte("C255"), TagKEXS: []byte("C255"),
TagAEAD: []byte("AESG"), TagAEAD: []byte("AESG"),
TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...), TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...),
TagOBIT: {0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7}, TagOBIT: s.obit,
TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
}) })
return serverConfig.Bytes() return serverConfig.Bytes()
...@@ -56,10 +64,10 @@ func (s *ServerConfig) Get() []byte { ...@@ -56,10 +64,10 @@ func (s *ServerConfig) Get() []byte {
// Sign the server config and CHLO with the server's keyData // Sign the server config and CHLO with the server's keyData
func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) { func (s *ServerConfig) Sign(sni string, chlo []byte) ([]byte, error) {
return s.signer.SignServerProof(sni, chlo, s.Get()) return s.certChain.SignServerProof(sni, chlo, s.Get())
} }
// GetCertsCompressed returns the certificate data // GetCertsCompressed returns the certificate data
func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) { func (s *ServerConfig) GetCertsCompressed(sni string, commonSetHashes, compressedHashes []byte) ([]byte, error) {
return s.signer.GetCertsCompressed(sni, commonSetHashes, compressedHashes) return s.certChain.GetCertsCompressed(sni, commonSetHashes, compressedHashes)
} }
package handshake
import (
"bytes"
"encoding/binary"
"errors"
"math"
"time"
"github.com/lucas-clemente/quic-go/crypto"
"github.com/lucas-clemente/quic-go/qerr"
"github.com/lucas-clemente/quic-go/utils"
)
type serverConfigClient struct {
raw []byte
ID []byte
obit []byte
expiry time.Time
kex crypto.KeyExchange
sharedSecret []byte
}
var (
errMessageNotServerConfig = errors.New("ServerConfig must have TagSCFG")
)
// parseServerConfig parses a server config
func parseServerConfig(data []byte) (*serverConfigClient, error) {
tag, tagMap, err := ParseHandshakeMessage(bytes.NewReader(data))
if err != nil {
return nil, err
}
if tag != TagSCFG {
return nil, errMessageNotServerConfig
}
scfg := &serverConfigClient{raw: data}
err = scfg.parseValues(tagMap)
if err != nil {
return nil, err
}
return scfg, nil
}
func (s *serverConfigClient) parseValues(tagMap map[Tag][]byte) error {
// SCID
scfgID, ok := tagMap[TagSCID]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "SCID")
}
if len(scfgID) != 16 {
return qerr.Error(qerr.CryptoInvalidValueLength, "SCID")
}
s.ID = scfgID
// KEXS
// TODO: allow for P256 in the list
// TODO: setup Key Exchange
kexs, ok := tagMap[TagKEXS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "KEXS")
}
if len(kexs)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "KEXS")
}
if !bytes.Equal(kexs, []byte("C255")) {
return qerr.Error(qerr.CryptoNoSupport, "KEXS")
}
// AEAD
aead, ok := tagMap[TagAEAD]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "AEAD")
}
if len(aead)%4 != 0 {
return qerr.Error(qerr.CryptoInvalidValueLength, "AEAD")
}
var aesgFound bool
for i := 0; i < len(aead)/4; i++ {
if bytes.Equal(aead[4*i:4*i+4], []byte("AESG")) {
aesgFound = true
break
}
}
if !aesgFound {
return qerr.Error(qerr.CryptoNoSupport, "AEAD")
}
// PUBS
// TODO: save this value
pubs, ok := tagMap[TagPUBS]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "PUBS")
}
if len(pubs) != 35 {
return qerr.Error(qerr.CryptoInvalidValueLength, "PUBS")
}
var err error
s.kex, err = crypto.NewCurve25519KEX()
if err != nil {
return err
}
// the PUBS value is always prepended by []byte{0x20, 0x00, 0x00}
s.sharedSecret, err = s.kex.CalculateSharedKey(pubs[3:])
if err != nil {
return err
}
// OBIT
obit, ok := tagMap[TagOBIT]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "OBIT")
}
if len(obit) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "OBIT")
}
s.obit = obit
// EXPY
expy, ok := tagMap[TagEXPY]
if !ok {
return qerr.Error(qerr.CryptoMessageParameterNotFound, "EXPY")
}
if len(expy) != 8 {
return qerr.Error(qerr.CryptoInvalidValueLength, "EXPY")
}
// make sure that the value doesn't overflow an int64
// furthermore, values close to MaxInt64 are not a valid input to time.Unix, thus set MaxInt64/2 as the maximum value here
expyTimestamp := utils.MinUint64(binary.LittleEndian.Uint64(expy), math.MaxInt64/2)
s.expiry = time.Unix(int64(expyTimestamp), 0)
// TODO: implement VER
return nil
}
func (s *serverConfigClient) IsExpired() bool {
return s.expiry.Before(time.Now())
}
func (s *serverConfigClient) Get() []byte {
return s.raw
}
...@@ -59,6 +59,8 @@ const ( ...@@ -59,6 +59,8 @@ const (
// TagNONC is the client nonce // TagNONC is the client nonce
TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24 TagNONC Tag = 'N' + 'O'<<8 + 'N'<<16 + 'C'<<24
// TagXLCT is the expected leaf certificate
TagXLCT Tag = 'X' + 'L'<<8 + 'C'<<16 + 'T'<<24
// TagSCID is the server config ID // TagSCID is the server config ID
TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24 TagSCID Tag = 'S' + 'C'<<8 + 'I'<<16 + 'D'<<24
......
...@@ -18,58 +18,72 @@ type packedPacket struct { ...@@ -18,58 +18,72 @@ type packedPacket struct {
type packetPacker struct { type packetPacker struct {
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
perspective protocol.Perspective
version protocol.VersionNumber version protocol.VersionNumber
cryptoSetup *handshake.CryptoSetup cryptoSetup handshake.CryptoSetup
packetNumberGenerator *packetNumberGenerator packetNumberGenerator *packetNumberGenerator
connectionParametersManager *handshake.ConnectionParametersManager connectionParameters handshake.ConnectionParametersManager
streamFramer *streamFramer streamFramer *streamFramer
controlFrames []frames.Frame controlFrames []frames.Frame
} }
func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup *handshake.CryptoSetup, connectionParametersHandler *handshake.ConnectionParametersManager, streamFramer *streamFramer, version protocol.VersionNumber) *packetPacker { func newPacketPacker(connectionID protocol.ConnectionID, cryptoSetup handshake.CryptoSetup, connectionParameters handshake.ConnectionParametersManager, streamFramer *streamFramer, perspective protocol.Perspective, version protocol.VersionNumber) *packetPacker {
return &packetPacker{ return &packetPacker{
cryptoSetup: cryptoSetup, cryptoSetup: cryptoSetup,
connectionID: connectionID, connectionID: connectionID,
connectionParametersManager: connectionParametersHandler, connectionParameters: connectionParameters,
version: version, perspective: perspective,
streamFramer: streamFramer, version: version,
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength), streamFramer: streamFramer,
packetNumberGenerator: newPacketNumberGenerator(protocol.SkipPacketAveragePeriodLength),
} }
} }
func (p *packetPacker) PackConnectionClose(frame *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) { // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
return p.packPacket(nil, []frames.Frame{frame}, leastUnacked, true, false) func (p *packetPacker) PackConnectionClose(ccf *frames.ConnectionCloseFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
// in case the connection is closed, all queued control frames aren't of any use anymore
// discard them and queue the ConnectionCloseFrame
p.controlFrames = []frames.Frame{ccf}
return p.packPacket(nil, leastUnacked)
} }
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, maySendOnlyAck bool) (*packedPacket, error) { // PackPacket packs a new packet
return p.packPacket(stopWaitingFrame, controlFrames, leastUnacked, false, maySendOnlyAck) // the stopWaitingFrame is *guaranteed* to be included in the next packet
// the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
func (p *packetPacker) PackPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
p.controlFrames = append(p.controlFrames, controlFrames...)
return p.packPacket(stopWaitingFrame, leastUnacked)
} }
func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, controlFrames []frames.Frame, leastUnacked protocol.PacketNumber, onlySendOneControlFrame, maySendOnlyAck bool) (*packedPacket, error) { func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, leastUnacked protocol.PacketNumber) (*packedPacket, error) {
if len(controlFrames) > 0 {
p.controlFrames = append(p.controlFrames, controlFrames...)
}
currentPacketNumber := p.packetNumberGenerator.Peek()
// cryptoSetup needs to be locked here, so that the AEADs are not changed between // cryptoSetup needs to be locked here, so that the AEADs are not changed between
// calling DiversificationNonce() and Seal(). // calling DiversificationNonce() and Seal().
p.cryptoSetup.LockForSealing() p.cryptoSetup.LockForSealing()
defer p.cryptoSetup.UnlockForSealing() defer p.cryptoSetup.UnlockForSealing()
currentPacketNumber := p.packetNumberGenerator.Peek()
packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(currentPacketNumber, leastUnacked) packetNumberLen := protocol.GetPacketNumberLengthForPublicHeader(currentPacketNumber, leastUnacked)
responsePublicHeader := &PublicHeader{ responsePublicHeader := &PublicHeader{
ConnectionID: p.connectionID, ConnectionID: p.connectionID,
PacketNumber: currentPacketNumber, PacketNumber: currentPacketNumber,
PacketNumberLen: packetNumberLen, PacketNumberLen: packetNumberLen,
TruncateConnectionID: p.connectionParametersManager.TruncateConnectionID(), TruncateConnectionID: p.connectionParameters.TruncateConnectionID(),
DiversificationNonce: p.cryptoSetup.DiversificationNonce(), }
if p.perspective == protocol.PerspectiveServer {
responsePublicHeader.DiversificationNonce = p.cryptoSetup.DiversificationNonce()
} }
publicHeaderLength, err := responsePublicHeader.GetLength() // TODO: stop sending version numbers once a version has been negotiated
if p.perspective == protocol.PerspectiveClient {
responsePublicHeader.VersionFlag = true
responsePublicHeader.VersionNumber = p.version
}
publicHeaderLength, err := responsePublicHeader.GetLength(p.perspective)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -79,9 +93,15 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con ...@@ -79,9 +93,15 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con
stopWaitingFrame.PacketNumberLen = packetNumberLen stopWaitingFrame.PacketNumberLen = packetNumberLen
} }
// we're packing a ConnectionClose, don't add any StreamFrames
var isConnectionClose bool
if len(p.controlFrames) == 1 {
_, isConnectionClose = p.controlFrames[0].(*frames.ConnectionCloseFrame)
}
var payloadFrames []frames.Frame var payloadFrames []frames.Frame
if onlySendOneControlFrame { if isConnectionClose {
payloadFrames = []frames.Frame{controlFrames[0]} payloadFrames = []frames.Frame{p.controlFrames[0]}
} else { } else {
payloadFrames, err = p.composeNextPacket(stopWaitingFrame, publicHeaderLength) payloadFrames, err = p.composeNextPacket(stopWaitingFrame, publicHeaderLength)
if err != nil { if err != nil {
...@@ -94,26 +114,14 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con ...@@ -94,26 +114,14 @@ func (p *packetPacker) packPacket(stopWaitingFrame *frames.StopWaitingFrame, con
return nil, nil return nil, nil
} }
// Don't send out packets that only contain a StopWaitingFrame // Don't send out packets that only contain a StopWaitingFrame
if !onlySendOneControlFrame && len(payloadFrames) == 1 && stopWaitingFrame != nil { if len(payloadFrames) == 1 && stopWaitingFrame != nil {
return nil, nil return nil, nil
} }
// Don't send out packets that only contain an ACK (plus optional STOP_WAITING), if requested
if !maySendOnlyAck {
if len(payloadFrames) == 1 {
if _, ok := payloadFrames[0].(*frames.AckFrame); ok {
return nil, nil
}
} else if len(payloadFrames) == 2 && stopWaitingFrame != nil {
if _, ok := payloadFrames[1].(*frames.AckFrame); ok {
return nil, nil
}
}
}
raw := getPacketBuffer() raw := getPacketBuffer()
buffer := bytes.NewBuffer(raw) buffer := bytes.NewBuffer(raw)
if err = responsePublicHeader.WritePublicHeader(buffer, p.version); err != nil { if err = responsePublicHeader.Write(buffer, p.version, p.perspective); err != nil {
return nil, err return nil, err
} }
......
...@@ -11,10 +11,6 @@ import ( ...@@ -11,10 +11,6 @@ import (
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
) )
type unpackedPacket struct {
frames []frames.Frame
}
type packetUnpacker struct { type packetUnpacker struct {
version protocol.VersionNumber version protocol.VersionNumber
aead crypto.AEAD aead crypto.AEAD
......
package protocol
// EncryptionLevel is the encryption level
// Default value is Unencrypted
type EncryptionLevel int
const (
// Unencrypted is not encrypted
Unencrypted EncryptionLevel = iota
// EncryptionSecure is encrypted, but not forward secure
EncryptionSecure
// EncryptionForwardSecure is forward secure
EncryptionForwardSecure
)
package protocol
// Perspective determines if we're acting as a server or a client
type Perspective int
// the perspectives
const (
PerspectiveServer Perspective = 1
PerspectiveClient Perspective = 2
)
...@@ -64,3 +64,9 @@ const MaxRetransmissionTime = 60 * time.Second ...@@ -64,3 +64,9 @@ const MaxRetransmissionTime = 60 * time.Second
// ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have. // ClientHelloMinimumSize is the minimum size the server expects an inchoate CHLO to have.
const ClientHelloMinimumSize = 1024 const ClientHelloMinimumSize = 1024
// MaxClientHellos is the maximum number of times we'll send a client hello
// The value 3 accounts for:
// * one failure due to an incorrect or missing source-address token
// * one failure due the server's certificate chain being unavailible and the server being unwilling to send it without a valid source-address token
const MaxClientHellos = 3
...@@ -3,31 +3,48 @@ package protocol ...@@ -3,31 +3,48 @@ package protocol
import "time" import "time"
// DefaultMaxCongestionWindow is the default for the max congestion window // DefaultMaxCongestionWindow is the default for the max congestion window
const DefaultMaxCongestionWindow PacketNumber = 1000 const DefaultMaxCongestionWindow = 1000
// InitialCongestionWindow is the initial congestion window in QUIC packets // InitialCongestionWindow is the initial congestion window in QUIC packets
const InitialCongestionWindow PacketNumber = 32 const InitialCongestionWindow = 32
// MaxUndecryptablePackets limits the number of undecryptable packets that a // MaxUndecryptablePackets limits the number of undecryptable packets that a
// session queues for later until it sends a public reset. // session queues for later until it sends a public reset.
const MaxUndecryptablePackets = 10 const MaxUndecryptablePackets = 10
// AckSendDelay is the maximal time delay applied to packets containing only ACKs // AckSendDelay is the maximum delay that can be applied to an ACK for a retransmittable packet
const AckSendDelay = 5 * time.Millisecond // This is the value Chromium is using
const AckSendDelay = 25 * time.Millisecond
// ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data // ReceiveStreamFlowControlWindow is the stream-level flow control window for receiving data
// This is the value that Google servers are using // This is the value that Google servers are using
const ReceiveStreamFlowControlWindow ByteCount = (1 << 20) // 1 MB const ReceiveStreamFlowControlWindow ByteCount = (1 << 10) * 32 // 32 kB
// ReceiveConnectionFlowControlWindow is the stream-level flow control window for receiving data // ReceiveConnectionFlowControlWindow is the connection-level flow control window for receiving data
// This is the value that Google servers are using // This is the value that Google servers are using
const ReceiveConnectionFlowControlWindow ByteCount = (1 << 20) * 1.5 // 1.5 MB const ReceiveConnectionFlowControlWindow ByteCount = (1 << 10) * 48 // 48 kB
// MaxReceiveStreamFlowControlWindowServer is the maximum stream-level flow control window for receiving data
// This is the value that Google servers are using
const MaxReceiveStreamFlowControlWindowServer ByteCount = 1 * (1 << 20) // 1 MB
// MaxReceiveConnectionFlowControlWindowServer is the connection-level flow control window for receiving data
// This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowServer ByteCount = 1.5 * (1 << 20) // 1.5 MB
// MaxReceiveStreamFlowControlWindowClient is the maximum stream-level flow control window for receiving data, for the client
// This is the value that Chromium is using
const MaxReceiveStreamFlowControlWindowClient ByteCount = 6 * (1 << 20) // 6 MB
// MaxReceiveConnectionFlowControlWindowClient is the connection-level flow control window for receiving data, for the server
// This is the value that Google servers are using
const MaxReceiveConnectionFlowControlWindowClient ByteCount = 15 * (1 << 20) // 15 MB
// MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection // MaxStreamsPerConnection is the maximum value accepted for the number of streams per connection
const MaxStreamsPerConnection = 100 const MaxStreamsPerConnection = 100
// MaxIncomingDynamicStreams is the maximum value accepted for the incoming number of dynamic streams per connection // MaxIncomingDynamicStreamsPerConnection is the maximum value accepted for the incoming number of dynamic streams per connection
const MaxIncomingDynamicStreams = 100 const MaxIncomingDynamicStreamsPerConnection = 100
// MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used. // MaxStreamsMultiplier is the slack the client is allowed for the maximum number of streams per connection, needed e.g. when packets are out of order or dropped. The minimum of this procentual increase and the absolute increment specified by MaxStreamsMinimumIncrement is used.
const MaxStreamsMultiplier = 1.1 const MaxStreamsMultiplier = 1.1
...@@ -60,8 +77,17 @@ const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow ...@@ -60,8 +77,17 @@ const MaxTrackedSentPackets = 2 * DefaultMaxCongestionWindow
// MaxTrackedReceivedPackets is the maximum number of received packets saved for doing the entropy calculations // MaxTrackedReceivedPackets is the maximum number of received packets saved for doing the entropy calculations
const MaxTrackedReceivedPackets = 2 * DefaultMaxCongestionWindow const MaxTrackedReceivedPackets = 2 * DefaultMaxCongestionWindow
// MaxTrackedReceivedAckRanges is the maximum number of ACK ranges tracked
const MaxTrackedReceivedAckRanges = DefaultMaxCongestionWindow
// MaxPacketsReceivedBeforeAckSend is the number of packets that can be received before an ACK frame is sent
const MaxPacketsReceivedBeforeAckSend = 20
// RetransmittablePacketsBeforeAck is the number of retransmittable that an ACK is sent for
const RetransmittablePacketsBeforeAck = 2
// MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames // MaxStreamFrameSorterGaps is the maximum number of gaps between received StreamFrames
// prevents DOS attacks against the streamFrameSorter // prevents DoS attacks against the streamFrameSorter
const MaxStreamFrameSorterGaps = 1000 const MaxStreamFrameSorterGaps = 1000
// CryptoMaxParams is the upper limit for the number of parameters in a crypto message. // CryptoMaxParams is the upper limit for the number of parameters in a crypto message.
...@@ -69,7 +95,7 @@ const MaxStreamFrameSorterGaps = 1000 ...@@ -69,7 +95,7 @@ const MaxStreamFrameSorterGaps = 1000
const CryptoMaxParams = 128 const CryptoMaxParams = 128
// CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message. // CryptoParameterMaxLength is the upper limit for the length of a parameter in a crypto message.
const CryptoParameterMaxLength = 2000 const CryptoParameterMaxLength = 4000
// EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX. // EphermalKeyLifetime is the lifetime of the ephermal key during the handshake, see handshake.getEphermalKEX.
const EphermalKeyLifetime = time.Minute const EphermalKeyLifetime = time.Minute
...@@ -77,14 +103,21 @@ const EphermalKeyLifetime = time.Minute ...@@ -77,14 +103,21 @@ const EphermalKeyLifetime = time.Minute
// InitialIdleTimeout is the timeout before the handshake succeeds. // InitialIdleTimeout is the timeout before the handshake succeeds.
const InitialIdleTimeout = 5 * time.Second const InitialIdleTimeout = 5 * time.Second
// DefaultIdleTimeout is the default idle timeout. // DefaultIdleTimeout is the default idle timeout, for the server
const DefaultIdleTimeout = 30 * time.Second const DefaultIdleTimeout = 30 * time.Second
// MaxIdleTimeout is the maximum idle timeout that can be negotiated. // MaxIdleTimeoutServer is the maximum idle timeout that can be negotiated, for the server
const MaxIdleTimeout = 1 * time.Minute const MaxIdleTimeoutServer = 1 * time.Minute
// MaxIdleTimeoutClient is the idle timeout that the client suggests to the server
const MaxIdleTimeoutClient = 2 * time.Minute
// MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds. // MaxTimeForCryptoHandshake is the default timeout for a connection until the crypto handshake succeeds.
const MaxTimeForCryptoHandshake = 10 * time.Second const MaxTimeForCryptoHandshake = 10 * time.Second
// ClosedSessionDeleteTimeout the server ignores packets arriving on a connection that is already closed
// after this time all information about the old connection will be deleted
const ClosedSessionDeleteTimeout = time.Minute
// NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space // NumCachedCertificates is the number of cached compressed certificate chains, each taking ~1K space
const NumCachedCertificates = 128 const NumCachedCertificates = 128
...@@ -14,10 +14,12 @@ const ( ...@@ -14,10 +14,12 @@ const (
Version34 VersionNumber = 34 + iota Version34 VersionNumber = 34 + iota
Version35 Version35
Version36 Version36
VersionWhatever = 0 // for when the version doesn't matter VersionWhatever = 0 // for when the version doesn't matter
VersionUnsupported = -1
) )
// SupportedVersions lists the versions that the server supports // SupportedVersions lists the versions that the server supports
// must be in sorted order
var SupportedVersions = []VersionNumber{ var SupportedVersions = []VersionNumber{
Version34, Version35, Version36, Version34, Version35, Version36,
} }
...@@ -49,6 +51,28 @@ func IsSupportedVersion(v VersionNumber) bool { ...@@ -49,6 +51,28 @@ func IsSupportedVersion(v VersionNumber) bool {
return false return false
} }
// HighestSupportedVersion finds the highest version number that is both present in other and in SupportedVersions
// the versions in other do not need to be ordered
// it returns true and the version number, if there is one, otherwise false
func HighestSupportedVersion(other []VersionNumber) (bool, VersionNumber) {
var otherSupported []VersionNumber
for _, ver := range other {
if ver != VersionUnsupported {
otherSupported = append(otherSupported, ver)
}
}
for i := len(SupportedVersions) - 1; i >= 0; i-- {
for _, ver := range otherSupported {
if ver == SupportedVersions[i] {
return true, ver
}
}
}
return false, 0
}
func init() { func init() {
var b bytes.Buffer var b bytes.Buffer
for _, v := range SupportedVersions { for _, v := range SupportedVersions {
......
...@@ -3,7 +3,6 @@ package quic ...@@ -3,7 +3,6 @@ package quic
import ( import (
"bytes" "bytes"
"errors" "errors"
"io"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/qerr" "github.com/lucas-clemente/quic-go/qerr"
...@@ -11,11 +10,11 @@ import ( ...@@ -11,11 +10,11 @@ import (
) )
var ( var (
errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set") errPacketNumberLenNotSet = errors.New("PublicHeader: PacketNumberLen not set")
errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time") errResetAndVersionFlagSet = errors.New("PublicHeader: Reset Flag and Version Flag should not be set at the same time")
errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported") errReceivedTruncatedConnectionID = qerr.Error(qerr.InvalidPacketHeader, "receiving packets with truncated ConnectionID is not supported")
errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0") errInvalidConnectionID = qerr.Error(qerr.InvalidPacketHeader, "connection ID cannot be 0")
errGetLengthOnlyForRegularPackets = errors.New("PublicHeader: GetLength can only be called for regular packets") errGetLengthNotForVersionNegotiation = errors.New("PublicHeader: GetLength cannot be called for VersionNegotiation packets")
) )
// The PublicHeader of a QUIC packet // The PublicHeader of a QUIC packet
...@@ -27,16 +26,19 @@ type PublicHeader struct { ...@@ -27,16 +26,19 @@ type PublicHeader struct {
TruncateConnectionID bool TruncateConnectionID bool
PacketNumberLen protocol.PacketNumberLen PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber PacketNumber protocol.PacketNumber
VersionNumber protocol.VersionNumber VersionNumber protocol.VersionNumber // VersionNumber sent by the client
SupportedVersions []protocol.VersionNumber // VersionNumbers sent by the server
DiversificationNonce []byte DiversificationNonce []byte
} }
// WritePublicHeader writes a public header // Write writes a public header
func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.VersionNumber) error { func (h *PublicHeader) Write(b *bytes.Buffer, version protocol.VersionNumber, pers protocol.Perspective) error {
publicFlagByte := uint8(0x00) publicFlagByte := uint8(0x00)
if h.VersionFlag && h.ResetFlag { if h.VersionFlag && h.ResetFlag {
return errResetAndVersionFlagSet return errResetAndVersionFlagSet
} }
if h.VersionFlag { if h.VersionFlag {
publicFlagByte |= 0x01 publicFlagByte |= 0x01
} }
...@@ -54,7 +56,8 @@ func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.Versi ...@@ -54,7 +56,8 @@ func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.Versi
publicFlagByte |= 0x04 publicFlagByte |= 0x04
} }
if !h.ResetFlag && !h.VersionFlag { // only set PacketNumberLen bits if a packet number will be written
if h.hasPacketNumber(pers) {
switch h.PacketNumberLen { switch h.PacketNumberLen {
case protocol.PacketNumberLen1: case protocol.PacketNumberLen1:
publicFlagByte |= 0x00 publicFlagByte |= 0x00
...@@ -73,30 +76,42 @@ func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.Versi ...@@ -73,30 +76,42 @@ func (h *PublicHeader) WritePublicHeader(b *bytes.Buffer, version protocol.Versi
utils.WriteUint64(b, uint64(h.ConnectionID)) utils.WriteUint64(b, uint64(h.ConnectionID))
} }
if h.VersionFlag && pers == protocol.PerspectiveClient {
utils.WriteUint32(b, protocol.VersionNumberToTag(h.VersionNumber))
}
if len(h.DiversificationNonce) > 0 { if len(h.DiversificationNonce) > 0 {
b.Write(h.DiversificationNonce) b.Write(h.DiversificationNonce)
} }
if !h.ResetFlag && !h.VersionFlag { // if we're a server, and the VersionFlag is set, we must not include anything else in the packet
switch h.PacketNumberLen { if !h.hasPacketNumber(pers) {
case protocol.PacketNumberLen1: return nil
b.WriteByte(uint8(h.PacketNumber)) }
case protocol.PacketNumberLen2:
utils.WriteUint16(b, uint16(h.PacketNumber)) if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 {
case protocol.PacketNumberLen4: return errPacketNumberLenNotSet
utils.WriteUint32(b, uint32(h.PacketNumber)) }
case protocol.PacketNumberLen6:
utils.WriteUint48(b, uint64(h.PacketNumber)) switch h.PacketNumberLen {
default: case protocol.PacketNumberLen1:
return errPacketNumberLenNotSet b.WriteByte(uint8(h.PacketNumber))
} case protocol.PacketNumberLen2:
utils.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.WriteUint32(b, uint32(h.PacketNumber))
case protocol.PacketNumberLen6:
utils.WriteUint48(b, uint64(h.PacketNumber))
default:
return errPacketNumberLenNotSet
} }
return nil return nil
} }
// ParsePublicHeader parses a QUIC packet's public header // ParsePublicHeader parses a QUIC packet's public header
func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { // the packetSentBy is the perspective of the peer that sent this PublicHeader, i.e. if we're the server, packetSentBy should be PerspectiveClient
func ParsePublicHeader(b *bytes.Reader, packetSentBy protocol.Perspective) (*PublicHeader, error) {
header := &PublicHeader{} header := &PublicHeader{}
// First byte // First byte
...@@ -117,15 +132,17 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { ...@@ -117,15 +132,17 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) {
return nil, errReceivedTruncatedConnectionID return nil, errReceivedTruncatedConnectionID
} }
switch publicFlagByte & 0x30 { if header.hasPacketNumber(packetSentBy) {
case 0x30: switch publicFlagByte & 0x30 {
header.PacketNumberLen = protocol.PacketNumberLen6 case 0x30:
case 0x20: header.PacketNumberLen = protocol.PacketNumberLen6
header.PacketNumberLen = protocol.PacketNumberLen4 case 0x20:
case 0x10: header.PacketNumberLen = protocol.PacketNumberLen4
header.PacketNumberLen = protocol.PacketNumberLen2 case 0x10:
case 0x00: header.PacketNumberLen = protocol.PacketNumberLen2
header.PacketNumberLen = protocol.PacketNumberLen1 case 0x00:
header.PacketNumberLen = protocol.PacketNumberLen1
}
} }
// Connection ID // Connection ID
...@@ -133,46 +150,111 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) { ...@@ -133,46 +150,111 @@ func ParsePublicHeader(b io.ByteReader) (*PublicHeader, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
header.ConnectionID = protocol.ConnectionID(connID) header.ConnectionID = protocol.ConnectionID(connID)
if header.ConnectionID == 0 { if header.ConnectionID == 0 {
return nil, errInvalidConnectionID return nil, errInvalidConnectionID
} }
if packetSentBy == protocol.PerspectiveServer && publicFlagByte&0x04 > 0 {
// TODO: remove the if once the Google servers send the correct value
// assume that a packet doesn't contain a diversification nonce if the version flag or the reset flag is set, no matter what the public flag says
// see https://github.com/lucas-clemente/quic-go/issues/232
if !header.VersionFlag && !header.ResetFlag {
header.DiversificationNonce = make([]byte, 32)
// this Read can never return an EOF for a valid packet, since the diversification nonce is followed by the packet number
_, err = b.Read(header.DiversificationNonce)
if err != nil {
return nil, err
}
}
}
// Version (optional) // Version (optional)
if header.VersionFlag { if !header.ResetFlag {
var versionTag uint32 if header.VersionFlag {
versionTag, err = utils.ReadUint32(b) if packetSentBy == protocol.PerspectiveClient {
if err != nil { var versionTag uint32
return nil, err versionTag, err = utils.ReadUint32(b)
if err != nil {
return nil, err
}
header.VersionNumber = protocol.VersionTagToNumber(versionTag)
} else { // parse the version negotiaton packet
if b.Len()%4 != 0 {
return nil, qerr.InvalidVersionNegotiationPacket
}
header.SupportedVersions = make([]protocol.VersionNumber, 0)
for {
var versionTag uint32
versionTag, err = utils.ReadUint32(b)
if err != nil {
break
}
v := protocol.VersionTagToNumber(versionTag)
if !protocol.IsSupportedVersion(v) {
v = protocol.VersionUnsupported
}
header.SupportedVersions = append(header.SupportedVersions, v)
}
}
} }
header.VersionNumber = protocol.VersionTagToNumber(versionTag)
} }
// Packet number // Packet number
packetNumber, err := utils.ReadUintN(b, uint8(header.PacketNumberLen)) if header.hasPacketNumber(packetSentBy) {
if err != nil { packetNumber, err := utils.ReadUintN(b, uint8(header.PacketNumberLen))
return nil, err if err != nil {
return nil, err
}
header.PacketNumber = protocol.PacketNumber(packetNumber)
} }
header.PacketNumber = protocol.PacketNumber(packetNumber)
return header, nil return header, nil
} }
// GetLength gets the length of the publicHeader in bytes // GetLength gets the length of the publicHeader in bytes
// can only be called for regular packets // can only be called for regular packets
func (h *PublicHeader) GetLength() (protocol.ByteCount, error) { func (h *PublicHeader) GetLength(pers protocol.Perspective) (protocol.ByteCount, error) {
if h.VersionFlag || h.ResetFlag { if h.VersionFlag && h.ResetFlag {
return 0, errGetLengthOnlyForRegularPackets return 0, errResetAndVersionFlagSet
}
if h.VersionFlag && pers == protocol.PerspectiveServer {
return 0, errGetLengthNotForVersionNegotiation
} }
length := protocol.ByteCount(1) // 1 byte for public flags length := protocol.ByteCount(1) // 1 byte for public flags
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 {
return 0, errPacketNumberLenNotSet if h.hasPacketNumber(pers) {
if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 && h.PacketNumberLen != protocol.PacketNumberLen6 {
return 0, errPacketNumberLenNotSet
}
length += protocol.ByteCount(h.PacketNumberLen)
} }
if !h.TruncateConnectionID { if !h.TruncateConnectionID {
length += 8 // 8 bytes for the connection ID length += 8 // 8 bytes for the connection ID
} }
// Version Number in packets sent by the client
if h.VersionFlag {
length += 4
}
length += protocol.ByteCount(len(h.DiversificationNonce)) length += protocol.ByteCount(len(h.DiversificationNonce))
length += protocol.ByteCount(h.PacketNumberLen)
return length, nil return length, nil
} }
// hasPacketNumber determines if this PublicHeader will contain a packet number
// this depends on the ResetFlag, the VersionFlag and who sent the packet
func (h *PublicHeader) hasPacketNumber(packetSentBy protocol.Perspective) bool {
if h.ResetFlag {
return false
}
if h.VersionFlag && packetSentBy == protocol.PerspectiveServer {
return false
}
return true
}
...@@ -2,12 +2,19 @@ package quic ...@@ -2,12 +2,19 @@ package quic
import ( import (
"bytes" "bytes"
"encoding/binary"
"errors"
"github.com/lucas-clemente/quic-go/handshake" "github.com/lucas-clemente/quic-go/handshake"
"github.com/lucas-clemente/quic-go/protocol" "github.com/lucas-clemente/quic-go/protocol"
"github.com/lucas-clemente/quic-go/utils" "github.com/lucas-clemente/quic-go/utils"
) )
type publicReset struct {
rejectedPacketNumber protocol.PacketNumber
nonce uint64
}
func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte { func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber protocol.PacketNumber, nonceProof uint64) []byte {
b := &bytes.Buffer{} b := &bytes.Buffer{}
b.WriteByte(0x0a) b.WriteByte(0x0a)
...@@ -22,3 +29,34 @@ func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p ...@@ -22,3 +29,34 @@ func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p
utils.WriteUint64(b, uint64(rejectedPacketNumber)) utils.WriteUint64(b, uint64(rejectedPacketNumber))
return b.Bytes() return b.Bytes()
} }
func parsePublicReset(r *bytes.Reader) (*publicReset, error) {
pr := publicReset{}
tag, tagMap, err := handshake.ParseHandshakeMessage(r)
if err != nil {
return nil, err
}
if tag != handshake.TagPRST {
return nil, errors.New("wrong public reset tag")
}
rseq, ok := tagMap[handshake.TagRSEQ]
if !ok {
return nil, errors.New("RSEQ missing")
}
if len(rseq) != 8 {
return nil, errors.New("invalid RSEQ tag")
}
pr.rejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq))
rnon, ok := tagMap[handshake.TagRNON]
if !ok {
return nil, errors.New("RNON missing")
}
if len(rnon) != 8 {
return nil, errors.New("invalid RNON tag")
}
pr.nonce = binary.LittleEndian.Uint64(rnon)
return &pr, nil
}
...@@ -3,6 +3,7 @@ package quic ...@@ -3,6 +3,7 @@ package quic
import ( import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
"errors"
"net" "net"
"strings" "strings"
"sync" "sync"
...@@ -18,6 +19,7 @@ import ( ...@@ -18,6 +19,7 @@ import (
// packetHandler handles packets // packetHandler handles packets
type packetHandler interface { type packetHandler interface {
handlePacket(*receivedPacket) handlePacket(*receivedPacket)
OpenStream(protocol.StreamID) (utils.Stream, error)
run() run()
Close(error) error Close(error) error
} }
...@@ -29,11 +31,12 @@ type Server struct { ...@@ -29,11 +31,12 @@ type Server struct {
conn *net.UDPConn conn *net.UDPConn
connMutex sync.Mutex connMutex sync.Mutex
signer crypto.Signer certChain crypto.CertChain
scfg *handshake.ServerConfig scfg *handshake.ServerConfig
sessions map[protocol.ConnectionID]packetHandler sessions map[protocol.ConnectionID]packetHandler
sessionsMutex sync.RWMutex sessionsMutex sync.RWMutex
deleteClosedSessionsAfter time.Duration
streamCallback StreamCallback streamCallback StreamCallback
...@@ -42,16 +45,13 @@ type Server struct { ...@@ -42,16 +45,13 @@ type Server struct {
// NewServer makes a new server // NewServer makes a new server
func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server, error) { func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server, error) {
signer, err := crypto.NewProofSource(tlsConfig) certChain := crypto.NewCertChain(tlsConfig)
if err != nil {
return nil, err
}
kex, err := crypto.NewCurve25519KEX() kex, err := crypto.NewCurve25519KEX()
if err != nil { if err != nil {
return nil, err return nil, err
} }
scfg, err := handshake.NewServerConfig(kex, signer) scfg, err := handshake.NewServerConfig(kex, certChain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -62,12 +62,13 @@ func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server, ...@@ -62,12 +62,13 @@ func NewServer(addr string, tlsConfig *tls.Config, cb StreamCallback) (*Server,
} }
return &Server{ return &Server{
addr: udpAddr, addr: udpAddr,
signer: signer, certChain: certChain,
scfg: scfg, scfg: scfg,
streamCallback: cb, streamCallback: cb,
sessions: map[protocol.ConnectionID]packetHandler{}, sessions: map[protocol.ConnectionID]packetHandler{},
newSession: newSession, newSession: newSession,
deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout,
}, nil }, nil
} }
...@@ -135,12 +136,39 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet ...@@ -135,12 +136,39 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
r := bytes.NewReader(packet) r := bytes.NewReader(packet)
hdr, err := ParsePublicHeader(r) hdr, err := ParsePublicHeader(r, protocol.PerspectiveClient)
if err != nil { if err != nil {
return qerr.Error(qerr.InvalidPacketHeader, err.Error()) return qerr.Error(qerr.InvalidPacketHeader, err.Error())
} }
hdr.Raw = packet[:len(packet)-r.Len()] hdr.Raw = packet[:len(packet)-r.Len()]
s.sessionsMutex.RLock()
session, ok := s.sessions[hdr.ConnectionID]
s.sessionsMutex.RUnlock()
// ignore all Public Reset packets
if hdr.ResetFlag {
if ok {
var pr *publicReset
pr, err = parsePublicReset(r)
if err != nil {
utils.Infof("Received a Public Reset for connection %x. An error occurred parsing the packet.")
} else {
utils.Infof("Received a Public Reset for connection %x, rejected packet number: 0x%x.", hdr.ConnectionID, pr.rejectedPacketNumber)
}
} else {
utils.Infof("Received Public Reset for unknown connection %x.", hdr.ConnectionID)
}
return nil
}
// a session is only created once the client sent a supported version
// if we receive a packet for a connection that already has session, it's probably an old packet that was sent by the client before the version was negotiated
// it is safe to drop it
if ok && hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) {
return nil
}
// Send Version Negotiation Packet if the client is speaking a different protocol version // Send Version Negotiation Packet if the client is speaking a different protocol version
if hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) { if hdr.VersionFlag && !protocol.IsSupportedVersion(hdr.VersionNumber) {
utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber) utils.Infof("Client offered version %d, sending VersionNegotiationPacket", hdr.VersionNumber)
...@@ -148,15 +176,20 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet ...@@ -148,15 +176,20 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet
return err return err
} }
s.sessionsMutex.RLock()
session, ok := s.sessions[hdr.ConnectionID]
s.sessionsMutex.RUnlock()
if !ok { if !ok {
utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, hdr.VersionNumber, remoteAddr) if !hdr.VersionFlag {
_, err = conn.WriteToUDP(writePublicReset(hdr.ConnectionID, hdr.PacketNumber, 0), remoteAddr)
return err
}
version := hdr.VersionNumber
if !protocol.IsSupportedVersion(version) {
return errors.New("Server BUG: negotiated version not supported")
}
utils.Infof("Serving new connection: %x, version %d from %v", hdr.ConnectionID, version, remoteAddr)
session, err = s.newSession( session, err = s.newSession(
&udpConn{conn: conn, currentAddr: remoteAddr}, &udpConn{conn: conn, currentAddr: remoteAddr},
hdr.VersionNumber, version,
hdr.ConnectionID, hdr.ConnectionID,
s.scfg, s.scfg,
s.streamCallback, s.streamCallback,
...@@ -187,6 +220,12 @@ func (s *Server) closeCallback(id protocol.ConnectionID) { ...@@ -187,6 +220,12 @@ func (s *Server) closeCallback(id protocol.ConnectionID) {
s.sessionsMutex.Lock() s.sessionsMutex.Lock()
s.sessions[id] = nil s.sessions[id] = nil
s.sessionsMutex.Unlock() s.sessionsMutex.Unlock()
time.AfterFunc(s.deleteClosedSessionsAfter, func() {
s.sessionsMutex.Lock()
delete(s.sessions, id)
s.sessionsMutex.Unlock()
})
} }
func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte { func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte {
...@@ -196,7 +235,7 @@ func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte { ...@@ -196,7 +235,7 @@ func composeVersionNegotiation(connectionID protocol.ConnectionID) []byte {
PacketNumber: 1, PacketNumber: 1,
VersionFlag: true, VersionFlag: true,
} }
err := responsePublicHeader.WritePublicHeader(fullReply, protocol.Version35) err := responsePublicHeader.Write(fullReply, protocol.Version35, protocol.PerspectiveServer)
if err != nil { if err != nil {
utils.Errorf("error composing version negotiation packet: %s", err.Error()) utils.Errorf("error composing version negotiation packet: %s", err.Error())
} }
......
...@@ -119,7 +119,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res [] ...@@ -119,7 +119,7 @@ func (f *streamFramer) maybePopNormalFrames(maxBytes protocol.ByteCount) (res []
if f.flowControlManager.RemainingConnectionWindowSize() == 0 { if f.flowControlManager.RemainingConnectionWindowSize() == 0 {
// We are now connection-level FC blocked // We are now connection-level FC blocked
f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: 0}) f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: 0})
} else if sendWindowSize-frame.DataLen() == 0 { } else if !frame.FinBit && sendWindowSize-frame.DataLen() == 0 {
// We are now stream-level FC blocked // We are now stream-level FC blocked
f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: s.StreamID()}) f.blockedFrameQueue = append(f.blockedFrameQueue, &frames.BlockedFrame{StreamID: s.StreamID()})
} }
......
package quic package quic
import "net" import (
"net"
"sync"
)
type connection interface { type connection interface {
write([]byte) error write([]byte) error
...@@ -9,6 +12,8 @@ type connection interface { ...@@ -9,6 +12,8 @@ type connection interface {
} }
type udpConn struct { type udpConn struct {
mutex sync.RWMutex
conn *net.UDPConn conn *net.UDPConn
currentAddr *net.UDPAddr currentAddr *net.UDPAddr
} }
...@@ -21,9 +26,14 @@ func (c *udpConn) write(p []byte) error { ...@@ -21,9 +26,14 @@ func (c *udpConn) write(p []byte) error {
} }
func (c *udpConn) setCurrentRemoteAddr(addr interface{}) { func (c *udpConn) setCurrentRemoteAddr(addr interface{}) {
c.mutex.Lock()
c.currentAddr = addr.(*net.UDPAddr) c.currentAddr = addr.(*net.UDPAddr)
c.mutex.Unlock()
} }
func (c *udpConn) RemoteAddr() *net.UDPAddr { func (c *udpConn) RemoteAddr() *net.UDPAddr {
return c.currentAddr c.mutex.RLock()
addr := c.currentAddr
c.mutex.RUnlock()
return addr
} }
package quic
import "github.com/lucas-clemente/quic-go/frames"
type unpackedPacket struct {
frames []frames.Frame
}
func (u *unpackedPacket) IsRetransmittable() bool {
for _, f := range u.frames {
switch f.(type) {
case *frames.StreamFrame:
return true
case *frames.RstStreamFrame:
return true
case *frames.WindowUpdateFrame:
return true
case *frames.BlockedFrame:
return true
case *frames.PingFrame:
return true
case *frames.GoawayFrame:
return true
}
}
return false
}
...@@ -34,6 +34,14 @@ func MaxUint64(a, b uint64) uint64 { ...@@ -34,6 +34,14 @@ func MaxUint64(a, b uint64) uint64 {
return a return a
} }
// MinUint64 returns the maximum of two uint64
func MinUint64(a, b uint64) uint64 {
if a < b {
return a
}
return b
}
// Min returns the minimum of two Ints // Min returns the minimum of two Ints
func Min(a, b int) int { func Min(a, b int) int {
if a < b { if a < b {
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment