Skip to content

Commit 9b7cd7d

Browse files
author
Abdelkarim Boujida
committed
Make it possible to grant privileges on all tables of a schema + bug fixes
1 parent 92e2679 commit 9b7cd7d

2 files changed

Lines changed: 67 additions & 37 deletions

File tree

internal/provider/resource_grant_role.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ func (r *roleGrantResource) Read(ctx context.Context, req resource.ReadRequest,
136136
if err != nil {
137137
resp.Diagnostics.AddError(
138138
"Error reading grant role",
139-
"Unable to connect to database, unexpected error: "+err.Error(),
139+
"Unable to read grant role, unexpected error: "+err.Error(),
140140
)
141141
return
142142
}

internal/provider/resource_grant_table.go

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package provider
33
import (
44
"context"
55
"fmt"
6+
"reflect"
67
"regexp"
78
"strings"
89

@@ -35,20 +36,18 @@ type tableGrantResourceModel struct {
3536
Role types.String `tfsdk:"role"`
3637
}
3738

38-
func (t *tableGrantResourceModel) hasAllPrivileges() bool {
39-
for _, priv := range t.Privileges {
40-
if priv.Privilege.ValueString() == "ALL" {
41-
return true
42-
}
43-
}
44-
return false
45-
}
46-
4739
type tablePrivilegeModel struct {
4840
Privilege types.String `tfsdk:"privilege"`
4941
WithGrantOption types.Bool `tfsdk:"with_grant_option"`
5042
}
5143

44+
func (t *tableGrantResourceModel) isAllPrivilegesWithGrantOption() (bool, bool) {
45+
if len(t.Privileges) != 1 {
46+
return false, false
47+
}
48+
return t.Privileges[0].Privilege.ValueString() == "ALL", t.Privileges[0].WithGrantOption.ValueBool()
49+
}
50+
5251
func newTableGrantResource() resource.Resource {
5352
return &tableGrantResource{}
5453
}
@@ -86,7 +85,7 @@ func (r *tableGrantResource) Schema(_ context.Context, _ resource.SchemaRequest,
8685
stringplanmodifier.RequiresReplace(),
8786
},
8887
},
89-
"table": schema.StringAttribute{ // make table property safe, because otherwise someone can put like "table1, table2" and the grant will work
88+
"table": schema.StringAttribute{ // TODO - make table property safe, because otherwise someone can put like "table1, table2" and the grant will work
9089
Description: "The table on which the privileges will be granted for this role.",
9190
MarkdownDescription: "The table on which the privileges will be granted for this role.",
9291
Required: true,
@@ -168,9 +167,9 @@ func (r *tableGrantResource) Create(ctx context.Context, req resource.CreateRequ
168167
}
169168

170169
tablePlaceholder := "TABLE " + schema + "." + table
171-
// if table == "*" { // will not support granting on all tables in a schema, because it will be difficult to see the difference otherwise later
172-
// tablePlaceholder = "ALL TABLES IN SCHEMA " + schema
173-
// }
170+
if table == "*" {
171+
tablePlaceholder = "ALL TABLES IN SCHEMA " + schema
172+
}
174173

175174
if len(privilegesGrant) > 0 {
176175
sqlStatement := fmt.Sprintf("GRANT %s ON %s TO %s WITH GRANT OPTION", strings.Join(privilegesGrant, ", "), tablePlaceholder, role)
@@ -213,6 +212,7 @@ func (r *tableGrantResource) Read(ctx context.Context, req resource.ReadRequest,
213212
}
214213

215214
table := state.Table.ValueString()
215+
isAllTables := table == "*"
216216
schema := state.Schema.ValueString()
217217
database := state.Database.ValueString()
218218
role := state.Role.ValueString()
@@ -226,16 +226,13 @@ func (r *tableGrantResource) Read(ctx context.Context, req resource.ReadRequest,
226226
return
227227
}
228228

229-
oid, err := fetchOidForRole(ctx, db, role)
230-
if err != nil {
231-
resp.Diagnostics.AddError(
232-
"Error reading table grant",
233-
"Unable to fetch oid for the role '"+role+"', unexpected error: "+err.Error(),
234-
)
235-
return
236-
}
237-
238-
rows, err := db.QueryContext(ctx, "SELECT privilege_type, is_grantable FROM (select (aclexplode(c.relacl)).* from pg_catalog.pg_class as c left join pg_catalog.pg_namespace as n on n.oid = c.relnamespace where n.nspname = $1 and c.relname = $2 and c.relkind = 'r') as acl WHERE grantee = $3", schema, table, oid)
229+
sqlStatement := `select acls.relname, acls.privilege_type, acls.is_grantable
230+
from (
231+
select relname, (aclexplode(relacl)).* from pg_catalog.pg_class as c
232+
where relnamespace = $1::regnamespace
233+
) as acls
234+
WHERE acls.grantee = $2::regrole`
235+
rows, err := db.QueryContext(ctx, sqlStatement, schema, role)
239236
if err != nil {
240237
resp.Diagnostics.AddError(
241238
"Error reading table grant",
@@ -244,33 +241,61 @@ func (r *tableGrantResource) Read(ctx context.Context, req resource.ReadRequest,
244241
return
245242
}
246243

247-
var privileges []tablePrivilegeModel
244+
privileges := make(map[string][]tablePrivilegeModel)
248245
for rows.Next() {
246+
var relname string
249247
var privilege string
250248
var isGrantable bool
251-
err = rows.Scan(&privilege, &isGrantable)
249+
err = rows.Scan(&relname, &privilege, &isGrantable)
252250
if err != nil {
253251
resp.Diagnostics.AddError(
254252
"Error reading table grant",
255253
"Unable to read privileges for '"+role+"' on table '"+schema+"."+table+"', unexpected error: "+err.Error(),
256254
)
257255
return
258256
}
259-
privileges = append(privileges, tablePrivilegeModel{
257+
258+
if !isAllTables && relname != table {
259+
continue // skip
260+
}
261+
262+
privileges[relname] = append(privileges[relname], tablePrivilegeModel{
260263
Privilege: types.StringValue(privilege),
261264
WithGrantOption: types.BoolValue(isGrantable),
262265
})
263266
}
264267

265-
if state.hasAllPrivileges() && containsAllPrivileges(privileges) {
266-
allPrivileges := []tablePrivilegeModel{}
267-
allPrivileges = append(allPrivileges, tablePrivilegeModel{
268-
Privilege: types.StringValue("ALL"),
269-
WithGrantOption: privileges[0].WithGrantOption,
270-
})
271-
state.Privileges = allPrivileges
272-
} else {
273-
state.Privileges = privileges
268+
for _, priv := range privileges {
269+
allPriv, grantOption := state.isAllPrivilegesWithGrantOption()
270+
271+
if allPriv {
272+
eq := true
273+
for _, p := range priv {
274+
if p.WithGrantOption.ValueBool() != grantOption {
275+
state.Privileges = priv
276+
eq = false
277+
break
278+
}
279+
}
280+
if !eq {
281+
break
282+
}
283+
284+
if containsAllPrivileges(priv) {
285+
allPrivileges := []tablePrivilegeModel{}
286+
allPrivileges = append(allPrivileges, tablePrivilegeModel{
287+
Privilege: types.StringValue("ALL"),
288+
WithGrantOption: priv[0].WithGrantOption,
289+
})
290+
state.Privileges = allPrivileges
291+
continue
292+
}
293+
}
294+
295+
state.Privileges = priv
296+
if !reflect.DeepEqual(state.Privileges, priv) { // Be sure to show the changes of the one table that has different privileges
297+
break
298+
}
274299
}
275300

276301
diags = resp.State.Set(ctx, &state)
@@ -312,7 +337,12 @@ func (r *tableGrantResource) Delete(ctx context.Context, req resource.DeleteRequ
312337
privileges = append(privileges, priv.Privilege.ValueString())
313338
}
314339

315-
sqlStatement := fmt.Sprintf("REVOKE %s ON TABLE %s.%s FROM %s", strings.Join(privileges, ", "), schema, table, role)
340+
tablePlaceholder := "TABLE " + schema + "." + table
341+
if table == "*" {
342+
tablePlaceholder = "ALL TABLES IN SCHEMA " + schema
343+
}
344+
345+
sqlStatement := fmt.Sprintf("REVOKE %s ON %s FROM %s", strings.Join(privileges, ", "), tablePlaceholder, role)
316346

317347
_, err = db.ExecContext(ctx, sqlStatement)
318348
if err != nil {

0 commit comments

Comments
 (0)