Skip to content

Commit

Permalink
parse config in factory function
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinForReal committed Jun 16, 2023
1 parent 7ca7b78 commit 25e68ac
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 26 deletions.
53 changes: 36 additions & 17 deletions armbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package armbalancer
import (
"fmt"
"math"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -32,6 +34,9 @@ type Options struct {
// that a connections lands on an ARM instance that already has a depleted rate limiting quota.
// Default: 10
MinReqsBeforeRecycle int64

// TransportFactory is a function that creates a new transport for a given connection.
TransportFactory func(id int, parent *http.Transport, host string, port string, recycleThreshold, minReqsBeforeRecycle int64) http.RoundTripper
}

// New wraps a transport to provide smart connection pooling and client-side load balancing.
Expand All @@ -42,6 +47,20 @@ func New(opts Options) http.RoundTripper {
if opts.Host == "" {
opts.Host = "management.azure.com"
}
if i := strings.Index(opts.Host, string(':')); i < 0 {
opts.Host += ":443"
}

host, port, err := net.SplitHostPort(opts.Host)
if err != nil {
panic(fmt.Sprintf("invalid host %q: %s", host, err))
}
if host == "" {
host = "management.azure.com"
}
if port == "" {
port = "443"
}
if opts.PoolSize == 0 {
opts.PoolSize = 8
}
Expand All @@ -52,9 +71,13 @@ func New(opts Options) http.RoundTripper {
opts.MinReqsBeforeRecycle = 10
}

if opts.TransportFactory == nil {
opts.TransportFactory = newRecyclableTransport
}

t := &transportPool{pool: make([]http.RoundTripper, opts.PoolSize)}
for i := range t.pool {
t.pool[i] = newRecyclableTransport(i, opts.Transport, opts.Host, opts.RecycleThreshold, opts.MinReqsBeforeRecycle)
t.pool[i] = newRecyclableTransport(i, opts.Transport, host, port, opts.RecycleThreshold, opts.MinReqsBeforeRecycle)
}
return t
}
Expand All @@ -72,18 +95,21 @@ func (t *transportPool) RoundTrip(req *http.Request) (*http.Response, error) {
type recyclableTransport struct {
lock sync.Mutex // only hold while copying pointer - not calling RoundTrip
host string
port string
current *http.Transport
counter int64 // atomic
activeCount *sync.WaitGroup
state *connState
signal chan struct{}
}

func newRecyclableTransport(id int, parent *http.Transport, host string, recycleThreshold, minReqsBeforeRecycle int64) *recyclableTransport {
func newRecyclableTransport(id int, parent *http.Transport, host string, port string, recycleThreshold, minReqsBeforeRecycle int64) http.RoundTripper {
tx := parent.Clone()
tx.MaxConnsPerHost = 1

r := &recyclableTransport{
host: host,
port: port,
current: tx.Clone(),
activeCount: &sync.WaitGroup{},
state: newConnState(),
Expand Down Expand Up @@ -113,26 +139,19 @@ func newRecyclableTransport(id int, parent *http.Transport, host string, recycle
}

// return retrue if transport host matched with request host
func (t *recyclableTransport) compareHost(reqHost string) bool {
idx := strings.Index(reqHost, ":")
idx1 := strings.Index(t.host, ":")

// both host have ":" or not, directly compare reqest host name with transport host
if idx == idx1 {
return reqHost == t.host
func (t *recyclableTransport) compareHost(request *url.URL) bool {
parsedHostName := request.Hostname()
if t.host != parsedHostName {
return false
}

// reqHost has ":", but transportHost doesn't, compare reqHost with port-appened transport host
if idx != -1 {
return reqHost == t.host+reqHost[idx:]
if len(request.Host) == len(parsedHostName) {
return true
}

// reqHost doesn't have ":", but transportHost does, compare reqHost with non-port transport host
return reqHost == t.host[:idx1]
return t.port == request.Port()
}

func (t *recyclableTransport) RoundTrip(req *http.Request) (*http.Response, error) {
matched := t.compareHost(req.URL.Host)
matched := t.compareHost(req.URL)
if !matched {
return nil, fmt.Errorf("host %q is not supported by the configured ARM balancer, supported host name is %q", req.URL.Host, t.host)
}
Expand Down
109 changes: 100 additions & 9 deletions armbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestSoak(t *testing.T) {
wg.Wait()

_, err := client.Get("http://not-the-host")
if err == nil || err.Error() != fmt.Sprintf(`Get "http://not-the-host": host "not-the-host" is not supported by the configured ARM balancer, supported host name is %q`, u.Host) {
if err == nil || err.Error() != fmt.Sprintf(`Get "http://not-the-host": host "not-the-host" is not supported by the configured ARM balancer, supported host name is %q`, u.Hostname()) {
t.Errorf("expected error when requesting host other than the one configured, got: %s", err)
}

Expand Down Expand Up @@ -118,6 +118,7 @@ type testCase struct {
name string
reqHost string
transHost string
transPort string
expected bool
}

Expand All @@ -127,30 +128,35 @@ func TestCompareHost(t *testing.T) {
name: "matched since all without port number",
reqHost: "host.com",
transHost: "host.com",
transPort: "443",
expected: true,
},
{
name: "matched since all with port number",
reqHost: "host.com:443",
transHost: "host.com:443",
transHost: "host.com",
transPort: "443",
expected: true,
},
{
name: "matched with appending port name",
reqHost: "host.com:443",
transHost: "host.com",
transPort: "443",
expected: true,
},
{
name: "matched with removing port name",
reqHost: "host.com",
transHost: "host.com:443",
transHost: "host.com",
transPort: "443",
expected: true,
},
{
name: "not matched since different port number",
reqHost: "host.com:443",
transHost: "host.com:11254",
transHost: "host.com",
transPort: "11254",
expected: false,
},
{
Expand All @@ -162,31 +168,116 @@ func TestCompareHost(t *testing.T) {
{
name: "not matched since differnt host name with port number",
reqHost: "host.com:443",
transHost: "abc.com:443",
transHost: "abc.com",
transPort: "443",
expected: false,
},
{
name: "not matched since differnt host name with port number for reqHost only",
reqHost: "host.com:443",
transHost: "abc.com",
transPort: "443",
expected: false,
},
{
name: "not matched since differnt host name with port number for transHost only",
reqHost: "host.com",
transHost: "abc.com:443",
transHost: "abc.com",
transPort: "443",
expected: false,
},
}

for _, c := range cases {
for index, c := range cases {
t.Run(c.name, func(t *testing.T) {
r := recyclableTransport{
host: c.transHost,
port: c.transPort,
}
v := r.compareHost(c.reqHost)
v := r.compareHost(&url.URL{Host: c.reqHost})
if v != c.expected {
t.Errorf("expected result \"%t\" is not same as we get: %t", c.expected, v)
t.Errorf("expected %d result \"%t\" is not same as we get: %t", index, c.expected, v)
}
})
}
}

func TestNew(t *testing.T) {
type args struct {
opts Options
}
tests := []struct {
name string
args args
wantHost string
wantPort string
paniced bool
}{
{
name: "invalid host",
args: args{
opts: Options{
Host: "invalid:host:invalidport",
},
},
paniced: true,
},
{
name: "host is not assigned",
args: args{
opts: Options{
Host: ":445",
},
},
wantHost: "management.azure.com",
wantPort: "445",
paniced: false,
},
{
name: "port is not assigned",
args: args{
opts: Options{
Host: "management.azure.com",
},
},
wantHost: "management.azure.com",
wantPort: "443",
paniced: false,
},
{
name: "hosturl is not assigned",
args: args{
opts: Options{
Host: "",
},
},
wantHost: "management.azure.com",
wantPort: "443",
paniced: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.paniced {
defer func() {
if r := recover(); r != nil {
return
}
t.Errorf("New() did not panic")
}()
} else {
tt.args.opts.TransportFactory = func(id int, parent *http.Transport, host string, port string, recycleThreshold, minReqsBeforeRecycle int64) http.RoundTripper {
if host != tt.wantHost {
t.Errorf("New() host = %v, want %v", host, tt.wantHost)
}
if port != tt.wantPort {
t.Errorf("New() port = %v, want %v", port, tt.wantPort)
}
return nil
}
}
if got := New(tt.args.opts); got == nil {
t.Errorf("New() returned nil")
}
})
}
Expand Down

0 comments on commit 25e68ac

Please sign in to comment.