Skip to content

Commit

Permalink
refactor logic little bit (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tantalor93 authored Oct 28, 2024
1 parent 5cdb91a commit dc54142
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 203 deletions.
232 changes: 34 additions & 198 deletions pkg/dnsbench/benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package dnsbench
import (
"bufio"
"context"
"crypto/tls"
"encoding/hex"
"errors"
"fmt"
Expand All @@ -22,13 +21,9 @@ import (

"github.com/fatih/color"
"github.com/miekg/dns"
"github.com/quic-go/quic-go/http3"
"github.com/schollz/progressbar/v3"
"github.com/tantalor93/dnspyre/v3/pkg/printutils"
"github.com/tantalor93/doh-go/doh"
"github.com/tantalor93/doq-go/doq"
"go.uber.org/ratelimit"
"golang.org/x/net/http2"
)

var client = http.Client{
Expand Down Expand Up @@ -280,36 +275,6 @@ func (b *Benchmark) init() error {
return nil
}

func (b *Benchmark) parseRequestDelay() error {
if len(b.RequestDelay) == 0 {
return nil
}
requestDelayRegex := regexp.MustCompile(`^(\d+(?:ms|ns|[smhdw]))(?:-(\d+(?:ms|ns|[smhdw])))?$`)

durations := requestDelayRegex.FindStringSubmatch(b.RequestDelay)
if len(durations) != 3 {
return fmt.Errorf("'%s' has unexpected format, either <GO duration> or <GO duration>-<Go duration> is expected", b.RequestDelay)
}
if len(durations[1]) != 0 {
durationStart, err := time.ParseDuration(durations[1])
if err != nil {
return err
}
b.requestDelayStart = durationStart
}
if len(durations[2]) != 0 {
durationEnd, err := time.ParseDuration(durations[2])
if err != nil {
return err
}
b.requestDelayEnd = durationEnd
}
if b.requestDelayEnd > 0 && b.requestDelayStart > 0 && b.requestDelayEnd-b.requestDelayStart <= 0 {
return fmt.Errorf("'%s' is invalid interval, start should be strictly less than end", b.RequestDelay)
}
return nil
}

// Run executes benchmark, if benchmark is unable to start the error is returned, otherwise array of results from parallel benchmark goroutines is returned.
func (b *Benchmark) Run(ctx context.Context) ([]*ResultStats, error) {
color.NoColor = !b.Color
Expand Down Expand Up @@ -347,7 +312,7 @@ func (b *Benchmark) Run(ctx context.Context) ([]*ResultStats, error) {
qTypes = append(qTypes, dns.StringToType[v])
}

queryFactory := b.queryFactory()
queryFactory := workerQueryFactory(b)

limits := ""
var limit ratelimit.Limiter
Expand Down Expand Up @@ -451,7 +416,8 @@ func (b *Benchmark) Run(ctx context.Context) ([]*ResultStats, error) {
if b.useQuic {
req.Id = 0
} else {
req.Id = uint16(rando.Uint32())
// nolint:gosec
req.Id = uint16(rand.Intn(1 << 16))
}

if b.Edns0 > 0 {
Expand Down Expand Up @@ -480,7 +446,7 @@ func (b *Benchmark) Run(ctx context.Context) ([]*ResultStats, error) {
}
dur := time.Since(start)
if b.RequestLogEnabled {
b.logRequest(workerID, req, resp, err, dur)
logRequest(workerID, req, resp, err, dur)
}
st.record(&req, resp, err, start, dur)

Expand Down Expand Up @@ -568,118 +534,6 @@ func (b *Benchmark) network() string {
return network
}

func (b *Benchmark) queryFactory() func() queryFunc {
// for DoH and DoQ we want to share the client, for plain DNS and DoT we want to have each worker have separate connection
// that is maintained by the worker, this allows DoT and plain DNS protocols to supports counting queries per connection
// and granular control of the connection
switch {
case b.useDoH:
if b.SeparateWorkerConnections {
return func() queryFunc {
return b.dohQuery()
}
}
dohQuery := b.dohQuery()
return func() queryFunc {
return dohQuery
}
case b.useQuic:
h, _, _ := net.SplitHostPort(b.Server)
if b.SeparateWorkerConnections {
return func() queryFunc {
// nolint:gosec
quicClient := doq.NewClient(b.Server,
doq.WithTLSConfig(&tls.Config{ServerName: h, InsecureSkipVerify: b.Insecure}),
doq.WithReadTimeout(b.ReadTimeout),
doq.WithWriteTimeout(b.WriteTimeout),
doq.WithConnectTimeout(b.ConnectTimeout),
)
return quicClient.Send
}
}
// nolint:gosec
quicClient := doq.NewClient(b.Server,
doq.WithTLSConfig(&tls.Config{ServerName: h, InsecureSkipVerify: b.Insecure}),
doq.WithReadTimeout(b.ReadTimeout),
doq.WithWriteTimeout(b.WriteTimeout),
doq.WithConnectTimeout(b.ConnectTimeout),
)
return func() queryFunc {
return quicClient.Send
}
default:
queryFactory := func() queryFunc {
dnsClient := b.getDNSClient()
var co *dns.Conn
var i int64
return func(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
if co != nil && b.QperConn > 0 && i%b.QperConn == 0 {
co.Close()
co = nil
}
i++
if co == nil {
var err error
co, err = dnsClient.DialContext(ctx, b.Server)
if err != nil {
return nil, err
}
}
r, _, err := dnsClient.ExchangeWithConnContext(ctx, msg, co)
if err != nil {
co.Close()
co = nil
return nil, err
}
return r, nil
}
}
return queryFactory
}
}

func (b *Benchmark) logRequest(workerID uint32, req dns.Msg, resp *dns.Msg, err error, dur time.Duration) {
rcode := "<nil>"
respid := "<nil>"
respflags := "<nil>"
if resp != nil {
rcode = dns.RcodeToString[resp.Rcode]
respid = fmt.Sprint(resp.Id)
respflags = getFlags(resp)
}
log.Printf("worker:[%v] reqid:[%d] qname:[%s] qtype:[%s] respid:[%s] rcode:[%s] respflags:[%s] err:[%v] duration:[%v]",
workerID, req.Id, req.Question[0].Name, dns.TypeToString[req.Question[0].Qtype], respid, rcode, respflags, err, dur)
}

func getFlags(resp *dns.Msg) string {
respflags := ""
if resp.Response {
respflags += "qr"
}
if resp.Authoritative {
respflags += " aa"
}
if resp.Truncated {
respflags += " tc"
}
if resp.RecursionDesired {
respflags += " rd"
}
if resp.RecursionAvailable {
respflags += " ra"
}
if resp.Zero {
respflags += " z"
}
if resp.AuthenticatedData {
respflags += " ad"
}
if resp.CheckingDisabled {
respflags += " cd"
}
return respflags
}

func addEdnsOpt(m *dns.Msg, ednsOpt string) {
o := m.IsEdns0()
if o == nil {
Expand Down Expand Up @@ -723,54 +577,6 @@ func isHTTPUrl(s string) (ok bool, network string) {
return false, ""
}

func (b *Benchmark) dohQuery() queryFunc {
var tr http.RoundTripper
switch b.DohProtocol {
case HTTP3Proto:
// nolint:gosec
tr = &http3.RoundTripper{TLSClientConfig: &tls.Config{InsecureSkipVerify: b.Insecure}}
case HTTP2Proto:
// nolint:gosec
tr = &http2.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: b.Insecure}}
case HTTP1Proto:
fallthrough
default:
// nolint:gosec
tr = &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: b.Insecure}}
}
c := http.Client{Transport: tr, Timeout: b.ReadTimeout}
dohClient := doh.NewClient(b.Server, doh.WithHTTPClient(&c))

switch b.DohMethod {
case PostHTTPMethod:
return dohClient.SendViaPost
case GetHTTPMethod:
return dohClient.SendViaGet
default:
return dohClient.SendViaPost
}
}

func (b *Benchmark) getDNSClient() *dns.Client {
network := UDPTransport
if b.TCP {
network = TCPTransport
}
if b.DOT {
network = TLSTransport
}

return &dns.Client{
Net: network,
DialTimeout: b.ConnectTimeout,
WriteTimeout: b.WriteTimeout,
ReadTimeout: b.ReadTimeout,
Timeout: b.RequestTimeout,
// nolint:gosec
TLSConfig: &tls.Config{InsecureSkipVerify: b.Insecure},
}
}

func (b *Benchmark) prepareQuestions() ([]string, error) {
var questions []string
for _, q := range b.Queries {
Expand Down Expand Up @@ -807,3 +613,33 @@ func checkLimit(ctx context.Context, limiter ratelimit.Limiter) error {
return ctx.Err()
}
}

func (b *Benchmark) parseRequestDelay() error {
if len(b.RequestDelay) == 0 {
return nil
}
requestDelayRegex := regexp.MustCompile(`^(\d+(?:ms|ns|[smhdw]))(?:-(\d+(?:ms|ns|[smhdw])))?$`)

durations := requestDelayRegex.FindStringSubmatch(b.RequestDelay)
if len(durations) != 3 {
return fmt.Errorf("'%s' has unexpected format, either <GO duration> or <GO duration>-<Go duration> is expected", b.RequestDelay)
}
if len(durations[1]) != 0 {
durationStart, err := time.ParseDuration(durations[1])
if err != nil {
return err
}
b.requestDelayStart = durationStart
}
if len(durations[2]) != 0 {
durationEnd, err := time.ParseDuration(durations[2])
if err != nil {
return err
}
b.requestDelayEnd = durationEnd
}
if b.requestDelayEnd > 0 && b.requestDelayStart > 0 && b.requestDelayEnd-b.requestDelayStart <= 0 {
return fmt.Errorf("'%s' is invalid interval, start should be strictly less than end", b.RequestDelay)
}
return nil
}
2 changes: 2 additions & 0 deletions pkg/dnsbench/benchmark_doq_api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ func (d *doqServer) start() {
return
}
packWithPrefix := make([]byte, 2+len(pack))
// nolint:gosec
binary.BigEndian.PutUint16(packWithPrefix, uint16(len(pack)))
copy(packWithPrefix[2:], pack)
_, _ = stream.Write(packWithPrefix)
Expand Down Expand Up @@ -346,6 +347,7 @@ func readDOQMessage(r io.Reader) (*dns.Msg, error) {
// A client or server receives a STREAM FIN before receiving all the bytes
// for a message indicated in the 2-octet length field.
// See https://www.rfc-editor.org/rfc/rfc9250#section-4.3.3-2.2
// nolint:gosec
if size != uint16(len(buf)) {
return nil, fmt.Errorf("message size does not match 2-byte prefix")
}
Expand Down
Loading

0 comments on commit dc54142

Please sign in to comment.