diff --git a/.golangci.yml b/.golangci.yml index fc124c57..754438b8 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -29,6 +29,14 @@ linters: - bodyclose # Checks HTTP response body closed settings: + errcheck: + # defer rows.Close() / stmt.Close() etc. are idiomatic and safe to ignore. + exclude-functions: + - (*database/sql.Rows).Close + - (*database/sql.Stmt).Close + - (*database/sql.DB).Close + - (*database/sql.Tx).Close + gocritic: disabled-checks: - ifElseChain diff --git a/base/db.go b/base/db.go index 24d22a68..172bbeab 100644 --- a/base/db.go +++ b/base/db.go @@ -1,10 +1,14 @@ package base import ( + "context" "database/sql" + "database/sql/driver" + "errors" "fmt" "time" + "github.com/go-sql-driver/mysql" "golang.org/x/oauth2" ) @@ -21,18 +25,46 @@ func NewDB(db *sql.DB) *DB { } } -func (d *DB) RunTxn(fn func(tx *sql.Tx) error) error { - tx, err := d.Begin() +func (d *DB) RunTxnContext(ctx context.Context, fn func(tx *sql.Tx) error) error { + tx, err := d.BeginTx(ctx, nil) if err != nil { return err } if err := fn(tx); err != nil { - if rerr := tx.Rollback(); rerr != nil { + if rerr := tx.Rollback(); rerr != nil && !errors.Is(rerr, sql.ErrTxDone) { fmt.Printf("unable to rollback: %v", rerr) } + return normalizeCtxErr(ctx, err) + } + return normalizeCommitErr(ctx, tx.Commit()) +} + +// normalizeCommitErr recovers the true cause of a failed Commit. When the +// context is cancelled, the driver rolls back the transaction and marks it +// done, so tx.Commit() returns sql.ErrTxDone instead of context.Canceled. +// Returning ctx.Err() lets callers distinguish a client abort from a real DB error. +func normalizeCommitErr(ctx context.Context, err error) error { + if err == nil { + return nil + } + if ctxErr := ctx.Err(); ctxErr != nil && errors.Is(err, sql.ErrTxDone) { + return ctxErr + } + return err +} + +// normalizeCtxErr maps opaque driver errors (ErrBadConn, ErrInvalidConn) back +// to ctx.Err() when the context is already done. This lets callers distinguish +// a client abort from a real storage failure without checking at every call site. +func normalizeCtxErr(ctx context.Context, err error) error { + if err == nil || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return err } - return tx.Commit() + if ctxErr := ctx.Err(); ctxErr != nil && + (errors.Is(err, driver.ErrBadConn) || errors.Is(err, mysql.ErrInvalidConn)) { + return ctxErr + } + return err } type BaseOAuthDB struct { //nolint @@ -45,9 +77,9 @@ func NewBaseOAuthDB(db *sql.DB) *BaseOAuthDB { } } -func (d *BaseOAuthDB) GetState(state string) (*OAuthRequest, error) { +func (d *BaseOAuthDB) GetState(ctx context.Context, state string) (*OAuthRequest, error) { var oauthState OAuthRequest - row := d.QueryRow(`SELECT identifier, conv_id, msg_id, is_complete + row := d.QueryRowContext(ctx, `SELECT identifier, conv_id, msg_id, is_complete FROM oauth_state WHERE state = ?`, state) err := row.Scan(&oauthState.TokenIdentifier, &oauthState.ConvID, @@ -62,9 +94,8 @@ func (d *BaseOAuthDB) GetState(state string) (*OAuthRequest, error) { } } -func (d *BaseOAuthDB) PutState(state string, oauthState *OAuthRequest) error { - err := d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(`INSERT INTO oauth_state +func (d *BaseOAuthDB) PutState(ctx context.Context, state string, oauthState *OAuthRequest) error { + _, err := d.ExecContext(ctx, `INSERT INTO oauth_state (state, identifier, conv_id, msg_id) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE @@ -72,18 +103,13 @@ func (d *BaseOAuthDB) PutState(state string, oauthState *OAuthRequest) error { conv_id=VALUES(conv_id), msg_id=VALUES(msg_id) `, state, oauthState.TokenIdentifier, oauthState.ConvID, oauthState.MsgID) - return err - }) return err } -func (d *BaseOAuthDB) CompleteState(state string) error { - err := d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(`UPDATE oauth_state +func (d *BaseOAuthDB) CompleteState(ctx context.Context, state string) error { + _, err := d.ExecContext(ctx, `UPDATE oauth_state SET is_complete=true WHERE state = ?`, state) - return err - }) return err } @@ -97,10 +123,10 @@ func NewOAuthDB(db *sql.DB) *OAuthDB { } } -func (d *OAuthDB) GetToken(identifier string) (*oauth2.Token, error) { +func (d *OAuthDB) GetToken(ctx context.Context, identifier string) (*oauth2.Token, error) { var token oauth2.Token var expiry int64 - row := d.QueryRow(`SELECT access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(expiry)) + row := d.QueryRowContext(ctx, `SELECT access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(expiry)) FROM oauth WHERE identifier = ?`, identifier) err := row.Scan(&token.AccessToken, &token.TokenType, @@ -116,9 +142,8 @@ func (d *OAuthDB) GetToken(identifier string) (*oauth2.Token, error) { } } -func (d *OAuthDB) PutToken(identifier string, token *oauth2.Token) error { - err := d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(`INSERT INTO oauth +func (d *OAuthDB) PutToken(ctx context.Context, identifier string, token *oauth2.Token) error { + _, err := d.ExecContext(ctx, `INSERT INTO oauth (identifier, access_token, token_type, refresh_token, expiry, ctime, mtime) VALUES (?, ?, ?, ?, ?, NOW(), NOW()) ON DUPLICATE KEY UPDATE @@ -127,16 +152,10 @@ func (d *OAuthDB) PutToken(identifier string, token *oauth2.Token) error { expiry=VALUES(expiry), mtime=VALUES(mtime) `, identifier, token.AccessToken, token.TokenType, token.RefreshToken, token.Expiry) - return err - }) return err } -func (d *OAuthDB) DeleteToken(identifier string) error { - err := d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(`DELETE FROM oauth - WHERE identifier = ?`, identifier) - return err - }) +func (d *OAuthDB) DeleteToken(ctx context.Context, identifier string) error { + _, err := d.ExecContext(ctx, `DELETE FROM oauth WHERE identifier = ?`, identifier) return err } diff --git a/base/multi.go b/base/multi.go index 6b764ee1..b301cf61 100644 --- a/base/multi.go +++ b/base/multi.go @@ -1,6 +1,7 @@ package base import ( + "context" "database/sql" "fmt" "sync" @@ -58,20 +59,17 @@ func (m *multi) IsLeader() bool { } func (m *multi) heartbeat() { + ctx := context.Background() // update ourselves first - err := m.db.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO heartbeats (id, name, mtime) - VALUES (?, ?, NOW(6)) ON DUPLICATE KEY UPDATE mtime=NOW(6) - `, m.id, m.name) - return err - }) - if err != nil { + if _, err := m.db.ExecContext(ctx, ` + INSERT INTO heartbeats (id, name, mtime) + VALUES (?, ?, NOW(6)) ON DUPLICATE KEY UPDATE mtime=NOW(6) + `, m.id, m.name); err != nil { m.Errorf("failed to register heartbeat tx: %s", err) return } // see if we are the leader - row := m.db.QueryRow(fmt.Sprintf(` + row := m.db.QueryRowContext(ctx, fmt.Sprintf(` SELECT id FROM heartbeats WHERE mtime > NOW(6) - INTERVAL %d SECOND AND name = ? ORDER BY id DESC @@ -101,13 +99,11 @@ func (m *multi) heartbeat() { } func (m *multi) deregister() { - err := m.db.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE from heartbeats - WHERE id = ? OR mtime < NOW() - INTERVAL 1 MINUTE - `, m.id) - return err - }) + ctx := context.Background() + _, err := m.db.ExecContext(ctx, ` + DELETE from heartbeats + WHERE id = ? OR mtime < NOW() - INTERVAL 1 MINUTE + `, m.id) if err != nil { m.Errorf("deregister: failed to execute : %s", err) } diff --git a/base/oauth.go b/base/oauth.go index 700b9067..86e25a70 100644 --- a/base/oauth.go +++ b/base/oauth.go @@ -22,13 +22,13 @@ func (e OAuthRequiredError) Error() string { } type OAuthStorage interface { - GetToken(identifier string) (*oauth2.Token, error) - PutToken(identifier string, token *oauth2.Token) error - DeleteToken(identifier string) error + GetToken(ctx context.Context, identifier string) (*oauth2.Token, error) + PutToken(ctx context.Context, identifier string, token *oauth2.Token) error + DeleteToken(ctx context.Context, identifier string) error - GetState(state string) (*OAuthRequest, error) - PutState(state string, req *OAuthRequest) error - CompleteState(state string) error + GetState(ctx context.Context, state string) (*OAuthRequest, error) + PutState(ctx context.Context, state string, req *OAuthRequest) error + CompleteState(ctx context.Context, state string) error } type OAuthHTTPSrv struct { @@ -36,7 +36,7 @@ type OAuthHTTPSrv struct { kbc *kbchat.API oauth *oauth2.Config storage OAuthStorage - callback func(msg chat1.MsgSummary, identifier string) error + callback func(ctx context.Context, msg chat1.MsgSummary, identifier string) error htmlTitle string htmlLogoB64 string htmlLogoSrc string @@ -48,7 +48,7 @@ func NewOAuthHTTPSrv( debugConfig *ChatDebugOutputConfig, oauth *oauth2.Config, storage OAuthStorage, - callback func(msg chat1.MsgSummary, identifier string) error, + callback func(ctx context.Context, msg chat1.MsgSummary, identifier string) error, htmlTitle string, htmlLogoB64 string, urlPrefix string, @@ -100,8 +100,11 @@ func (o *OAuthHTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() state := query.Get("state") + // WithoutCancel: the browser may close after the redirect; DB writes and the + // HandleAuth callback (which sends Keybase messages) must finish regardless. + ctx := context.WithoutCancel(r.Context()) - req, err := o.storage.GetState(state) + req, err := o.storage.GetState(ctx, state) if err != nil { err = fmt.Errorf("could not get state %q: %v", state, err) return @@ -128,15 +131,15 @@ func (o *OAuthHTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) { o.showOAuthError(w) return } - token, err := o.oauth.Exchange(context.TODO(), code) + token, err := o.oauth.Exchange(ctx, code) if err != nil { return } - if err = o.storage.PutToken(req.TokenIdentifier, token); err != nil { + if err = o.storage.PutToken(ctx, req.TokenIdentifier, token); err != nil { return } - if err = o.storage.CompleteState(state); err != nil { + if err = o.storage.CompleteState(ctx, state); err != nil { return } callbackMsg, err := o.getCallbackMsg(*req) @@ -144,7 +147,7 @@ func (o *OAuthHTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) { return } - if err = o.callback(callbackMsg, req.TokenIdentifier); err != nil { + if err = o.callback(ctx, callbackMsg, req.TokenIdentifier); err != nil { return } @@ -189,6 +192,7 @@ type GetOAuthOpts struct { } func GetOAuthClient( + ctx context.Context, tokenIdentifier string, callbackMsg chat1.MsgSummary, kbc *kbchat.API, @@ -196,7 +200,7 @@ func GetOAuthClient( storage OAuthStorage, opts GetOAuthOpts, ) (*http.Client, error) { - token, err := storage.GetToken(tokenIdentifier) + token, err := storage.GetToken(ctx, tokenIdentifier) if err != nil { return nil, err } @@ -216,7 +220,7 @@ func GetOAuthClient( if err != nil { return nil, err } - if err := storage.PutState(state, &OAuthRequest{ + if err := storage.PutState(ctx, state, &OAuthRequest{ TokenIdentifier: tokenIdentifier, ConvID: callbackMsg.ConvID, MsgID: callbackMsg.Id, @@ -257,16 +261,16 @@ func GetOAuthClient( } // renew token if token.Expiry.Before(time.Now()) { - newToken, err := config.TokenSource(context.Background(), token).Token() + newToken, err := config.TokenSource(ctx, token).Token() if err != nil { return nil, fmt.Errorf("unable to renew token: %s", err) } - err = storage.PutToken(tokenIdentifier, newToken) + err = storage.PutToken(ctx, tokenIdentifier, newToken) if err != nil { return nil, fmt.Errorf("unable to update token: %s", err) } token = newToken } - return config.Client(context.Background(), token), nil + return config.Client(ctx, token), nil } diff --git a/base/server.go b/base/server.go index 04b1cc14..5ae56eab 100644 --- a/base/server.go +++ b/base/server.go @@ -1,7 +1,9 @@ package base import ( + "context" "database/sql" + "errors" "fmt" "io" "os" @@ -18,8 +20,8 @@ import ( ) type Handler interface { - HandleCommand(chat1.MsgSummary) error - HandleNewConv(chat1.ConvSummary) error + HandleCommand(context.Context, chat1.MsgSummary) error + HandleNewConv(context.Context, chat1.ConvSummary) error } type Shutdowner interface { @@ -160,9 +162,18 @@ func (s *Server) Listen(handler Handler) (err error) { s.Lock() shutdownCh := s.shutdownCh s.Unlock() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + select { + case <-shutdownCh: + cancel() + case <-ctx.Done(): + } + }() eg := &errgroup.Group{} - s.GoWithRecover(eg, func() error { return s.listenForMsgs(shutdownCh, sub, handler) }) - s.GoWithRecover(eg, func() error { return s.listenForConvs(shutdownCh, sub, handler) }) + s.GoWithRecover(eg, func() error { return s.listenForMsgs(ctx, shutdownCh, sub, handler) }) + s.GoWithRecover(eg, func() error { return s.listenForConvs(ctx, shutdownCh, sub, handler) }) s.GoWithRecover(eg, func() error { return s.multi.Heartbeat(shutdownCh) }) if err := eg.Wait(); err != nil { s.Debug("wait error: %s", err) @@ -171,7 +182,7 @@ func (s *Server) Listen(handler Handler) (err error) { return nil } -func (s *Server) listenForMsgs(shutdownCh chan struct{}, sub *kbchat.Subscription, handler Handler) (err error) { +func (s *Server) listenForMsgs(ctx context.Context, shutdownCh chan struct{}, sub *kbchat.Subscription, handler Handler) (err error) { for { select { case <-shutdownCh: @@ -224,16 +235,19 @@ func (s *Server) listenForMsgs(shutdownCh chan struct{}, sub *kbchat.Subscriptio } } - err = handler.HandleCommand(msg) - switch err := err.(type) { - case nil, OAuthRequiredError: + err = handler.HandleCommand(ctx, msg) + switch { + case err == nil: + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + s.Debug("listenForMsgs: suppressing shutdown context error: %v", err) + case errors.As(err, new(OAuthRequiredError)): default: s.ChatErrorf(msg.ConvID, "listenForMsgs: unable to HandleCommand: %v", err) } } } -func (s *Server) listenForConvs(shutdownCh chan struct{}, sub *kbchat.Subscription, handler Handler) error { +func (s *Server) listenForConvs(ctx context.Context, shutdownCh chan struct{}, sub *kbchat.Subscription, handler Handler) error { for { select { case <-shutdownCh: @@ -253,7 +267,11 @@ func (s *Server) listenForConvs(shutdownCh chan struct{}, sub *kbchat.Subscripti continue } - if err := handler.HandleNewConv(c.Conversation); err != nil { + if err := handler.HandleNewConv(ctx, c.Conversation); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + s.Debug("listenForConvs: suppressing shutdown context error: %v", err) + continue + } s.Errorf("listenForConvs: unable to HandleNewConv: %v", err) } } diff --git a/canarybot/canarybot/handler.go b/canarybot/canarybot/handler.go index 451a0720..4e05be1c 100644 --- a/canarybot/canarybot/handler.go +++ b/canarybot/canarybot/handler.go @@ -1,6 +1,7 @@ package canarybot import ( + "context" "strings" "github.com/keybase/go-keybase-chat-bot/kbchat" @@ -35,12 +36,12 @@ func (h *Handler) handleEcho(cmd string, msg chat1.MsgSummary) error { return nil } -func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(_ context.Context, conv chat1.ConvSummary) error { welcomeMsg := "Hey there I'm canarybot. Seems like I'm alive because you're getting this message. Happy days." return base.HandleNewTeam(h.stats, h.DebugOutput, h.kbc, conv, welcomeMsg) } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(_ context.Context, msg chat1.MsgSummary) error { if msg.Content.Text == nil || !strings.HasPrefix(msg.Content.Text.Body, "!canary") { return nil } diff --git a/elastiwatch/elastiwatch/db.go b/elastiwatch/elastiwatch/db.go index 94f6e1ae..a7fa4eaa 100644 --- a/elastiwatch/elastiwatch/db.go +++ b/elastiwatch/elastiwatch/db.go @@ -1,8 +1,8 @@ package elastiwatch import ( + "context" "database/sql" - "fmt" "time" "github.com/keybase/managed-bots/base" @@ -18,13 +18,11 @@ func NewDB(db *sql.DB) *DB { } } -func (d *DB) Create(regex, author string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO deferrals (regex, author, ctime) VALUES (?, ?, NOW()) - `, regex, author) - return err - }) +func (d *DB) Create(ctx context.Context, regex, author string) error { + _, err := d.ExecContext(ctx, ` + INSERT INTO deferrals (regex, author, ctime) VALUES (?, ?, NOW()) + `, regex, author) + return err } type Deferral struct { @@ -34,18 +32,14 @@ type Deferral struct { Ctime time.Time } -func (d *DB) List() (res []Deferral, err error) { - rows, err := d.Query(` +func (d *DB) List(ctx context.Context) (res []Deferral, err error) { + rows, err := d.QueryContext(ctx, ` SELECT id, regex, author, ctime FROM deferrals `) if err != nil { return res, err } - defer func() { - if cerr := rows.Close(); cerr != nil { - fmt.Printf("elastiwatch: failed to close rows: %v\n", cerr) - } - }() + defer rows.Close() for rows.Next() { var def Deferral if err := rows.Scan(&def.ID, &def.Regex, &def.Author, &def.Ctime); err != nil { @@ -53,14 +47,10 @@ func (d *DB) List() (res []Deferral, err error) { } res = append(res, def) } - return res, nil + return res, rows.Err() } -func (d *DB) Remove(id int) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM deferrals WHERE id = ? - `, id) - return err - }) +func (d *DB) Remove(ctx context.Context, id int) error { + _, err := d.ExecContext(ctx, `DELETE FROM deferrals WHERE id = ?`, id) + return err } diff --git a/elastiwatch/elastiwatch/handler.go b/elastiwatch/elastiwatch/handler.go index 1abb7e48..ef1af9c7 100644 --- a/elastiwatch/elastiwatch/handler.go +++ b/elastiwatch/elastiwatch/handler.go @@ -1,6 +1,7 @@ package elastiwatch import ( + "context" "fmt" "strconv" "strings" @@ -31,7 +32,7 @@ func NewHandler(kbc *kbchat.API, debugConfig *base.ChatDebugOutputConfig, httpSr } } -func (h *Handler) handleDefer(convID chat1.ConvIDStr, author, cmd string) error { +func (h *Handler) handleDefer(ctx context.Context, convID chat1.ConvIDStr, author, cmd string) error { toks := strings.Split(cmd, " ") if len(toks) < 3 { h.ChatEcho(convID, "must specify a regular expression") @@ -39,15 +40,15 @@ func (h *Handler) handleDefer(convID chat1.ConvIDStr, author, cmd string) error } regex := strings.Join(toks[2:], " ") h.ChatEcho(convID, "adding deferral: %s", regex) - if err := h.db.Create(regex, author); err != nil { + if err := h.db.Create(ctx, regex, author); err != nil { return err } h.ChatEcho(convID, "Success!") return nil } -func (h *Handler) handleDeferrals(convID chat1.ConvIDStr, _ string) error { - deferrals, err := h.db.List() +func (h *Handler) handleDeferrals(ctx context.Context, convID chat1.ConvIDStr, _ string) error { + deferrals, err := h.db.List(ctx) if err != nil { return err } @@ -63,7 +64,7 @@ func (h *Handler) handleDeferrals(convID chat1.ConvIDStr, _ string) error { return nil } -func (h *Handler) handleUndefer(convID chat1.ConvIDStr, cmd string) error { +func (h *Handler) handleUndefer(ctx context.Context, convID chat1.ConvIDStr, cmd string) error { toks := strings.Split(cmd, " ") if len(toks) < 3 { h.ChatEcho(convID, "must specify an ID") @@ -75,7 +76,7 @@ func (h *Handler) handleUndefer(convID chat1.ConvIDStr, cmd string) error { return nil } h.ChatEcho(convID, "removing deferral: %d", id) - if err := h.db.Remove(int(id)); err != nil { + if err := h.db.Remove(ctx, int(id)); err != nil { return err } h.ChatEcho(convID, "Success!") @@ -87,24 +88,24 @@ func (h *Handler) handleDump() error { return nil } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(ctx context.Context, msg chat1.MsgSummary) error { if msg.Content.Text == nil { return nil } cmd := strings.TrimSpace(msg.Content.Text.Body) switch { case strings.HasPrefix(cmd, "!elastiwatch defer"): - return h.handleDefer(msg.ConvID, msg.Sender.Username, cmd) + return h.handleDefer(ctx, msg.ConvID, msg.Sender.Username, cmd) case strings.HasPrefix(cmd, "!elastiwatch list-defers"): - return h.handleDeferrals(msg.ConvID, cmd) + return h.handleDeferrals(ctx, msg.ConvID, cmd) case strings.HasPrefix(cmd, "!elastiwatch undefer"): - return h.handleUndefer(msg.ConvID, cmd) + return h.handleUndefer(ctx, msg.ConvID, cmd) case strings.HasPrefix(cmd, "!elastiwatch dump"): return h.handleDump() } return nil } -func (h *Handler) HandleNewConv(chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(context.Context, chat1.ConvSummary) error { return nil } diff --git a/elastiwatch/elastiwatch/logs.go b/elastiwatch/elastiwatch/logs.go index d911e221..105fd0e6 100644 --- a/elastiwatch/elastiwatch/logs.go +++ b/elastiwatch/elastiwatch/logs.go @@ -93,7 +93,7 @@ func (l *LogWatch) alertEmail(subject string, chunks []chunk) { func (l *LogWatch) filterEntries(entries []*entry) (res []*entry) { // get regexes - deferrals, err := l.db.List() + deferrals, err := l.db.List(context.Background()) if err != nil { l.Errorf("failed to get filter list: %s", err) return entries diff --git a/gcalbot/gcalbot/account.go b/gcalbot/gcalbot/account.go index ca2839a7..a60dce23 100644 --- a/gcalbot/gcalbot/account.go +++ b/gcalbot/gcalbot/account.go @@ -15,9 +15,9 @@ import ( "google.golang.org/api/option" ) -func (h *Handler) handleAccountsList(msg chat1.MsgSummary) error { +func (h *Handler) handleAccountsList(ctx context.Context, msg chat1.MsgSummary) error { username := msg.Sender.Username - accounts, err := h.db.GetAccountListForUsername(username) + accounts, err := h.db.GetAccountListForUsername(ctx, username) if err != nil { return fmt.Errorf("error fetching accounts from database %q", err) } @@ -37,7 +37,7 @@ func (h *Handler) handleAccountsList(msg chat1.MsgSummary) error { return nil } -func (h *Handler) handleAccountsConnect(msg chat1.MsgSummary, args []string) error { +func (h *Handler) handleAccountsConnect(ctx context.Context, msg chat1.MsgSummary, args []string) error { if len(args) != 1 { h.ChatEcho(msg.ConvID, "Invalid number of arguments.") return nil @@ -46,7 +46,7 @@ func (h *Handler) handleAccountsConnect(msg chat1.MsgSummary, args []string) err keybaseUsername := msg.Sender.Username accountNickname := args[0] - exists, err := h.db.ExistsAccount(keybaseUsername, accountNickname) + exists, err := h.db.ExistsAccount(ctx, keybaseUsername, accountNickname) if err != nil { return fmt.Errorf("error checking for account: %s", err) } else if exists { @@ -54,10 +54,10 @@ func (h *Handler) handleAccountsConnect(msg chat1.MsgSummary, args []string) err return nil } - return h.requestOAuth(msg, accountNickname) + return h.requestOAuth(ctx, msg, accountNickname) } -func (h *Handler) handleAccountsDisconnect(msg chat1.MsgSummary, args []string) error { +func (h *Handler) handleAccountsDisconnect(ctx context.Context, msg chat1.MsgSummary, args []string) error { if len(args) != 1 { h.ChatEcho(msg.ConvID, "Invalid number of arguments.") return nil @@ -66,7 +66,7 @@ func (h *Handler) handleAccountsDisconnect(msg chat1.MsgSummary, args []string) keybaseUsername := msg.Sender.Username accountNickname := args[0] - exists, err := h.db.ExistsAccount(keybaseUsername, accountNickname) + exists, err := h.db.ExistsAccount(ctx, keybaseUsername, accountNickname) if err != nil { return fmt.Errorf("error checking for account: %s", err) } else if !exists { @@ -74,7 +74,7 @@ func (h *Handler) handleAccountsDisconnect(msg chat1.MsgSummary, args []string) return nil } - err = h.deleteAccount(keybaseUsername, accountNickname) + err = h.deleteAccount(ctx, keybaseUsername, accountNickname) if err != nil { return err } @@ -83,18 +83,18 @@ func (h *Handler) handleAccountsDisconnect(msg chat1.MsgSummary, args []string) return nil } -func (h *Handler) deleteAccount(keybaseUsername, accountNickname string) error { - account, err := h.db.GetAccount(keybaseUsername, accountNickname) +func (h *Handler) deleteAccount(ctx context.Context, keybaseUsername, accountNickname string) error { + account, err := h.db.GetAccount(ctx, keybaseUsername, accountNickname) if err != nil || account == nil { return fmt.Errorf("error getting account: %s", err) } - srv, err := GetCalendarService(account, h.oauth, h.db) + srv, err := GetCalendarService(ctx, account, h.oauth, h.db) if err != nil { return err } - channels, err := h.db.GetChannelListByAccount(account) + channels, err := h.db.GetChannelListByAccount(ctx, account) if err != nil { return err } @@ -118,25 +118,25 @@ func (h *Handler) deleteAccount(keybaseUsername, accountNickname string) error { } // cascading delete of account, oauth, subscriptions, channels and invites - err = h.db.DeleteAccount(keybaseUsername, accountNickname) + err = h.db.DeleteAccount(ctx, keybaseUsername, accountNickname) return err } -func GetCalendarService(account *Account, config *oauth2.Config, db *DB) (srv *calendar.Service, err error) { +func GetCalendarService(ctx context.Context, account *Account, config *oauth2.Config, db *DB) (srv *calendar.Service, err error) { if account.Token.Expiry.Before(time.Now()) { - newToken, err := config.TokenSource(context.Background(), &account.Token).Token() + newToken, err := config.TokenSource(ctx, &account.Token).Token() if err != nil { return nil, err } if newToken.AccessToken != account.Token.AccessToken { account.Token = *newToken - err = db.InsertAccount(*account) + err = db.InsertAccount(ctx, *account) if err != nil { return nil, fmt.Errorf("unable to update account token: %s", err) } } } - client := config.Client(context.Background(), &account.Token) - return calendar.NewService(context.Background(), option.WithHTTPClient(client)) + client := config.Client(ctx, &account.Token) + return calendar.NewService(ctx, option.WithHTTPClient(client)) } diff --git a/gcalbot/gcalbot/calendar.go b/gcalbot/gcalbot/calendar.go index 98d7b809..f2e9e990 100644 --- a/gcalbot/gcalbot/calendar.go +++ b/gcalbot/gcalbot/calendar.go @@ -8,7 +8,7 @@ import ( "google.golang.org/api/calendar/v3" ) -func (h *Handler) handleCalendarsList(msg chat1.MsgSummary, args []string) error { +func (h *Handler) handleCalendarsList(ctx context.Context, msg chat1.MsgSummary, args []string) error { if len(args) != 1 { h.ChatEcho(msg.ConvID, "Invalid number of arguments.") return nil @@ -17,7 +17,7 @@ func (h *Handler) handleCalendarsList(msg chat1.MsgSummary, args []string) error keybaseUsername := msg.Sender.Username accountNickname := args[0] - account, err := h.db.GetAccount(keybaseUsername, accountNickname) + account, err := h.db.GetAccount(ctx, keybaseUsername, accountNickname) if err != nil { return err } else if account == nil { @@ -25,12 +25,12 @@ func (h *Handler) handleCalendarsList(msg chat1.MsgSummary, args []string) error return nil } - srv, err := GetCalendarService(account, h.oauth, h.db) + srv, err := GetCalendarService(ctx, account, h.oauth, h.db) if err != nil { return err } - calendarList, err := getCalendarList(srv) + calendarList, err := getCalendarList(ctx, srv) if err != nil { return err } @@ -55,8 +55,8 @@ func (h *Handler) handleCalendarsList(msg chat1.MsgSummary, args []string) error return nil } -func getCalendarList(srv *calendar.Service) (list []*calendar.CalendarListEntry, err error) { - err = srv.CalendarList.List().Pages(context.Background(), func(page *calendar.CalendarList) error { +func getCalendarList(ctx context.Context, srv *calendar.Service) (list []*calendar.CalendarListEntry, err error) { + err = srv.CalendarList.List().Pages(ctx, func(page *calendar.CalendarList) error { list = append(list, page.Items...) return nil }) diff --git a/gcalbot/gcalbot/db.go b/gcalbot/gcalbot/db.go index 08514cdd..9f0f8ab6 100644 --- a/gcalbot/gcalbot/db.go +++ b/gcalbot/gcalbot/db.go @@ -1,6 +1,7 @@ package gcalbot import ( + "context" "database/sql" "strings" "time" @@ -26,9 +27,9 @@ func NewDB( } // OAuth state -func (d *DB) GetState(state string) (*OAuthRequest, error) { +func (d *DB) GetState(ctx context.Context, state string) (*OAuthRequest, error) { var oauthState OAuthRequest - row := d.QueryRow(` + row := d.QueryRowContext(ctx, ` SELECT keybase_username, account_nickname, keybase_conv_id, is_complete FROM oauth_state WHERE state = ? @@ -45,54 +46,48 @@ func (d *DB) GetState(state string) (*OAuthRequest, error) { } } -func (d *DB) PutState(state string, oauthState OAuthRequest) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO oauth_state - (state, keybase_username, account_nickname, keybase_conv_id) - VALUES (?, ?, ?, ?) - ON DUPLICATE KEY UPDATE - keybase_username=VALUES(keybase_username), - account_nickname=VALUES(account_nickname), - keybase_conv_id=VALUES(keybase_conv_id) - `, state, oauthState.KeybaseUsername, oauthState.AccountNickname, oauthState.KeybaseConvID) - return err - }) +func (d *DB) PutState(ctx context.Context, state string, oauthState OAuthRequest) error { + _, err := d.ExecContext(ctx, ` + INSERT INTO oauth_state + (state, keybase_username, account_nickname, keybase_conv_id) + VALUES (?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + keybase_username=VALUES(keybase_username), + account_nickname=VALUES(account_nickname), + keybase_conv_id=VALUES(keybase_conv_id) + `, state, oauthState.KeybaseUsername, oauthState.AccountNickname, oauthState.KeybaseConvID) + return err } -func (d *DB) CompleteState(state string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - UPDATE oauth_state - SET is_complete = true - WHERE state = ? - `, state) - return err - }) +func (d *DB) CompleteState(ctx context.Context, state string) error { + _, err := d.ExecContext(ctx, ` + UPDATE oauth_state + SET is_complete = true + WHERE state = ? + `, state) + return err } // Account -func (d *DB) InsertAccount(account Account) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO account - (keybase_username, account_nickname, access_token, token_type, refresh_token, expiry, ctime, mtime) - VALUES (?, ?, ?, ?, ?, ?, NOW(), NOW()) - ON DUPLICATE KEY UPDATE - access_token=VALUES(access_token), - refresh_token=VALUES(refresh_token), - expiry=VALUES(expiry), - mtime=VALUES(mtime) - `, account.KeybaseUsername, account.AccountNickname, account.Token.AccessToken, account.Token.TokenType, - account.Token.RefreshToken, account.Token.Expiry) - return err - }) -} - -func (d *DB) GetAccount(keybaseUsername, accountNickname string) (account *Account, err error) { +func (d *DB) InsertAccount(ctx context.Context, account Account) error { + _, err := d.ExecContext(ctx, ` + INSERT INTO account + (keybase_username, account_nickname, access_token, token_type, refresh_token, expiry, ctime, mtime) + VALUES (?, ?, ?, ?, ?, ?, NOW(), NOW()) + ON DUPLICATE KEY UPDATE + access_token=VALUES(access_token), + refresh_token=VALUES(refresh_token), + expiry=VALUES(expiry), + mtime=VALUES(mtime) + `, account.KeybaseUsername, account.AccountNickname, account.Token.AccessToken, account.Token.TokenType, + account.Token.RefreshToken, account.Token.Expiry) + return err +} + +func (d *DB) GetAccount(ctx context.Context, keybaseUsername, accountNickname string) (account *Account, err error) { account = &Account{} var expiry int64 - row := d.QueryRow(` + row := d.QueryRowContext(ctx, ` SELECT keybase_username, account_nickname, access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(expiry)) FROM account WHERE keybase_username = ? AND account_nickname = ? @@ -110,10 +105,10 @@ func (d *DB) GetAccount(keybaseUsername, accountNickname string) (account *Accou } } -func (d *DB) DeleteAccount(keybaseUsername, accountNickname string) error { - return d.RunTxn(func(tx *sql.Tx) error { +func (d *DB) DeleteAccount(ctx context.Context, keybaseUsername, accountNickname string) error { + return d.RunTxnContext(ctx, func(tx *sql.Tx) error { // remove subscriptions first due to foreign key constraint - _, err := tx.Exec(` + _, err := tx.ExecContext(ctx, ` DELETE FROM subscription WHERE keybase_username = ? AND account_nickname = ? `, keybaseUsername, accountNickname) @@ -121,7 +116,7 @@ func (d *DB) DeleteAccount(keybaseUsername, accountNickname string) error { return err } // remove account (and cascading remove associated channels and invites) - _, err = tx.Exec(` + _, err = tx.ExecContext(ctx, ` DELETE FROM account WHERE keybase_username = ? AND account_nickname = ? `, keybaseUsername, accountNickname) @@ -129,16 +124,16 @@ func (d *DB) DeleteAccount(keybaseUsername, accountNickname string) error { }) } -func (d *DB) ExistsAccount(keybaseUsername string, accountNickname string) (exists bool, err error) { - row := d.QueryRow(` +func (d *DB) ExistsAccount(ctx context.Context, keybaseUsername string, accountNickname string) (exists bool, err error) { + row := d.QueryRowContext(ctx, ` SELECT EXISTS(SELECT * FROM account WHERE keybase_username = ? AND account_nickname = ?) `, keybaseUsername, accountNickname) err = row.Scan(&exists) return exists, err } -func (d *DB) GetAccountListForUsername(keybaseUsername string) (accounts []*Account, err error) { - rows, err := d.Query(` +func (d *DB) GetAccountListForUsername(ctx context.Context, keybaseUsername string) (accounts []*Account, err error) { + rows, err := d.QueryContext(ctx, ` SELECT keybase_username, account_nickname, access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(expiry)) FROM account WHERE keybase_username = ? @@ -147,11 +142,7 @@ func (d *DB) GetAccountListForUsername(keybaseUsername string) (accounts []*Acco if err != nil { return nil, err } - defer func() { - if cerr := rows.Close(); cerr != nil { - d.Errorf("GetAccountListForUsername: failed to close rows: %s", cerr) - } - }() + defer rows.Close() for rows.Next() { var account Account var expiry int64 @@ -163,48 +154,42 @@ func (d *DB) GetAccountListForUsername(keybaseUsername string) (accounts []*Acco } accounts = append(accounts, &account) } - return accounts, nil + return accounts, rows.Err() } // Channel -func (d *DB) InsertChannel(account *Account, channel Channel) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO channel - (channel_id, keybase_username, account_nickname, calendar_id, resource_id, expiry, next_sync_token) - VALUES (?, ?, ?, ?, ?, ?, ?) - `, channel.ChannelID, account.KeybaseUsername, account.AccountNickname, channel.CalendarID, channel.ResourceID, - channel.Expiry, channel.NextSyncToken) - return err - }) -} - -func (d *DB) UpdateChannel(oldChannelID, newChannelID string, resourceID string, expiry time.Time) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - UPDATE channel - SET channel_id = ?, resource_id = ?, expiry = ? - WHERE channel_id = ? - `, newChannelID, resourceID, expiry, oldChannelID) - return err - }) +func (d *DB) InsertChannel(ctx context.Context, account *Account, channel Channel) error { + _, err := d.ExecContext(ctx, ` + INSERT INTO channel + (channel_id, keybase_username, account_nickname, calendar_id, resource_id, expiry, next_sync_token) + VALUES (?, ?, ?, ?, ?, ?, ?) + `, channel.ChannelID, account.KeybaseUsername, account.AccountNickname, channel.CalendarID, channel.ResourceID, + channel.Expiry, channel.NextSyncToken) + return err +} + +func (d *DB) UpdateChannel(ctx context.Context, oldChannelID, newChannelID string, resourceID string, expiry time.Time) error { + _, err := d.ExecContext(ctx, ` + UPDATE channel + SET channel_id = ?, resource_id = ?, expiry = ? + WHERE channel_id = ? + `, newChannelID, resourceID, expiry, oldChannelID) + return err } -func (d *DB) UpdateChannelNextSyncToken(channelID, nextSyncToken string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - UPDATE channel - SET next_sync_token = ? - WHERE channel_id = ? - `, nextSyncToken, channelID) - return err - }) +func (d *DB) UpdateChannelNextSyncToken(ctx context.Context, channelID, nextSyncToken string) error { + _, err := d.ExecContext(ctx, ` + UPDATE channel + SET next_sync_token = ? + WHERE channel_id = ? + `, nextSyncToken, channelID) + return err } -func (d *DB) GetChannel(account *Account, calendarID string) (channel *Channel, err error) { +func (d *DB) GetChannel(ctx context.Context, account *Account, calendarID string) (channel *Channel, err error) { channel = &Channel{} var expiry int64 - row := d.QueryRow(` + row := d.QueryRowContext(ctx, ` SELECT channel_id, calendar_id, resource_id, ROUND(UNIX_TIMESTAMP(channel.expiry)), next_sync_token FROM channel WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ? @@ -221,12 +206,12 @@ func (d *DB) GetChannel(account *Account, calendarID string) (channel *Channel, } } -func (d *DB) GetChannelAndAccountByID(channelID string) (channel *Channel, account *Account, err error) { +func (d *DB) GetChannelAndAccountByID(ctx context.Context, channelID string) (channel *Channel, account *Account, err error) { channel = &Channel{} account = &Account{} var channelExpiry int64 var tokenExpiry int64 - row := d.QueryRow(` + row := d.QueryRowContext(ctx, ` SELECT channel_id, calendar_id, resource_id, ROUND(UNIX_TIMESTAMP(channel.expiry)), next_sync_token, account.keybase_username, account.account_nickname, access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(account.expiry)) @@ -249,8 +234,8 @@ func (d *DB) GetChannelAndAccountByID(channelID string) (channel *Channel, accou } } -func (d *DB) GetChannelListByAccount(account *Account) (channels []*Channel, err error) { - rows, err := d.Query(` +func (d *DB) GetChannelListByAccount(ctx context.Context, account *Account) (channels []*Channel, err error) { + rows, err := d.QueryContext(ctx, ` SELECT channel_id, calendar_id, resource_id, ROUND(UNIX_TIMESTAMP(expiry)), next_sync_token FROM channel WHERE keybase_username = ? AND account_nickname = ? @@ -258,9 +243,7 @@ func (d *DB) GetChannelListByAccount(account *Account) (channels []*Channel, err if err != nil { return nil, err } - defer func() { - _ = rows.Close() - }() + defer rows.Close() for rows.Next() { var channel Channel var expiry int64 @@ -271,12 +254,12 @@ func (d *DB) GetChannelListByAccount(account *Account) (channels []*Channel, err channel.Expiry = time.Unix(expiry, 0) channels = append(channels, &channel) } - return channels, nil + return channels, rows.Err() } -func (d *DB) GetExpiringChannelAndAccountList() (pairs []*ChannelAndAccount, err error) { +func (d *DB) GetExpiringChannelAndAccountList(ctx context.Context) (pairs []*ChannelAndAccount, err error) { // query all channels that are expiring in less than a day - rows, err := d.Query(` + rows, err := d.QueryContext(ctx, ` SELECT channel_id, calendar_id, resource_id, ROUND(UNIX_TIMESTAMP(channel.expiry)), next_sync_token, account.keybase_username, account.account_nickname, access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(account.expiry)) @@ -287,9 +270,7 @@ func (d *DB) GetExpiringChannelAndAccountList() (pairs []*ChannelAndAccount, err if err != nil { return nil, err } - defer func() { - _ = rows.Close() - }() + defer rows.Close() for rows.Next() { var pair ChannelAndAccount var channelExpiry int64 @@ -305,44 +286,40 @@ func (d *DB) GetExpiringChannelAndAccountList() (pairs []*ChannelAndAccount, err pair.Account.Token.Expiry = time.Unix(accountExpiry, 0) pairs = append(pairs, &pair) } - return pairs, nil + return pairs, rows.Err() } -func (d *DB) ExistsChannelByAccountAndCalendar(account *Account, calendarID string) (exists bool, err error) { - row := d.QueryRow(` +func (d *DB) ExistsChannelByAccountAndCalendar(ctx context.Context, account *Account, calendarID string) (exists bool, err error) { + row := d.QueryRowContext(ctx, ` SELECT EXISTS(SELECT * FROM channel WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ?) `, account.KeybaseUsername, account.AccountNickname, calendarID) err = row.Scan(&exists) return exists, err } -func (d *DB) DeleteChannelByChannelID(channelID string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM channel - WHERE channel_id = ? - `, channelID) - return err - }) +func (d *DB) DeleteChannelByChannelID(ctx context.Context, channelID string) error { + _, err := d.ExecContext(ctx, ` + DELETE FROM channel + WHERE channel_id = ? + `, channelID) + return err } // Subscription -func (d *DB) InsertSubscription(account *Account, subscription Subscription) error { +func (d *DB) InsertSubscription(ctx context.Context, account *Account, subscription Subscription) error { minutesBefore := GetMinutesFromDuration(subscription.DurationBefore) - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO subscription - (keybase_username, account_nickname, calendar_id, keybase_conv_id, minutes_before, type) - VALUES (?, ?, ?, ?, ?, ?) - `, account.KeybaseUsername, account.AccountNickname, subscription.CalendarID, - subscription.KeybaseConvID, minutesBefore, subscription.Type) - return err - }) + _, err := d.ExecContext(ctx, ` + INSERT INTO subscription + (keybase_username, account_nickname, calendar_id, keybase_conv_id, minutes_before, type) + VALUES (?, ?, ?, ?, ?, ?) + `, account.KeybaseUsername, account.AccountNickname, subscription.CalendarID, + subscription.KeybaseConvID, minutesBefore, subscription.Type) + return err } -func (d *DB) ExistsSubscription(account *Account, subscription Subscription) (exists bool, err error) { +func (d *DB) ExistsSubscription(ctx context.Context, account *Account, subscription Subscription) (exists bool, err error) { minutesBefore := GetMinutesFromDuration(subscription.DurationBefore) - row := d.QueryRow(` + row := d.QueryRowContext(ctx, ` SELECT EXISTS( SELECT * FROM subscription @@ -355,16 +332,16 @@ func (d *DB) ExistsSubscription(account *Account, subscription Subscription) (ex return exists, err } -func (d *DB) CountSubscriptionsByAccountAndCalender(account *Account, calendarID string) (count int, err error) { - row := d.QueryRow(` +func (d *DB) CountSubscriptionsByAccountAndCalender(ctx context.Context, account *Account, calendarID string) (count int, err error) { + row := d.QueryRowContext(ctx, ` SELECT COUNT(*) FROM subscription WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ? `, account.KeybaseUsername, account.AccountNickname, calendarID) err = row.Scan(&count) return count, err } -func (d *DB) GetReminderSubscriptionAndAccountPairs() (pairs []*SubscriptionAndAccount, err error) { - row, err := d.Query(` +func (d *DB) GetReminderSubscriptionAndAccountPairs(ctx context.Context) (pairs []*SubscriptionAndAccount, err error) { + rows, err := d.QueryContext(ctx, ` SELECT calendar_id, keybase_conv_id, minutes_before, type, -- subscription account.keybase_username, account.account_nickname, access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(expiry)) -- account @@ -375,16 +352,12 @@ func (d *DB) GetReminderSubscriptionAndAccountPairs() (pairs []*SubscriptionAndA if err != nil { return nil, err } - defer func() { - if cerr := row.Close(); cerr != nil { - d.Errorf("GetReminderSubscriptionAndAccountPairs: failed to close rows: %s", cerr) - } - }() - for row.Next() { + defer rows.Close() + for rows.Next() { var pair SubscriptionAndAccount var subscriptionMinutesBefore int var tokenExpiry int64 - err = row.Scan(&pair.Subscription.CalendarID, &pair.Subscription.KeybaseConvID, &subscriptionMinutesBefore, &pair.Subscription.Type, + err = rows.Scan(&pair.Subscription.CalendarID, &pair.Subscription.KeybaseConvID, &subscriptionMinutesBefore, &pair.Subscription.Type, &pair.Account.KeybaseUsername, &pair.Account.AccountNickname, &pair.Account.Token.AccessToken, &pair.Account.Token.TokenType, &pair.Account.Token.RefreshToken, &tokenExpiry) if err != nil { @@ -394,15 +367,16 @@ func (d *DB) GetReminderSubscriptionAndAccountPairs() (pairs []*SubscriptionAndA pair.Account.Token.Expiry = time.Unix(tokenExpiry, 0) pairs = append(pairs, &pair) } - return pairs, nil + return pairs, rows.Err() } func (d *DB) GetReminderSubscriptionsByAccountAndCalendar( + ctx context.Context, account *Account, calendarID string, subscriptionType SubscriptionType, ) (subscriptions []*Subscription, err error) { - row, err := d.Query(` + rows, err := d.QueryContext(ctx, ` SELECT calendar_id, keybase_conv_id, minutes_before, type FROM subscription WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ? AND type = ? @@ -410,26 +384,22 @@ func (d *DB) GetReminderSubscriptionsByAccountAndCalendar( if err != nil { return nil, err } - defer func() { - if cerr := row.Close(); cerr != nil { - d.Errorf("GetReminderSubscriptionsByAccountAndCalendar: failed to close rows: %s", cerr) - } - }() - for row.Next() { + defer rows.Close() + for rows.Next() { var subscription Subscription var minutesBefore int - err = row.Scan(&subscription.CalendarID, &subscription.KeybaseConvID, &minutesBefore, &subscription.Type) + err = rows.Scan(&subscription.CalendarID, &subscription.KeybaseConvID, &minutesBefore, &subscription.Type) if err != nil { return nil, err } subscription.DurationBefore = GetDurationFromMinutes(minutesBefore) subscriptions = append(subscriptions, &subscription) } - return subscriptions, nil + return subscriptions, rows.Err() } -func (d *DB) GetSubscriptions(account *Account, calendarID string, keybaseConvID chat1.ConvIDStr) (subscriptions []*Subscription, err error) { - rows, err := d.Query(` +func (d *DB) GetSubscriptions(ctx context.Context, account *Account, calendarID string, keybaseConvID chat1.ConvIDStr) (subscriptions []*Subscription, err error) { + rows, err := d.QueryContext(ctx, ` SELECT calendar_id, keybase_conv_id, minutes_before, type FROM subscription WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ? AND keybase_conv_id = ? @@ -437,11 +407,7 @@ func (d *DB) GetSubscriptions(account *Account, calendarID string, keybaseConvID if err != nil { return nil, err } - defer func() { - if cerr := rows.Close(); cerr != nil { - d.Errorf("GetSubscriptions: failed to close rows: %s", cerr) - } - }() + defer rows.Close() for rows.Next() { var subscription Subscription var minutesBefore int @@ -452,38 +418,34 @@ func (d *DB) GetSubscriptions(account *Account, calendarID string, keybaseConvID subscription.DurationBefore = GetDurationFromMinutes(minutesBefore) subscriptions = append(subscriptions, &subscription) } - return subscriptions, nil + return subscriptions, rows.Err() } -func (d *DB) DeleteSubscription(account *Account, subscription Subscription) error { +func (d *DB) DeleteSubscription(ctx context.Context, account *Account, subscription Subscription) error { minutesBefore := GetMinutesFromDuration(subscription.DurationBefore) - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM subscription - WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ? AND keybase_conv_id = ? AND - minutes_before = ? AND type = ? - `, account.KeybaseUsername, account.AccountNickname, subscription.CalendarID, - subscription.KeybaseConvID, minutesBefore, subscription.Type) - return err - }) + _, err := d.ExecContext(ctx, ` + DELETE FROM subscription + WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ? AND keybase_conv_id = ? AND + minutes_before = ? AND type = ? + `, account.KeybaseUsername, account.AccountNickname, subscription.CalendarID, + subscription.KeybaseConvID, minutesBefore, subscription.Type) + return err } // Invite -func (d *DB) InsertInvite(account *Account, invite Invite) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO invite - (keybase_username, account_nickname, calendar_id, event_id, message_id) - VALUES (?, ?, ?, ?, ?) - ON DUPLICATE KEY UPDATE - message_id=message_id -- message id stays the same - `, account.KeybaseUsername, account.AccountNickname, invite.CalendarID, invite.EventID, invite.MessageID) - return err - }) -} - -func (d *DB) ExistsInvite(account *Account, calendarID, eventID string) (exists bool, err error) { - row := d.QueryRow(` +func (d *DB) InsertInvite(ctx context.Context, account *Account, invite Invite) error { + _, err := d.ExecContext(ctx, ` + INSERT INTO invite + (keybase_username, account_nickname, calendar_id, event_id, message_id) + VALUES (?, ?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + message_id=message_id -- message id stays the same + `, account.KeybaseUsername, account.AccountNickname, invite.CalendarID, invite.EventID, invite.MessageID) + return err +} + +func (d *DB) ExistsInvite(ctx context.Context, account *Account, calendarID, eventID string) (exists bool, err error) { + row := d.QueryRowContext(ctx, ` SELECT EXISTS( SELECT * FROM invite WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ? AND event_id = ? ) @@ -492,11 +454,11 @@ func (d *DB) ExistsInvite(account *Account, calendarID, eventID string) (exists return exists, err } -func (d *DB) GetInviteAndAccountByUserMessage(keybaseUsername string, messageID chat1.MessageID) (invite *Invite, account *Account, err error) { +func (d *DB) GetInviteAndAccountByUserMessage(ctx context.Context, keybaseUsername string, messageID chat1.MessageID) (invite *Invite, account *Account, err error) { invite = &Invite{} account = &Account{} var expiry int64 - row := d.QueryRow(` + row := d.QueryRowContext(ctx, ` SELECT calendar_id, event_id, message_id, account.keybase_username, account.account_nickname, access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(expiry)) @@ -519,26 +481,24 @@ func (d *DB) GetInviteAndAccountByUserMessage(keybaseUsername string, messageID } // Daily Schedule Subscription -func (d *DB) InsertDailyScheduleSubscription(account *Account, subscription DailyScheduleSubscription) error { +func (d *DB) InsertDailyScheduleSubscription(ctx context.Context, account *Account, subscription DailyScheduleSubscription) error { notificationTime := GetTimeStringFromDuration(subscription.NotificationTime) - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO daily_schedule_subscription - (keybase_username, account_nickname, calendar_id, keybase_conv_id, timezone, days_to_send, schedule_to_send, notification_time) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ON DUPLICATE KEY UPDATE - timezone=VALUES(timezone), - days_to_send=VALUES(days_to_send), - schedule_to_send=VALUES(schedule_to_send), - notification_time=VALUES(notification_time) - `, account.KeybaseUsername, account.AccountNickname, subscription.CalendarID, subscription.KeybaseConvID, - subscription.Timezone.String(), subscription.DaysToSend, subscription.ScheduleToSend, notificationTime) - return err - }) + _, err := d.ExecContext(ctx, ` + INSERT INTO daily_schedule_subscription + (keybase_username, account_nickname, calendar_id, keybase_conv_id, timezone, days_to_send, schedule_to_send, notification_time) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + timezone=VALUES(timezone), + days_to_send=VALUES(days_to_send), + schedule_to_send=VALUES(schedule_to_send), + notification_time=VALUES(notification_time) + `, account.KeybaseUsername, account.AccountNickname, subscription.CalendarID, subscription.KeybaseConvID, + subscription.Timezone.String(), subscription.DaysToSend, subscription.ScheduleToSend, notificationTime) + return err } -func (d *DB) GetAggregatedDailyScheduleSubscription(scheduleToSend ScheduleToSendType) (subscriptions []*AggregatedDailyScheduleSubscription, err error) { - row, err := d.Query(` +func (d *DB) GetAggregatedDailyScheduleSubscription(ctx context.Context, scheduleToSend ScheduleToSendType) (subscriptions []*AggregatedDailyScheduleSubscription, err error) { + rows, err := d.QueryContext(ctx, ` SELECT GROUP_CONCAT(calendar_id) as calendar_ids, keybase_conv_id, timezone, days_to_send, schedule_to_send, notification_time, account.keybase_username, account.account_nickname, access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(expiry)) @@ -550,18 +510,14 @@ func (d *DB) GetAggregatedDailyScheduleSubscription(scheduleToSend ScheduleToSen if err != nil { return nil, err } - defer func() { - if cerr := row.Close(); cerr != nil { - d.Errorf("GetAggregatedDailyScheduleSubscription: failed to close rows: %s", cerr) - } - }() - for row.Next() { + defer rows.Close() + for rows.Next() { var pair AggregatedDailyScheduleSubscription var concatCalendarIDs string var timezone string var notificationTime string var tokenExpiry int64 - err = row.Scan(&concatCalendarIDs, &pair.KeybaseConvID, &timezone, &pair.DaysToSend, &pair.ScheduleToSend, ¬ificationTime, + err = rows.Scan(&concatCalendarIDs, &pair.KeybaseConvID, &timezone, &pair.DaysToSend, &pair.ScheduleToSend, ¬ificationTime, &pair.Account.KeybaseUsername, &pair.Account.AccountNickname, &pair.Account.Token.AccessToken, &pair.Account.Token.TokenType, &pair.Account.Token.RefreshToken, &tokenExpiry) if err != nil { @@ -581,10 +537,11 @@ func (d *DB) GetAggregatedDailyScheduleSubscription(scheduleToSend ScheduleToSen pair.Account.Token.Expiry = time.Unix(tokenExpiry, 0) subscriptions = append(subscriptions, &pair) } - return subscriptions, nil + return subscriptions, rows.Err() } func (d *DB) GetDailyScheduleSubscription( + ctx context.Context, account *Account, calendarID string, keybaseConvID chat1.ConvIDStr, @@ -592,7 +549,7 @@ func (d *DB) GetDailyScheduleSubscription( subscription = &DailyScheduleSubscription{} var timezone string var notificationTime string - row := d.QueryRow(` + row := d.QueryRowContext(ctx, ` SELECT calendar_id, keybase_conv_id, timezone, days_to_send, schedule_to_send, notification_time FROM daily_schedule_subscription WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ? AND keybase_conv_id = ? @@ -617,12 +574,10 @@ func (d *DB) GetDailyScheduleSubscription( } } -func (d *DB) DeleteDailyScheduleSubscription(account *Account, calendarID string, keybaseConvID chat1.ConvIDStr) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM daily_schedule_subscription - WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ? AND keybase_conv_id = ? - `, account.KeybaseUsername, account.AccountNickname, calendarID, keybaseConvID) - return err - }) +func (d *DB) DeleteDailyScheduleSubscription(ctx context.Context, account *Account, calendarID string, keybaseConvID chat1.ConvIDStr) error { + _, err := d.ExecContext(ctx, ` + DELETE FROM daily_schedule_subscription + WHERE keybase_username = ? AND account_nickname = ? AND calendar_id = ? AND keybase_conv_id = ? + `, account.KeybaseUsername, account.AccountNickname, calendarID, keybaseConvID) + return err } diff --git a/gcalbot/gcalbot/handler.go b/gcalbot/gcalbot/handler.go index ad1a2acd..8f7c1ed5 100644 --- a/gcalbot/gcalbot/handler.go +++ b/gcalbot/gcalbot/handler.go @@ -1,6 +1,7 @@ package gcalbot import ( + "context" "crypto/hmac" "crypto/sha256" "encoding/hex" @@ -53,14 +54,14 @@ func NewHandler( } } -func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(_ context.Context, conv chat1.ConvSummary) error { welcomeMsg := "Hello! I can get you set up with Google Calendar anytime, just send me `!gcal accounts connect `." return base.HandleNewTeam(h.stats, h.DebugOutput, h.kbc, conv, welcomeMsg) } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(ctx context.Context, msg chat1.MsgSummary) error { if msg.Content.Reaction != nil && msg.Sender.Username != h.kbc.GetUsername() { - return h.handleReaction(msg) + return h.handleReaction(ctx, msg) } if msg.Content.Text == nil { @@ -84,17 +85,17 @@ func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { switch { case strings.HasPrefix(cmd, "!gcal accounts list"): h.stats.Count("accounts list") - return h.handleAccountsList(msg) + return h.handleAccountsList(ctx, msg) case strings.HasPrefix(cmd, "!gcal accounts connect"): h.stats.Count("accounts connect") - return h.handleAccountsConnect(msg, tokens[3:]) + return h.handleAccountsConnect(ctx, msg, tokens[3:]) case strings.HasPrefix(cmd, "!gcal accounts disconnect"): h.stats.Count("accounts disconnect") - return h.handleAccountsDisconnect(msg, tokens[3:]) + return h.handleAccountsDisconnect(ctx, msg, tokens[3:]) case strings.HasPrefix(cmd, "!gcal calendars list"): h.stats.Count("calendars list") - return h.handleCalendarsList(msg, tokens[3:]) + return h.handleCalendarsList(ctx, msg, tokens[3:]) case strings.HasPrefix(cmd, "!gcal configure"): h.stats.Count("configure") @@ -106,16 +107,16 @@ func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { } } -func (h *Handler) handleReaction(msg chat1.MsgSummary) error { +func (h *Handler) handleReaction(ctx context.Context, msg chat1.MsgSummary) error { username := msg.Sender.Username messageID := msg.Content.Reaction.MessageID reaction := msg.Content.Reaction.Body - invite, account, err := h.db.GetInviteAndAccountByUserMessage(username, messageID) + invite, account, err := h.db.GetInviteAndAccountByUserMessage(ctx, username, messageID) if err != nil { return err } else if invite != nil && account != nil { - err = h.updateEventResponseStatus(invite, account, InviteReaction(reaction)) + err = h.updateEventResponseStatus(ctx, invite, account, InviteReaction(reaction)) if err != nil { return fmt.Errorf("error updating event response status: %s", err) } diff --git a/gcalbot/gcalbot/http.go b/gcalbot/gcalbot/http.go index 32c3c5a8..c17e8c0b 100644 --- a/gcalbot/gcalbot/http.go +++ b/gcalbot/gcalbot/http.go @@ -2,6 +2,7 @@ package gcalbot import ( "bytes" + "context" "crypto/hmac" "encoding/base64" "fmt" @@ -156,7 +157,11 @@ func (h *HTTPSrv) configHandler(w http.ResponseWriter, r *http.Request) { dsTime = GetDurationFromMinutes(dsTimeMinutes) } - accounts, err := h.db.GetAccountListForUsername(keybaseUsername) + // WithoutCancel: the browser submits the config form and may close the + // connection before DB writes complete; work must finish regardless. + ctx := context.WithoutCancel(r.Context()) + + accounts, err := h.db.GetAccountListForUsername(ctx, keybaseUsername) if err != nil { return } @@ -197,7 +202,7 @@ func (h *HTTPSrv) configHandler(w http.ResponseWriter, r *http.Request) { return } - srv, err := GetCalendarService(selectedAccount, h.oauth, h.db) + srv, err := GetCalendarService(ctx, selectedAccount, h.oauth, h.db) if err != nil { return } @@ -225,7 +230,7 @@ func (h *HTTPSrv) configHandler(w http.ResponseWriter, r *http.Request) { page.CalendarID = calendarID var subscriptions []*Subscription - subscriptions, err = h.db.GetSubscriptions(selectedAccount, calendarID, keybaseConvID) + subscriptions, err = h.db.GetSubscriptions(ctx, selectedAccount, calendarID, keybaseConvID) if err != nil { return } @@ -238,7 +243,7 @@ func (h *HTTPSrv) configHandler(w http.ResponseWriter, r *http.Request) { } } - dsSubscription, dsSubExists, err := h.db.GetDailyScheduleSubscription(selectedAccount, calendarID, keybaseConvID) + dsSubscription, dsSubExists, err := h.db.GetDailyScheduleSubscription(ctx, selectedAccount, calendarID, keybaseConvID) if err != nil { return } @@ -319,7 +324,7 @@ func (h *HTTPSrv) configHandler(w http.ResponseWriter, r *http.Request) { if (!page.Invite && page.Reminder == "") && (inviteInput != "" || reminderInput != "") { // this update must open a new webhook channel, do that now and if it errors, fail early - err = h.handler.createEventChannel(selectedAccount, calendarID) + err = h.handler.createEventChannel(ctx, selectedAccount, calendarID) switch typedErr := err.(type) { case nil: case *googleapi.Error: @@ -336,7 +341,7 @@ func (h *HTTPSrv) configHandler(w http.ResponseWriter, r *http.Request) { } if dsEnabled { - err = h.db.InsertDailyScheduleSubscription(selectedAccount, DailyScheduleSubscription{ + err = h.db.InsertDailyScheduleSubscription(ctx, selectedAccount, DailyScheduleSubscription{ CalendarID: calendarID, KeybaseConvID: keybaseConvID, Timezone: dsTimezone, @@ -353,7 +358,7 @@ func (h *HTTPSrv) configHandler(w http.ResponseWriter, r *http.Request) { page.DSTime = strconv.Itoa(GetMinutesFromDuration(dsTime)) } else if !dsEnabled && dsSubExists { page.DSEnabled = false - err = h.db.DeleteDailyScheduleSubscription(selectedAccount, calendarID, keybaseConvID) + err = h.db.DeleteDailyScheduleSubscription(ctx, selectedAccount, calendarID, keybaseConvID) if err != nil { return } @@ -371,14 +376,14 @@ func (h *HTTPSrv) configHandler(w http.ResponseWriter, r *http.Request) { if page.Invite && !invite { // remove invite subscription h.Stats.Count("config - update - invite - remove") - err = h.handler.removeSubscription(selectedAccount, inviteSubscription) + err = h.handler.removeSubscription(ctx, selectedAccount, inviteSubscription) if err != nil { return } } else if !page.Invite && invite { // create invite subscription h.Stats.Count("config - update - invite - create") - if err = h.handler.createSubscription(selectedAccount, inviteSubscription); err != nil { + if err = h.handler.createSubscription(ctx, selectedAccount, inviteSubscription); err != nil { return } } @@ -396,7 +401,7 @@ func (h *HTTPSrv) configHandler(w http.ResponseWriter, r *http.Request) { return } - err = h.handler.removeSubscription(selectedAccount, Subscription{ + err = h.handler.removeSubscription(ctx, selectedAccount, Subscription{ CalendarID: calendarID, KeybaseConvID: keybaseConvID, DurationBefore: GetDurationFromMinutes(oldMinutesBefore), @@ -415,7 +420,7 @@ func (h *HTTPSrv) configHandler(w http.ResponseWriter, r *http.Request) { return } - err = h.handler.createSubscription(selectedAccount, Subscription{ + err = h.handler.createSubscription(ctx, selectedAccount, Subscription{ CalendarID: calendarID, KeybaseConvID: keybaseConvID, DurationBefore: GetDurationFromMinutes(newMinutesBefore), diff --git a/gcalbot/gcalbot/invite.go b/gcalbot/gcalbot/invite.go index 3d45ab61..ce951389 100644 --- a/gcalbot/gcalbot/invite.go +++ b/gcalbot/gcalbot/invite.go @@ -28,7 +28,7 @@ const ( ResponseStatusAccepted ResponseStatus = "accepted" ) -func (h *Handler) sendEventInvite(account *Account, channel *Channel, event *calendar.Event) error { +func (h *Handler) sendEventInvite(ctx context.Context, account *Account, channel *Channel, event *calendar.Event) error { h.stats.Count("sendEventInvite") message := `You've been invited to %s: %s @@ -41,7 +41,7 @@ Awaiting your response. *Are you going?*` eventType = "a recurring event" } - srv, err := GetCalendarService(account, h.oauth, h.db) + srv, err := GetCalendarService(ctx, account, h.oauth, h.db) if err != nil { return err } @@ -67,7 +67,7 @@ Awaiting your response. *Are you going?*` return err } - err = h.db.InsertInvite(account, Invite{ + err = h.db.InsertInvite(ctx, account, Invite{ CalendarID: invitedCalendar.Id, EventID: event.Id, MessageID: *sendRes.Result.MessageID, @@ -87,7 +87,7 @@ Awaiting your response. *Are you going?*` return nil } -func (h *Handler) updateEventResponseStatus(invite *Invite, account *Account, reaction InviteReaction) error { +func (h *Handler) updateEventResponseStatus(ctx context.Context, invite *Invite, account *Account, reaction InviteReaction) error { h.stats.Count("updateEventResponseStatus") var responseStatus ResponseStatus @@ -107,7 +107,7 @@ func (h *Handler) updateEventResponseStatus(invite *Invite, account *Account, re return nil } - srv, err := GetCalendarService(account, h.oauth, h.db) + srv, err := GetCalendarService(ctx, account, h.oauth, h.db) if err != nil { return err } @@ -176,6 +176,7 @@ func (h *Handler) syncAllInvites(account *Account, srv *calendar.Service, channe var nextSyncToken string var events []*calendar.Event + // context.Background() because syncAllInvites is a background goroutine that outlives the request context err := srv.Events.List(calendarID). Pages(context.Background(), func(page *calendar.Events) error { if page.NextPageToken == "" { @@ -235,7 +236,7 @@ func (h *Handler) syncAllInvites(account *Account, srv *calendar.Service, channe for _, attendee := range event.Attendees { responseStatus := ResponseStatus(attendee.ResponseStatus) if attendee.Self && !attendee.Organizer && responseStatus == ResponseStatusNeedsAction { - err = h.db.InsertInvite(account, Invite{ + err = h.db.InsertInvite(context.Background(), account, Invite{ CalendarID: calendarID, EventID: event.Id, }) @@ -247,7 +248,7 @@ func (h *Handler) syncAllInvites(account *Account, srv *calendar.Service, channe } } - err = h.db.UpdateChannelNextSyncToken(channelID, nextSyncToken) + err = h.db.UpdateChannelNextSyncToken(context.Background(), channelID, nextSyncToken) if err != nil { h.Errorf("unable to update sync token: %v", err) return diff --git a/gcalbot/gcalbot/oauth.go b/gcalbot/gcalbot/oauth.go index a4f7a671..4f1069ae 100644 --- a/gcalbot/gcalbot/oauth.go +++ b/gcalbot/gcalbot/oauth.go @@ -30,8 +30,11 @@ func (h *HTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() state := query.Get("state") + // WithoutCancel: the browser may close after the redirect; DB writes and + // Keybase messaging triggered by OAuth completion must finish regardless. + ctx := context.WithoutCancel(r.Context()) - req, err := h.db.GetState(state) + req, err := h.db.GetState(ctx, state) if err != nil { err = fmt.Errorf("could not get state %q: %v", state, err) return @@ -58,7 +61,7 @@ func (h *HTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) { h.showOAuthError(w) return } - token, err := h.oauth.Exchange(context.TODO(), code) + token, err := h.oauth.Exchange(ctx, code) if err != nil { return } @@ -68,11 +71,11 @@ func (h *HTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) { AccountNickname: req.AccountNickname, Token: *token, } - err = h.db.InsertAccount(account) + err = h.db.InsertAccount(ctx, account) if err != nil { return } - if err = h.db.CompleteState(state); err != nil { + if err = h.db.CompleteState(ctx, state); err != nil { return } @@ -84,7 +87,7 @@ func (h *HTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) { // if account was created in a 1on1 conv, create default subscription to invites & 5 minute reminder for primary calendar if base.IsDirectPrivateMessage(h.kbc.GetUsername(), req.KeybaseUsername, conv.Channel) { var srv *calendar.Service - srv, err = GetCalendarService(&account, h.oauth, h.db) + srv, err = GetCalendarService(ctx, &account, h.oauth, h.db) if err != nil { return } @@ -95,7 +98,7 @@ func (h *HTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) { return } - err = h.handler.createSubscription(&account, Subscription{ + err = h.handler.createSubscription(ctx, &account, Subscription{ CalendarID: primaryCalendar.Id, KeybaseConvID: req.KeybaseConvID, Type: SubscriptionTypeInvite, @@ -103,7 +106,7 @@ func (h *HTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) { if err != nil { return } - err = h.handler.createSubscription(&account, Subscription{ + err = h.handler.createSubscription(ctx, &account, Subscription{ CalendarID: primaryCalendar.Id, KeybaseConvID: req.KeybaseConvID, DurationBefore: GetDurationFromMinutes(5), @@ -134,13 +137,13 @@ func (h *HTTPSrv) showOAuthError(w http.ResponseWriter) { } } -func (h *Handler) requestOAuth(msg chat1.MsgSummary, accountNickname string) error { +func (h *Handler) requestOAuth(ctx context.Context, msg chat1.MsgSummary, accountNickname string) error { state, err := base.MakeRequestID() if err != nil { return err } - err = h.db.PutState(state, OAuthRequest{ + err = h.db.PutState(ctx, state, OAuthRequest{ KeybaseUsername: msg.Sender.Username, AccountNickname: accountNickname, KeybaseConvID: msg.ConvID, diff --git a/gcalbot/gcalbot/reminderscheduler/sync.go b/gcalbot/gcalbot/reminderscheduler/sync.go index 3afb3522..10d5e638 100644 --- a/gcalbot/gcalbot/reminderscheduler/sync.go +++ b/gcalbot/gcalbot/reminderscheduler/sync.go @@ -19,7 +19,7 @@ func (r *ReminderScheduler) eventSyncLoop(shutdownCh chan struct{}) error { }() eventSync := func(syncMinute time.Time) { - pairs, err := r.db.GetReminderSubscriptionAndAccountPairs() + pairs, err := r.db.GetReminderSubscriptionAndAccountPairs(context.Background()) r.stats.ValueInt("eventSyncLoop - subscriptions - count", len(pairs)) if err != nil { r.Errorf("error getting reminder subscriptions to sync: %s", err) @@ -47,7 +47,7 @@ func (r *ReminderScheduler) eventSyncLoop(shutdownCh chan struct{}) error { } func (r *ReminderScheduler) syncEvents(account *gcalbot.Account, subscription *gcalbot.Subscription) { - srv, err := gcalbot.GetCalendarService(account, r.oauth, r.db) + srv, err := gcalbot.GetCalendarService(context.Background(), account, r.oauth, r.db) switch err.(type) { case nil: case *oauth2.RetrieveError: @@ -124,7 +124,7 @@ func (r *ReminderScheduler) UpdateOrCreateReminderEvent( } }) - srv, err := gcalbot.GetCalendarService(account, r.oauth, r.db) + srv, err := gcalbot.GetCalendarService(context.Background(), account, r.oauth, r.db) switch err.(type) { case nil: case *oauth2.RetrieveError: diff --git a/gcalbot/gcalbot/schedulescheduler/send.go b/gcalbot/gcalbot/schedulescheduler/send.go index 4c9ae40e..b0ace4e2 100644 --- a/gcalbot/gcalbot/schedulescheduler/send.go +++ b/gcalbot/gcalbot/schedulescheduler/send.go @@ -50,13 +50,13 @@ func (s *ScheduleScheduler) sendDailySchedulesForMinute(sendMinute time.Time, sh var subscriptions []*gcalbot.AggregatedDailyScheduleSubscription - todaySubscriptions, err := s.db.GetAggregatedDailyScheduleSubscription(gcalbot.ScheduleToSendToday) + todaySubscriptions, err := s.db.GetAggregatedDailyScheduleSubscription(context.Background(), gcalbot.ScheduleToSendToday) if err != nil { s.Errorf("error getting daily schedule subscriptions to sync: %s", err) } subscriptions = todaySubscriptions - tomorrowSubscriptions, err := s.db.GetAggregatedDailyScheduleSubscription(gcalbot.ScheduleToSendTomorrow) + tomorrowSubscriptions, err := s.db.GetAggregatedDailyScheduleSubscription(context.Background(), gcalbot.ScheduleToSendTomorrow) if err != nil { s.Errorf("error getting daily schedule subscriptions to sync: %s", err) } @@ -101,7 +101,7 @@ func (s *ScheduleScheduler) SendDailyScheduleMessage(sendMinute time.Time, subsc s.stats.Count("SendDailyScheduleMessage") s.stats.CountMult("SendDailyScheduleMessage - calendars", len(subscription.CalendarIDs)) - srv, err := gcalbot.GetCalendarService(&subscription.Account, s.oauth, s.db) + srv, err := gcalbot.GetCalendarService(context.Background(), &subscription.Account, s.oauth, s.db) switch err.(type) { case nil: case *oauth2.RetrieveError: diff --git a/gcalbot/gcalbot/webhook.go b/gcalbot/gcalbot/webhook.go index 2d2097ee..561d489f 100644 --- a/gcalbot/gcalbot/webhook.go +++ b/gcalbot/gcalbot/webhook.go @@ -28,10 +28,13 @@ func (h *HTTPSrv) handleEventUpdateWebhook(w http.ResponseWriter, r *http.Reques // sync header, safe to ignore return } + // WithoutCancel: Google sends the webhook and may close the connection + // immediately; DB writes and reminder scheduling must complete regardless. + ctx := context.WithoutCancel(r.Context()) channelID := r.Header.Get("X-Goog-Channel-ID") resourceID := r.Header.Get("X-Goog-Resource-ID") - channel, account, err := h.db.GetChannelAndAccountByID(channelID) + channel, account, err := h.db.GetChannelAndAccountByID(ctx, channelID) if err != nil { return } else if channel == nil { @@ -47,17 +50,17 @@ func (h *HTTPSrv) handleEventUpdateWebhook(w http.ResponseWriter, r *http.Reques } reminderSubscriptions, err := h.db.GetReminderSubscriptionsByAccountAndCalendar( - account, channel.CalendarID, SubscriptionTypeReminder) + ctx, account, channel.CalendarID, SubscriptionTypeReminder) if err != nil { return } inviteSubscriptions, err := h.db.GetReminderSubscriptionsByAccountAndCalendar( - account, channel.CalendarID, SubscriptionTypeInvite) + ctx, account, channel.CalendarID, SubscriptionTypeInvite) if err != nil { return } - srv, err := GetCalendarService(account, h.oauth, h.db) + srv, err := GetCalendarService(ctx, account, h.oauth, h.db) switch err.(type) { case nil: case *oauth2.RetrieveError: @@ -94,7 +97,7 @@ func (h *HTTPSrv) handleEventUpdateWebhook(w http.ResponseWriter, r *http.Reques return } var exists bool - exists, err = h.db.ExistsInvite(account, channel.CalendarID, event.Id) + exists, err = h.db.ExistsInvite(ctx, account, channel.CalendarID, event.Id) if err != nil { return } @@ -102,7 +105,7 @@ func (h *HTTPSrv) handleEventUpdateWebhook(w http.ResponseWriter, r *http.Reques // user was recently invited to the event for range inviteSubscriptions { // TODO(marcel): use subscription convid - err = h.handler.sendEventInvite(account, channel, event) + err = h.handler.sendEventInvite(ctx, account, channel, event) if err != nil { return } @@ -117,7 +120,7 @@ func (h *HTTPSrv) handleEventUpdateWebhook(w http.ResponseWriter, r *http.Reques err = srv.Events. List(channel.CalendarID). SyncToken(syncToken). - Pages(context.Background(), func(page *calendar.Events) error { + Pages(ctx, func(page *calendar.Events) error { if page.NextPageToken == "" { // set the sync token when the page token is empty nextSyncToken = page.NextSyncToken @@ -190,7 +193,7 @@ func (h *HTTPSrv) handleEventUpdateWebhook(w http.ResponseWriter, r *http.Reques } } - err = h.db.UpdateChannelNextSyncToken(channelID, nextSyncToken) + err = h.db.UpdateChannelNextSyncToken(ctx, channelID, nextSyncToken) if err != nil { return } @@ -202,19 +205,19 @@ func (h *HTTPSrv) handleEventUpdateWebhook(w http.ResponseWriter, r *http.Reques } func (h *Handler) createSubscription( - account *Account, subscription Subscription, + ctx context.Context, account *Account, subscription Subscription, ) error { - exists, err := h.db.ExistsSubscription(account, subscription) + exists, err := h.db.ExistsSubscription(ctx, account, subscription) if err != nil || exists { // if no error, subscription exists, short circuit return err } - if err := h.createEventChannel(account, subscription.CalendarID); err != nil { + if err := h.createEventChannel(ctx, account, subscription.CalendarID); err != nil { return err } - if err := h.db.InsertSubscription(account, subscription); err != nil { + if err := h.db.InsertSubscription(ctx, account, subscription); err != nil { return err } @@ -224,9 +227,9 @@ func (h *Handler) createSubscription( } func (h *Handler) removeSubscription( - account *Account, subscription Subscription, + ctx context.Context, account *Account, subscription Subscription, ) error { - err := h.db.DeleteSubscription(account, subscription) + err := h.db.DeleteSubscription(ctx, account, subscription) if err != nil { // if no error, subscription doesn't exist, short circuit return err @@ -234,20 +237,20 @@ func (h *Handler) removeSubscription( h.reminderScheduler.RemoveSubscription(account, subscription) - subscriptionCount, err := h.db.CountSubscriptionsByAccountAndCalender(account, subscription.CalendarID) + subscriptionCount, err := h.db.CountSubscriptionsByAccountAndCalender(ctx, account, subscription.CalendarID) if err != nil { return err } if subscriptionCount == 0 { // if there are no more subscriptions for this account + calendar, remove the channel - channel, err := h.db.GetChannel(account, subscription.CalendarID) + channel, err := h.db.GetChannel(ctx, account, subscription.CalendarID) if err != nil { return err } if channel != nil { - srv, err := GetCalendarService(account, h.oauth, h.db) + srv, err := GetCalendarService(ctx, account, h.oauth, h.db) if err != nil { return err } @@ -266,7 +269,7 @@ func (h *Handler) removeSubscription( return err } - err = h.db.DeleteChannelByChannelID(channel.ChannelID) + err = h.db.DeleteChannelByChannelID(ctx, channel.ChannelID) if err != nil { return err } @@ -276,12 +279,12 @@ func (h *Handler) removeSubscription( return nil } -func (h *Handler) createEventChannel(account *Account, calendarID string) error { - srv, err := GetCalendarService(account, h.oauth, h.db) +func (h *Handler) createEventChannel(ctx context.Context, account *Account, calendarID string) error { + srv, err := GetCalendarService(ctx, account, h.oauth, h.db) if err != nil { return err } - exists, err := h.db.ExistsChannelByAccountAndCalendar(account, calendarID) + exists, err := h.db.ExistsChannelByAccountAndCalendar(ctx, account, calendarID) if err != nil || exists { // if err is nil but the channel exists, return return err @@ -303,7 +306,7 @@ func (h *Handler) createEventChannel(account *Account, calendarID string) error return err } - err = h.db.InsertChannel(account, Channel{ + err = h.db.InsertChannel(ctx, account, Channel{ ChannelID: channelID, CalendarID: calendarID, ResourceID: res.ResourceId, @@ -315,6 +318,7 @@ func (h *Handler) createEventChannel(account *Account, calendarID string) error // pre-fill db with invites so we don't send old invites // there could be a race since this process can take up to a few seconds + // context.Background() because syncAllInvites is a background goroutine that outlives the request go h.syncAllInvites(account, srv, channelID, calendarID) return nil @@ -381,7 +385,7 @@ func (r *RenewChannelScheduler) renewScheduler(shutdownCh chan struct{}) { case <-shutdownCh: return case renewMinute := <-ticker.C: - pairs, err := r.db.GetExpiringChannelAndAccountList() + pairs, err := r.db.GetExpiringChannelAndAccountList(context.Background()) if err != nil { r.Errorf("error getting expiring pairs: %s", err) } @@ -403,7 +407,7 @@ func (r *RenewChannelScheduler) renewScheduler(shutdownCh chan struct{}) { func (r *RenewChannelScheduler) renewChannel(account *Account, channel *Channel) error { r.stats.Count("renewChannel") - srv, err := GetCalendarService(account, r.config, r.db) + srv, err := GetCalendarService(context.Background(), account, r.config, r.db) switch err.(type) { case nil: case *oauth2.RetrieveError: @@ -428,7 +432,7 @@ func (r *RenewChannelScheduler) renewChannel(account *Account, channel *Channel) return err } - err = r.db.UpdateChannel(channel.ChannelID, newChannelID, res.ResourceId, time.Unix(res.Expiration/1e3, 0)) + err = r.db.UpdateChannel(context.Background(), channel.ChannelID, newChannelID, res.ResourceId, time.Unix(res.Expiration/1e3, 0)) if err != nil { return err } diff --git a/githubbot/githubbot/db.go b/githubbot/githubbot/db.go index ef416773..9bf48372 100644 --- a/githubbot/githubbot/db.go +++ b/githubbot/githubbot/db.go @@ -1,8 +1,8 @@ package githubbot import ( + "context" "database/sql" - "fmt" "strings" "github.com/keybase/go-keybase-chat-bot/kbchat/types/chat1" @@ -22,74 +22,62 @@ func NewDB(db *sql.DB) *DB { // webhook subscription methods -func (d *DB) CreateSubscription(convID chat1.ConvIDStr, repo string, installationID int64) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO subscriptions - (conv_id, repo, installation_id) - VALUES - (?, ?, ?) - ON DUPLICATE KEY UPDATE - installation_id=VALUES(installation_id) - `, convID, repo, installationID) - return err - }) +func (d *DB) CreateSubscription(ctx context.Context, convID chat1.ConvIDStr, repo string, installationID int64) error { + _, err := d.ExecContext(ctx, ` + INSERT INTO subscriptions + (conv_id, repo, installation_id) + VALUES + (?, ?, ?) + ON DUPLICATE KEY UPDATE + installation_id=VALUES(installation_id) + `, convID, repo, installationID) + return err } -func (d *DB) DeleteSubscription(convID chat1.ConvIDStr, repo string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM subscriptions - WHERE conv_id = ? AND repo = ? - `, convID, repo) - return err - }) +func (d *DB) DeleteSubscription(ctx context.Context, convID chat1.ConvIDStr, repo string) error { + _, err := d.ExecContext(ctx, ` + DELETE FROM subscriptions + WHERE conv_id = ? AND repo = ? + `, convID, repo) + return err } -func (d *DB) WatchBranch(convID chat1.ConvIDStr, repo string, branch string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT IGNORE INTO branches - (conv_id, repo, branch) - VALUES - (?, ?, ?) - `, convID, repo, branch) - return err - }) +func (d *DB) WatchBranch(ctx context.Context, convID chat1.ConvIDStr, repo string, branch string) error { + _, err := d.ExecContext(ctx, ` + INSERT IGNORE INTO branches + (conv_id, repo, branch) + VALUES + (?, ?, ?) + `, convID, repo, branch) + return err } -func (d *DB) UnwatchBranch(convID chat1.ConvIDStr, repo string, branch string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM branches - WHERE conv_id = ? AND repo = ? AND branch = ? - `, convID, repo, branch) - return err - }) +func (d *DB) UnwatchBranch(ctx context.Context, convID chat1.ConvIDStr, repo string, branch string) error { + _, err := d.ExecContext(ctx, ` + DELETE FROM branches + WHERE conv_id = ? AND repo = ? AND branch = ? + `, convID, repo, branch) + return err } -func (d *DB) DeleteSubscriptionsForRepo(convID chat1.ConvIDStr, repo string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM subscriptions - WHERE conv_id = ? AND repo = ? - `, convID, repo) - return err - }) +func (d *DB) DeleteSubscriptionsForRepo(ctx context.Context, convID chat1.ConvIDStr, repo string) error { + _, err := d.ExecContext(ctx, ` + DELETE FROM subscriptions + WHERE conv_id = ? AND repo = ? + `, convID, repo) + return err } -func (d *DB) DeleteBranchesForRepo(convID chat1.ConvIDStr, repo string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM subscriptions - WHERE conv_id = ? AND repo = ? - `, convID, repo) - return err - }) +func (d *DB) DeleteBranchesForRepo(ctx context.Context, convID chat1.ConvIDStr, repo string) error { + _, err := d.ExecContext(ctx, ` + DELETE FROM branches + WHERE conv_id = ? AND repo = ? + `, convID, repo) + return err } -func (d *DB) GetConvIDsFromRepoInstallation(repo string, installationID int64) (res []chat1.ConvIDStr, err error) { - rows, err := d.Query(` +func (d *DB) GetConvIDsFromRepoInstallation(ctx context.Context, repo string, installationID int64) (res []chat1.ConvIDStr, err error) { + rows, err := d.QueryContext(ctx, ` SELECT conv_id FROM subscriptions WHERE repo = ? AND installation_id = ? @@ -98,11 +86,7 @@ func (d *DB) GetConvIDsFromRepoInstallation(repo string, installationID int64) ( if err != nil { return res, err } - defer func() { - if cerr := rows.Close(); cerr != nil { - fmt.Printf("GetConvIDsFromRepoInstallation: failed to close rows: %v\n", cerr) - } - }() + defer rows.Close() for rows.Next() { var convID chat1.ConvIDStr if err := rows.Scan(&convID); err != nil { @@ -110,11 +94,11 @@ func (d *DB) GetConvIDsFromRepoInstallation(repo string, installationID int64) ( } res = append(res, convID) } - return res, nil + return res, rows.Err() } -func (d *DB) GetSubscriptionForBranchExists(convID chat1.ConvIDStr, repo string, branch string) (exists bool, err error) { - row := d.QueryRow(` +func (d *DB) GetSubscriptionForBranchExists(ctx context.Context, convID chat1.ConvIDStr, repo string, branch string) (exists bool, err error) { + row := d.QueryRowContext(ctx, ` SELECT 1 FROM branches WHERE conv_id = ? AND repo = ? AND branch = ? @@ -132,8 +116,8 @@ func (d *DB) GetSubscriptionForBranchExists(convID chat1.ConvIDStr, repo string, } } -func (d *DB) GetSubscriptionForRepoExists(convID chat1.ConvIDStr, repo string) (exists bool, err error) { - row := d.QueryRow(` +func (d *DB) GetSubscriptionForRepoExists(ctx context.Context, convID chat1.ConvIDStr, repo string) (exists bool, err error) { + row := d.QueryRowContext(ctx, ` SELECT 1 FROM subscriptions WHERE conv_id = ? AND repo = ? @@ -150,17 +134,15 @@ func (d *DB) GetSubscriptionForRepoExists(convID chat1.ConvIDStr, repo string) ( } } -func (d *DB) GetAllBranchesForRepo(convID chat1.ConvIDStr, repo string) ([]string, error) { - rows, err := d.Query(`SELECT branch +func (d *DB) GetAllBranchesForRepo(ctx context.Context, convID chat1.ConvIDStr, repo string) ([]string, error) { + rows, err := d.QueryContext(ctx, `SELECT branch FROM branches WHERE conv_id = ? AND repo = ?`, convID, repo) if err != nil { return nil, err } res := []string{} - defer func() { - _ = rows.Close() - }() + defer rows.Close() for rows.Next() { var branch string if err := rows.Scan(&branch); err != nil { @@ -168,7 +150,7 @@ func (d *DB) GetAllBranchesForRepo(convID chat1.ConvIDStr, repo string) ([]strin } res = append(res, branch) } - return res, nil + return res, rows.Err() } // subscription preferences @@ -209,26 +191,24 @@ func (f *Features) String() string { return strings.Join(res, ", ") } -func (d *DB) SetFeatures(convID chat1.ConvIDStr, repo string, features *Features) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO features - (conv_id, repo, issues, pull_requests, commits, statuses, releases) - VALUES - (?, ?, ?, ?, ?, ?, ?) - ON DUPLICATE KEY UPDATE - issues=VALUES(issues), - pull_requests=VALUES(pull_requests), - commits=VALUES(commits), - statuses=VALUES(statuses), - releases=VALUES(releases) - `, convID, repo, features.Issues, features.PullRequests, features.Commits, features.Statuses, features.Releases) - return err - }) +func (d *DB) SetFeatures(ctx context.Context, convID chat1.ConvIDStr, repo string, features *Features) error { + _, err := d.ExecContext(ctx, ` + INSERT INTO features + (conv_id, repo, issues, pull_requests, commits, statuses, releases) + VALUES + (?, ?, ?, ?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + issues=VALUES(issues), + pull_requests=VALUES(pull_requests), + commits=VALUES(commits), + statuses=VALUES(statuses), + releases=VALUES(releases) + `, convID, repo, features.Issues, features.PullRequests, features.Commits, features.Statuses, features.Releases) + return err } -func (d *DB) GetFeatures(convID chat1.ConvIDStr, repo string) (*Features, error) { - row := d.QueryRow(`SELECT issues, pull_requests, commits, statuses, releases +func (d *DB) GetFeatures(ctx context.Context, convID chat1.ConvIDStr, repo string) (*Features, error) { + row := d.QueryRowContext(ctx, `SELECT issues, pull_requests, commits, statuses, releases FROM features WHERE conv_id = ? AND repo = ?`, convID, repo) features := &Features{} @@ -243,8 +223,8 @@ func (d *DB) GetFeatures(convID chat1.ConvIDStr, repo string) (*Features, error) } } -func (d *DB) GetFeaturesForAllRepos(convID chat1.ConvIDStr) (map[string]Features, error) { - rows, err := d.Query(`SELECT repo, COALESCE(issues, true), COALESCE(pull_requests, true), +func (d *DB) GetFeaturesForAllRepos(ctx context.Context, convID chat1.ConvIDStr) (map[string]Features, error) { + rows, err := d.QueryContext(ctx, `SELECT repo, COALESCE(issues, true), COALESCE(pull_requests, true), COALESCE(commits, true), COALESCE(statuses, true), COALESCE(releases, true) FROM subscriptions LEFT JOIN features USING(conv_id, repo) @@ -253,9 +233,7 @@ func (d *DB) GetFeaturesForAllRepos(convID chat1.ConvIDStr) (map[string]Features return nil, err } res := make(map[string]Features) - defer func() { - _ = rows.Close() - }() + defer rows.Close() for rows.Next() { var repo string var features Features @@ -264,24 +242,22 @@ func (d *DB) GetFeaturesForAllRepos(convID chat1.ConvIDStr) (map[string]Features } res[repo] = features } - return res, nil + return res, rows.Err() } -func (d *DB) DeleteFeaturesForRepo(convID chat1.ConvIDStr, repo string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM features - WHERE conv_id = ? AND repo = ? - `, convID, repo) - return err - }) +func (d *DB) DeleteFeaturesForRepo(ctx context.Context, convID chat1.ConvIDStr, repo string) error { + _, err := d.ExecContext(ctx, ` + DELETE FROM features + WHERE conv_id = ? AND repo = ? + `, convID, repo) + return err } // OAuth2 token methods -func (d *DB) GetToken(identifier string) (*oauth2.Token, error) { +func (d *DB) GetToken(ctx context.Context, identifier string) (*oauth2.Token, error) { var token oauth2.Token - row := d.QueryRow(`SELECT access_token, token_type + row := d.QueryRowContext(ctx, `SELECT access_token, token_type FROM oauth WHERE identifier = ?`, identifier) err := row.Scan(&token.AccessToken, &token.TokenType) @@ -295,25 +271,19 @@ func (d *DB) GetToken(identifier string) (*oauth2.Token, error) { } } -func (d *DB) PutToken(identifier string, token *oauth2.Token) error { - err := d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(`INSERT INTO oauth +func (d *DB) PutToken(ctx context.Context, identifier string, token *oauth2.Token) error { + _, err := d.ExecContext(ctx, `INSERT INTO oauth (identifier, access_token, token_type, ctime, mtime) VALUES (?, ?, ?, NOW(), NOW()) ON DUPLICATE KEY UPDATE access_token=VALUES(access_token), mtime=VALUES(mtime) `, identifier, token.AccessToken, token.TokenType) - return err - }) return err } -func (d *DB) DeleteToken(identifier string) error { - err := d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec("DELETE FROM oauth WHERE identifier = ?", identifier) - return err - }) +func (d *DB) DeleteToken(ctx context.Context, identifier string) error { + _, err := d.ExecContext(ctx, "DELETE FROM oauth WHERE identifier = ?", identifier) return err } @@ -323,8 +293,8 @@ type UserPreferences struct { Mention bool } -func (d *DB) GetUserPreferences(username string, convID chat1.ConvIDStr) (*UserPreferences, error) { - row := d.QueryRow(`SELECT mention +func (d *DB) GetUserPreferences(ctx context.Context, username string, convID chat1.ConvIDStr) (*UserPreferences, error) { + row := d.QueryRowContext(ctx, `SELECT mention FROM user_prefs WHERE username = ? AND conv_id = ?`, username, convID) prefs := &UserPreferences{} @@ -342,16 +312,13 @@ func (d *DB) GetUserPreferences(username string, convID chat1.ConvIDStr) (*UserP } } -func (d *DB) SetUserPreferences(username string, convID chat1.ConvIDStr, prefs *UserPreferences) error { - err := d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(`INSERT INTO user_prefs +func (d *DB) SetUserPreferences(ctx context.Context, username string, convID chat1.ConvIDStr, prefs *UserPreferences) error { + _, err := d.ExecContext(ctx, `INSERT INTO user_prefs (username, conv_id, mention) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE mention=VALUES(mention) `, username, convID, prefs.Mention) - return err - }) return err } @@ -362,17 +329,15 @@ type DBSubscription struct { InstallationID int64 } -func (d *DB) GetAllSubscriptions() (res []DBSubscription, err error) { - rows, err := d.Query(` +func (d *DB) GetAllSubscriptions(ctx context.Context) (res []DBSubscription, err error) { + rows, err := d.QueryContext(ctx, ` SELECT conv_id, repo, installation_id FROM subscriptions `) if err != nil { return res, err } - defer func() { - _ = rows.Close() - }() + defer rows.Close() for rows.Next() { var subscription DBSubscription if err := rows.Scan(&subscription.ConvID, &subscription.Repo, &subscription.InstallationID); err != nil { @@ -380,5 +345,5 @@ func (d *DB) GetAllSubscriptions() (res []DBSubscription, err error) { } res = append(res, subscription) } - return res, nil + return res, rows.Err() } diff --git a/githubbot/githubbot/handler.go b/githubbot/githubbot/handler.go index 289cb031..5c0090f0 100644 --- a/githubbot/githubbot/handler.go +++ b/githubbot/githubbot/handler.go @@ -42,7 +42,7 @@ func NewHandler(stats *base.StatsRegistry, kbc *kbchat.API, debugConfig *base.Ch } } -func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(_ context.Context, conv chat1.ConvSummary) error { welcomeMsg := fmt.Sprintf( "Hi! I can notify you whenever something happens on a GitHub repository. To get started, install the Keybase integration on your repository, then send `!github subscribe `\n\ngithub.com/apps/%s/installations/new", h.appName, @@ -50,11 +50,11 @@ func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { return base.HandleNewTeam(h.stats, h.DebugOutput, h.kbc, conv, welcomeMsg) } -func (h *Handler) HandleAuth(msg chat1.MsgSummary, _ string) error { - return h.HandleCommand(msg) +func (h *Handler) HandleAuth(ctx context.Context, msg chat1.MsgSummary, _ string) error { + return h.HandleCommand(ctx, msg) } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(ctx context.Context, msg chat1.MsgSummary) error { if msg.Content.Text == nil { return nil } @@ -68,27 +68,27 @@ func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { if strings.HasPrefix(cmd, "!github mentions") { // handle user preferences without needing oauth h.stats.Count("mentions") - return h.handleMentionPref(cmd, msg) + return h.handleMentionPref(ctx, cmd, msg) } client := github.NewClient(&http.Client{Transport: h.atr}) switch { case strings.HasPrefix(cmd, "!github subscribe"): h.stats.Count("subscribe") - return h.handleSubscribe(cmd, msg, true, client) + return h.handleSubscribe(ctx, cmd, msg, true, client) case strings.HasPrefix(cmd, "!github unsubscribe"): h.stats.Count("unsubscribe") - return h.handleSubscribe(cmd, msg, false, client) + return h.handleSubscribe(ctx, cmd, msg, false, client) case strings.HasPrefix(cmd, "!github list"): h.stats.Count("list") - return h.handleListSubscriptions(msg) + return h.handleListSubscriptions(ctx, msg) default: h.Debug("ignoring unknown command %q", cmd) } return nil } -func (h *Handler) handleSubscribe(cmd string, msg chat1.MsgSummary, create bool, client *github.Client) (err error) { +func (h *Handler) handleSubscribe(ctx context.Context, cmd string, msg chat1.MsgSummary, create bool, client *github.Client) (err error) { toks, userErr, err := base.SplitTokens(cmd) if err != nil { return err @@ -118,14 +118,14 @@ func (h *Handler) handleSubscribe(cmd string, msg chat1.MsgSummary, create bool, repo := args[0] // Check if command is subscribing to a branch - alreadyExists, err := h.db.GetSubscriptionForRepoExists(msg.ConvID, repo) + alreadyExists, err := h.db.GetSubscriptionForRepoExists(ctx, msg.ConvID, repo) if err != nil { return fmt.Errorf("error checking subscription: %s", err) } if len(args) == 2 { if !alreadyExists { if create { - if created, err := h.handleNewSubscription(repo, msg, client); err != nil { + if created, err := h.handleNewSubscription(ctx, repo, msg, client); err != nil { if _, ok := err.(base.OAuthRequiredError); ok { return nil } @@ -140,9 +140,9 @@ func (h *Handler) handleSubscribe(cmd string, msg chat1.MsgSummary, create bool, } switch args[1] { case "issues", "pulls", "statuses", "commits", "releases": - return h.handleSubscribeToFeature(repo, args[1], msg, create) + return h.handleSubscribeToFeature(ctx, repo, args[1], msg, create) default: - return h.handleSubscribeToBranch(repo, args[1], msg, create) + return h.handleSubscribeToBranch(ctx, repo, args[1], msg, create) } } @@ -151,7 +151,7 @@ func (h *Handler) handleSubscribe(cmd string, msg chat1.MsgSummary, create bool, h.ChatEcho(msg.ConvID, "You're already receiving notifications for `%s` here!", repo) return nil } - created, err := h.handleNewSubscription(repo, msg, client) + created, err := h.handleNewSubscription(ctx, repo, msg, client) if err != nil { if _, ok := err.(base.OAuthRequiredError); ok { return nil @@ -170,17 +170,17 @@ func (h *Handler) handleSubscribe(cmd string, msg chat1.MsgSummary, create bool, return nil } - err = h.db.DeleteSubscriptionsForRepo(msg.ConvID, repo) + err = h.db.DeleteSubscriptionsForRepo(ctx, msg.ConvID, repo) if err != nil { return fmt.Errorf("error deleting subscriptions: %s", err) } - err = h.db.DeleteBranchesForRepo(msg.ConvID, repo) + err = h.db.DeleteBranchesForRepo(ctx, msg.ConvID, repo) if err != nil { return fmt.Errorf("error deleting branches: %s", err) } - err = h.db.DeleteFeaturesForRepo(msg.ConvID, repo) + err = h.db.DeleteFeaturesForRepo(ctx, msg.ConvID, repo) if err != nil { return fmt.Errorf("error deleting features: %s", err) } @@ -188,8 +188,8 @@ func (h *Handler) handleSubscribe(cmd string, msg chat1.MsgSummary, create bool, return nil } -func (h *Handler) handleListSubscriptions(msg chat1.MsgSummary) (err error) { - features, err := h.db.GetFeaturesForAllRepos(msg.ConvID) +func (h *Handler) handleListSubscriptions(ctx context.Context, msg chat1.MsgSummary) (err error) { + features, err := h.db.GetFeaturesForAllRepos(ctx, msg.ConvID) if err != nil { return fmt.Errorf("Error getting current features: %s", err) } @@ -203,7 +203,7 @@ func (h *Handler) handleListSubscriptions(msg chat1.MsgSummary) (err error) { for repo, f := range features { res.WriteString(fmt.Sprintf("- *%s* (%s)\n", repo, &f)) if f.Commits { - branches, err := h.db.GetAllBranchesForRepo(msg.ConvID, repo) + branches, err := h.db.GetAllBranchesForRepo(ctx, msg.ConvID, repo) if err != nil { return fmt.Errorf("error getting branches for repo: %s", err) } @@ -217,7 +217,7 @@ func (h *Handler) handleListSubscriptions(msg chat1.MsgSummary) (err error) { return nil } -func (h *Handler) handleNewSubscription(repo string, msg chat1.MsgSummary, client *github.Client) (created bool, err error) { +func (h *Handler) handleNewSubscription(ctx context.Context, repo string, msg chat1.MsgSummary, client *github.Client) (created bool, err error) { parsedRepo := strings.Split(repo, "/") if len(parsedRepo) != 2 { h.ChatEcho(msg.ConvID, "`%s` doesn't look like a repository to me! Try sending `!github subscribe `", repo) @@ -235,7 +235,7 @@ func (h *Handler) handleNewSubscription(repo string, msg chat1.MsgSummary, clien } // check that user has authorization - tc, err := base.GetOAuthClient(msg.Sender.Username, msg, h.kbc, h.oauthConfig, h.db, + tc, err := base.GetOAuthClient(ctx, msg.Sender.Username, msg, h.kbc, h.oauthConfig, h.db, base.GetOAuthOpts{ AuthMessageTemplate: "Authorize me by clicking this link:\n%s", }) @@ -268,20 +268,20 @@ func (h *Handler) handleNewSubscription(repo string, msg chat1.MsgSummary, clien return false, fmt.Errorf("error getting default branch: %s", err) } - err = h.db.WatchBranch(msg.ConvID, repo, defaultBranch) + err = h.db.WatchBranch(ctx, msg.ConvID, repo, defaultBranch) if err != nil { return false, fmt.Errorf("error watching branch: %s", err) } - err = h.db.CreateSubscription(msg.ConvID, repo, repoInstallation.GetID()) + err = h.db.CreateSubscription(ctx, msg.ConvID, repo, repoInstallation.GetID()) if err != nil { return false, fmt.Errorf("error creating subscription: %s", err) } return true, nil } -func (h *Handler) handleSubscribeToFeature(repo, feature string, msg chat1.MsgSummary, enable bool) (err error) { - exists, err := h.db.GetSubscriptionForRepoExists(msg.ConvID, repo) +func (h *Handler) handleSubscribeToFeature(ctx context.Context, repo, feature string, msg chat1.MsgSummary, enable bool) (err error) { + exists, err := h.db.GetSubscriptionForRepoExists(ctx, msg.ConvID, repo) if err != nil { return fmt.Errorf("error getting subscription: %s", err) } else if !exists { @@ -293,7 +293,7 @@ func (h *Handler) handleSubscribeToFeature(repo, feature string, msg chat1.MsgSu return nil } - currentFeatures, err := h.db.GetFeatures(msg.ConvID, repo) + currentFeatures, err := h.db.GetFeatures(ctx, msg.ConvID, repo) if err != nil { return fmt.Errorf("Error getting current features: %s", err) } @@ -317,7 +317,7 @@ func (h *Handler) handleSubscribeToFeature(repo, feature string, msg chat1.MsgSu return fmt.Errorf("Error subscribing to feature: %s is not a valid feature", feature) } - err = h.db.SetFeatures(msg.ConvID, repo, currentFeatures) + err = h.db.SetFeatures(ctx, msg.ConvID, repo, currentFeatures) if err != nil { return fmt.Errorf("Error setting features: %s", err) } @@ -329,8 +329,8 @@ func (h *Handler) handleSubscribeToFeature(repo, feature string, msg chat1.MsgSu return nil } -func (h *Handler) handleSubscribeToBranch(repo, branch string, msg chat1.MsgSummary, create bool) (err error) { - exists, err := h.db.GetSubscriptionForRepoExists(msg.ConvID, repo) +func (h *Handler) handleSubscribeToBranch(ctx context.Context, repo, branch string, msg chat1.MsgSummary, create bool) (err error) { + exists, err := h.db.GetSubscriptionForRepoExists(ctx, msg.ConvID, repo) if err != nil { return fmt.Errorf("error getting subscription: %s", err) } else if !exists { @@ -343,7 +343,7 @@ func (h *Handler) handleSubscribeToBranch(repo, branch string, msg chat1.MsgSumm } if create { - err = h.db.WatchBranch(msg.ConvID, repo, branch) + err = h.db.WatchBranch(ctx, msg.ConvID, repo, branch) if err != nil { return fmt.Errorf("error creating branch subscription: %s", err) } @@ -351,7 +351,7 @@ func (h *Handler) handleSubscribeToBranch(repo, branch string, msg chat1.MsgSumm h.ChatEcho(msg.ConvID, "Now subscribed to notifications for `%s/%s`.", repo, branch) return nil } - err = h.db.UnwatchBranch(msg.ConvID, repo, branch) + err = h.db.UnwatchBranch(ctx, msg.ConvID, repo, branch) if err != nil { return fmt.Errorf("error deleting branch subscription: %s", err) } @@ -361,7 +361,7 @@ func (h *Handler) handleSubscribeToBranch(repo, branch string, msg chat1.MsgSumm } // user preferences -func (h *Handler) handleMentionPref(cmd string, msg chat1.MsgSummary) (err error) { +func (h *Handler) handleMentionPref(ctx context.Context, cmd string, msg chat1.MsgSummary) (err error) { toks, userErr, err := base.SplitTokens(cmd) if err != nil { return err @@ -376,7 +376,7 @@ func (h *Handler) handleMentionPref(cmd string, msg chat1.MsgSummary) (err error } allowMentions := args[0] == "enable" - err = h.db.SetUserPreferences(msg.Sender.Username, msg.ConvID, &UserPreferences{Mention: allowMentions}) + err = h.db.SetUserPreferences(ctx, msg.Sender.Username, msg.ConvID, &UserPreferences{Mention: allowMentions}) if err != nil { return fmt.Errorf("error setting user preference: %s", err) } diff --git a/githubbot/githubbot/http.go b/githubbot/githubbot/http.go index 33944163..7c43c8a4 100644 --- a/githubbot/githubbot/http.go +++ b/githubbot/githubbot/http.go @@ -91,14 +91,17 @@ func (h *HTTPSrv) handleWebhook(_ http.ResponseWriter, r *http.Request) { return } - convs, err := h.db.GetConvIDsFromRepoInstallation(repo, installationID) + // WithoutCancel: GitHub sends the webhook and may close the connection + // immediately; DB reads and Keybase message sends must complete regardless. + ctx := context.WithoutCancel(r.Context()) + convs, err := h.db.GetConvIDsFromRepoInstallation(ctx, repo, installationID) if err != nil { h.Errorf("Error getting subscriptions for repo: %s", err) return } for _, convID := range convs { - features, err := h.db.GetFeatures(convID, repo) + features, err := h.db.GetFeatures(ctx, convID, repo) if err != nil { h.Errorf("Error getting features for repo and convID: %s", err) return @@ -110,7 +113,7 @@ func (h *HTTPSrv) handleWebhook(_ http.ResponseWriter, r *http.Request) { continue } - message, branch := h.formatMessage(convID, event, repo, client) + message, branch := h.formatMessage(ctx, convID, event, repo, client) if message == "" { // if we don't have a message to send, bail continue @@ -118,7 +121,7 @@ func (h *HTTPSrv) handleWebhook(_ http.ResponseWriter, r *http.Request) { if branch != "" { // if the event has a branch associated with it, check if we're subscribed to that branch - subscriptionExists, err := h.db.GetSubscriptionForBranchExists(convID, repo, branch) + subscriptionExists, err := h.db.GetSubscriptionForBranchExists(ctx, convID, repo, branch) if err != nil { h.Errorf("could not get subscription: %s\n", err) return @@ -134,7 +137,7 @@ func (h *HTTPSrv) handleWebhook(_ http.ResponseWriter, r *http.Request) { } } -func (h *HTTPSrv) formatMessage(convID chat1.ConvIDStr, event any, repo string, client *github.Client) (message string, branch string) { +func (h *HTTPSrv) formatMessage(ctx context.Context, convID chat1.ConvIDStr, event any, repo string, client *github.Client) (message string, branch string) { parsedRepo := strings.Split(repo, "/") if len(parsedRepo) != 2 { h.Debug("invalid repo: %s", repo) @@ -142,7 +145,7 @@ func (h *HTTPSrv) formatMessage(convID chat1.ConvIDStr, event any, repo string, } switch event := event.(type) { case *github.IssuesEvent: - author := getPossibleKBUser(h.kbc, h.db, h.DebugOutput, event.GetSender().GetLogin(), convID) + author := getPossibleKBUser(ctx, h.kbc, h.db, h.DebugOutput, event.GetSender().GetLogin(), convID) return git.FormatIssueMsg( *event.Action, author.String(), @@ -152,7 +155,7 @@ func (h *HTTPSrv) formatMessage(convID chat1.ConvIDStr, event any, repo string, event.GetIssue().GetHTMLURL(), ), "" case *github.ReleaseEvent: - author := getPossibleKBUser(h.kbc, h.db, h.DebugOutput, event.GetSender().GetLogin(), convID) + author := getPossibleKBUser(ctx, h.kbc, h.db, h.DebugOutput, event.GetSender().GetLogin(), convID) return git.FormatReleaseMsg( *event.Action, author.String(), @@ -165,9 +168,9 @@ func (h *HTTPSrv) formatMessage(convID chat1.ConvIDStr, event any, repo string, case *github.PullRequestEvent: var author username if event.GetPullRequest().GetMerged() { - author = getPossibleKBUser(h.kbc, h.db, h.DebugOutput, event.GetPullRequest().GetMergedBy().GetLogin(), convID) + author = getPossibleKBUser(ctx, h.kbc, h.db, h.DebugOutput, event.GetPullRequest().GetMergedBy().GetLogin(), convID) } else { - author = getPossibleKBUser(h.kbc, h.db, h.DebugOutput, event.GetPullRequest().GetUser().GetLogin(), convID) + author = getPossibleKBUser(ctx, h.kbc, h.db, h.DebugOutput, event.GetPullRequest().GetUser().GetLogin(), convID) } action := *event.Action @@ -222,20 +225,20 @@ func (h *HTTPSrv) formatMessage(convID chat1.ConvIDStr, event any, repo string, } // fetch the pull request object so we can get the right author - pr, _, err := client.PullRequests.Get(context.TODO(), parsedRepo[0], parsedRepo[1], runPR.GetNumber()) + pr, _, err := client.PullRequests.Get(ctx, parsedRepo[0], parsedRepo[1], runPR.GetNumber()) if err != nil { if !strings.Contains(err.Error(), "401 Bad credentials") { h.Errorf("Error getting pull request object: %s", err) } return formatCheckRunMessage(event, ""), branch } - author = getPossibleKBUser(h.kbc, h.db, h.DebugOutput, pr.GetUser().GetLogin(), convID) + author = getPossibleKBUser(ctx, h.kbc, h.db, h.DebugOutput, pr.GetUser().GetLogin(), convID) return formatCheckRunMessage(event, author.String()), branch case *github.StatusEvent: var author username pullRequests, _, err := client.PullRequests.ListPullRequestsWithCommit( - context.TODO(), + ctx, event.GetRepo().GetOwner().GetLogin(), event.GetRepo().GetName(), event.GetSHA(), @@ -259,11 +262,11 @@ func (h *HTTPSrv) formatMessage(convID chat1.ConvIDStr, event any, repo string, } if runPR != nil { - author = getPossibleKBUser(h.kbc, h.db, h.DebugOutput, runPR.GetUser().GetLogin(), convID) + author = getPossibleKBUser(ctx, h.kbc, h.db, h.DebugOutput, runPR.GetUser().GetLogin(), convID) } else if len(event.Branches) >= 1 { // this is a branch test, not associated with a PR branch = event.Branches[0].GetName() - author = getPossibleKBUser(h.kbc, h.db, h.DebugOutput, event.GetCommit().GetAuthor().GetLogin(), convID) + author = getPossibleKBUser(ctx, h.kbc, h.db, h.DebugOutput, event.GetCommit().GetAuthor().GetLogin(), convID) } else { h.Debug("status event had no pull requests or branches") return "", "" diff --git a/githubbot/githubbot/util.go b/githubbot/githubbot/util.go index 339305c3..d967ac1b 100644 --- a/githubbot/githubbot/util.go +++ b/githubbot/githubbot/util.go @@ -175,7 +175,7 @@ type keybaseID struct { Username string `json:"username"` } -func getPossibleKBUser(kbc *kbchat.API, d *DB, debug *base.DebugOutput, githubUsername string, convID chat1.ConvIDStr) (u username) { +func getPossibleKBUser(ctx context.Context, kbc *kbchat.API, d *DB, debug *base.DebugOutput, githubUsername string, convID chat1.ConvIDStr) (u username) { u = username{githubUsername: githubUsername} id := kbc.Command("id", "-j", fmt.Sprintf("%s@github", githubUsername)) output, err := id.Output() @@ -191,7 +191,7 @@ func getPossibleKBUser(kbc *kbchat.API, d *DB, debug *base.DebugOutput, githubUs return u } - prefs, err := d.GetUserPreferences(i.Username, convID) + prefs, err := d.GetUserPreferences(ctx, i.Username, convID) if err != nil { debug.Debug("getPossibleKBUser: couldn't get user preferences: %s", err) return u diff --git a/githubbot/migrations/default_branch.go b/githubbot/migrations/default_branch.go index 926e86c1..a29ec885 100644 --- a/githubbot/migrations/default_branch.go +++ b/githubbot/migrations/default_branch.go @@ -2,6 +2,7 @@ package main import ( + "context" "database/sql" "flag" "fmt" @@ -79,7 +80,7 @@ func mainInner() int { return 1 } - subs, err := db.GetAllSubscriptions() + subs, err := db.GetAllSubscriptions(context.Background()) if err != nil { fmt.Printf("failed to get all subscriptions: %s", err) return 1 @@ -96,7 +97,7 @@ func mainInner() int { continue } - err = db.WatchBranch(subscription.ConvID, subscription.Repo, defaultBranch) + err = db.WatchBranch(context.Background(), subscription.ConvID, subscription.Repo, defaultBranch) if err != nil { fmt.Printf("Error watching branch: %s", err) return 1 diff --git a/gitlabbot/gitlabbot/db.go b/gitlabbot/gitlabbot/db.go index 42cb481b..c96c1c14 100644 --- a/gitlabbot/gitlabbot/db.go +++ b/gitlabbot/gitlabbot/db.go @@ -1,8 +1,8 @@ package gitlabbot import ( + "context" "database/sql" - "fmt" "github.com/keybase/go-keybase-chat-bot/kbchat/types/chat1" @@ -22,41 +22,35 @@ func NewDB(db *sql.DB) *DB { // webhook subscription methods -func (d *DB) CreateSubscription(convID chat1.ConvIDStr, repo string, oauthIdentifier string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO subscriptions - (conv_id, repo, oauth_identifier) - VALUES (?, ?, ?) - ON DUPLICATE KEY UPDATE - oauth_identifier=VALUES(oauth_identifier) - `, convID, repo, oauthIdentifier) - return err - }) +func (d *DB) CreateSubscription(ctx context.Context, convID chat1.ConvIDStr, repo string, oauthIdentifier string) error { + _, err := d.ExecContext(ctx, ` + INSERT INTO subscriptions + (conv_id, repo, oauth_identifier) + VALUES (?, ?, ?) + ON DUPLICATE KEY UPDATE + oauth_identifier=VALUES(oauth_identifier) + `, convID, repo, oauthIdentifier) + return err } -func (d *DB) DeleteSubscription(convID chat1.ConvIDStr, repo string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM subscriptions - WHERE (conv_id = ? AND repo = ?) - `, convID, repo) - return err - }) +func (d *DB) DeleteSubscription(ctx context.Context, convID chat1.ConvIDStr, repo string) error { + _, err := d.ExecContext(ctx, ` + DELETE FROM subscriptions + WHERE (conv_id = ? AND repo = ?) + `, convID, repo) + return err } -func (d *DB) DeleteSubscriptionsForRepo(convID chat1.ConvIDStr, repo string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM subscriptions - WHERE (conv_id = ? AND repo = ?) - `, convID, repo) - return err - }) +func (d *DB) DeleteSubscriptionsForRepo(ctx context.Context, convID chat1.ConvIDStr, repo string) error { + _, err := d.ExecContext(ctx, ` + DELETE FROM subscriptions + WHERE (conv_id = ? AND repo = ?) + `, convID, repo) + return err } -func (d *DB) GetSubscribedConvs(repo string) (res []chat1.ConvIDStr, err error) { - rows, err := d.Query(` +func (d *DB) GetSubscribedConvs(ctx context.Context, repo string) (res []chat1.ConvIDStr, err error) { + rows, err := d.QueryContext(ctx, ` SELECT conv_id FROM subscriptions WHERE repo = ? @@ -65,11 +59,7 @@ func (d *DB) GetSubscribedConvs(repo string) (res []chat1.ConvIDStr, err error) if err != nil { return res, err } - defer func() { - if cerr := rows.Close(); cerr != nil { - fmt.Printf("GetSubscribedConvs: failed to close rows: %v\n", cerr) - } - }() + defer rows.Close() for rows.Next() { var convID chat1.ConvIDStr if err := rows.Scan(&convID); err != nil { @@ -77,11 +67,11 @@ func (d *DB) GetSubscribedConvs(repo string) (res []chat1.ConvIDStr, err error) } res = append(res, convID) } - return res, nil + return res, rows.Err() } -func (d *DB) GetSubscriptionExists(convID chat1.ConvIDStr, repo string) (exists bool, err error) { - row := d.QueryRow(` +func (d *DB) GetSubscriptionExists(ctx context.Context, convID chat1.ConvIDStr, repo string) (exists bool, err error) { + row := d.QueryRowContext(ctx, ` SELECT 1 FROM subscriptions WHERE (conv_id = ? AND repo = ?) @@ -99,8 +89,8 @@ func (d *DB) GetSubscriptionExists(convID chat1.ConvIDStr, repo string) (exists } } -func (d *DB) GetSubscriptionForRepoExists(convID chat1.ConvIDStr, repo string) (exists bool, err error) { - row := d.QueryRow(` +func (d *DB) GetSubscriptionForRepoExists(ctx context.Context, convID chat1.ConvIDStr, repo string) (exists bool, err error) { + row := d.QueryRowContext(ctx, ` SELECT 1 FROM subscriptions WHERE (conv_id = ? AND repo = ?) @@ -117,8 +107,8 @@ func (d *DB) GetSubscriptionForRepoExists(convID chat1.ConvIDStr, repo string) ( } } -func (d *DB) GetAllSubscriptionsForConvID(convID chat1.ConvIDStr) (res []string, err error) { - rows, err := d.Query(` +func (d *DB) GetAllSubscriptionsForConvID(ctx context.Context, convID chat1.ConvIDStr) (res []string, err error) { + rows, err := d.QueryContext(ctx, ` SELECT repo FROM subscriptions WHERE conv_id = ? @@ -127,11 +117,7 @@ func (d *DB) GetAllSubscriptionsForConvID(convID chat1.ConvIDStr) (res []string, if err != nil { return nil, err } - defer func() { - if cerr := rows.Close(); cerr != nil { - fmt.Printf("GetAllSubscriptionsForConvID: failed to close rows: %v\n", cerr) - } - }() + defer rows.Close() for rows.Next() { var repo string if err := rows.Scan(&repo); err != nil { @@ -139,14 +125,14 @@ func (d *DB) GetAllSubscriptionsForConvID(convID chat1.ConvIDStr) (res []string, } res = append(res, repo) } - return res, nil + return res, rows.Err() } // OAuth2 token methods -func (d *DB) GetToken(identifier string) (*oauth2.Token, error) { +func (d *DB) GetToken(ctx context.Context, identifier string) (*oauth2.Token, error) { var token oauth2.Token - row := d.QueryRow(`SELECT access_token, token_type + row := d.QueryRowContext(ctx, `SELECT access_token, token_type FROM oauth WHERE identifier = ?`, identifier) err := row.Scan(&token.AccessToken, &token.TokenType) @@ -160,24 +146,18 @@ func (d *DB) GetToken(identifier string) (*oauth2.Token, error) { } } -func (d *DB) PutToken(identifier string, token *oauth2.Token) error { - err := d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(`INSERT INTO oauth +func (d *DB) PutToken(ctx context.Context, identifier string, token *oauth2.Token) error { + _, err := d.ExecContext(ctx, `INSERT INTO oauth (identifier, access_token, token_type, ctime, mtime) VALUES (?, ?, ?, NOW(), NOW()) ON DUPLICATE KEY UPDATE access_token=VALUES(access_token), mtime=VALUES(mtime) `, identifier, token.AccessToken, token.TokenType) - return err - }) return err } -func (d *DB) DeleteToken(identifier string) error { - err := d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec("DELETE FROM oauth WHERE identifier = ?", identifier) - return err - }) +func (d *DB) DeleteToken(ctx context.Context, identifier string) error { + _, err := d.ExecContext(ctx, "DELETE FROM oauth WHERE identifier = ?", identifier) return err } diff --git a/gitlabbot/gitlabbot/handler.go b/gitlabbot/gitlabbot/handler.go index 216a65a7..033afd0e 100644 --- a/gitlabbot/gitlabbot/handler.go +++ b/gitlabbot/gitlabbot/handler.go @@ -1,6 +1,7 @@ package gitlabbot import ( + "context" "fmt" "strings" @@ -39,16 +40,16 @@ func NewHandler( } } -func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(_ context.Context, conv chat1.ConvSummary) error { welcomeMsg := "Hi! I can notify you whenever something happens on a GitLab repository. To get started, set up a repository by sending `!gitlab subscribe `" return base.HandleNewTeam(h.stats, h.DebugOutput, h.kbc, conv, welcomeMsg) } -func (h *Handler) HandleAuth(msg chat1.MsgSummary, _ string) error { - return h.HandleCommand(msg) +func (h *Handler) HandleAuth(ctx context.Context, msg chat1.MsgSummary, _ string) error { + return h.HandleCommand(ctx, msg) } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(ctx context.Context, msg chat1.MsgSummary) error { if msg.Content.Text == nil { return nil } @@ -61,18 +62,18 @@ func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { switch { case strings.HasPrefix(cmd, "!gitlab subscribe"): h.stats.Count("subscribe") - return h.handleSubscribe(cmd, msg, true) + return h.handleSubscribe(ctx, cmd, msg, true) case strings.HasPrefix(cmd, "!gitlab unsubscribe"): h.stats.Count("unsubscribe") - return h.handleSubscribe(cmd, msg, false) + return h.handleSubscribe(ctx, cmd, msg, false) case strings.HasPrefix(cmd, "!gitlab list"): h.stats.Count("list") - return h.handleListSubscriptions(msg) + return h.handleListSubscriptions(ctx, msg) } return nil } -func (h *Handler) handleSubscribe(cmd string, msg chat1.MsgSummary, create bool) (err error) { +func (h *Handler) handleSubscribe(ctx context.Context, cmd string, msg chat1.MsgSummary, create bool) (err error) { toks, userErr, err := base.SplitTokens(cmd) if err != nil { return err @@ -93,14 +94,14 @@ func (h *Handler) handleSubscribe(cmd string, msg chat1.MsgSummary, create bool) return nil } - alreadyExists, err := h.db.GetSubscriptionForRepoExists(msg.ConvID, repo) + alreadyExists, err := h.db.GetSubscriptionForRepoExists(ctx, msg.ConvID, repo) if err != nil { return fmt.Errorf("error checking subscription: %s", err) } if create { if !alreadyExists { - err = h.db.CreateSubscription(msg.ConvID, repo, base.IdentifierFromMsg(msg)) + err = h.db.CreateSubscription(ctx, msg.ConvID, repo, base.IdentifierFromMsg(msg)) if err != nil { return fmt.Errorf("error creating subscription: %s", err) } @@ -119,7 +120,7 @@ func (h *Handler) handleSubscribe(cmd string, msg chat1.MsgSummary, create bool) } if alreadyExists { - err = h.db.DeleteSubscriptionsForRepo(msg.ConvID, repo) + err = h.db.DeleteSubscriptionsForRepo(ctx, msg.ConvID, repo) if err != nil { return fmt.Errorf("error deleting subscriptions: %s", err) } @@ -131,8 +132,8 @@ func (h *Handler) handleSubscribe(cmd string, msg chat1.MsgSummary, create bool) return nil } -func (h *Handler) handleListSubscriptions(msg chat1.MsgSummary) (err error) { - subscriptions, err := h.db.GetAllSubscriptionsForConvID(msg.ConvID) +func (h *Handler) handleListSubscriptions(ctx context.Context, msg chat1.MsgSummary) (err error) { + subscriptions, err := h.db.GetAllSubscriptionsForConvID(ctx, msg.ConvID) if err != nil { return fmt.Errorf("error getting current repos: %s", err) } diff --git a/gitlabbot/gitlabbot/http.go b/gitlabbot/gitlabbot/http.go index c3fd47ea..554b7e74 100644 --- a/gitlabbot/gitlabbot/http.go +++ b/gitlabbot/gitlabbot/http.go @@ -1,6 +1,7 @@ package gitlabbot import ( + "context" "fmt" "io" "net/http" @@ -116,8 +117,11 @@ func (h *HTTPSrv) handleWebhook(_ http.ResponseWriter, r *http.Request) { } repo = strings.ToLower(repo) signature := r.Header.Get("X-Gitlab-Token") + // WithoutCancel: GitLab sends the webhook and may close the connection + // immediately; DB reads and Keybase message sends must complete regardless. + ctx := context.WithoutCancel(r.Context()) - convs, err := h.db.GetSubscribedConvs(repo) + convs, err := h.db.GetSubscribedConvs(ctx, repo) if err != nil { h.Errorf("Error getting subscriptions for repo: %s", err) return diff --git a/macrobot/macrobot/db.go b/macrobot/macrobot/db.go index eaab01c5..f7888341 100644 --- a/macrobot/macrobot/db.go +++ b/macrobot/macrobot/db.go @@ -1,8 +1,8 @@ package macrobot import ( + "context" "database/sql" - "fmt" "github.com/keybase/go-keybase-chat-bot/kbchat/types/chat1" "github.com/keybase/managed-bots/base" @@ -18,35 +18,31 @@ func NewDB(db *sql.DB) *DB { } } -func (d *DB) Create(name string, convID chat1.ConvIDStr, isConv bool, macroName, macroMessage string) (created bool, err error) { - err = d.RunTxn(func(tx *sql.Tx) error { - if isConv { - name = string(convID) - } - res, err := tx.Exec(` - INSERT INTO macro - (channel_name, is_conv, macro_name, macro_message) - VALUES - (?, ?, ?, ?) - ON DUPLICATE KEY UPDATE - macro_message=VALUES(macro_message) - `, name, isConv, macroName, macroMessage) - if err != nil { - return err - } - numRows, err := res.RowsAffected() - if err != nil { - return err - } - // https://dev.mysql.com/doc/refman/5.7/en/insert-on-duplicate.html - created = numRows == 1 - return nil - }) - return created, err +func (d *DB) Create(ctx context.Context, name string, convID chat1.ConvIDStr, isConv bool, macroName, macroMessage string) (bool, error) { + if isConv { + name = string(convID) + } + res, err := d.ExecContext(ctx, ` + INSERT INTO macro + (channel_name, is_conv, macro_name, macro_message) + VALUES + (?, ?, ?, ?) + ON DUPLICATE KEY UPDATE + macro_message=VALUES(macro_message) + `, name, isConv, macroName, macroMessage) + if err != nil { + return false, err + } + numRows, err := res.RowsAffected() + if err != nil { + return false, err + } + // https://dev.mysql.com/doc/refman/5.7/en/insert-on-duplicate.html + return numRows == 1, nil } -func (d *DB) Get(name string, convID chat1.ConvIDStr, macroName string) (message string, err error) { - row := d.QueryRow(` +func (d *DB) Get(ctx context.Context, name string, convID chat1.ConvIDStr, macroName string) (message string, err error) { + row := d.QueryRowContext(ctx, ` SELECT macro_message FROM macro WHERE (channel_name = ? OR channel_name = ?) AND macro_name = ? @@ -64,8 +60,8 @@ type Macro struct { IsConv bool } -func (d *DB) List(name string, convID chat1.ConvIDStr) (list []Macro, err error) { - rows, err := d.Query(` +func (d *DB) List(ctx context.Context, name string, convID chat1.ConvIDStr) (list []Macro, err error) { + rows, err := d.QueryContext(ctx, ` SELECT macro_name, macro_message, is_conv FROM macro WHERE channel_name = ? @@ -76,11 +72,7 @@ func (d *DB) List(name string, convID chat1.ConvIDStr) (list []Macro, err error) if err != nil { return nil, err } - defer func() { - if cerr := rows.Close(); cerr != nil { - fmt.Printf("List: failed to close rows: %v\n", cerr) - } - }() + defer rows.Close() for rows.Next() { var macro Macro if err := rows.Scan(¯o.Name, ¯o.Message, ¯o.IsConv); err != nil { @@ -88,13 +80,13 @@ func (d *DB) List(name string, convID chat1.ConvIDStr) (list []Macro, err error) } list = append(list, macro) } - return list, nil + return list, rows.Err() } -func (d *DB) Remove(name string, convID chat1.ConvIDStr, macroName string) (removed bool, err error) { - err = d.RunTxn(func(tx *sql.Tx) error { +func (d *DB) Remove(ctx context.Context, name string, convID chat1.ConvIDStr, macroName string) (removed bool, err error) { + err = d.RunTxnContext(ctx, func(tx *sql.Tx) error { // First try to delete for the conv - res, err := tx.Exec(` + res, err := tx.ExecContext(ctx, ` DELETE FROM macro WHERE channel_name = ? AND macro_name = ? `, convID, macroName) @@ -109,7 +101,7 @@ func (d *DB) Remove(name string, convID chat1.ConvIDStr, macroName string) (remo return nil } // Now try teamwide - res, err = tx.Exec(` + res, err = tx.ExecContext(ctx, ` DELETE FROM macro WHERE channel_name = ? AND macro_name = ? `, name, macroName) diff --git a/macrobot/macrobot/handler.go b/macrobot/macrobot/handler.go index 399f16c3..46960566 100644 --- a/macrobot/macrobot/handler.go +++ b/macrobot/macrobot/handler.go @@ -1,6 +1,7 @@ package macrobot import ( + "context" "database/sql" "fmt" "strings" @@ -40,14 +41,14 @@ func NewHandler(stats *base.StatsRegistry, kbc *kbchat.API, debugConfig *base.Ch } } -func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(ctx context.Context, conv chat1.ConvSummary) error { h.Lock() defer h.Unlock() // When we're put into a team conv, tell the team about the // `create-for-channel` option. if _, ok := h.newConvCache[conv.Channel.Name]; !ok && conv.Channel.MembersType == "team" { - if err := h.doPrivateAdvertisement(conv.Channel, conv.Id); err != nil { + if err := h.doPrivateAdvertisement(ctx, conv.Channel, conv.Id); err != nil { h.Errorf("unable to advertise on new conv: %v", err) } h.newConvCache[conv.Channel.Name] = struct{}{} @@ -63,7 +64,7 @@ func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { return base.HandleNewTeam(h.stats, h.DebugOutput, h.kbc, conv, welcomeMsg) } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(ctx context.Context, msg chat1.MsgSummary) error { if msg.Content.Text == nil { return nil } @@ -84,21 +85,21 @@ func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { switch { case strings.HasPrefix(cmd, "!macro create "): - return h.handleCreate(msg, false, tokens[2:]) + return h.handleCreate(ctx, msg, false, tokens[2:]) case strings.HasPrefix(cmd, "!macro create-for-channel"): - return h.handleCreate(msg, true, tokens[2:]) + return h.handleCreate(ctx, msg, true, tokens[2:]) case strings.HasPrefix(cmd, "!macro list"): - return h.handleList(msg) + return h.handleList(ctx, msg) case strings.HasPrefix(cmd, "!macro remove"): - return h.handleRemove(msg, tokens[2:]) + return h.handleRemove(ctx, msg, tokens[2:]) default: - return h.handleRun(msg, tokens) + return h.handleRun(ctx, msg, tokens) } } -func (h *Handler) handleRun(msg chat1.MsgSummary, args []string) error { +func (h *Handler) handleRun(ctx context.Context, msg chat1.MsgSummary, args []string) error { macroName := strings.TrimPrefix(args[0], "!") - macroMessage, err := h.db.Get(msg.Channel.Name, msg.ConvID, macroName) + macroMessage, err := h.db.Get(ctx, msg.Channel.Name, msg.ConvID, macroName) switch err { case nil: case sql.ErrNoRows: @@ -111,7 +112,7 @@ func (h *Handler) handleRun(msg chat1.MsgSummary, args []string) error { return nil } -func (h *Handler) handleCreate(msg chat1.MsgSummary, forceConv bool, args []string) error { +func (h *Handler) handleCreate(ctx context.Context, msg chat1.MsgSummary, forceConv bool, args []string) error { if len(args) != 2 { h.ChatEcho(msg.ConvID, "Invalid number of arguments. Expected two: ") return nil @@ -137,12 +138,12 @@ func (h *Handler) handleCreate(msg chat1.MsgSummary, forceConv bool, args []stri // non-team conversations always get a conv type advertisement. Teams have // the option of registering a per team or per channel macro. isConv := msg.Channel.MembersType != "team" || forceConv - created, err := h.db.Create(msg.Channel.Name, msg.ConvID, isConv, macroName, macroMessage) + created, err := h.db.Create(ctx, msg.Channel.Name, msg.ConvID, isConv, macroName, macroMessage) if err != nil { return err } - if err = h.doPrivateAdvertisement(msg.Channel, msg.ConvID); err != nil { + if err = h.doPrivateAdvertisement(ctx, msg.Channel, msg.ConvID); err != nil { return err } if created { @@ -153,8 +154,8 @@ func (h *Handler) handleCreate(msg chat1.MsgSummary, forceConv bool, args []stri return nil } -func (h *Handler) handleList(msg chat1.MsgSummary) error { - macroList, err := h.db.List(msg.Channel.Name, msg.ConvID) +func (h *Handler) handleList(ctx context.Context, msg chat1.MsgSummary) error { + macroList, err := h.db.List(ctx, msg.Channel.Name, msg.ConvID) if err != nil { return err } else if len(macroList) == 0 { @@ -187,7 +188,7 @@ func (h *Handler) handleList(msg chat1.MsgSummary) error { return nil } -func (h *Handler) handleRemove(msg chat1.MsgSummary, args []string) error { +func (h *Handler) handleRemove(ctx context.Context, msg chat1.MsgSummary, args []string) error { if len(args) != 1 { h.ChatEcho(msg.ConvID, "Invalid number of arguments. Expected one: ") return nil @@ -202,12 +203,12 @@ func (h *Handler) handleRemove(msg chat1.MsgSummary, args []string) error { } macroName := args[0] - removed, err := h.db.Remove(msg.Channel.Name, msg.ConvID, macroName) + removed, err := h.db.Remove(ctx, msg.Channel.Name, msg.ConvID, macroName) if err != nil { return err } - if err = h.doPrivateAdvertisement(msg.Channel, msg.ConvID); err != nil { + if err = h.doPrivateAdvertisement(ctx, msg.Channel, msg.ConvID); err != nil { return err } @@ -219,8 +220,8 @@ func (h *Handler) handleRemove(msg chat1.MsgSummary, args []string) error { return nil } -func (h *Handler) doPrivateAdvertisement(channel chat1.ChatChannel, convID chat1.ConvIDStr) error { - macroList, err := h.db.List(channel.Name, convID) +func (h *Handler) doPrivateAdvertisement(ctx context.Context, channel chat1.ChatChannel, convID chat1.ConvIDStr) error { + macroList, err := h.db.List(ctx, channel.Name, convID) if err != nil { return err } diff --git a/meetbot/meetbot/handler.go b/meetbot/meetbot/handler.go index 084dccca..b6e7ecd0 100644 --- a/meetbot/meetbot/handler.go +++ b/meetbot/meetbot/handler.go @@ -41,16 +41,16 @@ func NewHandler( } } -func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(_ context.Context, conv chat1.ConvSummary) error { welcomeMsg := "Hello! I can get you set up with a Google Meet video call anytime, just send me `!meet`." return base.HandleNewTeam(h.stats, h.DebugOutput, h.kbc, conv, welcomeMsg) } -func (h *Handler) HandleAuth(msg chat1.MsgSummary, _ string) error { - return h.HandleCommand(msg) +func (h *Handler) HandleAuth(ctx context.Context, msg chat1.MsgSummary, _ string) error { + return h.HandleCommand(ctx, msg) } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(ctx context.Context, msg chat1.MsgSummary) error { if msg.Content.Text == nil { return nil } @@ -58,20 +58,20 @@ func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { cmd := strings.TrimSpace(msg.Content.Text.Body) if strings.HasPrefix(cmd, "!meet") { h.stats.Count("meet") - return h.meetHandler(msg) + return h.meetHandler(ctx, msg) } return nil } -func (h *Handler) meetHandler(msg chat1.MsgSummary) error { +func (h *Handler) meetHandler(ctx context.Context, msg chat1.MsgSummary) error { retry := func() error { // retry auth after nuking stored credentials - if err := h.db.DeleteToken(base.IdentifierFromMsg(msg)); err != nil { + if err := h.db.DeleteToken(ctx, base.IdentifierFromMsg(msg)); err != nil { return err } - return h.meetHandlerInner(msg) + return h.meetHandlerInner(ctx, msg) } - err := h.meetHandlerInner(msg) + err := h.meetHandlerInner(ctx, msg) switch err.(type) { case nil, base.OAuthRequiredError: return nil @@ -88,9 +88,9 @@ func (h *Handler) meetHandler(msg chat1.MsgSummary) error { } } -func (h *Handler) meetHandlerInner(msg chat1.MsgSummary) error { +func (h *Handler) meetHandlerInner(ctx context.Context, msg chat1.MsgSummary) error { identifier := base.IdentifierFromMsg(msg) - client, err := base.GetOAuthClient(identifier, msg, h.kbc, h.config, h.db, + client, err := base.GetOAuthClient(ctx, identifier, msg, h.kbc, h.config, h.db, base.GetOAuthOpts{ AuthMessageTemplate: "Authorize me by clicking this link:\n%s", OAuthOfflineAccessType: true, @@ -101,7 +101,7 @@ func (h *Handler) meetHandlerInner(msg chat1.MsgSummary) error { h.Errorf("unable to get oauth client: %q", identifier) } - srv, err := calendar.NewService(context.Background(), option.WithHTTPClient(client)) + srv, err := calendar.NewService(ctx, option.WithHTTPClient(client)) if err != nil { return err } diff --git a/pollbot/pollbot/db.go b/pollbot/pollbot/db.go index 2c437e0a..d6b32a5a 100644 --- a/pollbot/pollbot/db.go +++ b/pollbot/pollbot/db.go @@ -1,8 +1,8 @@ package pollbot import ( + "context" "database/sql" - "fmt" "github.com/keybase/go-keybase-chat-bot/kbchat/types/chat1" "github.com/keybase/managed-bots/base" @@ -25,20 +25,18 @@ func NewDB(db *sql.DB) *DB { } } -func (d *DB) CreatePoll(id string, convID chat1.ConvIDStr, msgID chat1.MessageID, resultMsgID chat1.MessageID, numChoices int) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO polls - (id, conv_id, msg_id, result_msg_id, choices) - VALUES - (?, ?, ?, ?, ?) - `, id, convID, msgID, resultMsgID, numChoices) - return err - }) +func (d *DB) CreatePoll(ctx context.Context, id string, convID chat1.ConvIDStr, msgID chat1.MessageID, resultMsgID chat1.MessageID, numChoices int) error { + _, err := d.ExecContext(ctx, ` + INSERT INTO polls + (id, conv_id, msg_id, result_msg_id, choices) + VALUES + (?, ?, ?, ?, ?) + `, id, convID, msgID, resultMsgID, numChoices) + return err } -func (d *DB) GetPollInfo(id string) (convID chat1.ConvIDStr, resultMsgID chat1.MessageID, numChoices int, err error) { - row := d.QueryRow(` +func (d *DB) GetPollInfo(ctx context.Context, id string) (convID chat1.ConvIDStr, resultMsgID chat1.MessageID, numChoices int, err error) { + row := d.QueryRowContext(ctx, ` SELECT conv_id, result_msg_id, choices FROM polls WHERE id = ? @@ -49,8 +47,8 @@ func (d *DB) GetPollInfo(id string) (convID chat1.ConvIDStr, resultMsgID chat1.M return convID, resultMsgID, numChoices, nil } -func (d *DB) GetTally(id string) (res Tally, err error) { - rows, err := d.Query(` +func (d *DB) GetTally(ctx context.Context, id string) (res Tally, err error) { + rows, err := d.QueryContext(ctx, ` SELECT choice, count(*) FROM votes WHERE id = ? @@ -59,11 +57,7 @@ func (d *DB) GetTally(id string) (res Tally, err error) { if err != nil { return res, err } - defer func() { - if cerr := rows.Close(); cerr != nil { - fmt.Printf("GetTally: failed to close rows: %v\n", cerr) - } - }() + defer rows.Close() for rows.Next() { var tres TallyResult if err := rows.Scan(&tres.choice, &tres.votes); err != nil { @@ -71,17 +65,15 @@ func (d *DB) GetTally(id string) (res Tally, err error) { } res = append(res, tres) } - return res, nil + return res, rows.Err() } -func (d *DB) CastVote(username string, vote Vote) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - REPLACE INTO votes - (id, username, choice) - VALUES - (?, ?, ?) - `, vote.ID, username, vote.Choice) - return err - }) +func (d *DB) CastVote(ctx context.Context, username string, vote Vote) error { + _, err := d.ExecContext(ctx, ` + REPLACE INTO votes + (id, username, choice) + VALUES + (?, ?, ?) + `, vote.ID, username, vote.Choice) + return err } diff --git a/pollbot/pollbot/handler.go b/pollbot/pollbot/handler.go index f8446954..546b29e1 100644 --- a/pollbot/pollbot/handler.go +++ b/pollbot/pollbot/handler.go @@ -1,6 +1,7 @@ package pollbot import ( + "context" "flag" "fmt" "net/url" @@ -47,7 +48,7 @@ func (h *Handler) generateVoteLink(id string, choice int) string { return strings.ReplaceAll(link, "%", "%%") } -func (h *Handler) generateAnonymousPoll(convID chat1.ConvIDStr, prompt string, options []string) error { +func (h *Handler) generateAnonymousPoll(ctx context.Context, convID chat1.ConvIDStr, prompt string, options []string) error { id := base.RandHexString(8) promptBody := fmt.Sprintf("Anonymous Poll: *%s*\n\n", prompt) sendRes, err := h.kbc.SendMessageByConvID(convID, "%s", promptBody) @@ -71,7 +72,7 @@ func (h *Handler) generateAnonymousPoll(convID chat1.ConvIDStr, prompt string, o return fmt.Errorf("failed to get ID of result message") } resultMsgID := *sendRes.Result.MessageID - if err := h.db.CreatePoll(id, convID, promptMsgID, resultMsgID, len(options)); err != nil { + if err := h.db.CreatePoll(ctx, id, convID, promptMsgID, resultMsgID, len(options)); err != nil { return fmt.Errorf("failed to create poll: %s", err) } return nil @@ -100,7 +101,7 @@ func (h *Handler) generatePoll(convID chat1.ConvIDStr, prompt string, options [] return nil } -func (h *Handler) handlePoll(cmd string, convID chat1.ConvIDStr) error { +func (h *Handler) handlePoll(ctx context.Context, cmd string, convID chat1.ConvIDStr) error { cmd = strings.ReplaceAll(cmd, "‘", "'") cmd = strings.ReplaceAll(cmd, "’", "'") cmd = strings.ReplaceAll(cmd, "“", "\"") @@ -128,7 +129,7 @@ func (h *Handler) handlePoll(cmd string, convID chat1.ConvIDStr) error { h.stats.Count("handlePoll") if anonymous { h.stats.Count("handlePoll - anonymous") - return h.generateAnonymousPoll(convID, prompt, args[1:]) + return h.generateAnonymousPoll(ctx, convID, prompt, args[1:]) } return h.generatePoll(convID, prompt, args[1:]) } @@ -151,19 +152,19 @@ To login your web browser in order to vote in anonymous polls, please follow the } } -func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(_ context.Context, conv chat1.ConvSummary) error { welcomeMsg := "Find out the answers to the hardest questions. Try `!poll 'Should we move the office to a beach?' Yes No`" return base.HandleNewTeam(h.stats, h.DebugOutput, h.kbc, conv, welcomeMsg) } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(ctx context.Context, msg chat1.MsgSummary) error { if msg.Content.Text == nil { return nil } cmd := strings.TrimSpace(msg.Content.Text.Body) switch { case strings.HasPrefix(cmd, "!poll"): - return h.handlePoll(cmd, msg.ConvID) + return h.handlePoll(ctx, cmd, msg.ConvID) case strings.ToLower(cmd) == "login": h.handleLogin(msg.Channel.Name, msg.Sender.Username) } diff --git a/pollbot/pollbot/http.go b/pollbot/pollbot/http.go index c3a16679..da9776d7 100644 --- a/pollbot/pollbot/http.go +++ b/pollbot/pollbot/http.go @@ -2,6 +2,7 @@ package pollbot import ( "bytes" + "context" "crypto/hmac" "crypto/sha256" "encoding/base64" @@ -90,18 +91,21 @@ func (h *HTTPSrv) handleVote(w http.ResponseWriter, r *http.Request) { } vstr := r.URL.Query().Get("") vote := NewVoteFromEncoded(vstr) - if err := h.db.CastVote(username, vote); err != nil { + // WithoutCancel: the browser submits the vote and may close the connection; + // DB writes and poll result updates must complete regardless. + ctx := context.WithoutCancel(r.Context()) + if err := h.db.CastVote(ctx, username, vote); err != nil { h.Errorf("failed to cast vote: %s", err) h.showError(w) return } - convID, resultMsgID, numChoices, err := h.db.GetPollInfo(vote.ID) + convID, resultMsgID, numChoices, err := h.db.GetPollInfo(ctx, vote.ID) if err != nil { h.Errorf("failed to find poll result msg: %s", err) h.showError(w) return } - tally, err := h.db.GetTally(vote.ID) + tally, err := h.db.GetTally(ctx, vote.ID) if err != nil { h.Errorf("failed to get tally: %s", err) h.showError(w) diff --git a/triviabot/triviabot/db.go b/triviabot/triviabot/db.go index b23789d5..94a353d3 100644 --- a/triviabot/triviabot/db.go +++ b/triviabot/triviabot/db.go @@ -1,8 +1,8 @@ package triviabot import ( + "context" "database/sql" - "fmt" "github.com/keybase/go-keybase-chat-bot/kbchat/types/chat1" "github.com/keybase/managed-bots/base" @@ -18,25 +18,21 @@ func NewDB(db *sql.DB) *DB { } } -func (d *DB) RecordAnswer(convID chat1.ConvIDStr, username string, pointAdjust int, isCorrect bool) error { - return d.RunTxn(func(tx *sql.Tx) error { - correct := 0 - incorrect := 0 - if isCorrect { - correct = 1 - } else { - incorrect = 1 - } - if _, err := tx.Exec(` - INSERT INTO leaderboard (conv_id, username, points, correct, incorrect) - VALUES (?, ?, ?, ?, ?) - ON DUPLICATE KEY UPDATE points=points+VALUES(points),correct=correct+VALUES(correct), - incorrect=incorrect+VALUES(incorrect) - `, base.ShortConvID(convID), username, pointAdjust, correct, incorrect); err != nil { - return err - } - return nil - }) +func (d *DB) RecordAnswer(ctx context.Context, convID chat1.ConvIDStr, username string, pointAdjust int, isCorrect bool) error { + correct := 0 + incorrect := 0 + if isCorrect { + correct = 1 + } else { + incorrect = 1 + } + _, err := d.ExecContext(ctx, ` + INSERT INTO leaderboard (conv_id, username, points, correct, incorrect) + VALUES (?, ?, ?, ?, ?) + ON DUPLICATE KEY UPDATE points=points+VALUES(points),correct=correct+VALUES(correct), + incorrect=incorrect+VALUES(incorrect) + `, base.ShortConvID(convID), username, pointAdjust, correct, incorrect) + return err } type TopUser struct { @@ -46,8 +42,8 @@ type TopUser struct { Incorrect int } -func (d *DB) TopUsers(convID chat1.ConvIDStr) (res []TopUser, err error) { - rows, err := d.Query(` +func (d *DB) TopUsers(ctx context.Context, convID chat1.ConvIDStr) (res []TopUser, err error) { + rows, err := d.QueryContext(ctx, ` SELECT username, points, correct, incorrect FROM leaderboard WHERE conv_id = ? @@ -57,11 +53,7 @@ func (d *DB) TopUsers(convID chat1.ConvIDStr) (res []TopUser, err error) { if err != nil { return res, err } - defer func() { - if cerr := rows.Close(); cerr != nil { - fmt.Printf("TopUsers: failed to close rows: %v\n", cerr) - } - }() + defer rows.Close() for rows.Next() { var user TopUser if err := rows.Scan(&user.Username, &user.Points, &user.Correct, &user.Incorrect); err != nil { @@ -69,22 +61,16 @@ func (d *DB) TopUsers(convID chat1.ConvIDStr) (res []TopUser, err error) { } res = append(res, user) } - return res, nil + return res, rows.Err() } -func (d *DB) ResetConv(convID chat1.ConvIDStr) error { - return d.RunTxn(func(tx *sql.Tx) error { - if _, err := tx.Exec(` - DELETE FROM leaderboard WHERE conv_id = ? - `, base.ShortConvID(convID)); err != nil { - return err - } - return nil - }) +func (d *DB) ResetConv(ctx context.Context, convID chat1.ConvIDStr) error { + _, err := d.ExecContext(ctx, `DELETE FROM leaderboard WHERE conv_id = ?`, base.ShortConvID(convID)) + return err } -func (d *DB) GetAPIToken(convID chat1.ConvIDStr) (res string, err error) { - row := d.QueryRow(` +func (d *DB) GetAPIToken(ctx context.Context, convID chat1.ConvIDStr) (res string, err error) { + row := d.QueryRowContext(ctx, ` SELECT token FROM tokens where conv_id = ? `, base.ShortConvID(convID)) if err := row.Scan(&res); err != nil { @@ -93,13 +79,7 @@ func (d *DB) GetAPIToken(convID chat1.ConvIDStr) (res string, err error) { return res, nil } -func (d *DB) SetAPIToken(convID chat1.ConvIDStr, token string) error { - return d.RunTxn(func(tx *sql.Tx) error { - if _, err := tx.Exec(` - REPLACE INTO tokens (conv_id, token) VALUES (?, ?) - `, base.ShortConvID(convID), token); err != nil { - return err - } - return nil - }) +func (d *DB) SetAPIToken(ctx context.Context, convID chat1.ConvIDStr, token string) error { + _, err := d.ExecContext(ctx, `REPLACE INTO tokens (conv_id, token) VALUES (?, ?)`, base.ShortConvID(convID), token) + return err } diff --git a/triviabot/triviabot/handler.go b/triviabot/triviabot/handler.go index 422f2f78..79a3e16a 100644 --- a/triviabot/triviabot/handler.go +++ b/triviabot/triviabot/handler.go @@ -1,6 +1,7 @@ package triviabot import ( + "context" "fmt" "strings" "sync" @@ -44,7 +45,8 @@ func (h *Handler) handleStart(msg chat1.MsgSummary) { base.GoWithRecover(h.DebugOutput, func() { <-doneCb h.ChatEcho(convID, "Session complete, here are the top players") - err := h.handleTop(convID) + // context.Background() because this goroutine outlives the original request context + err := h.handleTop(context.Background(), convID) if err != nil { h.ChatErrorf(msg.ConvID, "%s", err.Error()) } @@ -65,8 +67,8 @@ func (h *Handler) handleStop(msg chat1.MsgSummary) { h.ChatEcho(convID, "Session stopped") } -func (h *Handler) handleTop(convID chat1.ConvIDStr) error { - users, err := h.db.TopUsers(convID) +func (h *Handler) handleTop(ctx context.Context, convID chat1.ConvIDStr) error { + users, err := h.db.TopUsers(ctx, convID) if err != nil { return fmt.Errorf("handleTop: failed to get top users: %s", err) } @@ -82,9 +84,9 @@ func (h *Handler) handleTop(convID chat1.ConvIDStr) error { return nil } -func (h *Handler) handleReset(msg chat1.MsgSummary) error { +func (h *Handler) handleReset(ctx context.Context, msg chat1.MsgSummary) error { convID := msg.ConvID - if err := h.db.ResetConv(convID); err != nil { + if err := h.db.ResetConv(ctx, convID); err != nil { return fmt.Errorf("handleReset: failed to reset: %s", err) } h.ChatEcho(convID, "Leaderboard reset") @@ -106,12 +108,12 @@ func (h *Handler) handleAnswer(convID chat1.ConvIDStr, reaction chat1.MessageRea } } -func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(_ context.Context, conv chat1.ConvSummary) error { welcomeMsg := "Are you up to the challenge? Try `!trivia begin` to find out." return base.HandleNewTeam(h.stats, h.DebugOutput, h.kbc, conv, welcomeMsg) } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(ctx context.Context, msg chat1.MsgSummary) error { if msg.Content.Reaction != nil && msg.Sender.Username != h.kbc.GetUsername() { h.handleAnswer(msg.ConvID, *msg.Content.Reaction, msg.Sender.Username) return nil @@ -129,10 +131,10 @@ func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { h.handleStop(msg) case strings.HasPrefix(cmd, "!trivia top"): h.stats.Count("top") - return h.handleTop(msg.ConvID) + return h.handleTop(ctx, msg.ConvID) case strings.HasPrefix(cmd, "!trivia reset"): h.stats.Count("reset") - return h.handleReset(msg) + return h.handleReset(ctx, msg) } return nil } diff --git a/triviabot/triviabot/session.go b/triviabot/triviabot/session.go index 626a45e0..20abfe0d 100644 --- a/triviabot/triviabot/session.go +++ b/triviabot/triviabot/session.go @@ -1,6 +1,7 @@ package triviabot import ( + "context" "encoding/json" "errors" "fmt" @@ -144,7 +145,7 @@ var errForceAPI = errors.New("API token fetch requested") func (s *session) getToken(forceAPI bool) (token string, err error) { if !forceAPI { - token, err = s.db.GetAPIToken(s.convID) + token, err = s.db.GetAPIToken(context.Background(), s.convID) } else { err = errForceAPI } @@ -154,7 +155,7 @@ func (s *session) getToken(forceAPI bool) (token string, err error) { s.ChatErrorf(s.convID, "getToken: failed to get token from API: %s", err) return "", err } - if err := s.db.SetAPIToken(s.convID, token); err != nil { + if err := s.db.SetAPIToken(context.Background(), s.convID, token); err != nil { s.Errorf("getToken: failed to set token in DB: %s", err) } } else { @@ -318,7 +319,7 @@ func (s *session) waitForCorrectAnswer() { continue } isCorrect, pointAdjust := s.getAnswerPoints(answer, *s.curQuestion) - if err := s.db.RecordAnswer(s.convID, answer.username, pointAdjust, isCorrect); err != nil { + if err := s.db.RecordAnswer(context.Background(), s.convID, answer.username, pointAdjust, isCorrect); err != nil { s.Errorf("waitForCorrectAnswer: failed to record answer: %s", err) } s.regDupe(answer.username) diff --git a/webhookbot/webhookbot/db.go b/webhookbot/webhookbot/db.go index f9fd2c1f..3ffc4bb5 100644 --- a/webhookbot/webhookbot/db.go +++ b/webhookbot/webhookbot/db.go @@ -1,11 +1,11 @@ package webhookbot import ( + "context" "crypto/hmac" "crypto/sha256" "database/sql" "encoding/hex" - "fmt" "github.com/keybase/go-keybase-chat-bot/kbchat/types/chat1" "github.com/keybase/managed-bots/base" @@ -36,27 +36,22 @@ func (d *DB) makeID(name string, convID chat1.ConvIDStr) (string, error) { return base.URLEncoder().EncodeToString(h.Sum(nil)[:20]), nil } -func (d *DB) Create(name string, convID chat1.ConvIDStr) (string, error) { +func (d *DB) Create(ctx context.Context, name string, convID chat1.ConvIDStr) (string, error) { id, err := d.makeID(name, convID) if err != nil { return "", err } - err = d.RunTxn(func(tx *sql.Tx) error { - if _, err := tx.Exec(` - INSERT INTO hooks - (id, name, conv_id) - VALUES - (?, ?, ?) - `, id, name, convID); err != nil { - return err - } - return nil - }) + _, err = d.ExecContext(ctx, ` + INSERT INTO hooks + (id, name, conv_id) + VALUES + (?, ?, ?) + `, id, name, convID) return id, err } -func (d *DB) GetHook(id string) (res Webhook, err error) { - row := d.QueryRow(` +func (d *DB) GetHook(ctx context.Context, id string) (res Webhook, err error) { + row := d.QueryRowContext(ctx, ` SELECT conv_id, name FROM hooks WHERE id = ? `, id) if err := row.Scan(&res.ConvID, &res.Name); err != nil { @@ -71,18 +66,14 @@ type Webhook struct { Name string } -func (d *DB) List(convID chat1.ConvIDStr) (res []Webhook, err error) { - rows, err := d.Query(` +func (d *DB) List(ctx context.Context, convID chat1.ConvIDStr) (res []Webhook, err error) { + rows, err := d.QueryContext(ctx, ` SELECT id, name FROM hooks WHERE conv_id = ? `, convID) if err != nil { return nil, err } - defer func() { - if cerr := rows.Close(); cerr != nil { - fmt.Printf("List: failed to close rows: %v\n", cerr) - } - }() + defer rows.Close() for rows.Next() { var hook Webhook hook.ConvID = convID @@ -91,14 +82,10 @@ func (d *DB) List(convID chat1.ConvIDStr) (res []Webhook, err error) { } res = append(res, hook) } - return res, nil + return res, rows.Err() } -func (d *DB) Remove(name string, convID chat1.ConvIDStr) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE FROM hooks WHERE conv_id = ? AND name = ? - `, convID, name) - return err - }) +func (d *DB) Remove(ctx context.Context, name string, convID chat1.ConvIDStr) error { + _, err := d.ExecContext(ctx, `DELETE FROM hooks WHERE conv_id = ? AND name = ?`, convID, name) + return err } diff --git a/webhookbot/webhookbot/handler.go b/webhookbot/webhookbot/handler.go index 153a2fe5..d22acdee 100644 --- a/webhookbot/webhookbot/handler.go +++ b/webhookbot/webhookbot/handler.go @@ -1,6 +1,7 @@ package webhookbot import ( + "context" "errors" "fmt" "strings" @@ -50,7 +51,7 @@ func (h *Handler) checkAllowed(msg chat1.MsgSummary) error { return nil } -func (h *Handler) handleRemove(cmd string, msg chat1.MsgSummary) (err error) { +func (h *Handler) handleRemove(ctx context.Context, cmd string, msg chat1.MsgSummary) (err error) { convID := msg.ConvID toks := strings.Split(cmd, " ") if len(toks) != 3 { @@ -68,16 +69,16 @@ func (h *Handler) handleRemove(cmd string, msg chat1.MsgSummary) (err error) { } h.stats.Count("remove") name := toks[2] - if err := h.db.Remove(name, convID); err != nil { + if err := h.db.Remove(ctx, name, convID); err != nil { return fmt.Errorf("handleRemove: failed to remove webhook: %s", err) } h.ChatEcho(convID, "Success!") return nil } -func (h *Handler) handleList(_ string, msg chat1.MsgSummary) (err error) { +func (h *Handler) handleList(ctx context.Context, _ string, msg chat1.MsgSummary) (err error) { convID := msg.ConvID - hooks, err := h.db.List(convID) + hooks, err := h.db.List(ctx, convID) if err != nil { return fmt.Errorf("handleList: failed to list hook: %s", err) } @@ -107,7 +108,7 @@ func (h *Handler) handleList(_ string, msg chat1.MsgSummary) (err error) { return nil } -func (h *Handler) handleCreate(cmd string, msg chat1.MsgSummary) (err error) { +func (h *Handler) handleCreate(ctx context.Context, cmd string, msg chat1.MsgSummary) (err error) { convID := msg.ConvID toks := strings.Split(cmd, " ") if len(toks) != 3 { @@ -126,7 +127,7 @@ func (h *Handler) handleCreate(cmd string, msg chat1.MsgSummary) (err error) { h.stats.Count("create") name := toks[2] - id, err := h.db.Create(name, convID) + id, err := h.db.Create(ctx, name, convID) if err != nil { return fmt.Errorf("handleCreate: failed to create webhook: %s", err) } @@ -137,23 +138,23 @@ func (h *Handler) handleCreate(cmd string, msg chat1.MsgSummary) (err error) { return nil } -func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(_ context.Context, conv chat1.ConvSummary) error { welcomeMsg := "I can create generic webhooks into Keybase! Try `!webhook create` to get started." return base.HandleNewTeam(h.stats, h.DebugOutput, h.kbc, conv, welcomeMsg) } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(ctx context.Context, msg chat1.MsgSummary) error { if msg.Content.Text == nil { return nil } cmd := strings.TrimSpace(msg.Content.Text.Body) switch { case strings.HasPrefix(cmd, "!webhook create"): - return h.handleCreate(cmd, msg) + return h.handleCreate(ctx, cmd, msg) case strings.HasPrefix(cmd, "!webhook list"): - return h.handleList(cmd, msg) + return h.handleList(ctx, cmd, msg) case strings.HasPrefix(cmd, "!webhook remove"): - return h.handleRemove(cmd, msg) + return h.handleRemove(ctx, cmd, msg) } return nil } diff --git a/webhookbot/webhookbot/http.go b/webhookbot/webhookbot/http.go index 6ebb301f..96c80438 100644 --- a/webhookbot/webhookbot/http.go +++ b/webhookbot/webhookbot/http.go @@ -2,6 +2,7 @@ package webhookbot import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -94,7 +95,11 @@ func (h *HTTPSrv) safeWriteToFile(hookName, content string) (string, error) { func (h *HTTPSrv) handleHook(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id := vars["id"] - hook, err := h.db.GetHook(id) + // WithoutCancel: the caller may close the connection immediately after + // sending the webhook payload; DB reads and Keybase message sends must + // complete regardless. + ctx := context.WithoutCancel(r.Context()) + hook, err := h.db.GetHook(ctx, id) if err != nil { h.Stats.Count("handle - not found") h.Debug("handleHook: failed to find hook for ID: %s", id) diff --git a/zoombot/zoombot/db.go b/zoombot/zoombot/db.go index 7c72bbe2..aa392c7e 100644 --- a/zoombot/zoombot/db.go +++ b/zoombot/zoombot/db.go @@ -1,6 +1,7 @@ package zoombot import ( + "context" "database/sql" "github.com/keybase/managed-bots/base" @@ -16,27 +17,23 @@ func NewDB(db *sql.DB) *DB { } } -func (d *DB) CreateUser(userID, accountID, identifier string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - INSERT INTO user - (user_id, account_id, identifier) - VALUES (?, ?, ?) - ON DUPLICATE KEY UPDATE - identifier=identifier -- identifier stays the same - `, userID, accountID, identifier) - return err - }) +func (d *DB) CreateUser(ctx context.Context, userID, accountID, identifier string) error { + _, err := d.ExecContext(ctx, ` + INSERT INTO user + (user_id, account_id, identifier) + VALUES (?, ?, ?) + ON DUPLICATE KEY UPDATE + identifier=identifier -- identifier stays the same + `, userID, accountID, identifier) + return err } -func (d *DB) DeleteUserAndToken(userID, accountID string) error { - return d.RunTxn(func(tx *sql.Tx) error { - _, err := tx.Exec(` - DELETE user, oauth - FROM user - JOIN oauth USING(identifier) - WHERE user_id = ? AND account_id = ? - `, userID, accountID) - return err - }) +func (d *DB) DeleteUserAndToken(ctx context.Context, userID, accountID string) error { + _, err := d.ExecContext(ctx, ` + DELETE user, oauth + FROM user + JOIN oauth USING(identifier) + WHERE user_id = ? AND account_id = ? + `, userID, accountID) + return err } diff --git a/zoombot/zoombot/handler.go b/zoombot/zoombot/handler.go index e7248b00..bea588de 100644 --- a/zoombot/zoombot/handler.go +++ b/zoombot/zoombot/handler.go @@ -34,31 +34,31 @@ func NewHandler(stats *base.StatsRegistry, kbc *kbchat.API, debugConfig *base.Ch } } -func (h *Handler) HandleNewConv(conv chat1.ConvSummary) error { +func (h *Handler) HandleNewConv(_ context.Context, conv chat1.ConvSummary) error { welcomeMsg := "Hello! I can get you set up with a Zoom instant meeting anytime, just send me `!zoom`." return base.HandleNewTeam(h.stats, h.DebugOutput, h.kbc, conv, welcomeMsg) } -func (h *Handler) HandleAuth(msg chat1.MsgSummary, identifier string) error { - token, err := h.db.GetToken(identifier) +func (h *Handler) HandleAuth(ctx context.Context, msg chat1.MsgSummary, identifier string) error { + token, err := h.db.GetToken(ctx, identifier) if err != nil { return fmt.Errorf("error getting token: %s", err) } - client := h.config.Client(context.Background(), token) + client := h.config.Client(ctx, token) user, err := GetUser(client, currentUserID) if err != nil { return err } - err = h.db.CreateUser(user.ID, user.AccountID, identifier) + err = h.db.CreateUser(ctx, user.ID, user.AccountID, identifier) if err != nil { return fmt.Errorf("error creating user entry: %s", err) } - return h.HandleCommand(msg) + return h.HandleCommand(ctx, msg) } -func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { +func (h *Handler) HandleCommand(ctx context.Context, msg chat1.MsgSummary) error { if msg.Content.Text == nil { return nil } @@ -66,21 +66,21 @@ func (h *Handler) HandleCommand(msg chat1.MsgSummary) error { cmd := strings.TrimSpace(msg.Content.Text.Body) if strings.HasPrefix(cmd, "!zoom") { h.stats.Count("zoom") - return h.zoomHandler(msg, 0) + return h.zoomHandler(ctx, msg, 0) } return nil } -func (h *Handler) zoomHandler(msg chat1.MsgSummary, attempts int) error { +func (h *Handler) zoomHandler(ctx context.Context, msg chat1.MsgSummary, attempts int) error { retry := func() error { // retry auth after nuking stored credentials - if err := h.db.DeleteToken(IdentifierFromMsg(msg)); err != nil { + if err := h.db.DeleteToken(ctx, IdentifierFromMsg(msg)); err != nil { return err } attempts++ - return h.zoomHandlerInner(msg, attempts) + return h.zoomHandlerInner(ctx, msg, attempts) } - err := h.zoomHandlerInner(msg, attempts) + err := h.zoomHandlerInner(ctx, msg, attempts) switch err := err.(type) { case nil, base.OAuthRequiredError: return nil @@ -100,9 +100,9 @@ func (h *Handler) zoomHandler(msg chat1.MsgSummary, attempts int) error { } } -func (h *Handler) zoomHandlerInner(msg chat1.MsgSummary, attempts int) error { +func (h *Handler) zoomHandlerInner(ctx context.Context, msg chat1.MsgSummary, attempts int) error { identifier := IdentifierFromMsg(msg) - client, err := base.GetOAuthClient(identifier, msg, h.kbc, h.config, h.db, + client, err := base.GetOAuthClient(ctx, identifier, msg, h.kbc, h.config, h.db, base.GetOAuthOpts{ AuthMessageTemplate: "Authorize me by clicking this link:\n%s", OAuthOfflineAccessType: true, @@ -127,7 +127,7 @@ func (h *Handler) zoomHandlerInner(msg chat1.MsgSummary, attempts int) error { attempts++ h.Debug("zoomHandlerInner: retrying attempt #%d: %v", attempts, err) time.Sleep(500 * time.Millisecond) - err := h.zoomHandler(msg, attempts) + err := h.zoomHandler(ctx, msg, attempts) switch err := err.(type) { case nil, base.OAuthRequiredError: default: diff --git a/zoombot/zoombot/http.go b/zoombot/zoombot/http.go index f55c72b7..1d456926 100644 --- a/zoombot/zoombot/http.go +++ b/zoombot/zoombot/http.go @@ -2,6 +2,7 @@ package zoombot import ( "bytes" + "context" "crypto/hmac" "crypto/sha256" "encoding/hex" @@ -111,7 +112,10 @@ func (h *HTTPSrv) zoomDeauthorize(w http.ResponseWriter, r *http.Request) { return } - err = h.db.DeleteUserAndToken(deauthorizationRequest.Payload.UserID, deauthorizationRequest.Payload.AccountID) + // WithoutCancel: Zoom sends the deauthorization webhook and may close the + // connection immediately; the token deletion must complete regardless. + ctx := context.WithoutCancel(r.Context()) + err = h.db.DeleteUserAndToken(ctx, deauthorizationRequest.Payload.UserID, deauthorizationRequest.Payload.AccountID) if err != nil { h.Errorf("zoomDeauthorize: unable to delete user: %s", err) http.Error(w, "unable to delete user", http.StatusBadRequest)