diff --git a/.travis.yml b/.travis.yml index 92beb2eb8..c181ddaf8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: go go: - - 1.6 - 1.7 - 1.8 + - 1.9 install: - go get github.com/bwmarrin/discordgo - go get -v . @@ -10,5 +10,5 @@ install: script: - diff <(gofmt -d .) <(echo -n) - go vet -x ./... - - golint ./... + - golint -set_exit_status ./... - go test -v -race ./... diff --git a/discord.go b/discord.go index 40eabe183..99fda30b4 100644 --- a/discord.go +++ b/discord.go @@ -21,7 +21,7 @@ import ( ) // VERSION of DiscordGo, follows Semantic Versioning. (http://semver.org/) -const VERSION = "0.17.0" +const VERSION = "0.18.0" // ErrMFA will be risen by New when the user has 2FA. var ErrMFA = errors.New("account has 2FA enabled") @@ -50,7 +50,7 @@ func New(args ...interface{}) (s *Session, err error) { // Create an empty Session interface. s = &Session{ State: NewState(), - ratelimiter: NewRatelimiter(), + Ratelimiter: NewRatelimiter(), StateEnabled: true, Compress: true, ShouldReconnectOnError: true, diff --git a/discord_test.go b/discord_test.go index ca4472ada..318bed0a3 100644 --- a/discord_test.go +++ b/discord_test.go @@ -1,6 +1,7 @@ package discordgo import ( + "fmt" "os" "runtime" "sync/atomic" @@ -14,29 +15,25 @@ var ( dg *Session // Stores a global discordgo user session dgBot *Session // Stores a global discordgo bot session - envToken = os.Getenv("DG_TOKEN") // Token to use when authenticating the user account - envBotToken = os.Getenv("DGB_TOKEN") // Token to use when authenticating the bot account - envEmail = os.Getenv("DG_EMAIL") // Email to use when authenticating - envPassword = os.Getenv("DG_PASSWORD") // Password to use when authenticating - envGuild = os.Getenv("DG_GUILD") // Guild ID to use for tests - envChannel = os.Getenv("DG_CHANNEL") // Channel ID to use for tests - // envUser = os.Getenv("DG_USER") // User ID to use for tests - envAdmin = os.Getenv("DG_ADMIN") // User ID of admin user to use for tests + envToken = os.Getenv("DGU_TOKEN") // Token to use when authenticating the user account + envBotToken = os.Getenv("DGB_TOKEN") // Token to use when authenticating the bot account + envGuild = os.Getenv("DG_GUILD") // Guild ID to use for tests + envChannel = os.Getenv("DG_CHANNEL") // Channel ID to use for tests + envAdmin = os.Getenv("DG_ADMIN") // User ID of admin user to use for tests ) func init() { + fmt.Println("Init is being called.") if envBotToken != "" { if d, err := New(envBotToken); err == nil { dgBot = d } } - if envEmail == "" || envPassword == "" || envToken == "" { - return - } - - if d, err := New(envEmail, envPassword, envToken); err == nil { + if d, err := New(envToken); err == nil { dg = d + } else { + fmt.Println("dg is nil, error", err) } } @@ -67,58 +64,11 @@ func TestInvalidToken(t *testing.T) { } } -// TestInvalidUserPass tests the New() function with an invalid Email and Pass -func TestInvalidEmailPass(t *testing.T) { - - _, err := New("invalidemail", "invalidpassword") - if err == nil { - t.Errorf("New(InvalidEmail, InvalidPass) returned nil error.") - } - -} - -// TestInvalidPass tests the New() function with an invalid Password -func TestInvalidPass(t *testing.T) { - - if envEmail == "" { - t.Skip("Skipping New(username,InvalidPass), DG_EMAIL not set") - return - } - _, err := New(envEmail, "invalidpassword") - if err == nil { - t.Errorf("New(Email, InvalidPass) returned nil error.") - } -} - -// TestNewUserPass tests the New() function with a username and password. -// This should return a valid Session{}, a valid Session.Token. -func TestNewUserPass(t *testing.T) { - - if envEmail == "" || envPassword == "" { - t.Skip("Skipping New(username,password), DG_EMAIL or DG_PASSWORD not set") - return - } - - d, err := New(envEmail, envPassword) - if err != nil { - t.Fatalf("New(user,pass) returned error: %+v", err) - } - - if d == nil { - t.Fatal("New(user,pass), d is nil, should be Session{}") - } - - if d.Token == "" { - t.Fatal("New(user,pass), d.Token is empty, should be a valid Token.") - } -} - -// TestNewToken tests the New() function with a Token. This should return -// the same as the TestNewUserPass function. +// TestNewToken tests the New() function with a Token. func TestNewToken(t *testing.T) { if envToken == "" { - t.Skip("Skipping New(token), DG_TOKEN not set") + t.Skip("Skipping New(token), DGU_TOKEN not set") } d, err := New(envToken) @@ -135,32 +85,9 @@ func TestNewToken(t *testing.T) { } } -// TestNewUserPassToken tests the New() function with a username, password and token. -// This should return the same as the TestNewUserPass function. -func TestNewUserPassToken(t *testing.T) { - - if envEmail == "" || envPassword == "" || envToken == "" { - t.Skip("Skipping New(username,password,token), DG_EMAIL, DG_PASSWORD or DG_TOKEN not set") - return - } - - d, err := New(envEmail, envPassword, envToken) - if err != nil { - t.Fatalf("New(user,pass,token) returned error: %+v", err) - } - - if d == nil { - t.Fatal("New(user,pass,token), d is nil, should be Session{}") - } - - if d.Token == "" { - t.Fatal("New(user,pass,token), d.Token is empty, should be a valid Token.") - } -} - func TestOpenClose(t *testing.T) { if envToken == "" { - t.Skip("Skipping TestClose, DG_TOKEN not set") + t.Skip("Skipping TestClose, DGU_TOKEN not set") } d, err := New(envToken) diff --git a/endpoints.go b/endpoints.go index b10f95890..335e224dd 100644 --- a/endpoints.go +++ b/endpoints.go @@ -71,7 +71,6 @@ var ( EndpointUserNotes = func(uID string) string { return EndpointUsers + "@me/notes/" + uID } EndpointGuild = func(gID string) string { return EndpointGuilds + gID } - EndpointGuildInivtes = func(gID string) string { return EndpointGuilds + gID + "/invites" } EndpointGuildChannels = func(gID string) string { return EndpointGuilds + gID + "/channels" } EndpointGuildMembers = func(gID string) string { return EndpointGuilds + gID + "/members" } EndpointGuildMember = func(gID, uID string) string { return EndpointGuilds + gID + "/members/" + uID } @@ -98,7 +97,7 @@ var ( EndpointChannelMessages = func(cID string) string { return EndpointChannels + cID + "/messages" } EndpointChannelMessage = func(cID, mID string) string { return EndpointChannels + cID + "/messages/" + mID } EndpointChannelMessageAck = func(cID, mID string) string { return EndpointChannels + cID + "/messages/" + mID + "/ack" } - EndpointChannelMessagesBulkDelete = func(cID string) string { return EndpointChannel(cID) + "/messages/bulk_delete" } + EndpointChannelMessagesBulkDelete = func(cID string) string { return EndpointChannel(cID) + "/messages/bulk-delete" } EndpointChannelMessagesPins = func(cID string) string { return EndpointChannel(cID) + "/pins" } EndpointChannelMessagePin = func(cID, mID string) string { return EndpointChannel(cID) + "/pins/" + mID } @@ -122,6 +121,8 @@ var ( EndpointRelationship = func(uID string) string { return EndpointRelationships() + "/" + uID } EndpointRelationshipsMutual = func(uID string) string { return EndpointUsers + uID + "/relationships" } + EndpointGuildCreate = EndpointAPI + "guilds" + EndpointInvite = func(iID string) string { return EndpointAPI + "invite/" + iID } EndpointIntegrationsJoin = func(iID string) string { return EndpointAPI + "integrations/" + iID + "/join" } diff --git a/event.go b/event.go index 3a03f46dd..bba396cbb 100644 --- a/event.go +++ b/event.go @@ -6,7 +6,7 @@ type EventHandler interface { Type() string // Handle is called whenever an event of Type() happens. - // It is the recievers responsibility to type assert that the interface + // It is the receivers responsibility to type assert that the interface // is the expected struct. Handle(*Session, interface{}) } diff --git a/examples/appmaker/main.go b/examples/appmaker/main.go index 286fe1694..5581dd930 100644 --- a/examples/appmaker/main.go +++ b/examples/appmaker/main.go @@ -79,7 +79,7 @@ func main() { ap.Name = Name ap, err = dg.ApplicationCreate(ap) if err != nil { - fmt.Println("error creating new applicaiton,", err) + fmt.Println("error creating new application,", err) return } diff --git a/logging.go b/logging.go index 70d78d601..6460b35ba 100644 --- a/logging.go +++ b/logging.go @@ -23,7 +23,7 @@ const ( LogError int = iota // LogWarning level is used for very abnormal events and errors that are - // also returend to a calling function. + // also returned to a calling function. LogWarning // LogInformational level is used for normal non-error activity @@ -34,26 +34,34 @@ const ( LogDebug ) +// Logger can be used to replace the standard logging for discordgo +var Logger func(msgL, caller int, format string, a ...interface{}) + // msglog provides package wide logging consistancy for discordgo // the format, a... portion this command follows that of fmt.Printf // msgL : LogLevel of the message // caller : 1 + the number of callers away from the message source // format : Printf style message format -// a ... : comma seperated list of values to pass +// a ... : comma separated list of values to pass func msglog(msgL, caller int, format string, a ...interface{}) { - pc, file, line, _ := runtime.Caller(caller) + if Logger != nil { + Logger(msgL, caller, format, a...) + } else { + + pc, file, line, _ := runtime.Caller(caller) - files := strings.Split(file, "/") - file = files[len(files)-1] + files := strings.Split(file, "/") + file = files[len(files)-1] - name := runtime.FuncForPC(pc).Name() - fns := strings.Split(name, ".") - name = fns[len(fns)-1] + name := runtime.FuncForPC(pc).Name() + fns := strings.Split(name, ".") + name = fns[len(fns)-1] - msg := fmt.Sprintf(format, a...) + msg := fmt.Sprintf(format, a...) - log.Printf("[DG%d] %s:%d:%s() %s\n", msgL, file, line, name, msg) + log.Printf("[DG%d] %s:%d:%s() %s\n", msgL, file, line, name, msg) + } } // helper function that wraps msglog for the Session struct diff --git a/oauth2_test.go b/oauth2_test.go index 30526eb20..0ff0ca0e7 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -10,7 +10,7 @@ import ( func ExampleApplication() { // Authentication Token pulled from environment variable DG_TOKEN - Token := os.Getenv("DG_TOKEN") + Token := os.Getenv("DGU_TOKEN") if Token == "" { return } diff --git a/ratelimit.go b/ratelimit.go index 223c0d04e..dc48c9240 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -41,8 +41,8 @@ func NewRatelimiter() *RateLimiter { } } -// getBucket retrieves or creates a bucket -func (r *RateLimiter) getBucket(key string) *Bucket { +// GetBucket retrieves or creates a bucket +func (r *RateLimiter) GetBucket(key string) *Bucket { r.Lock() defer r.Unlock() @@ -51,7 +51,7 @@ func (r *RateLimiter) getBucket(key string) *Bucket { } b := &Bucket{ - remaining: 1, + Remaining: 1, Key: key, global: r.global, } @@ -68,27 +68,37 @@ func (r *RateLimiter) getBucket(key string) *Bucket { return b } -// LockBucket Locks until a request can be made -func (r *RateLimiter) LockBucket(bucketID string) *Bucket { - - b := r.getBucket(bucketID) - - b.Lock() - +// GetWaitTime returns the duration you should wait for a Bucket +func (r *RateLimiter) GetWaitTime(b *Bucket, minRemaining int) time.Duration { // If we ran out of calls and the reset time is still ahead of us // then we need to take it easy and relax a little - if b.remaining < 1 && b.reset.After(time.Now()) { - time.Sleep(b.reset.Sub(time.Now())) - + if b.Remaining < minRemaining && b.reset.After(time.Now()) { + return b.reset.Sub(time.Now()) } // Check for global ratelimits sleepTo := time.Unix(0, atomic.LoadInt64(r.global)) if now := time.Now(); now.Before(sleepTo) { - time.Sleep(sleepTo.Sub(now)) + return sleepTo.Sub(now) + } + + return 0 +} + +// LockBucket Locks until a request can be made +func (r *RateLimiter) LockBucket(bucketID string) *Bucket { + return r.LockBucketObject(r.GetBucket(bucketID)) +} + +// LockBucketObject Locks an already resolved bucket until a request can be made +func (r *RateLimiter) LockBucketObject(b *Bucket) *Bucket { + b.Lock() + + if wait := r.GetWaitTime(b, 1); wait > 0 { + time.Sleep(wait) } - b.remaining-- + b.Remaining-- return b } @@ -96,13 +106,14 @@ func (r *RateLimiter) LockBucket(bucketID string) *Bucket { type Bucket struct { sync.Mutex Key string - remaining int + Remaining int limit int reset time.Time global *int64 lastReset time.Time customRateLimit *customRateLimit + Userdata interface{} } // Release unlocks the bucket and reads the headers to update the buckets ratelimit info @@ -113,10 +124,10 @@ func (b *Bucket) Release(headers http.Header) error { // Check if the bucket uses a custom ratelimiter if rl := b.customRateLimit; rl != nil { if time.Now().Sub(b.lastReset) >= rl.reset { - b.remaining = rl.requests - 1 + b.Remaining = rl.requests - 1 b.lastReset = time.Now() } - if b.remaining < 1 { + if b.Remaining < 1 { b.reset = time.Now().Add(rl.reset) } return nil @@ -176,7 +187,7 @@ func (b *Bucket) Release(headers http.Header) error { if err != nil { return err } - b.remaining = int(parsedRemaining) + b.Remaining = int(parsedRemaining) } return nil diff --git a/restapi.go b/restapi.go index 836e4a41c..5dc0467f3 100644 --- a/restapi.go +++ b/restapi.go @@ -65,9 +65,11 @@ func (s *Session) request(method, urlStr, contentType string, b []byte, bucketID if bucketID == "" { bucketID = strings.SplitN(urlStr, "?", 2)[0] } + return s.RequestWithLockedBucket(method, urlStr, contentType, b, s.Ratelimiter.LockBucket(bucketID), sequence) +} - bucket := s.ratelimiter.LockBucket(bucketID) - +// RequestWithLockedBucket makes a request using a bucket that's already been locked +func (s *Session) RequestWithLockedBucket(method, urlStr, contentType string, b []byte, bucket *Bucket, sequence int) (response []byte, err error) { if s.Debug { log.Printf("API REQUEST %8s :: %s\n", method, urlStr) log.Printf("API REQUEST PAYLOAD :: [%s]\n", string(b)) @@ -139,7 +141,7 @@ func (s *Session) request(method, urlStr, contentType string, b []byte, bucketID if sequence < s.MaxRestRetries { s.log(LogInformational, "%s Failed (%s), Retrying...", urlStr, resp.Status) - response, err = s.request(method, urlStr, contentType, b, bucketID, sequence+1) + response, err = s.RequestWithLockedBucket(method, urlStr, contentType, b, s.Ratelimiter.LockBucketObject(bucket), sequence+1) } else { err = fmt.Errorf("Exceeded Max retries HTTP %s, %s", resp.Status, response) } @@ -158,7 +160,7 @@ func (s *Session) request(method, urlStr, contentType string, b []byte, bucketID // we can make the above smarter // this method can cause longer delays than required - response, err = s.request(method, urlStr, contentType, b, bucketID, sequence) + response, err = s.RequestWithLockedBucket(method, urlStr, contentType, b, s.Ratelimiter.LockBucketObject(bucket), sequence) default: // Error condition err = newRestError(req, resp, response) @@ -585,7 +587,7 @@ func (s *Session) GuildCreate(name string) (st *Guild, err error) { Name string `json:"name"` }{name} - body, err := s.RequestWithBucketID("POST", EndpointGuilds, data, EndpointGuilds) + body, err := s.RequestWithBucketID("POST", EndpointGuildCreate, data, EndpointGuildCreate) if err != nil { return } @@ -907,7 +909,7 @@ func (s *Session) GuildChannelsReorder(guildID string, channels []*Channel) (err // GuildInvites returns an array of Invite structures for the given guild // guildID : The ID of a Guild. func (s *Session) GuildInvites(guildID string) (st []*Invite, err error) { - body, err := s.RequestWithBucketID("GET", EndpointGuildInvites(guildID), nil, EndpointGuildInivtes(guildID)) + body, err := s.RequestWithBucketID("GET", EndpointGuildInvites(guildID), nil, EndpointGuildInvites(guildID)) if err != nil { return } @@ -957,6 +959,7 @@ func (s *Session) GuildRoleEdit(guildID, roleID, name string, color int, hoist b // Prevent sending a color int that is too big. if color > 0xFFFFFF { err = fmt.Errorf("color value cannot be larger than 0xFFFFFF") + return nil, err } data := struct { @@ -1020,6 +1023,9 @@ func (s *Session) GuildPruneCount(guildID string, days uint32) (count uint32, er uri := EndpointGuildPrune(guildID) + fmt.Sprintf("?days=%d", days) body, err := s.RequestWithBucketID("GET", uri, nil, EndpointGuildPrune(guildID)) + if err != nil { + return + } err = unmarshal(body, &p) if err != nil { @@ -1204,7 +1210,7 @@ func (s *Session) GuildEmbedEdit(guildID string, enabled bool, channelID string) // Functions specific to Discord Channels // ------------------------------------------------------------------------------------------------ -// Channel returns a Channel strucutre of a specific Channel. +// Channel returns a Channel structure of a specific Channel. // channelID : The ID of the Channel you want returned. func (s *Session) Channel(channelID string) (st *Channel, err error) { body, err := s.RequestWithBucketID("GET", EndpointChannel(channelID), nil, EndpointChannel(channelID)) @@ -1219,12 +1225,16 @@ func (s *Session) Channel(channelID string) (st *Channel, err error) { // ChannelEdit edits the given channel // channelID : The ID of a Channel // name : The new name to assign the channel. -func (s *Session) ChannelEdit(channelID, name string) (st *Channel, err error) { - - data := struct { - Name string `json:"name"` - }{name} +func (s *Session) ChannelEdit(channelID, name string) (*Channel, error) { + return s.ChannelEditComplex(channelID, &ChannelEdit{ + Name: name, + }) +} +// ChannelEditComplex edits an existing channel, replacing the parameters entirely with ChannelEdit struct +// channelID : The ID of a Channel +// data : The channel struct to send +func (s *Session) ChannelEditComplex(channelID string, data *ChannelEdit) (st *Channel, err error) { body, err := s.RequestWithBucketID("PATCH", EndpointChannel(channelID), data, EndpointChannel(channelID)) if err != nil { return @@ -1476,7 +1486,7 @@ func (s *Session) ChannelMessageDelete(channelID, messageID string) (err error) } // ChannelMessagesBulkDelete bulk deletes the messages from the channel for the provided messageIDs. -// If only one messageID is in the slice call channelMessageDelete funciton. +// If only one messageID is in the slice call channelMessageDelete function. // If the slice is empty do nothing. // channelID : The ID of the channel for the messages to delete. // messages : The IDs of the messages to be deleted. A slice of string IDs. A maximum of 100 messages. @@ -1569,16 +1579,14 @@ func (s *Session) ChannelInvites(channelID string) (st []*Invite, err error) { // ChannelInviteCreate creates a new invite for the given channel. // channelID : The ID of a Channel -// i : An Invite struct with the values MaxAge, MaxUses, Temporary, -// and XkcdPass defined. +// i : An Invite struct with the values MaxAge, MaxUses and Temporary defined. func (s *Session) ChannelInviteCreate(channelID string, i Invite) (st *Invite, err error) { data := struct { - MaxAge int `json:"max_age"` - MaxUses int `json:"max_uses"` - Temporary bool `json:"temporary"` - XKCDPass string `json:"xkcdpass"` - }{i.MaxAge, i.MaxUses, i.Temporary, i.XkcdPass} + MaxAge int `json:"max_age"` + MaxUses int `json:"max_uses"` + Temporary bool `json:"temporary"` + }{i.MaxAge, i.MaxUses, i.Temporary} body, err := s.RequestWithBucketID("POST", EndpointChannelInvites(channelID), data, EndpointChannelInvites(channelID)) if err != nil { @@ -1618,7 +1626,7 @@ func (s *Session) ChannelPermissionDelete(channelID, targetID string) (err error // ------------------------------------------------------------------------------------------------ // Invite returns an Invite structure of the given invite -// inviteID : The invite code (or maybe xkcdpass?) +// inviteID : The invite code func (s *Session) Invite(inviteID string) (st *Invite, err error) { body, err := s.RequestWithBucketID("GET", EndpointInvite(inviteID), nil, EndpointInvite("")) @@ -1631,7 +1639,7 @@ func (s *Session) Invite(inviteID string) (st *Invite, err error) { } // InviteDelete deletes an existing invite -// inviteID : the code (or maybe xkcdpass?) of an invite +// inviteID : the code of an invite func (s *Session) InviteDelete(inviteID string) (st *Invite, err error) { body, err := s.RequestWithBucketID("DELETE", EndpointInvite(inviteID), nil, EndpointInvite("")) @@ -1644,7 +1652,7 @@ func (s *Session) InviteDelete(inviteID string) (st *Invite, err error) { } // InviteAccept accepts an Invite to a Guild or Channel -// inviteID : The invite code (or maybe xkcdpass?) +// inviteID : The invite code func (s *Session) InviteAccept(inviteID string) (st *Invite, err error) { body, err := s.RequestWithBucketID("POST", EndpointInvite(inviteID), nil, EndpointInvite("")) diff --git a/restapi_test.go b/restapi_test.go index 7aa4e604d..a2da6344e 100644 --- a/restapi_test.go +++ b/restapi_test.go @@ -227,7 +227,7 @@ func TestGuildMemberNickname(t *testing.T) { t.Skip("Skipping, dg not set.") } - err := dg.GuildMemberNickname(envGuild, "@me/nick", "testnickname") + err := dg.GuildMemberNickname(envGuild, "@me/nick", "B1nzyRocks") if err != nil { t.Errorf("GuildNickname returned error: %+v", err) } diff --git a/state.go b/state.go index 35a8e7578..8158708b3 100644 --- a/state.go +++ b/state.go @@ -531,7 +531,7 @@ func (s *State) PrivateChannel(channelID string) (*Channel, error) { return s.Channel(channelID) } -// Channel gets a channel by ID, it will look in all guilds an private channels. +// Channel gets a channel by ID, it will look in all guilds and private channels. func (s *State) Channel(channelID string) (*Channel, error) { if s == nil { return nil, ErrNilState @@ -816,6 +816,13 @@ func (s *State) OnInterface(se *Session, i interface{}) (err error) { if s.TrackMembers { err = s.MemberRemove(t.Member) } + case *GuildMembersChunk: + if s.TrackMembers { + for i := range t.Members { + t.Members[i].GuildID = t.GuildID + err = s.MemberAdd(t.Members[i]) + } + } case *GuildRoleCreate: if s.TrackRoles { err = s.RoleAdd(t.GuildID, t.Role) diff --git a/structs.go b/structs.go index c3e395667..19d2bad73 100644 --- a/structs.go +++ b/structs.go @@ -14,7 +14,6 @@ package discordgo import ( "encoding/json" "net/http" - "strconv" "sync" "time" @@ -85,6 +84,9 @@ type Session struct { // Stores the last HeartbeatAck that was recieved (in UTC) LastHeartbeatAck time.Time + // used to deal with rate limits + Ratelimiter *RateLimiter + // Event handlers handlersMu sync.RWMutex handlers map[string][]*eventHandlerInstance @@ -96,9 +98,6 @@ type Session struct { // When nil, the session is not listening. listening chan interface{} - // used to deal with rate limits - ratelimiter *RateLimiter - // sequence tracks the current gateway api websocket sequence number sequence *int64 @@ -143,9 +142,9 @@ type Invite struct { MaxAge int `json:"max_age"` Uses int `json:"uses"` MaxUses int `json:"max_uses"` - XkcdPass string `json:"xkcdpass"` Revoked bool `json:"revoked"` Temporary bool `json:"temporary"` + Unique bool `json:"unique"` } // ChannelType is the type of a Channel @@ -171,9 +170,22 @@ type Channel struct { NSFW bool `json:"nsfw"` Position int `json:"position"` Bitrate int `json:"bitrate"` - Recipients []*User `json:"recipient"` + Recipients []*User `json:"recipients"` Messages []*Message `json:"-"` PermissionOverwrites []*PermissionOverwrite `json:"permission_overwrites"` + ParentID string `json:"parent_id"` +} + +// A ChannelEdit holds Channel Feild data for a channel edit. +type ChannelEdit struct { + Name string `json:"name,omitempty"` + Topic string `json:"topic,omitempty"` + NSFW bool `json:"nsfw,omitempty"` + Position int `json:"position"` + Bitrate int `json:"bitrate,omitempty"` + UserLimit int `json:"user_limit,omitempty"` + PermissionOverwrites []*PermissionOverwrite `json:"permission_overwrites,omitempty"` + ParentID string `json:"parent_id,omitempty"` } // A PermissionOverwrite holds permission overwrite data for a Channel @@ -191,6 +203,7 @@ type Emoji struct { Roles []string `json:"roles"` Managed bool `json:"managed"` RequireColons bool `json:"require_colons"` + Animated bool `json:"animated"` } // APIName returns an correctly formatted API name for use in the MessageReactions endpoints. @@ -204,7 +217,7 @@ func (e *Emoji) APIName() string { return e.ID } -// VerificationLevel type defination +// VerificationLevel type definition type VerificationLevel int // Constants for VerificationLevel levels from 0 to 3 inclusive @@ -314,45 +327,58 @@ type Presence struct { Since *int `json:"since"` } +// GameType is the type of "game" (see GameType* consts) in the Game struct +type GameType int + +// Valid GameType values +const ( + GameTypeGame GameType = iota + GameTypeStreaming +) + // A Game struct holds the name of the "playing .." game for a user type Game struct { - Name string `json:"name"` - Type int `json:"type"` - URL string `json:"url,omitempty"` -} - -// UnmarshalJSON unmarshals json to Game struct -func (g *Game) UnmarshalJSON(bytes []byte) error { - temp := &struct { - Name json.Number `json:"name"` - Type json.RawMessage `json:"type"` - URL string `json:"url"` + Name string `json:"name"` + Type GameType `json:"type"` + URL string `json:"url,omitempty"` + Details string `json:"details,omitempty"` + State string `json:"state,omitempty"` + TimeStamps TimeStamps `json:"timestamps,omitempty"` + Assets Assets `json:"assets,omitempty"` + ApplicationID string `json:"application_id,omitempty"` + Instance int8 `json:"instance,omitempty"` + // TODO: Party and Secrets (unknown structure) +} + +// A TimeStamps struct contains start and end times used in the rich presence "playing .." Game +type TimeStamps struct { + EndTimestamp int64 `json:"end,omitempty"` + StartTimestamp int64 `json:"start,omitempty"` +} + +// UnmarshalJSON unmarshals JSON into TimeStamps struct +func (t *TimeStamps) UnmarshalJSON(b []byte) error { + temp := struct { + End float64 `json:"end,omitempty"` + Start float64 `json:"start,omitempty"` }{} - err := json.Unmarshal(bytes, temp) + err := json.Unmarshal(b, &temp) if err != nil { return err } - g.URL = temp.URL - g.Name = temp.Name.String() - - if temp.Type != nil { - err = json.Unmarshal(temp.Type, &g.Type) - if err == nil { - return nil - } - - s := "" - err = json.Unmarshal(temp.Type, &s) - if err == nil { - g.Type, err = strconv.Atoi(s) - } - - return err - } - + t.EndTimestamp = int64(temp.End) + t.StartTimestamp = int64(temp.Start) return nil } +// An Assets struct contains assets and labels used in the rich presence "playing .." Game +type Assets struct { + LargeImageID string `json:"large_image,omitempty"` + SmallImageID string `json:"small_image,omitempty"` + LargeText string `json:"large_text,omitempty"` + SmallText string `json:"small_text,omitempty"` +} + // A Member stores user information for Guild members. type Member struct { GuildID string `json:"guild_id"` @@ -383,7 +409,7 @@ type Settings struct { DeveloperMode bool `json:"developer_mode"` } -// Status type defination +// Status type definition type Status string // Constants for Status with the different current available status diff --git a/user.go b/user.go index 76abdd1d7..a710f2865 100644 --- a/user.go +++ b/user.go @@ -29,7 +29,9 @@ func (u *User) Mention() string { } // AvatarURL returns a URL to the user's avatar. -// size: The size of the user's avatar as a power of two +// size: The size of the user's avatar as a power of two +// if size is an empty string, no size parameter will +// be added to the URL. func (u *User) AvatarURL(size string) string { var URL string if strings.HasPrefix(u.Avatar, "a_") { @@ -38,5 +40,8 @@ func (u *User) AvatarURL(size string) string { URL = EndpointUserAvatar(u.ID, u.Avatar) } - return URL + "?size=" + size + if size != "" { + return URL + "?size=" + size + } + return URL } diff --git a/voice.go b/voice.go index 8f033aa00..3bbf6212b 100644 --- a/voice.go +++ b/voice.go @@ -13,7 +13,6 @@ import ( "encoding/binary" "encoding/json" "fmt" - "log" "net" "strings" "sync" @@ -69,7 +68,7 @@ type VoiceConnection struct { voiceSpeakingUpdateHandlers []VoiceSpeakingUpdateHandler } -// VoiceSpeakingUpdateHandler type provides a function defination for the +// VoiceSpeakingUpdateHandler type provides a function definition for the // VoiceSpeakingUpdate event type VoiceSpeakingUpdateHandler func(vc *VoiceConnection, vs *VoiceSpeakingUpdate) @@ -104,7 +103,7 @@ func (v *VoiceConnection) Speaking(b bool) (err error) { defer v.Unlock() if err != nil { v.speaking = false - log.Println("Speaking() write json error:", err) + v.log(LogError, "Speaking() write json error:", err) return } @@ -181,7 +180,7 @@ func (v *VoiceConnection) Close() { v.log(LogInformational, "closing udp") err := v.udpConn.Close() if err != nil { - log.Println("error closing udp connection: ", err) + v.log(LogError, "error closing udp connection: ", err) } v.udpConn = nil } @@ -247,7 +246,7 @@ type voiceOP2 struct { } // WaitUntilConnected waits for the Voice Connection to -// become ready, if it does not become ready it retuns an err +// become ready, if it does not become ready it returns an err func (v *VoiceConnection) waitUntilConnected() error { v.log(LogInformational, "called") @@ -858,7 +857,7 @@ func (v *VoiceConnection) reconnect() { } if v.session.DataReady == false || v.session.wsConn == nil { - v.log(LogInformational, "cannot reconenct to channel %s with unready session", v.ChannelID) + v.log(LogInformational, "cannot reconnect to channel %s with unready session", v.ChannelID) continue } diff --git a/wsapi.go b/wsapi.go index df87092e0..de66f6931 100644 --- a/wsapi.go +++ b/wsapi.go @@ -15,6 +15,7 @@ import ( "compress/zlib" "encoding/json" "errors" + "fmt" "io" "net/http" "runtime" @@ -45,65 +46,93 @@ type resumePacket struct { } `json:"d"` } -// Open opens a websocket connection to Discord. -func (s *Session) Open() (err error) { - +// Open creates a websocket connection to Discord. +// See: https://discordapp.com/developers/docs/topics/gateway#connecting +func (s *Session) Open() error { s.log(LogInformational, "called") - s.Lock() - defer func() { - if err != nil { - s.Unlock() - } - }() + var err error - // A basic state is a hard requirement for Voice. - if s.State == nil { - state := NewState() - state.TrackChannels = false - state.TrackEmojis = false - state.TrackMembers = false - state.TrackRoles = false - state.TrackVoice = false - s.State = state - } + // Prevent Open or other major Session functions from + // being called while Open is still running. + s.Lock() + defer s.Unlock() + // If the websock is already open, bail out here. if s.wsConn != nil { - err = ErrWSAlreadyOpen - return - } - - if s.VoiceConnections == nil { - s.log(LogInformational, "creating new VoiceConnections map") - s.VoiceConnections = make(map[string]*VoiceConnection) + return ErrWSAlreadyOpen } // Get the gateway to use for the Websocket connection if s.gateway == "" { s.gateway, err = s.Gateway() if err != nil { - return + return err } // Add the version and encoding to the URL s.gateway = s.gateway + "?v=" + APIVersion + "&encoding=json" } + // Connect to the Gateway + s.log(LogInformational, "connecting to gateway %s", s.gateway) header := http.Header{} header.Add("accept-encoding", "zlib") - - s.log(LogInformational, "connecting to gateway %s", s.gateway) s.wsConn, _, err = websocket.DefaultDialer.Dial(s.gateway, header) if err != nil { s.log(LogWarning, "error connecting to gateway %s, %s", s.gateway, err) s.gateway = "" // clear cached gateway - // TODO: should we add a retry block here? - return + s.wsConn = nil // Just to be safe. + return err + } + + defer func() { + // because of this, all code below must set err to the error + // when exiting with an error :) Maybe someone has a better + // way :) + if err != nil { + s.wsConn.Close() + s.wsConn = nil + } + }() + + // The first response from Discord should be an Op 10 (Hello) Packet. + // When processed by onEvent the heartbeat goroutine will be started. + mt, m, err := s.wsConn.ReadMessage() + if err != nil { + return err + } + e, err := s.onEvent(mt, m) + if err != nil { + return err + } + if e.Operation != 10 { + err = fmt.Errorf("expecting Op 10, got Op %d instead", e.Operation) + return err + } + s.log(LogInformational, "Op 10 Hello Packet received from Discord") + s.LastHeartbeatAck = time.Now().UTC() + var h helloOp + if err = json.Unmarshal(e.RawData, &h); err != nil { + err = fmt.Errorf("error unmarshalling helloOp, %s", err) + return err } + // Now we send either an Op 2 Identity if this is a brand new + // connection or Op 6 Resume if we are resuming an existing connection. sequence := atomic.LoadInt64(s.sequence) - if s.sessionID != "" && sequence > 0 { + if s.sessionID == "" && sequence == 0 { + // Send Op 2 Identity Packet + err = s.identify() + if err != nil { + err = fmt.Errorf("error sending identify packet to gateway, %s, %s", s.gateway, err) + return err + } + + } else { + + // Send Op 6 Resume Packet p := resumePacket{} p.Op = 6 p.Data.Token = s.Token @@ -111,34 +140,66 @@ func (s *Session) Open() (err error) { p.Data.Sequence = sequence s.log(LogInformational, "sending resume packet to gateway") + s.wsMutex.Lock() err = s.wsConn.WriteJSON(p) + s.wsMutex.Unlock() if err != nil { - s.log(LogWarning, "error sending gateway resume packet, %s, %s", s.gateway, err) - return + err = fmt.Errorf("error sending gateway resume packet, %s, %s", s.gateway, err) + return err } - } else { - - err = s.identify() - if err != nil { - s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err) - return - } } - // Create listening outside of listen, as it needs to happen inside the mutex - // lock. - s.listening = make(chan interface{}) - go s.listen(s.wsConn, s.listening) - s.LastHeartbeatAck = time.Now().UTC() + // A basic state is a hard requirement for Voice. + // We create it here so the below READY/RESUMED packet can populate + // the state :) + // XXX: Move to New() func? + if s.State == nil { + state := NewState() + state.TrackChannels = false + state.TrackEmojis = false + state.TrackMembers = false + state.TrackRoles = false + state.TrackVoice = false + s.State = state + } - s.Unlock() + // Now Discord should send us a READY or RESUMED packet. + mt, m, err = s.wsConn.ReadMessage() + if err != nil { + return err + } + e, err = s.onEvent(mt, m) + if err != nil { + return err + } + if e.Type != `READY` && e.Type != `RESUMED` { + // This is not fatal, but it does not follow their API documentation. + s.log(LogWarning, "Expected READY/RESUMED, instead got:\n%#v\n", e) + } + s.log(LogInformational, "First Packet:\n%#v\n", e) - s.log(LogInformational, "emit connect event") + s.log(LogInformational, "We are now connected to Discord, emitting connect event") s.handleEvent(connectEventType, &Connect{}) + // A VoiceConnections map is a hard requirement for Voice. + // XXX: can this be moved to when opening a voice connection? + if s.VoiceConnections == nil { + s.log(LogInformational, "creating new VoiceConnections map") + s.VoiceConnections = make(map[string]*VoiceConnection) + } + + // Create listening chan outside of listen, as it needs to happen inside the + // mutex lock and needs to exist before calling heartbeat and listen + // go rountines. + s.listening = make(chan interface{}) + + // Start sending heartbeats and reading messages from Discord. + go s.heartbeat(s.wsConn, s.listening, h.HeartbeatInterval) + go s.listen(s.wsConn, s.listening) + s.log(LogInformational, "exiting") - return + return nil } // listen polls the websocket connection for events, it will stop when the @@ -249,7 +310,8 @@ func (s *Session) heartbeat(wsConn *websocket.Conn, listening <-chan interface{} } } -type updateStatusData struct { +// UpdateStatusData ia provided to UpdateStatusComplex() +type UpdateStatusData struct { IdleSince *int `json:"since"` Game *Game `json:"game"` AFK bool `json:"afk"` @@ -258,7 +320,7 @@ type updateStatusData struct { type updateStatusOp struct { Op int `json:"op"` - Data updateStatusData `json:"d"` + Data UpdateStatusData `json:"d"` } // UpdateStreamingStatus is used to update the user's streaming status. @@ -270,13 +332,7 @@ func (s *Session) UpdateStreamingStatus(idle int, game string, url string) (err s.log(LogInformational, "called") - s.RLock() - defer s.RUnlock() - if s.wsConn == nil { - return ErrWSNotFound - } - - usd := updateStatusData{ + usd := UpdateStatusData{ Status: "online", } @@ -285,9 +341,9 @@ func (s *Session) UpdateStreamingStatus(idle int, game string, url string) (err } if game != "" { - gameType := 0 + gameType := GameTypeGame if url != "" { - gameType = 1 + gameType = GameTypeStreaming } usd.Game = &Game{ Name: game, @@ -296,6 +352,18 @@ func (s *Session) UpdateStreamingStatus(idle int, game string, url string) (err } } + return s.UpdateStatusComplex(usd) +} + +// UpdateStatusComplex allows for sending the raw status update data untouched by discordgo. +func (s *Session) UpdateStatusComplex(usd UpdateStatusData) (err error) { + + s.RLock() + defer s.RUnlock() + if s.wsConn == nil { + return ErrWSNotFound + } + s.wsMutex.Lock() err = s.wsConn.WriteJSON(updateStatusOp{3, usd}) s.wsMutex.Unlock() @@ -357,9 +425,7 @@ func (s *Session) RequestGuildMembers(guildID, query string, limit int) (err err // // If you use the AddHandler() function to register a handler for the // "OnEvent" event then all events will be passed to that handler. -// -// TODO: You may also register a custom event handler entirely using... -func (s *Session) onEvent(messageType int, message []byte) { +func (s *Session) onEvent(messageType int, message []byte) (*Event, error) { var err error var reader io.Reader @@ -371,7 +437,7 @@ func (s *Session) onEvent(messageType int, message []byte) { z, err2 := zlib.NewReader(reader) if err2 != nil { s.log(LogError, "error uncompressing websocket message, %s", err) - return + return nil, err2 } defer func() { @@ -389,7 +455,7 @@ func (s *Session) onEvent(messageType int, message []byte) { decoder := json.NewDecoder(reader) if err = decoder.Decode(&e); err != nil { s.log(LogError, "error decoding websocket message, %s", err) - return + return e, err } s.log(LogDebug, "Op: %d, Seq: %d, Type: %s, Data: %s\n\n", e.Operation, e.Sequence, e.Type, string(e.RawData)) @@ -403,10 +469,10 @@ func (s *Session) onEvent(messageType int, message []byte) { s.wsMutex.Unlock() if err != nil { s.log(LogError, "error sending heartbeat in response to Op1") - return + return e, err } - return + return e, nil } // Reconnect @@ -415,7 +481,7 @@ func (s *Session) onEvent(messageType int, message []byte) { s.log(LogInformational, "Closing and reconnecting in response to Op7") s.Close() s.reconnect() - return + return e, nil } // Invalid Session @@ -427,20 +493,15 @@ func (s *Session) onEvent(messageType int, message []byte) { err = s.identify() if err != nil { s.log(LogWarning, "error sending gateway identify packet, %s, %s", s.gateway, err) - return + return e, err } - return + return e, nil } if e.Operation == 10 { - var h helloOp - if err = json.Unmarshal(e.RawData, &h); err != nil { - s.log(LogError, "error unmarshalling helloOp, %s", err) - } else { - go s.heartbeat(s.wsConn, s.listening, h.HeartbeatInterval) - } - return + // Op10 is handled by Open() + return e, nil } if e.Operation == 11 { @@ -448,7 +509,7 @@ func (s *Session) onEvent(messageType int, message []byte) { s.LastHeartbeatAck = time.Now().UTC() s.Unlock() s.log(LogInformational, "got heartbeat ACK") - return + return e, nil } // Do not try to Dispatch a non-Dispatch Message @@ -456,7 +517,7 @@ func (s *Session) onEvent(messageType int, message []byte) { // But we probably should be doing something with them. // TEMP s.log(LogWarning, "unknown Op: %d, Seq: %d, Type: %s, Data: %s, message: %s", e.Operation, e.Sequence, e.Type, string(e.RawData), string(message)) - return + return e, nil } // Store the message sequence @@ -485,6 +546,8 @@ func (s *Session) onEvent(messageType int, message []byte) { // For legacy reasons, we send the raw event also, this could be useful for handling unknown events. s.handleEvent(eventEventType, e) + + return e, nil } // ------------------------------------------------------------------------------------------------ @@ -610,7 +673,7 @@ func (s *Session) onVoiceServerUpdate(st *VoiceServerUpdate) { voice.GuildID = st.GuildID voice.Unlock() - // Open a conenction to the voice server + // Open a connection to the voice server err := voice.open() if err != nil { s.log(LogError, "onVoiceServerUpdate voice.open, %s", err)