diff --git a/apps/api/src/integration-platform/controllers/task-integrations.controller.ts b/apps/api/src/integration-platform/controllers/task-integrations.controller.ts index 0a9383181d..f0e4cd9a40 100644 --- a/apps/api/src/integration-platform/controllers/task-integrations.controller.ts +++ b/apps/api/src/integration-platform/controllers/task-integrations.controller.ts @@ -20,14 +20,15 @@ import { getManifest, runAllChecks, type CheckRunResult, - type OAuthConfig, } from '@trycompai/integration-platform'; import { ConnectionRepository } from '../repositories/connection.repository'; import { ProviderRepository } from '../repositories/provider.repository'; import { CheckRunRepository } from '../repositories/check-run.repository'; import { CredentialVaultService } from '../services/credential-vault.service'; import { OAuthCredentialsService } from '../services/oauth-credentials.service'; +import { TaskIntegrationChecksService } from '../services/task-integration-checks.service'; import { getStringValue, toStringCredentials } from '../utils/credential-utils'; +import { isCheckDisabledForTask } from '../utils/disabled-task-checks'; import { db } from '@db'; import type { Prisma } from '@db'; @@ -39,6 +40,8 @@ interface TaskIntegrationCheck { checkName: string; checkDescription: string; isConnected: boolean; + /** True when the check has been manually disconnected from this task. */ + isDisabledForTask: boolean; needsConfiguration: boolean; connectionId?: string; connectionStatus?: string; @@ -56,6 +59,11 @@ interface RunCheckForTaskDto { checkId: string; } +interface ToggleCheckForTaskDto { + connectionId: string; + checkId: string; +} + @Controller({ path: 'integrations/tasks', version: '1' }) @ApiTags('Integrations') @UseGuards(HybridAuthGuard, PermissionGuard) @@ -69,18 +77,22 @@ export class TaskIntegrationsController { private readonly checkRunRepository: CheckRunRepository, private readonly credentialVaultService: CredentialVaultService, private readonly oauthCredentialsService: OAuthCredentialsService, + private readonly taskIntegrationChecksService: TaskIntegrationChecksService, ) {} /** - * Get all integration checks that can auto-complete a specific task template + * Get all integration checks that can auto-complete a specific task template. + * When a specific `taskId` is also provided, per-task disable state is + * resolved from the matching connection's metadata so the UI can show + * which checks have been manually disconnected from that task. */ @Get('template/:templateId/checks') @RequirePermission('integration', 'read') async getChecksForTaskTemplate( @Param('templateId') templateId: string, @OrganizationId() organizationId: string, + taskIdForDisableState?: string, ): Promise<{ checks: TaskIntegrationCheck[] }> { - const manifests = getActiveManifests(); const checks: TaskIntegrationCheck[] = []; @@ -136,6 +148,15 @@ export class TaskIntegrationsController { oauthConfigured = availability.available; } + const isDisabledForTask = + !!taskIdForDisableState && + !!connection && + isCheckDisabledForTask( + connection.metadata, + taskIdForDisableState, + check.id, + ); + checks.push({ integrationId: manifest.id, integrationName: manifest.name, @@ -144,6 +165,7 @@ export class TaskIntegrationsController { checkName: check.name, checkDescription: check.description, isConnected: !!connection && connection.status === 'active', + isDisabledForTask, needsConfiguration, connectionId: connection?.id, connectionStatus: connection?.status, @@ -169,7 +191,6 @@ export class TaskIntegrationsController { checks: TaskIntegrationCheck[]; task: { id: string; title: string; templateId: string | null }; }> { - // Get the task to find its template ID const task = await db.task.findUnique({ where: { id: taskId, organizationId }, @@ -187,10 +208,11 @@ export class TaskIntegrationsController { }; } - // Get checks for this template + // Get checks for this template, annotated with per-task disable state const { checks } = await this.getChecksForTaskTemplate( task.taskTemplateId, organizationId, + task.id, ); return { @@ -215,7 +237,6 @@ export class TaskIntegrationsController { checkRunId?: string; taskStatus?: string | null; }> { - const { connectionId, checkId } = body; // Verify task exists @@ -240,6 +261,14 @@ export class TaskIntegrationsController { ); } + // Reject runs for checks that have been disconnected from this task. + if (isCheckDisabledForTask(connection.metadata, taskId, checkId)) { + throw new HttpException( + 'This check is disconnected from the task. Reconnect it before running.', + HttpStatus.BAD_REQUEST, + ); + } + // Get provider and manifest const provider = await this.providerRepository.findById( connection.providerId, @@ -493,6 +522,48 @@ export class TaskIntegrationsController { } } + /** + * Disconnect a single integration check from a specific task. + * Does not affect the connection itself or any other task that uses the + * same check. Scheduled runs, manual runs, and the task detail UI will all + * skip this (task, check) pair until it is reconnected. + */ + @Post(':taskId/checks/disconnect') + @RequirePermission('integration', 'update') + async disconnectCheckFromTask( + @Param('taskId') taskId: string, + @OrganizationId() organizationId: string, + @Body() body: ToggleCheckForTaskDto, + ): Promise<{ success: true; disabled: true }> { + await this.taskIntegrationChecksService.disconnectCheckFromTask({ + taskId, + connectionId: body.connectionId, + checkId: body.checkId, + organizationId, + }); + return { success: true, disabled: true }; + } + + /** + * Re-enable a previously disconnected integration check for a specific + * task. Inverse of the disconnect endpoint. + */ + @Post(':taskId/checks/reconnect') + @RequirePermission('integration', 'update') + async reconnectCheckToTask( + @Param('taskId') taskId: string, + @OrganizationId() organizationId: string, + @Body() body: ToggleCheckForTaskDto, + ): Promise<{ success: true; disabled: false }> { + await this.taskIntegrationChecksService.reconnectCheckToTask({ + taskId, + connectionId: body.connectionId, + checkId: body.checkId, + organizationId, + }); + return { success: true, disabled: false }; + } + /** * Get check run history for a task */ @@ -502,7 +573,6 @@ export class TaskIntegrationsController { @Param('taskId') taskId: string, @Query('limit') limit?: string, ) { - const runs = await this.checkRunRepository.findByTask( taskId, limit ? parseInt(limit, 10) : 10, diff --git a/apps/api/src/integration-platform/integration-platform.module.ts b/apps/api/src/integration-platform/integration-platform.module.ts index 0a24c44859..05fb6d43ee 100644 --- a/apps/api/src/integration-platform/integration-platform.module.ts +++ b/apps/api/src/integration-platform/integration-platform.module.ts @@ -17,6 +17,7 @@ import { AutoCheckRunnerService } from './services/auto-check-runner.service'; import { ConnectionAuthTeardownService } from './services/connection-auth-teardown.service'; import { OAuthTokenRevocationService } from './services/oauth-token-revocation.service'; import { DynamicManifestLoaderService } from './services/dynamic-manifest-loader.service'; +import { TaskIntegrationChecksService } from './services/task-integration-checks.service'; import { ProviderRepository } from './repositories/provider.repository'; import { ConnectionRepository } from './repositories/connection.repository'; import { CredentialRepository } from './repositories/credential.repository'; @@ -52,6 +53,7 @@ import { GenericEmployeeSyncService } from './services/generic-employee-sync.ser OAuthTokenRevocationService, ConnectionAuthTeardownService, DynamicManifestLoaderService, + TaskIntegrationChecksService, IntegrationSyncLoggerService, GenericEmployeeSyncService, // Repositories diff --git a/apps/api/src/integration-platform/services/task-integration-checks.service.spec.ts b/apps/api/src/integration-platform/services/task-integration-checks.service.spec.ts new file mode 100644 index 0000000000..e335410472 --- /dev/null +++ b/apps/api/src/integration-platform/services/task-integration-checks.service.spec.ts @@ -0,0 +1,259 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { BadRequestException, NotFoundException } from '@nestjs/common'; +import { TaskIntegrationChecksService } from './task-integration-checks.service'; +import { ConnectionRepository } from '../repositories/connection.repository'; +import { ProviderRepository } from '../repositories/provider.repository'; +import { ConnectionService } from './connection.service'; +import { DISABLED_TASK_CHECKS_KEY } from '../utils/disabled-task-checks'; + +jest.mock('@db', () => ({ + db: { + task: { + findUnique: jest.fn(), + }, + }, +})); + +jest.mock('@trycompai/integration-platform', () => ({ + getManifest: jest.fn(), +})); + +import { db } from '@db'; +import { getManifest } from '@trycompai/integration-platform'; + +const mockedGetManifest = getManifest as jest.MockedFunction< + typeof getManifest +>; +// Grabbing through the module reference avoids the `unbound-method` lint rule +// that fires when you extract an instance method from an object literal. +const mockedFindTask = (db.task as unknown as { findUnique: jest.Mock }) + .findUnique; + +describe('TaskIntegrationChecksService', () => { + let service: TaskIntegrationChecksService; + + const mockConnectionRepository = { + findById: jest.fn(), + }; + + const mockConnectionService = { + updateConnectionMetadata: jest.fn(), + }; + + const mockProviderRepository = { + findById: jest.fn(), + }; + + const ORG_ID = 'org_1'; + const TASK_ID = 'tsk_1'; + const CONNECTION_ID = 'icn_1'; + const PROVIDER_ID = 'prv_1'; + const CHECK_ID = 'branch_protection'; + + const baseConnection = { + id: CONNECTION_ID, + organizationId: ORG_ID, + providerId: PROVIDER_ID, + metadata: { connectionName: 'My GitHub' }, + }; + + const baseManifest = { + id: 'github', + checks: [ + { id: CHECK_ID, name: 'Branch Protection' }, + { id: 'dependabot', name: 'Dependabot' }, + ], + }; + + beforeEach(async () => { + jest.clearAllMocks(); + + const module: TestingModule = await Test.createTestingModule({ + providers: [ + TaskIntegrationChecksService, + { provide: ConnectionRepository, useValue: mockConnectionRepository }, + { provide: ConnectionService, useValue: mockConnectionService }, + { provide: ProviderRepository, useValue: mockProviderRepository }, + ], + }).compile(); + + service = module.get(TaskIntegrationChecksService); + + mockConnectionRepository.findById.mockResolvedValue(baseConnection); + mockProviderRepository.findById.mockResolvedValue({ + id: PROVIDER_ID, + slug: 'github', + }); + mockedGetManifest.mockReturnValue(baseManifest as never); + mockedFindTask.mockResolvedValue({ id: TASK_ID } as never); + mockConnectionService.updateConnectionMetadata.mockResolvedValue( + baseConnection as never, + ); + }); + + describe('disconnectCheckFromTask', () => { + it('marks the check as disabled and persists merged metadata', async () => { + await service.disconnectCheckFromTask({ + taskId: TASK_ID, + connectionId: CONNECTION_ID, + checkId: CHECK_ID, + organizationId: ORG_ID, + }); + + expect( + mockConnectionService.updateConnectionMetadata, + ).toHaveBeenCalledTimes(1); + const [persistedId, persistedMetadata] = + mockConnectionService.updateConnectionMetadata.mock.calls[0]; + expect(persistedId).toBe(CONNECTION_ID); + expect(persistedMetadata.connectionName).toBe('My GitHub'); + expect(persistedMetadata[DISABLED_TASK_CHECKS_KEY]).toEqual({ + [TASK_ID]: [CHECK_ID], + }); + }); + + it('is idempotent if called twice for the same check', async () => { + mockConnectionRepository.findById + .mockResolvedValueOnce(baseConnection) + .mockResolvedValueOnce({ + ...baseConnection, + metadata: { + ...baseConnection.metadata, + [DISABLED_TASK_CHECKS_KEY]: { [TASK_ID]: [CHECK_ID] }, + }, + }); + + await service.disconnectCheckFromTask({ + taskId: TASK_ID, + connectionId: CONNECTION_ID, + checkId: CHECK_ID, + organizationId: ORG_ID, + }); + await service.disconnectCheckFromTask({ + taskId: TASK_ID, + connectionId: CONNECTION_ID, + checkId: CHECK_ID, + organizationId: ORG_ID, + }); + + const secondCallMetadata = + mockConnectionService.updateConnectionMetadata.mock.calls[1][1]; + expect(secondCallMetadata[DISABLED_TASK_CHECKS_KEY]).toEqual({ + [TASK_ID]: [CHECK_ID], + }); + }); + + it('throws NotFound when the connection belongs to another org', async () => { + mockConnectionRepository.findById.mockResolvedValue({ + ...baseConnection, + organizationId: 'another_org', + }); + + await expect( + service.disconnectCheckFromTask({ + taskId: TASK_ID, + connectionId: CONNECTION_ID, + checkId: CHECK_ID, + organizationId: ORG_ID, + }), + ).rejects.toBeInstanceOf(NotFoundException); + expect( + mockConnectionService.updateConnectionMetadata, + ).not.toHaveBeenCalled(); + }); + + it('throws NotFound when the task does not belong to the org', async () => { + mockedFindTask.mockResolvedValue(null); + + await expect( + service.disconnectCheckFromTask({ + taskId: TASK_ID, + connectionId: CONNECTION_ID, + checkId: CHECK_ID, + organizationId: ORG_ID, + }), + ).rejects.toBeInstanceOf(NotFoundException); + expect( + mockConnectionService.updateConnectionMetadata, + ).not.toHaveBeenCalled(); + }); + + it('throws BadRequest when the check id is unknown for the provider', async () => { + await expect( + service.disconnectCheckFromTask({ + taskId: TASK_ID, + connectionId: CONNECTION_ID, + checkId: 'does_not_exist', + organizationId: ORG_ID, + }), + ).rejects.toBeInstanceOf(BadRequestException); + expect( + mockConnectionService.updateConnectionMetadata, + ).not.toHaveBeenCalled(); + }); + }); + + describe('reconnectCheckToTask', () => { + it('removes the check from the disabled list and preserves other metadata', async () => { + mockConnectionRepository.findById.mockResolvedValue({ + ...baseConnection, + metadata: { + connectionName: 'My GitHub', + [DISABLED_TASK_CHECKS_KEY]: { + [TASK_ID]: [CHECK_ID, 'dependabot'], + }, + }, + }); + + await service.reconnectCheckToTask({ + taskId: TASK_ID, + connectionId: CONNECTION_ID, + checkId: CHECK_ID, + organizationId: ORG_ID, + }); + + const [, persistedMetadata] = + mockConnectionService.updateConnectionMetadata.mock.calls[0]; + expect(persistedMetadata.connectionName).toBe('My GitHub'); + expect(persistedMetadata[DISABLED_TASK_CHECKS_KEY]).toEqual({ + [TASK_ID]: ['dependabot'], + }); + }); + + it('cleans up the task entry when its list becomes empty', async () => { + mockConnectionRepository.findById.mockResolvedValue({ + ...baseConnection, + metadata: { + [DISABLED_TASK_CHECKS_KEY]: { + [TASK_ID]: [CHECK_ID], + }, + }, + }); + + await service.reconnectCheckToTask({ + taskId: TASK_ID, + connectionId: CONNECTION_ID, + checkId: CHECK_ID, + organizationId: ORG_ID, + }); + + const [, persistedMetadata] = + mockConnectionService.updateConnectionMetadata.mock.calls[0]; + expect(persistedMetadata[DISABLED_TASK_CHECKS_KEY]).toEqual({}); + }); + + it('is a no-op when the check was not disabled', async () => { + await service.reconnectCheckToTask({ + taskId: TASK_ID, + connectionId: CONNECTION_ID, + checkId: CHECK_ID, + organizationId: ORG_ID, + }); + + expect(mockConnectionService.updateConnectionMetadata).toHaveBeenCalled(); + const [, persistedMetadata] = + mockConnectionService.updateConnectionMetadata.mock.calls[0]; + expect(persistedMetadata[DISABLED_TASK_CHECKS_KEY]).toEqual({}); + }); + }); +}); diff --git a/apps/api/src/integration-platform/services/task-integration-checks.service.ts b/apps/api/src/integration-platform/services/task-integration-checks.service.ts new file mode 100644 index 0000000000..8f53746b9c --- /dev/null +++ b/apps/api/src/integration-platform/services/task-integration-checks.service.ts @@ -0,0 +1,140 @@ +import { + BadRequestException, + Injectable, + Logger, + NotFoundException, +} from '@nestjs/common'; +import { db } from '@db'; +import { getManifest } from '@trycompai/integration-platform'; +import { ConnectionRepository } from '../repositories/connection.repository'; +import { ProviderRepository } from '../repositories/provider.repository'; +import { ConnectionService } from './connection.service'; +import { + withCheckDisabled, + withCheckEnabled, +} from '../utils/disabled-task-checks'; + +/** + * Handles enable/disable of a single integration check for a single task. + * + * This does NOT disconnect the whole integration — only removes one check from + * one task. The disable state lives on the connection's metadata so it + * survives alongside credentials and is scoped to the specific connection that + * provides the check. + */ +@Injectable() +export class TaskIntegrationChecksService { + private readonly logger = new Logger(TaskIntegrationChecksService.name); + + constructor( + private readonly connectionRepository: ConnectionRepository, + private readonly connectionService: ConnectionService, + private readonly providerRepository: ProviderRepository, + ) {} + + /** + * Disconnect a single check from a single task. The connection stays active + * for all other tasks that use it. Validates that: + * - the task exists and belongs to the org + * - the connection exists and belongs to the org + * - the provider has a check with this id + */ + async disconnectCheckFromTask(params: { + taskId: string; + connectionId: string; + checkId: string; + organizationId: string; + }): Promise<{ disabled: true }> { + const { taskId, connectionId, checkId, organizationId } = params; + + const connection = await this.loadConnectionForOrg( + connectionId, + organizationId, + ); + await this.assertTaskInOrg(taskId, organizationId); + await this.assertCheckExists(connection.providerId, checkId); + + const nextMetadata = withCheckDisabled( + connection.metadata, + taskId, + checkId, + ); + await this.connectionService.updateConnectionMetadata( + connectionId, + nextMetadata, + ); + + this.logger.log( + `Disabled check ${checkId} for task ${taskId} on connection ${connectionId}`, + ); + return { disabled: true }; + } + + /** + * Re-enable a single check for a single task. Inverse of disconnect. + */ + async reconnectCheckToTask(params: { + taskId: string; + connectionId: string; + checkId: string; + organizationId: string; + }): Promise<{ disabled: false }> { + const { taskId, connectionId, checkId, organizationId } = params; + + const connection = await this.loadConnectionForOrg( + connectionId, + organizationId, + ); + await this.assertTaskInOrg(taskId, organizationId); + await this.assertCheckExists(connection.providerId, checkId); + + const nextMetadata = withCheckEnabled(connection.metadata, taskId, checkId); + await this.connectionService.updateConnectionMetadata( + connectionId, + nextMetadata, + ); + + this.logger.log( + `Re-enabled check ${checkId} for task ${taskId} on connection ${connectionId}`, + ); + return { disabled: false }; + } + + private async loadConnectionForOrg( + connectionId: string, + organizationId: string, + ) { + const connection = await this.connectionRepository.findById(connectionId); + if (!connection || connection.organizationId !== organizationId) { + throw new NotFoundException('Connection not found'); + } + return connection; + } + + private async assertTaskInOrg(taskId: string, organizationId: string) { + const task = await db.task.findUnique({ + where: { id: taskId, organizationId }, + select: { id: true }, + }); + if (!task) { + throw new NotFoundException('Task not found'); + } + } + + private async assertCheckExists(providerId: string, checkId: string) { + const provider = await this.providerRepository.findById(providerId); + if (!provider) { + throw new NotFoundException('Provider not found'); + } + const manifest = getManifest(provider.slug); + if (!manifest) { + throw new NotFoundException('Manifest not found'); + } + const check = manifest.checks?.find((c) => c.id === checkId); + if (!check) { + throw new BadRequestException( + `Check "${checkId}" is not defined for provider "${provider.slug}"`, + ); + } + } +} diff --git a/apps/api/src/integration-platform/utils/disabled-task-checks.spec.ts b/apps/api/src/integration-platform/utils/disabled-task-checks.spec.ts new file mode 100644 index 0000000000..5d16e5665d --- /dev/null +++ b/apps/api/src/integration-platform/utils/disabled-task-checks.spec.ts @@ -0,0 +1,244 @@ +import { + DISABLED_TASK_CHECKS_KEY, + isCheckDisabledForTask, + parseDisabledTaskChecks, + withCheckDisabled, + withCheckEnabled, +} from './disabled-task-checks'; + +describe('disabled-task-checks utils', () => { + describe('parseDisabledTaskChecks', () => { + it('returns empty map for null/undefined', () => { + expect(parseDisabledTaskChecks(null)).toEqual({}); + expect(parseDisabledTaskChecks(undefined)).toEqual({}); + }); + + it('returns empty map when metadata is not an object', () => { + expect(parseDisabledTaskChecks('string')).toEqual({}); + expect(parseDisabledTaskChecks(123)).toEqual({}); + expect(parseDisabledTaskChecks([])).toEqual({}); + }); + + it('returns empty map when the key is missing', () => { + expect(parseDisabledTaskChecks({ somethingElse: true })).toEqual({}); + }); + + it('parses a valid map', () => { + const metadata = { + connectionName: 'My GitHub', + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection', 'dependabot'], + tsk_xyz: ['sanitized_inputs'], + }, + }; + expect(parseDisabledTaskChecks(metadata)).toEqual({ + tsk_abc: ['branch_protection', 'dependabot'], + tsk_xyz: ['sanitized_inputs'], + }); + }); + + it('drops non-string/empty check ids', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection', 42, null, '', 'dependabot'], + }, + }; + expect(parseDisabledTaskChecks(metadata)).toEqual({ + tsk_abc: ['branch_protection', 'dependabot'], + }); + }); + + it('drops task entries where all check ids are invalid', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: [null, 42, ''], + tsk_xyz: ['valid'], + }, + }; + expect(parseDisabledTaskChecks(metadata)).toEqual({ + tsk_xyz: ['valid'], + }); + }); + + it('skips non-array check lists', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: 'not-an-array', + tsk_xyz: ['valid'], + }, + }; + expect(parseDisabledTaskChecks(metadata)).toEqual({ + tsk_xyz: ['valid'], + }); + }); + }); + + describe('isCheckDisabledForTask', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection'], + }, + }; + + it('returns true when the check is disabled', () => { + expect( + isCheckDisabledForTask(metadata, 'tsk_abc', 'branch_protection'), + ).toBe(true); + }); + + it('returns false when the check is not in the list', () => { + expect(isCheckDisabledForTask(metadata, 'tsk_abc', 'dependabot')).toBe( + false, + ); + }); + + it('returns false when the task has no disabled checks', () => { + expect( + isCheckDisabledForTask(metadata, 'tsk_xyz', 'branch_protection'), + ).toBe(false); + }); + + it('returns false for empty metadata', () => { + expect(isCheckDisabledForTask(null, 'tsk_abc', 'branch_protection')).toBe( + false, + ); + }); + }); + + describe('withCheckDisabled', () => { + it('adds a check to an empty metadata object', () => { + const result = withCheckDisabled(null, 'tsk_abc', 'branch_protection'); + expect(result[DISABLED_TASK_CHECKS_KEY]).toEqual({ + tsk_abc: ['branch_protection'], + }); + }); + + it('preserves existing metadata fields', () => { + const metadata = { + connectionName: 'My GitHub', + accountId: '12345', + }; + const result = withCheckDisabled( + metadata, + 'tsk_abc', + 'branch_protection', + ); + expect(result.connectionName).toBe('My GitHub'); + expect(result.accountId).toBe('12345'); + expect(result[DISABLED_TASK_CHECKS_KEY]).toEqual({ + tsk_abc: ['branch_protection'], + }); + }); + + it('adds to an existing task list', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection'], + }, + }; + const result = withCheckDisabled(metadata, 'tsk_abc', 'dependabot'); + expect(result[DISABLED_TASK_CHECKS_KEY]).toEqual({ + tsk_abc: ['branch_protection', 'dependabot'], + }); + }); + + it('is idempotent when the check is already disabled', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection'], + }, + }; + const result = withCheckDisabled( + metadata, + 'tsk_abc', + 'branch_protection', + ); + expect(result[DISABLED_TASK_CHECKS_KEY]).toEqual({ + tsk_abc: ['branch_protection'], + }); + }); + + it('does not mutate the input metadata', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection'], + }, + }; + const snapshot = JSON.stringify(metadata); + withCheckDisabled(metadata, 'tsk_abc', 'dependabot'); + expect(JSON.stringify(metadata)).toBe(snapshot); + }); + + it('works across multiple tasks independently', () => { + let metadata: Record = {}; + metadata = withCheckDisabled(metadata, 'tsk_abc', 'branch_protection'); + metadata = withCheckDisabled(metadata, 'tsk_xyz', 'sanitized_inputs'); + expect(metadata[DISABLED_TASK_CHECKS_KEY]).toEqual({ + tsk_abc: ['branch_protection'], + tsk_xyz: ['sanitized_inputs'], + }); + }); + }); + + describe('withCheckEnabled', () => { + it('removes the check from a task list', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection', 'dependabot'], + }, + }; + const result = withCheckEnabled(metadata, 'tsk_abc', 'branch_protection'); + expect(result[DISABLED_TASK_CHECKS_KEY]).toEqual({ + tsk_abc: ['dependabot'], + }); + }); + + it('removes the task entry when its list becomes empty', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection'], + tsk_xyz: ['sanitized_inputs'], + }, + }; + const result = withCheckEnabled(metadata, 'tsk_abc', 'branch_protection'); + expect(result[DISABLED_TASK_CHECKS_KEY]).toEqual({ + tsk_xyz: ['sanitized_inputs'], + }); + }); + + it('is a no-op when the check was not disabled', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection'], + }, + }; + const result = withCheckEnabled(metadata, 'tsk_abc', 'dependabot'); + expect(result[DISABLED_TASK_CHECKS_KEY]).toEqual({ + tsk_abc: ['branch_protection'], + }); + }); + + it('preserves other metadata fields', () => { + const metadata = { + connectionName: 'My GitHub', + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection'], + }, + }; + const result = withCheckEnabled(metadata, 'tsk_abc', 'branch_protection'); + expect(result.connectionName).toBe('My GitHub'); + expect(result[DISABLED_TASK_CHECKS_KEY]).toEqual({}); + }); + + it('does not mutate the input metadata', () => { + const metadata = { + [DISABLED_TASK_CHECKS_KEY]: { + tsk_abc: ['branch_protection'], + }, + }; + const snapshot = JSON.stringify(metadata); + withCheckEnabled(metadata, 'tsk_abc', 'branch_protection'); + expect(JSON.stringify(metadata)).toBe(snapshot); + }); + }); +}); diff --git a/apps/api/src/integration-platform/utils/disabled-task-checks.ts b/apps/api/src/integration-platform/utils/disabled-task-checks.ts new file mode 100644 index 0000000000..21c93414e1 --- /dev/null +++ b/apps/api/src/integration-platform/utils/disabled-task-checks.ts @@ -0,0 +1,129 @@ +/** + * Helpers for reading and writing per-task disabled integration checks from + * `IntegrationConnection.metadata`. + * + * Per-task disable state is stored under `metadata.disabledTaskChecks` as a map + * from task ID to the list of manifest check IDs that are disabled for that + * task on this connection. Example: + * + * { + * ...otherMetadata, + * disabledTaskChecks: { + * "tsk_abc123": ["branch_protection", "dependabot"], + * "tsk_xyz789": ["sanitized_inputs"] + * } + * } + * + * Storing on the connection gives us "reconnect = fresh state" for free and + * transparently supports orgs with multiple connections per provider — each + * connection has its own disable state. + */ + +export const DISABLED_TASK_CHECKS_KEY = 'disabledTaskChecks'; + +export type DisabledTaskChecksMap = Record; + +/** + * Parse the disabled task checks map from a connection's metadata JSON blob. + * Returns an empty map if the metadata is missing, malformed, or doesn't + * contain a `disabledTaskChecks` entry. Never throws. + */ +export function parseDisabledTaskChecks( + metadata: unknown, +): DisabledTaskChecksMap { + if (!metadata || typeof metadata !== 'object') { + return {}; + } + const raw = (metadata as Record)[DISABLED_TASK_CHECKS_KEY]; + if (!raw || typeof raw !== 'object') { + return {}; + } + + const result: DisabledTaskChecksMap = {}; + for (const [taskId, checkIds] of Object.entries( + raw as Record, + )) { + if (!Array.isArray(checkIds)) continue; + const cleaned = checkIds.filter( + (id): id is string => typeof id === 'string' && id.length > 0, + ); + if (cleaned.length > 0) { + result[taskId] = cleaned; + } + } + return result; +} + +/** + * Returns true if the given checkId is disabled for the given taskId on this + * connection's metadata. + */ +export function isCheckDisabledForTask( + metadata: unknown, + taskId: string, + checkId: string, +): boolean { + const map = parseDisabledTaskChecks(metadata); + const disabled = map[taskId]; + return Array.isArray(disabled) && disabled.includes(checkId); +} + +/** + * Returns a new metadata object with the given check marked as disabled for + * the given task. Does not mutate the input. If the check is already disabled, + * returns the metadata unchanged (same reference). + */ +export function withCheckDisabled( + metadata: unknown, + taskId: string, + checkId: string, +): Record { + const base: Record = + metadata && typeof metadata === 'object' + ? { ...(metadata as Record) } + : {}; + const map = parseDisabledTaskChecks(base); + const current = map[taskId] ?? []; + if (current.includes(checkId)) { + // Already disabled — return a merged copy so callers can safely write back. + base[DISABLED_TASK_CHECKS_KEY] = map; + return base; + } + const nextMap: DisabledTaskChecksMap = { + ...map, + [taskId]: [...current, checkId], + }; + base[DISABLED_TASK_CHECKS_KEY] = nextMap; + return base; +} + +/** + * Returns a new metadata object with the given check re-enabled for the given + * task. Cleans up empty arrays. If the check wasn't disabled, returns a merged + * copy unchanged. + */ +export function withCheckEnabled( + metadata: unknown, + taskId: string, + checkId: string, +): Record { + const base: Record = + metadata && typeof metadata === 'object' + ? { ...(metadata as Record) } + : {}; + const map = parseDisabledTaskChecks(base); + const current = map[taskId]; + if (!current || !current.includes(checkId)) { + base[DISABLED_TASK_CHECKS_KEY] = map; + return base; + } + const nextChecks = current.filter((id) => id !== checkId); + const nextMap: DisabledTaskChecksMap = { ...map }; + if (nextChecks.length === 0) { + delete nextMap[taskId]; + } else { + nextMap[taskId] = nextChecks; + } + base[DISABLED_TASK_CHECKS_KEY] = nextMap; + return base; +} diff --git a/apps/api/src/trigger/integration-platform/run-integration-checks-schedule.ts b/apps/api/src/trigger/integration-platform/run-integration-checks-schedule.ts index c57f95b83a..0a80d22fa6 100644 --- a/apps/api/src/trigger/integration-platform/run-integration-checks-schedule.ts +++ b/apps/api/src/trigger/integration-platform/run-integration-checks-schedule.ts @@ -2,6 +2,7 @@ import { getManifest } from '@trycompai/integration-platform'; import { db } from '@db'; import { logger, schedules } from '@trigger.dev/sdk'; import { runTaskIntegrationChecks } from './run-task-integration-checks'; +import { parseDisabledTaskChecks } from '../../integration-platform/utils/disabled-task-checks'; /** * Daily scheduled task (orchestrator) that finds all tasks with integration checks @@ -74,10 +75,21 @@ export const integrationChecksSchedule = schedules.task({ }, }); + // Per-task disabled checks are stored on the connection's metadata so + // users can disconnect individual checks from individual tasks without + // tearing down the whole integration. Resolve once per connection. + const disabledByTask = parseDisabledTaskChecks(connection.metadata); + for (const t of tasks) { - // Find which checks apply to this task + const disabledForThisTask = new Set(disabledByTask[t.id] ?? []); + + // Find which checks apply to this task, minus any the user disabled const checksForTask = manifest.checks - .filter((c) => c.taskMapping === t.taskTemplateId) + .filter( + (c) => + c.taskMapping === t.taskTemplateId && + !disabledForThisTask.has(c.id), + ) .map((c) => c.id); if (checksForTask.length > 0) { diff --git a/apps/api/src/trigger/integration-platform/run-task-integration-checks.ts b/apps/api/src/trigger/integration-platform/run-task-integration-checks.ts index 24ca17a9ee..57aef35305 100644 --- a/apps/api/src/trigger/integration-platform/run-task-integration-checks.ts +++ b/apps/api/src/trigger/integration-platform/run-task-integration-checks.ts @@ -4,6 +4,7 @@ import { logger, tags, task } from '@trigger.dev/sdk'; import { triggerEmail } from '../../email/trigger-email'; import { TaskStatusChangedEmail } from '../../email/templates/task-status-changed'; import { isUserUnsubscribed } from '@trycompai/email'; +import { parseDisabledTaskChecks } from '../../integration-platform/utils/disabled-task-checks'; /** * Send email notifications for task status change @@ -294,6 +295,26 @@ export const runTaskIntegrationChecks = task({ string | number | boolean | string[] | undefined >) || {}; + // Defensive per-task disable filter: the orchestrator already removes + // disabled checks, but a user may disconnect a check between batching and + // execution. Re-resolve the disabled set from the just-fetched connection + // metadata and skip anything that's now disabled. The rest of the flow + // (lastSyncAt update, task status evaluation, return payload) runs as + // before — just over the filtered list instead of the original one. + const disabledForThisTask = new Set( + parseDisabledTaskChecks(connection.metadata)[taskId] ?? [], + ); + const effectiveCheckIds = checkIds.filter( + (id) => !disabledForThisTask.has(id), + ); + if (effectiveCheckIds.length < checkIds.length) { + logger.info( + `Skipping ${ + checkIds.length - effectiveCheckIds.length + } disabled check(s) for task ${taskId}`, + ); + } + // Track overall results across all checks for this task let totalFindings = 0; let totalPassing = 0; @@ -301,7 +322,7 @@ export const runTaskIntegrationChecks = task({ // Run only the checks that apply to this task try { - for (const checkId of checkIds) { + for (const checkId of effectiveCheckIds) { const result = await runAllChecks({ manifest, accessToken: credentials.access_token ?? undefined, @@ -467,7 +488,7 @@ export const runTaskIntegrationChecks = task({ return { success: true, taskId, - checksRun: checkIds.length, + checksRun: effectiveCheckIds.length, totalPassing, totalFindings, taskStatus: diff --git a/apps/api/src/trust-portal/trust-portal.service.ts b/apps/api/src/trust-portal/trust-portal.service.ts index 35f2dc47e0..0fb95662d6 100644 --- a/apps/api/src/trust-portal/trust-portal.service.ts +++ b/apps/api/src/trust-portal/trust-portal.service.ts @@ -598,6 +598,7 @@ export class TrustPortalService { // Map framework boolean fields (frontend sends camelCase, DB uses snake_case) const boolFieldMap: Record = { + soc2: 'soc2', soc2type1: 'soc2type1', soc2type2: 'soc2type2', iso27001: 'iso27001', diff --git a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/TaskIntegrationChecks.tsx b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/TaskIntegrationChecks.tsx index d361e8a4b1..064795e994 100644 --- a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/TaskIntegrationChecks.tsx +++ b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/TaskIntegrationChecks.tsx @@ -7,6 +7,16 @@ import type { TaskIntegrationCheck, StoredCheckRun } from '../hooks/useIntegrati import { useIntegrationChecks } from '../hooks/useIntegrationChecks'; import { cn } from '@/lib/utils'; import { useActiveOrganization } from '@/utils/auth-client'; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from '@trycompai/ui/alert-dialog'; import { Badge } from '@trycompai/ui/badge'; import { Button } from '@trycompai/ui/button'; import { addDays, formatDistanceToNow, isBefore, setHours, setMinutes } from 'date-fns'; @@ -21,9 +31,11 @@ import { ExternalLink, Loader2, Play, + Plug, PlugZap, Settings2, TrendingUp, + Unplug, XCircle, } from 'lucide-react'; import Image from 'next/image'; @@ -58,11 +70,21 @@ export function TaskIntegrationChecks({ error: hookError, mutateChecks, runCheck, + disconnectCheckFromTask, + reconnectCheckToTask, } = useIntegrationChecks({ taskId, orgId }); const [runningCheck, setRunningCheck] = useState(null); + const [togglingCheck, setTogglingCheck] = useState(null); const [expandedCheck, setExpandedCheck] = useState(null); const [error, setError] = useState(null); + const [disconnectTarget, setDisconnectTarget] = useState<{ + connectionId: string; + checkId: string; + checkName: string; + integrationName: string; + } | null>(null); + const [disconnectError, setDisconnectError] = useState(null); // Sync hook-level error into local state useEffect(() => { @@ -138,6 +160,44 @@ export function TaskIntegrationChecks({ [runCheck, onTaskUpdated], ); + const handleConfirmDisconnect = useCallback(async () => { + if (!disconnectTarget) return; + const { connectionId, checkId, checkName } = disconnectTarget; + setTogglingCheck(checkId); + setDisconnectError(null); + try { + await disconnectCheckFromTask(connectionId, checkId); + toast.success(`Disconnected "${checkName}" from this task.`); + setDisconnectTarget(null); + } catch (err) { + console.error('Failed to disconnect check:', err); + setDisconnectError( + err instanceof Error ? err.message : 'Failed to disconnect check', + ); + } finally { + setTogglingCheck(null); + } + }, [disconnectCheckFromTask, disconnectTarget]); + + const handleReconnect = useCallback( + async (connectionId: string, checkId: string, checkName: string) => { + setTogglingCheck(checkId); + setError(null); + try { + await reconnectCheckToTask(connectionId, checkId); + toast.success(`Reconnected "${checkName}" to this task.`); + } catch (err) { + console.error('Failed to reconnect check:', err); + setError( + err instanceof Error ? err.message : 'Failed to reconnect check', + ); + } finally { + setTogglingCheck(null); + } + }, + [reconnectCheckToTask], + ); + if (loading) { return (
@@ -155,7 +215,16 @@ export function TaskIntegrationChecks({ ); } - const connectedChecks = checks.filter((c) => c.isConnected); + // Split checks into three groups: + // 1. connectedChecks — active + not disabled for this task + // 2. disabledForTaskChecks — connected but manually disconnected from this task + // 3. disconnectedChecks — no connection at all (suggestions) + const connectedChecks = checks.filter( + (c) => c.isConnected && !c.isDisabledForTask, + ); + const disabledForTaskChecks = checks.filter( + (c) => c.isConnected && c.isDisabledForTask, + ); const disconnectedChecks = checks.filter((c) => !c.isConnected); // If there are no checks at all for this task, don't render anything @@ -229,18 +298,11 @@ export function TaskIntegrationChecks({ {/* Card Content */}
- {connectedChecks.length === 0 && disconnectedChecks.length === 0 ? ( - - ) : connectedChecks.length === 0 ? ( + {connectedChecks.length === 0 && + disabledForTaskChecks.length === 0 ? ( + )} @@ -547,6 +631,64 @@ export function TaskIntegrationChecks({ })}
+ {/* Checks that are connected but manually disabled for this task */} + {disabledForTaskChecks.length > 0 && ( +
+

+ Disconnected from this task +

+
+ {disabledForTaskChecks.map((check) => { + const isToggling = togglingCheck === check.checkId; + return ( +
+
+ {check.integrationName} +
+

+ {check.checkName} +

+

+ Will not run until reconnected +

+
+
+ +
+ ); + })} +
+
+ )} + {/* Disconnected Checks as Suggestions */} {disconnectedChecks.length > 0 && (
@@ -588,6 +730,60 @@ export function TaskIntegrationChecks({ )}
+ {/* Confirm disconnect-from-task dialog */} + { + // Don't let Escape / click-outside close the dialog mid-request — + // the in-flight operation still owns the target state. + if (!open && togglingCheck === null) { + setDisconnectTarget(null); + setDisconnectError(null); + } + }} + > + + + Disconnect check from task? + + {disconnectTarget ? ( + <> + {disconnectTarget.checkName} from{' '} + {disconnectTarget.integrationName} will no + longer run for this task. The integration itself stays + connected and will continue running for other tasks. You can + reconnect it to this task at any time. + + ) : null} + + + {disconnectError && ( +
+ + {disconnectError} +
+ )} + + + Cancel + + { + // Radix's AlertDialogAction auto-closes the dialog on click. + // Stop that so our async handler controls the close, keeping + // the "Disconnecting…" state visible until the request lands + // and surfacing any error inside the dialog context. + e.preventDefault(); + void handleConfirmDisconnect(); + }} + disabled={togglingCheck !== null} + > + {togglingCheck !== null ? 'Disconnecting...' : 'Disconnect'} + + +
+
+ {/* Configure Integration Dialog - opens after OAuth success or when clicking Configure */} {configureConnection && ( = {}) => ({ + integrationId: 'github', + integrationName: 'GitHub', + integrationLogoUrl: '/github.png', + checkId: 'branch_protection', + checkName: 'Branch Protection', + checkDescription: 'Ensures branches are protected', + isConnected: true, + isDisabledForTask: false, + needsConfiguration: false, + connectionId: 'icn_1', + connectionStatus: 'active', + ...overrides, +}); + +const createJsonResponse = (body: unknown, status = 200): Response => + new Response(JSON.stringify(body), { + status, + headers: { 'Content-Type': 'application/json' }, + }); + +const wrapper = ({ children }: { children: ReactNode }) => ( + new Map(), + dedupingInterval: 0, + shouldRetryOnError: false, + revalidateOnFocus: false, + refreshInterval: 0, + }} + > + {children} + +); + +describe('useIntegrationChecks', () => { + let fetchMock: ReturnType; + + beforeEach(() => { + fetchMock = vi.fn(); + vi.stubGlobal('fetch', fetchMock); + }); + + afterEach(() => { + vi.unstubAllGlobals(); + vi.clearAllMocks(); + }); + + const mockInitialLoad = ( + checks: ReturnType[], + runs: unknown[] = [], + ) => { + fetchMock.mockImplementation((url: string) => { + if (url.includes('/checks')) { + return Promise.resolve( + createJsonResponse({ + checks, + task: { id: TASK_ID, title: 'Test', templateId: 'tpl_1' }, + }), + ); + } + if (url.includes('/runs')) { + return Promise.resolve(createJsonResponse({ runs })); + } + return Promise.resolve(createJsonResponse({})); + }); + }; + + it('loads checks and exposes the disabled flag', async () => { + mockInitialLoad([ + makeCheck({ checkId: 'a', isDisabledForTask: false }), + makeCheck({ checkId: 'b', isDisabledForTask: true }), + ]); + + const { result } = renderHook( + () => useIntegrationChecks({ taskId: TASK_ID, orgId: ORG_ID }), + { wrapper }, + ); + + await waitFor(() => expect(result.current.isLoading).toBe(false)); + + expect(result.current.checks).toHaveLength(2); + expect(result.current.checks[0]!.isDisabledForTask).toBe(false); + expect(result.current.checks[1]!.isDisabledForTask).toBe(true); + }); + + it('disconnectCheckFromTask POSTs to the disconnect endpoint and updates the cache', async () => { + mockInitialLoad([makeCheck({ checkId: 'branch_protection' })]); + + const { result } = renderHook( + () => useIntegrationChecks({ taskId: TASK_ID, orgId: ORG_ID }), + { wrapper }, + ); + + await waitFor(() => expect(result.current.isLoading).toBe(false)); + + // After initial load, queue the disconnect POST response and a refetch + // (SWR revalidates after mutate). + fetchMock.mockImplementation((url: string) => { + if (url.includes('/checks/disconnect')) { + return Promise.resolve( + createJsonResponse({ success: true, disabled: true }), + ); + } + if (url.includes('/checks?')) { + return Promise.resolve( + createJsonResponse({ + checks: [ + makeCheck({ + checkId: 'branch_protection', + isDisabledForTask: true, + }), + ], + task: { id: TASK_ID, title: 'Test', templateId: 'tpl_1' }, + }), + ); + } + if (url.includes('/runs')) { + return Promise.resolve(createJsonResponse({ runs: [] })); + } + return Promise.resolve(createJsonResponse({})); + }); + + await act(async () => { + await result.current.disconnectCheckFromTask( + 'icn_1', + 'branch_protection', + ); + }); + + // Verify the POST was sent + const disconnectCall = fetchMock.mock.calls.find(([url]) => + String(url).includes('/checks/disconnect'), + ); + expect(disconnectCall).toBeTruthy(); + const disconnectInit = disconnectCall![1] as RequestInit; + expect(disconnectInit.method).toBe('POST'); + expect(JSON.parse(disconnectInit.body as string)).toEqual({ + connectionId: 'icn_1', + checkId: 'branch_protection', + }); + + // Cache should reflect the updated state + await waitFor(() => + expect(result.current.checks[0]!.isDisabledForTask).toBe(true), + ); + }); + + it('reconnectCheckToTask POSTs to the reconnect endpoint and updates the cache', async () => { + mockInitialLoad([ + makeCheck({ checkId: 'branch_protection', isDisabledForTask: true }), + ]); + + const { result } = renderHook( + () => useIntegrationChecks({ taskId: TASK_ID, orgId: ORG_ID }), + { wrapper }, + ); + + await waitFor(() => expect(result.current.isLoading).toBe(false)); + + fetchMock.mockImplementation((url: string) => { + if (url.includes('/checks/reconnect')) { + return Promise.resolve( + createJsonResponse({ success: true, disabled: false }), + ); + } + if (url.includes('/checks?')) { + return Promise.resolve( + createJsonResponse({ + checks: [ + makeCheck({ + checkId: 'branch_protection', + isDisabledForTask: false, + }), + ], + task: { id: TASK_ID, title: 'Test', templateId: 'tpl_1' }, + }), + ); + } + if (url.includes('/runs')) { + return Promise.resolve(createJsonResponse({ runs: [] })); + } + return Promise.resolve(createJsonResponse({})); + }); + + await act(async () => { + await result.current.reconnectCheckToTask('icn_1', 'branch_protection'); + }); + + const reconnectCall = fetchMock.mock.calls.find(([url]) => + String(url).includes('/checks/reconnect'), + ); + expect(reconnectCall).toBeTruthy(); + + await waitFor(() => + expect(result.current.checks[0]!.isDisabledForTask).toBe(false), + ); + }); + + it('throws and rolls back optimistic updates when the disconnect request fails', async () => { + mockInitialLoad([makeCheck({ checkId: 'branch_protection' })]); + + const { result } = renderHook( + () => useIntegrationChecks({ taskId: TASK_ID, orgId: ORG_ID }), + { wrapper }, + ); + + await waitFor(() => expect(result.current.isLoading).toBe(false)); + + fetchMock.mockImplementation((url: string) => { + if (url.includes('/checks/disconnect')) { + return Promise.resolve( + createJsonResponse({ message: 'Server exploded' }, 500), + ); + } + if (url.includes('/checks?')) { + return Promise.resolve( + createJsonResponse({ + checks: [makeCheck({ checkId: 'branch_protection' })], + task: { id: TASK_ID, title: 'Test', templateId: 'tpl_1' }, + }), + ); + } + if (url.includes('/runs')) { + return Promise.resolve(createJsonResponse({ runs: [] })); + } + return Promise.resolve(createJsonResponse({})); + }); + + await expect( + act(async () => { + await result.current.disconnectCheckFromTask( + 'icn_1', + 'branch_protection', + ); + }), + ).rejects.toThrow(); + + // Cache should have rolled back + await waitFor(() => + expect(result.current.checks[0]!.isDisabledForTask).toBe(false), + ); + }); +}); diff --git a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/hooks/useIntegrationChecks.ts b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/hooks/useIntegrationChecks.ts index 6ff8cb8d8d..3c85c3c171 100644 --- a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/hooks/useIntegrationChecks.ts +++ b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/hooks/useIntegrationChecks.ts @@ -11,6 +11,8 @@ interface TaskIntegrationCheck { checkName: string; checkDescription: string; isConnected: boolean; + /** True when the user has disconnected this specific check from the task. */ + isDisabledForTask: boolean; needsConfiguration: boolean; connectionId?: string; connectionStatus?: string; @@ -134,6 +136,90 @@ export function useIntegrationChecks({ taskId, orgId }: UseIntegrationChecksOpti throw new Error('Failed to run check'); }; + /** + * Disconnect a single check from the current task. The integration itself + * stays connected — only the (task, check) pair is affected. Applies an + * optimistic update to the SWR cache and revalidates in the background. + */ + const disconnectCheckFromTask = async ( + connectionId: string, + checkId: string, + ): Promise => { + await mutateChecks( + async (current) => { + const response = await api.post<{ + success: boolean; + disabled: true; + error?: string; + }>( + `/v1/integrations/tasks/${taskId}/checks/disconnect?organizationId=${orgId}`, + { connectionId, checkId }, + ); + + if (response.error || !response.data?.success) { + throw new Error(response.error || 'Failed to disconnect check'); + } + + return (current ?? []).map((c) => + c.checkId === checkId && c.connectionId === connectionId + ? { ...c, isDisabledForTask: true } + : c, + ); + }, + { + optimisticData: (current) => + (current ?? []).map((c) => + c.checkId === checkId && c.connectionId === connectionId + ? { ...c, isDisabledForTask: true } + : c, + ), + rollbackOnError: true, + revalidate: true, + }, + ); + }; + + /** + * Re-enable a previously disconnected check for the current task. + */ + const reconnectCheckToTask = async ( + connectionId: string, + checkId: string, + ): Promise => { + await mutateChecks( + async (current) => { + const response = await api.post<{ + success: boolean; + disabled: false; + error?: string; + }>( + `/v1/integrations/tasks/${taskId}/checks/reconnect?organizationId=${orgId}`, + { connectionId, checkId }, + ); + + if (response.error || !response.data?.success) { + throw new Error(response.error || 'Failed to reconnect check'); + } + + return (current ?? []).map((c) => + c.checkId === checkId && c.connectionId === connectionId + ? { ...c, isDisabledForTask: false } + : c, + ); + }, + { + optimisticData: (current) => + (current ?? []).map((c) => + c.checkId === checkId && c.connectionId === connectionId + ? { ...c, isDisabledForTask: false } + : c, + ), + rollbackOnError: true, + revalidate: true, + }, + ); + }; + return { checks: Array.isArray(checks) ? checks : [], runs: Array.isArray(runs) ? runs : [], @@ -142,5 +228,7 @@ export function useIntegrationChecks({ taskId, orgId }: UseIntegrationChecksOpti mutateChecks, mutateRuns, runCheck, + disconnectCheckFromTask, + reconnectCheckToTask, }; } diff --git a/apps/app/src/app/(app)/[orgId]/trust/portal-settings/components/TrustPortalSwitch.tsx b/apps/app/src/app/(app)/[orgId]/trust/portal-settings/components/TrustPortalSwitch.tsx index 7dd9c52955..19fcc8cc7f 100644 --- a/apps/app/src/app/(app)/[orgId]/trust/portal-settings/components/TrustPortalSwitch.tsx +++ b/apps/app/src/app/(app)/[orgId]/trust/portal-settings/components/TrustPortalSwitch.tsx @@ -668,9 +668,11 @@ export function TrustPortalSwitch({ }} onToggle={async (checked) => { try { - await updateFrameworkSettings({ - soc2type2: checked, - }); + await updateFrameworkSettings( + checked + ? { soc2type2: true } + : { soc2: false, soc2type2: false }, + ); toast.success('SOC 2 Type 2 status updated'); } catch (error) { toast.error('Failed to update SOC 2 Type 2 status');