|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "crypto/tls" |
| 5 | + "crypto/x509" |
| 6 | + "database/sql" |
| 7 | + "fmt" |
| 8 | + "log" |
| 9 | + "net" |
| 10 | + "net/http" |
| 11 | + "os" |
| 12 | + "path/filepath" |
| 13 | + "strconv" |
| 14 | + "strings" |
| 15 | + "time" |
| 16 | + |
| 17 | + "github.com/elazarl/goproxy" |
| 18 | + "github.com/peterbourgon/ff/v4" |
| 19 | + "github.com/peterbourgon/ff/v4/ffhelp" |
| 20 | + "github.com/tursodatabase/go-libsql" |
| 21 | + |
| 22 | + "github.com/walterwanderley/sqlite-http-cache/config" |
| 23 | + "github.com/walterwanderley/sqlite-http-cache/db" |
| 24 | + proxyhandler "github.com/walterwanderley/sqlite-http-cache/http/proxy" |
| 25 | +) |
| 26 | + |
| 27 | +func main() { |
| 28 | + fs := ff.NewFlagSet("libsql-http-proxy") |
| 29 | + port := fs.Uint('p', "port", 8080, "Server port") |
| 30 | + dbPrimaryURL := fs.StringLong("db-primary-url", "", "Database primary URL") |
| 31 | + dbSyncInterval := fs.DurationLong("db-sync-interval", 30*time.Second, "Database sync interval") |
| 32 | + dbAuthToken := fs.StringLong("db-token", "", "Database authorization token") |
| 33 | + dbEncryptionKey := fs.StringLong("db-key", "", "Database encryption key") |
| 34 | + verbose := fs.Bool('v', "verbose", "Enable verbose mode") |
| 35 | + allowHTTP2 := fs.BoolLong("h2", "Allow HTTP2") |
| 36 | + statusCodes := fs.StringListLong("status-code", fmt.Sprintf("List of cacheable status code. Defaults to the heuristically cacheable codes: %v", config.DefaultStatusCodes())) |
| 37 | + ttl := fs.IntLong("ttl", 0, "Time to Live in seconds (0 is infinite time)") |
| 38 | + responseTables := fs.StringListLong("response-table", "List of database tables used to store response data") |
| 39 | + caCert := fs.StringLong("ca-cert", "", "Path to CA Certificate file (required to HTTPS proxy)") |
| 40 | + caCertKey := fs.StringLong("ca-cert-key", "", "Path to CA Certificate Key file (required to HTTPS proxy)") |
| 41 | + readOnly := fs.BoolLong("ro", "Read Only mode. Do not store new HTTP responses") |
| 42 | + rfc9111 := fs.BoolLong("rfc9111", "Use RFC9111 spec") |
| 43 | + shared := fs.BoolLong("shared", "Enable shared cache mode") |
| 44 | + _ = fs.String('c', "config", "", "config file (optional)") |
| 45 | + |
| 46 | + if err := ff.Parse(fs, os.Args[1:], |
| 47 | + ff.WithEnvVarPrefix("LIBSQL_HTTP_PROXY"), |
| 48 | + ff.WithConfigFileFlag("config"), |
| 49 | + ff.WithConfigFileParser(ff.PlainParser), |
| 50 | + ); err != nil { |
| 51 | + fmt.Printf("%s\n", ffhelp.Flags(fs)) |
| 52 | + fmt.Printf("err=%v\n", err) |
| 53 | + return |
| 54 | + } |
| 55 | + |
| 56 | + if len(fs.GetArgs()) == 0 { |
| 57 | + log.Fatalf("Usage: %s <FLAGS> [DatabasePath1] [DatabasePathN\n\nExample:\n\t%s example.db example2.db example3.db\n", os.Args[0], os.Args[0]) |
| 58 | + } |
| 59 | + |
| 60 | + if *verbose { |
| 61 | + fmt.Printf("Using options: port=%d db-primary-url=%s, h2=%v, ttl=%d, response-tables=%v, ca-cert=%s, ca-cert-key=%s, read-only=%v, rfc9111=%v shared-cache=%v\n", |
| 62 | + *port, *dbPrimaryURL, *allowHTTP2, *ttl, *responseTables, *caCert, *caCertKey, *readOnly, *rfc9111, *shared) |
| 63 | + } |
| 64 | + |
| 65 | + dbs := make([]*sql.DB, 0) |
| 66 | + var ( |
| 67 | + repository db.Repository |
| 68 | + tableList []string |
| 69 | + err error |
| 70 | + ) |
| 71 | + |
| 72 | + dbOpts := make([]libsql.Option, 0) |
| 73 | + if *dbAuthToken != "" { |
| 74 | + dbOpts = append(dbOpts, libsql.WithAuthToken(*dbAuthToken)) |
| 75 | + } |
| 76 | + if *dbEncryptionKey != "" { |
| 77 | + dbOpts = append(dbOpts, libsql.WithEncryption(*dbEncryptionKey)) |
| 78 | + } |
| 79 | + if *dbSyncInterval > 0 { |
| 80 | + dbOpts = append(dbOpts, libsql.WithSyncInterval(*dbSyncInterval)) |
| 81 | + } |
| 82 | + |
| 83 | + fnRegisterResonseTables := func(sqlDB *sql.DB, dbPath string) { |
| 84 | + if responseTables == nil || len(*responseTables) == 0 { |
| 85 | + tableList, err = db.ResponseTables(sqlDB) |
| 86 | + if err != nil { |
| 87 | + log.Fatalf("discovery response tables: %v", err) |
| 88 | + } |
| 89 | + } else { |
| 90 | + tableList = *responseTables |
| 91 | + err := db.CreateResponseTables(sqlDB, tableList...) |
| 92 | + if err != nil { |
| 93 | + log.Fatalf("create response tables on DB %q: %v", dbPath, err) |
| 94 | + } |
| 95 | + } |
| 96 | + } |
| 97 | + |
| 98 | + for _, dbPath := range fs.GetArgs() { |
| 99 | + if strings.HasPrefix(dbPath, "file:") { |
| 100 | + sqlDB, err := sql.Open("libsql", dbPath) |
| 101 | + if err != nil { |
| 102 | + log.Fatalf("connecting to database %q: %v", dbPath, err) |
| 103 | + } |
| 104 | + fnRegisterResonseTables(sqlDB, dbPath) |
| 105 | + dbs = append(dbs, sqlDB) |
| 106 | + continue |
| 107 | + } |
| 108 | + dir, err := os.MkdirTemp("", "libsql-*") |
| 109 | + if err != nil { |
| 110 | + log.Fatalf("creating database directory: %v", err) |
| 111 | + } |
| 112 | + defer os.RemoveAll(dir) |
| 113 | + |
| 114 | + connector, err := libsql.NewEmbeddedReplicaConnector(filepath.Join(dir, dbPath), *dbPrimaryURL, dbOpts...) |
| 115 | + if err != nil { |
| 116 | + log.Fatalf("creating database connector: %v", err) |
| 117 | + } |
| 118 | + defer connector.Close() |
| 119 | + |
| 120 | + sqlDB := sql.OpenDB(connector) |
| 121 | + defer func() { |
| 122 | + if closeError := sqlDB.Close(); closeError != nil { |
| 123 | + fmt.Println("Error closing database", closeError) |
| 124 | + if err == nil { |
| 125 | + err = closeError |
| 126 | + } |
| 127 | + } |
| 128 | + }() |
| 129 | + |
| 130 | + fnRegisterResonseTables(sqlDB, dbPath) |
| 131 | + dbs = append(dbs, sqlDB) |
| 132 | + |
| 133 | + } |
| 134 | + if len(dbs) == 1 { |
| 135 | + repository, err = db.NewRepository(dbs[0], tableList...) |
| 136 | + if err != nil { |
| 137 | + log.Fatalf("new repository: %v", err) |
| 138 | + } |
| 139 | + } else { |
| 140 | + repository, err = db.NewMultiDatabaseRepository(dbs) |
| 141 | + if err != nil { |
| 142 | + log.Fatalf("new multi database repository: %v", err) |
| 143 | + } |
| 144 | + } |
| 145 | + defer repository.Close() |
| 146 | + |
| 147 | + proxy := goproxy.NewProxyHttpServer() |
| 148 | + proxy.Verbose = *verbose |
| 149 | + proxy.AllowHTTP2 = *allowHTTP2 |
| 150 | + |
| 151 | + if *caCert != "" && *caCertKey != "" { |
| 152 | + proxy.Logger.Printf("INFO: Starting HTTP/HTTPS Proxy...") |
| 153 | + cert, err := parseCA([]byte(*caCert), []byte(*caCertKey)) |
| 154 | + if err != nil { |
| 155 | + log.Fatal(err) |
| 156 | + } |
| 157 | + |
| 158 | + customCaMitm := &goproxy.ConnectAction{Action: goproxy.ConnectMitm, TLSConfig: goproxy.TLSConfigFromCA(cert)} |
| 159 | + var customAlwaysMitm goproxy.FuncHttpsHandler = func(host string, ctx *goproxy.ProxyCtx) (*goproxy.ConnectAction, string) { |
| 160 | + return customCaMitm, host |
| 161 | + } |
| 162 | + proxy.OnRequest().HandleConnect(customAlwaysMitm) |
| 163 | + } else { |
| 164 | + proxy.Logger.Printf("INFO: Starting HTTP Proxy...") |
| 165 | + } |
| 166 | + |
| 167 | + cacheableStatus := make([]int, 0) |
| 168 | + for _, status := range *statusCodes { |
| 169 | + statusStr := strings.TrimSpace(status) |
| 170 | + code, err := strconv.Atoi(statusStr) |
| 171 | + if err != nil { |
| 172 | + log.Fatalf("Invalid status-code %q. Must be integer: %v", status, err) |
| 173 | + } |
| 174 | + cacheableStatus = append(cacheableStatus, code) |
| 175 | + } |
| 176 | + if len(cacheableStatus) == 0 { |
| 177 | + cacheableStatus = config.DefaultStatusCodes() |
| 178 | + } |
| 179 | + |
| 180 | + proxy.OnRequest().Do(proxyhandler.NewRequestHandler( |
| 181 | + proxyhandler.RequestConfig{ |
| 182 | + Querier: repository, |
| 183 | + CacheableStatus: cacheableStatus, |
| 184 | + TTL: *ttl, |
| 185 | + RFC9111: *rfc9111, |
| 186 | + SharedCache: *shared, |
| 187 | + ReadOnly: *readOnly, |
| 188 | + Verbose: *verbose, |
| 189 | + }, |
| 190 | + )) |
| 191 | + |
| 192 | + if !*readOnly { |
| 193 | + proxy.OnResponse().Do(proxyhandler.NewResponseHandler( |
| 194 | + proxyhandler.ResponseConfig{ |
| 195 | + Writer: repository, |
| 196 | + RFC9111: *rfc9111, |
| 197 | + TTL: *ttl, |
| 198 | + SharedCache: *shared, |
| 199 | + Verbose: *verbose, |
| 200 | + }, |
| 201 | + )) |
| 202 | + } |
| 203 | + |
| 204 | + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port)) |
| 205 | + if err != nil { |
| 206 | + log.Fatalf("cannot open port %d: %v", port, err) |
| 207 | + } |
| 208 | + |
| 209 | + proxy.Logger.Printf("LibSQL-HTTP-Proxy listening port=%d", *port) |
| 210 | + log.Fatal(http.Serve(lis, proxy)) |
| 211 | +} |
| 212 | + |
| 213 | +func parseCA(caCert, caKey []byte) (*tls.Certificate, error) { |
| 214 | + parsedCert, err := tls.X509KeyPair(caCert, caKey) |
| 215 | + if err != nil { |
| 216 | + return nil, err |
| 217 | + } |
| 218 | + if parsedCert.Leaf, err = x509.ParseCertificate(parsedCert.Certificate[0]); err != nil { |
| 219 | + return nil, err |
| 220 | + } |
| 221 | + return &parsedCert, nil |
| 222 | +} |
0 commit comments