Skip to content

Commit 80f0219

Browse files
authored
validation: Move query validation to separate package (#498)
1 parent e36acce commit 80f0219

9 files changed

Lines changed: 231 additions & 194 deletions

File tree

internal/dinosql/checks.go

Lines changed: 0 additions & 150 deletions
This file was deleted.

internal/dinosql/parser.go

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
core "github.com/kyleconroy/sqlc/internal/pg"
1818
"github.com/kyleconroy/sqlc/internal/postgres"
1919
"github.com/kyleconroy/sqlc/internal/postgresql/ast"
20+
"github.com/kyleconroy/sqlc/internal/postgresql/validate"
2021
"github.com/kyleconroy/sqlc/internal/sql/sqlpath"
2122

2223
"github.com/davecgh/go-spew/spew"
@@ -49,7 +50,7 @@ func ParseCatalog(schemas []string) (core.Catalog, error) {
4950
continue
5051
}
5152
for _, stmt := range tree.Statements {
52-
if err := validateFuncCall(&c, stmt); err != nil {
53+
if err := validate.FuncCall(&c, stmt); err != nil {
5354
merr.Add(filename, contents, location(stmt), err)
5455
continue
5556
}
@@ -72,7 +73,7 @@ func ParseCatalog(schemas []string) (core.Catalog, error) {
7273

7374
func updateCatalog(c *core.Catalog, tree pg.ParsetreeList) error {
7475
for _, stmt := range tree.Statements {
75-
if err := validateFuncCall(c, stmt); err != nil {
76+
if err := validate.FuncCall(c, stmt); err != nil {
7677
return err
7778
}
7879
if err := catalog.Update(c, stmt); err != nil {
@@ -301,10 +302,10 @@ func validateCmd(n nodes.Node, name, cmd string) error {
301302
var errUnsupportedStatementType = errors.New("parseQuery: unsupported statement type")
302303

303304
func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameters bool) (*Query, error) {
304-
if err := validateParamStyle(stmt); err != nil {
305+
if err := validate.ParamStyle(stmt); err != nil {
305306
return nil, err
306307
}
307-
if err := validateParamRef(stmt); err != nil {
308+
if err := validate.ParamRef(stmt); err != nil {
308309
return nil, err
309310
}
310311
raw, ok := stmt.(nodes.RawStmt)
@@ -315,7 +316,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameter
315316
case nodes.SelectStmt:
316317
case nodes.DeleteStmt:
317318
case nodes.InsertStmt:
318-
if err := validateInsertStmt(n); err != nil {
319+
if err := validate.InsertStmt(n); err != nil {
319320
return nil, err
320321
}
321322
case nodes.TruncateStmt:
@@ -331,7 +332,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string, rewriteParameter
331332
if rawSQL == "" {
332333
return nil, errors.New("missing semicolon at end of file")
333334
}
334-
if err := validateFuncCall(&c, raw); err != nil {
335+
if err := validate.FuncCall(&c, raw); err != nil {
335336
return nil, err
336337
}
337338
name, cmd, err := ParseMetadata(strings.TrimSpace(rawSQL), CommentSyntaxDash)
@@ -437,7 +438,7 @@ type edit struct {
437438
}
438439

439440
func expand(qc *QueryCatalog, raw nodes.RawStmt) ([]edit, error) {
440-
list := search(raw, func(node nodes.Node) bool {
441+
list := ast.Search(raw, func(node nodes.Node) bool {
441442
switch node.(type) {
442443
case nodes.DeleteStmt:
443444
case nodes.InsertStmt:
@@ -655,7 +656,7 @@ func sourceTables(qc *QueryCatalog, node nodes.Node) ([]core.Table, error) {
655656
Items: []nodes.Node{*n.Relation},
656657
}
657658
case nodes.SelectStmt:
658-
list = search(n.FromClause, func(node nodes.Node) bool {
659+
list = ast.Search(n.FromClause, func(node nodes.Node) bool {
659660
switch node.(type) {
660661
case nodes.RangeVar, nodes.RangeSubselect:
661662
return true
@@ -664,7 +665,7 @@ func sourceTables(qc *QueryCatalog, node nodes.Node) ([]core.Table, error) {
664665
}
665666
})
666667
case nodes.TruncateStmt:
667-
list = search(n.Relations, func(node nodes.Node) bool {
668+
list = ast.Search(n.Relations, func(node nodes.Node) bool {
668669
_, ok := node.(nodes.RangeVar)
669670
return ok
670671
})
@@ -1095,24 +1096,6 @@ func findParameters(root nodes.Node) []paramRef {
10951096
return refs
10961097
}
10971098

1098-
type nodeSearch struct {
1099-
list nodes.List
1100-
check func(nodes.Node) bool
1101-
}
1102-
1103-
func (s *nodeSearch) Visit(node nodes.Node) ast.Visitor {
1104-
if s.check(node) {
1105-
s.list.Items = append(s.list.Items, node)
1106-
}
1107-
return s
1108-
}
1109-
1110-
func search(root nodes.Node, f func(nodes.Node) bool) nodes.List {
1111-
ns := &nodeSearch{check: f}
1112-
ast.Walk(ns, root)
1113-
return ns.list
1114-
}
1115-
11161099
func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef, names map[int]string) ([]Parameter, error) {
11171100
aliasMap := map[string]core.FQN{}
11181101
// TODO: Deprecate defaultTable
@@ -1194,7 +1177,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef, n
11941177
case nodes.A_Expr:
11951178
// TODO: While this works for a wide range of simple expressions,
11961179
// more complicated expressions will cause this logic to fail.
1197-
list := search(n.Lexpr, func(node nodes.Node) bool {
1180+
list := ast.Search(n.Lexpr, func(node nodes.Node) bool {
11981181
_, ok := node.(nodes.ColumnRef)
11991182
return ok
12001183
})

internal/dinosql/rewrite.go

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55

66
nodes "github.com/lfittl/pg_query_go/nodes"
77

8+
"github.com/kyleconroy/sqlc/internal/postgresql"
89
"github.com/kyleconroy/sqlc/internal/postgresql/ast"
910
)
1011

@@ -16,7 +17,7 @@ func flatten(root nodes.Node) (string, bool) {
1617
}
1718

1819
type stringWalker struct {
19-
String string
20+
String string
2021
IsConst bool
2122
}
2223

@@ -30,16 +31,6 @@ func (s *stringWalker) Visit(node nodes.Node) ast.Visitor {
3031
return s
3132
}
3233

33-
func isNamedParamFunc(node nodes.Node) bool {
34-
fun, ok := node.(nodes.FuncCall)
35-
return ok && ast.Join(fun.Funcname, ".") == "sqlc.arg"
36-
}
37-
38-
func isNamedParamSign(node nodes.Node) bool {
39-
expr, ok := node.(nodes.A_Expr)
40-
return ok && ast.Join(expr.Name, ".") == "@"
41-
}
42-
4334
func isNamedParamSignCast(node nodes.Node) bool {
4435
expr, ok := node.(nodes.A_Expr)
4536
if !ok {
@@ -50,8 +41,8 @@ func isNamedParamSignCast(node nodes.Node) bool {
5041
}
5142

5243
func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, []edit) {
53-
foundFunc := search(raw, isNamedParamFunc)
54-
foundSign := search(raw, isNamedParamSign)
44+
foundFunc := ast.Search(raw, postgresql.IsNamedParamFunc)
45+
foundSign := ast.Search(raw, postgresql.IsNamedParamSign)
5546
if len(foundFunc.Items)+len(foundSign.Items) == 0 {
5647
return raw, map[int]string{}, nil
5748
}
@@ -63,7 +54,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [
6354
node := cr.Node()
6455
switch {
6556

66-
case isNamedParamFunc(node):
57+
case postgresql.IsNamedParamFunc(node):
6758
fun := node.(nodes.FuncCall)
6859
param, isConst := flatten(fun.Args)
6960
if num, ok := args[param]; ok {
@@ -120,7 +111,7 @@ func rewriteNamedParameters(raw nodes.RawStmt) (nodes.RawStmt, map[int]string, [
120111
})
121112
return false
122113

123-
case isNamedParamSign(node):
114+
case postgresql.IsNamedParamSign(node):
124115
expr := node.(nodes.A_Expr)
125116
param, _ := flatten(expr.Rexpr)
126117
if num, ok := args[param]; ok {

internal/postgresql/ast/search.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package ast
2+
3+
import (
4+
nodes "github.com/lfittl/pg_query_go/nodes"
5+
)
6+
7+
type nodeSearch struct {
8+
list nodes.List
9+
check func(nodes.Node) bool
10+
}
11+
12+
func (s *nodeSearch) Visit(node nodes.Node) Visitor {
13+
if s.check(node) {
14+
s.list.Items = append(s.list.Items, node)
15+
}
16+
return s
17+
}
18+
19+
func Search(root nodes.Node, f func(nodes.Node) bool) nodes.List {
20+
ns := &nodeSearch{check: f}
21+
Walk(ns, root)
22+
return ns.list
23+
}

internal/postgresql/utils.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package postgresql
22

3-
import nodes "github.com/lfittl/pg_query_go/nodes"
3+
import (
4+
nodes "github.com/lfittl/pg_query_go/nodes"
5+
6+
"github.com/kyleconroy/sqlc/internal/postgresql/ast"
7+
)
48

59
func isNotNull(n nodes.ColumnDef) bool {
610
if n.IsNotNull {
@@ -19,3 +23,13 @@ func isNotNull(n nodes.ColumnDef) bool {
1923
}
2024
return false
2125
}
26+
27+
func IsNamedParamFunc(node nodes.Node) bool {
28+
fun, ok := node.(nodes.FuncCall)
29+
return ok && ast.Join(fun.Funcname, ".") == "sqlc.arg"
30+
}
31+
32+
func IsNamedParamSign(node nodes.Node) bool {
33+
expr, ok := node.(nodes.A_Expr)
34+
return ok && ast.Join(expr.Name, ".") == "@"
35+
}

0 commit comments

Comments
 (0)