Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ func TestHandler(t *testing.T) {
})

t.Run("suite=PATCH identities", func(t *testing.T) {
t.Run("case=fails on > 100 identities", func(t *testing.T) {
t.Run("case=fails with too many patches", func(t *testing.T) {
tooMany := make([]*identity.BatchIdentityPatch, identity.BatchPatchIdentitiesLimit+1)
for i := range tooMany {
tooMany[i] = &identity.BatchIdentityPatch{Create: validCreateIdentityBody("too-many-patches", i)}
Expand All @@ -767,8 +767,8 @@ func TestHandler(t *testing.T) {
t.Run("case=fails some on a bad identity", func(t *testing.T) {
// Test setup: we have a list of valid identitiy patches and a list of invalid ones.
// Each run adds one invalid patch to the list and sends it to the server.
// --> we expect the server to fail all patches in the list.
// Finally, we send just the valid patches
// --> we expect the server to fail only the bad patches in the list.
// Finally, we send just valid patches
// --> we expect the server to succeed all patches in the list.

t.Run("case=invalid patches fail", func(t *testing.T) {
Expand All @@ -782,24 +782,23 @@ func TestHandler(t *testing.T) {
{Create: &identity.CreateIdentityBody{Traits: json.RawMessage(`"invalid traits"`)}}, // <-- invalid traits
{Create: validCreateIdentityBody("valid", 4)},
}
expectedToPass := []*identity.BatchIdentityPatch{patches[0], patches[1], patches[3], patches[5], patches[7]}

// Create unique IDs for each patch
var patchIDs []string
patchIDs := make([]string, len(patches))
for i, p := range patches {
id := uuid.NewV5(uuid.Nil, fmt.Sprintf("%d", i))
p.ID = &id
patchIDs = append(patchIDs, id.String())
patchIDs[i] = id.String()
}

req := &identity.BatchPatchIdentitiesBody{Identities: patches}
body := send(t, adminTS, "PATCH", "/identities", http.StatusOK, req)
var actions []string
for _, a := range body.Get("identities.#.action").Array() {
actions = append(actions, a.String())
}
assert.Equal(t,
require.NoErrorf(t, json.Unmarshal(([]byte)(body.Get("identities.#.action").Raw), &actions), "%s", body)
assert.Equalf(t,
[]string{"create", "create", "error", "create", "error", "create", "error", "create"},
actions, body)
actions, "%s", body)

// Check that all patch IDs are returned
for i, gotPatchID := range body.Get("identities.#.patch_id").Array() {
Expand All @@ -811,6 +810,37 @@ func TestHandler(t *testing.T) {
assert.Equal(t, "Conflict", body.Get("identities.4.error.status").String())
assert.Equal(t, "Bad Request", body.Get("identities.6.error.status").String())

// Only collect identity IDs from successful patches. Error entries have
// no "identity" field, so iterating the full array and unmarshaling
// would yield zero UUIDs for those positions and corrupt the slice.
var identityIDs []uuid.UUID
for _, item := range body.Get("identities").Array() {
if item.Get("action").String() == string(identity.ActionCreate) {
id := uuid.FromStringOrNil(item.Get("identity").String())
require.NotZerof(t, id, "expected non-zero UUID for create action: %s", body)
identityIDs = append(identityIDs, id)
}
}
require.Lenf(t, identityIDs, len(expectedToPass), "%s", body)

actualIdentities, _, err := reg.Persister().ListIdentities(ctx, identity.ListIdentityParameters{IdsFilter: identityIDs})
require.NoError(t, err)
actualIdentityIDs := make([]uuid.UUID, len(actualIdentities))
for i, id := range actualIdentities {
actualIdentityIDs[i] = id.ID
}
assert.ElementsMatchf(t, identityIDs, actualIdentityIDs, "%s", body)

expectedTraits := make(map[string]string, len(expectedToPass))
for i, p := range expectedToPass {
expectedTraits[identityIDs[i].String()] = string(p.Create.Traits)
}
actualTraits := make(map[string]string, len(actualIdentities))
for _, id := range actualIdentities {
actualTraits[id.ID.String()] = string(id.Traits)
}

assert.Equal(t, expectedTraits, actualTraits)
})

t.Run("valid patches succeed", func(t *testing.T) {
Expand Down Expand Up @@ -1890,7 +1920,7 @@ func validCreateIdentityBody(prefix string, i int) *identity.CreateIdentityBody
identity.VerifiableAddressStatusCompleted,
}

for j := 0; j < 4; j++ {
for j := range 4 {
email := fmt.Sprintf("%s-%d-%d@ory.sh", prefix, i, j)
traits.Emails = append(traits.Emails, email)
verifiableAddresses = append(verifiableAddresses, identity.VerifiableAddress{
Expand Down
10 changes: 8 additions & 2 deletions identity/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ type CreateIdentitiesError struct {
failedIdentities map[*Identity]*herodot.DefaultError
}

func NewCreateIdentitiesError(capacity int) *CreateIdentitiesError {
return &CreateIdentitiesError{
failedIdentities: make(map[*Identity]*herodot.DefaultError, capacity),
}
}

func (e *CreateIdentitiesError) Error() string {
e.init()
return fmt.Sprintf("create identities error: %d identities failed", len(e.failedIdentities))
Expand Down Expand Up @@ -370,7 +376,7 @@ func (e *CreateIdentitiesError) Find(ident *Identity) *FailedIdentity {
return nil
}
func (e *CreateIdentitiesError) ErrOrNil() error {
if e.failedIdentities == nil || len(e.failedIdentities) == 0 {
if e == nil || len(e.failedIdentities) == 0 {
return nil
}
return e
Expand All @@ -385,7 +391,7 @@ func (m *Manager) CreateIdentities(ctx context.Context, identities []*Identity,
ctx, span := m.r.Tracer(ctx).Tracer().Start(ctx, "identity.Manager.CreateIdentities")
defer otelx.End(span, &err)

createIdentitiesError := &CreateIdentitiesError{}
createIdentitiesError := NewCreateIdentitiesError(len(identities))
validIdentities := make([]*Identity, 0, len(identities))
for _, ident := range identities {
if ident.SchemaID == "" {
Expand Down
68 changes: 68 additions & 0 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,80 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager,
assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute)
// because of mysql precision
assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second)
assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second)

require.NoError(t, p.DeleteIdentity(ctx, id.ID))
}
})

t.Run("create exactly the non-conflicting ones", func(t *testing.T) {
// Pre-insert 40 identities (indices 0–39) so the conflicts in the
// batch below are against rows already in the DB. Without this,
// intra-batch conflict resolution is non-deterministic — either side
// of each duplicate pair could win.
preExisting := make([]*identity.Identity, 40)
for i := range preExisting {
preExisting[i] = NewTestIdentity(4, "persister-create-multiple-2", i)
}
require.NoError(t, p.CreateIdentities(ctx, preExisting...))
defer func() {
for _, id := range preExisting {
require.NoError(t, p.DeleteIdentity(ctx, id.ID))
}
}()

// First 60 use indices 100–159 (no conflicts); last 40 duplicate the
// pre-existing indices 0–39 and are guaranteed to fail.
identities := make([]*identity.Identity, 100)
for i := range identities[:60] {
identities[i] = NewTestIdentity(4, "persister-create-multiple-2", 100+i)
}
for i := range identities[60:] {
identities[60+i] = NewTestIdentity(4, "persister-create-multiple-2", i)
}
err := p.CreateIdentities(ctx, identities...)
if dbname == "mysql" {
// partial inserts are not supported on mysql
assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation)
return
}

errWithCtx := new(identity.CreateIdentitiesError)
require.ErrorAsf(t, err, &errWithCtx, "%#v", err)

for _, id := range identities[:60] {
require.NotZero(t, id.ID)

idFromDB, err := p.GetIdentity(ctx, id.ID, identity.ExpandEverything)
require.NoError(t, err)

credFromDB := idFromDB.Credentials[identity.CredentialsTypePassword]
assert.Equal(t, id.ID, idFromDB.ID)
assert.Equal(t, id.SchemaID, idFromDB.SchemaID)
assert.Equal(t, id.SchemaURL, idFromDB.SchemaURL)
assert.Equal(t, id.State, idFromDB.State)

// We test that the values are plausible in the handler test already.
assert.Equal(t, len(id.VerifiableAddresses), len(idFromDB.VerifiableAddresses))
assert.Equal(t, len(id.RecoveryAddresses), len(idFromDB.RecoveryAddresses))

assert.Equal(t, id.Credentials["password"].Identifiers, credFromDB.Identifiers)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.CreatedAt, time.Minute)
assert.WithinDuration(t, time.Now().UTC(), credFromDB.UpdatedAt, time.Minute)
// because of mysql precision
assert.WithinDuration(t, id.CreatedAt, idFromDB.CreatedAt, time.Second)
assert.WithinDuration(t, id.UpdatedAt, idFromDB.UpdatedAt, time.Second)

require.NoError(t, p.DeleteIdentity(ctx, id.ID))
}

for _, id := range identities[60:] {
failed := errWithCtx.Find(id)
assert.NotNil(t, failed)
}
})
})

t.Run("case=should error when the identity ID does not exist", func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
19 changes: 14 additions & 5 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,14 +593,16 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
}
}()

return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
var partialErr *identity.CreateIdentitiesError
if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
conn := &batch.TracerConnection{
Tracer: p.r.Tracer(ctx),
Connection: tx,
}

succeededIDs = make([]uuid.UUID, 0, len(identities))
failedIdentityIDs := make(map[uuid.UUID]struct{})
partialErr = nil

// Don't use batch.WithPartialInserts, because identities have no other
// constraints other than the primary key that could cause conflicts.
Expand Down Expand Up @@ -653,7 +655,7 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
// If any of the batch inserts failed on conflict, let's delete the corresponding
// identities and return a list of failed identities in the error.
if len(failedIdentityIDs) > 0 {
partialErr := &identity.CreateIdentitiesError{}
partialErr = identity.NewCreateIdentitiesError(len(failedIdentityIDs))
failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs))

for _, ident := range identities {
Expand All @@ -667,10 +669,14 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
// Manually roll back by deleting the identities that were inserted before the
// error occurred.
if err := p.DeleteIdentities(ctx, failedIDs); err != nil {
return sqlcon.HandleError(err)
// If cleanup fails (e.g. transient DB error), log and still commit
// the successful inserts rather than rolling back everything.
// The orphaned identity records may need manual cleanup.
p.r.Logger().WithError(sqlcon.HandleError(err)).
Error("Failed to delete conflicting identities during batch create; orphaned records may remain")
}

return partialErr
return nil
} else {
// No failures: report all identities as created.
for _, ident := range identities {
Expand All @@ -679,7 +685,10 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ...
}

return nil
})
}); err != nil {
return err
}
return partialErr.ErrOrNil()
}

func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) {
Expand Down