diff --git a/http_propagation.go b/http_propagation.go index 58c56e4..9947ca3 100644 --- a/http_propagation.go +++ b/http_propagation.go @@ -25,6 +25,7 @@ import ( const ( httpHeaderMaxSize = 200 httpHeader = `X-Amzn-Trace-Id` + prefixSelf = "Self=" prefixRoot = "Root=" prefixParent = "Parent=" prefixSampled = "Sampled=" @@ -44,31 +45,18 @@ func ParseTraceHeader(header string) (trace.SpanContext, bool) { traceOptions trace.TraceOptions ) - if strings.HasPrefix(header, prefixRoot) { - header = header[len(prefixRoot):] - } - - // Parse the trace id field. - if index := strings.Index(header, `;`); index == -1 { - amazonTraceID, header = header, header[len(header):] - } else { - amazonTraceID, header = header[:index], header[index+1:] - } - - if strings.HasPrefix(header, prefixParent) { - header = header[len(prefixParent):] - - if index := strings.Index(header, `;`); index == -1 { - parentSpanID, header = header, header[len(header):] - } else { - parentSpanID, header = header[:index], header[index+1:] + for _, field := range strings.Split(header, ";") { + if field == "" { + continue } - } - - if strings.HasPrefix(header, prefixSampled) { - header = header[len(prefixSampled):] - if strings.HasPrefix(header, "1") { + if strings.HasPrefix(field, prefixRoot) { + amazonTraceID = field[len(prefixRoot):] + } else if strings.HasPrefix(field, prefixParent) { + parentSpanID = field[len(prefixParent):] + } else if field == prefixSampled+"1" { traceOptions = 1 + } else if strings.Index(field, "=") == -1 { + amazonTraceID = field } } diff --git a/http_propagation_test.go b/http_propagation_test.go index 3583e3f..5bfe4a9 100644 --- a/http_propagation_test.go +++ b/http_propagation_test.go @@ -20,6 +20,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -126,6 +127,67 @@ func TestSpanContextFromRequest(t *testing.T) { } }) + t.Run("traceID with self", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) + amazonTraceID := convertToAmazonTraceID(traceID) + req.Header.Set(httpHeader, prefixSelf+amazonTraceID+";"+prefixRoot+amazonTraceID) + + sc, ok := format.SpanContextFromRequest(req) + if !ok { + t.Errorf("expected true; got false") + } + if traceID != sc.TraceID { + t.Errorf("expected %v; got %v", traceID, sc.TraceID) + } + if zeroSpanID != sc.SpanID { + t.Errorf("expected %v; got %v", zeroSpanID, sc.SpanID) + } + if 0 != sc.TraceOptions { + t.Errorf("expected 0; got %v", sc.TraceOptions) + } + }) + + t.Run("traceID with self and parentSpanID", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) + amazonTraceID := convertToAmazonTraceID(traceID) + amazonSpanID := convertToAmazonSpanID(spanID) + req.Header.Set(httpHeader, prefixSelf+amazonTraceID+";"+prefixRoot+amazonTraceID+";"+prefixParent+amazonSpanID) + + sc, ok := format.SpanContextFromRequest(req) + if !ok { + t.Errorf("expected true; got false") + } + if traceID != sc.TraceID { + t.Errorf("expected %v; got %v", traceID, sc.TraceID) + } + if spanID != sc.SpanID { + t.Errorf("expected %v; got %v", spanID, sc.SpanID) + } + if 0 != sc.TraceOptions { + t.Errorf("expected 0; got %v", sc.TraceOptions) + } + }) + + t.Run("traceID with self and sampled", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) + amazonTraceID := convertToAmazonTraceID(traceID) + req.Header.Set(httpHeader, prefixSelf+amazonTraceID+";"+prefixRoot+amazonTraceID+";"+prefixSampled+"1") + + sc, ok := format.SpanContextFromRequest(req) + if !ok { + t.Errorf("expected true; got false") + } + if traceID != sc.TraceID { + t.Errorf("expected %v; got %v", traceID, sc.TraceID) + } + if zeroSpanID != sc.SpanID { + t.Errorf("expected %v; got %v", zeroSpanID, sc.SpanID) + } + if 1 != sc.TraceOptions { + t.Errorf("expected 1; got %v", sc.TraceOptions) + } + }) + t.Run("bad traceID", func(t *testing.T) { var ( req = httptest.NewRequest(http.MethodGet, "http://localhost/", nil) @@ -148,6 +210,16 @@ func TestSpanContextFromRequest(t *testing.T) { t.Errorf("expected false; got true") } }) + + t.Run("invalid header", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "http://localhost/", nil) + req.Header.Set(httpHeader, ";invalid-header=") + + _, ok := format.SpanContextFromRequest(req) + if ok { + t.Errorf("expected false, got true") + } + }) } func TestSpanContextToRequest(t *testing.T) { @@ -188,3 +260,18 @@ func TestSpanContextToRequest(t *testing.T) { } }) } + +func BenchmarkParseTraceHeader(b *testing.B) { + var ( + root = prefixRoot + "1-581cf771-a006649127e371903a2de979" + self = prefixSelf + "1-581cf771-a006649127e371903a2de979" + parent = prefixParent + "0102030405060708" + header = strings.Join([]string{root, self, parent, prefixSampled + "1"}, ";") + ) + + for n := 0; n < b.N; n++ { + if _, ok := ParseTraceHeader(header); !ok { + b.FailNow() + } + } +}