diff --git a/identity/handler_test.go b/identity/handler_test.go index 12bbde6cfa80..85473ab4c4fa 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -754,7 +754,7 @@ func TestHandler(t *testing.T) { }) t.Run("suite=PATCH identities", func(t *testing.T) { - t.Run("case=fails on > 100 identities", func(t *testing.T) { + t.Run("case=fails with too many patches", func(t *testing.T) { tooMany := make([]*identity.BatchIdentityPatch, identity.BatchPatchIdentitiesLimit+1) for i := range tooMany { tooMany[i] = &identity.BatchIdentityPatch{Create: validCreateIdentityBody("too-many-patches", i)} @@ -767,8 +767,8 @@ func TestHandler(t *testing.T) { t.Run("case=fails some on a bad identity", func(t *testing.T) { // Test setup: we have a list of valid identitiy patches and a list of invalid ones. // Each run adds one invalid patch to the list and sends it to the server. - // --> we expect the server to fail all patches in the list. - // Finally, we send just the valid patches + // --> we expect the server to fail only the bad patches in the list. + // Finally, we send just valid patches // --> we expect the server to succeed all patches in the list. t.Run("case=invalid patches fail", func(t *testing.T) { @@ -782,24 +782,23 @@ func TestHandler(t *testing.T) { {Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits {Create: validCreateIdentityBody("valid", 4)}, } + expectedToPass := []*identity.BatchIdentityPatch{patches[0], patches[1], patches[3], patches[5], patches[7]} // Create unique IDs for each patch - var patchIDs []string + patchIDs := make([]string, len(patches)) for i, p := range patches { id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i)) p.ID = &id - patchIDs = append(patchIDs, id.String()) + patchIDs[i] = id.String() } req := &identity.BatchPatchIdentitiesBody{Identities: patches} body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req) var actions []string - for _, a := range body.Get("identities.#.action").Array() { - actions = append(actions, a.String()) - } - assert.Equal(t, + require.NoErrorf(t, json.Unmarshal(([]byte)(body.Get("identities.#.action").Raw), &actions), "%s", body) + assert.Equalf(t, []string{"create", "create", "error", "create", "error", "create", "error", "create"}, - actions, body) + actions, "%s", body) // Check that all patch IDs are returned for i, gotPatchID := range body.Get("identities.#.patch_id").Array() { @@ -811,6 +810,37 @@ func TestHandler(t *testing.T) { assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String()) assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String()) + // Only collect identity IDs from successful patches. Error entries have + // no "identity" field, so iterating the full array and unmarshaling + // would yield zero UUIDs for those positions and corrupt the slice. + var identityIDs []uuid.UUID + for _, item := range body.Get("identities").Array() { + if item.Get("action").String() == string(identity.ActionCreate) { + id := uuid.FromStringOrNil(item.Get("identity").String()) + require.NotZerof(t, id, "expected non-zero UUID for create action: %s", body) + identityIDs = append(identityIDs, id) + } + } + require.Lenf(t, identityIDs, len(expectedToPass), "%s", body) + + actualIdentities, _, err := reg.Persister().ListIdentities(ctx, identity.ListIdentityParameters{IdsFilter: identityIDs}) + require.NoError(t, err) + actualIdentityIDs := make([]uuid.UUID, len(actualIdentities)) + for i, id := range actualIdentities { + actualIdentityIDs[i] = id.ID + } + assert.ElementsMatchf(t, identityIDs, actualIdentityIDs, "%s", body) + + expectedTraits := make(map[string]string, len(expectedToPass)) + for i, p := range expectedToPass { + expectedTraits[identityIDs[i].String()] = string(p.Create.Traits) + } + actualTraits := make(map[string]string, len(actualIdentities)) + for _, id := range actualIdentities { + actualTraits[id.ID.String()] = string(id.Traits) + } + + assert.Equal(t, expectedTraits, actualTraits) }) t.Run("valid patches succeed", func(t *testing.T) { @@ -1890,7 +1920,7 @@ func validCreateIdentityBody(prefix string, i int) *identity.CreateIdentityBody identity.VerifiableAddressStatusCompleted, } - for j := 0; j < 4; j++ { + for j := range 4 { email := fmt.Sprintf("%s-%d-%d@ory.sh", prefix, i, j) traits.Emails = append(traits.Emails, email) verifiableAddresses = append(verifiableAddresses, identity.VerifiableAddress{ diff --git a/identity/manager.go b/identity/manager.go index 3bc5b08e0158..a09a08a778cd 100644 --- a/identity/manager.go +++ b/identity/manager.go @@ -333,6 +333,12 @@ type CreateIdentitiesError struct { failedIdentities map[*Identity]*herodot.DefaultError } +func NewCreateIdentitiesError(capacity int) *CreateIdentitiesError { + return &CreateIdentitiesError{ + failedIdentities: make(map[*Identity]*herodot.DefaultError, capacity), + } +} + func (e *CreateIdentitiesError) Error() string { e.init() return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities)) @@ -370,7 +376,7 @@ func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity { return nil } func (e *CreateIdentitiesError) ErrOrNil() error { - if e.failedIdentities == nil || len(e.failedIdentities) == 0 { + if e == nil || len(e.failedIdentities) == 0 { return nil } return e @@ -385,7 +391,7 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity, ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities") defer otelx.End(span, &err) - createIdentitiesError := &CreateIdentitiesError{} + createIdentitiesError := NewCreateIdentitiesError(len(identities)) validIdentities := make([]*Identity, 0, len(identities)) for _, ident := range identities { if ident.SchemaID == "" { diff --git a/identity/test/pool.go b/identity/test/pool.go index 4d9f4c440910..99f92813831e 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -351,12 +351,80 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager, assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers) assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute) assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute) + // because of mysql precision assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second) assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second) require.NoError(t, p.DeleteIdentity(ctx, id.ID)) } }) + + t.Run("create exactly the non-conflicting ones", func(t *testing.T) { + // Pre-insert 40 identities (indices 0–39) so the conflicts in the + // batch below are against rows already in the DB. Without this, + // intra-batch conflict resolution is non-deterministic — either side + // of each duplicate pair could win. + preExisting := make([]*identity.Identity, 40) + for i := range preExisting { + preExisting[i] = NewTestIdentity(4, "persister-create-multiple-2", i) + } + require.NoError(t, p.CreateIdentities(ctx, preExisting...)) + defer func() { + for _, id := range preExisting { + require.NoError(t, p.DeleteIdentity(ctx, id.ID)) + } + }() + + // First 60 use indices 100–159 (no conflicts); last 40 duplicate the + // pre-existing indices 0–39 and are guaranteed to fail. + identities := make([]*identity.Identity, 100) + for i := range identities[:60] { + identities[i] = NewTestIdentity(4, "persister-create-multiple-2", 100+i) + } + for i := range identities[60:] { + identities[60+i] = NewTestIdentity(4, "persister-create-multiple-2", i) + } + err := p.CreateIdentities(ctx, identities...) + if dbname == "mysql" { + // partial inserts are not supported on mysql + assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation) + return + } + + errWithCtx := new(identity.CreateIdentitiesError) + require.ErrorAsf(t, err, &errWithCtx, "%#v", err) + + for _, id := range identities[:60] { + require.NotZero(t, id.ID) + + idFromDB, err := p.GetIdentity(ctx, id.ID, identity.ExpandEverything) + require.NoError(t, err) + + credFromDB := idFromDB.Credentials[identity.CredentialsTypePassword] + assert.Equal(t, id.ID, idFromDB.ID) + assert.Equal(t, id.SchemaID, idFromDB.SchemaID) + assert.Equal(t, id.SchemaURL, idFromDB.SchemaURL) + assert.Equal(t, id.State, idFromDB.State) + + // We test that the values are plausible in the handler test already. + assert.Equal(t, len(id.VerifiableAddresses), len(idFromDB.VerifiableAddresses)) + assert.Equal(t, len(id.RecoveryAddresses), len(idFromDB.RecoveryAddresses)) + + assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers) + assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute) + assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute) + // because of mysql precision + assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second) + assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second) + + require.NoError(t, p.DeleteIdentity(ctx, id.ID)) + } + + for _, id := range identities[60:] { + failed := errWithCtx.Find(id) + assert.NotNil(t, failed) + } + }) }) t.Run("case=should error when the identity ID does not exist", func(t *testing.T) { diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 4523ce7f2146..4aa2bf596e45 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -593,7 +593,8 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... } }() - return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + var partialErr *identity.CreateIdentitiesError + if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { conn := &batch.TracerConnection{ Tracer: p.r.Tracer(ctx), Connection: tx, @@ -601,6 +602,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... succeededIDs = make([]uuid.UUID, 0, len(identities)) failedIdentityIDs := make(map[uuid.UUID]struct{}) + partialErr = nil // Don't use batch.WithPartialInserts, because identities have no other // constraints other than the primary key that could cause conflicts. @@ -653,7 +655,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... // If any of the batch inserts failed on conflict, let's delete the corresponding // identities and return a list of failed identities in the error. if len(failedIdentityIDs) > 0 { - partialErr := &identity.CreateIdentitiesError{} + partialErr = identity.NewCreateIdentitiesError(len(failedIdentityIDs)) failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs)) for _, ident := range identities { @@ -667,10 +669,14 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... // Manually roll back by deleting the identities that were inserted before the // error occurred. if err := p.DeleteIdentities(ctx, failedIDs); err != nil { - return sqlcon.HandleError(err) + // If cleanup fails (e.g. transient DB error), log and still commit + // the successful inserts rather than rolling back everything. + // The orphaned identity records may need manual cleanup. + p.r.Logger().WithError(sqlcon.HandleError(err)). + Error("Failed to delete conflicting identities during batch create; orphaned records may remain") } - return partialErr + return nil } else { // No failures: report all identities as created. for _, ident := range identities { @@ -679,7 +685,10 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... } return nil - }) + }); err != nil { + return err + } + return partialErr.ErrOrNil() } func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) {