Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ linters:
- bodyclose # Checks HTTP response body closed

settings:
errcheck:
# defer rows.Close() / stmt.Close() etc. are idiomatic and safe to ignore.
exclude-functions:
- (*database/sql.Rows).Close
- (*database/sql.Stmt).Close
- (*database/sql.DB).Close
- (*database/sql.Tx).Close

gocritic:
disabled-checks:
- ifElseChain
Expand Down
77 changes: 48 additions & 29 deletions base/db.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package base

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"time"

"github.com/go-sql-driver/mysql"
"golang.org/x/oauth2"
)

Expand All @@ -21,18 +25,46 @@ func NewDB(db *sql.DB) *DB {
}
}

func (d *DB) RunTxn(fn func(tx *sql.Tx) error) error {
tx, err := d.Begin()
func (d *DB) RunTxnContext(ctx context.Context, fn func(tx *sql.Tx) error) error {
tx, err := d.BeginTx(ctx, nil)
if err != nil {
return err
}
if err := fn(tx); err != nil {
if rerr := tx.Rollback(); rerr != nil {
if rerr := tx.Rollback(); rerr != nil && !errors.Is(rerr, sql.ErrTxDone) {
fmt.Printf("unable to rollback: %v", rerr)
}
return normalizeCtxErr(ctx, err)
}
return normalizeCommitErr(ctx, tx.Commit())
}

// normalizeCommitErr recovers the true cause of a failed Commit. When the
// context is cancelled, the driver rolls back the transaction and marks it
// done, so tx.Commit() returns sql.ErrTxDone instead of context.Canceled.
// Returning ctx.Err() lets callers distinguish a client abort from a real DB error.
func normalizeCommitErr(ctx context.Context, err error) error {
if err == nil {
return nil
}
if ctxErr := ctx.Err(); ctxErr != nil && errors.Is(err, sql.ErrTxDone) {
return ctxErr
}
return err
}

// normalizeCtxErr maps opaque driver errors (ErrBadConn, ErrInvalidConn) back
// to ctx.Err() when the context is already done. This lets callers distinguish
// a client abort from a real storage failure without checking at every call site.
func normalizeCtxErr(ctx context.Context, err error) error {
if err == nil || errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return err
}
return tx.Commit()
if ctxErr := ctx.Err(); ctxErr != nil &&
(errors.Is(err, driver.ErrBadConn) || errors.Is(err, mysql.ErrInvalidConn)) {
return ctxErr
}
return err
}

type BaseOAuthDB struct { //nolint
Expand All @@ -45,9 +77,9 @@ func NewBaseOAuthDB(db *sql.DB) *BaseOAuthDB {
}
}

func (d *BaseOAuthDB) GetState(state string) (*OAuthRequest, error) {
func (d *BaseOAuthDB) GetState(ctx context.Context, state string) (*OAuthRequest, error) {
var oauthState OAuthRequest
row := d.QueryRow(`SELECT identifier, conv_id, msg_id, is_complete
row := d.QueryRowContext(ctx, `SELECT identifier, conv_id, msg_id, is_complete
FROM oauth_state
WHERE state = ?`, state)
err := row.Scan(&oauthState.TokenIdentifier, &oauthState.ConvID,
Expand All @@ -62,28 +94,22 @@ func (d *BaseOAuthDB) GetState(state string) (*OAuthRequest, error) {
}
}

func (d *BaseOAuthDB) PutState(state string, oauthState *OAuthRequest) error {
err := d.RunTxn(func(tx *sql.Tx) error {
_, err := tx.Exec(`INSERT INTO oauth_state
func (d *BaseOAuthDB) PutState(ctx context.Context, state string, oauthState *OAuthRequest) error {
_, err := d.ExecContext(ctx, `INSERT INTO oauth_state
(state, identifier, conv_id, msg_id)
VALUES (?, ?, ?, ?)
ON DUPLICATE KEY UPDATE
identifier=VALUES(identifier),
conv_id=VALUES(conv_id),
msg_id=VALUES(msg_id)
`, state, oauthState.TokenIdentifier, oauthState.ConvID, oauthState.MsgID)
return err
})
return err
}

func (d *BaseOAuthDB) CompleteState(state string) error {
err := d.RunTxn(func(tx *sql.Tx) error {
_, err := tx.Exec(`UPDATE oauth_state
func (d *BaseOAuthDB) CompleteState(ctx context.Context, state string) error {
_, err := d.ExecContext(ctx, `UPDATE oauth_state
SET is_complete=true
WHERE state = ?`, state)
return err
})
return err
}

Expand All @@ -97,10 +123,10 @@ func NewOAuthDB(db *sql.DB) *OAuthDB {
}
}

func (d *OAuthDB) GetToken(identifier string) (*oauth2.Token, error) {
func (d *OAuthDB) GetToken(ctx context.Context, identifier string) (*oauth2.Token, error) {
var token oauth2.Token
var expiry int64
row := d.QueryRow(`SELECT access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(expiry))
row := d.QueryRowContext(ctx, `SELECT access_token, token_type, refresh_token, ROUND(UNIX_TIMESTAMP(expiry))
FROM oauth
WHERE identifier = ?`, identifier)
err := row.Scan(&token.AccessToken, &token.TokenType,
Expand All @@ -116,9 +142,8 @@ func (d *OAuthDB) GetToken(identifier string) (*oauth2.Token, error) {
}
}

func (d *OAuthDB) PutToken(identifier string, token *oauth2.Token) error {
err := d.RunTxn(func(tx *sql.Tx) error {
_, err := tx.Exec(`INSERT INTO oauth
func (d *OAuthDB) PutToken(ctx context.Context, identifier string, token *oauth2.Token) error {
_, err := d.ExecContext(ctx, `INSERT INTO oauth
(identifier, access_token, token_type, refresh_token, expiry, ctime, mtime)
VALUES (?, ?, ?, ?, ?, NOW(), NOW())
ON DUPLICATE KEY UPDATE
Expand All @@ -127,16 +152,10 @@ func (d *OAuthDB) PutToken(identifier string, token *oauth2.Token) error {
expiry=VALUES(expiry),
mtime=VALUES(mtime)
`, identifier, token.AccessToken, token.TokenType, token.RefreshToken, token.Expiry)
return err
})
return err
}

func (d *OAuthDB) DeleteToken(identifier string) error {
err := d.RunTxn(func(tx *sql.Tx) error {
_, err := tx.Exec(`DELETE FROM oauth
WHERE identifier = ?`, identifier)
return err
})
func (d *OAuthDB) DeleteToken(ctx context.Context, identifier string) error {
_, err := d.ExecContext(ctx, `DELETE FROM oauth WHERE identifier = ?`, identifier)
return err
}
28 changes: 12 additions & 16 deletions base/multi.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package base

import (
"context"
"database/sql"
"fmt"
"sync"
Expand Down Expand Up @@ -58,20 +59,17 @@ func (m *multi) IsLeader() bool {
}

func (m *multi) heartbeat() {
ctx := context.Background()
// update ourselves first
err := m.db.RunTxn(func(tx *sql.Tx) error {
_, err := tx.Exec(`
INSERT INTO heartbeats (id, name, mtime)
VALUES (?, ?, NOW(6)) ON DUPLICATE KEY UPDATE mtime=NOW(6)
`, m.id, m.name)
return err
})
if err != nil {
if _, err := m.db.ExecContext(ctx, `
INSERT INTO heartbeats (id, name, mtime)
VALUES (?, ?, NOW(6)) ON DUPLICATE KEY UPDATE mtime=NOW(6)
`, m.id, m.name); err != nil {
m.Errorf("failed to register heartbeat tx: %s", err)
return
}
// see if we are the leader
row := m.db.QueryRow(fmt.Sprintf(`
row := m.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT id FROM heartbeats
WHERE mtime > NOW(6) - INTERVAL %d SECOND AND name = ?
ORDER BY id DESC
Expand Down Expand Up @@ -101,13 +99,11 @@ func (m *multi) heartbeat() {
}

func (m *multi) deregister() {
err := m.db.RunTxn(func(tx *sql.Tx) error {
_, err := tx.Exec(`
DELETE from heartbeats
WHERE id = ? OR mtime < NOW() - INTERVAL 1 MINUTE
`, m.id)
return err
})
ctx := context.Background()
_, err := m.db.ExecContext(ctx, `
DELETE from heartbeats
WHERE id = ? OR mtime < NOW() - INTERVAL 1 MINUTE
`, m.id)
if err != nil {
m.Errorf("deregister: failed to execute : %s", err)
}
Expand Down
40 changes: 22 additions & 18 deletions base/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ func (e OAuthRequiredError) Error() string {
}

type OAuthStorage interface {
GetToken(identifier string) (*oauth2.Token, error)
PutToken(identifier string, token *oauth2.Token) error
DeleteToken(identifier string) error
GetToken(ctx context.Context, identifier string) (*oauth2.Token, error)
PutToken(ctx context.Context, identifier string, token *oauth2.Token) error
DeleteToken(ctx context.Context, identifier string) error

GetState(state string) (*OAuthRequest, error)
PutState(state string, req *OAuthRequest) error
CompleteState(state string) error
GetState(ctx context.Context, state string) (*OAuthRequest, error)
PutState(ctx context.Context, state string, req *OAuthRequest) error
CompleteState(ctx context.Context, state string) error
}

type OAuthHTTPSrv struct {
*HTTPSrv
kbc *kbchat.API
oauth *oauth2.Config
storage OAuthStorage
callback func(msg chat1.MsgSummary, identifier string) error
callback func(ctx context.Context, msg chat1.MsgSummary, identifier string) error
htmlTitle string
htmlLogoB64 string
htmlLogoSrc string
Expand All @@ -48,7 +48,7 @@ func NewOAuthHTTPSrv(
debugConfig *ChatDebugOutputConfig,
oauth *oauth2.Config,
storage OAuthStorage,
callback func(msg chat1.MsgSummary, identifier string) error,
callback func(ctx context.Context, msg chat1.MsgSummary, identifier string) error,
htmlTitle string,
htmlLogoB64 string,
urlPrefix string,
Expand Down Expand Up @@ -100,8 +100,11 @@ func (o *OAuthHTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) {

query := r.URL.Query()
state := query.Get("state")
// WithoutCancel: the browser may close after the redirect; DB writes and the
// HandleAuth callback (which sends Keybase messages) must finish regardless.
ctx := context.WithoutCancel(r.Context())

req, err := o.storage.GetState(state)
req, err := o.storage.GetState(ctx, state)
if err != nil {
err = fmt.Errorf("could not get state %q: %v", state, err)
return
Expand All @@ -128,23 +131,23 @@ func (o *OAuthHTTPSrv) oauthHandler(w http.ResponseWriter, r *http.Request) {
o.showOAuthError(w)
return
}
token, err := o.oauth.Exchange(context.TODO(), code)
token, err := o.oauth.Exchange(ctx, code)
if err != nil {
return
}

if err = o.storage.PutToken(req.TokenIdentifier, token); err != nil {
if err = o.storage.PutToken(ctx, req.TokenIdentifier, token); err != nil {
return
}
if err = o.storage.CompleteState(state); err != nil {
if err = o.storage.CompleteState(ctx, state); err != nil {
return
}
callbackMsg, err := o.getCallbackMsg(*req)
if err != nil {
return
}

if err = o.callback(callbackMsg, req.TokenIdentifier); err != nil {
if err = o.callback(ctx, callbackMsg, req.TokenIdentifier); err != nil {
return
}

Expand Down Expand Up @@ -189,14 +192,15 @@ type GetOAuthOpts struct {
}

func GetOAuthClient(
ctx context.Context,
tokenIdentifier string,
callbackMsg chat1.MsgSummary,
kbc *kbchat.API,
config *oauth2.Config,
storage OAuthStorage,
opts GetOAuthOpts,
) (*http.Client, error) {
token, err := storage.GetToken(tokenIdentifier)
token, err := storage.GetToken(ctx, tokenIdentifier)
if err != nil {
return nil, err
}
Expand All @@ -216,7 +220,7 @@ func GetOAuthClient(
if err != nil {
return nil, err
}
if err := storage.PutState(state, &OAuthRequest{
if err := storage.PutState(ctx, state, &OAuthRequest{
TokenIdentifier: tokenIdentifier,
ConvID: callbackMsg.ConvID,
MsgID: callbackMsg.Id,
Expand Down Expand Up @@ -257,16 +261,16 @@ func GetOAuthClient(
}
// renew token
if token.Expiry.Before(time.Now()) {
newToken, err := config.TokenSource(context.Background(), token).Token()
newToken, err := config.TokenSource(ctx, token).Token()
if err != nil {
return nil, fmt.Errorf("unable to renew token: %s", err)
}
err = storage.PutToken(tokenIdentifier, newToken)
err = storage.PutToken(ctx, tokenIdentifier, newToken)
if err != nil {
return nil, fmt.Errorf("unable to update token: %s", err)
}
token = newToken
}

return config.Client(context.Background(), token), nil
return config.Client(ctx, token), nil
}
Loading
Loading