Skip to content

Commit

Permalink
feat: add support for ALBTargetGroupRequest/Response
Browse files Browse the repository at this point in the history
  • Loading branch information
its-felix committed Sep 4, 2024
1 parent 5f61940 commit bc6743a
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 2 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Simple HTTP adapter for AWS Lambda
- AWS Lambda Function URL (both normal and streaming)
- API Gateway (v1)
- API Gateway (v2)
- Application Load Balancer

## Builtin support for these HTTP frameworks:
- `net/http`
Expand Down Expand Up @@ -252,6 +253,7 @@ Once this build-tag is present, the following build-tags are available:
- `lambdahttpadapter.apigwv1` (enables API Gateway V1 handler)
- `lambdahttpadapter.apigwv2` (enables API Gateway V2 handler)
- `lambdahttpadapter.functionurl` (enables Lambda Function URL handler)
- `lambdahttpadapter.alb` (enables Application Load Balancer handler)

Also note that Lambda Function URL in Streaming-Mode requires the following build-tag to be set:
- `lambda.norpc`
Expand Down
165 changes: 165 additions & 0 deletions handler/alb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
//go:build !lambdahttpadapter.partial || (lambdahttpadapter.partial && lambdahttpadapter.alb)

package handler

import (
"bytes"
"context"
"encoding/base64"
"github.com/aws/aws-lambda-go/events"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"unicode/utf8"
)

func convertALBRequest(ctx context.Context, event events.ALBTargetGroupRequest) (*http.Request, error) {
q := make(url.Values)

if len(event.MultiValueQueryStringParameters) > 0 {
for k, values := range event.MultiValueQueryStringParameters {
for _, v := range values {
q.Add(k, v)
}
}
} else if len(event.QueryStringParameters) > 0 {
for k, v := range event.QueryStringParameters {
q.Add(k, v)
}
}

headers := make(http.Header)
if event.Headers != nil {
for k, v := range event.Headers {
headers.Add(k, v)
}
}

if event.MultiValueHeaders != nil {
for k, values := range event.MultiValueHeaders {
for _, v := range values {
headers.Add(k, v)
}
}
}

host := headers.Get("X-Forwarded-Host")
if host == "" {
host = headers.Get("Host")
if host == "" {
host = "127.0.0.1"
}
}

sourceIp := headers.Get("X-Forwarded-For")
if sourceIp == "" {
sourceIp = "127.0.0.1"
}

proto := headers.Get("X-Forwarded-Proto")
if proto == "" {
proto = "http"
}

rUrl := buildFullRequestURLWithProto(proto, host, event.Path, "", q.Encode())
req, err := http.NewRequestWithContext(ctx, event.HTTPMethod, rUrl, getBody(event.Body, event.IsBase64Encoded))
if err != nil {
return nil, err
}

req.Header = headers
req.RemoteAddr = buildRemoteAddr(sourceIp)
req.RequestURI = req.URL.RequestURI()

return req, nil
}

type albResponseWriter struct {
multiValueHeaders bool
headersWritten bool
contentTypeSet bool
contentLengthSet bool
headers http.Header
body bytes.Buffer
res events.ALBTargetGroupResponse
}

func (w *albResponseWriter) Header() http.Header {
return w.headers
}

func (w *albResponseWriter) Write(p []byte) (int, error) {
w.WriteHeader(http.StatusOK)
return w.body.Write(p)
}

func (w *albResponseWriter) WriteHeader(statusCode int) {
if !w.headersWritten {
w.headersWritten = true
w.res.StatusCode = statusCode

for k, values := range w.headers {
if w.multiValueHeaders {
w.res.MultiValueHeaders[k] = values
} else {
w.res.Headers[k] = strings.Join(values, ",")
}
}
}
}

func handleALB(multiValueHeaders bool) func(ctx context.Context, event events.ALBTargetGroupRequest, adapter AdapterFunc) (events.ALBTargetGroupResponse, error) {
return func(ctx context.Context, event events.ALBTargetGroupRequest, adapter AdapterFunc) (events.ALBTargetGroupResponse, error) {
req, err := convertALBRequest(ctx, event)
if err != nil {
var def events.ALBTargetGroupResponse
return def, err
}

w := albResponseWriter{
multiValueHeaders: multiValueHeaders,
headers: make(http.Header),
res: events.ALBTargetGroupResponse{},
}

if multiValueHeaders {
w.res.MultiValueHeaders = make(map[string][]string)
} else {
w.res.Headers = make(map[string]string)
}

if err = adapter(ctx, req, &w); err != nil {
var def events.ALBTargetGroupResponse
return def, err
}

b, err := io.ReadAll(&w.body)
if err != nil {
var def events.ALBTargetGroupResponse
return def, err
}

if !w.contentTypeSet {
w.res.Headers["Content-Type"] = http.DetectContentType(b)
}

if !w.contentLengthSet {
w.res.Headers["Content-Length"] = strconv.Itoa(len(b))
}

if utf8.Valid(b) {
w.res.Body = string(b)
} else {
w.res.IsBase64Encoded = true
w.res.Body = base64.StdEncoding.EncodeToString(b)
}

return w.res, nil
}
}

func NewALBHandler(adapter AdapterFunc, multiValueHeaders bool) func(context.Context, events.ALBTargetGroupRequest) (events.ALBTargetGroupResponse, error) {
return NewHandler(handleALB(multiValueHeaders), adapter)
}
8 changes: 6 additions & 2 deletions handler/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ func buildQuery(rawQuery string, queryParams map[string]string) string {
return ""
}

func buildFullRequestURL(host string, path string, altPath string, query string) string {
func buildFullRequestURL(host, path, altPath, query string) string {
return buildFullRequestURLWithProto("https", host, path, altPath, query)
}

func buildFullRequestURLWithProto(proto, host, path, altPath, query string) string {
rUrl := path

if rUrl == "" {
Expand All @@ -35,7 +39,7 @@ func buildFullRequestURL(host string, path string, altPath string, query string)
rUrl = "/" + rUrl
}

rUrl = "https://" + host + rUrl
rUrl = proto + "://" + host + rUrl

if query != "" {
rUrl += "?" + query
Expand Down

0 comments on commit bc6743a

Please sign in to comment.