Skip to content

Commit 1eaf7c5

Browse files
committed
feat: add context to cache interface
1 parent 76d3c63 commit 1eaf7c5

2 files changed

Lines changed: 29 additions & 22 deletions

File tree

cache/cache.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cache
22

33
import (
4+
"context"
45
"database/sql/driver"
56
"time"
67
)
@@ -17,7 +18,7 @@ type Cacher interface {
1718
// Get must return a pointer to the item, a boolean representing whether
1819
// item is present or not, and an error (must be nil when key is not
1920
// present).
20-
Get(key string) (*Item, bool, error)
21+
Get(ctx context.Context, key string) (*Item, bool, error)
2122
// Set sets the item into cache with the given TTL.
22-
Set(key string, item *Item, ttl time.Duration) error
23+
Set(ctx context.Context, key string, item *Item, ttl time.Duration) error
2324
}

interceptor.go

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,17 @@ func (i *Interceptor) Disable() {
8484
}
8585

8686
// StmtQueryContext intecepts database/sql's stmt.QueryContext calls from a prepared statement.
87-
func (i *Interceptor) StmtQueryContext(ctx context.Context, conn driver.StmtQueryContext, query string, args []driver.NamedValue) (driver.Rows, error) {
87+
func (i *Interceptor) StmtQueryContext(ctx context.Context, conn driver.StmtQueryContext, query string, args []driver.NamedValue) (context.Context, driver.Rows, error) {
8888

8989
if i.disabled {
90-
return conn.QueryContext(ctx, args)
90+
r , err := conn.QueryContext(ctx, args)
91+
return ctx, r, err
9192
}
9293

9394
attrs := getAttrs(query)
9495
if attrs == nil {
95-
return conn.QueryContext(ctx, args)
96+
r , err := conn.QueryContext(ctx, args)
97+
return ctx, r, err
9698
}
9799

98100
hash, err := i.hashFunc(query, args)
@@ -101,20 +103,21 @@ func (i *Interceptor) StmtQueryContext(ctx context.Context, conn driver.StmtQuer
101103
if i.onErr != nil {
102104
i.onErr(fmt.Errorf("HashFunc failed: %w", err))
103105
}
104-
return conn.QueryContext(ctx, args)
106+
r , err := conn.QueryContext(ctx, args)
107+
return ctx, r, err
105108
}
106109

107-
if cached := i.checkCache(hash); cached != nil {
108-
return cached, nil
110+
if cached := i.checkCache(ctx, hash); cached != nil {
111+
return ctx, cached, nil
109112
}
110113

111114
rows, err := conn.QueryContext(ctx, args)
112115
if err != nil {
113-
return rows, err
116+
return ctx, rows, err
114117
}
115118

116119
cacheSetter := func(item *cache.Item) {
117-
err := i.c.Set(hash, item, time.Duration(attrs.ttl)*time.Second)
120+
err := i.c.Set(ctx, hash, item, time.Duration(attrs.ttl)*time.Second)
118121
if err != nil {
119122
atomic.AddUint64(&i.stats.Errors, 1)
120123
if i.onErr != nil {
@@ -123,19 +126,21 @@ func (i *Interceptor) StmtQueryContext(ctx context.Context, conn driver.StmtQuer
123126
}
124127
}
125128

126-
return newRowsRecorder(cacheSetter, rows, attrs.maxRows), err
129+
return ctx, newRowsRecorder(cacheSetter, rows, attrs.maxRows), err
127130
}
128131

129132
// ConnQueryContext intecepts database/sql's DB.QueryContext Conn.QueryContext calls.
130-
func (i *Interceptor) ConnQueryContext(ctx context.Context, conn driver.QueryerContext, query string, args []driver.NamedValue) (driver.Rows, error) {
133+
func (i *Interceptor) ConnQueryContext(ctx context.Context, conn driver.QueryerContext, query string, args []driver.NamedValue) (context.Context, driver.Rows, error) {
131134

132135
if i.disabled {
133-
return conn.QueryContext(ctx, query, args)
136+
r, err := conn.QueryContext(ctx, query, args)
137+
return ctx, r, err
134138
}
135139

136140
attrs := getAttrs(query)
137141
if attrs == nil {
138-
return conn.QueryContext(ctx, query, args)
142+
r, err := conn.QueryContext(ctx, query, args)
143+
return ctx, r, err
139144
}
140145

141146
hash, err := i.hashFunc(query, args)
@@ -144,20 +149,21 @@ func (i *Interceptor) ConnQueryContext(ctx context.Context, conn driver.QueryerC
144149
if i.onErr != nil {
145150
i.onErr(fmt.Errorf("HashFunc failed: %w", err))
146151
}
147-
return conn.QueryContext(ctx, query, args)
152+
r, err := conn.QueryContext(ctx, query, args)
153+
return ctx, r, err
148154
}
149155

150-
if cached := i.checkCache(hash); cached != nil {
151-
return cached, nil
156+
if cached := i.checkCache(ctx, hash); cached != nil {
157+
return ctx, cached, nil
152158
}
153159

154160
rows, err := conn.QueryContext(ctx, query, args)
155161
if err != nil {
156-
return rows, err
162+
return ctx, rows, err
157163
}
158164

159165
cacheSetter := func(item *cache.Item) {
160-
err := i.c.Set(hash, item, time.Duration(attrs.ttl)*time.Second)
166+
err := i.c.Set(ctx, hash, item, time.Duration(attrs.ttl)*time.Second)
161167
if err != nil {
162168
atomic.AddUint64(&i.stats.Errors, 1)
163169
if i.onErr != nil {
@@ -166,11 +172,11 @@ func (i *Interceptor) ConnQueryContext(ctx context.Context, conn driver.QueryerC
166172
}
167173
}
168174

169-
return newRowsRecorder(cacheSetter, rows, attrs.maxRows), err
175+
return ctx, newRowsRecorder(cacheSetter, rows, attrs.maxRows), err
170176
}
171177

172-
func (i *Interceptor) checkCache(hash string) driver.Rows {
173-
item, ok, err := i.c.Get(hash)
178+
func (i *Interceptor) checkCache(ctx context.Context, hash string) driver.Rows {
179+
item, ok, err := i.c.Get(ctx, hash)
174180
if err != nil {
175181
atomic.AddUint64(&i.stats.Errors, 1)
176182
if i.onErr != nil {

0 commit comments

Comments
 (0)