@@ -84,15 +84,17 @@ func (i *Interceptor) Disable() {
8484}
8585
8686// StmtQueryContext intecepts database/sql's stmt.QueryContext calls from a prepared statement.
87- func (i * Interceptor ) StmtQueryContext (ctx context.Context , conn driver.StmtQueryContext , query string , args []driver.NamedValue ) (driver.Rows , error ) {
87+ func (i * Interceptor ) StmtQueryContext (ctx context.Context , conn driver.StmtQueryContext , query string , args []driver.NamedValue ) (context. Context , driver.Rows , error ) {
8888
8989 if i .disabled {
90- return conn .QueryContext (ctx , args )
90+ r , err := conn .QueryContext (ctx , args )
91+ return ctx , r , err
9192 }
9293
9394 attrs := getAttrs (query )
9495 if attrs == nil {
95- return conn .QueryContext (ctx , args )
96+ r , err := conn .QueryContext (ctx , args )
97+ return ctx , r , err
9698 }
9799
98100 hash , err := i .hashFunc (query , args )
@@ -101,20 +103,21 @@ func (i *Interceptor) StmtQueryContext(ctx context.Context, conn driver.StmtQuer
101103 if i .onErr != nil {
102104 i .onErr (fmt .Errorf ("HashFunc failed: %w" , err ))
103105 }
104- return conn .QueryContext (ctx , args )
106+ r , err := conn .QueryContext (ctx , args )
107+ return ctx , r , err
105108 }
106109
107- if cached := i .checkCache (hash ); cached != nil {
108- return cached , nil
110+ if cached := i .checkCache (ctx , hash ); cached != nil {
111+ return ctx , cached , nil
109112 }
110113
111114 rows , err := conn .QueryContext (ctx , args )
112115 if err != nil {
113- return rows , err
116+ return ctx , rows , err
114117 }
115118
116119 cacheSetter := func (item * cache.Item ) {
117- err := i .c .Set (hash , item , time .Duration (attrs .ttl )* time .Second )
120+ err := i .c .Set (ctx , hash , item , time .Duration (attrs .ttl )* time .Second )
118121 if err != nil {
119122 atomic .AddUint64 (& i .stats .Errors , 1 )
120123 if i .onErr != nil {
@@ -123,19 +126,21 @@ func (i *Interceptor) StmtQueryContext(ctx context.Context, conn driver.StmtQuer
123126 }
124127 }
125128
126- return newRowsRecorder (cacheSetter , rows , attrs .maxRows ), err
129+ return ctx , newRowsRecorder (cacheSetter , rows , attrs .maxRows ), err
127130}
128131
129132// ConnQueryContext intecepts database/sql's DB.QueryContext Conn.QueryContext calls.
130- func (i * Interceptor ) ConnQueryContext (ctx context.Context , conn driver.QueryerContext , query string , args []driver.NamedValue ) (driver.Rows , error ) {
133+ func (i * Interceptor ) ConnQueryContext (ctx context.Context , conn driver.QueryerContext , query string , args []driver.NamedValue ) (context. Context , driver.Rows , error ) {
131134
132135 if i .disabled {
133- return conn .QueryContext (ctx , query , args )
136+ r , err := conn .QueryContext (ctx , query , args )
137+ return ctx , r , err
134138 }
135139
136140 attrs := getAttrs (query )
137141 if attrs == nil {
138- return conn .QueryContext (ctx , query , args )
142+ r , err := conn .QueryContext (ctx , query , args )
143+ return ctx , r , err
139144 }
140145
141146 hash , err := i .hashFunc (query , args )
@@ -144,20 +149,21 @@ func (i *Interceptor) ConnQueryContext(ctx context.Context, conn driver.QueryerC
144149 if i .onErr != nil {
145150 i .onErr (fmt .Errorf ("HashFunc failed: %w" , err ))
146151 }
147- return conn .QueryContext (ctx , query , args )
152+ r , err := conn .QueryContext (ctx , query , args )
153+ return ctx , r , err
148154 }
149155
150- if cached := i .checkCache (hash ); cached != nil {
151- return cached , nil
156+ if cached := i .checkCache (ctx , hash ); cached != nil {
157+ return ctx , cached , nil
152158 }
153159
154160 rows , err := conn .QueryContext (ctx , query , args )
155161 if err != nil {
156- return rows , err
162+ return ctx , rows , err
157163 }
158164
159165 cacheSetter := func (item * cache.Item ) {
160- err := i .c .Set (hash , item , time .Duration (attrs .ttl )* time .Second )
166+ err := i .c .Set (ctx , hash , item , time .Duration (attrs .ttl )* time .Second )
161167 if err != nil {
162168 atomic .AddUint64 (& i .stats .Errors , 1 )
163169 if i .onErr != nil {
@@ -166,11 +172,11 @@ func (i *Interceptor) ConnQueryContext(ctx context.Context, conn driver.QueryerC
166172 }
167173 }
168174
169- return newRowsRecorder (cacheSetter , rows , attrs .maxRows ), err
175+ return ctx , newRowsRecorder (cacheSetter , rows , attrs .maxRows ), err
170176}
171177
172- func (i * Interceptor ) checkCache (hash string ) driver.Rows {
173- item , ok , err := i .c .Get (hash )
178+ func (i * Interceptor ) checkCache (ctx context. Context , hash string ) driver.Rows {
179+ item , ok , err := i .c .Get (ctx , hash )
174180 if err != nil {
175181 atomic .AddUint64 (& i .stats .Errors , 1 )
176182 if i .onErr != nil {
0 commit comments