Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions invokeai/frontend/web/src/services/events/nodeExecutionState.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import { deepClone } from 'common/util/deepClone';
import type { NodeExecutionState } from 'features/nodes/types/invocation';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { S } from 'services/api/types';

const getInvocationKey = (data: { item_id: number; invocation: { id: string } }) =>
`${data.item_id}:${data.invocation.id}`;

export const getUpdatedNodeExecutionStateOnInvocationStarted = (
nodeExecutionState: NodeExecutionState | undefined,
data: S['InvocationStartedEvent'],
completedInvocationKeys: Set<string>
) => {
if (!nodeExecutionState) {
return;
}

if (completedInvocationKeys.has(getInvocationKey(data))) {
return;
}

const _nodeExecutionState = deepClone(nodeExecutionState);
_nodeExecutionState.status = zNodeStatus.enum.IN_PROGRESS;

return _nodeExecutionState;
};

export const getUpdatedNodeExecutionStateOnInvocationProgress = (
nodeExecutionState: NodeExecutionState | undefined,
data: S['InvocationProgressEvent'],
completedInvocationKeys: Set<string>
) => {
if (!nodeExecutionState) {
return;
}

if (completedInvocationKeys.has(getInvocationKey(data))) {
return;
}

const _nodeExecutionState = deepClone(nodeExecutionState);
_nodeExecutionState.status = zNodeStatus.enum.IN_PROGRESS;
_nodeExecutionState.progress = data.percentage ?? null;
_nodeExecutionState.progressImage = data.image ?? null;

return _nodeExecutionState;
};

export const getUpdatedNodeExecutionStateOnInvocationComplete = (
nodeExecutionState: NodeExecutionState | undefined,
data: S['InvocationCompleteEvent'],
completedInvocationKeys: Set<string>
) => {
if (!nodeExecutionState) {
return;
}

const completedInvocationKey = getInvocationKey(data);

if (completedInvocationKeys.has(completedInvocationKey)) {
return;
}

const _nodeExecutionState = deepClone(nodeExecutionState);
_nodeExecutionState.status = zNodeStatus.enum.COMPLETED;
if (_nodeExecutionState.progress !== null) {
_nodeExecutionState.progress = 1;
}
_nodeExecutionState.outputs.push(data.result);
completedInvocationKeys.add(completedInvocationKey);

return _nodeExecutionState;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import type { NodeExecutionState } from 'features/nodes/types/invocation';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { S } from 'services/api/types';
import { describe, expect, it } from 'vitest';

import {
getUpdatedNodeExecutionStateOnInvocationComplete,
getUpdatedNodeExecutionStateOnInvocationProgress,
getUpdatedNodeExecutionStateOnInvocationStarted,
} from './nodeExecutionState';

const buildNodeExecutionState = (overrides: Partial<NodeExecutionState> = {}): NodeExecutionState => ({
nodeId: 'node-1',
status: zNodeStatus.enum.PENDING,
progress: null,
progressImage: null,
outputs: [],
error: null,
...overrides,
});

const buildInvocationStartedEvent = (
overrides: Partial<S['InvocationStartedEvent']> = {}
): S['InvocationStartedEvent'] =>
({
queue_id: 'default',
item_id: 1,
batch_id: 'batch-1',
origin: 'workflows',
destination: 'gallery',
user_id: 'user-1',
session_id: 'session-1',
invocation_source_id: 'node-1',
invocation: {
id: 'prepared-node-1',
type: 'add',
},
...overrides,
}) as S['InvocationStartedEvent'];

const buildInvocationProgressEvent = (
overrides: Partial<S['InvocationProgressEvent']> = {}
): S['InvocationProgressEvent'] =>
({
queue_id: 'default',
item_id: 1,
batch_id: 'batch-1',
origin: 'workflows',
destination: 'gallery',
user_id: 'user-1',
session_id: 'session-1',
invocation_source_id: 'node-1',
invocation: {
id: 'prepared-node-1',
type: 'add',
},
percentage: 0.42,
image: {
dataURL: 'data:image/png;base64,abc',
width: 64,
height: 64,
},
message: 'working',
...overrides,
}) as S['InvocationProgressEvent'];

const buildInvocationCompleteEvent = (
overrides: Partial<S['InvocationCompleteEvent']> = {}
): S['InvocationCompleteEvent'] =>
({
queue_id: 'default',
item_id: 1,
batch_id: 'batch-1',
origin: 'workflows',
destination: 'gallery',
user_id: 'user-1',
session_id: 'session-1',
invocation_source_id: 'node-1',
invocation: {
id: 'prepared-node-1',
type: 'add',
},
result: {
type: 'integer_output',
value: 42,
},
...overrides,
}) as S['InvocationCompleteEvent'];

describe(getUpdatedNodeExecutionStateOnInvocationStarted.name, () => {
it('marks the node in progress on invocation start', () => {
const updated = getUpdatedNodeExecutionStateOnInvocationStarted(
buildNodeExecutionState(),
buildInvocationStartedEvent(),
new Set<string>()
);

expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS);
});

it('ignores a late started event after that invocation already completed', () => {
const event = buildInvocationStartedEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationStarted(
buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1 }),
event,
new Set([`${event.item_id}:${event.invocation.id}`])
);

expect(updated).toBeUndefined();
});
});

describe(getUpdatedNodeExecutionStateOnInvocationProgress.name, () => {
it('marks the node in progress and preserves progress updates', () => {
const event = buildInvocationProgressEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationProgress(
buildNodeExecutionState(),
event,
new Set<string>()
);

expect(updated?.status).toBe(zNodeStatus.enum.IN_PROGRESS);
expect(updated?.progress).toBe(event.percentage);
expect(updated?.progressImage).toEqual(event.image);
});

it('ignores a late progress event after that invocation already completed', () => {
const event = buildInvocationProgressEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationProgress(
buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1 }),
event,
new Set([`${event.item_id}:${event.invocation.id}`])
);

expect(updated).toBeUndefined();
});
});

describe(getUpdatedNodeExecutionStateOnInvocationComplete.name, () => {
it('records a completed invocation result once', () => {
const event = buildInvocationCompleteEvent();
const completedInvocationKeys = new Set<string>();

const updated = getUpdatedNodeExecutionStateOnInvocationComplete(
buildNodeExecutionState({ status: zNodeStatus.enum.IN_PROGRESS, progress: 0.5 }),
event,
completedInvocationKeys
);

expect(updated?.status).toBe(zNodeStatus.enum.COMPLETED);
expect(updated?.progress).toBe(1);
expect(updated?.outputs).toEqual([event.result]);
expect(completedInvocationKeys).toEqual(new Set([`${event.item_id}:${event.invocation.id}`]));
});

it('ignores duplicate completion events for the same invocation', () => {
const event = buildInvocationCompleteEvent();
const updated = getUpdatedNodeExecutionStateOnInvocationComplete(
buildNodeExecutionState({ status: zNodeStatus.enum.COMPLETED, progress: 1, outputs: [event.result] }),
event,
new Set([`${event.item_id}:${event.invocation.id}`])
);

expect(updated).toBeUndefined();
});

it('allows the same prepared invocation id on a different queue item', () => {
const firstEvent = buildInvocationCompleteEvent({
item_id: 1,
result: { type: 'integer_output', value: 1 } as unknown as S['InvocationCompleteEvent']['result'],
});
const secondEvent = buildInvocationCompleteEvent({
item_id: 2,
result: { type: 'integer_output', value: 2 } as unknown as S['InvocationCompleteEvent']['result'],
});
const completedInvocationKeys = new Set<string>();

const firstUpdate = getUpdatedNodeExecutionStateOnInvocationComplete(
buildNodeExecutionState(),
firstEvent,
completedInvocationKeys
);
const secondUpdate = getUpdatedNodeExecutionStateOnInvocationComplete(
firstUpdate,
secondEvent,
completedInvocationKeys
);

expect(secondUpdate?.outputs).toEqual([firstEvent.result, secondEvent.result]);
});
});
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, AppGetState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { canvasWorkflowIntegrationProcessingCompleted } from 'features/controlLayers/store/canvasWorkflowIntegrationSlice';
import {
selectAutoSwitch,
Expand All @@ -12,15 +11,14 @@ import {
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { LRUCache } from 'lru-cache';
import { LIST_ALL_TAG } from 'services/api';
import { boardsApi } from 'services/api/endpoints/boards';
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO, S } from 'services/api/types';
import { getCategories } from 'services/api/util';
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
import { getUpdatedNodeExecutionStateOnInvocationComplete } from 'services/events/nodeExecutionState';
import { $lastProgressEvent } from 'services/events/stores';
import stableHash from 'stable-hash';
import type { Param0 } from 'tsafe';
Expand All @@ -38,13 +36,13 @@ const nodeTypeDenylist = ['load_image', 'image'];
*
* @param getState The Redux getState function.
* @param dispatch The Redux dispatch function.
* @param finishedQueueItemIds A cache of finished queue item IDs to prevent duplicate handling and avoid race
* conditions that can happen when a graph finishes very quickly.
* @param completedInvocationKeys A listener-local set used to dedupe repeated invocation completion events and to
* share completion knowledge with the other invocation event handlers.
*/
export const buildOnInvocationComplete = (
getState: AppGetState,
dispatch: AppDispatch,
finishedQueueItemIds: LRUCache<number, boolean>
completedInvocationKeys: Set<string>
) => {
const addImagesToGallery = async (data: S['InvocationCompleteEvent']) => {
if (nodeTypeDenylist.includes(data.invocation.type)) {
Expand Down Expand Up @@ -242,22 +240,24 @@ export const buildOnInvocationComplete = (
};

return async (data: S['InvocationCompleteEvent']) => {
if (finishedQueueItemIds.has(data.item_id)) {
log.trace({ data } as JsonObject, `Received event for already-finished queue item ${data.item_id}`);
return;
}
log.debug({ data } as JsonObject, `Invocation complete (${data.invocation.type}, ${data.invocation_source_id})`);

const nodeExecutionState = $nodeExecutionStates.get()[data.invocation_source_id];
const updatedNodeExecutionState = getUpdatedNodeExecutionStateOnInvocationComplete(
nodeExecutionState,
data,
completedInvocationKeys
);

if (nodeExecutionState) {
const _nodeExecutionState = deepClone(nodeExecutionState);
_nodeExecutionState.status = zNodeStatus.enum.COMPLETED;
if (_nodeExecutionState.progress !== null) {
_nodeExecutionState.progress = 1;
}
_nodeExecutionState.outputs.push(data.result);
upsertExecutionState(_nodeExecutionState.nodeId, _nodeExecutionState);
if (nodeExecutionState && !updatedNodeExecutionState) {
log.trace(
{ data } as JsonObject,
`Ignoring duplicate invocation complete (${data.invocation.type}, ${data.invocation_source_id})`
);
}

if (updatedNodeExecutionState) {
upsertExecutionState(updatedNodeExecutionState.nodeId, updatedNodeExecutionState);
}

// Clear canvas workflow integration processing state if needed
Expand Down
Loading
Loading