Commit 806bf0e8 authored by ginuerzh's avatar ginuerzh

update vendor

parent d61407c7
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
.vscode
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
language: go
go:
- "1.8.x"
- "1.9.x"
- "1.10.x"
env:
- TRAVIS_GOARCH=amd64
- TRAVIS_GOARCH=386
before_install:
- export GOARCH=$TRAVIS_GOARCH
branches:
only:
- master
before_script:
- go get -u github.com/klauspost/asmfmt/cmd/asmfmt
script:
- diff -au <(gofmt -d .) <(printf "")
- diff -au <(asmfmt -d .) <(printf "")
- go test -v ./...
The MIT License (MIT)
Copyright (c) 2016 Andreas Auernhammer
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
[![Godoc Reference](https://godoc.org/github.com/aead/chacha20?status.svg)](https://godoc.org/github.com/aead/chacha20)
[![Build Status](https://travis-ci.org/aead/chacha20.svg?branch=master)](https://travis-ci.org/aead/chacha20)
[![Go Report Card](https://goreportcard.com/badge/aead/chacha20)](https://goreportcard.com/report/aead/chacha20)
## The ChaCha20 stream cipher
ChaCha is a stream cipher family created by Daniel J. Bernstein.
The most common ChaCha variant is ChaCha20 (20 rounds). ChaCha20 is
standardized in [RFC 7539](https://tools.ietf.org/html/rfc7539 "RFC 7539").
This package provides implementations of three ChaCha versions:
- ChaCha20 with a 64 bit nonce (can en/decrypt up to 2^64 * 64 bytes for one key-nonce combination)
- ChaCha20 with a 96 bit nonce (can en/decrypt up to 2^32 * 64 bytes ~ 256 GB for one key-nonce combination)
- XChaCha20 with a 192 bit nonce (can en/decrypt up to 2^64 * 64 bytes for one key-nonce combination)
Furthermore the chacha sub package implements ChaCha20/12 and ChaCha20/8.
These versions use 12 or 8 rounds instead of 20.
But it's recommended to use ChaCha20 (with 20 rounds) - it will be fast enough for almost all purposes.
### Installation
Install in your GOPATH: `go get -u github.com/aead/chacha20`
### Requirements
All go versions >= 1.8.7 are supported.
The code may also work on Go 1.7 but this is not tested.
### Performance
#### AMD64
Hardware: Intel i7-6500U 2.50GHz x 2
System: Linux Ubuntu 16.04 - kernel: 4.4.0-62-generic
Go version: 1.8.0
```
AVX2
name speed cpb
ChaCha20_64-4 573MB/s ± 0% 4.16
ChaCha20_1K-4 2.19GB/s ± 0% 1.06
XChaCha20_64-4 261MB/s ± 0% 9.13
XChaCha20_1K-4 1.69GB/s ± 4% 1.37
XORKeyStream64-4 474MB/s ± 2% 5.02
XORKeyStream1K-4 2.09GB/s ± 1% 1.11
XChaCha20_XORKeyStream64-4 262MB/s ± 0% 9.09
XChaCha20_XORKeyStream1K-4 1.71GB/s ± 1% 1.36
SSSE3
name speed cpb
ChaCha20_64-4 583MB/s ± 0% 4.08
ChaCha20_1K-4 1.15GB/s ± 1% 2.02
XChaCha20_64-4 267MB/s ± 0% 8.92
XChaCha20_1K-4 984MB/s ± 5% 2.42
XORKeyStream64-4 492MB/s ± 1% 4.84
XORKeyStream1K-4 1.10GB/s ± 5% 2.11
XChaCha20_XORKeyStream64-4 266MB/s ± 0% 8.96
XChaCha20_XORKeyStream1K-4 1.00GB/s ± 2% 2.32
```
#### 386
Hardware: Intel i7-6500U 2.50GHz x 2
System: Linux Ubuntu 16.04 - kernel: 4.4.0-62-generic
Go version: 1.8.0
```
SSSE3
name                        speed cpb
ChaCha20_64-4               570MB/s ± 0% 4.18
ChaCha20_1K-4               650MB/s ± 0% 3.66
XChaCha20_64-4              223MB/s ± 0% 10.69
XChaCha20_1K-4              584MB/s ± 1% 4.08
XORKeyStream64-4            392MB/s ± 1% 6.08
XORKeyStream1K-4            629MB/s ± 1% 3.79
XChaCha20_XORKeyStream64-4  222MB/s ± 0% 10.73
XChaCha20_XORKeyStream1K-4  585MB/s ± 0% 4.07
SSE2
name speed cpb
ChaCha20_64-4 509MB/s ± 0% 4.68
ChaCha20_1K-4 553MB/s ± 2% 4.31
XChaCha20_64-4 201MB/s ± 0% 11.86
XChaCha20_1K-4 498MB/s ± 4% 4.78
XORKeyStream64-4 359MB/s ± 1% 6.64
XORKeyStream1K-4 545MB/s ± 0% 4.37
XChaCha20_XORKeyStream64-4 201MB/s ± 1% 11.86
XChaCha20_XORKeyStream1K-4 507MB/s ± 0% 4.70
```
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// Package chacha implements some low-level functions of the
// ChaCha cipher family.
package chacha // import "github.com/aead/chacha20/chacha"
import (
"encoding/binary"
"errors"
"math"
)
const (
// NonceSize is the size of the ChaCha20 nonce in bytes.
NonceSize = 8
// INonceSize is the size of the IETF-ChaCha20 nonce in bytes.
INonceSize = 12
// XNonceSize is the size of the XChaCha20 nonce in bytes.
XNonceSize = 24
// KeySize is the size of the key in bytes.
KeySize = 32
)
var (
useSSE2 bool
useSSSE3 bool
useAVX bool
useAVX2 bool
)
var (
errKeySize = errors.New("chacha20/chacha: bad key length")
errInvalidNonce = errors.New("chacha20/chacha: bad nonce length")
)
func setup(state *[64]byte, nonce, key []byte) (err error) {
if len(key) != KeySize {
err = errKeySize
return
}
var Nonce [16]byte
switch len(nonce) {
case NonceSize:
copy(Nonce[8:], nonce)
initialize(state, key, &Nonce)
case INonceSize:
copy(Nonce[4:], nonce)
initialize(state, key, &Nonce)
case XNonceSize:
var tmpKey [32]byte
var hNonce [16]byte
copy(hNonce[:], nonce[:16])
copy(tmpKey[:], key)
HChaCha20(&tmpKey, &hNonce, &tmpKey)
copy(Nonce[8:], nonce[16:])
initialize(state, tmpKey[:], &Nonce)
// BUG(aead): A "good" compiler will remove this (optimizations)
// But using the provided key instead of tmpKey,
// will change the key (-> probably confuses users)
for i := range tmpKey {
tmpKey[i] = 0
}
default:
err = errInvalidNonce
}
return
}
// XORKeyStream crypts bytes from src to dst using the given nonce and key.
// The length of the nonce determinds the version of ChaCha20:
// - NonceSize: ChaCha20/r with a 64 bit nonce and a 2^64 * 64 byte period.
// - INonceSize: ChaCha20/r as defined in RFC 7539 and a 2^32 * 64 byte period.
// - XNonceSize: XChaCha20/r with a 192 bit nonce and a 2^64 * 64 byte period.
// The rounds argument specifies the number of rounds performed for keystream
// generation - valid values are 8, 12 or 20. The src and dst may be the same slice
// but otherwise should not overlap. If len(dst) < len(src) this function panics.
// If the nonce is neither 64, 96 nor 192 bits long, this function panics.
func XORKeyStream(dst, src, nonce, key []byte, rounds int) {
if rounds != 20 && rounds != 12 && rounds != 8 {
panic("chacha20/chacha: bad number of rounds")
}
if len(dst) < len(src) {
panic("chacha20/chacha: dst buffer is to small")
}
if len(nonce) == INonceSize && uint64(len(src)) > (1<<38) {
panic("chacha20/chacha: src is too large")
}
var block, state [64]byte
if err := setup(&state, nonce, key); err != nil {
panic(err)
}
xorKeyStream(dst, src, &block, &state, rounds)
}
// Cipher implements ChaCha20/r (XChaCha20/r) for a given number of rounds r.
type Cipher struct {
state, block [64]byte
off int
rounds int // 20 for ChaCha20
noncesize int
}
// NewCipher returns a new *chacha.Cipher implementing the ChaCha20/r or XChaCha20/r
// (r = 8, 12 or 20) stream cipher. The nonce must be unique for one key for all time.
// The length of the nonce determinds the version of ChaCha20:
// - NonceSize: ChaCha20/r with a 64 bit nonce and a 2^64 * 64 byte period.
// - INonceSize: ChaCha20/r as defined in RFC 7539 and a 2^32 * 64 byte period.
// - XNonceSize: XChaCha20/r with a 192 bit nonce and a 2^64 * 64 byte period.
// If the nonce is neither 64, 96 nor 192 bits long, a non-nil error is returned.
func NewCipher(nonce, key []byte, rounds int) (*Cipher, error) {
if rounds != 20 && rounds != 12 && rounds != 8 {
panic("chacha20/chacha: bad number of rounds")
}
c := new(Cipher)
if err := setup(&(c.state), nonce, key); err != nil {
return nil, err
}
c.rounds = rounds
if len(nonce) == INonceSize {
c.noncesize = INonceSize
} else {
c.noncesize = NonceSize
}
return c, nil
}
// XORKeyStream crypts bytes from src to dst. Src and dst may be the same slice
// but otherwise should not overlap. If len(dst) < len(src) the function panics.
func (c *Cipher) XORKeyStream(dst, src []byte) {
if len(dst) < len(src) {
panic("chacha20/chacha: dst buffer is to small")
}
if c.off > 0 {
n := len(c.block[c.off:])
if len(src) <= n {
for i, v := range src {
dst[i] = v ^ c.block[c.off]
c.off++
}
if c.off == 64 {
c.off = 0
}
return
}
for i, v := range c.block[c.off:] {
dst[i] = src[i] ^ v
}
src = src[n:]
dst = dst[n:]
c.off = 0
}
// check for counter overflow
blocksToXOR := len(src) / 64
if len(src)%64 != 0 {
blocksToXOR++
}
var overflow bool
if c.noncesize == INonceSize {
overflow = binary.LittleEndian.Uint32(c.state[48:]) > math.MaxUint32-uint32(blocksToXOR)
} else {
overflow = binary.LittleEndian.Uint64(c.state[48:]) > math.MaxUint64-uint64(blocksToXOR)
}
if overflow {
panic("chacha20/chacha: counter overflow")
}
c.off += xorKeyStream(dst, src, &(c.block), &(c.state), c.rounds)
}
// SetCounter skips ctr * 64 byte blocks. SetCounter(0) resets the cipher.
// This function always skips the unused keystream of the current 64 byte block.
func (c *Cipher) SetCounter(ctr uint64) {
if c.noncesize == INonceSize {
binary.LittleEndian.PutUint32(c.state[48:], uint32(ctr))
} else {
binary.LittleEndian.PutUint64(c.state[48:], ctr)
}
c.off = 0
}
// HChaCha20 generates 32 pseudo-random bytes from a 128 bit nonce and a 256 bit secret key.
// It can be used as a key-derivation-function (KDF).
func HChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) { hChaCha20(out, nonce, key) }
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build amd64,!gccgo,!appengine,!nacl
#include "const.s"
#include "macro.s"
#define TWO 0(SP)
#define C16 32(SP)
#define C8 64(SP)
#define STATE_0 96(SP)
#define STATE_1 128(SP)
#define STATE_2 160(SP)
#define STATE_3 192(SP)
#define TMP_0 224(SP)
#define TMP_1 256(SP)
// func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamAVX2(SB), 4, $320-80
MOVQ dst_base+0(FP), DI
MOVQ src_base+24(FP), SI
MOVQ block+48(FP), BX
MOVQ state+56(FP), AX
MOVQ rounds+64(FP), DX
MOVQ src_len+32(FP), CX
MOVQ SP, R8
ADDQ $32, SP
ANDQ $-32, SP
VMOVDQU 0(AX), Y2
VMOVDQU 32(AX), Y3
VPERM2I128 $0x22, Y2, Y0, Y0
VPERM2I128 $0x33, Y2, Y1, Y1
VPERM2I128 $0x22, Y3, Y2, Y2
VPERM2I128 $0x33, Y3, Y3, Y3
TESTQ CX, CX
JZ done
VMOVDQU ·one_AVX2<>(SB), Y4
VPADDD Y4, Y3, Y3
VMOVDQA Y0, STATE_0
VMOVDQA Y1, STATE_1
VMOVDQA Y2, STATE_2
VMOVDQA Y3, STATE_3
VMOVDQU ·rol16_AVX2<>(SB), Y4
VMOVDQU ·rol8_AVX2<>(SB), Y5
VMOVDQU ·two_AVX2<>(SB), Y6
VMOVDQA Y4, Y14
VMOVDQA Y5, Y15
VMOVDQA Y4, C16
VMOVDQA Y5, C8
VMOVDQA Y6, TWO
CMPQ CX, $64
JBE between_0_and_64
CMPQ CX, $192
JBE between_64_and_192
CMPQ CX, $320
JBE between_192_and_320
CMPQ CX, $448
JBE between_320_and_448
at_least_512:
VMOVDQA Y0, Y4
VMOVDQA Y1, Y5
VMOVDQA Y2, Y6
VPADDQ TWO, Y3, Y7
VMOVDQA Y0, Y8
VMOVDQA Y1, Y9
VMOVDQA Y2, Y10
VPADDQ TWO, Y7, Y11
VMOVDQA Y0, Y12
VMOVDQA Y1, Y13
VMOVDQA Y2, Y14
VPADDQ TWO, Y11, Y15
MOVQ DX, R9
chacha_loop_512:
VMOVDQA Y8, TMP_0
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y8, C16, C8)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y8, C16, C8)
VMOVDQA TMP_0, Y8
VMOVDQA Y0, TMP_0
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y0, C16, C8)
CHACHA_QROUND_AVX(Y12, Y13, Y14, Y15, Y0, C16, C8)
CHACHA_SHUFFLE_AVX(Y1, Y2, Y3)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
CHACHA_SHUFFLE_AVX(Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y12, Y13, Y14, Y15, Y0, C16, C8)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y0, C16, C8)
VMOVDQA TMP_0, Y0
VMOVDQA Y8, TMP_0
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y8, C16, C8)
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y8, C16, C8)
VMOVDQA TMP_0, Y8
CHACHA_SHUFFLE_AVX(Y3, Y2, Y1)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
CHACHA_SHUFFLE_AVX(Y15, Y14, Y13)
SUBQ $2, R9
JA chacha_loop_512
VMOVDQA Y12, TMP_0
VMOVDQA Y13, TMP_1
VPADDD STATE_0, Y0, Y0
VPADDD STATE_1, Y1, Y1
VPADDD STATE_2, Y2, Y2
VPADDD STATE_3, Y3, Y3
XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13)
VMOVDQA STATE_0, Y0
VMOVDQA STATE_1, Y1
VMOVDQA STATE_2, Y2
VMOVDQA STATE_3, Y3
VPADDQ TWO, Y3, Y3
VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7
XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13)
VPADDQ TWO, Y3, Y3
VPADDD Y0, Y8, Y8
VPADDD Y1, Y9, Y9
VPADDD Y2, Y10, Y10
VPADDD Y3, Y11, Y11
XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13)
VPADDQ TWO, Y3, Y3
VPADDD TMP_0, Y0, Y12
VPADDD TMP_1, Y1, Y13
VPADDD Y2, Y14, Y14
VPADDD Y3, Y15, Y15
VPADDQ TWO, Y3, Y3
CMPQ CX, $512
JB less_than_512
XOR_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5)
VMOVDQA Y3, STATE_3
ADDQ $512, SI
ADDQ $512, DI
SUBQ $512, CX
CMPQ CX, $448
JA at_least_512
TESTQ CX, CX
JZ done
VMOVDQA C16, Y14
VMOVDQA C8, Y15
CMPQ CX, $64
JBE between_0_and_64
CMPQ CX, $192
JBE between_64_and_192
CMPQ CX, $320
JBE between_192_and_320
JMP between_320_and_448
less_than_512:
XOR_UPPER_AVX2(DI, SI, 384, Y12, Y13, Y14, Y15, Y4, Y5)
EXTRACT_LOWER(BX, Y12, Y13, Y14, Y15, Y4)
ADDQ $448, SI
ADDQ $448, DI
SUBQ $448, CX
JMP finalize
between_320_and_448:
VMOVDQA Y0, Y4
VMOVDQA Y1, Y5
VMOVDQA Y2, Y6
VPADDQ TWO, Y3, Y7
VMOVDQA Y0, Y8
VMOVDQA Y1, Y9
VMOVDQA Y2, Y10
VPADDQ TWO, Y7, Y11
MOVQ DX, R9
chacha_loop_384:
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y1, Y2, Y3)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
CHACHA_QROUND_AVX(Y0, Y1, Y2, Y3, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y3, Y2, Y1)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
SUBQ $2, R9
JA chacha_loop_384
VPADDD STATE_0, Y0, Y0
VPADDD STATE_1, Y1, Y1
VPADDD STATE_2, Y2, Y2
VPADDD STATE_3, Y3, Y3
XOR_AVX2(DI, SI, 0, Y0, Y1, Y2, Y3, Y12, Y13)
VMOVDQA STATE_0, Y0
VMOVDQA STATE_1, Y1
VMOVDQA STATE_2, Y2
VMOVDQA STATE_3, Y3
VPADDQ TWO, Y3, Y3
VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7
XOR_AVX2(DI, SI, 128, Y4, Y5, Y6, Y7, Y12, Y13)
VPADDQ TWO, Y3, Y3
VPADDD Y0, Y8, Y8
VPADDD Y1, Y9, Y9
VPADDD Y2, Y10, Y10
VPADDD Y3, Y11, Y11
VPADDQ TWO, Y3, Y3
CMPQ CX, $384
JB less_than_384
XOR_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13)
SUBQ $384, CX
TESTQ CX, CX
JE done
ADDQ $384, SI
ADDQ $384, DI
JMP between_0_and_64
less_than_384:
XOR_UPPER_AVX2(DI, SI, 256, Y8, Y9, Y10, Y11, Y12, Y13)
EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12)
ADDQ $320, SI
ADDQ $320, DI
SUBQ $320, CX
JMP finalize
between_192_and_320:
VMOVDQA Y0, Y4
VMOVDQA Y1, Y5
VMOVDQA Y2, Y6
VMOVDQA Y3, Y7
VMOVDQA Y0, Y8
VMOVDQA Y1, Y9
VMOVDQA Y2, Y10
VPADDQ TWO, Y3, Y11
MOVQ DX, R9
chacha_loop_256:
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_SHUFFLE_AVX(Y9, Y10, Y11)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_QROUND_AVX(Y8, Y9, Y10, Y11, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
CHACHA_SHUFFLE_AVX(Y11, Y10, Y9)
SUBQ $2, R9
JA chacha_loop_256
VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7
VPADDQ TWO, Y3, Y3
XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13)
VPADDD Y0, Y8, Y8
VPADDD Y1, Y9, Y9
VPADDD Y2, Y10, Y10
VPADDD Y3, Y11, Y11
VPADDQ TWO, Y3, Y3
CMPQ CX, $256
JB less_than_256
XOR_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13)
SUBQ $256, CX
TESTQ CX, CX
JE done
ADDQ $256, SI
ADDQ $256, DI
JMP between_0_and_64
less_than_256:
XOR_UPPER_AVX2(DI, SI, 128, Y8, Y9, Y10, Y11, Y12, Y13)
EXTRACT_LOWER(BX, Y8, Y9, Y10, Y11, Y12)
ADDQ $192, SI
ADDQ $192, DI
SUBQ $192, CX
JMP finalize
between_64_and_192:
VMOVDQA Y0, Y4
VMOVDQA Y1, Y5
VMOVDQA Y2, Y6
VMOVDQA Y3, Y7
MOVQ DX, R9
chacha_loop_128:
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y5, Y6, Y7)
CHACHA_QROUND_AVX(Y4, Y5, Y6, Y7, Y13, Y14, Y15)
CHACHA_SHUFFLE_AVX(Y7, Y6, Y5)
SUBQ $2, R9
JA chacha_loop_128
VPADDD Y0, Y4, Y4
VPADDD Y1, Y5, Y5
VPADDD Y2, Y6, Y6
VPADDD Y3, Y7, Y7
VPADDQ TWO, Y3, Y3
CMPQ CX, $128
JB less_than_128
XOR_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13)
SUBQ $128, CX
TESTQ CX, CX
JE done
ADDQ $128, SI
ADDQ $128, DI
JMP between_0_and_64
less_than_128:
XOR_UPPER_AVX2(DI, SI, 0, Y4, Y5, Y6, Y7, Y12, Y13)
EXTRACT_LOWER(BX, Y4, Y5, Y6, Y7, Y13)
ADDQ $64, SI
ADDQ $64, DI
SUBQ $64, CX
JMP finalize
between_0_and_64:
VMOVDQA X0, X4
VMOVDQA X1, X5
VMOVDQA X2, X6
VMOVDQA X3, X7
MOVQ DX, R9
chacha_loop_64:
CHACHA_QROUND_AVX(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE_AVX(X5, X6, X7)
CHACHA_QROUND_AVX(X4, X5, X6, X7, X13, X14, X15)
CHACHA_SHUFFLE_AVX(X7, X6, X5)
SUBQ $2, R9
JA chacha_loop_64
VPADDD X0, X4, X4
VPADDD X1, X5, X5
VPADDD X2, X6, X6
VPADDD X3, X7, X7
VMOVDQU ·one<>(SB), X0
VPADDQ X0, X3, X3
CMPQ CX, $64
JB less_than_64
XOR_AVX(DI, SI, 0, X4, X5, X6, X7, X13)
SUBQ $64, CX
JMP done
less_than_64:
VMOVDQU X4, 0(BX)
VMOVDQU X5, 16(BX)
VMOVDQU X6, 32(BX)
VMOVDQU X7, 48(BX)
finalize:
XORQ R11, R11
XORQ R12, R12
MOVQ CX, BP
xor_loop:
MOVB 0(SI), R11
MOVB 0(BX), R12
XORQ R11, R12
MOVB R12, 0(DI)
INCQ SI
INCQ BX
INCQ DI
DECQ BP
JA xor_loop
done:
VMOVDQU X3, 48(AX)
VZEROUPPER
MOVQ R8, SP
MOVQ CX, ret+72(FP)
RET
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build 386,!gccgo,!appengine,!nacl
package chacha
import (
"encoding/binary"
"golang.org/x/sys/cpu"
)
func init() {
useSSE2 = cpu.X86.HasSSE2
useSSSE3 = cpu.X86.HasSSSE3
useAVX = false
useAVX2 = false
}
func initialize(state *[64]byte, key []byte, nonce *[16]byte) {
binary.LittleEndian.PutUint32(state[0:], sigma[0])
binary.LittleEndian.PutUint32(state[4:], sigma[1])
binary.LittleEndian.PutUint32(state[8:], sigma[2])
binary.LittleEndian.PutUint32(state[12:], sigma[3])
copy(state[16:], key[:])
copy(state[48:], nonce[:])
}
// This function is implemented in chacha_386.s
//go:noescape
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chacha_386.s
//go:noescape
func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chacha_386.s
//go:noescape
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
switch {
case useSSSE3:
hChaCha20SSSE3(out, nonce, key)
case useSSE2:
hChaCha20SSE2(out, nonce, key)
default:
hChaCha20Generic(out, nonce, key)
}
}
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
if useSSE2 {
return xorKeyStreamSSE2(dst, src, block, state, rounds)
} else {
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}
}
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build 386,!gccgo,!appengine,!nacl
#include "const.s"
#include "macro.s"
// FINALIZE xors len bytes from src and block using
// the temp. registers t0 and t1 and writes the result
// to dst.
#define FINALIZE(dst, src, block, len, t0, t1) \
XORL t0, t0; \
XORL t1, t1; \
FINALIZE_LOOP:; \
MOVB 0(src), t0; \
MOVB 0(block), t1; \
XORL t0, t1; \
MOVB t1, 0(dst); \
INCL src; \
INCL block; \
INCL dst; \
DECL len; \
JG FINALIZE_LOOP \
#define Dst DI
#define Nonce AX
#define Key BX
#define Rounds DX
// func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20SSE2(SB), 4, $0-12
MOVL out+0(FP), Dst
MOVL nonce+4(FP), Nonce
MOVL key+8(FP), Key
MOVOU ·sigma<>(SB), X0
MOVOU 0*16(Key), X1
MOVOU 1*16(Key), X2
MOVOU 0*16(Nonce), X3
MOVL $20, Rounds
chacha_loop:
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE_SSE(X1, X2, X3)
CHACHA_QROUND_SSE2(X0, X1, X2, X3, X4)
CHACHA_SHUFFLE_SSE(X3, X2, X1)
SUBL $2, Rounds
JNZ chacha_loop
MOVOU X0, 0*16(Dst)
MOVOU X3, 1*16(Dst)
RET
// func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
TEXT ·hChaCha20SSSE3(SB), 4, $0-12
MOVL out+0(FP), Dst
MOVL nonce+4(FP), Nonce
MOVL key+8(FP), Key
MOVOU ·sigma<>(SB), X0
MOVOU 0*16(Key), X1
MOVOU 1*16(Key), X2
MOVOU 0*16(Nonce), X3
MOVL $20, Rounds
MOVOU ·rol16<>(SB), X5
MOVOU ·rol8<>(SB), X6
chacha_loop:
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE_SSE(X1, X2, X3)
CHACHA_QROUND_SSSE3(X0, X1, X2, X3, X4, X5, X6)
CHACHA_SHUFFLE_SSE(X3, X2, X1)
SUBL $2, Rounds
JNZ chacha_loop
MOVOU X0, 0*16(Dst)
MOVOU X3, 1*16(Dst)
RET
#undef Dst
#undef Nonce
#undef Key
#undef Rounds
#define State AX
#define Dst DI
#define Src SI
#define Len DX
#define Tmp0 BX
#define Tmp1 BP
// func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
TEXT ·xorKeyStreamSSE2(SB), 4, $0-40
MOVL dst_base+0(FP), Dst
MOVL src_base+12(FP), Src
MOVL state+28(FP), State
MOVL src_len+16(FP), Len
MOVL $0, ret+36(FP) // Number of bytes written to the keystream buffer - 0 iff len mod 64 == 0
MOVOU 0*16(State), X0
MOVOU 1*16(State), X1
MOVOU 2*16(State), X2
MOVOU 3*16(State), X3
TESTL Len, Len
JZ DONE
GENERATE_KEYSTREAM:
MOVO X0, X4
MOVO X1, X5
MOVO X2, X6
MOVO X3, X7
MOVL rounds+32(FP), Tmp0
CHACHA_LOOP:
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
CHACHA_SHUFFLE_SSE(X5, X6, X7)
CHACHA_QROUND_SSE2(X4, X5, X6, X7, X0)
CHACHA_SHUFFLE_SSE(X7, X6, X5)
SUBL $2, Tmp0
JA CHACHA_LOOP
MOVOU 0*16(State), X0 // Restore X0 from state
PADDL X0, X4
PADDL X1, X5
PADDL X2, X6
PADDL X3, X7
MOVOU ·one<>(SB), X0
PADDQ X0, X3
CMPL Len, $64
JL BUFFER_KEYSTREAM
XOR_SSE(Dst, Src, 0, X4, X5, X6, X7, X0)
MOVOU 0*16(State), X0 // Restore X0 from state
ADDL $64, Src
ADDL $64, Dst
SUBL $64, Len
JZ DONE
JMP GENERATE_KEYSTREAM // There is at least one more plaintext byte
BUFFER_KEYSTREAM:
MOVL block+24(FP), State
MOVOU X4, 0(State)
MOVOU X5, 16(State)
MOVOU X6, 32(State)
MOVOU X7, 48(State)
MOVL Len, ret+36(FP) // Number of bytes written to the keystream buffer - 0 < Len < 64
FINALIZE(Dst, Src, State, Len, Tmp0, Tmp1)
DONE:
MOVL state+28(FP), State
MOVOU X3, 3*16(State)
RET
#undef State
#undef Dst
#undef Src
#undef Len
#undef Tmp0
#undef Tmp1
// Copyright (c) 2017 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build go1.7,amd64,!gccgo,!appengine,!nacl
package chacha
import "golang.org/x/sys/cpu"
func init() {
useSSE2 = cpu.X86.HasSSE2
useSSSE3 = cpu.X86.HasSSSE3
useAVX = cpu.X86.HasAVX
useAVX2 = cpu.X86.HasAVX2
}
// This function is implemented in chacha_amd64.s
//go:noescape
func initialize(state *[64]byte, key []byte, nonce *[16]byte)
// This function is implemented in chacha_amd64.s
//go:noescape
func hChaCha20SSE2(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chacha_amd64.s
//go:noescape
func hChaCha20SSSE3(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chachaAVX2_amd64.s
//go:noescape
func hChaCha20AVX(out *[32]byte, nonce *[16]byte, key *[32]byte)
// This function is implemented in chacha_amd64.s
//go:noescape
func xorKeyStreamSSE2(dst, src []byte, block, state *[64]byte, rounds int) int
// This function is implemented in chacha_amd64.s
//go:noescape
func xorKeyStreamSSSE3(dst, src []byte, block, state *[64]byte, rounds int) int
// This function is implemented in chacha_amd64.s
//go:noescape
func xorKeyStreamAVX(dst, src []byte, block, state *[64]byte, rounds int) int
// This function is implemented in chachaAVX2_amd64.s
//go:noescape
func xorKeyStreamAVX2(dst, src []byte, block, state *[64]byte, rounds int) int
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
switch {
case useAVX:
hChaCha20AVX(out, nonce, key)
case useSSSE3:
hChaCha20SSSE3(out, nonce, key)
case useSSE2:
hChaCha20SSE2(out, nonce, key)
default:
hChaCha20Generic(out, nonce, key)
}
}
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
switch {
case useAVX2:
return xorKeyStreamAVX2(dst, src, block, state, rounds)
case useAVX:
return xorKeyStreamAVX(dst, src, block, state, rounds)
case useSSSE3:
return xorKeyStreamSSSE3(dst, src, block, state, rounds)
case useSSE2:
return xorKeyStreamSSE2(dst, src, block, state, rounds)
default:
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}
}
This diff is collapsed.
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
package chacha
import "encoding/binary"
var sigma = [4]uint32{0x61707865, 0x3320646e, 0x79622d32, 0x6b206574}
func xorKeyStreamGeneric(dst, src []byte, block, state *[64]byte, rounds int) int {
for len(src) >= 64 {
chachaGeneric(block, state, rounds)
for i, v := range block {
dst[i] = src[i] ^ v
}
src = src[64:]
dst = dst[64:]
}
n := len(src)
if n > 0 {
chachaGeneric(block, state, rounds)
for i, v := range src {
dst[i] = v ^ block[i]
}
}
return n
}
func chachaGeneric(dst *[64]byte, state *[64]byte, rounds int) {
v00 := binary.LittleEndian.Uint32(state[0:])
v01 := binary.LittleEndian.Uint32(state[4:])
v02 := binary.LittleEndian.Uint32(state[8:])
v03 := binary.LittleEndian.Uint32(state[12:])
v04 := binary.LittleEndian.Uint32(state[16:])
v05 := binary.LittleEndian.Uint32(state[20:])
v06 := binary.LittleEndian.Uint32(state[24:])
v07 := binary.LittleEndian.Uint32(state[28:])
v08 := binary.LittleEndian.Uint32(state[32:])
v09 := binary.LittleEndian.Uint32(state[36:])
v10 := binary.LittleEndian.Uint32(state[40:])
v11 := binary.LittleEndian.Uint32(state[44:])
v12 := binary.LittleEndian.Uint32(state[48:])
v13 := binary.LittleEndian.Uint32(state[52:])
v14 := binary.LittleEndian.Uint32(state[56:])
v15 := binary.LittleEndian.Uint32(state[60:])
s00, s01, s02, s03, s04, s05, s06, s07 := v00, v01, v02, v03, v04, v05, v06, v07
s08, s09, s10, s11, s12, s13, s14, s15 := v08, v09, v10, v11, v12, v13, v14, v15
for i := 0; i < rounds; i += 2 {
v00 += v04
v12 ^= v00
v12 = (v12 << 16) | (v12 >> 16)
v08 += v12
v04 ^= v08
v04 = (v04 << 12) | (v04 >> 20)
v00 += v04
v12 ^= v00
v12 = (v12 << 8) | (v12 >> 24)
v08 += v12
v04 ^= v08
v04 = (v04 << 7) | (v04 >> 25)
v01 += v05
v13 ^= v01
v13 = (v13 << 16) | (v13 >> 16)
v09 += v13
v05 ^= v09
v05 = (v05 << 12) | (v05 >> 20)
v01 += v05
v13 ^= v01
v13 = (v13 << 8) | (v13 >> 24)
v09 += v13
v05 ^= v09
v05 = (v05 << 7) | (v05 >> 25)
v02 += v06
v14 ^= v02
v14 = (v14 << 16) | (v14 >> 16)
v10 += v14
v06 ^= v10
v06 = (v06 << 12) | (v06 >> 20)
v02 += v06
v14 ^= v02
v14 = (v14 << 8) | (v14 >> 24)
v10 += v14
v06 ^= v10
v06 = (v06 << 7) | (v06 >> 25)
v03 += v07
v15 ^= v03
v15 = (v15 << 16) | (v15 >> 16)
v11 += v15
v07 ^= v11
v07 = (v07 << 12) | (v07 >> 20)
v03 += v07
v15 ^= v03
v15 = (v15 << 8) | (v15 >> 24)
v11 += v15
v07 ^= v11
v07 = (v07 << 7) | (v07 >> 25)
v00 += v05
v15 ^= v00
v15 = (v15 << 16) | (v15 >> 16)
v10 += v15
v05 ^= v10
v05 = (v05 << 12) | (v05 >> 20)
v00 += v05
v15 ^= v00
v15 = (v15 << 8) | (v15 >> 24)
v10 += v15
v05 ^= v10
v05 = (v05 << 7) | (v05 >> 25)
v01 += v06
v12 ^= v01
v12 = (v12 << 16) | (v12 >> 16)
v11 += v12
v06 ^= v11
v06 = (v06 << 12) | (v06 >> 20)
v01 += v06
v12 ^= v01
v12 = (v12 << 8) | (v12 >> 24)
v11 += v12
v06 ^= v11
v06 = (v06 << 7) | (v06 >> 25)
v02 += v07
v13 ^= v02
v13 = (v13 << 16) | (v13 >> 16)
v08 += v13
v07 ^= v08
v07 = (v07 << 12) | (v07 >> 20)
v02 += v07
v13 ^= v02
v13 = (v13 << 8) | (v13 >> 24)
v08 += v13
v07 ^= v08
v07 = (v07 << 7) | (v07 >> 25)
v03 += v04
v14 ^= v03
v14 = (v14 << 16) | (v14 >> 16)
v09 += v14
v04 ^= v09
v04 = (v04 << 12) | (v04 >> 20)
v03 += v04
v14 ^= v03
v14 = (v14 << 8) | (v14 >> 24)
v09 += v14
v04 ^= v09
v04 = (v04 << 7) | (v04 >> 25)
}
v00 += s00
v01 += s01
v02 += s02
v03 += s03
v04 += s04
v05 += s05
v06 += s06
v07 += s07
v08 += s08
v09 += s09
v10 += s10
v11 += s11
v12 += s12
v13 += s13
v14 += s14
v15 += s15
s12++
binary.LittleEndian.PutUint32(state[48:], s12)
if s12 == 0 { // indicates overflow
s13++
binary.LittleEndian.PutUint32(state[52:], s13)
}
binary.LittleEndian.PutUint32(dst[0:], v00)
binary.LittleEndian.PutUint32(dst[4:], v01)
binary.LittleEndian.PutUint32(dst[8:], v02)
binary.LittleEndian.PutUint32(dst[12:], v03)
binary.LittleEndian.PutUint32(dst[16:], v04)
binary.LittleEndian.PutUint32(dst[20:], v05)
binary.LittleEndian.PutUint32(dst[24:], v06)
binary.LittleEndian.PutUint32(dst[28:], v07)
binary.LittleEndian.PutUint32(dst[32:], v08)
binary.LittleEndian.PutUint32(dst[36:], v09)
binary.LittleEndian.PutUint32(dst[40:], v10)
binary.LittleEndian.PutUint32(dst[44:], v11)
binary.LittleEndian.PutUint32(dst[48:], v12)
binary.LittleEndian.PutUint32(dst[52:], v13)
binary.LittleEndian.PutUint32(dst[56:], v14)
binary.LittleEndian.PutUint32(dst[60:], v15)
}
func hChaCha20Generic(out *[32]byte, nonce *[16]byte, key *[32]byte) {
v00 := sigma[0]
v01 := sigma[1]
v02 := sigma[2]
v03 := sigma[3]
v04 := binary.LittleEndian.Uint32(key[0:])
v05 := binary.LittleEndian.Uint32(key[4:])
v06 := binary.LittleEndian.Uint32(key[8:])
v07 := binary.LittleEndian.Uint32(key[12:])
v08 := binary.LittleEndian.Uint32(key[16:])
v09 := binary.LittleEndian.Uint32(key[20:])
v10 := binary.LittleEndian.Uint32(key[24:])
v11 := binary.LittleEndian.Uint32(key[28:])
v12 := binary.LittleEndian.Uint32(nonce[0:])
v13 := binary.LittleEndian.Uint32(nonce[4:])
v14 := binary.LittleEndian.Uint32(nonce[8:])
v15 := binary.LittleEndian.Uint32(nonce[12:])
for i := 0; i < 20; i += 2 {
v00 += v04
v12 ^= v00
v12 = (v12 << 16) | (v12 >> 16)
v08 += v12
v04 ^= v08
v04 = (v04 << 12) | (v04 >> 20)
v00 += v04
v12 ^= v00
v12 = (v12 << 8) | (v12 >> 24)
v08 += v12
v04 ^= v08
v04 = (v04 << 7) | (v04 >> 25)
v01 += v05
v13 ^= v01
v13 = (v13 << 16) | (v13 >> 16)
v09 += v13
v05 ^= v09
v05 = (v05 << 12) | (v05 >> 20)
v01 += v05
v13 ^= v01
v13 = (v13 << 8) | (v13 >> 24)
v09 += v13
v05 ^= v09
v05 = (v05 << 7) | (v05 >> 25)
v02 += v06
v14 ^= v02
v14 = (v14 << 16) | (v14 >> 16)
v10 += v14
v06 ^= v10
v06 = (v06 << 12) | (v06 >> 20)
v02 += v06
v14 ^= v02
v14 = (v14 << 8) | (v14 >> 24)
v10 += v14
v06 ^= v10
v06 = (v06 << 7) | (v06 >> 25)
v03 += v07
v15 ^= v03
v15 = (v15 << 16) | (v15 >> 16)
v11 += v15
v07 ^= v11
v07 = (v07 << 12) | (v07 >> 20)
v03 += v07
v15 ^= v03
v15 = (v15 << 8) | (v15 >> 24)
v11 += v15
v07 ^= v11
v07 = (v07 << 7) | (v07 >> 25)
v00 += v05
v15 ^= v00
v15 = (v15 << 16) | (v15 >> 16)
v10 += v15
v05 ^= v10
v05 = (v05 << 12) | (v05 >> 20)
v00 += v05
v15 ^= v00
v15 = (v15 << 8) | (v15 >> 24)
v10 += v15
v05 ^= v10
v05 = (v05 << 7) | (v05 >> 25)
v01 += v06
v12 ^= v01
v12 = (v12 << 16) | (v12 >> 16)
v11 += v12
v06 ^= v11
v06 = (v06 << 12) | (v06 >> 20)
v01 += v06
v12 ^= v01
v12 = (v12 << 8) | (v12 >> 24)
v11 += v12
v06 ^= v11
v06 = (v06 << 7) | (v06 >> 25)
v02 += v07
v13 ^= v02
v13 = (v13 << 16) | (v13 >> 16)
v08 += v13
v07 ^= v08
v07 = (v07 << 12) | (v07 >> 20)
v02 += v07
v13 ^= v02
v13 = (v13 << 8) | (v13 >> 24)
v08 += v13
v07 ^= v08
v07 = (v07 << 7) | (v07 >> 25)
v03 += v04
v14 ^= v03
v14 = (v14 << 16) | (v14 >> 16)
v09 += v14
v04 ^= v09
v04 = (v04 << 12) | (v04 >> 20)
v03 += v04
v14 ^= v03
v14 = (v14 << 8) | (v14 >> 24)
v09 += v14
v04 ^= v09
v04 = (v04 << 7) | (v04 >> 25)
}
binary.LittleEndian.PutUint32(out[0:], v00)
binary.LittleEndian.PutUint32(out[4:], v01)
binary.LittleEndian.PutUint32(out[8:], v02)
binary.LittleEndian.PutUint32(out[12:], v03)
binary.LittleEndian.PutUint32(out[16:], v12)
binary.LittleEndian.PutUint32(out[20:], v13)
binary.LittleEndian.PutUint32(out[24:], v14)
binary.LittleEndian.PutUint32(out[28:], v15)
}
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build !amd64,!386 gccgo appengine nacl
package chacha
import "encoding/binary"
func init() {
useSSE2 = false
useSSSE3 = false
useAVX = false
useAVX2 = false
}
func initialize(state *[64]byte, key []byte, nonce *[16]byte) {
binary.LittleEndian.PutUint32(state[0:], sigma[0])
binary.LittleEndian.PutUint32(state[4:], sigma[1])
binary.LittleEndian.PutUint32(state[8:], sigma[2])
binary.LittleEndian.PutUint32(state[12:], sigma[3])
copy(state[16:], key[:])
copy(state[48:], nonce[:])
}
func xorKeyStream(dst, src []byte, block, state *[64]byte, rounds int) int {
return xorKeyStreamGeneric(dst, src, block, state, rounds)
}
func hChaCha20(out *[32]byte, nonce *[16]byte, key *[32]byte) {
hChaCha20Generic(out, nonce, key)
}
// Copyright (c) 2018 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build 386,!gccgo,!appengine,!nacl amd64,!gccgo,!appengine,!nacl
#include "textflag.h"
DATA ·sigma<>+0x00(SB)/4, $0x61707865
DATA ·sigma<>+0x04(SB)/4, $0x3320646e
DATA ·sigma<>+0x08(SB)/4, $0x79622d32
DATA ·sigma<>+0x0C(SB)/4, $0x6b206574
GLOBL ·sigma<>(SB), (NOPTR+RODATA), $16 // The 4 ChaCha initialization constants
// SSE2/SSE3/AVX constants
DATA ·one<>+0x00(SB)/8, $1
DATA ·one<>+0x08(SB)/8, $0
GLOBL ·one<>(SB), (NOPTR+RODATA), $16 // The constant 1 as 128 bit value
DATA ·rol16<>+0x00(SB)/8, $0x0504070601000302
DATA ·rol16<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
GLOBL ·rol16<>(SB), (NOPTR+RODATA), $16 // The PSHUFB 16 bit left rotate constant
DATA ·rol8<>+0x00(SB)/8, $0x0605040702010003
DATA ·rol8<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
GLOBL ·rol8<>(SB), (NOPTR+RODATA), $16 // The PSHUFB 8 bit left rotate constant
// AVX2 constants
DATA ·one_AVX2<>+0x00(SB)/8, $0
DATA ·one_AVX2<>+0x08(SB)/8, $0
DATA ·one_AVX2<>+0x10(SB)/8, $1
DATA ·one_AVX2<>+0x18(SB)/8, $0
GLOBL ·one_AVX2<>(SB), (NOPTR+RODATA), $32 // The constant 1 as 256 bit value
DATA ·two_AVX2<>+0x00(SB)/8, $2
DATA ·two_AVX2<>+0x08(SB)/8, $0
DATA ·two_AVX2<>+0x10(SB)/8, $2
DATA ·two_AVX2<>+0x18(SB)/8, $0
GLOBL ·two_AVX2<>(SB), (NOPTR+RODATA), $32
DATA ·rol16_AVX2<>+0x00(SB)/8, $0x0504070601000302
DATA ·rol16_AVX2<>+0x08(SB)/8, $0x0D0C0F0E09080B0A
DATA ·rol16_AVX2<>+0x10(SB)/8, $0x0504070601000302
DATA ·rol16_AVX2<>+0x18(SB)/8, $0x0D0C0F0E09080B0A
GLOBL ·rol16_AVX2<>(SB), (NOPTR+RODATA), $32 // The VPSHUFB 16 bit left rotate constant
DATA ·rol8_AVX2<>+0x00(SB)/8, $0x0605040702010003
DATA ·rol8_AVX2<>+0x08(SB)/8, $0x0E0D0C0F0A09080B
DATA ·rol8_AVX2<>+0x10(SB)/8, $0x0605040702010003
DATA ·rol8_AVX2<>+0x18(SB)/8, $0x0E0D0C0F0A09080B
GLOBL ·rol8_AVX2<>(SB), (NOPTR+RODATA), $32 // The VPSHUFB 8 bit left rotate constant
// Copyright (c) 2018 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// +build 386,!gccgo,!appengine,!nacl amd64,!gccgo,!appengine,!nacl
// ROTL_SSE rotates all 4 32 bit values of the XMM register v
// left by n bits using SSE2 instructions (0 <= n <= 32).
// The XMM register t is used as a temp. register.
#define ROTL_SSE(n, t, v) \
MOVO v, t; \
PSLLL $n, t; \
PSRLL $(32-n), v; \
PXOR t, v
// ROTL_AVX rotates all 4/8 32 bit values of the AVX/AVX2 register v
// left by n bits using AVX/AVX2 instructions (0 <= n <= 32).
// The AVX/AVX2 register t is used as a temp. register.
#define ROTL_AVX(n, t, v) \
VPSLLD $n, v, t; \
VPSRLD $(32-n), v, v; \
VPXOR v, t, v
// CHACHA_QROUND_SSE2 performs a ChaCha quarter-round using the
// 4 XMM registers v0, v1, v2 and v3. It uses only ROTL_SSE2 for
// rotations. The XMM register t is used as a temp. register.
#define CHACHA_QROUND_SSE2(v0, v1, v2, v3, t) \
PADDL v1, v0; \
PXOR v0, v3; \
ROTL_SSE(16, t, v3); \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(12, t, v1); \
PADDL v1, v0; \
PXOR v0, v3; \
ROTL_SSE(8, t, v3); \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(7, t, v1)
// CHACHA_QROUND_SSSE3 performs a ChaCha quarter-round using the
// 4 XMM registers v0, v1, v2 and v3. It uses PSHUFB for 8/16 bit
// rotations. The XMM register t is used as a temp. register.
//
// r16 holds the PSHUFB constant for a 16 bit left rotate.
// r8 holds the PSHUFB constant for a 8 bit left rotate.
#define CHACHA_QROUND_SSSE3(v0, v1, v2, v3, t, r16, r8) \
PADDL v1, v0; \
PXOR v0, v3; \
PSHUFB r16, v3; \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(12, t, v1); \
PADDL v1, v0; \
PXOR v0, v3; \
PSHUFB r8, v3; \
PADDL v3, v2; \
PXOR v2, v1; \
ROTL_SSE(7, t, v1)
// CHACHA_QROUND_AVX performs a ChaCha quarter-round using the
// 4 AVX/AVX2 registers v0, v1, v2 and v3. It uses VPSHUFB for 8/16 bit
// rotations. The AVX/AVX2 register t is used as a temp. register.
//
// r16 holds the VPSHUFB constant for a 16 bit left rotate.
// r8 holds the VPSHUFB constant for a 8 bit left rotate.
#define CHACHA_QROUND_AVX(v0, v1, v2, v3, t, r16, r8) \
VPADDD v0, v1, v0; \
VPXOR v3, v0, v3; \
VPSHUFB r16, v3, v3; \
VPADDD v2, v3, v2; \
VPXOR v1, v2, v1; \
ROTL_AVX(12, t, v1); \
VPADDD v0, v1, v0; \
VPXOR v3, v0, v3; \
VPSHUFB r8, v3, v3; \
VPADDD v2, v3, v2; \
VPXOR v1, v2, v1; \
ROTL_AVX(7, t, v1)
// CHACHA_SHUFFLE_SSE performs a ChaCha shuffle using the
// 3 XMM registers v1, v2 and v3. The inverse shuffle is
// performed by switching v1 and v3: CHACHA_SHUFFLE_SSE(v3, v2, v1).
#define CHACHA_SHUFFLE_SSE(v1, v2, v3) \
PSHUFL $0x39, v1, v1; \
PSHUFL $0x4E, v2, v2; \
PSHUFL $0x93, v3, v3
// CHACHA_SHUFFLE_AVX performs a ChaCha shuffle using the
// 3 AVX/AVX2 registers v1, v2 and v3. The inverse shuffle is
// performed by switching v1 and v3: CHACHA_SHUFFLE_AVX(v3, v2, v1).
#define CHACHA_SHUFFLE_AVX(v1, v2, v3) \
VPSHUFD $0x39, v1, v1; \
VPSHUFD $0x4E, v2, v2; \
VPSHUFD $0x93, v3, v3
// XOR_SSE extracts 4x16 byte vectors from src at
// off, xors all vectors with the corresponding XMM
// register (v0 - v3) and writes the result to dst
// at off.
// The XMM register t is used as a temp. register.
#define XOR_SSE(dst, src, off, v0, v1, v2, v3, t) \
MOVOU 0+off(src), t; \
PXOR v0, t; \
MOVOU t, 0+off(dst); \
MOVOU 16+off(src), t; \
PXOR v1, t; \
MOVOU t, 16+off(dst); \
MOVOU 32+off(src), t; \
PXOR v2, t; \
MOVOU t, 32+off(dst); \
MOVOU 48+off(src), t; \
PXOR v3, t; \
MOVOU t, 48+off(dst)
// XOR_AVX extracts 4x16 byte vectors from src at
// off, xors all vectors with the corresponding AVX
// register (v0 - v3) and writes the result to dst
// at off.
// The XMM register t is used as a temp. register.
#define XOR_AVX(dst, src, off, v0, v1, v2, v3, t) \
VPXOR 0+off(src), v0, t; \
VMOVDQU t, 0+off(dst); \
VPXOR 16+off(src), v1, t; \
VMOVDQU t, 16+off(dst); \
VPXOR 32+off(src), v2, t; \
VMOVDQU t, 32+off(dst); \
VPXOR 48+off(src), v3, t; \
VMOVDQU t, 48+off(dst)
#define XOR_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
VMOVDQU (0+off)(src), t0; \
VPERM2I128 $32, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (0+off)(dst); \
VMOVDQU (32+off)(src), t0; \
VPERM2I128 $32, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (32+off)(dst); \
VMOVDQU (64+off)(src), t0; \
VPERM2I128 $49, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (64+off)(dst); \
VMOVDQU (96+off)(src), t0; \
VPERM2I128 $49, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (96+off)(dst)
#define XOR_UPPER_AVX2(dst, src, off, v0, v1, v2, v3, t0, t1) \
VMOVDQU (0+off)(src), t0; \
VPERM2I128 $32, v1, v0, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (0+off)(dst); \
VMOVDQU (32+off)(src), t0; \
VPERM2I128 $32, v3, v2, t1; \
VPXOR t0, t1, t0; \
VMOVDQU t0, (32+off)(dst); \
#define EXTRACT_LOWER(dst, v0, v1, v2, v3, t0) \
VPERM2I128 $49, v1, v0, t0; \
VMOVDQU t0, 0(dst); \
VPERM2I128 $49, v3, v2, t0; \
VMOVDQU t0, 32(dst)
// Copyright (c) 2016 Andreas Auernhammer. All rights reserved.
// Use of this source code is governed by a license that can be
// found in the LICENSE file.
// Package chacha20 implements the ChaCha20 / XChaCha20 stream chipher.
// Notice that one specific key-nonce combination must be unique for all time.
//
// There are three versions of ChaCha20:
// - ChaCha20 with a 64 bit nonce (en/decrypt up to 2^64 * 64 bytes for one key-nonce combination)
// - ChaCha20 with a 96 bit nonce (en/decrypt up to 2^32 * 64 bytes (~256 GB) for one key-nonce combination)
// - XChaCha20 with a 192 bit nonce (en/decrypt up to 2^64 * 64 bytes for one key-nonce combination)
package chacha20 // import "github.com/aead/chacha20"
import (
"crypto/cipher"
"github.com/aead/chacha20/chacha"
)
// XORKeyStream crypts bytes from src to dst using the given nonce and key.
// The length of the nonce determinds the version of ChaCha20:
// - 8 bytes: ChaCha20 with a 64 bit nonce and a 2^64 * 64 byte period.
// - 12 bytes: ChaCha20 as defined in RFC 7539 and a 2^32 * 64 byte period.
// - 24 bytes: XChaCha20 with a 192 bit nonce and a 2^64 * 64 byte period.
// Src and dst may be the same slice but otherwise should not overlap.
// If len(dst) < len(src) this function panics.
// If the nonce is neither 64, 96 nor 192 bits long, this function panics.
func XORKeyStream(dst, src, nonce, key []byte) {
chacha.XORKeyStream(dst, src, nonce, key, 20)
}
// NewCipher returns a new cipher.Stream implementing a ChaCha20 version.
// The nonce must be unique for one key for all time.
// The length of the nonce determinds the version of ChaCha20:
// - 8 bytes: ChaCha20 with a 64 bit nonce and a 2^64 * 64 byte period.
// - 12 bytes: ChaCha20 as defined in RFC 7539 and a 2^32 * 64 byte period.
// - 24 bytes: XChaCha20 with a 192 bit nonce and a 2^64 * 64 byte period.
// If the nonce is neither 64, 96 nor 192 bits long, a non-nil error is returned.
func NewCipher(nonce, key []byte) (cipher.Stream, error) {
return chacha.NewCipher(nonce, key, 20)
}
...@@ -99,8 +99,8 @@ work: ...@@ -99,8 +99,8 @@ work:
## Examples ## Examples
A short "how to use the API" is at the beginning of doc.go (this also will show A short "how to use the API" is at the beginning of doc.go (this also will show when you call `godoc
when you call `godoc github.com/miekg/dns`). github.com/miekg/dns`).
Example programs can be found in the `github.com/miekg/exdns` repository. Example programs can be found in the `github.com/miekg/exdns` repository.
...@@ -158,8 +158,9 @@ Example programs can be found in the `github.com/miekg/exdns` repository. ...@@ -158,8 +158,9 @@ Example programs can be found in the `github.com/miekg/exdns` repository.
* 7553 - URI record * 7553 - URI record
* 7858 - DNS over TLS: Initiation and Performance Considerations * 7858 - DNS over TLS: Initiation and Performance Considerations
* 7871 - EDNS0 Client Subnet * 7871 - EDNS0 Client Subnet
* 7873 - Domain Name System (DNS) Cookies (draft-ietf-dnsop-cookies) * 7873 - Domain Name System (DNS) Cookies
* 8080 - EdDSA for DNSSEC * 8080 - EdDSA for DNSSEC
* 8499 - DNS Terminology
## Loosely Based Upon ## Loosely Based Upon
......
...@@ -10,7 +10,7 @@ type MsgAcceptFunc func(dh Header) MsgAcceptAction ...@@ -10,7 +10,7 @@ type MsgAcceptFunc func(dh Header) MsgAcceptAction
// * opcode isn't OpcodeQuery or OpcodeNotify // * opcode isn't OpcodeQuery or OpcodeNotify
// * Zero bit isn't zero // * Zero bit isn't zero
// * has more than 1 question in the question section // * has more than 1 question in the question section
// * has more than 0 RRs in the Answer section // * has more than 1 RR in the Answer section
// * has more than 0 RRs in the Authority section // * has more than 0 RRs in the Authority section
// * has more than 2 RRs in the Additional section // * has more than 2 RRs in the Additional section
var DefaultMsgAcceptFunc MsgAcceptFunc = defaultMsgAcceptFunc var DefaultMsgAcceptFunc MsgAcceptFunc = defaultMsgAcceptFunc
...@@ -24,7 +24,7 @@ const ( ...@@ -24,7 +24,7 @@ const (
MsgIgnore // Ignore the error and send nothing back. MsgIgnore // Ignore the error and send nothing back.
) )
var defaultMsgAcceptFunc = func(dh Header) MsgAcceptAction { func defaultMsgAcceptFunc(dh Header) MsgAcceptAction {
if isResponse := dh.Bits&_QR != 0; isResponse { if isResponse := dh.Bits&_QR != 0; isResponse {
return MsgIgnore return MsgIgnore
} }
...@@ -41,10 +41,12 @@ var defaultMsgAcceptFunc = func(dh Header) MsgAcceptAction { ...@@ -41,10 +41,12 @@ var defaultMsgAcceptFunc = func(dh Header) MsgAcceptAction {
if dh.Qdcount != 1 { if dh.Qdcount != 1 {
return MsgReject return MsgReject
} }
if dh.Ancount != 0 { // NOTIFY requests can have a SOA in the ANSWER section. See RFC 1996 Section 3.7 and 3.11.
if dh.Ancount > 1 {
return MsgReject return MsgReject
} }
if dh.Nscount != 0 { // IXFR request could have one SOA RR in the NS section. See RFC 1995, section 3.
if dh.Nscount > 1 {
return MsgReject return MsgReject
} }
if dh.Arcount > 2 { if dh.Arcount > 2 {
......
...@@ -320,16 +320,12 @@ func (co *Conn) Read(p []byte) (n int, err error) { ...@@ -320,16 +320,12 @@ func (co *Conn) Read(p []byte) (n int, err error) {
return 0, err return 0, err
} }
if l > len(p) { if l > len(p) {
return int(l), io.ErrShortBuffer return l, io.ErrShortBuffer
} }
return tcpRead(r, p[:l]) return tcpRead(r, p[:l])
} }
// UDP connection // UDP connection
n, err = co.Conn.Read(p) return co.Conn.Read(p)
if err != nil {
return n, err
}
return n, err
} }
// WriteMsg sends a message through the connection co. // WriteMsg sends a message through the connection co.
...@@ -351,10 +347,8 @@ func (co *Conn) WriteMsg(m *Msg) (err error) { ...@@ -351,10 +347,8 @@ func (co *Conn) WriteMsg(m *Msg) (err error) {
if err != nil { if err != nil {
return err return err
} }
if _, err = co.Write(out); err != nil { _, err = co.Write(out)
return err return err
}
return nil
} }
// Write implements the net.Conn Write method. // Write implements the net.Conn Write method.
...@@ -376,8 +370,7 @@ func (co *Conn) Write(p []byte) (n int, err error) { ...@@ -376,8 +370,7 @@ func (co *Conn) Write(p []byte) (n int, err error) {
n, err := io.Copy(w, bytes.NewReader(p)) n, err := io.Copy(w, bytes.NewReader(p))
return int(n), err return int(n), err
} }
n, err = co.Conn.Write(p) return co.Conn.Write(p)
return n, err
} }
// Return the appropriate timeout for a specific request // Return the appropriate timeout for a specific request
...@@ -444,11 +437,7 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { ...@@ -444,11 +437,7 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
// DialTimeout acts like Dial but takes a timeout. // DialTimeout acts like Dial but takes a timeout.
func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) { func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) {
client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}} client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}}
conn, err = client.Dial(address) return client.Dial(address)
if err != nil {
return nil, err
}
return conn, nil
} }
// DialWithTLS connects to the address on the named network with TLS. // DialWithTLS connects to the address on the named network with TLS.
...@@ -457,12 +446,7 @@ func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, er ...@@ -457,12 +446,7 @@ func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, er
network += "-tls" network += "-tls"
} }
client := Client{Net: network, TLSConfig: tlsConfig} client := Client{Net: network, TLSConfig: tlsConfig}
conn, err = client.Dial(address) return client.Dial(address)
if err != nil {
return nil, err
}
return conn, nil
} }
// DialTimeoutWithTLS acts like DialWithTLS but takes a timeout. // DialTimeoutWithTLS acts like DialWithTLS but takes a timeout.
...@@ -471,11 +455,7 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout ...@@ -471,11 +455,7 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
network += "-tls" network += "-tls"
} }
client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig} client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig}
conn, err = client.Dial(address) return client.Dial(address)
if err != nil {
return nil, err
}
return conn, nil
} }
// ExchangeContext acts like Exchange, but honors the deadline on the provided // ExchangeContext acts like Exchange, but honors the deadline on the provided
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"errors" "errors"
"net" "net"
"strconv" "strconv"
"strings"
) )
const hexDigit = "0123456789abcdef" const hexDigit = "0123456789abcdef"
...@@ -163,11 +164,72 @@ func (dns *Msg) IsEdns0() *OPT { ...@@ -163,11 +164,72 @@ func (dns *Msg) IsEdns0() *OPT {
// the number of labels. When false is returned the number of labels is not // the number of labels. When false is returned the number of labels is not
// defined. Also note that this function is extremely liberal; almost any // defined. Also note that this function is extremely liberal; almost any
// string is a valid domain name as the DNS is 8 bit protocol. It checks if each // string is a valid domain name as the DNS is 8 bit protocol. It checks if each
// label fits in 63 characters, but there is no length check for the entire // label fits in 63 characters and that the entire name will fit into the 255
// string s. I.e. a domain name longer than 255 characters is considered valid. // octet wire format limit.
func IsDomainName(s string) (labels int, ok bool) { func IsDomainName(s string) (labels int, ok bool) {
_, labels, err := packDomainName(s, nil, 0, compressionMap{}, false) // XXX: The logic in this function was copied from packDomainName and
return labels, err == nil // should be kept in sync with that function.
const lenmsg = 256
if len(s) == 0 { // Ok, for instance when dealing with update RR without any rdata.
return 0, false
}
s = Fqdn(s)
// Each dot ends a segment of the name. Except for escaped dots (\.), which
// are normal dots.
var (
off int
begin int
wasDot bool
)
for i := 0; i < len(s); i++ {
switch s[i] {
case '\\':
if off+1 > lenmsg {
return labels, false
}
// check for \DDD
if i+3 < len(s) && isDigit(s[i+1]) && isDigit(s[i+2]) && isDigit(s[i+3]) {
i += 3
begin += 3
} else {
i++
begin++
}
wasDot = false
case '.':
if wasDot {
// two dots back to back is not legal
return labels, false
}
wasDot = true
labelLen := i - begin
if labelLen >= 1<<6 { // top two bits of length must be clear
return labels, false
}
// off can already (we're in a loop) be bigger than lenmsg
// this happens when a name isn't fully qualified
off += 1 + labelLen
if off > lenmsg {
return labels, false
}
labels++
begin = i + 1
default:
wasDot = false
}
}
return labels, true
} }
// IsSubDomain checks if child is indeed a child of the parent. If child and parent // IsSubDomain checks if child is indeed a child of the parent. If child and parent
...@@ -181,7 +243,7 @@ func IsSubDomain(parent, child string) bool { ...@@ -181,7 +243,7 @@ func IsSubDomain(parent, child string) bool {
// The checking is performed on the binary payload. // The checking is performed on the binary payload.
func IsMsg(buf []byte) error { func IsMsg(buf []byte) error {
// Header // Header
if len(buf) < 12 { if len(buf) < headerSize {
return errors.New("dns: bad message header") return errors.New("dns: bad message header")
} }
// Header: Opcode // Header: Opcode
...@@ -191,11 +253,18 @@ func IsMsg(buf []byte) error { ...@@ -191,11 +253,18 @@ func IsMsg(buf []byte) error {
// IsFqdn checks if a domain name is fully qualified. // IsFqdn checks if a domain name is fully qualified.
func IsFqdn(s string) bool { func IsFqdn(s string) bool {
l := len(s) s2 := strings.TrimSuffix(s, ".")
if l == 0 { if s == s2 {
return false return false
} }
return s[l-1] == '.'
i := strings.LastIndexFunc(s2, func(r rune) bool {
return r != '\\'
})
// Test whether we have an even number of escape sequences before
// the dot or none.
return (len(s2)-i)%2 != 0
} }
// IsRRset checks if a set of RRs is a valid RRset as defined by RFC 2181. // IsRRset checks if a set of RRs is a valid RRset as defined by RFC 2181.
...@@ -244,12 +313,19 @@ func ReverseAddr(addr string) (arpa string, err error) { ...@@ -244,12 +313,19 @@ func ReverseAddr(addr string) (arpa string, err error) {
if ip == nil { if ip == nil {
return "", &Error{err: "unrecognized address: " + addr} return "", &Error{err: "unrecognized address: " + addr}
} }
if ip.To4() != nil { if v4 := ip.To4(); v4 != nil {
return strconv.Itoa(int(ip[15])) + "." + strconv.Itoa(int(ip[14])) + "." + strconv.Itoa(int(ip[13])) + "." + buf := make([]byte, 0, net.IPv4len*4+len("in-addr.arpa."))
strconv.Itoa(int(ip[12])) + ".in-addr.arpa.", nil // Add it, in reverse, to the buffer
for i := len(v4) - 1; i >= 0; i-- {
buf = strconv.AppendInt(buf, int64(v4[i]), 10)
buf = append(buf, '.')
}
// Append "in-addr.arpa." and return (buf already has the final .)
buf = append(buf, "in-addr.arpa."...)
return string(buf), nil
} }
// Must be IPv6 // Must be IPv6
buf := make([]byte, 0, len(ip)*4+len("ip6.arpa.")) buf := make([]byte, 0, net.IPv6len*4+len("ip6.arpa."))
// Add it, in reverse, to the buffer // Add it, in reverse, to the buffer
for i := len(ip) - 1; i >= 0; i-- { for i := len(ip) - 1; i >= 0; i-- {
v := ip[i] v := ip[i]
......
...@@ -41,8 +41,23 @@ type RR interface { ...@@ -41,8 +41,23 @@ type RR interface {
// size will be returned and domain names will be added to the map for future compression. // size will be returned and domain names will be added to the map for future compression.
len(off int, compression map[string]struct{}) int len(off int, compression map[string]struct{}) int
// pack packs an RR into wire format. // pack packs the records RDATA into wire format. The header will
pack(msg []byte, off int, compression compressionMap, compress bool) (headerEnd int, off1 int, err error) // already have been packed into msg.
pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error)
// unpack unpacks an RR from wire format.
//
// This will only be called on a new and empty RR type with only the header populated. It
// will only be called if the record's RDATA is non-empty.
unpack(msg []byte, off int) (off1 int, err error)
// parse parses an RR from zone file format.
//
// This will only be called on a new and empty RR type with only the header populated.
parse(c *zlexer, origin, file string) *ParseError
// isDuplicate returns whether the two RRs are duplicates.
isDuplicate(r2 RR) bool
} }
// RR_Header is the header all DNS resource records share. // RR_Header is the header all DNS resource records share.
...@@ -81,6 +96,19 @@ func (h *RR_Header) len(off int, compression map[string]struct{}) int { ...@@ -81,6 +96,19 @@ func (h *RR_Header) len(off int, compression map[string]struct{}) int {
return l return l
} }
func (h *RR_Header) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
// RR_Header has no RDATA to pack.
return off, nil
}
func (h *RR_Header) unpack(msg []byte, off int) (int, error) {
panic("dns: internal error: unpack should never be called on RR_Header")
}
func (h *RR_Header) parse(c *zlexer, origin, file string) *ParseError {
panic("dns: internal error: parse should never be called on RR_Header")
}
// ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597. // ToRFC3597 converts a known RR to the unknown RR representation from RFC 3597.
func (rr *RFC3597) ToRFC3597(r RR) error { func (rr *RFC3597) ToRFC3597(r RR) error {
buf := make([]byte, Len(r)*2) buf := make([]byte, Len(r)*2)
...@@ -90,14 +118,17 @@ func (rr *RFC3597) ToRFC3597(r RR) error { ...@@ -90,14 +118,17 @@ func (rr *RFC3597) ToRFC3597(r RR) error {
} }
buf = buf[:off] buf = buf[:off]
hdr := *r.Header() *rr = RFC3597{Hdr: *r.Header()}
hdr.Rdlength = uint16(off - headerEnd) rr.Hdr.Rdlength = uint16(off - headerEnd)
if noRdata(rr.Hdr) {
return nil
}
rfc3597, _, err := unpackRFC3597(hdr, buf, headerEnd) _, err = rr.unpack(buf, headerEnd)
if err != nil { if err != nil {
return err return err
} }
*rr = *rfc3597.(*RFC3597)
return nil return nil
} }
...@@ -67,9 +67,6 @@ var AlgorithmToString = map[uint8]string{ ...@@ -67,9 +67,6 @@ var AlgorithmToString = map[uint8]string{
PRIVATEOID: "PRIVATEOID", PRIVATEOID: "PRIVATEOID",
} }
// StringToAlgorithm is the reverse of AlgorithmToString.
var StringToAlgorithm = reverseInt8(AlgorithmToString)
// AlgorithmToHash is a map of algorithm crypto hash IDs to crypto.Hash's. // AlgorithmToHash is a map of algorithm crypto hash IDs to crypto.Hash's.
var AlgorithmToHash = map[uint8]crypto.Hash{ var AlgorithmToHash = map[uint8]crypto.Hash{
RSAMD5: crypto.MD5, // Deprecated in RFC 6725 RSAMD5: crypto.MD5, // Deprecated in RFC 6725
...@@ -102,9 +99,6 @@ var HashToString = map[uint8]string{ ...@@ -102,9 +99,6 @@ var HashToString = map[uint8]string{
SHA512: "SHA512", SHA512: "SHA512",
} }
// StringToHash is a map of names to hash IDs.
var StringToHash = reverseInt8(HashToString)
// DNSKEY flag values. // DNSKEY flag values.
const ( const (
SEP = 1 SEP = 1
...@@ -268,16 +262,17 @@ func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error { ...@@ -268,16 +262,17 @@ func (rr *RRSIG) Sign(k crypto.Signer, rrset []RR) error {
return ErrKey return ErrKey
} }
h0 := rrset[0].Header()
rr.Hdr.Rrtype = TypeRRSIG rr.Hdr.Rrtype = TypeRRSIG
rr.Hdr.Name = rrset[0].Header().Name rr.Hdr.Name = h0.Name
rr.Hdr.Class = rrset[0].Header().Class rr.Hdr.Class = h0.Class
if rr.OrigTtl == 0 { // If set don't override if rr.OrigTtl == 0 { // If set don't override
rr.OrigTtl = rrset[0].Header().Ttl rr.OrigTtl = h0.Ttl
} }
rr.TypeCovered = rrset[0].Header().Rrtype rr.TypeCovered = h0.Rrtype
rr.Labels = uint8(CountLabel(rrset[0].Header().Name)) rr.Labels = uint8(CountLabel(h0.Name))
if strings.HasPrefix(rrset[0].Header().Name, "*") { if strings.HasPrefix(h0.Name, "*") {
rr.Labels-- // wildcard, remove from label count rr.Labels-- // wildcard, remove from label count
} }
...@@ -411,10 +406,7 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error { ...@@ -411,10 +406,7 @@ func (rr *RRSIG) Verify(k *DNSKEY, rrset []RR) error {
// IsRRset checked that we have at least one RR and that the RRs in // IsRRset checked that we have at least one RR and that the RRs in
// the set have consistent type, class, and name. Also check that type and // the set have consistent type, class, and name. Also check that type and
// class matches the RRSIG record. // class matches the RRSIG record.
if rrset[0].Header().Class != rr.Hdr.Class { if h0 := rrset[0].Header(); h0.Class != rr.Hdr.Class || h0.Rrtype != rr.TypeCovered {
return ErrRRset
}
if rrset[0].Header().Rrtype != rr.TypeCovered {
return ErrRRset return ErrRRset
} }
...@@ -563,7 +555,7 @@ func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey { ...@@ -563,7 +555,7 @@ func (k *DNSKEY) publicKeyRSA() *rsa.PublicKey {
pubkey := new(rsa.PublicKey) pubkey := new(rsa.PublicKey)
expo := uint64(0) var expo uint64
for i := 0; i < int(explen); i++ { for i := 0; i < int(explen); i++ {
expo <<= 8 expo <<= 8
expo |= uint64(keybuf[keyoff+i]) expo |= uint64(keybuf[keyoff+i])
...@@ -658,15 +650,16 @@ func rawSignatureData(rrset []RR, s *RRSIG) (buf []byte, err error) { ...@@ -658,15 +650,16 @@ func rawSignatureData(rrset []RR, s *RRSIG) (buf []byte, err error) {
wires := make(wireSlice, len(rrset)) wires := make(wireSlice, len(rrset))
for i, r := range rrset { for i, r := range rrset {
r1 := r.copy() r1 := r.copy()
r1.Header().Ttl = s.OrigTtl h := r1.Header()
labels := SplitDomainName(r1.Header().Name) h.Ttl = s.OrigTtl
labels := SplitDomainName(h.Name)
// 6.2. Canonical RR Form. (4) - wildcards // 6.2. Canonical RR Form. (4) - wildcards
if len(labels) > int(s.Labels) { if len(labels) > int(s.Labels) {
// Wildcard // Wildcard
r1.Header().Name = "*." + strings.Join(labels[len(labels)-int(s.Labels):], ".") + "." h.Name = "*." + strings.Join(labels[len(labels)-int(s.Labels):], ".") + "."
} }
// RFC 4034: 6.2. Canonical RR Form. (2) - domain name to lowercase // RFC 4034: 6.2. Canonical RR Form. (2) - domain name to lowercase
r1.Header().Name = strings.ToLower(r1.Header().Name) h.Name = strings.ToLower(h.Name)
// 6.2. Canonical RR Form. (3) - domain rdata to lowercase. // 6.2. Canonical RR Form. (3) - domain rdata to lowercase.
// NS, MD, MF, CNAME, SOA, MB, MG, MR, PTR, // NS, MD, MF, CNAME, SOA, MB, MG, MR, PTR,
// HINFO, MINFO, MX, RP, AFSDB, RT, SIG, PX, NXT, NAPTR, KX, // HINFO, MINFO, MX, RP, AFSDB, RT, SIG, PX, NXT, NAPTR, KX,
......
...@@ -7,18 +7,31 @@ package dns ...@@ -7,18 +7,31 @@ package dns
// is so, otherwise false. // is so, otherwise false.
// It's is a protocol violation to have identical RRs in a message. // It's is a protocol violation to have identical RRs in a message.
func IsDuplicate(r1, r2 RR) bool { func IsDuplicate(r1, r2 RR) bool {
if r1.Header().Class != r2.Header().Class { // Check whether the record header is identical.
if !r1.Header().isDuplicate(r2.Header()) {
return false return false
} }
if r1.Header().Rrtype != r2.Header().Rrtype {
// Check whether the RDATA is identical.
return r1.isDuplicate(r2)
}
func (r1 *RR_Header) isDuplicate(_r2 RR) bool {
r2, ok := _r2.(*RR_Header)
if !ok {
return false
}
if r1.Class != r2.Class {
return false
}
if r1.Rrtype != r2.Rrtype {
return false return false
} }
if !isDulicateName(r1.Header().Name, r2.Header().Name) { if !isDulicateName(r1.Name, r2.Name) {
return false return false
} }
// ignore TTL // ignore TTL
return true
return isDuplicateRdata(r1, r2)
} }
// isDulicateName checks if the domain names s1 and s2 are equal. // isDulicateName checks if the domain names s1 and s2 are equal.
......
...@@ -57,10 +57,7 @@ func main() { ...@@ -57,10 +57,7 @@ func main() {
continue continue
} }
if name == "PrivateRR" || name == "RFC3597" { if name == "PrivateRR" || name == "OPT" {
continue
}
if name == "OPT" || name == "ANY" || name == "IXFR" || name == "AXFR" {
continue continue
} }
...@@ -70,22 +67,6 @@ func main() { ...@@ -70,22 +67,6 @@ func main() {
b := &bytes.Buffer{} b := &bytes.Buffer{}
b.WriteString(packageHdr) b.WriteString(packageHdr)
// Generate the giant switch that calls the correct function for each type.
fmt.Fprint(b, "// isDuplicateRdata calls the rdata specific functions\n")
fmt.Fprint(b, "func isDuplicateRdata(r1, r2 RR) bool {\n")
fmt.Fprint(b, "switch r1.Header().Rrtype {\n")
for _, name := range namedTypes {
o := scope.Lookup(name)
_, isEmbedded := getTypeStruct(o.Type(), scope)
if isEmbedded {
continue
}
fmt.Fprintf(b, "case Type%s:\nreturn isDuplicate%s(r1.(*%s), r2.(*%s))\n", name, name, name, name)
}
fmt.Fprintf(b, "}\nreturn false\n}\n")
// Generate the duplicate check for each type. // Generate the duplicate check for each type.
fmt.Fprint(b, "// isDuplicate() functions\n\n") fmt.Fprint(b, "// isDuplicate() functions\n\n")
for _, name := range namedTypes { for _, name := range namedTypes {
...@@ -95,7 +76,10 @@ func main() { ...@@ -95,7 +76,10 @@ func main() {
if isEmbedded { if isEmbedded {
continue continue
} }
fmt.Fprintf(b, "func isDuplicate%s(r1, r2 *%s) bool {\n", name, name) fmt.Fprintf(b, "func (r1 *%s) isDuplicate(_r2 RR) bool {\n", name)
fmt.Fprintf(b, "r2, ok := _r2.(*%s)\n", name)
fmt.Fprint(b, "if !ok { return false }\n")
fmt.Fprint(b, "_ = r2\n")
for i := 1; i < st.NumFields(); i++ { for i := 1; i < st.NumFields(); i++ {
field := st.Field(i).Name() field := st.Field(i).Name()
o2 := func(s string) { fmt.Fprintf(b, s+"\n", field, field) } o2 := func(s string) { fmt.Fprintf(b, s+"\n", field, field) }
...@@ -103,7 +87,7 @@ func main() { ...@@ -103,7 +87,7 @@ func main() {
// For some reason, a and aaaa don't pop up as *types.Slice here (mostly like because the are // For some reason, a and aaaa don't pop up as *types.Slice here (mostly like because the are
// *indirectly* defined as a slice in the net package). // *indirectly* defined as a slice in the net package).
if _, ok := st.Field(i).Type().(*types.Slice); ok || st.Tag(i) == `dns:"a"` || st.Tag(i) == `dns:"aaaa"` { if _, ok := st.Field(i).Type().(*types.Slice); ok {
o2("if len(r1.%s) != len(r2.%s) {\nreturn false\n}") o2("if len(r1.%s) != len(r2.%s) {\nreturn false\n}")
if st.Tag(i) == `dns:"cdomain-name"` || st.Tag(i) == `dns:"domain-name"` { if st.Tag(i) == `dns:"cdomain-name"` || st.Tag(i) == `dns:"domain-name"` {
...@@ -128,6 +112,8 @@ func main() { ...@@ -128,6 +112,8 @@ func main() {
switch st.Tag(i) { switch st.Tag(i) {
case `dns:"-"`: case `dns:"-"`:
// ignored // ignored
case `dns:"a"`, `dns:"aaaa"`:
o2("if !r1.%s.Equal(r2.%s) {\nreturn false\n}")
case `dns:"cdomain-name"`, `dns:"domain-name"`: case `dns:"cdomain-name"`, `dns:"domain-name"`:
o2("if !isDulicateName(r1.%s, r2.%s) {\nreturn false\n}") o2("if !isDulicateName(r1.%s, r2.%s) {\nreturn false\n}")
default: default:
......
...@@ -88,6 +88,12 @@ func (rr *OPT) len(off int, compression map[string]struct{}) int { ...@@ -88,6 +88,12 @@ func (rr *OPT) len(off int, compression map[string]struct{}) int {
return l return l
} }
func (rr *OPT) parse(c *zlexer, origin, file string) *ParseError {
panic("dns: internal error: parse should never be called on OPT")
}
func (r1 *OPT) isDuplicate(r2 RR) bool { return false }
// return the old value -> delete SetVersion? // return the old value -> delete SetVersion?
// Version returns the EDNS version used. Only zero is defined. // Version returns the EDNS version used. Only zero is defined.
...@@ -183,7 +189,7 @@ func (e *EDNS0_NSID) pack() ([]byte, error) { ...@@ -183,7 +189,7 @@ func (e *EDNS0_NSID) pack() ([]byte, error) {
// Option implements the EDNS0 interface. // Option implements the EDNS0 interface.
func (e *EDNS0_NSID) Option() uint16 { return EDNS0NSID } // Option returns the option code. func (e *EDNS0_NSID) Option() uint16 { return EDNS0NSID } // Option returns the option code.
func (e *EDNS0_NSID) unpack(b []byte) error { e.Nsid = hex.EncodeToString(b); return nil } func (e *EDNS0_NSID) unpack(b []byte) error { e.Nsid = hex.EncodeToString(b); return nil }
func (e *EDNS0_NSID) String() string { return string(e.Nsid) } func (e *EDNS0_NSID) String() string { return e.Nsid }
// EDNS0_SUBNET is the subnet option that is used to give the remote nameserver // EDNS0_SUBNET is the subnet option that is used to give the remote nameserver
// an idea of where the client lives. See RFC 7871. It can then give back a different // an idea of where the client lives. See RFC 7871. It can then give back a different
...@@ -411,7 +417,7 @@ func (e *EDNS0_LLQ) unpack(b []byte) error { ...@@ -411,7 +417,7 @@ func (e *EDNS0_LLQ) unpack(b []byte) error {
func (e *EDNS0_LLQ) String() string { func (e *EDNS0_LLQ) String() string {
s := strconv.FormatUint(uint64(e.Version), 10) + " " + strconv.FormatUint(uint64(e.Opcode), 10) + s := strconv.FormatUint(uint64(e.Version), 10) + " " + strconv.FormatUint(uint64(e.Opcode), 10) +
" " + strconv.FormatUint(uint64(e.Error), 10) + " " + strconv.FormatUint(uint64(e.Id), 10) + " " + strconv.FormatUint(uint64(e.Error), 10) + " " + strconv.FormatUint(e.Id, 10) +
" " + strconv.FormatUint(uint64(e.LeaseLife), 10) " " + strconv.FormatUint(uint64(e.LeaseLife), 10)
return s return s
} }
...@@ -498,10 +504,7 @@ func (e *EDNS0_EXPIRE) String() string { return strconv.FormatUint(uint64(e.Expi ...@@ -498,10 +504,7 @@ func (e *EDNS0_EXPIRE) String() string { return strconv.FormatUint(uint64(e.Expi
func (e *EDNS0_EXPIRE) pack() ([]byte, error) { func (e *EDNS0_EXPIRE) pack() ([]byte, error) {
b := make([]byte, 4) b := make([]byte, 4)
b[0] = byte(e.Expire >> 24) binary.BigEndian.PutUint32(b, e.Expire)
b[1] = byte(e.Expire >> 16)
b[2] = byte(e.Expire >> 8)
b[3] = byte(e.Expire)
return b, nil return b, nil
} }
......
...@@ -20,7 +20,7 @@ func Field(r RR, i int) string { ...@@ -20,7 +20,7 @@ func Field(r RR, i int) string {
return "" return ""
} }
d := reflect.ValueOf(r).Elem().Field(i) d := reflect.ValueOf(r).Elem().Field(i)
switch k := d.Kind(); k { switch d.Kind() {
case reflect.String: case reflect.String:
return d.String() return d.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
......
...@@ -16,7 +16,7 @@ func SplitDomainName(s string) (labels []string) { ...@@ -16,7 +16,7 @@ func SplitDomainName(s string) (labels []string) {
fqdnEnd := 0 // offset of the final '.' or the length of the name fqdnEnd := 0 // offset of the final '.' or the length of the name
idx := Split(s) idx := Split(s)
begin := 0 begin := 0
if s[len(s)-1] == '.' { if IsFqdn(s) {
fqdnEnd = len(s) - 1 fqdnEnd = len(s) - 1
} else { } else {
fqdnEnd = len(s) fqdnEnd = len(s)
...@@ -36,8 +36,7 @@ func SplitDomainName(s string) (labels []string) { ...@@ -36,8 +36,7 @@ func SplitDomainName(s string) (labels []string) {
} }
} }
labels = append(labels, s[begin:fqdnEnd]) return append(labels, s[begin:fqdnEnd])
return labels
} }
// CompareDomainName compares the names s1 and s2 and // CompareDomainName compares the names s1 and s2 and
......
...@@ -231,29 +231,21 @@ func (m compressionMap) find(s string) (int, bool) { ...@@ -231,29 +231,21 @@ func (m compressionMap) find(s string) (int, bool) {
// map needs to hold a mapping between domain names and offsets // map needs to hold a mapping between domain names and offsets
// pointing into msg. // pointing into msg.
func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) { func PackDomainName(s string, msg []byte, off int, compression map[string]int, compress bool) (off1 int, err error) {
off1, _, err = packDomainName(s, msg, off, compressionMap{ext: compression}, compress) return packDomainName(s, msg, off, compressionMap{ext: compression}, compress)
return
} }
func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, labels int, err error) { func packDomainName(s string, msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {
// special case if msg == nil // XXX: A logical copy of this function exists in IsDomainName and
lenmsg := 256 // should be kept in sync with this function.
if msg != nil {
lenmsg = len(msg)
}
ls := len(s) ls := len(s)
if ls == 0 { // Ok, for instance when dealing with update RR without any rdata. if ls == 0 { // Ok, for instance when dealing with update RR without any rdata.
return off, 0, nil return off, nil
} }
// If not fully qualified, error out, but only if msg != nil #ugly // If not fully qualified, error out.
if s[ls-1] != '.' { if !IsFqdn(s) {
if msg != nil { return len(msg), ErrFqdn
return lenmsg, 0, ErrFqdn
}
s += "."
ls++
} }
// Each dot ends a segment of the name. // Each dot ends a segment of the name.
...@@ -283,8 +275,8 @@ loop: ...@@ -283,8 +275,8 @@ loop:
switch c { switch c {
case '\\': case '\\':
if off+1 > lenmsg { if off+1 > len(msg) {
return lenmsg, labels, ErrBuf return len(msg), ErrBuf
} }
if bs == nil { if bs == nil {
...@@ -307,19 +299,19 @@ loop: ...@@ -307,19 +299,19 @@ loop:
case '.': case '.':
if wasDot { if wasDot {
// two dots back to back is not legal // two dots back to back is not legal
return lenmsg, labels, ErrRdata return len(msg), ErrRdata
} }
wasDot = true wasDot = true
labelLen := i - begin labelLen := i - begin
if labelLen >= 1<<6 { // top two bits of length must be clear if labelLen >= 1<<6 { // top two bits of length must be clear
return lenmsg, labels, ErrRdata return len(msg), ErrRdata
} }
// off can already (we're in a loop) be bigger than len(msg) // off can already (we're in a loop) be bigger than len(msg)
// this happens when a name isn't fully qualified // this happens when a name isn't fully qualified
if off+1+labelLen > lenmsg { if off+1+labelLen > len(msg) {
return lenmsg, labels, ErrBuf return len(msg), ErrBuf
} }
// Don't try to compress '.' // Don't try to compress '.'
...@@ -344,7 +336,6 @@ loop: ...@@ -344,7 +336,6 @@ loop:
} }
// The following is covered by the length check above. // The following is covered by the length check above.
if msg != nil {
msg[off] = byte(labelLen) msg[off] = byte(labelLen)
if bs == nil { if bs == nil {
...@@ -352,10 +343,8 @@ loop: ...@@ -352,10 +343,8 @@ loop:
} else { } else {
copy(msg[off+1:], bs[begin:i]) copy(msg[off+1:], bs[begin:i])
} }
}
off += 1 + labelLen off += 1 + labelLen
labels++
begin = i + 1 begin = i + 1
compBegin = begin + compOff compBegin = begin + compOff
default: default:
...@@ -365,22 +354,21 @@ loop: ...@@ -365,22 +354,21 @@ loop:
// Root label is special // Root label is special
if isRootLabel(s, bs, 0, ls) { if isRootLabel(s, bs, 0, ls) {
return off, labels, nil return off, nil
} }
// If we did compression and we find something add the pointer here // If we did compression and we find something add the pointer here
if pointer != -1 { if pointer != -1 {
// We have two bytes (14 bits) to put the pointer in // We have two bytes (14 bits) to put the pointer in
// if msg == nil, we will never do compression
binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000)) binary.BigEndian.PutUint16(msg[off:], uint16(pointer^0xC000))
return off + 2, labels, nil return off + 2, nil
} }
if msg != nil && off < lenmsg { if off < len(msg) {
msg[off] = 0 msg[off] = 0
} }
return off + 1, labels, nil return off + 1, nil
} }
// isRootLabel returns whether s or bs, from off to end, is the root // isRootLabel returns whether s or bs, from off to end, is the root
...@@ -633,7 +621,12 @@ func packRR(rr RR, msg []byte, off int, compression compressionMap, compress boo ...@@ -633,7 +621,12 @@ func packRR(rr RR, msg []byte, off int, compression compressionMap, compress boo
return len(msg), len(msg), &Error{err: "nil rr"} return len(msg), len(msg), &Error{err: "nil rr"}
} }
headerEnd, off1, err = rr.pack(msg, off, compression, compress) headerEnd, err = rr.Header().packHeader(msg, off, compression, compress)
if err != nil {
return headerEnd, len(msg), err
}
off1, err = rr.pack(msg, headerEnd, compression, compress)
if err != nil { if err != nil {
return headerEnd, len(msg), err return headerEnd, len(msg), err
} }
...@@ -661,17 +654,28 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) { ...@@ -661,17 +654,28 @@ func UnpackRR(msg []byte, off int) (rr RR, off1 int, err error) {
// UnpackRRWithHeader unpacks the record type specific payload given an existing // UnpackRRWithHeader unpacks the record type specific payload given an existing
// RR_Header. // RR_Header.
func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) { func UnpackRRWithHeader(h RR_Header, msg []byte, off int) (rr RR, off1 int, err error) {
if newFn, ok := TypeToRR[h.Rrtype]; ok {
rr = newFn()
*rr.Header() = h
} else {
rr = &RFC3597{Hdr: h}
}
if noRdata(h) {
return rr, off, nil
}
end := off + int(h.Rdlength) end := off + int(h.Rdlength)
if fn, known := typeToUnpack[h.Rrtype]; !known { off, err = rr.unpack(msg, off)
rr, off, err = unpackRFC3597(h, msg, off) if err != nil {
} else { return nil, end, err
rr, off, err = fn(h, msg, off)
} }
if off != end { if off != end {
return &h, end, &Error{err: "bad rdlength"} return &h, end, &Error{err: "bad rdlength"}
} }
return rr, off, err
return rr, off, nil
} }
// unpackRRslice unpacks msg[off:] into an []RR. // unpackRRslice unpacks msg[off:] into an []RR.
...@@ -984,7 +988,7 @@ func (dns *Msg) Len() int { ...@@ -984,7 +988,7 @@ func (dns *Msg) Len() int {
} }
func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int { func msgLenWithCompressionMap(dns *Msg, compression map[string]struct{}) int {
l := 12 // Message header is always 12 bytes l := headerSize
for _, r := range dns.Question { for _, r := range dns.Question {
l += r.len(l, compression) l += r.len(l, compression)
...@@ -1068,7 +1072,7 @@ func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, boo ...@@ -1068,7 +1072,7 @@ func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, boo
} }
// Copy returns a new RR which is a deep-copy of r. // Copy returns a new RR which is a deep-copy of r.
func Copy(r RR) RR { r1 := r.copy(); return r1 } func Copy(r RR) RR { return r.copy() }
// Len returns the length (in octets) of the uncompressed RR in wire format. // Len returns the length (in octets) of the uncompressed RR in wire format.
func Len(r RR) int { return r.len(0, nil) } func Len(r RR) int { return r.len(0, nil) }
...@@ -1120,7 +1124,7 @@ func (dns *Msg) CopyTo(r1 *Msg) *Msg { ...@@ -1120,7 +1124,7 @@ func (dns *Msg) CopyTo(r1 *Msg) *Msg {
} }
func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) { func (q *Question) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
off, _, err := packDomainName(q.Name, msg, off, compression, compress) off, err := packDomainName(q.Name, msg, off, compression, compress)
if err != nil { if err != nil {
return off, err return off, err
} }
...@@ -1183,7 +1187,10 @@ func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress ...@@ -1183,7 +1187,10 @@ func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress
return off, err return off, err
} }
off, err = packUint16(dh.Arcount, msg, off) off, err = packUint16(dh.Arcount, msg, off)
if err != nil {
return off, err return off, err
}
return off, nil
} }
func unpackMsgHdr(msg []byte, off int) (Header, int, error) { func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
...@@ -1212,7 +1219,10 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) { ...@@ -1212,7 +1219,10 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
return dh, off, err return dh, off, err
} }
dh.Arcount, off, err = unpackUint16(msg, off) dh.Arcount, off, err = unpackUint16(msg, off)
if err != nil {
return dh, off, err return dh, off, err
}
return dh, off, nil
} }
// setHdr set the header in the dns using the binary data in dh. // setHdr set the header in the dns using the binary data in dh.
......
...@@ -80,17 +80,12 @@ func main() { ...@@ -80,17 +80,12 @@ func main() {
o := scope.Lookup(name) o := scope.Lookup(name)
st, _ := getTypeStruct(o.Type(), scope) st, _ := getTypeStruct(o.Type(), scope)
fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (int, int, error) {\n", name) fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression compressionMap, compress bool) (off1 int, err error) {\n", name)
fmt.Fprint(b, `headerEnd, off, err := rr.Hdr.pack(msg, off, compression, compress)
if err != nil {
return headerEnd, off, err
}
`)
for i := 1; i < st.NumFields(); i++ { for i := 1; i < st.NumFields(); i++ {
o := func(s string) { o := func(s string) {
fmt.Fprintf(b, s, st.Field(i).Name()) fmt.Fprintf(b, s, st.Field(i).Name())
fmt.Fprint(b, `if err != nil { fmt.Fprint(b, `if err != nil {
return headerEnd, off, err return off, err
} }
`) `)
} }
...@@ -115,9 +110,9 @@ return headerEnd, off, err ...@@ -115,9 +110,9 @@ return headerEnd, off, err
switch { switch {
case st.Tag(i) == `dns:"-"`: // ignored case st.Tag(i) == `dns:"-"`: // ignored
case st.Tag(i) == `dns:"cdomain-name"`: case st.Tag(i) == `dns:"cdomain-name"`:
o("off, _, err = packDomainName(rr.%s, msg, off, compression, compress)\n") o("off, err = packDomainName(rr.%s, msg, off, compression, compress)\n")
case st.Tag(i) == `dns:"domain-name"`: case st.Tag(i) == `dns:"domain-name"`:
o("off, _, err = packDomainName(rr.%s, msg, off, compression, false)\n") o("off, err = packDomainName(rr.%s, msg, off, compression, false)\n")
case st.Tag(i) == `dns:"a"`: case st.Tag(i) == `dns:"a"`:
o("off, err = packDataA(rr.%s, msg, off)\n") o("off, err = packDataA(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"aaaa"`: case st.Tag(i) == `dns:"aaaa"`:
...@@ -144,7 +139,7 @@ return headerEnd, off, err ...@@ -144,7 +139,7 @@ return headerEnd, off, err
if rr.%s != "-" { if rr.%s != "-" {
off, err = packStringHex(rr.%s, msg, off) off, err = packStringHex(rr.%s, msg, off)
if err != nil { if err != nil {
return headerEnd, off, err return off, err
} }
} }
`, field, field) `, field, field)
...@@ -153,7 +148,8 @@ if rr.%s != "-" { ...@@ -153,7 +148,8 @@ if rr.%s != "-" {
fallthrough fallthrough
case st.Tag(i) == `dns:"hex"`: case st.Tag(i) == `dns:"hex"`:
o("off, err = packStringHex(rr.%s, msg, off)\n") o("off, err = packStringHex(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"any"`:
o("off, err = packStringAny(rr.%s, msg, off)\n")
case st.Tag(i) == `dns:"octet"`: case st.Tag(i) == `dns:"octet"`:
o("off, err = packStringOctet(rr.%s, msg, off)\n") o("off, err = packStringOctet(rr.%s, msg, off)\n")
case st.Tag(i) == "": case st.Tag(i) == "":
...@@ -175,7 +171,7 @@ if rr.%s != "-" { ...@@ -175,7 +171,7 @@ if rr.%s != "-" {
log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
} }
} }
fmt.Fprintln(b, "return headerEnd, off, nil }\n") fmt.Fprintln(b, "return off, nil }\n")
} }
fmt.Fprint(b, "// unpack*() functions\n\n") fmt.Fprint(b, "// unpack*() functions\n\n")
...@@ -183,14 +179,8 @@ if rr.%s != "-" { ...@@ -183,14 +179,8 @@ if rr.%s != "-" {
o := scope.Lookup(name) o := scope.Lookup(name)
st, _ := getTypeStruct(o.Type(), scope) st, _ := getTypeStruct(o.Type(), scope)
fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name) fmt.Fprintf(b, "func (rr *%s) unpack(msg []byte, off int) (off1 int, err error) {\n", name)
fmt.Fprintf(b, "rr := new(%s)\n", name) fmt.Fprint(b, `rdStart := off
fmt.Fprint(b, "rr.Hdr = h\n")
fmt.Fprint(b, `if noRdata(h) {
return rr, off, nil
}
var err error
rdStart := off
_ = rdStart _ = rdStart
`) `)
...@@ -198,7 +188,7 @@ _ = rdStart ...@@ -198,7 +188,7 @@ _ = rdStart
o := func(s string) { o := func(s string) {
fmt.Fprintf(b, s, st.Field(i).Name()) fmt.Fprintf(b, s, st.Field(i).Name())
fmt.Fprint(b, `if err != nil { fmt.Fprint(b, `if err != nil {
return rr, off, err return off, err
} }
`) `)
} }
...@@ -218,7 +208,7 @@ return rr, off, err ...@@ -218,7 +208,7 @@ return rr, off, err
log.Fatalln(name, st.Field(i).Name(), st.Tag(i)) log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
} }
fmt.Fprint(b, `if err != nil { fmt.Fprint(b, `if err != nil {
return rr, off, err return off, err
} }
`) `)
continue continue
...@@ -261,6 +251,8 @@ return rr, off, err ...@@ -261,6 +251,8 @@ return rr, off, err
o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"hex"`: case `dns:"hex"`:
o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n") o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"any"`:
o("rr.%s, off, err = unpackStringAny(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
case `dns:"octet"`: case `dns:"octet"`:
o("rr.%s, off, err = unpackStringOctet(msg, off)\n") o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
case "": case "":
...@@ -284,22 +276,13 @@ return rr, off, err ...@@ -284,22 +276,13 @@ return rr, off, err
// If we've hit len(msg) we return without error. // If we've hit len(msg) we return without error.
if i < st.NumFields()-1 { if i < st.NumFields()-1 {
fmt.Fprintf(b, `if off == len(msg) { fmt.Fprintf(b, `if off == len(msg) {
return rr, off, nil return off, nil
} }
`) `)
} }
} }
fmt.Fprintf(b, "return rr, off, err }\n\n") fmt.Fprintf(b, "return off, nil }\n\n")
}
// Generate typeToUnpack map
fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){")
for _, name := range namedTypes {
if name == "RFC3597" {
continue
}
fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name)
} }
fmt.Fprintln(b, "}\n")
// gofmt // gofmt
res, err := format.Source(b.Bytes()) res, err := format.Source(b.Bytes())
......
...@@ -99,34 +99,34 @@ func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte, ...@@ -99,34 +99,34 @@ func unpackHeader(msg []byte, off int) (rr RR_Header, off1 int, truncmsg []byte,
return hdr, off, msg, err return hdr, off, msg, err
} }
// pack packs an RR header, returning the offset to the end of the header. // packHeader packs an RR header, returning the offset to the end of the header.
// See PackDomainName for documentation about the compression. // See PackDomainName for documentation about the compression.
func (hdr RR_Header) pack(msg []byte, off int, compression compressionMap, compress bool) (int, int, error) { func (hdr RR_Header) packHeader(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
if off == len(msg) { if off == len(msg) {
return off, off, nil return off, nil
} }
off, _, err := packDomainName(hdr.Name, msg, off, compression, compress) off, err := packDomainName(hdr.Name, msg, off, compression, compress)
if err != nil { if err != nil {
return off, len(msg), err return len(msg), err
} }
off, err = packUint16(hdr.Rrtype, msg, off) off, err = packUint16(hdr.Rrtype, msg, off)
if err != nil { if err != nil {
return off, len(msg), err return len(msg), err
} }
off, err = packUint16(hdr.Class, msg, off) off, err = packUint16(hdr.Class, msg, off)
if err != nil { if err != nil {
return off, len(msg), err return len(msg), err
} }
off, err = packUint32(hdr.Ttl, msg, off) off, err = packUint32(hdr.Ttl, msg, off)
if err != nil { if err != nil {
return off, len(msg), err return len(msg), err
} }
off, err = packUint16(0, msg, off) // The RDLENGTH field will be set later in packRR. off, err = packUint16(0, msg, off) // The RDLENGTH field will be set later in packRR.
if err != nil { if err != nil {
return off, len(msg), err return len(msg), err
} }
return off, off, nil return off, nil
} }
// helper helper functions. // helper helper functions.
...@@ -177,14 +177,14 @@ func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) { ...@@ -177,14 +177,14 @@ func unpackUint8(msg []byte, off int) (i uint8, off1 int, err error) {
if off+1 > len(msg) { if off+1 > len(msg) {
return 0, len(msg), &Error{err: "overflow unpacking uint8"} return 0, len(msg), &Error{err: "overflow unpacking uint8"}
} }
return uint8(msg[off]), off + 1, nil return msg[off], off + 1, nil
} }
func packUint8(i uint8, msg []byte, off int) (off1 int, err error) { func packUint8(i uint8, msg []byte, off int) (off1 int, err error) {
if off+1 > len(msg) { if off+1 > len(msg) {
return len(msg), &Error{err: "overflow packing uint8"} return len(msg), &Error{err: "overflow packing uint8"}
} }
msg[off] = byte(i) msg[off] = i
return off + 1, nil return off + 1, nil
} }
...@@ -363,6 +363,22 @@ func packStringHex(s string, msg []byte, off int) (int, error) { ...@@ -363,6 +363,22 @@ func packStringHex(s string, msg []byte, off int) (int, error) {
return off, nil return off, nil
} }
func unpackStringAny(msg []byte, off, end int) (string, int, error) {
if end > len(msg) {
return "", len(msg), &Error{err: "overflow unpacking anything"}
}
return string(msg[off:end]), end, nil
}
func packStringAny(s string, msg []byte, off int) (int, error) {
if off+len(s) > len(msg) {
return len(msg), &Error{err: "overflow packing anything"}
}
copy(msg[off:off+len(s)], s)
off += len(s)
return off, nil
}
func unpackStringTxt(msg []byte, off int) ([]string, int, error) { func unpackStringTxt(msg []byte, off int) ([]string, int, error) {
txt, off, err := unpackTxt(msg, off) txt, off, err := unpackTxt(msg, off)
if err != nil { if err != nil {
...@@ -383,7 +399,7 @@ func packStringTxt(s []string, msg []byte, off int) (int, error) { ...@@ -383,7 +399,7 @@ func packStringTxt(s []string, msg []byte, off int) (int, error) {
func unpackDataOpt(msg []byte, off int) ([]EDNS0, int, error) { func unpackDataOpt(msg []byte, off int) ([]EDNS0, int, error) {
var edns []EDNS0 var edns []EDNS0
Option: Option:
code := uint16(0) var code uint16
if off+4 > len(msg) { if off+4 > len(msg) {
return nil, len(msg), &Error{err: "overflow unpacking opt"} return nil, len(msg), &Error{err: "overflow unpacking opt"}
} }
...@@ -624,7 +640,7 @@ func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) { ...@@ -624,7 +640,7 @@ func unpackDataDomainNames(msg []byte, off, end int) ([]string, int, error) {
func packDataDomainNames(names []string, msg []byte, off int, compression compressionMap, compress bool) (int, error) { func packDataDomainNames(names []string, msg []byte, off int, compression compressionMap, compress bool) (int, error) {
var err error var err error
for j := 0; j < len(names); j++ { for j := 0; j < len(names); j++ {
off, _, err = packDomainName(names[j], msg, off, compression, compress) off, err = packDomainName(names[j], msg, off, compression, compress)
if err != nil { if err != nil {
return len(msg), err return len(msg), err
} }
......
...@@ -39,11 +39,12 @@ func mkPrivateRR(rrtype uint16) *PrivateRR { ...@@ -39,11 +39,12 @@ func mkPrivateRR(rrtype uint16) *PrivateRR {
} }
anyrr := rrfunc() anyrr := rrfunc()
switch rr := anyrr.(type) { rr, ok := anyrr.(*PrivateRR)
case *PrivateRR: if !ok {
return rr
}
panic(fmt.Sprintf("dns: RR is not a PrivateRR, TypeToRR[%d] generator returned %T", rrtype, anyrr)) panic(fmt.Sprintf("dns: RR is not a PrivateRR, TypeToRR[%d] generator returned %T", rrtype, anyrr))
}
return rr
} }
// Header return the RR header of r. // Header return the RR header of r.
...@@ -70,52 +71,25 @@ func (r *PrivateRR) copy() RR { ...@@ -70,52 +71,25 @@ func (r *PrivateRR) copy() RR {
return rr return rr
} }
func (r *PrivateRR) pack(msg []byte, off int, compression compressionMap, compress bool) (int, int, error) { func (r *PrivateRR) pack(msg []byte, off int, compression compressionMap, compress bool) (int, error) {
headerEnd, off, err := r.Hdr.pack(msg, off, compression, compress)
if err != nil {
return off, off, err
}
n, err := r.Data.Pack(msg[off:]) n, err := r.Data.Pack(msg[off:])
if err != nil { if err != nil {
return headerEnd, len(msg), err return len(msg), err
} }
off += n off += n
return headerEnd, off, nil return off, nil
} }
// PrivateHandle registers a private resource record type. It requires func (r *PrivateRR) unpack(msg []byte, off int) (int, error) {
// string and numeric representation of private RR type and generator function as argument. off1, err := r.Data.Unpack(msg[off:])
func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata) {
rtypestr = strings.ToUpper(rtypestr)
TypeToRR[rtype] = func() RR { return &PrivateRR{RR_Header{}, generator()} }
TypeToString[rtype] = rtypestr
StringToType[rtypestr] = rtype
typeToUnpack[rtype] = func(h RR_Header, msg []byte, off int) (RR, int, error) {
if noRdata(h) {
return &h, off, nil
}
var err error
rr := mkPrivateRR(h.Rrtype)
rr.Hdr = h
off1, err := rr.Data.Unpack(msg[off:])
off += off1 off += off1
if err != nil { return off, err
return rr, off, err }
}
return rr, off, err
}
setPrivateRR := func(h RR_Header, c *zlexer, o, f string) (RR, *ParseError, string) {
rr := mkPrivateRR(h.Rrtype)
rr.Hdr = h
func (r *PrivateRR) parse(c *zlexer, origin, file string) *ParseError {
var l lex var l lex
text := make([]string, 0, 2) // could be 0..N elements, median is probably 1 text := make([]string, 0, 2) // could be 0..N elements, median is probably 1
Fetch: Fetch:
for { for {
// TODO(miek): we could also be returning _QUOTE, this might or might not // TODO(miek): we could also be returning _QUOTE, this might or might not
// be an issue (basically parsing TXT becomes hard) // be an issue (basically parsing TXT becomes hard)
...@@ -127,15 +101,24 @@ func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata) ...@@ -127,15 +101,24 @@ func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata)
} }
} }
err := rr.Data.Parse(text) err := r.Data.Parse(text)
if err != nil { if err != nil {
return nil, &ParseError{f, err.Error(), l}, "" return &ParseError{file, err.Error(), l}
} }
return rr, nil, "" return nil
} }
typeToparserFunc[rtype] = parserFunc{setPrivateRR, true} func (r1 *PrivateRR) isDuplicate(r2 RR) bool { return false }
// PrivateHandle registers a private resource record type. It requires
// string and numeric representation of private RR type and generator function as argument.
func PrivateHandle(rtypestr string, rtype uint16, generator func() PrivateRdata) {
rtypestr = strings.ToUpper(rtypestr)
TypeToRR[rtype] = func() RR { return &PrivateRR{RR_Header{}, generator()} }
TypeToString[rtype] = rtypestr
StringToType[rtypestr] = rtype
} }
// PrivateHandleRemove removes definitions required to support private RR type. // PrivateHandleRemove removes definitions required to support private RR type.
...@@ -144,8 +127,6 @@ func PrivateHandleRemove(rtype uint16) { ...@@ -144,8 +127,6 @@ func PrivateHandleRemove(rtype uint16) {
if ok { if ok {
delete(TypeToRR, rtype) delete(TypeToRR, rtype)
delete(TypeToString, rtype) delete(TypeToString, rtype)
delete(typeToparserFunc, rtype)
delete(StringToType, rtypestr) delete(StringToType, rtypestr)
delete(typeToUnpack, rtype)
} }
} }
...@@ -17,6 +17,15 @@ func init() { ...@@ -17,6 +17,15 @@ func init() {
StringToRcode["NOTIMPL"] = RcodeNotImplemented StringToRcode["NOTIMPL"] = RcodeNotImplemented
} }
// StringToAlgorithm is the reverse of AlgorithmToString.
var StringToAlgorithm = reverseInt8(AlgorithmToString)
// StringToHash is a map of names to hash IDs.
var StringToHash = reverseInt8(HashToString)
// StringToCertType is the reverseof CertTypeToString.
var StringToCertType = reverseInt16(CertTypeToString)
// Reverse a map // Reverse a map
func reverseInt8(m map[uint8]string) map[string]uint8 { func reverseInt8(m map[uint8]string) map[string]uint8 {
n := make(map[string]uint8, len(m)) n := make(map[string]uint8, len(m))
......
...@@ -15,10 +15,11 @@ func Dedup(rrs []RR, m map[string]RR) []RR { ...@@ -15,10 +15,11 @@ func Dedup(rrs []RR, m map[string]RR) []RR {
for _, r := range rrs { for _, r := range rrs {
key := normalizedString(r) key := normalizedString(r)
keys = append(keys, &key) keys = append(keys, &key)
if _, ok := m[key]; ok { if mr, ok := m[key]; ok {
// Shortest TTL wins. // Shortest TTL wins.
if m[key].Header().Ttl > r.Header().Ttl { rh, mrh := r.Header(), mr.Header()
m[key].Header().Ttl = r.Header().Ttl if mrh.Ttl > rh.Ttl {
mrh.Ttl = rh.Ttl
} }
continue continue
} }
......
...@@ -85,7 +85,6 @@ type lex struct { ...@@ -85,7 +85,6 @@ type lex struct {
torc uint16 // type or class as parsed in the lexer, we only need to look this up in the grammar torc uint16 // type or class as parsed in the lexer, we only need to look this up in the grammar
line int // line in the file line int // line in the file
column int // column in the file column int // column in the file
comment string // any comment text seen
} }
// Token holds the token that are returned when a zone file is parsed. // Token holds the token that are returned when a zone file is parsed.
...@@ -244,8 +243,6 @@ type ZoneParser struct { ...@@ -244,8 +243,6 @@ type ZoneParser struct {
sub *ZoneParser sub *ZoneParser
osFile *os.File osFile *os.File
com string
includeDepth uint8 includeDepth uint8
includeAllowed bool includeAllowed bool
...@@ -318,12 +315,19 @@ func (zp *ZoneParser) setParseError(err string, l lex) (RR, bool) { ...@@ -318,12 +315,19 @@ func (zp *ZoneParser) setParseError(err string, l lex) (RR, bool) {
// Comment returns an optional text comment that occurred alongside // Comment returns an optional text comment that occurred alongside
// the RR. // the RR.
func (zp *ZoneParser) Comment() string { func (zp *ZoneParser) Comment() string {
return zp.com if zp.parseErr != nil {
return ""
}
if zp.sub != nil {
return zp.sub.Comment()
}
return zp.c.Comment()
} }
func (zp *ZoneParser) subNext() (RR, bool) { func (zp *ZoneParser) subNext() (RR, bool) {
if rr, ok := zp.sub.Next(); ok { if rr, ok := zp.sub.Next(); ok {
zp.com = zp.sub.com
return rr, true return rr, true
} }
...@@ -347,8 +351,6 @@ func (zp *ZoneParser) subNext() (RR, bool) { ...@@ -347,8 +351,6 @@ func (zp *ZoneParser) subNext() (RR, bool) {
// error. After Next returns (nil, false), the Err method will return // error. After Next returns (nil, false), the Err method will return
// any error that occurred during parsing. // any error that occurred during parsing.
func (zp *ZoneParser) Next() (RR, bool) { func (zp *ZoneParser) Next() (RR, bool) {
zp.com = ""
if zp.parseErr != nil { if zp.parseErr != nil {
return nil, false return nil, false
} }
...@@ -501,7 +503,7 @@ func (zp *ZoneParser) Next() (RR, bool) { ...@@ -501,7 +503,7 @@ func (zp *ZoneParser) Next() (RR, bool) {
return zp.setParseError("expecting $TTL value, not this...", l) return zp.setParseError("expecting $TTL value, not this...", l)
} }
if e, _ := slurpRemainder(zp.c, zp.file); e != nil { if e := slurpRemainder(zp.c, zp.file); e != nil {
zp.parseErr = e zp.parseErr = e
return nil, false return nil, false
} }
...@@ -525,7 +527,7 @@ func (zp *ZoneParser) Next() (RR, bool) { ...@@ -525,7 +527,7 @@ func (zp *ZoneParser) Next() (RR, bool) {
return zp.setParseError("expecting $ORIGIN value, not this...", l) return zp.setParseError("expecting $ORIGIN value, not this...", l)
} }
if e, _ := slurpRemainder(zp.c, zp.file); e != nil { if e := slurpRemainder(zp.c, zp.file); e != nil {
zp.parseErr = e zp.parseErr = e
return nil, false return nil, false
} }
...@@ -648,7 +650,7 @@ func (zp *ZoneParser) Next() (RR, bool) { ...@@ -648,7 +650,7 @@ func (zp *ZoneParser) Next() (RR, bool) {
st = zExpectRdata st = zExpectRdata
case zExpectRdata: case zExpectRdata:
r, e, c1 := setRR(*h, zp.c, zp.origin, zp.file) r, e := setRR(*h, zp.c, zp.origin, zp.file)
if e != nil { if e != nil {
// If e.lex is nil than we have encounter a unknown RR type // If e.lex is nil than we have encounter a unknown RR type
// in that case we substitute our current lex token // in that case we substitute our current lex token
...@@ -660,7 +662,6 @@ func (zp *ZoneParser) Next() (RR, bool) { ...@@ -660,7 +662,6 @@ func (zp *ZoneParser) Next() (RR, bool) {
return nil, false return nil, false
} }
zp.com = c1
return r, true return r, true
} }
} }
...@@ -678,7 +679,8 @@ type zlexer struct { ...@@ -678,7 +679,8 @@ type zlexer struct {
line int line int
column int column int
com string comBuf string
comment string
l lex l lex
...@@ -767,14 +769,15 @@ func (zl *zlexer) Next() (lex, bool) { ...@@ -767,14 +769,15 @@ func (zl *zlexer) Next() (lex, bool) {
escape bool escape bool
) )
if zl.com != "" { if zl.comBuf != "" {
comi = copy(com[:], zl.com) comi = copy(com[:], zl.comBuf)
zl.com = "" zl.comBuf = ""
} }
zl.comment = ""
for x, ok := zl.readByte(); ok; x, ok = zl.readByte() { for x, ok := zl.readByte(); ok; x, ok = zl.readByte() {
l.line, l.column = zl.line, zl.column l.line, l.column = zl.line, zl.column
l.comment = ""
if stri >= len(str) { if stri >= len(str) {
l.token = "token length insufficient for parsing" l.token = "token length insufficient for parsing"
...@@ -898,7 +901,7 @@ func (zl *zlexer) Next() (lex, bool) { ...@@ -898,7 +901,7 @@ func (zl *zlexer) Next() (lex, bool) {
} }
zl.commt = true zl.commt = true
zl.com = "" zl.comBuf = ""
if comi > 1 { if comi > 1 {
// A newline was previously seen inside a comment that // A newline was previously seen inside a comment that
...@@ -911,7 +914,7 @@ func (zl *zlexer) Next() (lex, bool) { ...@@ -911,7 +914,7 @@ func (zl *zlexer) Next() (lex, bool) {
comi++ comi++
if stri > 0 { if stri > 0 {
zl.com = string(com[:comi]) zl.comBuf = string(com[:comi])
l.value = zString l.value = zString
l.token = string(str[:stri]) l.token = string(str[:stri])
...@@ -947,11 +950,11 @@ func (zl *zlexer) Next() (lex, bool) { ...@@ -947,11 +950,11 @@ func (zl *zlexer) Next() (lex, bool) {
l.value = zNewline l.value = zNewline
l.token = "\n" l.token = "\n"
l.comment = string(com[:comi]) zl.comment = string(com[:comi])
return *l, true return *l, true
} }
zl.com = string(com[:comi]) zl.comBuf = string(com[:comi])
break break
} }
...@@ -977,9 +980,9 @@ func (zl *zlexer) Next() (lex, bool) { ...@@ -977,9 +980,9 @@ func (zl *zlexer) Next() (lex, bool) {
l.value = zNewline l.value = zNewline
l.token = "\n" l.token = "\n"
l.comment = zl.com
zl.com = "" zl.comment = zl.comBuf
zl.comBuf = ""
zl.rrtype = false zl.rrtype = false
zl.owner = true zl.owner = true
...@@ -1115,7 +1118,7 @@ func (zl *zlexer) Next() (lex, bool) { ...@@ -1115,7 +1118,7 @@ func (zl *zlexer) Next() (lex, bool) {
// Send remainder of com // Send remainder of com
l.value = zNewline l.value = zNewline
l.token = "\n" l.token = "\n"
l.comment = string(com[:comi]) zl.comment = string(com[:comi])
if retL != (lex{}) { if retL != (lex{}) {
zl.nextL = true zl.nextL = true
...@@ -1126,7 +1129,6 @@ func (zl *zlexer) Next() (lex, bool) { ...@@ -1126,7 +1129,6 @@ func (zl *zlexer) Next() (lex, bool) {
} }
if zl.brace != 0 { if zl.brace != 0 {
l.comment = "" // in case there was left over string and comment
l.token = "unbalanced brace" l.token = "unbalanced brace"
l.err = true l.err = true
return *l, true return *l, true
...@@ -1135,6 +1137,14 @@ func (zl *zlexer) Next() (lex, bool) { ...@@ -1135,6 +1137,14 @@ func (zl *zlexer) Next() (lex, bool) {
return lex{value: zEOF}, false return lex{value: zEOF}, false
} }
func (zl *zlexer) Comment() string {
if zl.l.err {
return ""
}
return zl.comment
}
// Extract the class number from CLASSxx // Extract the class number from CLASSxx
func classToInt(token string) (uint16, bool) { func classToInt(token string) (uint16, bool) {
offset := 5 offset := 5
...@@ -1163,8 +1173,7 @@ func typeToInt(token string) (uint16, bool) { ...@@ -1163,8 +1173,7 @@ func typeToInt(token string) (uint16, bool) {
// stringToTTL parses things like 2w, 2m, etc, and returns the time in seconds. // stringToTTL parses things like 2w, 2m, etc, and returns the time in seconds.
func stringToTTL(token string) (uint32, bool) { func stringToTTL(token string) (uint32, bool) {
s := uint32(0) var s, i uint32
i := uint32(0)
for _, c := range token { for _, c := range token {
switch c { switch c {
case 's', 'S': case 's', 'S':
...@@ -1252,7 +1261,7 @@ func toAbsoluteName(name, origin string) (absolute string, ok bool) { ...@@ -1252,7 +1261,7 @@ func toAbsoluteName(name, origin string) (absolute string, ok bool) {
} }
// check if name is already absolute // check if name is already absolute
if name[len(name)-1] == '.' { if IsFqdn(name) {
return name, true return name, true
} }
...@@ -1292,24 +1301,21 @@ func locCheckEast(token string, longitude uint32) (uint32, bool) { ...@@ -1292,24 +1301,21 @@ func locCheckEast(token string, longitude uint32) (uint32, bool) {
return longitude, false return longitude, false
} }
// "Eat" the rest of the "line". Return potential comments // "Eat" the rest of the "line"
func slurpRemainder(c *zlexer, f string) (*ParseError, string) { func slurpRemainder(c *zlexer, f string) *ParseError {
l, _ := c.Next() l, _ := c.Next()
com := ""
switch l.value { switch l.value {
case zBlank: case zBlank:
l, _ = c.Next() l, _ = c.Next()
com = l.comment
if l.value != zNewline && l.value != zEOF { if l.value != zNewline && l.value != zEOF {
return &ParseError{f, "garbage after rdata", l}, "" return &ParseError{f, "garbage after rdata", l}
} }
case zNewline: case zNewline:
com = l.comment
case zEOF: case zEOF:
default: default:
return &ParseError{f, "garbage after rdata", l}, "" return &ParseError{f, "garbage after rdata", l}
} }
return nil, com return nil
} }
// Parse a 64 bit-like ipv6 address: "0014:4fff:ff20:ee64" // Parse a 64 bit-like ipv6 address: "0014:4fff:ff20:ee64"
......
This diff is collapsed.
...@@ -162,11 +162,11 @@ type defaultReader struct { ...@@ -162,11 +162,11 @@ type defaultReader struct {
*Server *Server
} }
func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) { func (dr defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
return dr.readTCP(conn, timeout) return dr.readTCP(conn, timeout)
} }
func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) { func (dr defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
return dr.readUDP(conn, timeout) return dr.readUDP(conn, timeout)
} }
...@@ -463,11 +463,10 @@ var testShutdownNotify *sync.Cond ...@@ -463,11 +463,10 @@ var testShutdownNotify *sync.Cond
// getReadTimeout is a helper func to use system timeout if server did not intend to change it. // getReadTimeout is a helper func to use system timeout if server did not intend to change it.
func (srv *Server) getReadTimeout() time.Duration { func (srv *Server) getReadTimeout() time.Duration {
rtimeout := dnsTimeout
if srv.ReadTimeout != 0 { if srv.ReadTimeout != 0 {
rtimeout = srv.ReadTimeout return srv.ReadTimeout
} }
return rtimeout return dnsTimeout
} }
// serveTCP starts a TCP listener for the server. // serveTCP starts a TCP listener for the server.
...@@ -518,7 +517,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error { ...@@ -518,7 +517,7 @@ func (srv *Server) serveUDP(l *net.UDPConn) error {
srv.NotifyStartedFunc() srv.NotifyStartedFunc()
} }
reader := Reader(&defaultReader{srv}) reader := Reader(defaultReader{srv})
if srv.DecorateReader != nil { if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader) reader = srv.DecorateReader(reader)
} }
...@@ -588,7 +587,7 @@ func (srv *Server) serve(w *response) { ...@@ -588,7 +587,7 @@ func (srv *Server) serve(w *response) {
w.wg.Done() w.wg.Done()
}() }()
reader := Reader(&defaultReader{srv}) reader := Reader(defaultReader{srv})
if srv.DecorateReader != nil { if srv.DecorateReader != nil {
reader = srv.DecorateReader(reader) reader = srv.DecorateReader(reader)
} }
...@@ -783,8 +782,7 @@ func (w *response) Write(m []byte) (int, error) { ...@@ -783,8 +782,7 @@ func (w *response) Write(m []byte) (int, error) {
switch { switch {
case w.udp != nil: case w.udp != nil:
n, err := WriteToSessionUDP(w.udp, m, w.udpSession) return WriteToSessionUDP(w.udp, m, w.udpSession)
return n, err
case w.tcp != nil: case w.tcp != nil:
lm := len(m) lm := len(m)
if lm < 2 { if lm < 2 {
......
...@@ -21,13 +21,9 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) { ...@@ -21,13 +21,9 @@ func (rr *SIG) Sign(k crypto.Signer, m *Msg) ([]byte, error) {
if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 { if rr.KeyTag == 0 || len(rr.SignerName) == 0 || rr.Algorithm == 0 {
return nil, ErrKey return nil, ErrKey
} }
rr.Header().Rrtype = TypeSIG
rr.Header().Class = ClassANY rr.Hdr = RR_Header{Name: ".", Rrtype: TypeSIG, Class: ClassANY, Ttl: 0}
rr.Header().Ttl = 0 rr.OrigTtl, rr.TypeCovered, rr.Labels = 0, 0, 0
rr.Header().Name = "."
rr.OrigTtl = 0
rr.TypeCovered = 0
rr.Labels = 0
buf := make([]byte, m.Len()+Len(rr)) buf := make([]byte, m.Len()+Len(rr))
mbuf, err := m.PackBuffer(buf) mbuf, err := m.PackBuffer(buf)
...@@ -107,7 +103,7 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error { ...@@ -107,7 +103,7 @@ func (rr *SIG) Verify(k *KEY, buf []byte) error {
anc := binary.BigEndian.Uint16(buf[6:]) anc := binary.BigEndian.Uint16(buf[6:])
auc := binary.BigEndian.Uint16(buf[8:]) auc := binary.BigEndian.Uint16(buf[8:])
adc := binary.BigEndian.Uint16(buf[10:]) adc := binary.BigEndian.Uint16(buf[10:])
offset := 12 offset := headerSize
var err error var err error
for i := uint16(0); i < qdc && offset < buflen; i++ { for i := uint16(0); i < qdc && offset < buflen; i++ {
_, offset, err = UnpackDomainName(buf, offset) _, offset, err = UnpackDomainName(buf, offset)
......
...@@ -23,6 +23,8 @@ type call struct { ...@@ -23,6 +23,8 @@ type call struct {
type singleflight struct { type singleflight struct {
sync.Mutex // protects m sync.Mutex // protects m
m map[string]*call // lazily initialized m map[string]*call // lazily initialized
dontDeleteForTesting bool // this is only to be used by TestConcurrentExchanges
} }
// Do executes and returns the results of the given function, making // Do executes and returns the results of the given function, making
...@@ -49,9 +51,11 @@ func (g *singleflight) Do(key string, fn func() (*Msg, time.Duration, error)) (v ...@@ -49,9 +51,11 @@ func (g *singleflight) Do(key string, fn func() (*Msg, time.Duration, error)) (v
c.val, c.rtt, c.err = fn() c.val, c.rtt, c.err = fn()
c.wg.Done() c.wg.Done()
if !g.dontDeleteForTesting {
g.Lock() g.Lock()
delete(g.m, key) delete(g.m, key)
g.Unlock() g.Unlock()
}
return c.val, c.rtt, c.err, c.dups > 0 return c.val, c.rtt, c.err, c.dups > 0
} }
...@@ -14,10 +14,7 @@ func (r *SMIMEA) Sign(usage, selector, matchingType int, cert *x509.Certificate) ...@@ -14,10 +14,7 @@ func (r *SMIMEA) Sign(usage, selector, matchingType int, cert *x509.Certificate)
r.MatchingType = uint8(matchingType) r.MatchingType = uint8(matchingType)
r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert) r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert)
if err != nil {
return err return err
}
return nil
} }
// Verify verifies a SMIMEA record against an SSL certificate. If it is OK // Verify verifies a SMIMEA record against an SSL certificate. If it is OK
......
...@@ -14,10 +14,7 @@ func (r *TLSA) Sign(usage, selector, matchingType int, cert *x509.Certificate) ( ...@@ -14,10 +14,7 @@ func (r *TLSA) Sign(usage, selector, matchingType int, cert *x509.Certificate) (
r.MatchingType = uint8(matchingType) r.MatchingType = uint8(matchingType)
r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert) r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert)
if err != nil {
return err return err
}
return nil
} }
// Verify verifies a TLSA record against an SSL certificate. If it is OK // Verify verifies a TLSA record against an SSL certificate. If it is OK
......
...@@ -54,6 +54,10 @@ func (rr *TSIG) String() string { ...@@ -54,6 +54,10 @@ func (rr *TSIG) String() string {
return s return s
} }
func (rr *TSIG) parse(c *zlexer, origin, file string) *ParseError {
panic("dns: internal error: parse should never be called on TSIG")
}
// The following values must be put in wireformat, so that the MAC can be calculated. // The following values must be put in wireformat, so that the MAC can be calculated.
// RFC 2845, section 3.4.2. TSIG Variables. // RFC 2845, section 3.4.2. TSIG Variables.
type tsigWireFmt struct { type tsigWireFmt struct {
...@@ -113,13 +117,13 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s ...@@ -113,13 +117,13 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
var h hash.Hash var h hash.Hash
switch strings.ToLower(rr.Algorithm) { switch strings.ToLower(rr.Algorithm) {
case HmacMD5: case HmacMD5:
h = hmac.New(md5.New, []byte(rawsecret)) h = hmac.New(md5.New, rawsecret)
case HmacSHA1: case HmacSHA1:
h = hmac.New(sha1.New, []byte(rawsecret)) h = hmac.New(sha1.New, rawsecret)
case HmacSHA256: case HmacSHA256:
h = hmac.New(sha256.New, []byte(rawsecret)) h = hmac.New(sha256.New, rawsecret)
case HmacSHA512: case HmacSHA512:
h = hmac.New(sha512.New, []byte(rawsecret)) h = hmac.New(sha512.New, rawsecret)
default: default:
return nil, "", ErrKeyAlg return nil, "", ErrKeyAlg
} }
...@@ -134,12 +138,11 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s ...@@ -134,12 +138,11 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
t.OrigId = m.Id t.OrigId = m.Id
tbuf := make([]byte, Len(t)) tbuf := make([]byte, Len(t))
if off, err := PackRR(t, tbuf, 0, nil, false); err == nil { off, err := PackRR(t, tbuf, 0, nil, false)
tbuf = tbuf[:off] // reset to actual size used if err != nil {
} else {
return nil, "", err return nil, "", err
} }
mbuf = append(mbuf, tbuf...) mbuf = append(mbuf, tbuf[:off]...)
// Update the ArCount directly in the buffer. // Update the ArCount directly in the buffer.
binary.BigEndian.PutUint16(mbuf[10:], uint16(len(m.Extra)+1)) binary.BigEndian.PutUint16(mbuf[10:], uint16(len(m.Extra)+1))
......
...@@ -205,9 +205,6 @@ var CertTypeToString = map[uint16]string{ ...@@ -205,9 +205,6 @@ var CertTypeToString = map[uint16]string{
CertOID: "OID", CertOID: "OID",
} }
// StringToCertType is the reverseof CertTypeToString.
var StringToCertType = reverseInt16(CertTypeToString)
//go:generate go run types_generate.go //go:generate go run types_generate.go
// Question holds a DNS question. There can be multiple questions in the // Question holds a DNS question. There can be multiple questions in the
...@@ -241,6 +238,25 @@ type ANY struct { ...@@ -241,6 +238,25 @@ type ANY struct {
func (rr *ANY) String() string { return rr.Hdr.String() } func (rr *ANY) String() string { return rr.Hdr.String() }
func (rr *ANY) parse(c *zlexer, origin, file string) *ParseError {
panic("dns: internal error: parse should never be called on ANY")
}
// NULL RR. See RFC 1035.
type NULL struct {
Hdr RR_Header
Data string `dns:"any"`
}
func (rr *NULL) String() string {
// There is no presentation format; prefix string with a comment.
return ";" + rr.Hdr.String() + rr.Data
}
func (rr *NULL) parse(c *zlexer, origin, file string) *ParseError {
panic("dns: internal error: parse should never be called on NULL")
}
// CNAME RR. See RFC 1034. // CNAME RR. See RFC 1034.
type CNAME struct { type CNAME struct {
Hdr RR_Header Hdr RR_Header
...@@ -1050,10 +1066,16 @@ type TKEY struct { ...@@ -1050,10 +1066,16 @@ type TKEY struct {
// TKEY has no official presentation format, but this will suffice. // TKEY has no official presentation format, but this will suffice.
func (rr *TKEY) String() string { func (rr *TKEY) String() string {
s := "\n;; TKEY PSEUDOSECTION:\n" s := ";" + rr.Hdr.String() +
s += rr.Hdr.String() + " " + rr.Algorithm + " " + " " + rr.Algorithm +
strconv.Itoa(int(rr.KeySize)) + " " + rr.Key + " " + " " + TimeToString(rr.Inception) +
strconv.Itoa(int(rr.OtherLen)) + " " + rr.OtherData " " + TimeToString(rr.Expiration) +
" " + strconv.Itoa(int(rr.Mode)) +
" " + strconv.Itoa(int(rr.Error)) +
" " + strconv.Itoa(int(rr.KeySize)) +
" " + rr.Key +
" " + strconv.Itoa(int(rr.OtherLen)) +
" " + rr.OtherData
return s return s
} }
......
...@@ -193,6 +193,8 @@ func main() { ...@@ -193,6 +193,8 @@ func main() {
fallthrough fallthrough
case st.Tag(i) == `dns:"hex"`: case st.Tag(i) == `dns:"hex"`:
o("l += len(rr.%s)/2 + 1\n") o("l += len(rr.%s)/2 + 1\n")
case st.Tag(i) == `dns:"any"`:
o("l += len(rr.%s)\n")
case st.Tag(i) == `dns:"a"`: case st.Tag(i) == `dns:"a"`:
o("l += net.IPv4len // %s\n") o("l += net.IPv4len // %s\n")
case st.Tag(i) == `dns:"aaaa"`: case st.Tag(i) == `dns:"aaaa"`:
......
...@@ -20,15 +20,13 @@ func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { ...@@ -20,15 +20,13 @@ func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) {
if err != nil { if err != nil {
return n, nil, err return n, nil, err
} }
session := &SessionUDP{raddr.(*net.UDPAddr)} return n, &SessionUDP{raddr.(*net.UDPAddr)}, err
return n, session, err
} }
// WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. // WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr.
// TODO(fastest963): Once go1.10 is released, use WriteMsgUDP. // TODO(fastest963): Once go1.10 is released, use WriteMsgUDP.
func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) { func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) {
n, err := conn.WriteTo(b, session.raddr) return conn.WriteTo(b, session.raddr)
return n, err
} }
// TODO(fastest963): Once go1.10 is released and we can use *MsgUDP methods // TODO(fastest963): Once go1.10 is released and we can use *MsgUDP methods
......
...@@ -44,7 +44,8 @@ func (u *Msg) RRsetUsed(rr []RR) { ...@@ -44,7 +44,8 @@ func (u *Msg) RRsetUsed(rr []RR) {
u.Answer = make([]RR, 0, len(rr)) u.Answer = make([]RR, 0, len(rr))
} }
for _, r := range rr { for _, r := range rr {
u.Answer = append(u.Answer, &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: r.Header().Rrtype, Class: ClassANY}}) h := r.Header()
u.Answer = append(u.Answer, &ANY{Hdr: RR_Header{Name: h.Name, Ttl: 0, Rrtype: h.Rrtype, Class: ClassANY}})
} }
} }
...@@ -55,7 +56,8 @@ func (u *Msg) RRsetNotUsed(rr []RR) { ...@@ -55,7 +56,8 @@ func (u *Msg) RRsetNotUsed(rr []RR) {
u.Answer = make([]RR, 0, len(rr)) u.Answer = make([]RR, 0, len(rr))
} }
for _, r := range rr { for _, r := range rr {
u.Answer = append(u.Answer, &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: r.Header().Rrtype, Class: ClassNONE}}) h := r.Header()
u.Answer = append(u.Answer, &ANY{Hdr: RR_Header{Name: h.Name, Ttl: 0, Rrtype: h.Rrtype, Class: ClassNONE}})
} }
} }
...@@ -79,7 +81,8 @@ func (u *Msg) RemoveRRset(rr []RR) { ...@@ -79,7 +81,8 @@ func (u *Msg) RemoveRRset(rr []RR) {
u.Ns = make([]RR, 0, len(rr)) u.Ns = make([]RR, 0, len(rr))
} }
for _, r := range rr { for _, r := range rr {
u.Ns = append(u.Ns, &ANY{Hdr: RR_Header{Name: r.Header().Name, Ttl: 0, Rrtype: r.Header().Rrtype, Class: ClassANY}}) h := r.Header()
u.Ns = append(u.Ns, &ANY{Hdr: RR_Header{Name: h.Name, Ttl: 0, Rrtype: h.Rrtype, Class: ClassANY}})
} }
} }
...@@ -99,8 +102,9 @@ func (u *Msg) Remove(rr []RR) { ...@@ -99,8 +102,9 @@ func (u *Msg) Remove(rr []RR) {
u.Ns = make([]RR, 0, len(rr)) u.Ns = make([]RR, 0, len(rr))
} }
for _, r := range rr { for _, r := range rr {
r.Header().Class = ClassNONE h := r.Header()
r.Header().Ttl = 0 h.Class = ClassNONE
h.Ttl = 0
u.Ns = append(u.Ns, r) u.Ns = append(u.Ns, r)
} }
} }
...@@ -3,7 +3,7 @@ package dns ...@@ -3,7 +3,7 @@ package dns
import "fmt" import "fmt"
// Version is current version of this library. // Version is current version of this library.
var Version = V{1, 1, 1} var Version = V{1, 1, 3}
// V holds the version of this library. // V holds the version of this library.
type V struct { type V struct {
......
...@@ -35,30 +35,36 @@ type Transfer struct { ...@@ -35,30 +35,36 @@ type Transfer struct {
// channel, err := transfer.In(message, master) // channel, err := transfer.In(message, master)
// //
func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) { func (t *Transfer) In(q *Msg, a string) (env chan *Envelope, err error) {
switch q.Question[0].Qtype {
case TypeAXFR, TypeIXFR:
default:
return nil, &Error{"unsupported question type"}
}
timeout := dnsTimeout timeout := dnsTimeout
if t.DialTimeout != 0 { if t.DialTimeout != 0 {
timeout = t.DialTimeout timeout = t.DialTimeout
} }
if t.Conn == nil { if t.Conn == nil {
t.Conn, err = DialTimeout("tcp", a, timeout) t.Conn, err = DialTimeout("tcp", a, timeout)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if err := t.WriteMsg(q); err != nil { if err := t.WriteMsg(q); err != nil {
return nil, err return nil, err
} }
env = make(chan *Envelope) env = make(chan *Envelope)
go func() { switch q.Question[0].Qtype {
if q.Question[0].Qtype == TypeAXFR { case TypeAXFR:
go t.inAxfr(q, env) go t.inAxfr(q, env)
return case TypeIXFR:
}
if q.Question[0].Qtype == TypeIXFR {
go t.inIxfr(q, env) go t.inIxfr(q, env)
return
} }
}()
return env, nil return env, nil
} }
...@@ -111,7 +117,7 @@ func (t *Transfer) inAxfr(q *Msg, c chan *Envelope) { ...@@ -111,7 +117,7 @@ func (t *Transfer) inAxfr(q *Msg, c chan *Envelope) {
} }
func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) { func (t *Transfer) inIxfr(q *Msg, c chan *Envelope) {
serial := uint32(0) // The first serial seen is the current server serial var serial uint32 // The first serial seen is the current server serial
axfr := true axfr := true
n := 0 n := 0
qser := q.Ns[0].(*SOA).Serial qser := q.Ns[0].(*SOA).Serial
...@@ -237,24 +243,18 @@ func (t *Transfer) WriteMsg(m *Msg) (err error) { ...@@ -237,24 +243,18 @@ func (t *Transfer) WriteMsg(m *Msg) (err error) {
if err != nil { if err != nil {
return err return err
} }
if _, err = t.Write(out); err != nil { _, err = t.Write(out)
return err return err
}
return nil
} }
func isSOAFirst(in *Msg) bool { func isSOAFirst(in *Msg) bool {
if len(in.Answer) > 0 { return len(in.Answer) > 0 &&
return in.Answer[0].Header().Rrtype == TypeSOA in.Answer[0].Header().Rrtype == TypeSOA
}
return false
} }
func isSOALast(in *Msg) bool { func isSOALast(in *Msg) bool {
if len(in.Answer) > 0 { return len(in.Answer) > 0 &&
return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
}
return false
} }
const errXFR = "bad xfr rcode: %d" const errXFR = "bad xfr rcode: %d"
This diff is collapsed.
This diff is collapsed.
...@@ -54,6 +54,7 @@ var TypeToRR = map[uint16]func() RR{ ...@@ -54,6 +54,7 @@ var TypeToRR = map[uint16]func() RR{
TypeNSEC: func() RR { return new(NSEC) }, TypeNSEC: func() RR { return new(NSEC) },
TypeNSEC3: func() RR { return new(NSEC3) }, TypeNSEC3: func() RR { return new(NSEC3) },
TypeNSEC3PARAM: func() RR { return new(NSEC3PARAM) }, TypeNSEC3PARAM: func() RR { return new(NSEC3PARAM) },
TypeNULL: func() RR { return new(NULL) },
TypeOPENPGPKEY: func() RR { return new(OPENPGPKEY) }, TypeOPENPGPKEY: func() RR { return new(OPENPGPKEY) },
TypeOPT: func() RR { return new(OPT) }, TypeOPT: func() RR { return new(OPT) },
TypePTR: func() RR { return new(PTR) }, TypePTR: func() RR { return new(PTR) },
...@@ -209,6 +210,7 @@ func (rr *NSAPPTR) Header() *RR_Header { return &rr.Hdr } ...@@ -209,6 +210,7 @@ func (rr *NSAPPTR) Header() *RR_Header { return &rr.Hdr }
func (rr *NSEC) Header() *RR_Header { return &rr.Hdr } func (rr *NSEC) Header() *RR_Header { return &rr.Hdr }
func (rr *NSEC3) Header() *RR_Header { return &rr.Hdr } func (rr *NSEC3) Header() *RR_Header { return &rr.Hdr }
func (rr *NSEC3PARAM) Header() *RR_Header { return &rr.Hdr } func (rr *NSEC3PARAM) Header() *RR_Header { return &rr.Hdr }
func (rr *NULL) Header() *RR_Header { return &rr.Hdr }
func (rr *OPENPGPKEY) Header() *RR_Header { return &rr.Hdr } func (rr *OPENPGPKEY) Header() *RR_Header { return &rr.Hdr }
func (rr *OPT) Header() *RR_Header { return &rr.Hdr } func (rr *OPT) Header() *RR_Header { return &rr.Hdr }
func (rr *PTR) Header() *RR_Header { return &rr.Hdr } func (rr *PTR) Header() *RR_Header { return &rr.Hdr }
...@@ -473,6 +475,11 @@ func (rr *NSEC3PARAM) len(off int, compression map[string]struct{}) int { ...@@ -473,6 +475,11 @@ func (rr *NSEC3PARAM) len(off int, compression map[string]struct{}) int {
l += len(rr.Salt) / 2 l += len(rr.Salt) / 2
return l return l
} }
func (rr *NULL) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len(off, compression)
l += len(rr.Data)
return l
}
func (rr *OPENPGPKEY) len(off int, compression map[string]struct{}) int { func (rr *OPENPGPKEY) len(off int, compression map[string]struct{}) int {
l := rr.Hdr.len(off, compression) l := rr.Hdr.len(off, compression)
l += base64.StdEncoding.DecodedLen(len(rr.PublicKey)) l += base64.StdEncoding.DecodedLen(len(rr.PublicKey))
...@@ -783,6 +790,9 @@ func (rr *NSEC3) copy() RR { ...@@ -783,6 +790,9 @@ func (rr *NSEC3) copy() RR {
func (rr *NSEC3PARAM) copy() RR { func (rr *NSEC3PARAM) copy() RR {
return &NSEC3PARAM{rr.Hdr, rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt} return &NSEC3PARAM{rr.Hdr, rr.Hash, rr.Flags, rr.Iterations, rr.SaltLength, rr.Salt}
} }
func (rr *NULL) copy() RR {
return &NULL{rr.Hdr, rr.Data}
}
func (rr *OPENPGPKEY) copy() RR { func (rr *OPENPGPKEY) copy() RR {
return &OPENPGPKEY{rr.Hdr, rr.PublicKey} return &OPENPGPKEY{rr.Hdr, rr.PublicKey}
} }
......
language: go language: go
go_import_path: github.com/pkg/errors go_import_path: github.com/pkg/errors
go: go:
- 1.4.3 - 1.4.x
- 1.5.4 - 1.5.x
- 1.6.2 - 1.6.x
- 1.7.1 - 1.7.x
- 1.8.x
- 1.9.x
- 1.10.x
- 1.11.x
- tip - tip
script: script:
......
# errors [![Travis-CI](https://travis-ci.org/pkg/errors.svg)](https://travis-ci.org/pkg/errors) [![AppVeyor](https://ci.appveyor.com/api/projects/status/b98mptawhudj53ep/branch/master?svg=true)](https://ci.appveyor.com/project/davecheney/errors/branch/master) [![GoDoc](https://godoc.org/github.com/pkg/errors?status.svg)](http://godoc.org/github.com/pkg/errors) [![Report card](https://goreportcard.com/badge/github.com/pkg/errors)](https://goreportcard.com/report/github.com/pkg/errors) # errors [![Travis-CI](https://travis-ci.org/pkg/errors.svg)](https://travis-ci.org/pkg/errors) [![AppVeyor](https://ci.appveyor.com/api/projects/status/b98mptawhudj53ep/branch/master?svg=true)](https://ci.appveyor.com/project/davecheney/errors/branch/master) [![GoDoc](https://godoc.org/github.com/pkg/errors?status.svg)](http://godoc.org/github.com/pkg/errors) [![Report card](https://goreportcard.com/badge/github.com/pkg/errors)](https://goreportcard.com/report/github.com/pkg/errors) [![Sourcegraph](https://sourcegraph.com/github.com/pkg/errors/-/badge.svg)](https://sourcegraph.com/github.com/pkg/errors?badge)
Package errors provides simple error handling primitives. Package errors provides simple error handling primitives.
...@@ -47,6 +47,6 @@ We welcome pull requests, bug fixes and issue reports. With that said, the bar f ...@@ -47,6 +47,6 @@ We welcome pull requests, bug fixes and issue reports. With that said, the bar f
Before proposing a change, please discuss your change by raising an issue. Before proposing a change, please discuss your change by raising an issue.
## Licence ## License
BSD-2-Clause BSD-2-Clause
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
// return err // return err
// } // }
// //
// which applied recursively up the call stack results in error reports // which when applied recursively up the call stack results in error reports
// without context or debugging information. The errors package allows // without context or debugging information. The errors package allows
// programmers to add context to the failure path in their code in a way // programmers to add context to the failure path in their code in a way
// that does not destroy the original value of the error. // that does not destroy the original value of the error.
...@@ -15,16 +15,17 @@ ...@@ -15,16 +15,17 @@
// //
// The errors.Wrap function returns a new error that adds context to the // The errors.Wrap function returns a new error that adds context to the
// original error by recording a stack trace at the point Wrap is called, // original error by recording a stack trace at the point Wrap is called,
// and the supplied message. For example // together with the supplied message. For example
// //
// _, err := ioutil.ReadAll(r) // _, err := ioutil.ReadAll(r)
// if err != nil { // if err != nil {
// return errors.Wrap(err, "read failed") // return errors.Wrap(err, "read failed")
// } // }
// //
// If additional control is required the errors.WithStack and errors.WithMessage // If additional control is required, the errors.WithStack and
// functions destructure errors.Wrap into its component operations of annotating // errors.WithMessage functions destructure errors.Wrap into its component
// an error with a stack trace and an a message, respectively. // operations: annotating an error with a stack trace and with a message,
// respectively.
// //
// Retrieving the cause of an error // Retrieving the cause of an error
// //
...@@ -38,7 +39,7 @@ ...@@ -38,7 +39,7 @@
// } // }
// //
// can be inspected by errors.Cause. errors.Cause will recursively retrieve // can be inspected by errors.Cause. errors.Cause will recursively retrieve
// the topmost error which does not implement causer, which is assumed to be // the topmost error that does not implement causer, which is assumed to be
// the original cause. For example: // the original cause. For example:
// //
// switch err := errors.Cause(err).(type) { // switch err := errors.Cause(err).(type) {
...@@ -48,16 +49,16 @@ ...@@ -48,16 +49,16 @@
// // unknown error // // unknown error
// } // }
// //
// causer interface is not exported by this package, but is considered a part // Although the causer interface is not exported by this package, it is
// of stable public API. // considered a part of its stable public interface.
// //
// Formatted printing of errors // Formatted printing of errors
// //
// All error values returned from this package implement fmt.Formatter and can // All error values returned from this package implement fmt.Formatter and can
// be formatted by the fmt package. The following verbs are supported // be formatted by the fmt package. The following verbs are supported:
// //
// %s print the error. If the error has a Cause it will be // %s print the error. If the error has a Cause it will be
// printed recursively // printed recursively.
// %v see %s // %v see %s
// %+v extended format. Each Frame of the error's StackTrace will // %+v extended format. Each Frame of the error's StackTrace will
// be printed in detail. // be printed in detail.
...@@ -65,13 +66,13 @@ ...@@ -65,13 +66,13 @@
// Retrieving the stack trace of an error or wrapper // Retrieving the stack trace of an error or wrapper
// //
// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are // New, Errorf, Wrap, and Wrapf record a stack trace at the point they are
// invoked. This information can be retrieved with the following interface. // invoked. This information can be retrieved with the following interface:
// //
// type stackTracer interface { // type stackTracer interface {
// StackTrace() errors.StackTrace // StackTrace() errors.StackTrace
// } // }
// //
// Where errors.StackTrace is defined as // The returned errors.StackTrace type is defined as
// //
// type StackTrace []Frame // type StackTrace []Frame
// //
...@@ -85,8 +86,8 @@ ...@@ -85,8 +86,8 @@
// } // }
// } // }
// //
// stackTracer interface is not exported by this package, but is considered a part // Although the stackTracer interface is not exported by this package, it is
// of stable public API. // considered a part of its stable public interface.
// //
// See the documentation for Frame.Format for more details. // See the documentation for Frame.Format for more details.
package errors package errors
...@@ -192,7 +193,7 @@ func Wrap(err error, message string) error { ...@@ -192,7 +193,7 @@ func Wrap(err error, message string) error {
} }
// Wrapf returns an error annotating err with a stack trace // Wrapf returns an error annotating err with a stack trace
// at the point Wrapf is call, and the format specifier. // at the point Wrapf is called, and the format specifier.
// If err is nil, Wrapf returns nil. // If err is nil, Wrapf returns nil.
func Wrapf(err error, format string, args ...interface{}) error { func Wrapf(err error, format string, args ...interface{}) error {
if err == nil { if err == nil {
...@@ -220,6 +221,18 @@ func WithMessage(err error, message string) error { ...@@ -220,6 +221,18 @@ func WithMessage(err error, message string) error {
} }
} }
// WithMessagef annotates err with the format specifier.
// If err is nil, WithMessagef returns nil.
func WithMessagef(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
return &withMessage{
cause: err,
msg: fmt.Sprintf(format, args...),
}
}
type withMessage struct { type withMessage struct {
cause error cause error
msg string msg string
......
...@@ -46,7 +46,8 @@ func (f Frame) line() int { ...@@ -46,7 +46,8 @@ func (f Frame) line() int {
// //
// Format accepts flags that alter the printing of some verbs, as follows: // Format accepts flags that alter the printing of some verbs, as follows:
// //
// %+s path of source file relative to the compile time GOPATH // %+s function name and path of source file relative to the compile time
// GOPATH separated by \n\t (<funcname>\n\t<path>)
// %+v equivalent to %+s:%d // %+v equivalent to %+s:%d
func (f Frame) Format(s fmt.State, verb rune) { func (f Frame) Format(s fmt.State, verb rune) {
switch verb { switch verb {
...@@ -79,6 +80,14 @@ func (f Frame) Format(s fmt.State, verb rune) { ...@@ -79,6 +80,14 @@ func (f Frame) Format(s fmt.State, verb rune) {
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). // StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
type StackTrace []Frame type StackTrace []Frame
// Format formats the stack of Frames according to the fmt.Formatter interface.
//
// %s lists source files for each Frame in the stack
// %v lists the source file and line number for each Frame in the stack
//
// Format accepts flags that alter the printing of some verbs, as follows:
//
// %+v Prints filename, function, and line number for each Frame in the stack.
func (st StackTrace) Format(s fmt.State, verb rune) { func (st StackTrace) Format(s fmt.State, verb rune) {
switch verb { switch verb {
case 'v': case 'v':
...@@ -136,43 +145,3 @@ func funcname(name string) string { ...@@ -136,43 +145,3 @@ func funcname(name string) string {
i = strings.Index(name, ".") i = strings.Index(name, ".")
return name[i+1:] return name[i+1:]
} }
func trimGOPATH(name, file string) string {
// Here we want to get the source file path relative to the compile time
// GOPATH. As of Go 1.6.x there is no direct way to know the compiled
// GOPATH at runtime, but we can infer the number of path segments in the
// GOPATH. We note that fn.Name() returns the function name qualified by
// the import path, which does not include the GOPATH. Thus we can trim
// segments from the beginning of the file path until the number of path
// separators remaining is one more than the number of path separators in
// the function name. For example, given:
//
// GOPATH /home/user
// file /home/user/src/pkg/sub/file.go
// fn.Name() pkg/sub.Type.Method
//
// We want to produce:
//
// pkg/sub/file.go
//
// From this we can easily see that fn.Name() has one less path separator
// than our desired output. We count separators from the end of the file
// path until it finds two more than in the function name and then move
// one character forward to preserve the initial path segment without a
// leading separator.
const sep = "/"
goal := strings.Count(name, sep) + 2
i := len(file)
for n := 0; n < goal; n++ {
i = strings.LastIndex(file[:i], sep)
if i == -1 {
// not enough separators found, set i so that the slice expression
// below leaves file unmodified
i = -len(sep)
break
}
}
// get back to 0 or trim the leading separator
file = file[i+len(sep):]
return file
}
This diff is collapsed.
This diff is collapsed.
// Package core implements essential parts of Shadowsocks
package core
package core
import "net"
func ListenPacket(network, address string, ciph PacketConnCipher) (net.PacketConn, error) {
c, err := net.ListenPacket(network, address)
return ciph.PacketConn(c), err
}
package core
import "net"
type listener struct {
net.Listener
StreamConnCipher
}
func Listen(network, address string, ciph StreamConnCipher) (net.Listener, error) {
l, err := net.Listen(network, address)
return &listener{l, ciph}, err
}
func (l *listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
return l.StreamConn(c), err
}
func Dial(network, address string, ciph StreamConnCipher) (net.Conn, error) {
c, err := net.Dial(network, address)
return ciph.StreamConn(c), err
}
package shadowaead
import (
"crypto/aes"
"crypto/cipher"
"crypto/sha1"
"io"
"strconv"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/hkdf"
)
type Cipher interface {
KeySize() int
SaltSize() int
Encrypter(salt []byte) (cipher.AEAD, error)
Decrypter(salt []byte) (cipher.AEAD, error)
}
type KeySizeError int
func (e KeySizeError) Error() string {
return "key size error: need " + strconv.Itoa(int(e)) + " bytes"
}
func hkdfSHA1(secret, salt, info, outkey []byte) {
r := hkdf.New(sha1.New, secret, salt, info)
if _, err := io.ReadFull(r, outkey); err != nil {
panic(err) // should never happen
}
}
type metaCipher struct {
psk []byte
makeAEAD func(key []byte) (cipher.AEAD, error)
}
func (a *metaCipher) KeySize() int { return len(a.psk) }
func (a *metaCipher) SaltSize() int {
if ks := a.KeySize(); ks > 16 {
return ks
}
return 16
}
func (a *metaCipher) Encrypter(salt []byte) (cipher.AEAD, error) {
subkey := make([]byte, a.KeySize())
hkdfSHA1(a.psk, salt, []byte("ss-subkey"), subkey)
return a.makeAEAD(subkey)
}
func (a *metaCipher) Decrypter(salt []byte) (cipher.AEAD, error) {
subkey := make([]byte, a.KeySize())
hkdfSHA1(a.psk, salt, []byte("ss-subkey"), subkey)
return a.makeAEAD(subkey)
}
func aesGCM(key []byte) (cipher.AEAD, error) {
blk, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return cipher.NewGCM(blk)
}
// AESGCM creates a new Cipher with a pre-shared key. len(psk) must be
// one of 16, 24, or 32 to select AES-128/196/256-GCM.
func AESGCM(psk []byte) (Cipher, error) {
switch l := len(psk); l {
case 16, 24, 32: // AES 128/196/256
default:
return nil, aes.KeySizeError(l)
}
return &metaCipher{psk: psk, makeAEAD: aesGCM}, nil
}
// Chacha20Poly1305 creates a new Cipher with a pre-shared key. len(psk)
// must be 32.
func Chacha20Poly1305(psk []byte) (Cipher, error) {
if len(psk) != chacha20poly1305.KeySize {
return nil, KeySizeError(chacha20poly1305.KeySize)
}
return &metaCipher{psk: psk, makeAEAD: chacha20poly1305.New}, nil
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment