Skip to content

Commit

Permalink
Functional options (#8)
Browse files Browse the repository at this point in the history
* functional options

* update tests
  • Loading branch information
TuanKiri authored May 8, 2024
1 parent b6a790b commit 1bdac7a
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 208 deletions.
11 changes: 4 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ Create your `.go` file. For example: `main.go`.
package main

import (
"context"
"log"
"os"
"os/signal"

"github.com/JC5LZiy3HVfV5ux/socks5"
)
Expand All @@ -43,15 +46,9 @@ func main() {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
defer stop()

// Options allowed as nil. Example options:
// &socks5.Options{
// Authentication: true,
// ListenAddress: "0.0.0.0:1080",
// }
srv := socks5.New(nil)
srv := socks5.New()

go func() {
// Default address: 127.0.0.1:1080
if err := srv.ListenAndServe(); err != nil {
log.Fatal(err)
}
Expand Down
18 changes: 8 additions & 10 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,23 @@ import (
)

type Driver interface {
Listen() (net.Listener, error)
ListenPacket() (net.PacketConn, error)
Listen(network, address string) (net.Listener, error)
ListenPacket(network, address string) (net.PacketConn, error)
Dial(network, address string) (net.Conn, error)
}

type netDriver struct {
listenAddress string
bindAddress string
dialTimeout time.Duration
timeout time.Duration
}

func (d *netDriver) Listen() (net.Listener, error) {
return net.Listen("tcp", d.listenAddress)
func (d *netDriver) Listen(network, address string) (net.Listener, error) {
return net.Listen(network, address)
}

func (d *netDriver) ListenPacket() (net.PacketConn, error) {
return net.ListenPacket("udp", d.bindAddress)
func (d *netDriver) ListenPacket(network, address string) (net.PacketConn, error) {
return net.ListenPacket(network, address)
}

func (d *netDriver) Dial(network, address string) (net.Conn, error) {
return net.DialTimeout(network, address, d.dialTimeout)
return net.DialTimeout(network, address, d.timeout)
}
85 changes: 35 additions & 50 deletions helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,58 +39,57 @@ lPQTC4uW7AAywREZ2ekd8XUX8JwdTJHmPg==
-----END EC PRIVATE KEY-----
`

var cert tls.Certificate

func init() {
c, err := tls.X509KeyPair([]byte(certPem), []byte(keyPem))
if err != nil {
panic(err)
}
cert = c

mux := http.NewServeMux()

mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "pong!")
})

go func() {
if err := http.ListenAndServe(":5444", mux); err != nil {
log.Fatalf("runRemoteServer: %v", err)
}
}()

go func() {
if err := listenAndServeTLS(":6444", mux); err != nil {
log.Fatalf("runRemoteServer: %v", err)
}
}()
}

type testTLSDial struct{}

func (p testTLSDial) Dial(network, address string) (c net.Conn, err error) {
return tls.Dial(network, address, &tls.Config{InsecureSkipVerify: true})
}

type testTLSDriver struct {
listenAddress string
tlsConfig *tls.Config
tlsConfig *tls.Config
}

func (d testTLSDriver) Listen() (net.Listener, error) {
return tls.Listen("tcp", d.listenAddress, d.tlsConfig)
func (d testTLSDriver) Listen(network, address string) (net.Listener, error) {
return tls.Listen(network, address, d.tlsConfig)
}

func (d testTLSDriver) Dial(network, address string) (net.Conn, error) {
return tls.Dial("tcp", address, &tls.Config{InsecureSkipVerify: true})
return tls.Dial(network, address, &tls.Config{InsecureSkipVerify: true})
}

func (d testTLSDriver) ListenPacket() (net.PacketConn, error) {
func (d testTLSDriver) ListenPacket(network, address string) (net.PacketConn, error) {
return nil, nil
}

func runRemoteServer(address string, useTLS bool) {
if address == "" {
return
}

mux := http.NewServeMux()

mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, "pong!")
})

if useTLS {
if err := listenAndServeTLS(address, mux); err != nil {
log.Fatalf("runRemoteServer: %v", err)
}
return
}

if err := http.ListenAndServe(address, mux); err != nil {
log.Fatalf("runRemoteServer: %v", err)
}
}

func listenAndServeTLS(address string, handler http.Handler) error {
cert, err := tls.X509KeyPair([]byte(certPem), []byte(keyPem))
if err != nil {
return err
}

server := http.Server{
Addr: address,
TLSConfig: &tls.Config{
Expand All @@ -102,22 +101,8 @@ func listenAndServeTLS(address string, handler http.Handler) error {
return server.ListenAndServeTLS("", "")
}

func runProxy(opts *socks5.Options, useTLS bool) {
if useTLS {
cert, err := tls.X509KeyPair([]byte(certPem), []byte(keyPem))
if err != nil {
log.Fatalf("runProxy: %v", err)
}

opts.Driver = &testTLSDriver{
listenAddress: opts.ListenAddress,
tlsConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
}
}

srv := socks5.New(opts)
func runProxy(opts []socks5.Option) {
srv := socks5.New(opts...)

if err := srv.ListenAndServe(); err != nil {
log.Fatalf("runProxy: %v", err)
Expand Down
163 changes: 119 additions & 44 deletions options.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
package socks5

import (
"fmt"
"log"
"net"
"os"
"time"
)

type Options struct {
ListenAddress string // default: 127.0.0.1:1080
PublicIP net.IP // default: 127.0.0.1. Only IPv4 address that is visible to the external connections. Port is assigned automatically.
ReadTimeout time.Duration // default: none
WriteTimeout time.Duration // default: none
DialTimeout time.Duration // default: none
GetPasswordTimeout time.Duration // default: none
Authentication bool // default: no authentication required
StaticCredentials map[string]string // default: root / password
Logger Logger // default: stdoutLogger
Store Store // default: mapStore
Driver Driver // default: netDriver
Metrics Metrics // default: nopMetrics
}

func (o Options) authMethods() map[byte]struct{} {
type Option func(*options)

type options struct {
host string
port int
publicIP net.IP
readTimeout time.Duration
writeTimeout time.Duration
dialTimeout time.Duration
getPasswordTimeout time.Duration
passwordAuthentication bool
staticCredentials map[string]string
logger Logger
store Store
driver Driver
metrics Metrics
}

func (o options) authMethods() map[byte]struct{} {
methods := make(map[byte]struct{})

switch {
case o.Authentication:
case o.passwordAuthentication:
methods[usernamePasswordAuthentication] = struct{}{}
default:
methods[noAuthenticationRequired] = struct{}{}
Expand All @@ -35,53 +39,124 @@ func (o Options) authMethods() map[byte]struct{} {
return methods
}

func optsWithDefaults(opts *Options) *Options {
if opts == nil {
opts = &Options{}
func (o options) listenAddress() string {
return fmt.Sprintf("%s:%d", o.host, o.port)
}

func optsWithDefaults(opts *options) *options {
if opts.port == 0 {
opts.port = 1080
}

if opts.publicIP == nil {
opts.publicIP = net.ParseIP("127.0.0.1")
}

if opts.Logger == nil {
opts.Logger = &stdoutLogger{
if opts.logger == nil {
opts.logger = &stdoutLogger{
log: log.New(os.Stdout, "[socks5] - ", log.Ldate|log.Ltime),
}
}

if opts.Store == nil {
if opts.StaticCredentials == nil {
opts.StaticCredentials = map[string]string{
if opts.store == nil {
if opts.staticCredentials == nil {
opts.staticCredentials = map[string]string{
"root": "password",
}
}

opts.Store = &mapStore{
db: opts.StaticCredentials,
opts.store = &mapStore{
db: opts.staticCredentials,
}
}

if opts.Driver == nil {
if opts.ListenAddress == "" {
opts.ListenAddress = "127.0.0.1:1080"
if opts.driver == nil {
opts.driver = &netDriver{
timeout: opts.dialTimeout,
}
}

host, _, err := net.SplitHostPort(opts.ListenAddress)
if err != nil {
host = "127.0.0.1"
}
if opts.metrics == nil {
opts.metrics = &nopMetrics{}
}

opts.Driver = &netDriver{
listenAddress: opts.ListenAddress,
bindAddress: host + ":0",
dialTimeout: opts.DialTimeout,
}
return opts
}

func WithHost(val string) Option {
return func(o *options) {
o.host = val
}
}

if opts.PublicIP == nil {
opts.PublicIP = net.ParseIP("127.0.0.1")
func WithPort(val int) Option {
return func(o *options) {
o.port = val
}
}

if opts.Metrics == nil {
opts.Metrics = &nopMetrics{}
func WithPublicIP(val net.IP) Option {
return func(o *options) {
o.publicIP = val
}
}

return opts
func WithReadTimeout(val time.Duration) Option {
return func(o *options) {
o.readTimeout = val
}
}

func WithWriteTimeout(val time.Duration) Option {
return func(o *options) {
o.writeTimeout = val
}
}

func WithDialTimeout(val time.Duration) Option {
return func(o *options) {
o.dialTimeout = val
}
}

func WithGetPasswordTimeout(val time.Duration) Option {
return func(o *options) {
o.getPasswordTimeout = val
}
}

func WithPasswordAuthentication() Option {
return func(o *options) {
o.passwordAuthentication = true
}
}

func WithStaticCredentials(val map[string]string) Option {
return func(o *options) {
o.staticCredentials = val
}
}

func WithLogger(val Logger) Option {
return func(o *options) {
o.logger = val
}
}

func WithStore(val Store) Option {
return func(o *options) {
o.store = val
}
}

func WithDriver(val Driver) Option {
return func(o *options) {
o.driver = val
}
}

func WithMetrics(val Metrics) Option {
return func(o *options) {
o.metrics = val
}
}
Loading

0 comments on commit 1bdac7a

Please sign in to comment.