diff --git a/router/realm.go b/router/realm.go index 37ae764f..5dcbd3e8 100644 --- a/router/realm.go +++ b/router/realm.go @@ -789,7 +789,22 @@ func (r *realm) metaProcedureHandler() { func (r *realm) sessionCount(msg *wamp.Invocation) wamp.Message { var filter []string if len(msg.Arguments) != 0 { - filter, _ = msg.Arguments[0].([]string) + filterList, ok := wamp.AsList(msg.Arguments[0]) + if !ok { + return &wamp.Error{ + Type: wamp.INVOCATION, + Error: wamp.ErrInvalidArgument, + Request: msg.Request, + } + } + filter, ok = wamp.ListToStrings(filterList) + if !ok { + return &wamp.Error{ + Type: wamp.INVOCATION, + Error: wamp.ErrInvalidArgument, + Request: msg.Request, + } + } } retChan := make(chan int) @@ -826,7 +841,22 @@ func (r *realm) sessionCount(msg *wamp.Invocation) wamp.Message { func (r *realm) sessionList(msg *wamp.Invocation) wamp.Message { var filter []string if len(msg.Arguments) != 0 { - filter, _ = msg.Arguments[0].([]string) + filterList, ok := wamp.AsList(msg.Arguments[0]) + if !ok { + return &wamp.Error{ + Type: wamp.INVOCATION, + Error: wamp.ErrInvalidArgument, + Request: msg.Request, + } + } + filter, ok = wamp.ListToStrings(filterList) + if !ok { + return &wamp.Error{ + Type: wamp.INVOCATION, + Error: wamp.ErrInvalidArgument, + Request: msg.Request, + } + } } retChan := make(chan []wamp.ID) @@ -1061,12 +1091,10 @@ func (r *realm) testamentAdd(msg *wamp.Invocation) wamp.Message { } topic, ok := wamp.AsURI(msg.Arguments[0]) if !ok { - fmt.Printf("invalid topic") return makeError(msg.Request, wamp.ErrInvalidArgument) } args, ok := wamp.AsList(msg.Arguments[1]) if !ok { - fmt.Printf("invalid args") return makeError(msg.Request, wamp.ErrInvalidArgument) } kwargs, ok := wamp.AsDict(msg.Arguments[2]) diff --git a/router/router_test.go b/router/router_test.go index fa67afce..21c101e3 100644 --- a/router/router_test.go +++ b/router/router_test.go @@ -396,7 +396,119 @@ func TestRouterCall(t *testing.T) { } } -func TestSessionMetaProcedures(t *testing.T) { +func TestSessionCountMetaProcedure(t *testing.T) { + defer leaktest.Check(t)() + r, err := newTestRouter() + if err != nil { + t.Error(err) + } + defer r.Close() + + caller, err := testClient(r) + if err != nil { + t.Fatal(err) + } + + // Call wamp.MetaProcSessionCount + req := &wamp.Call{Request: wamp.GlobalID(), Procedure: wamp.MetaProcSessionCount} + caller.Send(req) + msg, err := wamp.RecvTimeout(caller, time.Second) + if err != nil { + t.Fatal(err) + } + result, ok := msg.(*wamp.Result) + if !ok { + t.Fatal("expected RESULT, got", msg.MessageType()) + } + if result.Request != req.Request { + t.Fatal("wrong result ID") + } + if len(result.Arguments) == 0 { + t.Fatal("missing expected arguemnt") + } + count, ok := result.Arguments[0].(int) + if !ok { + t.Fatal("expected int arguemnt") + } + if count != 1 { + t.Fatal("wrong session count") + } + + // Call wamp.MetaProcSessionCount with invalid argument + req = &wamp.Call{ + Request: wamp.GlobalID(), + Procedure: wamp.MetaProcSessionCount, + Arguments: wamp.List{"should-be-a-list"}, + } + caller.Send(req) + msg, err = wamp.RecvTimeout(caller, time.Second) + if err != nil { + t.Fatal(err) + } + errResult, ok := msg.(*wamp.Error) + if !ok { + t.Fatal("expected ERROR, got", msg.MessageType()) + } + if errResult.Request != req.Request { + t.Fatal("wrong result ID") + } + + // Call wamp.MetaProcSessionCount with non-matching filter + filter := wamp.List{"user", "def"} + req = &wamp.Call{ + Request: wamp.GlobalID(), + Procedure: wamp.MetaProcSessionCount, + Arguments: wamp.List{filter}, + } + caller.Send(req) + msg, err = wamp.RecvTimeout(caller, time.Second) + if err != nil { + t.Fatal(err) + } + result, ok = msg.(*wamp.Result) + if !ok { + t.Fatal("expected RESULT, got", msg.MessageType()) + } + if result.Request != req.Request { + t.Fatal("wrong result ID") + } + if len(result.Arguments) == 0 { + t.Fatal("missing expected arguemnt") + } + count = result.Arguments[0].(int) + if count != 0 { + t.Fatal("wrong session count") + } + + // Call wamp.MetaProcSessionCount with matching filter + filter = wamp.List{"trusted", "user", "def"} + req = &wamp.Call{ + Request: wamp.GlobalID(), + Procedure: wamp.MetaProcSessionCount, + Arguments: wamp.List{filter}, + } + caller.Send(req) + msg, err = wamp.RecvTimeout(caller, time.Second) + if err != nil { + t.Fatal(err) + } + result, ok = msg.(*wamp.Result) + if !ok { + t.Fatal("expected RESULT, got", msg.MessageType()) + } + if result.Request != req.Request { + t.Fatal("wrong result ID") + } + if len(result.Arguments) == 0 { + t.Fatal("missing expected arguemnt") + } + count = result.Arguments[0].(int) + if count != 1 { + t.Fatal("wrong session count") + } +} + +func TestListSessionMetaProcedures(t *testing.T) { defer leaktest.Check(t)() r, err := newTestRouter() if err != nil { @@ -410,77 +522,84 @@ func TestSessionMetaProcedures(t *testing.T) { } sessID := caller.ID - // Call session meta-procedure to get session count. - sessionCountRequests := []*wamp.Call{ - // Normal call - &wamp.Call{Request: wamp.GlobalID(), Procedure: wamp.MetaProcSessionCount}, - // Call with extra arguments (but invalid) - &wamp.Call{Request: wamp.GlobalID(), Procedure: wamp.MetaProcSessionCount, Arguments: wamp.List{"invalidarg"}}, - } - var msg wamp.Message - for _, req := range sessionCountRequests { - callID := req.Request - caller.Send(req) - msg, err = wamp.RecvTimeout(caller, time.Second) - if err != nil { - t.Fatal(err) - } - result, ok := msg.(*wamp.Result) - if !ok { - t.Fatal("expected RESULT, got", msg.MessageType()) - } - if result.Request != callID { - t.Fatal("wrong result ID") - } + // Call wamp.MetaProcSessionList to get session list. + req := &wamp.Call{ + Request: wamp.GlobalID(), + Procedure: wamp.MetaProcSessionList, + } + caller.Send(req) + msg, err := wamp.RecvTimeout(caller, time.Second) + if err != nil { + t.Fatal(err) + } + result, ok := msg.(*wamp.Result) + if !ok { + t.Fatal("expected RESULT, got", msg.MessageType()) + } + if result.Request != req.Request { + t.Fatal("wrong result ID") + } + if len(result.Arguments) == 0 { + t.Fatal("missing expected arguemnt") + } + ids, ok := result.Arguments[0].([]wamp.ID) + if !ok { + t.Fatal("wrong arg type") + } + if len(ids) != 1 { + t.Fatal("wrong number of session IDs") + } + if sessID != ids[0] { + t.Fatal("wrong session ID") + } - if len(result.Arguments) == 0 { - t.Fatal("missing expected arguemnt") - } - count, ok := result.Arguments[0].(int) - if !ok { - t.Fatal("expected int arguemnt") - } - if count != 1 { - t.Fatal("wrong session count") - } + // Call wamp.MetaProcSessionList with matching filter + filter := wamp.List{"trusted"} + req = &wamp.Call{ + Request: wamp.GlobalID(), + Procedure: wamp.MetaProcSessionList, + Arguments: wamp.List{filter}, } + caller.Send(req) + msg, err = wamp.RecvTimeout(caller, time.Second) + if err != nil { + t.Fatal(err) + } + result, ok = msg.(*wamp.Result) + if !ok { + t.Fatal("expected RESULT, got", msg.MessageType()) + } + if result.Request != req.Request { + t.Fatal("wrong result ID") + } + if len(result.Arguments) == 0 { + t.Fatal("missing expected arguemnt") + } + ids, ok = result.Arguments[0].([]wamp.ID) + if !ok { + t.Fatal("wrong arg type") + } + if len(ids) != 1 { + t.Fatal("wrong number of session IDs") + } + if sessID != ids[0] { + t.Fatal("wrong session ID") + } +} - // Call session meta-procedure to get session list. - sessionListRequests := []*wamp.Call{ - // Normal call - &wamp.Call{Request: wamp.GlobalID(), Procedure: wamp.MetaProcSessionList}, - // Call with extra arguments (but invalid) - &wamp.Call{Request: wamp.GlobalID(), Procedure: wamp.MetaProcSessionList, Arguments: wamp.List{"invalidarg"}}, +func TestGetSessionMetaProcedures(t *testing.T) { + defer leaktest.Check(t)() + r, err := newTestRouter() + if err != nil { + t.Error(err) } - for _, req := range sessionListRequests { - callID := req.Request - caller.Send(&wamp.Call{Request: callID, Procedure: wamp.MetaProcSessionList}) - msg, err = wamp.RecvTimeout(caller, time.Second) - if err != nil { - t.Fatal(err) - } - result, ok := msg.(*wamp.Result) - if !ok { - t.Fatal("expected RESULT, got", msg.MessageType()) - } - if result.Request != callID { - t.Fatal("wrong result ID") - } + defer r.Close() - if len(result.Arguments) == 0 { - t.Fatal("missing expected arguemnt") - } - ids, ok := result.Arguments[0].([]wamp.ID) - if !ok { - t.Fatal("wrong arg type") - } - if len(ids) != 1 { - t.Fatal("wrong number of session IDs") - } - if sessID != ids[0] { - t.Fatal("wrong session ID") - } + caller, err := testClient(r) + if err != nil { + t.Fatal(err) } + sessID := caller.ID // Call session meta-procedure with bad session ID callID := wamp.GlobalID() @@ -489,7 +608,7 @@ func TestSessionMetaProcedures(t *testing.T) { Procedure: wamp.MetaProcSessionGet, Arguments: wamp.List{wamp.ID(123456789)}, }) - msg, err = wamp.RecvTimeout(caller, time.Second) + msg, err := wamp.RecvTimeout(caller, time.Second) if err != nil { t.Fatal(err) } diff --git a/wamp/convert.go b/wamp/convert.go index fdcd8cd4..fe43e286 100644 --- a/wamp/convert.go +++ b/wamp/convert.go @@ -117,6 +117,23 @@ func AsList(v interface{}) (List, bool) { return list, true } +// ListToStrings converts a List to a slice of string. Returns the string +// slice and a boolean indicating if the conversion was successful. +func ListToStrings(list List) ([]string, bool) { + if len(list) == 0 { + return nil, true + } + strs := make([]string, len(list)) + for i := range list { + s, ok := AsString(list[i]) + if !ok { + return nil, false + } + strs[i] = s + } + return strs, true +} + // OptionString returns named value as string; empty string if missing or not // string type. func OptionString(opts Dict, optionName string) string { diff --git a/wamp/convert_test.go b/wamp/convert_test.go index 72d095c4..f32dd62e 100644 --- a/wamp/convert_test.go +++ b/wamp/convert_test.go @@ -241,3 +241,17 @@ func TestAsFloat64(t *testing.T) { t.Error(shouldFailMsg) } } + +func TestListToStrings(t *testing.T) { + strs, ok := ListToStrings(List{"hello", "world"}) + if !ok { + t.Fatal("not convered") + } + if strs[0] != "hello" || strs[1] != "world" { + t.Fatal("bad conversion") + } + + if _, ok = ListToStrings(List{"hello", 123}); ok { + t.Fatal("should not have converted") + } +}