Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
45 changes: 25 additions & 20 deletions cmd/cone/form_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@ import (
"github.com/conductorone/cone/pkg/output"
)

// formFields returns the field definitions for a form, or nil if the form is empty.
func formFields(form *shared.RequestSchemaForm) []shared.FormField {
if form == nil {
return nil
}
return form.Fields
}

// collectFormFields collects form field values from the user based on the form definition.
// Returns a map of field names to their values, or nil if no form fields are present.
func collectFormFields(ctx context.Context, v *viper.Viper, form *shared.FormInput) (map[string]any, error) {
if form == nil || len(form.Fields) == 0 {
func collectFormFields(ctx context.Context, v *viper.Viper, form *shared.RequestSchemaForm) (map[string]any, error) {
fields := formFields(form)
if len(fields) == 0 {
return nil, nil
}

Expand All @@ -34,7 +43,7 @@ func collectFormFields(ctx context.Context, v *viper.Viper, form *shared.FormInp
return nil, err
}

for _, field := range form.Fields {
for _, field := range fields {
fieldName := client.StringFromPtr(field.Name)
if fieldName == "" {
continue
Expand Down Expand Up @@ -81,7 +90,7 @@ func collectFormFields(ctx context.Context, v *viper.Viper, form *shared.FormInp
}

// collectFieldValue collects a single field value from the user based on field type.
func collectFieldValue(ctx context.Context, field shared.FieldInput, displayName, description string) (any, error) {
func collectFieldValue(ctx context.Context, field shared.FormField, displayName, description string) (any, error) {
// Check for default value first
if defaultValue := getFieldDefaultValue(field); defaultValue != nil {
// Show default value and ask for confirmation
Expand All @@ -100,8 +109,8 @@ func collectFieldValue(ctx context.Context, field shared.FieldInput, displayName

// Collect based on field type
switch {
case field.StringField != nil:
return collectStringField(ctx, field.StringField, displayName, description)
case field.FormStringField != nil:
return collectStringField(ctx, field.FormStringField, displayName, description)
case field.BoolField != nil:
return collectBoolField(ctx, field.BoolField, displayName, description)
case field.Int64Field != nil:
Expand All @@ -116,7 +125,7 @@ func collectFieldValue(ctx context.Context, field shared.FieldInput, displayName
}

// collectStringField collects a string field value with validation.
func collectStringField(ctx context.Context, field *shared.StringField, displayName, description string) (string, error) {
func collectStringField(ctx context.Context, field *shared.FormStringField, displayName, description string) (string, error) {
validator := StringFieldValidator{
field: field,
displayName: displayName,
Expand Down Expand Up @@ -235,10 +244,10 @@ func collectStringSliceField(ctx context.Context, field *shared.StringSliceField
}

// getFieldDefaultValue extracts the default value from a field based on its type.
func getFieldDefaultValue(field shared.FieldInput) any {
func getFieldDefaultValue(field shared.FormField) any {
switch {
case field.StringField != nil && field.StringField.DefaultValue != nil:
return *field.StringField.DefaultValue
case field.FormStringField != nil && field.FormStringField.DefaultValue != nil:
return *field.FormStringField.DefaultValue
case field.BoolField != nil && field.BoolField.DefaultValue != nil:
return *field.BoolField.DefaultValue
case field.Int64Field != nil && field.Int64Field.DefaultValue != nil:
Expand Down Expand Up @@ -270,7 +279,7 @@ func parseFormDataFlag(formDataFlag string) (map[string]any, error) {

// StringFieldValidator validates string field input.
type StringFieldValidator struct {
field *shared.StringField
field *shared.FormStringField
displayName string
description string
}
Expand Down Expand Up @@ -385,10 +394,10 @@ func (v Int64FieldValidator) Prompt(isFirstRun bool) {
}

// isFieldRequired checks if a field is required based on its validation rules.
func isFieldRequired(field shared.FieldInput) bool {
func isFieldRequired(field shared.FormField) bool {
switch {
case field.StringField != nil:
rules := field.StringField.StringRules
case field.FormStringField != nil:
rules := field.FormStringField.StringRules
if rules == nil {
return false
}
Expand Down Expand Up @@ -420,12 +429,8 @@ func isFieldRequired(field shared.FieldInput) bool {
}

// validateFormData validates that all required form fields are present.
func validateFormData(form *shared.FormInput, requestData map[string]any) error {
if form == nil {
return nil
}

for _, field := range form.Fields {
func validateFormData(form *shared.RequestSchemaForm, requestData map[string]any) error {
for _, field := range formFields(form) {
fieldName := client.StringFromPtr(field.Name)
if fieldName == "" {
continue
Expand Down
9 changes: 3 additions & 6 deletions cmd/cone/get_drop_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,12 @@ func runGet(cmd *cobra.Command, args []string) error {
task := accessRequest.TaskView.Task

// Check if the task has form fields
hasFormFields := task.Form != nil
if hasFormFields {
hasFormFields = len(task.Form.Fields) > 0
}
hasFormFields := len(formFields(task.RequestSchemaForm)) > 0

if hasFormFields {
// Collect form fields if not already provided
if len(requestData) == 0 {
collectedData, err := collectFormFields(ctx, v, task.Form)
collectedData, err := collectFormFields(ctx, v, task.RequestSchemaForm)
if err != nil {
return nil, fmt.Errorf("error collecting form fields: %w", err)
}
Expand All @@ -293,7 +290,7 @@ func runGet(cmd *cobra.Command, args []string) error {
}
} else {
// Validate that provided form data matches the form structure
if err := validateFormData(task.Form, requestData); err != nil {
if err := validateFormData(task.RequestSchemaForm, requestData); err != nil {
pterm.Warning.Printf("Form data validation warning: %v\n", err)
}
}
Expand Down
1 change: 1 addition & 0 deletions cmd/cone/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func runCli(ctx context.Context) int {
cliCmd.AddCommand(virtualEntitlementsCmd())
cliCmd.AddCommand(generateAliasCmd())
cliCmd.AddCommand(awsCmd())
cliCmd.AddCommand(secretCmd())

err = cliCmd.ExecuteContext(ctx)
if err != nil {
Expand Down
Loading
Loading