@@ -420,13 +420,7 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
420420 if err != nil {
421421 return nil , err
422422 }
423-
424- // TODO: Limit calls to sourceTables
425- tables , err := sourceTables (c , raw .Stmt )
426- if err != nil {
427- return nil , err
428- }
429- expanded , err := expand (raw , tables , rawSQL )
423+ expanded , err := expand (c , raw , rawSQL )
430424 if err != nil {
431425 return nil , err
432426 }
@@ -468,25 +462,65 @@ type edit struct {
468462 New string
469463}
470464
471- func expand (raw nodes. RawStmt , tables []core. Table , sql string ) (string , error ) {
465+ func expand (c core. Catalog , raw nodes. RawStmt , sql string ) (string , error ) {
472466 list := search (raw , func (node nodes.Node ) bool {
473- res , ok := node .(nodes. ResTarget )
474- if ! ok {
475- return false
476- }
477- ref , ok := res . Val .( nodes.ColumnRef )
478- if ! ok {
467+ switch node .(type ) {
468+ case nodes. DeleteStmt :
469+ case nodes. InsertStmt :
470+ case nodes. SelectStmt :
471+ case nodes.UpdateStmt :
472+ default :
479473 return false
480474 }
481- return HasStarRef ( ref )
475+ return true
482476 })
483477 if len (list .Items ) == 0 {
484478 return sql , nil
485479 }
486480 var edits []edit
487481 for _ , item := range list .Items {
488- res := item .(nodes.ResTarget )
489- ref := res .Val .(nodes.ColumnRef )
482+ edit , err := expandStmt (c , raw , item )
483+ if err != nil {
484+ return "" , err
485+ }
486+ edits = append (edits , edit ... )
487+ }
488+ return editQuery (sql , edits )
489+ }
490+
491+ func expandStmt (c core.Catalog , raw nodes.RawStmt , node nodes.Node ) ([]edit , error ) {
492+ tables , err := sourceTables (c , node )
493+ if err != nil {
494+ return nil , err
495+ }
496+
497+ var targets nodes.List
498+ switch n := node .(type ) {
499+ case nodes.DeleteStmt :
500+ targets = n .ReturningList
501+ case nodes.InsertStmt :
502+ targets = n .ReturningList
503+ case nodes.SelectStmt :
504+ targets = n .TargetList
505+ case nodes.UpdateStmt :
506+ targets = n .ReturningList
507+ default :
508+ return nil , fmt .Errorf ("outputColumns: unsupported node type: %T" , n )
509+ }
510+
511+ var edits []edit
512+ for _ , target := range targets .Items {
513+ res , ok := target .(nodes.ResTarget )
514+ if ! ok {
515+ continue
516+ }
517+ ref , ok := res .Val .(nodes.ColumnRef )
518+ if ! ok {
519+ continue
520+ }
521+ if ! HasStarRef (ref ) {
522+ continue
523+ }
490524 var parts , cols []string
491525 for _ , f := range ref .Fields .Items {
492526 switch field := f .(type ) {
@@ -495,7 +529,7 @@ func expand(raw nodes.RawStmt, tables []core.Table, sql string) (string, error)
495529 case nodes.A_Star :
496530 parts = append (parts , "*" )
497531 default :
498- return "" , fmt .Errorf ("unknown field in ColumnRef: %T" , f )
532+ return nil , fmt .Errorf ("unknown field in ColumnRef: %T" , f )
499533 }
500534 }
501535 for _ , t := range tables {
@@ -520,7 +554,7 @@ func expand(raw nodes.RawStmt, tables []core.Table, sql string) (string, error)
520554 New : strings .Join (cols , ", " ),
521555 })
522556 }
523- return editQuery ( sql , edits )
557+ return edits , nil
524558}
525559
526560func editQuery (raw string , a []edit ) (string , error ) {
0 commit comments