Skip to content

Commit da51567

Browse files
proxy multi database support
1 parent b343028 commit da51567

8 files changed

Lines changed: 197 additions & 41 deletions

File tree

cmd/sqlite-http-proxy/main.go

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
var (
2222
port uint
23+
dbParams string
2324
allowHTTP2 bool
2425
verbose bool
2526

@@ -35,6 +36,7 @@ var (
3536

3637
func main() {
3738
flag.UintVar(&port, "p", 8080, "Server port")
39+
flag.StringVar(&dbParams, "db-params", "_journal=WAL&_sync=NORMAL&_timeout=5000&_txlock=immediate", "Database connection params")
3840
flag.BoolVar(&verbose, "v", false, "Enable verbose mode")
3941
flag.BoolVar(&allowHTTP2, "h2", false, "Allow HTTP2")
4042
flag.UintVar(&ttl, "ttl", 0, "Time to Live in seconds (0 is infinite time)")
@@ -45,36 +47,57 @@ func main() {
4547
flag.BoolVar(&readOnly, "ro", false, "Read Only mode. Do not store new HTTP responses")
4648
flag.Parse()
4749

48-
if len(flag.Args()) != 1 {
49-
log.Fatalf("Usage: %s <flags> [DSN]\n\nExample:\n\t%s file:example.db\n", os.Args[0], os.Args[0])
50+
if len(flag.Args()) == 0 {
51+
log.Fatalf("Usage: %s <flags> [DatabasePath1] [DatabasePathN\n\nExample:\n\t%s example.db example2.db example3.db\n", os.Args[0], os.Args[0])
5052
}
51-
dsn := flag.Args()[0]
5253

53-
sqlDB, err := sql.Open("sqlite3", dsn)
54-
if err != nil {
55-
log.Fatalf("open db error: %v", err)
56-
}
57-
defer sqlDB.Close()
54+
dbs := make([]*sql.DB, 0)
55+
var (
56+
repository db.Repository
57+
tableList []string
58+
err error
59+
)
60+
for _, file := range flag.Args() {
61+
var dsn string
62+
if file == ":memory:" {
63+
dsn = file + "?cache=shared"
64+
} else {
65+
dsn = fmt.Sprintf("file:%s?%s", file, dbParams)
66+
}
5867

59-
var tableList []string
60-
if responseTables == "" {
61-
tableList, err = db.ResponseTables(sqlDB)
68+
sqlDB, err := sql.Open("sqlite3", dsn)
6269
if err != nil {
63-
log.Fatalf("discovery response tables: %v", err)
70+
log.Fatalf("open db error: %v", err)
6471
}
65-
} else {
66-
tableList = strings.Split(responseTables, ",")
67-
if forceCreateTables {
68-
err := db.CreateResponseTables(sqlDB, tableList...)
72+
defer sqlDB.Close()
73+
74+
dbs = append(dbs, sqlDB)
75+
76+
if responseTables == "" {
77+
tableList, err = db.ResponseTables(sqlDB)
6978
if err != nil {
70-
log.Fatalf("force create tables: %v", err)
79+
log.Fatalf("discovery response tables: %v", err)
80+
}
81+
} else {
82+
tableList = strings.Split(responseTables, ",")
83+
if forceCreateTables {
84+
err := db.CreateResponseTables(sqlDB, tableList...)
85+
if err != nil {
86+
log.Fatalf("force create tables: %v", err)
87+
}
7188
}
7289
}
7390
}
74-
75-
repository, err := db.NewRepository(sqlDB, tableList...)
76-
if err != nil {
77-
log.Fatalf("new repository: %v", err)
91+
if len(dbs) == 1 {
92+
repository, err = db.NewRepository(dbs[0], tableList...)
93+
if err != nil {
94+
log.Fatalf("new repository: %v", err)
95+
}
96+
} else {
97+
repository, err = db.NewMultiDatabaseRepository(dbs)
98+
if err != nil {
99+
log.Fatalf("new multi database repository: %v", err)
100+
}
78101
}
79102
defer repository.Close()
80103

cmd/sqlite-http-proxy/request.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"errors"
7+
"fmt"
78
"log/slog"
89
"net/http"
910
"time"
@@ -33,13 +34,13 @@ func (h *requestHandler) Handle(r *http.Request, ctx *goproxy.ProxyCtx) (*http.R
3334
slog.Error("database query", "error", err.Error())
3435
}
3536
// tell the responseHandler to save the new response data
36-
ctx.UserData = ""
37+
ctx.UserData = ":-1"
3738
return r, nil
3839
}
3940

4041
if !readOnly && uint(time.Since(resp.Timestamp).Seconds()) > ttl {
4142
// data is too old, tell the responseHandler to save the new data
42-
ctx.UserData = resp.TableName
43+
ctx.UserData = fmt.Sprintf("%s:%d", resp.TableName, resp.DatabaseID)
4344
return r, nil
4445
}
4546
if verbose {

cmd/sqlite-http-proxy/response.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"log/slog"
66
"net/http"
7+
"strconv"
8+
"strings"
79

810
"github.com/elazarl/goproxy"
911
"github.com/walterwanderley/sqlite-http-cache/db"
@@ -26,7 +28,12 @@ func (h *responseHandler) Handle(resp *http.Response, ctx *goproxy.ProxyCtx) *ht
2628
if err != nil {
2729
slog.Error("adapter response body", "error", err)
2830
} else {
29-
responseDB.TableName = ctx.UserData.(string)
31+
userData := ctx.UserData.(string)
32+
tableName, databaseID, ok := strings.Cut(userData, ":")
33+
if ok {
34+
responseDB.DatabaseID, _ = strconv.Atoi(databaseID)
35+
}
36+
responseDB.TableName = tableName
3037
err := h.writer.Write(context.Background(), ctx.Req.URL.String(), responseDB)
3138
if err != nil {
3239
slog.Error("recording response", "error", err, "url", ctx.Req.URL.String(), "status", resp.StatusCode)

db/concurrent_repository.go

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@ type concurrentRepository struct {
1616
// roundRobin strategy to choose one writer
1717
currentWriter int
1818
muWriter sync.Mutex
19+
20+
// workers for MultiDatabaseRepository
21+
source chan work
22+
quit chan struct{}
1923
}
2024

21-
func newConcurrentRepository(db *sql.DB, tableNames ...string) (*concurrentRepository, error) {
25+
func newConcurrentRepository(db *sql.DB, databaseID int, tableNames ...string) (*concurrentRepository, error) {
2226
globalQuit := make(chan struct{})
2327

2428
size := len(tableNames)
@@ -29,7 +33,7 @@ func newConcurrentRepository(db *sql.DB, tableNames ...string) (*concurrentRepos
2933
if err != nil {
3034
return nil, fmt.Errorf("prepare read query for %q: %w", tableName, err)
3135
}
32-
q := newQuerier(readStmt, tableName)
36+
q := newQuerier(readStmt, databaseID, tableName)
3337
queriers[i] = q
3438
q.start()
3539

@@ -111,27 +115,51 @@ func (r *concurrentRepository) Write(ctx context.Context, url string, resp *Resp
111115
}
112116

113117
func (r *concurrentRepository) Close() error {
118+
if r.quit != nil {
119+
close(r.quit)
120+
}
114121
close(r.globalQuit)
115122
return nil
116123
}
117124

125+
func (r *concurrentRepository) start() {
126+
r.source = make(chan work, 10)
127+
go func() {
128+
for {
129+
select {
130+
case unitWork := <-r.source:
131+
resp, err := r.FindByURL(unitWork.ctx, unitWork.url)
132+
if err != nil {
133+
unitWork.resp <- nil
134+
continue
135+
}
136+
unitWork.resp <- resp
137+
case <-r.quit:
138+
return
139+
}
140+
}
141+
}()
142+
}
143+
118144
type work struct {
119145
ctx context.Context
120146
url string
121147
resp chan *Response
122148
}
123149

124150
type querier struct {
125-
source chan work
126-
quit chan struct{}
127-
stmt *sql.Stmt
128-
tableName string
151+
source chan work
152+
quit chan struct{}
153+
stmt *sql.Stmt
154+
tableName string
155+
databaseID int
129156
}
130157

131-
func newQuerier(stmt *sql.Stmt, tableName string) *querier {
158+
func newQuerier(stmt *sql.Stmt, databaseID int, tableName string) *querier {
132159
return &querier{
133-
stmt: stmt,
134-
tableName: tableName,
160+
stmt: stmt,
161+
tableName: tableName,
162+
databaseID: databaseID,
135163
}
136164
}
137165

@@ -147,6 +175,7 @@ func (q *querier) start() {
147175
continue
148176
}
149177
response.TableName = q.tableName
178+
response.DatabaseID = q.databaseID
150179
unitWork.resp <- response
151180

152181
case <-q.quit:

db/multi_database.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package db
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"errors"
7+
"sync"
8+
)
9+
10+
type MultiDatabaseRepository struct {
11+
dbs []*sql.DB
12+
concurrentRepositories []*concurrentRepository
13+
// roundRobin strategy to choose one writer
14+
currentWriter int
15+
muWriter sync.Mutex
16+
}
17+
18+
func NewMultiDatabaseRepository(dbs []*sql.DB) (*MultiDatabaseRepository, error) {
19+
concurrentRepositories := make([]*concurrentRepository, len(dbs))
20+
for i, db := range dbs {
21+
tables, err := ResponseTables(db)
22+
if err != nil {
23+
return nil, err
24+
}
25+
cr, err := newConcurrentRepository(db, i, tables...)
26+
if err != nil {
27+
return nil, err
28+
}
29+
cr.start()
30+
concurrentRepositories[i] = cr
31+
}
32+
return &MultiDatabaseRepository{
33+
dbs: dbs,
34+
concurrentRepositories: concurrentRepositories,
35+
}, nil
36+
}
37+
38+
func (r *MultiDatabaseRepository) FindByURL(ctx context.Context, url string) (*Response, error) {
39+
size := len(r.concurrentRepositories)
40+
41+
ctx, cancel := context.WithCancel(ctx)
42+
defer cancel()
43+
44+
respCh := make(chan *Response, size)
45+
unitWork := work{
46+
ctx: ctx,
47+
url: url,
48+
resp: respCh,
49+
}
50+
51+
go func() {
52+
for _, cr := range r.concurrentRepositories {
53+
cr.source <- unitWork
54+
}
55+
}()
56+
var count int
57+
for {
58+
select {
59+
case resp := <-respCh:
60+
count++
61+
if resp != nil {
62+
return resp, nil
63+
}
64+
if count == size {
65+
return nil, sql.ErrNoRows
66+
}
67+
case <-ctx.Done():
68+
return nil, ctx.Err()
69+
}
70+
}
71+
}
72+
73+
func (r *MultiDatabaseRepository) Write(ctx context.Context, url string, resp *Response) error {
74+
var cr Repository
75+
if resp.DatabaseID == -1 {
76+
r.muWriter.Lock()
77+
defer r.muWriter.Unlock()
78+
r.currentWriter++
79+
if r.currentWriter >= len(r.concurrentRepositories) {
80+
r.currentWriter = 0
81+
}
82+
cr = r.concurrentRepositories[r.currentWriter]
83+
} else {
84+
cr = r.concurrentRepositories[resp.DatabaseID]
85+
}
86+
return cr.Write(ctx, url, resp)
87+
}
88+
89+
func (r *MultiDatabaseRepository) Close() error {
90+
var err error
91+
for _, cr := range r.concurrentRepositories {
92+
err = errors.Join(err, cr.Close())
93+
}
94+
return err
95+
}

db/reository.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ type Repository interface {
2020
}
2121

2222
type Response struct {
23-
Status int
24-
Body io.ReadCloser
25-
Headers map[string][]string
26-
Timestamp time.Time
27-
TableName string
23+
Status int
24+
Body io.ReadCloser
25+
Headers map[string][]string
26+
Timestamp time.Time
27+
DatabaseID int
28+
TableName string
2829
}
2930

3031
func NewRepository(db *sql.DB, tableNames ...string) (Repository, error) {
@@ -43,7 +44,7 @@ func NewRepository(db *sql.DB, tableNames ...string) (Repository, error) {
4344
return newSingleRepository(db, tableNames[0])
4445
}
4546

46-
return newConcurrentRepository(db, tableNames...)
47+
return newConcurrentRepository(db, 0, tableNames...)
4748
}
4849

4950
func CreateResponseTables(db *sql.DB, tableNames ...string) error {

http/ro_transport.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (t *readOnlyTransport) RoundTrip(req *http.Request) (*http.Response, error)
4444

4545
url := req.URL.String()
4646
resp, err := t.querier.FindByURL(req.Context(), url)
47-
if err != nil || time.Since(resp.Timestamp) > t.ttl {
47+
if err != nil || (t.ttl > 0 && time.Since(resp.Timestamp) > t.ttl) {
4848
return t.base.RoundTrip(req)
4949
}
5050

http/rw_transport.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func (t *readWriteTransport) RoundTrip(req *http.Request) (*http.Response, error
4343

4444
url := req.URL.String()
4545
respDB, err := t.querier.FindByURL(req.Context(), url)
46-
if err != nil || time.Since(respDB.Timestamp) > t.ttl {
46+
if err != nil || (t.ttl > 0 && time.Since(respDB.Timestamp) > t.ttl) {
4747
resp, err := t.base.RoundTrip(req)
4848
if err == nil {
4949
newRespDB, err := db.HttpToResponse(resp)

0 commit comments

Comments
 (0)