Skip to content

Commit 4bdd9a3

Browse files
authored
parser: Error on mismatched INSERT input (#135)
Error if the length of the target columns do not match the length of the value expressions. Trim space off comments before splitting. Trailing whitespace should not causing parsing metadata to fail.
1 parent 15294de commit 4bdd9a3

3 files changed

Lines changed: 46 additions & 4 deletions

File tree

internal/dinosql/checks.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ func validateParamRef(n nodes.Node) error {
3333
}
3434
}
3535
}
36-
3736
return nil
3837
}
3938

@@ -95,3 +94,29 @@ func validateFuncCall(c *pg.Catalog, n nodes.Node) error {
9594
Walk(&visitor, n)
9695
return visitor.err
9796
}
97+
98+
func validateInsertStmt(stmt nodes.InsertStmt) error {
99+
sel, ok := stmt.SelectStmt.(nodes.SelectStmt)
100+
if !ok {
101+
return nil
102+
}
103+
if len(sel.ValuesLists) != 1 {
104+
return nil
105+
}
106+
107+
colsLen := len(stmt.Cols.Items)
108+
valsLen := len(sel.ValuesLists[0])
109+
switch {
110+
case colsLen > valsLen:
111+
return pg.Error{
112+
Code: "42601",
113+
Message: "INSERT has more target columns than expressions",
114+
}
115+
case colsLen < valsLen:
116+
return pg.Error{
117+
Code: "42601",
118+
Message: "INSERT has more expressions than target columns",
119+
}
120+
}
121+
return nil
122+
}

internal/dinosql/parser.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ func parseMetadata(t string) (string, string, error) {
317317
if !strings.HasPrefix(line, "-- name:") {
318318
continue
319319
}
320-
part := strings.Split(line, " ")
320+
part := strings.Split(strings.TrimSpace(line), " ")
321321
if len(part) == 2 {
322322
return "", "", fmt.Errorf("missing query type [':one', ':many', ':exec', ':execrows']: %s", line)
323323
}
@@ -371,10 +371,13 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
371371
if !ok {
372372
return nil, nil
373373
}
374-
switch raw.Stmt.(type) {
374+
switch n := raw.Stmt.(type) {
375375
case nodes.SelectStmt:
376376
case nodes.DeleteStmt:
377377
case nodes.InsertStmt:
378+
if err := validateInsertStmt(n); err != nil {
379+
return nil, err
380+
}
378381
case nodes.UpdateStmt:
379382
default:
380383
return nil, nil

internal/dinosql/query_test.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -867,12 +867,26 @@ func TestInvalidQueries(t *testing.T) {
867867
`,
868868
`query "InsertFoo" specifies parameter ":one" without containing a RETURNING clause`,
869869
},
870+
{
871+
`
872+
CREATE TABLE foo (bar text not null, baz text not null);
873+
INSERT INTO foo (bar, baz) VALUES ($1);
874+
`,
875+
`INSERT has more target columns than expressions`,
876+
},
877+
{
878+
`
879+
CREATE TABLE foo (bar text not null, baz text not null);
880+
INSERT INTO foo (bar) VALUES ($1, $2);
881+
`,
882+
`INSERT has more expressions than target columns`,
883+
},
870884
} {
871885
test := tc
872886
t.Run(strconv.Itoa(i), func(t *testing.T) {
873887
_, err := parseSQL(test.stmt)
874888
if err == nil {
875-
t.Errorf("expected err, got nil")
889+
t.Fatalf("expected err, got nil")
876890
}
877891
if diff := cmp.Diff(test.msg, err.Error()); diff != "" {
878892
t.Errorf("error message differs: \n%s", diff)

0 commit comments

Comments
 (0)