-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #970 from apernet/wip-speedtest
feat: built-in speed test client & server
- Loading branch information
Showing
9 changed files
with
1,027 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,8 @@ bandwidth: | |
|
||
ignoreClientBandwidth: true | ||
|
||
speedTest: true | ||
|
||
disableUDP: true | ||
udpIdleTimeout: 120s | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
package cmd | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"time" | ||
|
||
"github.com/spf13/cobra" | ||
"github.com/spf13/viper" | ||
"go.uber.org/zap" | ||
|
||
"github.com/apernet/hysteria/core/client" | ||
hyErrors "github.com/apernet/hysteria/core/errors" | ||
"github.com/apernet/hysteria/extras/outbounds" | ||
"github.com/apernet/hysteria/extras/outbounds/speedtest" | ||
) | ||
|
||
var ( | ||
skipDownload bool | ||
skipUpload bool | ||
dataSize uint32 | ||
useBytes bool | ||
|
||
speedtestAddr = fmt.Sprintf("%s:%d", outbounds.SpeedtestDest, 0) | ||
) | ||
|
||
// speedtestCmd represents the speedtest command | ||
var speedtestCmd = &cobra.Command{ | ||
Use: "speedtest", | ||
Short: "Speed test mode", | ||
Long: "Perform a speed test through the proxy server. The server must have speed test support enabled.", | ||
Run: runSpeedtest, | ||
} | ||
|
||
func init() { | ||
initSpeedtestFlags() | ||
rootCmd.AddCommand(speedtestCmd) | ||
} | ||
|
||
func initSpeedtestFlags() { | ||
speedtestCmd.Flags().BoolVar(&skipDownload, "skip-download", false, "Skip download test") | ||
speedtestCmd.Flags().BoolVar(&skipUpload, "skip-upload", false, "Skip upload test") | ||
speedtestCmd.Flags().Uint32Var(&dataSize, "data-size", 1024*1024*100, "Data size for download and upload tests") | ||
speedtestCmd.Flags().BoolVar(&useBytes, "use-bytes", false, "Use bytes per second instead of bits per second") | ||
} | ||
|
||
func runSpeedtest(cmd *cobra.Command, args []string) { | ||
logger.Info("speed test mode") | ||
|
||
if err := viper.ReadInConfig(); err != nil { | ||
logger.Fatal("failed to read client config", zap.Error(err)) | ||
} | ||
var config clientConfig | ||
if err := viper.Unmarshal(&config); err != nil { | ||
logger.Fatal("failed to parse client config", zap.Error(err)) | ||
} | ||
hyConfig, err := config.Config() | ||
if err != nil { | ||
logger.Fatal("failed to load client config", zap.Error(err)) | ||
} | ||
|
||
c, info, err := client.NewClient(hyConfig) | ||
if err != nil { | ||
logger.Fatal("failed to initialize client", zap.Error(err)) | ||
} | ||
defer c.Close() | ||
logger.Info("connected to server", | ||
zap.Bool("udpEnabled", info.UDPEnabled), | ||
zap.Uint64("tx", info.Tx)) | ||
|
||
if !skipDownload { | ||
runDownloadTest(c) | ||
} | ||
if !skipUpload { | ||
runUploadTest(c) | ||
} | ||
} | ||
|
||
func runDownloadTest(c client.Client) { | ||
logger.Info("performing download test") | ||
downConn, err := c.TCP(speedtestAddr) | ||
if err != nil { | ||
if errors.As(err, &hyErrors.DialError{}) { | ||
logger.Fatal("failed to connect (server may not support speed test)", zap.Error(err)) | ||
} else { | ||
logger.Fatal("failed to connect", zap.Error(err)) | ||
} | ||
} | ||
defer downConn.Close() | ||
|
||
downClient := &speedtest.Client{Conn: downConn} | ||
currentTotal := uint32(0) | ||
err = downClient.Download(dataSize, func(d time.Duration, b uint32, done bool) { | ||
if !done { | ||
currentTotal += b | ||
logger.Info("downloading", | ||
zap.Uint32("bytes", b), | ||
zap.String("progress", fmt.Sprintf("%.2f%%", float64(currentTotal)/float64(dataSize)*100)), | ||
zap.String("speed", formatSpeed(b, d, useBytes))) | ||
} else { | ||
logger.Info("download complete", | ||
zap.Uint32("bytes", b), | ||
zap.String("speed", formatSpeed(b, d, useBytes))) | ||
} | ||
}) | ||
if err != nil { | ||
logger.Fatal("download test failed", zap.Error(err)) | ||
} | ||
logger.Info("download test complete") | ||
} | ||
|
||
func runUploadTest(c client.Client) { | ||
logger.Info("performing upload test") | ||
upConn, err := c.TCP(speedtestAddr) | ||
if err != nil { | ||
if errors.As(err, &hyErrors.DialError{}) { | ||
logger.Fatal("failed to connect (server may not support speed test)", zap.Error(err)) | ||
} else { | ||
logger.Fatal("failed to connect", zap.Error(err)) | ||
} | ||
} | ||
defer upConn.Close() | ||
|
||
upClient := &speedtest.Client{Conn: upConn} | ||
currentTotal := uint32(0) | ||
err = upClient.Upload(dataSize, func(d time.Duration, b uint32, done bool) { | ||
if !done { | ||
currentTotal += b | ||
logger.Info("uploading", | ||
zap.Uint32("bytes", b), | ||
zap.String("progress", fmt.Sprintf("%.2f%%", float64(currentTotal)/float64(dataSize)*100)), | ||
zap.String("speed", formatSpeed(b, d, useBytes))) | ||
} else { | ||
logger.Info("upload complete", | ||
zap.Uint32("bytes", b), | ||
zap.String("speed", formatSpeed(b, d, useBytes))) | ||
} | ||
}) | ||
if err != nil { | ||
logger.Fatal("upload test failed", zap.Error(err)) | ||
} | ||
logger.Info("upload test complete") | ||
} | ||
|
||
func formatSpeed(bytes uint32, duration time.Duration, useBytes bool) string { | ||
speed := float64(bytes) / duration.Seconds() | ||
var units []string | ||
if useBytes { | ||
units = []string{"B/s", "KB/s", "MB/s", "GB/s"} | ||
} else { | ||
units = []string{"bps", "Kbps", "Mbps", "Gbps"} | ||
speed *= 8 | ||
} | ||
unitIndex := 0 | ||
for speed > 1024 && unitIndex < len(units)-1 { | ||
speed /= 1024 | ||
unitIndex++ | ||
} | ||
return fmt.Sprintf("%.2f %s", speed, units[unitIndex]) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package outbounds | ||
|
||
import ( | ||
"net" | ||
|
||
"github.com/apernet/hysteria/extras/outbounds/speedtest" | ||
) | ||
|
||
const ( | ||
SpeedtestDest = "@SpeedTest" | ||
) | ||
|
||
// speedtestHandler is a PluggableOutbound that handles speed test requests. | ||
// It's used to intercept speed test requests and return a pseudo connection that | ||
// implements the speed test protocol. | ||
type speedtestHandler struct { | ||
Next PluggableOutbound | ||
} | ||
|
||
func NewSpeedtestHandler(next PluggableOutbound) PluggableOutbound { | ||
return &speedtestHandler{ | ||
Next: next, | ||
} | ||
} | ||
|
||
func (s *speedtestHandler) TCP(reqAddr *AddrEx) (net.Conn, error) { | ||
if reqAddr.Host == SpeedtestDest { | ||
return speedtest.NewServerConn(), nil | ||
} else { | ||
return s.Next.TCP(reqAddr) | ||
} | ||
} | ||
|
||
func (s *speedtestHandler) UDP(reqAddr *AddrEx) (UDPConn, error) { | ||
return s.Next.UDP(reqAddr) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
package speedtest | ||
|
||
import ( | ||
"fmt" | ||
"io" | ||
"net" | ||
"sync/atomic" | ||
"time" | ||
) | ||
|
||
type Client struct { | ||
Conn net.Conn | ||
} | ||
|
||
// Download requests the server to send l bytes of data. | ||
// The callback function cb is called every second with the time since the last call, | ||
// and the number of bytes received in that time. | ||
func (c *Client) Download(l uint32, cb func(time.Duration, uint32, bool)) error { | ||
err := writeDownloadRequest(c.Conn, l) | ||
if err != nil { | ||
return err | ||
} | ||
ok, msg, err := readDownloadResponse(c.Conn) | ||
if err != nil { | ||
return err | ||
} | ||
if !ok { | ||
return fmt.Errorf("server rejected download request: %s", msg) | ||
} | ||
var counter uint32 | ||
stopChan := make(chan struct{}) | ||
defer close(stopChan) | ||
// Call the callback function every second, | ||
// with the time since the last call and the number of bytes received in that time. | ||
go func() { | ||
ticker := time.NewTicker(time.Second) | ||
defer ticker.Stop() | ||
t := time.Now() | ||
for { | ||
select { | ||
case <-stopChan: | ||
return | ||
case <-ticker.C: | ||
cb(time.Since(t), atomic.SwapUint32(&counter, 0), false) | ||
t = time.Now() | ||
} | ||
} | ||
}() | ||
buf := make([]byte, chunkSize) | ||
startTime := time.Now() | ||
remaining := l | ||
for remaining > 0 { | ||
n := remaining | ||
if n > chunkSize { | ||
n = chunkSize | ||
} | ||
rn, err := c.Conn.Read(buf[:n]) | ||
remaining -= uint32(rn) | ||
atomic.AddUint32(&counter, uint32(rn)) | ||
if err != nil && !(remaining == 0 && err == io.EOF) { | ||
return err | ||
} | ||
} | ||
// One last call to the callback function to report the total time and bytes received. | ||
cb(time.Since(startTime), l, true) | ||
return nil | ||
} | ||
|
||
// Upload requests the server to receive l bytes of data. | ||
// The callback function cb is called every second with the time since the last call, | ||
// and the number of bytes sent in that time. | ||
func (c *Client) Upload(l uint32, cb func(time.Duration, uint32, bool)) error { | ||
err := writeUploadRequest(c.Conn, l) | ||
if err != nil { | ||
return err | ||
} | ||
ok, msg, err := readUploadResponse(c.Conn) | ||
if err != nil { | ||
return err | ||
} | ||
if !ok { | ||
return fmt.Errorf("server rejected upload request: %s", msg) | ||
} | ||
var counter uint32 | ||
stopChan := make(chan struct{}) | ||
defer close(stopChan) | ||
// Call the callback function every second, | ||
// with the time since the last call and the number of bytes sent in that time. | ||
go func() { | ||
ticker := time.NewTicker(time.Second) | ||
defer ticker.Stop() | ||
t := time.Now() | ||
for { | ||
select { | ||
case <-stopChan: | ||
return | ||
case <-ticker.C: | ||
cb(time.Since(t), atomic.SwapUint32(&counter, 0), false) | ||
t = time.Now() | ||
} | ||
} | ||
}() | ||
buf := make([]byte, chunkSize) | ||
remaining := l | ||
for remaining > 0 { | ||
n := remaining | ||
if n > chunkSize { | ||
n = chunkSize | ||
} | ||
_, err := c.Conn.Write(buf[:n]) | ||
if err != nil { | ||
return err | ||
} | ||
remaining -= n | ||
atomic.AddUint32(&counter, n) | ||
} | ||
// Now we should receive the upload summary from the server. | ||
elapsed, received, err := readUploadSummary(c.Conn) | ||
if err != nil { | ||
return err | ||
} | ||
// One last call to the callback function to report the total time and bytes sent. | ||
cb(elapsed, received, true) | ||
return nil | ||
} |
Oops, something went wrong.