@@ -3,6 +3,7 @@ package provider
33import (
44 "context"
55 "fmt"
6+ "reflect"
67 "regexp"
78 "strings"
89
@@ -35,20 +36,18 @@ type tableGrantResourceModel struct {
3536 Role types.String `tfsdk:"role"`
3637}
3738
38- func (t * tableGrantResourceModel ) hasAllPrivileges () bool {
39- for _ , priv := range t .Privileges {
40- if priv .Privilege .ValueString () == "ALL" {
41- return true
42- }
43- }
44- return false
45- }
46-
4739type tablePrivilegeModel struct {
4840 Privilege types.String `tfsdk:"privilege"`
4941 WithGrantOption types.Bool `tfsdk:"with_grant_option"`
5042}
5143
44+ func (t * tableGrantResourceModel ) isAllPrivilegesWithGrantOption () (bool , bool ) {
45+ if len (t .Privileges ) != 1 {
46+ return false , false
47+ }
48+ return t .Privileges [0 ].Privilege .ValueString () == "ALL" , t .Privileges [0 ].WithGrantOption .ValueBool ()
49+ }
50+
5251func newTableGrantResource () resource.Resource {
5352 return & tableGrantResource {}
5453}
@@ -86,7 +85,7 @@ func (r *tableGrantResource) Schema(_ context.Context, _ resource.SchemaRequest,
8685 stringplanmodifier .RequiresReplace (),
8786 },
8887 },
89- "table" : schema.StringAttribute { // make table property safe, because otherwise someone can put like "table1, table2" and the grant will work
88+ "table" : schema.StringAttribute { // TODO - make table property safe, because otherwise someone can put like "table1, table2" and the grant will work
9089 Description : "The table on which the privileges will be granted for this role." ,
9190 MarkdownDescription : "The table on which the privileges will be granted for this role." ,
9291 Required : true ,
@@ -168,9 +167,9 @@ func (r *tableGrantResource) Create(ctx context.Context, req resource.CreateRequ
168167 }
169168
170169 tablePlaceholder := "TABLE " + schema + "." + table
171- // if table == "*" { // will not support granting on all tables in a schema, because it will be difficult to see the difference otherwise later
172- // tablePlaceholder = "ALL TABLES IN SCHEMA " + schema
173- // }
170+ if table == "*" {
171+ tablePlaceholder = "ALL TABLES IN SCHEMA " + schema
172+ }
174173
175174 if len (privilegesGrant ) > 0 {
176175 sqlStatement := fmt .Sprintf ("GRANT %s ON %s TO %s WITH GRANT OPTION" , strings .Join (privilegesGrant , ", " ), tablePlaceholder , role )
@@ -213,6 +212,7 @@ func (r *tableGrantResource) Read(ctx context.Context, req resource.ReadRequest,
213212 }
214213
215214 table := state .Table .ValueString ()
215+ isAllTables := table == "*"
216216 schema := state .Schema .ValueString ()
217217 database := state .Database .ValueString ()
218218 role := state .Role .ValueString ()
@@ -226,16 +226,13 @@ func (r *tableGrantResource) Read(ctx context.Context, req resource.ReadRequest,
226226 return
227227 }
228228
229- oid , err := fetchOidForRole (ctx , db , role )
230- if err != nil {
231- resp .Diagnostics .AddError (
232- "Error reading table grant" ,
233- "Unable to fetch oid for the role '" + role + "', unexpected error: " + err .Error (),
234- )
235- return
236- }
237-
238- rows , err := db .QueryContext (ctx , "SELECT privilege_type, is_grantable FROM (select (aclexplode(c.relacl)).* from pg_catalog.pg_class as c left join pg_catalog.pg_namespace as n on n.oid = c.relnamespace where n.nspname = $1 and c.relname = $2 and c.relkind = 'r') as acl WHERE grantee = $3" , schema , table , oid )
229+ sqlStatement := `select acls.relname, acls.privilege_type, acls.is_grantable
230+ from (
231+ select relname, (aclexplode(relacl)).* from pg_catalog.pg_class as c
232+ where relnamespace = $1::regnamespace
233+ ) as acls
234+ WHERE acls.grantee = $2::regrole`
235+ rows , err := db .QueryContext (ctx , sqlStatement , schema , role )
239236 if err != nil {
240237 resp .Diagnostics .AddError (
241238 "Error reading table grant" ,
@@ -244,33 +241,61 @@ func (r *tableGrantResource) Read(ctx context.Context, req resource.ReadRequest,
244241 return
245242 }
246243
247- var privileges [] tablePrivilegeModel
244+ privileges := make ( map [ string ][] tablePrivilegeModel )
248245 for rows .Next () {
246+ var relname string
249247 var privilege string
250248 var isGrantable bool
251- err = rows .Scan (& privilege , & isGrantable )
249+ err = rows .Scan (& relname , & privilege , & isGrantable )
252250 if err != nil {
253251 resp .Diagnostics .AddError (
254252 "Error reading table grant" ,
255253 "Unable to read privileges for '" + role + "' on table '" + schema + "." + table + "', unexpected error: " + err .Error (),
256254 )
257255 return
258256 }
259- privileges = append (privileges , tablePrivilegeModel {
257+
258+ if ! isAllTables && relname != table {
259+ continue // skip
260+ }
261+
262+ privileges [relname ] = append (privileges [relname ], tablePrivilegeModel {
260263 Privilege : types .StringValue (privilege ),
261264 WithGrantOption : types .BoolValue (isGrantable ),
262265 })
263266 }
264267
265- if state .hasAllPrivileges () && containsAllPrivileges (privileges ) {
266- allPrivileges := []tablePrivilegeModel {}
267- allPrivileges = append (allPrivileges , tablePrivilegeModel {
268- Privilege : types .StringValue ("ALL" ),
269- WithGrantOption : privileges [0 ].WithGrantOption ,
270- })
271- state .Privileges = allPrivileges
272- } else {
273- state .Privileges = privileges
268+ for _ , priv := range privileges {
269+ allPriv , grantOption := state .isAllPrivilegesWithGrantOption ()
270+
271+ if allPriv {
272+ eq := true
273+ for _ , p := range priv {
274+ if p .WithGrantOption .ValueBool () != grantOption {
275+ state .Privileges = priv
276+ eq = false
277+ break
278+ }
279+ }
280+ if ! eq {
281+ break
282+ }
283+
284+ if containsAllPrivileges (priv ) {
285+ allPrivileges := []tablePrivilegeModel {}
286+ allPrivileges = append (allPrivileges , tablePrivilegeModel {
287+ Privilege : types .StringValue ("ALL" ),
288+ WithGrantOption : priv [0 ].WithGrantOption ,
289+ })
290+ state .Privileges = allPrivileges
291+ continue
292+ }
293+ }
294+
295+ state .Privileges = priv
296+ if ! reflect .DeepEqual (state .Privileges , priv ) { // Be sure to show the changes of the one table that has different privileges
297+ break
298+ }
274299 }
275300
276301 diags = resp .State .Set (ctx , & state )
@@ -312,7 +337,12 @@ func (r *tableGrantResource) Delete(ctx context.Context, req resource.DeleteRequ
312337 privileges = append (privileges , priv .Privilege .ValueString ())
313338 }
314339
315- sqlStatement := fmt .Sprintf ("REVOKE %s ON TABLE %s.%s FROM %s" , strings .Join (privileges , ", " ), schema , table , role )
340+ tablePlaceholder := "TABLE " + schema + "." + table
341+ if table == "*" {
342+ tablePlaceholder = "ALL TABLES IN SCHEMA " + schema
343+ }
344+
345+ sqlStatement := fmt .Sprintf ("REVOKE %s ON %s FROM %s" , strings .Join (privileges , ", " ), tablePlaceholder , role )
316346
317347 _ , err = db .ExecContext (ctx , sqlStatement )
318348 if err != nil {
0 commit comments