diff --git a/.gitignore b/.gitignore index 4f5ba5dbea..9cd387c21e 100644 --- a/.gitignore +++ b/.gitignore @@ -95,4 +95,5 @@ scripts/sync-release-branch.sh .claude/audit-findings.md -.superpowers/* \ No newline at end of file +.superpowers/* +.claude/worktrees/ diff --git a/apps/api/package.json b/apps/api/package.json index 6bf71b788e..01e0ee02c8 100644 --- a/apps/api/package.json +++ b/apps/api/package.json @@ -8,10 +8,57 @@ "@ai-sdk/groq": "^2.0.32", "@ai-sdk/openai": "^2.0.65", "@aws-sdk/client-ec2": "^3.911.0", - "@aws-sdk/client-s3": "^3.859.0", + "@aws-sdk/client-s3": "3.1013.0", + "@aws-sdk/client-acm": "^3.948.0", + "@aws-sdk/client-backup": "^3.948.0", + "@aws-sdk/client-cloudtrail": "^3.948.0", + "@aws-sdk/client-cloudwatch": "^3.948.0", + "@aws-sdk/client-cost-explorer": "^3.948.0", + "@aws-sdk/client-cloudwatch-logs": "^3.948.0", + "@aws-sdk/client-config-service": "^3.948.0", + "@aws-sdk/client-dynamodb": "^3.948.0", + "@aws-sdk/client-ecr": "^3.948.0", + "@aws-sdk/client-ecs": "^3.948.0", + "@aws-sdk/client-efs": "^3.948.0", + "@aws-sdk/client-eks": "^3.948.0", + "@aws-sdk/client-elastic-load-balancing-v2": "^3.948.0", + "@aws-sdk/client-guardduty": "^3.948.0", + "@aws-sdk/client-iam": "^3.948.0", + "@aws-sdk/client-inspector2": "^3.948.0", + "@aws-sdk/client-kms": "^3.948.0", + "@aws-sdk/client-lambda": "^3.948.0", + "@aws-sdk/client-macie2": "^3.948.0", + "@aws-sdk/client-opensearch": "^3.948.0", + "@aws-sdk/client-rds": "^3.948.0", + "@aws-sdk/client-redshift": "^3.948.0", + "@aws-sdk/client-route-53": "^3.948.0", + "@aws-sdk/client-secrets-manager": "^3.948.0", "@aws-sdk/client-securityhub": "^3.948.0", + "@aws-sdk/client-sns": "^3.948.0", + "@aws-sdk/client-sqs": "^3.948.0", + "@aws-sdk/client-wafv2": "^3.948.0", + "@aws-sdk/client-api-gateway": "^3.948.0", + "@aws-sdk/client-apigatewayv2": "^3.948.0", + "@aws-sdk/client-appflow": "^3.948.0", + "@aws-sdk/client-athena": "^3.948.0", + "@aws-sdk/client-cloudfront": "^3.948.0", + "@aws-sdk/client-codebuild": "^3.948.0", + "@aws-sdk/client-cognito-identity-provider": "^3.948.0", + "@aws-sdk/client-elastic-beanstalk": "^3.948.0", + "@aws-sdk/client-elasticache": "^3.948.0", + "@aws-sdk/client-emr": "^3.948.0", + "@aws-sdk/client-eventbridge": "^3.948.0", + "@aws-sdk/client-glue": "^3.948.0", + "@aws-sdk/client-kafka": "^3.948.0", + "@aws-sdk/client-kinesis": "^3.948.0", + "@aws-sdk/client-network-firewall": "^3.948.0", + "@aws-sdk/client-sagemaker": "^3.948.0", + "@aws-sdk/client-sfn": "^3.948.0", + "@aws-sdk/client-shield": "^3.948.0", + "@aws-sdk/client-ssm": "^3.948.0", "@aws-sdk/client-sts": "^3.948.0", - "@aws-sdk/s3-request-presigner": "^3.859.0", + "@aws-sdk/client-transfer": "^3.948.0", + "@aws-sdk/s3-request-presigner": "3.1013.0", "@browserbasehq/sdk": "2.6.0", "@browserbasehq/stagehand": "^3.0.5", "@mendable/firecrawl-js": "^4.9.3", diff --git a/apps/api/src/admin-organizations/admin-audit-log.interceptor.spec.ts b/apps/api/src/admin-organizations/admin-audit-log.interceptor.spec.ts index b0f2eeaec3..3050c80b64 100644 --- a/apps/api/src/admin-organizations/admin-audit-log.interceptor.spec.ts +++ b/apps/api/src/admin-organizations/admin-audit-log.interceptor.spec.ts @@ -18,12 +18,36 @@ jest.mock('@db', () => ({ }, Prisma: {}, db: { - auditLog: { get create() { return mockCreate; } }, - policy: { get findFirst() { return mockPolicyFind; } }, - taskItem: { get findFirst() { return mockTaskFind; } }, - vendor: { get findFirst() { return mockVendorFind; } }, - finding: { get findFirst() { return mockFindingFind; } }, - context: { get findFirst() { return mockContextFind; } }, + auditLog: { + get create() { + return mockCreate; + }, + }, + policy: { + get findFirst() { + return mockPolicyFind; + }, + }, + taskItem: { + get findFirst() { + return mockTaskFind; + }, + }, + vendor: { + get findFirst() { + return mockVendorFind; + }, + }, + finding: { + get findFirst() { + return mockFindingFind; + }, + }, + context: { + get findFirst() { + return mockContextFind; + }, + }, }, })); diff --git a/apps/api/src/admin-organizations/admin-context.controller.ts b/apps/api/src/admin-organizations/admin-context.controller.ts index 37d7ca82a1..2084bfbc99 100644 --- a/apps/api/src/admin-organizations/admin-context.controller.ts +++ b/apps/api/src/admin-organizations/admin-context.controller.ts @@ -43,7 +43,9 @@ export class AdminContextController { } @Post(':orgId/context') - @ApiOperation({ summary: 'Create a context entry for an organization (admin)' }) + @ApiOperation({ + summary: 'Create a context entry for an organization (admin)', + }) @UsePipes( new ValidationPipe({ whitelist: true, @@ -59,7 +61,9 @@ export class AdminContextController { } @Patch(':orgId/context/:contextId') - @ApiOperation({ summary: 'Update a context entry for an organization (admin)' }) + @ApiOperation({ + summary: 'Update a context entry for an organization (admin)', + }) @UsePipes( new ValidationPipe({ whitelist: true, diff --git a/apps/api/src/admin-organizations/admin-evidence.controller.spec.ts b/apps/api/src/admin-organizations/admin-evidence.controller.spec.ts index 8b7ecb8709..907eb2fd4f 100644 --- a/apps/api/src/admin-organizations/admin-evidence.controller.spec.ts +++ b/apps/api/src/admin-organizations/admin-evidence.controller.spec.ts @@ -28,9 +28,7 @@ describe('AdminEvidenceController', () => { beforeEach(async () => { const module: TestingModule = await Test.createTestingModule({ controllers: [AdminEvidenceController], - providers: [ - { provide: EvidenceFormsService, useValue: mockService }, - ], + providers: [{ provide: EvidenceFormsService, useValue: mockService }], }).compile(); controller = module.get(AdminEvidenceController); diff --git a/apps/api/src/admin-organizations/admin-evidence.controller.ts b/apps/api/src/admin-organizations/admin-evidence.controller.ts index 463586dd7c..42cb24d970 100644 --- a/apps/api/src/admin-organizations/admin-evidence.controller.ts +++ b/apps/api/src/admin-organizations/admin-evidence.controller.ts @@ -24,12 +24,12 @@ import { @UseInterceptors(AdminAuditLogInterceptor) @Throttle({ default: { ttl: 60000, limit: 30 } }) export class AdminEvidenceController { - constructor( - private readonly evidenceFormsService: EvidenceFormsService, - ) {} + constructor(private readonly evidenceFormsService: EvidenceFormsService) {} @Get(':orgId/evidence-forms') - @ApiOperation({ summary: 'List evidence form statuses for an organization (admin)' }) + @ApiOperation({ + summary: 'List evidence form statuses for an organization (admin)', + }) async listFormStatuses(@Param('orgId') orgId: string) { return this.evidenceFormsService.getFormStatuses(orgId); } @@ -53,8 +53,12 @@ export class AdminEvidenceController { authContext: buildPlatformAdminAuthContext(req.userId, orgId), formType, search, - limit: limit ? String(Math.min(200, Math.max(1, parseInt(limit, 10) || 1))) : undefined, - offset: offset ? String(Math.max(0, parseInt(offset, 10) || 0)) : undefined, + limit: limit + ? String(Math.min(200, Math.max(1, parseInt(limit, 10) || 1))) + : undefined, + offset: offset + ? String(Math.max(0, parseInt(offset, 10) || 0)) + : undefined, }); } } diff --git a/apps/api/src/admin-organizations/admin-findings.controller.ts b/apps/api/src/admin-organizations/admin-findings.controller.ts index d208cdf900..0b2b6965e5 100644 --- a/apps/api/src/admin-organizations/admin-findings.controller.ts +++ b/apps/api/src/admin-organizations/admin-findings.controller.ts @@ -33,10 +33,7 @@ export class AdminFindingsController { @Get(':orgId/findings') @ApiOperation({ summary: 'List all findings for an organization (admin)' }) - async list( - @Param('orgId') orgId: string, - @Query('status') status?: string, - ) { + async list(@Param('orgId') orgId: string, @Query('status') status?: string) { let validatedStatus: FindingStatus | undefined; if (status) { if (!Object.values(FindingStatus).includes(status as FindingStatus)) { diff --git a/apps/api/src/admin-organizations/admin-guard-integration.spec.ts b/apps/api/src/admin-organizations/admin-guard-integration.spec.ts index b00a032ac5..fb4310d0df 100644 --- a/apps/api/src/admin-organizations/admin-guard-integration.spec.ts +++ b/apps/api/src/admin-organizations/admin-guard-integration.spec.ts @@ -99,9 +99,7 @@ describe('PlatformAdminGuard — runtime rejection scenarios', () => { }); const ctx = buildContext({ cookie: 'session=valid' }); - await expect(guard.canActivate(ctx)).rejects.toThrow( - ForbiddenException, - ); + await expect(guard.canActivate(ctx)).rejects.toThrow(ForbiddenException); await expect(guard.canActivate(ctx)).rejects.toThrow( 'Access denied: Platform admin privileges required', ); @@ -116,9 +114,7 @@ describe('PlatformAdminGuard — runtime rejection scenarios', () => { }); const ctx = buildContext({ cookie: 'session=valid' }); - await expect(guard.canActivate(ctx)).rejects.toThrow( - ForbiddenException, - ); + await expect(guard.canActivate(ctx)).rejects.toThrow(ForbiddenException); }); it('rejects a user with role "owner" (org role, not platform admin)', async () => { @@ -130,9 +126,7 @@ describe('PlatformAdminGuard — runtime rejection scenarios', () => { }); const ctx = buildContext({ cookie: 'session=valid' }); - await expect(guard.canActivate(ctx)).rejects.toThrow( - ForbiddenException, - ); + await expect(guard.canActivate(ctx)).rejects.toThrow(ForbiddenException); }); it('rejects when session claims admin but DB says user', async () => { @@ -146,9 +140,7 @@ describe('PlatformAdminGuard — runtime rejection scenarios', () => { }); const ctx = buildContext({ authorization: 'Bearer valid' }); - await expect(guard.canActivate(ctx)).rejects.toThrow( - ForbiddenException, - ); + await expect(guard.canActivate(ctx)).rejects.toThrow(ForbiddenException); expect(mockFindUnique).toHaveBeenCalledWith({ where: { id: 'usr_sneaky' }, select: { id: true, email: true, role: true }, diff --git a/apps/api/src/admin-organizations/admin-organizations.controller.ts b/apps/api/src/admin-organizations/admin-organizations.controller.ts index 523a6a4742..fea7a03a99 100644 --- a/apps/api/src/admin-organizations/admin-organizations.controller.ts +++ b/apps/api/src/admin-organizations/admin-organizations.controller.ts @@ -43,10 +43,25 @@ export class AdminOrganizationsController { } @Get('activity') - @ApiOperation({ summary: 'Organization activity report - shows last session per org (platform admin)' }) - @ApiQuery({ name: 'inactiveDays', required: false, description: 'Filter orgs with no session in N days (default: 90)' }) - @ApiQuery({ name: 'hasAccess', required: false, description: 'Filter by hasAccess (true/false)' }) - @ApiQuery({ name: 'onboarded', required: false, description: 'Filter by onboardingCompleted (true/false)' }) + @ApiOperation({ + summary: + 'Organization activity report - shows last session per org (platform admin)', + }) + @ApiQuery({ + name: 'inactiveDays', + required: false, + description: 'Filter orgs with no session in N days (default: 90)', + }) + @ApiQuery({ + name: 'hasAccess', + required: false, + description: 'Filter by hasAccess (true/false)', + }) + @ApiQuery({ + name: 'onboarded', + required: false, + description: 'Filter by onboardingCompleted (true/false)', + }) @ApiQuery({ name: 'page', required: false }) @ApiQuery({ name: 'limit', required: false }) async activity( @@ -57,9 +72,16 @@ export class AdminOrganizationsController { @Query('limit') limit?: string, ) { return this.service.getOrgActivity({ - inactiveDays: Math.max(0, Number.isFinite(parseInt(inactiveDays ?? '90', 10)) ? parseInt(inactiveDays ?? '90', 10) : 90), - hasAccess: hasAccess === 'true' ? true : hasAccess === 'false' ? false : undefined, - onboarded: onboarded === 'true' ? true : onboarded === 'false' ? false : undefined, + inactiveDays: Math.max( + 0, + Number.isFinite(parseInt(inactiveDays ?? '90', 10)) + ? parseInt(inactiveDays ?? '90', 10) + : 90, + ), + hasAccess: + hasAccess === 'true' ? true : hasAccess === 'false' ? false : undefined, + onboarded: + onboarded === 'true' ? true : onboarded === 'false' ? false : undefined, page: Math.max(1, parseInt(page || '1', 10) || 1), limit: Math.min(100, Math.max(1, parseInt(limit || '50', 10) || 50)), }); @@ -109,9 +131,19 @@ export class AdminOrganizationsController { } @Get(':id/audit-logs') - @ApiOperation({ summary: 'Get audit logs for an organization (platform admin)' }) - @ApiQuery({ name: 'entityType', required: false, description: 'Filter by entity type (e.g. policy, task)' }) - @ApiQuery({ name: 'take', required: false, description: 'Number of logs to return (max 100, default 100)' }) + @ApiOperation({ + summary: 'Get audit logs for an organization (platform admin)', + }) + @ApiQuery({ + name: 'entityType', + required: false, + description: 'Filter by entity type (e.g. policy, task)', + }) + @ApiQuery({ + name: 'take', + required: false, + description: 'Number of logs to return (max 100, default 100)', + }) async getAuditLogs( @Param('id') id: string, @Query('entityType') entityType?: string, diff --git a/apps/api/src/admin-organizations/admin-organizations.service.ts b/apps/api/src/admin-organizations/admin-organizations.service.ts index 46a3239de4..ec3b9fa2a9 100644 --- a/apps/api/src/admin-organizations/admin-organizations.service.ts +++ b/apps/api/src/admin-organizations/admin-organizations.service.ts @@ -125,7 +125,14 @@ export class AdminOrganizationsService { createdAt: true, hasAccess: true, onboardingCompleted: true, - _count: { select: { members: true, tasks: true, policy: true, auditLog: true } }, + _count: { + select: { + members: true, + tasks: true, + policy: true, + auditLog: true, + }, + }, members: { where: { deactivated: false }, select: { @@ -168,14 +175,20 @@ export class AdminOrganizationsService { lastSession = sess; } if (member.role?.includes('owner') && !owner) { - owner = { id: member.user.id, name: member.user.name, email: member.user.email }; + owner = { + id: member.user.id, + name: member.user.name, + email: member.user.email, + }; } } const lastAuditLog = org.auditLog?.[0]?.timestamp ?? null; const lastActivity = [lastSession, lastAuditLog] .filter(Boolean) - .sort((a, b) => (b as Date).getTime() - (a as Date).getTime())[0] as Date | undefined; + .sort((a, b) => (b as Date).getTime() - (a as Date).getTime())[0] as + | Date + | undefined; const isActive = lastActivity ? lastActivity >= cutoff : false; @@ -191,7 +204,7 @@ export class AdminOrganizationsService { auditLogCount: org._count.auditLog, owner, lastSession: lastSession?.toISOString() ?? null, - lastAuditLog: lastAuditLog ? (lastAuditLog as Date).toISOString() : null, + lastAuditLog: lastAuditLog ? lastAuditLog.toISOString() : null, lastActivity: lastActivity?.toISOString() ?? null, isActive, }; diff --git a/apps/api/src/admin-organizations/admin-policies.controller.spec.ts b/apps/api/src/admin-organizations/admin-policies.controller.spec.ts index d3ec2b8e7e..36132396c5 100644 --- a/apps/api/src/admin-organizations/admin-policies.controller.spec.ts +++ b/apps/api/src/admin-organizations/admin-policies.controller.spec.ts @@ -29,9 +29,7 @@ jest.mock('@db', () => ({ jest.mock('@trigger.dev/sdk', () => ({ auth: { - createPublicToken: jest - .fn() - .mockResolvedValue('mock-public-access-token'), + createPublicToken: jest.fn().mockResolvedValue('mock-public-access-token'), }, tasks: { trigger: jest.fn().mockResolvedValue({ id: 'run_123' }), @@ -85,9 +83,9 @@ describe('AdminPoliciesController', () => { }); it('should reject missing status', async () => { - await expect( - controller.update('org_1', 'pol_1', {}), - ).rejects.toThrow(BadRequestException); + await expect(controller.update('org_1', 'pol_1', {})).rejects.toThrow( + BadRequestException, + ); }); it('should reject invalid status', async () => { diff --git a/apps/api/src/admin-organizations/admin-policies.controller.ts b/apps/api/src/admin-organizations/admin-policies.controller.ts index 1dcfa6c952..090ec175e1 100644 --- a/apps/api/src/admin-organizations/admin-policies.controller.ts +++ b/apps/api/src/admin-organizations/admin-policies.controller.ts @@ -79,9 +79,7 @@ export class AdminPoliciesController { const updateData: Record = {}; if (body.status !== undefined) { - if ( - !Object.values(PolicyStatus).includes(body.status as PolicyStatus) - ) { + if (!Object.values(PolicyStatus).includes(body.status as PolicyStatus)) { throw new BadRequestException( `Invalid status. Must be one of: ${Object.values(PolicyStatus).join(', ')}`, ); @@ -135,9 +133,7 @@ export class AdminPoliciesController { }); const uniqueFrameworks = Array.from( - new Map( - instances.map((fi) => [fi.framework.id, fi.framework]), - ).values(), + new Map(instances.map((fi) => [fi.framework.id, fi.framework])).values(), ).map((f) => ({ id: f.id, name: f.name, diff --git a/apps/api/src/admin-organizations/admin-security.spec.ts b/apps/api/src/admin-organizations/admin-security.spec.ts index 4f95a5516f..04b3e6e343 100644 --- a/apps/api/src/admin-organizations/admin-security.spec.ts +++ b/apps/api/src/admin-organizations/admin-security.spec.ts @@ -1,5 +1,8 @@ import 'reflect-metadata'; -import { GUARDS_METADATA, INTERCEPTORS_METADATA } from '@nestjs/common/constants'; +import { + GUARDS_METADATA, + INTERCEPTORS_METADATA, +} from '@nestjs/common/constants'; import { PlatformAdminGuard } from '../auth/platform-admin.guard'; import { AdminAuditLogInterceptor } from './admin-audit-log.interceptor'; import { AdminOrganizationsController } from './admin-organizations.controller'; @@ -18,7 +21,12 @@ jest.mock('../auth/auth.server', () => ({ jest.mock('@db', () => ({ db: {}, - FindingStatus: { open: 'open', ready_for_review: 'ready_for_review', needs_revision: 'needs_revision', closed: 'closed' }, + FindingStatus: { + open: 'open', + ready_for_review: 'ready_for_review', + needs_revision: 'needs_revision', + closed: 'closed', + }, FindingType: { soc2: 'soc2', iso27001: 'iso27001' }, TaskStatus: { todo: 'todo', in_progress: 'in_progress', done: 'done' }, TaskFrequency: { daily: 'daily', weekly: 'weekly', monthly: 'monthly' }, @@ -41,7 +49,10 @@ jest.mock('@trycompai/integration-platform', () => ({ })); const ORG_ADMIN_CONTROLLERS = [ - { name: 'AdminOrganizationsController', controller: AdminOrganizationsController }, + { + name: 'AdminOrganizationsController', + controller: AdminOrganizationsController, + }, { name: 'AdminFindingsController', controller: AdminFindingsController }, { name: 'AdminPoliciesController', controller: AdminPoliciesController }, { name: 'AdminTasksController', controller: AdminTasksController }, @@ -51,49 +62,46 @@ const ORG_ADMIN_CONTROLLERS = [ ]; describe('Admin controllers security baseline', () => { - describe.each(ORG_ADMIN_CONTROLLERS)( - '$name', - ({ controller }) => { - it('has PlatformAdminGuard applied at the class level', () => { - const guards = Reflect.getMetadata(GUARDS_METADATA, controller) ?? []; - const hasPlatformAdminGuard = guards.some( - (g: unknown) => g === PlatformAdminGuard, - ); - expect(hasPlatformAdminGuard).toBe(true); - }); - - it('has AdminAuditLogInterceptor applied at the class level', () => { - const interceptors = - Reflect.getMetadata(INTERCEPTORS_METADATA, controller) ?? []; - const hasAuditInterceptor = interceptors.some( - (i: unknown) => i === AdminAuditLogInterceptor, - ); - expect(hasAuditInterceptor).toBe(true); - }); - - it('uses the correct controller path prefix', () => { - const path = Reflect.getMetadata('path', controller); - expect(path).toBe('admin/organizations'); - }); - - it('uses versioned controller format', () => { - const version = Reflect.getMetadata('__version__', controller); - expect(version).toBeDefined(); - }); - - it('does NOT use HybridAuthGuard (admin controllers use PlatformAdminGuard)', () => { - const guards = Reflect.getMetadata(GUARDS_METADATA, controller) ?? []; - const guardNames = guards.map((g: { name?: string }) => g.name); - expect(guardNames).not.toContain('HybridAuthGuard'); - }); - - it('does NOT use PermissionGuard (admin controllers bypass RBAC)', () => { - const guards = Reflect.getMetadata(GUARDS_METADATA, controller) ?? []; - const guardNames = guards.map((g: { name?: string }) => g.name); - expect(guardNames).not.toContain('PermissionGuard'); - }); - }, - ); + describe.each(ORG_ADMIN_CONTROLLERS)('$name', ({ controller }) => { + it('has PlatformAdminGuard applied at the class level', () => { + const guards = Reflect.getMetadata(GUARDS_METADATA, controller) ?? []; + const hasPlatformAdminGuard = guards.some( + (g: unknown) => g === PlatformAdminGuard, + ); + expect(hasPlatformAdminGuard).toBe(true); + }); + + it('has AdminAuditLogInterceptor applied at the class level', () => { + const interceptors = + Reflect.getMetadata(INTERCEPTORS_METADATA, controller) ?? []; + const hasAuditInterceptor = interceptors.some( + (i: unknown) => i === AdminAuditLogInterceptor, + ); + expect(hasAuditInterceptor).toBe(true); + }); + + it('uses the correct controller path prefix', () => { + const path = Reflect.getMetadata('path', controller); + expect(path).toBe('admin/organizations'); + }); + + it('uses versioned controller format', () => { + const version = Reflect.getMetadata('__version__', controller); + expect(version).toBeDefined(); + }); + + it('does NOT use HybridAuthGuard (admin controllers use PlatformAdminGuard)', () => { + const guards = Reflect.getMetadata(GUARDS_METADATA, controller) ?? []; + const guardNames = guards.map((g: { name?: string }) => g.name); + expect(guardNames).not.toContain('HybridAuthGuard'); + }); + + it('does NOT use PermissionGuard (admin controllers bypass RBAC)', () => { + const guards = Reflect.getMetadata(GUARDS_METADATA, controller) ?? []; + const guardNames = guards.map((g: { name?: string }) => g.name); + expect(guardNames).not.toContain('PermissionGuard'); + }); + }); it('covers all 7 expected org-scoped admin controllers', () => { expect(ORG_ADMIN_CONTROLLERS).toHaveLength(7); diff --git a/apps/api/src/admin-organizations/admin-tasks.controller.ts b/apps/api/src/admin-organizations/admin-tasks.controller.ts index 946be61a63..e13026fb47 100644 --- a/apps/api/src/admin-organizations/admin-tasks.controller.ts +++ b/apps/api/src/admin-organizations/admin-tasks.controller.ts @@ -76,7 +76,10 @@ export class AdminTasksController { } @Get(':orgId/tasks/:taskId/details') - @ApiOperation({ summary: 'Get task details with comments, attachments, and evidence (admin)' }) + @ApiOperation({ + summary: + 'Get task details with comments, attachments, and evidence (admin)', + }) async getDetails( @Param('orgId') orgId: string, @Param('taskId') taskId: string, @@ -86,11 +89,7 @@ export class AdminTasksController { const [comments, attachments, automationRuns, integrationRuns] = await Promise.all([ - this.commentsService.getComments( - orgId, - taskId, - CommentEntityType.task, - ), + this.commentsService.getComments(orgId, taskId, CommentEntityType.task), this.attachmentsService.getAttachments( orgId, taskId, @@ -173,9 +172,7 @@ export class AdminTasksController { if (body.frequency !== undefined) { if ( body.frequency !== null && - !Object.values(TaskFrequency).includes( - body.frequency as TaskFrequency, - ) + !Object.values(TaskFrequency).includes(body.frequency as TaskFrequency) ) { throw new BadRequestException( `Invalid frequency. Must be one of: ${Object.values(TaskFrequency).join(', ')}`, diff --git a/apps/api/src/admin-organizations/admin-vendors.controller.ts b/apps/api/src/admin-organizations/admin-vendors.controller.ts index 29cedae5d0..2aea33f513 100644 --- a/apps/api/src/admin-organizations/admin-vendors.controller.ts +++ b/apps/api/src/admin-organizations/admin-vendors.controller.ts @@ -67,9 +67,7 @@ export class AdminVendorsController { const updateData: Record = {}; if (body.status !== undefined) { - if ( - !Object.values(VendorStatus).includes(body.status as VendorStatus) - ) { + if (!Object.values(VendorStatus).includes(body.status as VendorStatus)) { throw new BadRequestException( `Invalid status. Must be one of: ${Object.values(VendorStatus).join(', ')}`, ); @@ -79,9 +77,7 @@ export class AdminVendorsController { if (body.category !== undefined) { if ( - !Object.values(VendorCategory).includes( - body.category as VendorCategory, - ) + !Object.values(VendorCategory).includes(body.category as VendorCategory) ) { throw new BadRequestException( `Invalid category. Must be one of: ${Object.values(VendorCategory).join(', ')}`, diff --git a/apps/api/src/admin-organizations/dto/invite-member.dto.ts b/apps/api/src/admin-organizations/dto/invite-member.dto.ts index f34440e257..4617ee0e72 100644 --- a/apps/api/src/admin-organizations/dto/invite-member.dto.ts +++ b/apps/api/src/admin-organizations/dto/invite-member.dto.ts @@ -15,7 +15,9 @@ export class InviteMemberDto { example: 'user@example.com', }) @IsEmail({}, { message: 'A valid email address is required' }) - @Transform(({ value }) => (typeof value === 'string' ? value.toLowerCase().trim() : value)) + @Transform(({ value }) => + typeof value === 'string' ? value.toLowerCase().trim() : value, + ) email: string; @ApiProperty({ diff --git a/apps/api/src/app/s3.ts b/apps/api/src/app/s3.ts index c20f3facc7..d1e445e439 100644 --- a/apps/api/src/app/s3.ts +++ b/apps/api/src/app/s3.ts @@ -1,11 +1,25 @@ import { GetObjectCommand, + PutObjectCommand, S3Client, type GetObjectCommandOutput, } from '@aws-sdk/client-s3'; +import { getSignedUrl as _getSignedUrl } from '@aws-sdk/s3-request-presigner'; import { Logger } from '@nestjs/common'; import '../config/load-env'; +/** + * Re-export getSignedUrl with a type workaround for duplicate @smithy/types. + * Bun/Docker installs separate @smithy/types copies for @aws-sdk/client-s3 + * and @aws-sdk/s3-request-presigner even when pinned to the same version. + * The runtime types are fully compatible — only the TypeScript class identity differs. + */ +export const getSignedUrl = _getSignedUrl as unknown as ( + client: S3Client, + command: GetObjectCommand | PutObjectCommand, + options?: { expiresIn?: number }, +) => Promise; + const logger = new Logger('S3'); const APP_AWS_REGION = process.env.APP_AWS_REGION; diff --git a/apps/api/src/assistant-chat/assistant-chat-tools.ts b/apps/api/src/assistant-chat/assistant-chat-tools.ts index 9d487e3e8b..5e501a4c7b 100644 --- a/apps/api/src/assistant-chat/assistant-chat-tools.ts +++ b/apps/api/src/assistant-chat/assistant-chat-tools.ts @@ -9,12 +9,23 @@ interface ToolContext { permissions: Permissions; } -function hasPermission(permissions: Permissions, resource: string, action: string): boolean { +function hasPermission( + permissions: Permissions, + resource: string, + action: string, +): boolean { return permissions[resource]?.includes(action) ?? false; } export function buildTools(ctx: ToolContext) { - const tools: Record Promise }> = {}; + const tools: Record< + string, + { + description: string; + inputSchema: z.ZodType; + execute: (...args: any[]) => Promise; + } + > = {}; // Always available tools.findOrganization = { @@ -25,7 +36,9 @@ export function buildTools(ctx: ToolContext) { where: { id: ctx.organizationId }, select: { name: true }, }); - return org ? { organization: org } : { organization: null, message: 'Organization not found' }; + return org + ? { organization: org } + : { organization: null, message: 'Organization not found' }; }, }; @@ -57,14 +70,17 @@ export function buildTools(ctx: ToolContext) { }; tools.getPolicyContent = { - description: 'Get the content of a specific policy by id. Run getPolicies first to get ids.', + description: + 'Get the content of a specific policy by id. Run getPolicies first to get ids.', inputSchema: z.object({ id: z.string() }), execute: async ({ id }: { id: string }) => { const policy = await db.policy.findUnique({ where: { id, organizationId: ctx.organizationId }, select: { content: true }, }); - return policy ? { content: policy.content } : { content: null, message: 'Policy not found' }; + return policy + ? { content: policy.content } + : { content: null, message: 'Policy not found' }; }, }; } @@ -74,12 +90,25 @@ export function buildTools(ctx: ToolContext) { tools.getRisks = { description: 'Get risks for the organization', inputSchema: z.object({ - status: z.enum(Object.values(RiskStatus) as [RiskStatus, ...RiskStatus[]]).optional(), - department: z.enum(Object.values(Departments) as [Departments, ...Departments[]]).optional(), - category: z.enum(Object.values(RiskCategory) as [RiskCategory, ...RiskCategory[]]).optional(), + status: z + .enum(Object.values(RiskStatus) as [RiskStatus, ...RiskStatus[]]) + .optional(), + department: z + .enum(Object.values(Departments) as [Departments, ...Departments[]]) + .optional(), + category: z + .enum( + Object.values(RiskCategory) as [RiskCategory, ...RiskCategory[]], + ) + .optional(), owner: z.string().optional(), }), - execute: async (input: { status?: RiskStatus; department?: Departments; category?: RiskCategory; owner?: string }) => { + execute: async (input: { + status?: RiskStatus; + department?: Departments; + category?: RiskCategory; + owner?: string; + }) => { const risks = await db.risk.findMany({ where: { organizationId: ctx.organizationId, diff --git a/apps/api/src/assistant-chat/assistant-chat.controller.ts b/apps/api/src/assistant-chat/assistant-chat.controller.ts index 78e1da8269..bb7b680381 100644 --- a/apps/api/src/assistant-chat/assistant-chat.controller.ts +++ b/apps/api/src/assistant-chat/assistant-chat.controller.ts @@ -20,7 +20,12 @@ import { ApiTags, } from '@nestjs/swagger'; import { openai } from '@ai-sdk/openai'; -import { streamText, convertToModelMessages, stepCountIs, type UIMessage } from 'ai'; +import { + streamText, + convertToModelMessages, + stepCountIs, + type UIMessage, +} from 'ai'; import type { Response, Request } from 'express'; import { AuthContext } from '../auth/auth-context.decorator'; import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; @@ -79,7 +84,9 @@ export class AssistantChatController { // @Res() bypasses NestJS exception filters, so we must handle errors manually try { if (!process.env.OPENAI_API_KEY) { - res.status(HttpStatus.SERVICE_UNAVAILABLE).json({ message: 'AI service not configured' }); + res + .status(HttpStatus.SERVICE_UNAVAILABLE) + .json({ message: 'AI service not configured' }); return; } @@ -155,8 +162,14 @@ Important: } catch (error) { this.logger.error('Completions endpoint error', error); if (!res.headersSent) { - const status = error instanceof HttpException ? error.getStatus() : HttpStatus.INTERNAL_SERVER_ERROR; - const message = error instanceof HttpException ? error.message : 'Internal server error'; + const status = + error instanceof HttpException + ? error.getStatus() + : HttpStatus.INTERNAL_SERVER_ERROR; + const message = + error instanceof HttpException + ? error.message + : 'Internal server error'; res.status(status).json({ message }); } else { res.end(); diff --git a/apps/api/src/attachments/attachments.controller.spec.ts b/apps/api/src/attachments/attachments.controller.spec.ts index 7b2d134285..e58956bfd4 100644 --- a/apps/api/src/attachments/attachments.controller.spec.ts +++ b/apps/api/src/attachments/attachments.controller.spec.ts @@ -57,9 +57,10 @@ describe('AttachmentsController', () => { 'att_abc123', ); - expect( - attachmentsService.getAttachmentDownloadUrl, - ).toHaveBeenCalledWith('org_123', 'att_abc123'); + expect(attachmentsService.getAttachmentDownloadUrl).toHaveBeenCalledWith( + 'org_123', + 'att_abc123', + ); expect(result).toEqual(downloadResult); }); diff --git a/apps/api/src/attachments/attachments.service.ts b/apps/api/src/attachments/attachments.service.ts index fa818e2622..6bb88bed94 100644 --- a/apps/api/src/attachments/attachments.service.ts +++ b/apps/api/src/attachments/attachments.service.ts @@ -5,7 +5,7 @@ import { PutObjectCommand, S3Client, } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; +import { getSignedUrl, s3Client } from '@/app/s3'; import { AttachmentEntityType, AttachmentType, db } from '@db'; import { BadRequestException, @@ -15,7 +15,6 @@ import { import { randomBytes } from 'crypto'; import { AttachmentResponseDto } from '../tasks/dto/task-responses.dto'; import { UploadAttachmentDto } from './upload-attachment.dto'; -import { s3Client } from '@/app/s3'; import { validateFileContent } from '../utils/file-type-validation'; @Injectable() diff --git a/apps/api/src/audit/audit-log.controller.spec.ts b/apps/api/src/audit/audit-log.controller.spec.ts index 3ce604b749..17b01821b8 100644 --- a/apps/api/src/audit/audit-log.controller.spec.ts +++ b/apps/api/src/audit/audit-log.controller.spec.ts @@ -61,10 +61,7 @@ describe('AuditLogController', () => { const mockLogs = [{ id: 'log_1' }, { id: 'log_2' }]; mockFindMany.mockResolvedValue(mockLogs); - const result = await controller.getAuditLogs( - 'org_1', - mockAuthContext, - ); + const result = await controller.getAuditLogs('org_1', mockAuthContext); expect(result).toEqual({ data: mockLogs, @@ -75,7 +72,13 @@ describe('AuditLogController', () => { where: { organizationId: 'org_1' }, include: { user: { - select: { id: true, name: true, email: true, image: true, role: true }, + select: { + id: true, + name: true, + email: true, + image: true, + role: true, + }, }, member: true, organization: true, @@ -88,11 +91,7 @@ describe('AuditLogController', () => { it('should filter by single entityType', async () => { mockFindMany.mockResolvedValue([]); - await controller.getAuditLogs( - 'org_1', - mockAuthContext, - 'policy', - ); + await controller.getAuditLogs('org_1', mockAuthContext, 'policy'); expect(mockFindMany).toHaveBeenCalledWith( expect.objectContaining({ @@ -104,11 +103,7 @@ describe('AuditLogController', () => { it('should filter by multiple comma-separated entityTypes', async () => { mockFindMany.mockResolvedValue([]); - await controller.getAuditLogs( - 'org_1', - mockAuthContext, - 'risk,task', - ); + await controller.getAuditLogs('org_1', mockAuthContext, 'risk,task'); expect(mockFindMany).toHaveBeenCalledWith( expect.objectContaining({ @@ -244,10 +239,7 @@ describe('AuditLogController', () => { userRoles: null, }; - const result = await controller.getAuditLogs( - 'org_1', - authContextNoUser, - ); + const result = await controller.getAuditLogs('org_1', authContextNoUser); expect(result).toEqual({ data: [], diff --git a/apps/api/src/audit/audit-log.controller.ts b/apps/api/src/audit/audit-log.controller.ts index f950aeea97..0e5fda432b 100644 --- a/apps/api/src/audit/audit-log.controller.ts +++ b/apps/api/src/audit/audit-log.controller.ts @@ -15,10 +15,26 @@ export class AuditLogController { @Get() @RequirePermission('app', 'read') @ApiOperation({ summary: 'Get audit logs filtered by entity type and ID' }) - @ApiQuery({ name: 'entityType', required: false, description: 'Filter by entity type (e.g. policy, task, control)' }) - @ApiQuery({ name: 'entityId', required: false, description: 'Filter by entity ID' }) - @ApiQuery({ name: 'pathContains', required: false, description: 'Filter by path substring (e.g. automation ID)' }) - @ApiQuery({ name: 'take', required: false, description: 'Number of logs to return (max 100, default 50)' }) + @ApiQuery({ + name: 'entityType', + required: false, + description: 'Filter by entity type (e.g. policy, task, control)', + }) + @ApiQuery({ + name: 'entityId', + required: false, + description: 'Filter by entity ID', + }) + @ApiQuery({ + name: 'pathContains', + required: false, + description: 'Filter by path substring (e.g. automation ID)', + }) + @ApiQuery({ + name: 'take', + required: false, + description: 'Number of logs to return (max 100, default 50)', + }) async getAuditLogs( @OrganizationId() organizationId: string, @AuthContext() authContext: AuthContextType, @@ -31,12 +47,18 @@ export class AuditLogController { const where: Record = { organizationId }; if (entityType) { // Support comma-separated entity types (e.g. "risk,task") - const types = entityType.split(',').map((t) => t.trim()).filter(Boolean); + const types = entityType + .split(',') + .map((t) => t.trim()) + .filter(Boolean); where.entityType = types.length === 1 ? types[0] : { in: types }; } if (entityId) { // Support comma-separated entity IDs - const ids = entityId.split(',').map((id) => id.trim()).filter(Boolean); + const ids = entityId + .split(',') + .map((id) => id.trim()) + .filter(Boolean); where.entityId = ids.length === 1 ? ids[0] : { in: ids }; } if (pathContains) { @@ -54,7 +76,13 @@ export class AuditLogController { where, include: { user: { - select: { id: true, name: true, email: true, image: true, role: true }, + select: { + id: true, + name: true, + email: true, + image: true, + role: true, + }, }, member: true, organization: true, diff --git a/apps/api/src/audit/audit-log.interceptor.spec.ts b/apps/api/src/audit/audit-log.interceptor.spec.ts index 94be902fbc..e8130d665b 100644 --- a/apps/api/src/audit/audit-log.interceptor.spec.ts +++ b/apps/api/src/audit/audit-log.interceptor.spec.ts @@ -101,7 +101,9 @@ describe('AuditLogInterceptor', () => { } as unknown as ExecutionContext; }; - const createMockCallHandler = (response: unknown = { id: 'new_123' }): CallHandler => ({ + const createMockCallHandler = ( + response: unknown = { id: 'new_123' }, + ): CallHandler => ({ handle: () => of(response), }); @@ -441,7 +443,10 @@ describe('AuditLogInterceptor', () => { url: '/v1/risks', params: {}, }); - const handler = createMockCallHandler({ id: 'risk_789', name: 'Test Risk' }); + const handler = createMockCallHandler({ + id: 'risk_789', + name: 'Test Risk', + }); interceptor.intercept(context, handler).subscribe({ next: () => { @@ -646,7 +651,8 @@ describe('AuditLogInterceptor', () => { url: '/v1/comments', params: {}, body: { - content: '{"type":"doc","content":[{"type":"text","text":"This looks good!"}]}', + content: + '{"type":"doc","content":[{"type":"text","text":"This looks good!"}]}', entityId: 'pol_abc', entityType: 'policy', }, @@ -706,7 +712,10 @@ describe('AuditLogInterceptor', () => { data: expect.objectContaining({ data: expect.objectContaining({ changes: { - assignee: { previous: 'Alice Smith (mem_old)', current: 'Bob Jones (mem_new)' }, + assignee: { + previous: 'Alice Smith (mem_old)', + current: 'Bob Jones (mem_new)', + }, }, }), }), @@ -749,7 +758,10 @@ describe('AuditLogInterceptor', () => { data: expect.objectContaining({ data: expect.objectContaining({ changes: { - assignee: { previous: 'Unassigned', current: 'Bob Jones (mem_new)' }, + assignee: { + previous: 'Unassigned', + current: 'Bob Jones (mem_new)', + }, }, }), }), @@ -808,9 +820,7 @@ describe('AuditLogInterceptor', () => { // Existing controls on the policy mockPolicyFindUnique.mockResolvedValue({ - controls: [ - { id: 'ctrl_1', name: 'Access Control' }, - ], + controls: [{ id: 'ctrl_1', name: 'Access Control' }], }); // Resolve control names @@ -916,9 +926,7 @@ describe('AuditLogInterceptor', () => { // Only one control exists mockPolicyFindUnique.mockResolvedValue({ - controls: [ - { id: 'ctrl_1', name: 'Access Control' }, - ], + controls: [{ id: 'ctrl_1', name: 'Access Control' }], }); mockControlFindMany.mockResolvedValue([ @@ -1132,9 +1140,13 @@ describe('AuditLogInterceptor', () => { method: 'PATCH', url: '/v1/policies/pol_123/versions/ver_abc', params: { id: 'pol_123' }, - body: { content: [{ type: 'doc', content: [{ type: 'text', text: 'Hello' }] }] }, + body: { + content: [{ type: 'doc', content: [{ type: 'text', text: 'Hello' }] }], + }, + }); + const handler = createMockCallHandler({ + data: { versionId: 'ver_abc', version: 3 }, }); - const handler = createMockCallHandler({ data: { versionId: 'ver_abc', version: 3 } }); interceptor.intercept(context, handler).subscribe({ next: () => { @@ -1237,7 +1249,9 @@ describe('AuditLogInterceptor', () => { url: '/v1/policies/pol_123/pdf/signed-url?versionId=ver_456', params: { id: 'pol_123' }, }); - const handler = createMockCallHandler({ url: 'https://s3.example.com/policy.pdf' }); + const handler = createMockCallHandler({ + url: 'https://s3.example.com/policy.pdf', + }); interceptor.intercept(context, handler).subscribe({ next: () => { diff --git a/apps/api/src/audit/audit-log.interceptor.ts b/apps/api/src/audit/audit-log.interceptor.ts index 25b7d14f6c..e8f7667768 100644 --- a/apps/api/src/audit/audit-log.interceptor.ts +++ b/apps/api/src/audit/audit-log.interceptor.ts @@ -8,10 +8,7 @@ import { import { Reflector } from '@nestjs/core'; import { AuditLogEntityType, db, Prisma } from '@db'; import { Observable, from, switchMap, tap } from 'rxjs'; -import { - PERMISSIONS_KEY, - RequiredPermission, -} from '../auth/permission.guard'; +import { PERMISSIONS_KEY, RequiredPermission } from '../auth/permission.guard'; import { AuthenticatedRequest } from '../auth/types'; import { AUDIT_READ_KEY, SKIP_AUDIT_LOG_KEY } from './skip-audit-log.decorator'; import { @@ -61,11 +58,9 @@ export class AuditLogInterceptor implements NestInterceptor { return next.handle(); } - const requiredPermissions = - this.reflector.getAllAndOverride(PERMISSIONS_KEY, [ - context.getHandler(), - context.getClass(), - ]); + const requiredPermissions = this.reflector.getAllAndOverride< + RequiredPermission[] + >(PERMISSIONS_KEY, [context.getHandler(), context.getClass()]); if (!requiredPermissions?.length) { return next.handle(); @@ -108,8 +103,15 @@ export class AuditLogInterceptor implements NestInterceptor { entityId, isUpdate ? Object.keys(requestBody) : null, ).catch((err) => { - this.logger.error('Audit preflight failed, proceeding without pre-flight data', err); - return { previousValues: null, memberNames: {} as Record, relationMappingResult: null }; + this.logger.error( + 'Audit preflight failed, proceeding without pre-flight data', + err, + ); + return { + previousValues: null, + memberNames: {} as Record, + relationMappingResult: null, + }; }); return from(safePreFlightPromise).pipe( @@ -130,10 +132,7 @@ export class AuditLogInterceptor implements NestInterceptor { responseBody, requestBody, ); - const actionDesc = extractActionDescription( - request.url, - method, - ); + const actionDesc = extractActionDescription(request.url, method); const downloadDesc = extractDownloadDescription( request.url, method, @@ -150,28 +149,62 @@ export class AuditLogInterceptor implements NestInterceptor { (request as { userRoles?: string[] }).userRoles, ); let descriptionOverride: string | null = - actionDesc ?? versionDesc ?? downloadDesc ?? policyActionDesc ?? findingDesc; + actionDesc ?? + versionDesc ?? + downloadDesc ?? + policyActionDesc ?? + findingDesc; - const isAutomationUpdate = policyActionDesc && /automations/.test(request.url) && method === 'PATCH'; - const isAttachmentAction = policyActionDesc && /attachments/.test(request.url); + const isAutomationUpdate = + policyActionDesc && + /automations/.test(request.url) && + method === 'PATCH'; + const isAttachmentAction = + policyActionDesc && /attachments/.test(request.url); - if (commentCtx || versionDesc || (policyActionDesc && !isAutomationUpdate && !isAttachmentAction)) { + if ( + commentCtx || + versionDesc || + (policyActionDesc && !isAutomationUpdate && !isAttachmentAction) + ) { // Comments and version operations don't produce meaningful diffs // But preserve the comment/reason/changelog if provided in the request body const note = requestBody?.comment || requestBody?.changelog; - const noteLabel = requestBody?.changelog ? 'changelog' : 'reason'; - changes = note && typeof note === 'string' - ? { [noteLabel]: { previous: null, current: note } } - : null; + const noteLabel = requestBody?.changelog + ? 'changelog' + : 'reason'; + changes = + note && typeof note === 'string' + ? { [noteLabel]: { previous: null, current: note } } + : null; } else if (isAttachmentAction) { // For attachments, show file details in the expandable section // Upload: file info in request body. Delete: file info in response body. const attachmentChanges: ChangesRecord = {}; - const fileName = requestBody?.fileName || (responseBody && typeof responseBody === 'object' ? (responseBody as Record).fileName : null); - const fileType = requestBody?.fileType || (responseBody && typeof responseBody === 'object' ? (responseBody as Record).fileType : null); - if (fileName) attachmentChanges.file = { previous: null, current: fileName }; - if (fileType) attachmentChanges.type = { previous: null, current: fileType }; - changes = Object.keys(attachmentChanges).length > 0 ? attachmentChanges : null; + const fileName = + requestBody?.fileName || + (responseBody && typeof responseBody === 'object' + ? (responseBody as Record).fileName + : null); + const fileType = + requestBody?.fileType || + (responseBody && typeof responseBody === 'object' + ? (responseBody as Record).fileType + : null); + if (fileName) + attachmentChanges.file = { + previous: null, + current: fileName, + }; + if (fileType) + attachmentChanges.type = { + previous: null, + current: fileType, + }; + changes = + Object.keys(attachmentChanges).length > 0 + ? attachmentChanges + : null; } else if (relationMappingResult) { changes = relationMappingResult.changes; descriptionOverride ??= relationMappingResult.description; @@ -264,7 +297,8 @@ export class AuditLogInterceptor implements NestInterceptor { const entityType = commentContext?.entityType ?? RESOURCE_TO_ENTITY_TYPE[resource] ?? null; const entityId = - commentContext?.entityId ?? extractEntityId(request, method, responseBody); + commentContext?.entityId ?? + extractEntityId(request, method, responseBody); const description = commentContext?.description ?? descriptionOverride ?? diff --git a/apps/api/src/audit/audit-log.resolvers.ts b/apps/api/src/audit/audit-log.resolvers.ts index 7babbd8bd1..906d3b56c2 100644 --- a/apps/api/src/audit/audit-log.resolvers.ts +++ b/apps/api/src/audit/audit-log.resolvers.ts @@ -46,17 +46,25 @@ const RESOURCE_CONTROL_MODELS: Record = { risks: 'risk', }; -async function fetchControlIds(resource: string, parentId: string): Promise { +async function fetchControlIds( + resource: string, + parentId: string, +): Promise { const modelName = RESOURCE_CONTROL_MODELS[resource]; if (!modelName) return []; try { - const model = (db as unknown as Record)[modelName]; + const model = (db as unknown as Record)[ + modelName + ]; if (!model?.findUnique) return []; const record = await model.findUnique({ where: { id: parentId }, select: { controls: { select: { id: true } } }, }); - return (record as { controls?: { id: string }[] })?.controls?.map((c) => c.id) ?? []; + return ( + (record as { controls?: { id: string }[] })?.controls?.map((c) => c.id) ?? + [] + ); } catch { return []; } @@ -104,9 +112,7 @@ export async function buildRelationMappingChanges( } // DELETE /v1//:id/controls/:controlId — unmapping - const unmapMatch = path.match( - /\/v1\/(\w+)\/([^/]+)\/controls\/([^/]+)\/?$/, - ); + const unmapMatch = path.match(/\/v1\/(\w+)\/([^/]+)\/controls\/([^/]+)\/?$/); if (unmapMatch && method === 'DELETE') { const resource = unmapMatch[1]; const parentId = unmapMatch[2]; @@ -145,7 +151,9 @@ export async function fetchCurrentValues( const modelName = RESOURCE_TO_PRISMA_MODEL[resource]; if (!modelName) return null; - const model = (db as unknown as Record)[modelName]; + const model = (db as unknown as Record)[ + modelName + ]; if (!model?.findUnique) return null; const select: Record = {}; diff --git a/apps/api/src/audit/audit-log.utils.ts b/apps/api/src/audit/audit-log.utils.ts index 9aa24a4535..f1fe6dbc5e 100644 --- a/apps/api/src/audit/audit-log.utils.ts +++ b/apps/api/src/audit/audit-log.utils.ts @@ -66,7 +66,7 @@ export function extractActionDescription( ): string | null { if (method !== 'POST') return null; - const pathWithoutQuery = path.split('?')[0]!; + const pathWithoutQuery = path.split('?')[0]; if (/\/vendors\/[^/]+\/trigger-assessment\/?$/.test(pathWithoutQuery)) return 'Triggered vendor risk assessment'; @@ -195,7 +195,9 @@ export function extractPolicyActionDescription( ): string | null { // POST /v1/policies/:id/regenerate or /v1/tasks/:id/regenerate if (/\/regenerate\/?$/.test(path) && method === 'POST') { - return path.includes('/tasks/') ? 'Regenerated evidence' : 'Regenerated policy'; + return path.includes('/tasks/') + ? 'Regenerated evidence' + : 'Regenerated policy'; } // POST /v1/tasks/:id/approve @@ -227,11 +229,16 @@ export function extractPolicyActionDescription( } // Custom automation CRUD — /v1/tasks/:taskId/automations[/:automationId] - if (/\/tasks\/[^/]+\/automations(\/[^/]+)?\/?$/.test(path) && !/(runs|versions)/.test(path)) { + if ( + /\/tasks\/[^/]+\/automations(\/[^/]+)?\/?$/.test(path) && + !/(runs|versions)/.test(path) + ) { if (method === 'POST') return 'Created custom automation'; if (method === 'PATCH') { if (requestBody && 'isEnabled' in requestBody) { - return requestBody.isEnabled ? 'Enabled custom automation' : 'Disabled custom automation'; + return requestBody.isEnabled + ? 'Enabled custom automation' + : 'Disabled custom automation'; } if (requestBody && 'evaluationCriteria' in requestBody) { return 'Updated automation evaluation criteria'; @@ -261,7 +268,12 @@ export function extractPolicyActionDescription( } // PATCH /v1/policies/:id with isArchived field - if (method === 'PATCH' && /\/policies\/[^/]+\/?$/.test(pathWithoutQuery) && requestBody && 'isArchived' in requestBody) { + if ( + method === 'PATCH' && + /\/policies\/[^/]+\/?$/.test(pathWithoutQuery) && + requestBody && + 'isArchived' in requestBody + ) { return requestBody.isArchived ? 'Archived policy' : 'Restored policy'; } diff --git a/apps/api/src/auth/admin-rate-limit.middleware.spec.ts b/apps/api/src/auth/admin-rate-limit.middleware.spec.ts index af92fe73a9..11097e1faa 100644 --- a/apps/api/src/auth/admin-rate-limit.middleware.spec.ts +++ b/apps/api/src/auth/admin-rate-limit.middleware.spec.ts @@ -79,11 +79,7 @@ describe('adminAuthRateLimiter', () => { const next = jest.fn(); const res = buildRes(); - await adminAuthRateLimiter( - buildReq('/api/auth/admin/set-role'), - res, - next, - ); + await adminAuthRateLimiter(buildReq('/api/auth/admin/set-role'), res, next); expect(next).not.toHaveBeenCalled(); expect(res.statusCode).toBe(429); expect(res.body).toEqual({ diff --git a/apps/api/src/auth/api-key.service.spec.ts b/apps/api/src/auth/api-key.service.spec.ts index 518ab3aa65..ee77446c58 100644 --- a/apps/api/src/auth/api-key.service.spec.ts +++ b/apps/api/src/auth/api-key.service.spec.ts @@ -56,9 +56,19 @@ describe('ApiKeyService', () => { it('should include expected public resources', () => { const expected = [ - 'risk', 'vendor', 'task', 'control', 'policy', - 'evidence', 'framework', 'audit', 'finding', - 'questionnaire', 'integration', 'apiKey', 'pentest', + 'risk', + 'vendor', + 'task', + 'control', + 'policy', + 'evidence', + 'framework', + 'audit', + 'finding', + 'questionnaire', + 'integration', + 'apiKey', + 'pentest', ]; for (const resource of expected) { const matching = scopes.filter((s) => s.startsWith(`${resource}:`)); diff --git a/apps/api/src/auth/api-key.service.ts b/apps/api/src/auth/api-key.service.ts index b0fbac86de..8596207b8c 100644 --- a/apps/api/src/auth/api-key.service.ts +++ b/apps/api/src/auth/api-key.service.ts @@ -65,9 +65,7 @@ export class ApiKeyService { const availableScopes = this.getAvailableScopes(); const invalid = scopes.filter((s) => !availableScopes.includes(s)); if (invalid.length > 0) { - throw new BadRequestException( - `Invalid scopes: ${invalid.join(', ')}`, - ); + throw new BadRequestException(`Invalid scopes: ${invalid.join(', ')}`); } const apiKey = this.generateApiKey(); @@ -85,9 +83,7 @@ export class ApiKeyService { expirationDate = new Date(now.setDate(now.getDate() + 90)); break; case '1year': - expirationDate = new Date( - now.setFullYear(now.getFullYear() + 1), - ); + expirationDate = new Date(now.setFullYear(now.getFullYear() + 1)); break; default: throw new BadRequestException( @@ -179,15 +175,14 @@ export class ApiKeyService { // Use key prefix for indexed lookup when available (new keys), // fall back to full scan for legacy keys without prefix - const keyPrefix = apiKey.startsWith('comp_') ? this.extractPrefix(apiKey) : null; + const keyPrefix = apiKey.startsWith('comp_') + ? this.extractPrefix(apiKey) + : null; const apiKeyRecords = await db.apiKey.findMany({ where: { isActive: true, - OR: [ - { expiresAt: null }, - { expiresAt: { gt: new Date() } }, - ], + OR: [{ expiresAt: null }, { expiresAt: { gt: new Date() } }], ...(keyPrefix ? { keyPrefix } : {}), }, select: { @@ -215,10 +210,7 @@ export class ApiKeyService { where: { isActive: true, keyPrefix: null, - OR: [ - { expiresAt: null }, - { expiresAt: { gt: new Date() } }, - ], + OR: [{ expiresAt: null }, { expiresAt: { gt: new Date() } }], }, select: { id: true, @@ -279,10 +271,7 @@ export class ApiKeyService { * Resources from better-auth that are not used by any API endpoint's @RequirePermission. * These are handled internally by better-auth for session-based auth only. */ - private static readonly INTERNAL_ONLY_RESOURCES = [ - 'invitation', - 'team', - ]; + private static readonly INTERNAL_ONLY_RESOURCES = ['invitation', 'team']; /** * Returns all valid `resource:action` scope pairs derived from the permission statement. diff --git a/apps/api/src/auth/auth-context.decorator.ts b/apps/api/src/auth/auth-context.decorator.ts index a041e0d992..1d5590a8a0 100644 --- a/apps/api/src/auth/auth-context.decorator.ts +++ b/apps/api/src/auth/auth-context.decorator.ts @@ -76,6 +76,10 @@ export const UserId = createParamDecorator( } if (!userId) { + // For service tokens: allow if no user context needed (return a system identifier) + if (authType === 'service') { + return 'system'; + } throw new Error( 'User ID not found. Ensure HybridAuthGuard is applied and using session auth.', ); diff --git a/apps/api/src/auth/auth-server-origins.spec.ts b/apps/api/src/auth/auth-server-origins.spec.ts index 97ce7cb44d..7bb79a7709 100644 --- a/apps/api/src/auth/auth-server-origins.spec.ts +++ b/apps/api/src/auth/auth-server-origins.spec.ts @@ -6,7 +6,9 @@ * in isolation rather than importing the module directly. */ -function getTrustedOriginsLogic(authTrustedOrigins: string | undefined): string[] { +function getTrustedOriginsLogic( + authTrustedOrigins: string | undefined, +): string[] { if (authTrustedOrigins) { return authTrustedOrigins.split(',').map((o) => o.trim()); } @@ -68,7 +70,9 @@ describe('getTrustedOrigins', () => { }); it('should trim whitespace from comma-separated origins', () => { - const origins = getTrustedOriginsLogic(' https://a.com , https://b.com '); + const origins = getTrustedOriginsLogic( + ' https://a.com , https://b.com ', + ); expect(origins).toEqual(['https://a.com', 'https://b.com']); }); }); @@ -77,26 +81,45 @@ describe('isStaticTrustedOrigin', () => { const defaults = getTrustedOriginsLogic(undefined); it('should allow static trusted origins', () => { - expect(isStaticTrustedOriginLogic('https://app.trycomp.ai', defaults)).toBe(true); + expect(isStaticTrustedOriginLogic('https://app.trycomp.ai', defaults)).toBe( + true, + ); }); it('should allow trust portal subdomains of trycomp.ai', () => { - expect(isStaticTrustedOriginLogic('https://security.trycomp.ai', defaults)).toBe(true); - expect(isStaticTrustedOriginLogic('https://acme.trycomp.ai', defaults)).toBe(true); + expect( + isStaticTrustedOriginLogic('https://security.trycomp.ai', defaults), + ).toBe(true); + expect( + isStaticTrustedOriginLogic('https://acme.trycomp.ai', defaults), + ).toBe(true); }); it('should allow trust portal subdomains of staging.trycomp.ai', () => { - expect(isStaticTrustedOriginLogic('https://security.staging.trycomp.ai', defaults)).toBe(true); + expect( + isStaticTrustedOriginLogic( + 'https://security.staging.trycomp.ai', + defaults, + ), + ).toBe(true); }); it('should allow trust.inc and its subdomains', () => { - expect(isStaticTrustedOriginLogic('https://trust.inc', defaults)).toBe(true); - expect(isStaticTrustedOriginLogic('https://acme.trust.inc', defaults)).toBe(true); + expect(isStaticTrustedOriginLogic('https://trust.inc', defaults)).toBe( + true, + ); + expect(isStaticTrustedOriginLogic('https://acme.trust.inc', defaults)).toBe( + true, + ); }); it('should reject unknown origins', () => { - expect(isStaticTrustedOriginLogic('https://evil.com', defaults)).toBe(false); - expect(isStaticTrustedOriginLogic('https://trycomp.ai.evil.com', defaults)).toBe(false); + expect(isStaticTrustedOriginLogic('https://evil.com', defaults)).toBe( + false, + ); + expect( + isStaticTrustedOriginLogic('https://trycomp.ai.evil.com', defaults), + ).toBe(false); }); it('should handle invalid origins gracefully', () => { @@ -112,7 +135,9 @@ describe('isStaticTrustedOrigin', () => { ) as string; expect(mainTs).not.toContain('origin: true'); expect(mainTs).toContain('isTrustedOrigin'); - expect(mainTs).toContain("import { isTrustedOrigin } from './auth/auth.server'"); + expect(mainTs).toContain( + "import { isTrustedOrigin } from './auth/auth.server'", + ); }); }); diff --git a/apps/api/src/auth/auth.controller.ts b/apps/api/src/auth/auth.controller.ts index cfc074be26..eab2e593bd 100644 --- a/apps/api/src/auth/auth.controller.ts +++ b/apps/api/src/auth/auth.controller.ts @@ -25,7 +25,9 @@ import type { AuthContext as AuthContextType } from './types'; export class AuthController { @Get('me') @SkipOrgCheck() - @ApiOperation({ summary: 'Get current user info, organizations, and pending invitations' }) + @ApiOperation({ + summary: 'Get current user info, organizations, and pending invitations', + }) async getMe(@AuthContext() authContext: AuthContextType) { const userId = authContext.userId; if (!userId) { diff --git a/apps/api/src/auth/auth.module.ts b/apps/api/src/auth/auth.module.ts index 64ef0ce2d5..d739406a1a 100644 --- a/apps/api/src/auth/auth.module.ts +++ b/apps/api/src/auth/auth.module.ts @@ -24,12 +24,7 @@ import { PermissionGuard } from './permission.guard'; }), ], controllers: [AuthController], - providers: [ - ApiKeyService, - ApiKeyGuard, - HybridAuthGuard, - PermissionGuard, - ], + providers: [ApiKeyService, ApiKeyGuard, HybridAuthGuard, PermissionGuard], exports: [ ApiKeyService, ApiKeyGuard, diff --git a/apps/api/src/auth/auth.server.ts b/apps/api/src/auth/auth.server.ts index 03b5b3b7fd..3f7a8f4611 100644 --- a/apps/api/src/auth/auth.server.ts +++ b/apps/api/src/auth/auth.server.ts @@ -427,7 +427,7 @@ export const auth = betterAuth({ }), }); }, - ac: ac as AccessControl, + ac: ac, roles: allRoles, // Enable dynamic access control for custom roles // This allows organizations to create custom roles at runtime diff --git a/apps/api/src/auth/hybrid-auth.guard.ts b/apps/api/src/auth/hybrid-auth.guard.ts index adbbb23cde..e8a4adc15b 100644 --- a/apps/api/src/auth/hybrid-auth.guard.ts +++ b/apps/api/src/auth/hybrid-auth.guard.ts @@ -112,6 +112,23 @@ export class HybridAuthGuard implements CanActivate { request.isPlatformAdmin = false; request.userRoles = null; + // Service tokens can pass x-user-id to act on behalf of a user + // Validate that the user exists and belongs to the organization + const actingUserId = request.headers['x-user-id'] as string; + if (actingUserId) { + const member = await db.member.findFirst({ + where: { userId: actingUserId, organizationId }, + select: { userId: true }, + }); + if (member) { + request.userId = actingUserId; + } else { + this.logger.warn( + `Service token x-user-id "${actingUserId}" not found in org ${organizationId}`, + ); + } + } + this.logger.log( `Service "${service.definition.name}" authenticated for org ${organizationId}`, ); diff --git a/apps/api/src/auth/origin-check.middleware.spec.ts b/apps/api/src/auth/origin-check.middleware.spec.ts index 1690b74454..f679d97f6e 100644 --- a/apps/api/src/auth/origin-check.middleware.spec.ts +++ b/apps/api/src/auth/origin-check.middleware.spec.ts @@ -39,8 +39,12 @@ function createMockReq( /** Flush the microtask queue so async middleware completes. */ const flushPromises = () => new Promise((resolve) => setImmediate(resolve)); -function createMockRes(): Record & { statusCode?: number; body?: unknown } { - const res: Record & { statusCode?: number; body?: unknown } = {}; +function createMockRes(): Record & { + statusCode?: number; + body?: unknown; +} { + const res: Record & { statusCode?: number; body?: unknown } = + {}; res.status = jest.fn().mockImplementation((code: number) => { res.statusCode = code; return res; @@ -84,7 +88,11 @@ describe('originCheckMiddleware', () => { }); it('should allow POST from trusted origin', async () => { - const req = createMockReq('POST', '/v1/organization/api-keys', 'http://localhost:3000'); + const req = createMockReq( + 'POST', + '/v1/organization/api-keys', + 'http://localhost:3000', + ); const res = createMockRes(); const next = jest.fn(); @@ -95,7 +103,11 @@ describe('originCheckMiddleware', () => { }); it('should block POST from untrusted origin', async () => { - const req = createMockReq('POST', '/v1/organization/transfer-ownership', 'http://evil.com'); + const req = createMockReq( + 'POST', + '/v1/organization/transfer-ownership', + 'http://evil.com', + ); const res = createMockRes(); const next = jest.fn(); @@ -119,7 +131,11 @@ describe('originCheckMiddleware', () => { }); it('should block PATCH from untrusted origin', async () => { - const req = createMockReq('PATCH', '/v1/members/123/role', 'http://evil.com'); + const req = createMockReq( + 'PATCH', + '/v1/members/123/role', + 'http://evil.com', + ); const res = createMockRes(); const next = jest.fn(); @@ -161,7 +177,11 @@ describe('originCheckMiddleware', () => { }); it('should allow production origins', async () => { - const req = createMockReq('POST', '/v1/organization/api-keys', 'https://app.trycomp.ai'); + const req = createMockReq( + 'POST', + '/v1/organization/api-keys', + 'https://app.trycomp.ai', + ); const res = createMockRes(); const next = jest.fn(); diff --git a/apps/api/src/auth/origin-check.middleware.ts b/apps/api/src/auth/origin-check.middleware.ts index 27cb202264..44e4c868e7 100644 --- a/apps/api/src/auth/origin-check.middleware.ts +++ b/apps/api/src/auth/origin-check.middleware.ts @@ -8,10 +8,10 @@ const SAFE_METHODS = new Set(['GET', 'HEAD', 'OPTIONS']); * These are called by external services that don't send browser Origin headers. */ const EXEMPT_PATH_PREFIXES = [ - '/api/auth', // better-auth handles its own CSRF - '/v1/health', // health check - '/api/docs', // swagger - '/v1/trust-access', // public trust portal endpoints (no auth, no cookies) + '/api/auth', // better-auth handles its own CSRF + '/v1/health', // health check + '/api/docs', // swagger + '/v1/trust-access', // public trust portal endpoints (no auth, no cookies) ]; /** @@ -45,7 +45,7 @@ export function originCheckMiddleware( return next(); } - const origin = req.headers['origin'] as string | undefined; + const origin = req.headers['origin']; // No Origin header = not a browser request (API key, service token, curl, etc.) // These are authenticated via HybridAuthGuard, not cookies, so no CSRF risk. diff --git a/apps/api/src/auth/permission.guard.spec.ts b/apps/api/src/auth/permission.guard.spec.ts index d96c010f9e..3bf4c6ed6a 100644 --- a/apps/api/src/auth/permission.guard.spec.ts +++ b/apps/api/src/auth/permission.guard.spec.ts @@ -76,9 +76,9 @@ describe('PermissionGuard', () => { jest.useFakeTimers(); jest.setSystemTime(new Date('2026-04-19T23:59:59Z')); - jest.spyOn(reflector, 'getAllAndOverride').mockReturnValue([ - { resource: 'control', actions: ['delete'] }, - ]); + jest + .spyOn(reflector, 'getAllAndOverride') + .mockReturnValue([{ resource: 'control', actions: ['delete'] }]); const context = createMockExecutionContext({ isApiKey: true, @@ -97,9 +97,9 @@ describe('PermissionGuard', () => { jest.useFakeTimers(); jest.setSystemTime(new Date('2026-04-20T00:00:00Z')); - jest.spyOn(reflector, 'getAllAndOverride').mockReturnValue([ - { resource: 'control', actions: ['read'] }, - ]); + jest + .spyOn(reflector, 'getAllAndOverride') + .mockReturnValue([{ resource: 'control', actions: ['read'] }]); const context = createMockExecutionContext({ isApiKey: true, @@ -119,9 +119,9 @@ describe('PermissionGuard', () => { jest.useFakeTimers(); jest.setSystemTime(new Date('2026-05-01T00:00:00Z')); - jest.spyOn(reflector, 'getAllAndOverride').mockReturnValue([ - { resource: 'control', actions: ['read'] }, - ]); + jest + .spyOn(reflector, 'getAllAndOverride') + .mockReturnValue([{ resource: 'control', actions: ['read'] }]); const context = createMockExecutionContext({ isApiKey: true, @@ -138,9 +138,9 @@ describe('PermissionGuard', () => { }); it('should allow access for API keys with matching scopes', async () => { - jest.spyOn(reflector, 'getAllAndOverride').mockReturnValue([ - { resource: 'control', actions: ['read'] }, - ]); + jest + .spyOn(reflector, 'getAllAndOverride') + .mockReturnValue([{ resource: 'control', actions: ['read'] }]); const context = createMockExecutionContext({ isApiKey: true, @@ -154,9 +154,9 @@ describe('PermissionGuard', () => { }); it('should deny access for API keys with non-matching scopes', async () => { - jest.spyOn(reflector, 'getAllAndOverride').mockReturnValue([ - { resource: 'control', actions: ['read'] }, - ]); + jest + .spyOn(reflector, 'getAllAndOverride') + .mockReturnValue([{ resource: 'control', actions: ['read'] }]); const context = createMockExecutionContext({ isApiKey: true, @@ -171,9 +171,9 @@ describe('PermissionGuard', () => { }); it('should deny access when no authorization or cookie header present', async () => { - jest.spyOn(reflector, 'getAllAndOverride').mockReturnValue([ - { resource: 'control', actions: ['delete'] }, - ]); + jest + .spyOn(reflector, 'getAllAndOverride') + .mockReturnValue([{ resource: 'control', actions: ['delete'] }]); const context = createMockExecutionContext({ headers: {}, @@ -185,9 +185,9 @@ describe('PermissionGuard', () => { }); it('should allow access when SDK returns success', async () => { - jest.spyOn(reflector, 'getAllAndOverride').mockReturnValue([ - { resource: 'control', actions: ['delete'] }, - ]); + jest + .spyOn(reflector, 'getAllAndOverride') + .mockReturnValue([{ resource: 'control', actions: ['delete'] }]); mockHasPermission.mockResolvedValue({ success: true, error: null }); @@ -206,9 +206,9 @@ describe('PermissionGuard', () => { }); it('should deny access when SDK returns failure', async () => { - jest.spyOn(reflector, 'getAllAndOverride').mockReturnValue([ - { resource: 'control', actions: ['delete'] }, - ]); + jest + .spyOn(reflector, 'getAllAndOverride') + .mockReturnValue([{ resource: 'control', actions: ['delete'] }]); mockHasPermission.mockResolvedValue({ success: false, @@ -225,9 +225,9 @@ describe('PermissionGuard', () => { }); it('should deny access when SDK throws', async () => { - jest.spyOn(reflector, 'getAllAndOverride').mockReturnValue([ - { resource: 'control', actions: ['delete'] }, - ]); + jest + .spyOn(reflector, 'getAllAndOverride') + .mockReturnValue([{ resource: 'control', actions: ['delete'] }]); mockHasPermission.mockRejectedValue(new Error('SDK error')); diff --git a/apps/api/src/auth/platform-admin.guard.spec.ts b/apps/api/src/auth/platform-admin.guard.spec.ts index a5ff714c02..3cbd3ead09 100644 --- a/apps/api/src/auth/platform-admin.guard.spec.ts +++ b/apps/api/src/auth/platform-admin.guard.spec.ts @@ -24,8 +24,15 @@ jest.mock('@db', () => ({ }, })); -function buildContext(headers: Record = {}): ExecutionContext { - const request = { headers, userId: undefined, userEmail: undefined, isPlatformAdmin: undefined }; +function buildContext( + headers: Record = {}, +): ExecutionContext { + const request = { + headers, + userId: undefined, + userEmail: undefined, + isPlatformAdmin: undefined, + }; return { switchToHttp: () => ({ getRequest: () => request, diff --git a/apps/api/src/browserbase/browserbase.service.ts b/apps/api/src/browserbase/browserbase.service.ts index c4c63d8dd4..2cfb919f80 100644 --- a/apps/api/src/browserbase/browserbase.service.ts +++ b/apps/api/src/browserbase/browserbase.service.ts @@ -10,7 +10,7 @@ import { PutObjectCommand, S3Client, } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; +import { getSignedUrl } from '@/app/s3'; const BROWSER_WIDTH = 1440; const BROWSER_HEIGHT = 900; diff --git a/apps/api/src/browserbase/dto/browserbase.dto.ts b/apps/api/src/browserbase/dto/browserbase.dto.ts index 842d96d0f2..f0bfffbf49 100644 --- a/apps/api/src/browserbase/dto/browserbase.dto.ts +++ b/apps/api/src/browserbase/dto/browserbase.dto.ts @@ -1,5 +1,11 @@ import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger'; -import { IsNotEmpty, IsOptional, IsString, IsBoolean, IsUrl } from 'class-validator'; +import { + IsNotEmpty, + IsOptional, + IsString, + IsBoolean, + IsUrl, +} from 'class-validator'; import { IsSafeUrl } from '../validators/url-safety.validator'; // ===== Session DTOs ===== diff --git a/apps/api/src/browserbase/validators/url-safety.validator.ts b/apps/api/src/browserbase/validators/url-safety.validator.ts index 9d191e6bfd..1ae2f84b89 100644 --- a/apps/api/src/browserbase/validators/url-safety.validator.ts +++ b/apps/api/src/browserbase/validators/url-safety.validator.ts @@ -37,7 +37,9 @@ function isPrivateIpv6(hostname: string): boolean { } // IPv4-mapped IPv6 in hex form: ::ffff:a9fe:a9fe (169.254.169.254) - const hexMappedMatch = stripped.match(/^::ffff:([0-9a-f]{1,4}):([0-9a-f]{1,4})$/); + const hexMappedMatch = stripped.match( + /^::ffff:([0-9a-f]{1,4}):([0-9a-f]{1,4})$/, + ); if (hexMappedMatch) { const hi = parseInt(hexMappedMatch[1], 16); const lo = parseInt(hexMappedMatch[2], 16); diff --git a/apps/api/src/cloud-security/ai-remediation.prompt.ts b/apps/api/src/cloud-security/ai-remediation.prompt.ts new file mode 100644 index 0000000000..d175f39317 --- /dev/null +++ b/apps/api/src/cloud-security/ai-remediation.prompt.ts @@ -0,0 +1,299 @@ +import { z } from 'zod'; + +// ─── Zod Schemas ──────────────────────────────────────────────────────────── + +export const awsCommandStepSchema = z.object({ + service: z + .string() + .describe( + 'AWS SDK client package suffix, e.g. "s3" for @aws-sdk/client-s3', + ), + command: z + .string() + .describe( + 'Exact AWS SDK v3 command class name with Command suffix, e.g. "PutPublicAccessBlockCommand"', + ), + params: z + .record(z.string(), z.unknown()) + .describe('Exact input parameters the command expects'), + purpose: z + .string() + .describe('Human-readable description of what this step does'), +}); + +export type AwsCommandStep = z.infer; + +export const fixPlanSchema = z.object({ + canAutoFix: z + .boolean() + .describe('Whether this finding can be auto-fixed via AWS API calls'), + risk: z + .enum(['low', 'medium', 'high', 'critical']) + .describe('Risk level of applying this fix'), + description: z.string().describe('Human-readable description of the fix'), + currentState: z + .record(z.string(), z.unknown()) + .describe( + 'What the user currently has — the actual configuration that the scan found. Use real values from the evidence.', + ), + proposedState: z + .record(z.string(), z.unknown()) + .describe( + 'What the configuration will look like after the fix is applied.', + ), + requiredPermissions: z + .array(z.string()) + .describe('IAM actions needed, e.g. ["s3:PutPublicAccessBlock"]'), + readSteps: z + .array(awsCommandStepSchema) + .describe('Steps to read current state before fixing'), + fixSteps: z.array(awsCommandStepSchema).describe('Steps to apply the fix'), + rollbackSteps: z + .array(awsCommandStepSchema) + .describe('Steps to reverse the fix using previous state'), + rollbackSupported: z + .boolean() + .describe('Whether this fix can be rolled back'), + requiresAcknowledgment: z + .boolean() + .describe('Whether user must acknowledge before execution'), + acknowledgmentMessage: z + .string() + .optional() + .describe('Message shown when acknowledgment is required'), + guidedSteps: z + .array(z.string()) + .optional() + .describe('Manual steps when canAutoFix is false'), + reason: z + .string() + .optional() + .describe('Why auto-fix is not possible when canAutoFix is false'), +}); + +export type FixPlan = z.infer; + +export const permissionFixSchema = z.object({ + missingActions: z + .array(z.string()) + .describe('IAM actions that need to be added'), + policyStatement: z.object({ + Effect: z.literal('Allow'), + Action: z.array(z.string()), + Resource: z.string(), + }), +}); + +export type PermissionFix = z.infer; + +export const completePermissionsSchema = z.object({ + permissions: z + .array(z.string()) + .describe('Every single IAM action needed for the entire fix operation'), + reasoning: z + .string() + .describe('Brief explanation of why each permission group is needed'), +}); + +export type CompletePermissions = z.infer; + +// ─── Prompt Builders ──────────────────────────────────────────────────────── + +const SYSTEM_PROMPT = `You are an AWS security remediation expert. You analyze security findings and produce structured fix plans that will be executed by an automated system using AWS SDK v3. + +A human will ALWAYS review your plan before execution. Be precise and correct. + +## OUTPUT RULES + +1. For each step, provide: + - service: The AWS SDK client package suffix (e.g., "s3" for @aws-sdk/client-s3, "kms" for @aws-sdk/client-kms, "ec2" for @aws-sdk/client-ec2, "config-service" for @aws-sdk/client-config-service, "elastic-load-balancing-v2" for @aws-sdk/client-elastic-load-balancing-v2, "cognito-identity-provider" for @aws-sdk/client-cognito-identity-provider, "wafv2" for @aws-sdk/client-wafv2) + - command: The EXACT AWS SDK v3 command class name WITH "Command" suffix (e.g., "PutPublicAccessBlockCommand", "EnableKeyRotationCommand") + - params: The EXACT input parameters the command constructor expects + - purpose: Human-readable explanation + +2. For readSteps: provide commands that READ the current state (Get*, Describe*, List*) +3. For fixSteps: provide commands that CHANGE the state to fix the issue +4. For rollbackSteps: provide commands that RESTORE the previous state. Use "{{previousState}}" as a placeholder for values that will be filled from the read step results. + +## RESOURCE ID PARSING +- Extract actual resource names from ARNs: + - "arn:aws:s3:::my-bucket" → Bucket: "my-bucket" + - "arn:aws:kms:us-east-1:123:key/abc" → KeyId: "arn:aws:kms:us-east-1:123:key/abc" (use full ARN for KMS) + - "arn:aws:rds:us-east-1:123:db:mydb" → DBInstanceIdentifier: "mydb" + - "arn:aws:ec2:us-east-1:123:vpc/vpc-abc" → VpcId: "vpc-abc" +- Use the correct parameter names that the AWS SDK expects + +## SAFETY RULES (NEVER violate) +- NEVER delete data, buckets, tables, databases, or file systems +- NEVER modify IAM policies, roles, or users in ways that could lock out users +- NEVER change resource endpoints that active applications depend on +- NEVER terminate instances, clusters, or running services +- PREFER enabling features (encryption, logging, versioning) over disabling +- ALWAYS make changes reversible when possible +- For service-linked roles: create them as a setup step using IAM CreateServiceLinkedRoleCommand + +## IDEMPOTENCY (CRITICAL) +- All fix steps MUST be safe to run even if the resource already exists +- For Create operations: our executor automatically handles "already exists" errors — they are treated as success, not failure +- Use naturally idempotent APIs when possible: PutMetricFilter (overwrites), SNS CreateTopic (returns existing ARN), PutRetentionPolicy (overwrites) +- For IAM service delivery roles: use CreateRole — if role exists, the executor handles it +- For S3 buckets: use CreateBucket — if it exists, the executor handles it +- For log groups: use CreateLogGroup — if it exists, the executor handles it + +## IMPORTANT: IAM ROLES +- CompAI-Auditor: for scanning (read-only). Created during onboarding. +- CompAI-Remediator: for ALL our API calls. Created during onboarding. NEVER create a replacement. +- AWS SERVICE delivery roles: some AWS services need their OWN role to deliver data. Example: CloudTrail needs a role trusting cloudtrail.amazonaws.com to write to CloudWatch Logs. This is NOT the same as CompAI-Remediator — it's a role for the AWS service itself. +- You MAY create service delivery roles when required. Name them: CompAI-{Service}Delivery (e.g., CompAI-CloudTrailDelivery). +- Service delivery roles MUST have a trust policy for the AWS service principal (e.g., cloudtrail.amazonaws.com, config.amazonaws.com). +- Service-linked roles (GuardDuty, Config, Inspector, Macie): use CreateServiceLinkedRole — AWS manages them. + +## NAMING CONVENTIONS FOR NEW RESOURCES (FOLLOW EXACTLY) +- S3 bucket names MUST: be lowercase only, no underscores, 3-63 chars, globally unique + - Format: compai-{purpose}-{accountId}-{region} (e.g., compai-cloudtrail-013388577167-us-east-1) + - The account ID and region make it globally unique + - Get accountId from evidence.awsAccountId, get region from the finding context +- Log groups: /compai/{service} (e.g., /compai/cloudtrail) +- SNS topics: CompAI-{Purpose} (e.g., CompAI-CIS-Alerts) +- Service delivery IAM roles: CompAI-{Service}Delivery (e.g., CompAI-CloudTrailDelivery) +- Use the AWS account ID and region from evidence for unique resource names + +## GUIDED STEPS FORMAT (when canAutoFix=false) +- Each step should be SHORT and clear — one action per step +- Separate explanation from commands: put the explanation first, then the command on its own line +- Format commands with backtick markers: wrap CLI commands in triple backticks (three backtick characters before and after the command) +- Keep each step under 2-3 sentences of explanation + 1 command block +- Do NOT put multiple commands in one step — split them into separate steps +- Do NOT inline JSON policies in the step text — instead say "Apply the required bucket policy" and put the command separately + +## CRITICAL: FIX WHAT THE SCAN ACTUALLY CHECKS +- The finding tells you WHAT is wrong. Your fix must change the EXACT AWS configuration that the scan checks. +- If the finding says "encryption not enabled" — your fix must enable encryption on THAT specific resource, not create a new encrypted resource. +- If the finding says "logging not enabled" — your fix must enable logging on THAT existing resource. +- ALWAYS read the finding title, description, and evidence carefully to understand what EXACTLY needs to change. +- The fix must make the SAME check pass on the next scan. If you're not sure what the scan checks, use the finding evidence — it contains the exact data the scan found. +- The "Existing Remediation Guidance" field contains PRECISE instructions with exact AWS SDK command names. FOLLOW THOSE INSTRUCTIONS EXACTLY — they were written by the adapter that performs the scan and knows exactly what needs to change. + +## HANDLING [MANUAL] FINDINGS +- If the remediation guidance starts with "[MANUAL]", set canAutoFix to false immediately. +- These are findings that CANNOT be auto-fixed (e.g., encryption requiring resource recreation, MFA requiring physical devices). +- Provide the explanation from the remediation guidance as guidedSteps. + +## ERROR RESILIENCE +- If a resource or setting might not exist (e.g., SSM documents, Config recorders), use a read step first to check existence before attempting to update. +- For UpdateDocument: check document existence with GetDocument first. If it doesn't exist, use CreateDocument instead. +- For UpdateServiceSetting: check the setting exists with GetServiceSetting first. If it returns ServiceSettingNotFound, set canAutoFix to false and explain the issue. +- NEVER assume a resource exists just because the finding references it — the finding may have been created because the resource is MISSING. + +## WHEN TO SET canAutoFix=true (DEFAULT — auto-fix as much as possible) +- Enable/disable features on existing resources (encryption, logging, versioning, monitoring) +- Update configuration settings (password policy, retention, rotation) +- Enable services (GuardDuty, Macie, Inspector, Config) +- Block public access, disable public endpoints +- Create metric filters, alarms, SNS topics +- Create S3 buckets, log groups (our executor handles "already exists" gracefully) +- Multi-step operations where each step is a deterministic AWS API call +- Complex setups including those that need service delivery roles (e.g., CloudTrail + S3 bucket + CloudWatch Logs + service role) +- ALWAYS provide rollback steps so the customer can undo + +## WHEN TO SET canAutoFix=false +- Remediation guidance starts with "[MANUAL]" — always respect this +- Resource RECREATION required (EFS encryption, ElastiCache encryption, RDS encryption — must snapshot + recreate + migrate data) +- Physical device required (MFA hardware tokens, root MFA) +- User must choose between exclusive options (which auth type, which security group rules to keep) +- Active data migration needed between resources +- DNS/certificate changes (external registrar actions) +- Lambda runtime updates (may require code changes) +- Secret rotation setup (requires custom Lambda function) +- A required resource/setting does not exist and cannot be created with a simple API call + +## RISK ASSESSMENT +- low: Enabling features with no impact on existing functionality (encryption, logging, versioning) +- medium: Changes that modify behavior but are reversible (access restrictions) +- high: Changes that affect production traffic or access patterns +- critical: Irreversible changes or changes affecting authentication + +## REQUIRED PERMISSIONS (VERY IMPORTANT — GET THIS RIGHT FIRST TIME) +- List EVERY IAM action needed for the COMPLETE operation, not just the direct API calls +- Think through the FULL chain: if you CreateBucket, you also need PutBucketPolicy, GetBucketPolicy, PutBucketAcl +- Include iam:CreateRole and iam:PutRolePolicy when creating AWS service delivery roles +- Include iam:PassRole when attaching a role to an AWS service (CloudTrail, Config, etc.) +- NEVER include iam:AttachRolePolicy — use iam:PutRolePolicy (inline policies) instead +- If you CreateLogGroup, you also need PutRetentionPolicy, DescribeLogGroups +- If you CreateTrail, you also need StartLogging, GetTrailStatus, PutEventSelectors +- Include iam:CreateServiceLinkedRole when the service needs a service-linked role +- Include iam:PassRole when attaching a role to a service (CloudTrail, Config, etc.) +- Include BOTH the read permissions (Get*, Describe*, List*) AND write permissions (Put*, Create*, Update*) +- ALWAYS overestimate — it's better to request one extra permission than to fail mid-execution +- Common permissions people forget: iam:PassRole, s3:PutBucketPolicy, logs:CreateLogStream, logs:PutLogEvents + +## CRITICAL: NO PLACEHOLDERS EVER +- NEVER use placeholder values like "{{variable}}", "", or template syntax +- ALWAYS use concrete values in fix step params +- If a value depends on the account (like a log group name), put the discovery in readSteps and use a reasonable default or convention in fixSteps: + - CloudTrail log group: use "CloudTrail/DefaultLogGroup" (the system will resolve the real one from readSteps) + - SNS topic: use "CompAI-CIS-Alerts" (will be created if it doesn't exist) + - KMS keys: use "alias/aws/service-name" for AWS-managed keys +- The finding evidence contains REAL data from the AWS account scan — use those values +- If a value is truly unknown and not in evidence, use a sensible default that will work + +## CURRENT STATE AND PROPOSED STATE +- currentState: ONLY what the scan evidence shows. Do NOT guess or add fields that aren't in the evidence. + - If the evidence says something doesn't exist, show it as false or null + - NEVER use "unknown" — either you know from evidence or don't include the field + - Example: { "versioning": "Disabled" } + - Example: { "metricFilterExists": false, "alarmExists": false } +- proposedState: ONLY what will change after the fix. Same keys as currentState. + - Example: { "versioning": "Enabled" } + - Example: { "metricFilterExists": true, "filterName": "cis-4.8-s3-bucket-policy-changes", "alarmExists": true, "alarmName": "cis-4.8-s3-bucket-policy-changes" } +- Both must use the SAME keys so the user can compare side by side +- Do NOT include fields you don't know the value of`; + +export function buildFixPlanPrompt(finding: { + title: string; + description: string | null; + severity: string | null; + resourceType: string; + resourceId: string; + remediation: string | null; + findingKey: string; + evidence: Record; +}): string { + return `Analyze this AWS security finding and generate a fix plan. + +IMPORTANT: Your fix must change the EXACT AWS setting/resource that caused this finding. The scan will re-check the same thing after the fix — if you fix something different, the finding will persist. + +FINDING: +- Title: ${finding.title} +- Description: ${finding.description ?? 'N/A'} +- Severity: ${finding.severity ?? 'medium'} +- Resource Type: ${finding.resourceType} +- Resource ID: ${finding.resourceId} +- Finding Key: ${finding.findingKey} +- Existing Remediation Guidance: ${finding.remediation ?? 'None'} +- Evidence: ${JSON.stringify(finding.evidence, null, 2)} + +Generate the fix plan following all the rules in your instructions.`; +} + +export function buildPermissionFixPrompt(params: { + errorMessage: string; + failedStep: AwsCommandStep; + roleName: string; +}): string { + return `An AWS remediation step failed due to missing IAM permissions. + +ERROR: ${params.errorMessage} + +FAILED STEP: +- Service: ${params.failedStep.service} +- Command: ${params.failedStep.command} +- Params: ${JSON.stringify(params.failedStep.params)} + +IAM ROLE NAME: ${params.roleName} + +Analyze the error and determine EXACTLY which IAM actions are missing. +Include any related actions needed (e.g., if CreateDetector fails with service-linked role error, include iam:CreateServiceLinkedRole).`; +} + +export { SYSTEM_PROMPT }; diff --git a/apps/api/src/cloud-security/ai-remediation.service.ts b/apps/api/src/cloud-security/ai-remediation.service.ts new file mode 100644 index 0000000000..b3549e06ff --- /dev/null +++ b/apps/api/src/cloud-security/ai-remediation.service.ts @@ -0,0 +1,446 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { generateObject } from 'ai'; +import { anthropic } from '@ai-sdk/anthropic'; +import { + type FixPlan, + type PermissionFix, + type AwsCommandStep, + fixPlanSchema, + permissionFixSchema, + completePermissionsSchema, + SYSTEM_PROMPT, + buildFixPlanPrompt, + buildPermissionFixPrompt, +} from './ai-remediation.prompt'; +import { + type GcpFixPlan, + gcpFixPlanSchema, + GCP_SYSTEM_PROMPT, + buildGcpFixPlanPrompt, +} from './gcp-ai-remediation.prompt'; +import { + type AzureFixPlan, + azureFixPlanSchema, + AZURE_SYSTEM_PROMPT, + buildAzureFixPlanPrompt, +} from './azure-ai-remediation.prompt'; + +const MODEL = anthropic('claude-opus-4-6'); +const REMEDIATION_ROLE_NAME = 'CompAI-Remediator'; + +interface FindingContext { + title: string; + description: string | null; + severity: string | null; + resourceType: string; + resourceId: string; + remediation: string | null; + findingKey: string; + evidence: Record; +} + +@Injectable() +export class AiRemediationService { + private readonly logger = new Logger(AiRemediationService.name); + + /** Phase 1: Generate initial plan (read steps + preliminary fix plan). */ + async generateFixPlan(finding: FindingContext): Promise { + try { + const { object } = await generateObject({ + model: MODEL, + schema: fixPlanSchema, + system: SYSTEM_PROMPT, + prompt: buildFixPlanPrompt(finding), + temperature: 0, + }); + + this.logger.log( + `AI plan for ${finding.findingKey}: canAutoFix=${object.canAutoFix}, risk=${object.risk}`, + ); + return object; + } catch (err) { + this.logger.error( + `AI plan failed: ${err instanceof Error ? err.message : String(err)}`, + ); + return this.fallbackPlan(finding); + } + } + + /** + * Phase 2: Refine fix steps using REAL data from AWS. + * Called after read steps executed successfully. + * AI gets the actual AWS state and generates exact fix commands. + */ + async refineFixPlan(params: { + finding: FindingContext; + originalPlan: FixPlan; + realAwsState: Record; + }): Promise { + try { + const { object } = await generateObject({ + model: MODEL, + schema: fixPlanSchema, + system: SYSTEM_PROMPT, + prompt: `You previously analyzed this finding and generated read steps. Those read steps have been executed against the REAL AWS account. Here is the REAL data: + +REAL AWS STATE (from executing read steps): +${JSON.stringify(params.realAwsState, null, 2)} + +ORIGINAL FINDING: +${buildFixPlanPrompt(params.finding)} + +IMPORTANT: +1. Use the REAL AWS STATE above for ALL values in your fix steps. Do NOT guess or use defaults. +2. For requiredPermissions: list EVERY SINGLE IAM permission needed for ALL steps — read, fix, AND rollback. Think through the entire execution chain. If step 1 creates a bucket, you need s3:CreateBucket, s3:PutBucketPolicy, s3:GetBucketPolicy. If step 2 creates a role, you need iam:CreateRole, iam:PutRolePolicy, iam:GetRole, iam:PassRole. If step 3 creates a trail, you need cloudtrail:CreateTrail, cloudtrail:StartLogging, cloudtrail:GetTrailStatus, cloudtrail:DescribeTrails, cloudtrail:PutEventSelectors. Include EVERYTHING — the customer will add these permissions ONCE and should never need to add more. +3. ALWAYS overestimate permissions. It is much better to request 5 extra permissions than to fail mid-execution because one was missing. + +Generate the complete fix plan with EXACT values from the real AWS state.`, + temperature: 0, + }); + + this.logger.log(`AI refined plan for ${params.finding.findingKey}`); + return object; + } catch (err) { + this.logger.error( + `AI refine failed: ${err instanceof Error ? err.message : String(err)}`, + ); + // Fall back to original plan + return params.originalPlan; + } + } + + /** + * Dedicated permission analysis: given a complete plan, determine + * EVERY IAM permission needed. Separate AI call for maximum accuracy. + */ + async analyzeRequiredPermissions(plan: FixPlan): Promise { + try { + const allSteps = [ + ...plan.readSteps, + ...plan.fixSteps, + ...plan.rollbackSteps, + ]; + const stepsDescription = allSteps + .map((s) => `${s.service}:${s.command} — ${s.purpose}`) + .join('\n'); + + const { object } = await generateObject({ + model: MODEL, + schema: completePermissionsSchema, + system: + 'You are an AWS IAM permission expert. Given a list of AWS API calls, determine EVERY IAM permission needed. Be thorough — include all implicit permissions (iam:PassRole when roles are used, s3:PutBucketPolicy when buckets are created, etc.). It is critical that the list is COMPLETE because the customer will add these permissions once and should never need to add more.', + prompt: `These are the exact AWS SDK commands that will be executed: + +${stepsDescription} + +Full step details: +${JSON.stringify(allSteps, null, 2)} + +List EVERY IAM action needed. Include: +- The direct permission for each command (e.g., CreateBucketCommand → s3:CreateBucket) +- Implicit permissions (e.g., creating a bucket also needs s3:PutBucketPolicy, s3:GetBucketAcl) +- Dependent permissions (e.g., iam:PassRole when passing a role to CloudTrail) +- Read permissions needed for validation (e.g., cloudtrail:GetTrailStatus after creating a trail) + +OVERESTIMATE. Better to have 5 extra permissions than to miss one.`, + temperature: 0, + }); + + this.logger.log( + `AI permission analysis: ${object.permissions.length} permissions identified`, + ); + return object.permissions; + } catch (err) { + this.logger.error( + `AI permission analysis failed: ${err instanceof Error ? err.message : String(err)}`, + ); + // Fallback to plan's requiredPermissions + return plan.requiredPermissions; + } + } + + /** When a fix fails due to missing permissions. */ + async suggestPermissionFix(params: { + errorMessage: string; + failedStep: AwsCommandStep; + }): Promise { + try { + const { object } = await generateObject({ + model: MODEL, + schema: permissionFixSchema, + system: + 'You are an AWS IAM permission expert. Analyze the error and determine the exact missing IAM actions.', + prompt: buildPermissionFixPrompt({ + errorMessage: params.errorMessage, + failedStep: params.failedStep, + roleName: REMEDIATION_ROLE_NAME, + }), + temperature: 0, + }); + + const policy = JSON.stringify({ + Version: '2012-10-17', + Statement: [object.policyStatement], + }); + + return { + ...object, + fixScript: `aws iam put-role-policy --role-name ${REMEDIATION_ROLE_NAME} --policy-name CompAI-AutoFix --policy-document '${policy}'`, + }; + } catch (err) { + this.logger.error( + `AI permission fix failed: ${err instanceof Error ? err.message : String(err)}`, + ); + + const actionMatch = + params.errorMessage.match(/not authorized to perform:\s*([\w:*]+)/i) ?? + params.errorMessage.match(/required\s+([\w:*]+)\s+permission/i); + + const actions = actionMatch?.[1] ? [actionMatch[1]] : []; + if (actions.length === 0) { + return { + missingActions: [], + policyStatement: { + Effect: 'Allow' as const, + Action: [], + Resource: '*', + }, + fixScript: `# Could not determine the missing IAM action from the error. Check the error message and add the required permission manually to the ${REMEDIATION_ROLE_NAME} role.`, + }; + } + const policy = JSON.stringify({ + Version: '2012-10-17', + Statement: [{ Effect: 'Allow', Action: actions, Resource: '*' }], + }); + + return { + missingActions: actions, + policyStatement: { + Effect: 'Allow' as const, + Action: actions, + Resource: '*', + }, + fixScript: `aws iam put-role-policy --role-name ${REMEDIATION_ROLE_NAME} --policy-name CompAI-AutoFix --policy-document '${policy}'`, + }; + } + } + + // ─── GCP Methods ────────────────────────────────────────────────────── + + async generateGcpFixPlan(finding: FindingContext): Promise { + for (let attempt = 0; attempt < 2; attempt++) { + try { + const { object } = await generateObject({ + model: MODEL, + schema: gcpFixPlanSchema, + system: GCP_SYSTEM_PROMPT, + prompt: buildGcpFixPlanPrompt(finding), + temperature: 0, + }); + + this.logger.log( + `GCP AI plan for ${finding.findingKey}: canAutoFix=${object.canAutoFix}, risk=${object.risk}`, + ); + return object; + } catch (err) { + this.logger.error( + `GCP AI plan failed (attempt ${attempt + 1}): ${err instanceof Error ? err.message : String(err)}`, + ); + if (attempt === 0) continue; + return this.fallbackGcpPlan(finding); + } + } + return this.fallbackGcpPlan(finding); + } + + async refineGcpFixPlan(params: { + finding: FindingContext; + originalPlan: GcpFixPlan; + realGcpState: Record; + }): Promise { + try { + const { object } = await generateObject({ + model: MODEL, + schema: gcpFixPlanSchema, + system: GCP_SYSTEM_PROMPT, + prompt: `You previously analyzed this GCP finding and generated read steps. Those read steps have been executed against the REAL GCP account. Here is the REAL data: + +REAL GCP STATE (from executing read steps): +${JSON.stringify(params.realGcpState, null, 2)} + +ORIGINAL FINDING: +${buildGcpFixPlanPrompt(params.finding)} + +CRITICAL INSTRUCTIONS: +1. The "body" field in each fix step must contain EXACT JSON that will be sent to the GCP API. No descriptions, no placeholders, no human-readable text — ONLY valid JSON objects. +2. Use the REAL GCP STATE above for ALL values. Copy existing data structures exactly as they appear. +3. For setIamPolicy: the body MUST be { "policy": { "bindings": [...ALL existing bindings from real state...], "etag": "...from real state...", "version": 3, "auditConfigs": [...existing plus your additions...] } }. Copy the ENTIRE policy from the read step, then add/modify only what's needed. +4. For audit logging: add this to the policy's auditConfigs array: { "service": "allServices", "auditLogConfigs": [{"logType": "ADMIN_READ"}, {"logType": "DATA_READ"}, {"logType": "DATA_WRITE"}] } +5. For rollback: the body must restore the EXACT original policy from the read step (copy it verbatim). +6. The "body" field is sent directly as JSON to fetch(). If it contains strings like "enabled for all services" instead of actual JSON, the API will ignore it silently. + +Generate the complete fix plan with EXACT JSON values from the real GCP state.`, + temperature: 0, + }); + + this.logger.log(`GCP AI refined plan for ${params.finding.findingKey}`); + return object; + } catch (err) { + this.logger.error( + `GCP AI refine failed: ${err instanceof Error ? err.message : String(err)}`, + ); + return params.originalPlan; + } + } + + // ─── Azure Methods ──────────────────────────────────────────────────── + + async generateAzureFixPlan(finding: FindingContext): Promise { + try { + const { object } = await generateObject({ + model: MODEL, + schema: azureFixPlanSchema, + system: AZURE_SYSTEM_PROMPT, + prompt: buildAzureFixPlanPrompt(finding), + temperature: 0, + }); + + this.logger.log( + `Azure AI plan for ${finding.findingKey}: canAutoFix=${object.canAutoFix}, risk=${object.risk}`, + ); + return object; + } catch (err) { + this.logger.error( + `Azure AI plan failed: ${err instanceof Error ? err.message : String(err)}`, + ); + return this.fallbackAzurePlan(finding); + } + } + + async refineAzureFixPlan(params: { + finding: FindingContext; + originalPlan: AzureFixPlan; + realAzureState: Record; + }): Promise { + try { + const { object } = await generateObject({ + model: MODEL, + schema: azureFixPlanSchema, + system: AZURE_SYSTEM_PROMPT, + prompt: `You previously analyzed this Azure finding and generated read steps. Those read steps have been executed against the REAL Azure account. Here is the REAL data: + +REAL AZURE STATE (from executing read steps): +${JSON.stringify(params.realAzureState, null, 2)} + +ORIGINAL FINDING: +${buildAzureFixPlanPrompt(params.finding)} + +IMPORTANT: +1. Use the REAL AZURE STATE above for ALL values in your fix steps. Do NOT guess or use defaults. +2. For rollback steps, use the REAL values from the read steps to restore the previous configuration. +3. Make sure all URLs include the correct api-version parameter. + +Generate the complete fix plan with EXACT values from the real Azure state.`, + temperature: 0, + }); + + this.logger.log(`Azure AI refined plan for ${params.finding.findingKey}`); + return object; + } catch (err) { + this.logger.error( + `Azure AI refine failed: ${err instanceof Error ? err.message : String(err)}`, + ); + return params.originalPlan; + } + } + + private fallbackAzurePlan(finding: FindingContext): AzureFixPlan { + return { + canAutoFix: false, + risk: (finding.severity as AzureFixPlan['risk']) ?? 'medium', + description: + finding.remediation ?? finding.description ?? 'Check Azure Portal.', + currentState: {}, + proposedState: {}, + readSteps: [], + fixSteps: [], + rollbackSteps: [], + rollbackSupported: false, + requiresAcknowledgment: false, + guidedSteps: finding.remediation + ? [finding.remediation] + : ['Review the finding in Azure Portal and apply the recommended fix.'], + reason: 'AI analysis unavailable. Follow the guided steps.', + }; + } + + private fallbackGcpPlan(finding: FindingContext): GcpFixPlan { + const evidence = finding.evidence ?? {}; + const externalUri = evidence.externalUri as string | undefined; + const projectName = + (evidence.projectDisplayName as string) ?? 'your project'; + + const steps: string[] = []; + if (externalUri) { + steps.push( + `Open the resource in GCP Console: ${externalUri}`, + ); + } + if (finding.remediation) { + // Split SCC remediation text into separate steps if it contains "More info:" or multiple sentences + const parts = finding.remediation + .split(/(?:More info:|Compliance:)/i) + .map((s) => s.trim()) + .filter(Boolean); + if (parts[0]) steps.push(parts[0]); + if (parts[1]) steps.push(`Reference: ${parts[1]}`); + } + if (steps.length === 0) { + steps.push( + `Review the finding "${finding.title}" in the GCP Console for project ${projectName} and apply the recommended fix.`, + ); + } + + return { + canAutoFix: false, + risk: (finding.severity as GcpFixPlan['risk']) ?? 'medium', + description: + finding.description ?? finding.remediation ?? 'Check GCP Console.', + currentState: {}, + proposedState: {}, + readSteps: [], + fixSteps: [], + rollbackSteps: [], + rollbackSupported: false, + requiresAcknowledgment: false, + guidedSteps: steps, + reason: 'This finding requires manual remediation in the GCP Console.', + }; + } + + private fallbackPlan(finding: FindingContext): FixPlan { + return { + canAutoFix: false, + risk: + (finding.severity === 'info' + ? 'low' + : (finding.severity as FixPlan['risk'])) ?? 'medium', + description: + finding.remediation ?? + finding.description ?? + 'Check AWS documentation.', + currentState: {}, + proposedState: {}, + requiredPermissions: [], + readSteps: [], + fixSteps: [], + rollbackSteps: [], + rollbackSupported: false, + requiresAcknowledgment: false, + guidedSteps: finding.remediation + ? [finding.remediation] + : ['Review the finding in AWS Console and apply the recommended fix.'], + reason: 'AI analysis unavailable. Follow the guided steps.', + }; + } +} diff --git a/apps/api/src/cloud-security/aws-command-executor.ts b/apps/api/src/cloud-security/aws-command-executor.ts new file mode 100644 index 0000000000..2fc2837077 --- /dev/null +++ b/apps/api/src/cloud-security/aws-command-executor.ts @@ -0,0 +1,623 @@ +import type { AwsCredentialIdentity } from '@aws-sdk/types'; +import type { AwsCommandStep } from './ai-remediation.prompt'; + +import * as s3 from '@aws-sdk/client-s3'; +import * as dynamodb from '@aws-sdk/client-dynamodb'; +import * as kinesis from '@aws-sdk/client-kinesis'; +import * as redshift from '@aws-sdk/client-redshift'; +import * as backup from '@aws-sdk/client-backup'; +import * as ecr from '@aws-sdk/client-ecr'; +import * as glue from '@aws-sdk/client-glue'; +import * as athena from '@aws-sdk/client-athena'; +import * as opensearch from '@aws-sdk/client-opensearch'; +import * as secretsManager from '@aws-sdk/client-secrets-manager'; +import * as kms from '@aws-sdk/client-kms'; +import * as cloudtrail from '@aws-sdk/client-cloudtrail'; +import * as guardduty from '@aws-sdk/client-guardduty'; +import * as configService from '@aws-sdk/client-config-service'; +import * as iam from '@aws-sdk/client-iam'; +import * as sts from '@aws-sdk/client-sts'; +import * as inspector2 from '@aws-sdk/client-inspector2'; +import * as macie2 from '@aws-sdk/client-macie2'; +import * as cognito from '@aws-sdk/client-cognito-identity-provider'; +import * as shield from '@aws-sdk/client-shield'; +import * as wafv2 from '@aws-sdk/client-wafv2'; +import * as acm from '@aws-sdk/client-acm'; +import * as cwLogs from '@aws-sdk/client-cloudwatch-logs'; +import * as cloudwatch from '@aws-sdk/client-cloudwatch'; +import * as sns from '@aws-sdk/client-sns'; +import * as ec2 from '@aws-sdk/client-ec2'; +import * as lambda from '@aws-sdk/client-lambda'; +import * as eks from '@aws-sdk/client-eks'; +import * as emr from '@aws-sdk/client-emr'; +import * as codebuild from '@aws-sdk/client-codebuild'; +import * as elasticBeanstalk from '@aws-sdk/client-elastic-beanstalk'; +import * as sfn from '@aws-sdk/client-sfn'; +import * as elbv2 from '@aws-sdk/client-elastic-load-balancing-v2'; +import * as cloudfront from '@aws-sdk/client-cloudfront'; +import * as rds from '@aws-sdk/client-rds'; +import * as apigw from '@aws-sdk/client-apigatewayv2'; +import * as route53 from '@aws-sdk/client-route-53'; +import * as networkFirewall from '@aws-sdk/client-network-firewall'; +import * as transfer from '@aws-sdk/client-transfer'; +import * as sqs from '@aws-sdk/client-sqs'; +import * as eventbridge from '@aws-sdk/client-eventbridge'; +import * as ssm from '@aws-sdk/client-ssm'; +import * as kafka from '@aws-sdk/client-kafka'; +import * as sagemaker from '@aws-sdk/client-sagemaker'; +import * as efs from '@aws-sdk/client-efs'; +import * as elasticache from '@aws-sdk/client-elasticache'; + +type SdkModule = Record; + +/** Static map of service name → SDK module. Includes common aliases AI might use. */ +const SDK_MODULES: Record = { + s3: s3, + dynamodb: dynamodb, + kinesis: kinesis, + redshift: redshift, + backup: backup, + ecr: ecr, + glue: glue, + athena: athena, + opensearch: opensearch, + 'secrets-manager': secretsManager, + kms: kms, + cloudtrail: cloudtrail, + guardduty: guardduty, + 'config-service': configService, + iam: iam, + sts: sts, + inspector2: inspector2, + macie2: macie2, + 'cognito-identity-provider': cognito, + shield: shield, + wafv2: wafv2, + acm: acm, + 'cloudwatch-logs': cwLogs, + cloudwatch: cloudwatch, + sns: sns, + ec2: ec2, + lambda: lambda, + eks: eks, + emr: emr, + codebuild: codebuild, + 'elastic-beanstalk': elasticBeanstalk, + sfn: sfn, + 'elastic-load-balancing-v2': elbv2, + cloudfront: cloudfront, + rds: rds, + apigatewayv2: apigw, + 'route-53': route53, + 'network-firewall': networkFirewall, + transfer: transfer, + sqs: sqs, + eventbridge: eventbridge, + ssm: ssm, + kafka: kafka, + sagemaker: sagemaker, + efs: efs, + elasticache: elasticache, + // Common aliases AI might use + logs: cwLogs, + config: configService, + cognito: cognito, + waf: wafv2, + route53: route53, + 'step-functions': sfn, + elb: elbv2, + elbv2: elbv2, + apigateway: apigw, + msk: kafka, + inspector: inspector2, + macie: macie2, + secretsmanager: secretsManager, +}; + +/** Commands that are too dangerous or not allowed to execute. */ +const BLOCKED_COMMANDS = new Set([ + // Destructive + 'DeleteBucketCommand', + 'DeleteTableCommand', + 'DeleteDBInstanceCommand', + 'DeleteDBClusterCommand', + 'DeleteFileSystemCommand', + 'TerminateInstancesCommand', + 'DeleteClusterCommand', + 'DeleteStackCommand', + 'DeleteVpcCommand', + 'DeleteSubnetCommand', + 'DeleteUserCommand', + 'DeleteRoleCommand', + // AttachRolePolicy blocked — use PutRolePolicy (inline) instead + 'AttachRolePolicyCommand', +]); + +/** Param names that AWS expects as JSON strings, not objects. */ +const JSON_STRING_PARAMS = new Set([ + 'Content', + 'PolicyDocument', + 'AssumeRolePolicyDocument', + 'Policy', + 'TrustPolicy', + 'ResourcePolicy', + 'Configuration', + 'Definition', +]); + +/** + * Universal pre-execution param normalisation. + * Fixes common AI mistakes without per-command logic. + */ +function normaliseInputParams( + input: Record, + command: string, + region: string, +): void { + for (const [key, value] of Object.entries(input)) { + // Rule 1: Stringify any object param that AWS expects as a JSON string + if ( + value !== null && + typeof value === 'object' && + !Array.isArray(value) && + JSON_STRING_PARAMS.has(key) + ) { + input[key] = JSON.stringify(value); + } + } + + // Rule 2: S3 CreateBucket needs LocationConstraint for non-us-east-1 + if (command === 'CreateBucketCommand') { + if (input.Bucket) { + input.Bucket = String(input.Bucket).toLowerCase().replace(/_/g, '-'); + } + if (region !== 'us-east-1' && !input.CreateBucketConfiguration) { + input.CreateBucketConfiguration = { LocationConstraint: region }; + } + } + + // Rule 3: CloudTrail trails should default to multi-region + validation + if (command === 'CreateTrailCommand') { + if (!input.IsMultiRegionTrail) input.IsMultiRegionTrail = true; + if (!input.EnableLogFileValidation) input.EnableLogFileValidation = true; + } +} + +/** + * Universal send-with-retry. Handles three recoverable error classes: + * 1. Validation errors → auto-fix the offending param, retry once + * 2. Throttling → exponential backoff, up to 3 retries + * 3. IAM propagation → wait and retry (roles/policies take seconds to propagate) + * Everything else is surfaced immediately. + */ + +async function sendWithAutoRetry( + client: any, + + CommandClass: any, + input: Record, + command: string, + service: string, +): Promise> { + const MAX_ATTEMPTS = 4; // 1 initial + up to 3 retries + let validationFixed = false; + + for (let attempt = 0; attempt < MAX_ATTEMPTS; attempt++) { + try { + const result = await client.send(new CommandClass(input)); + + // After creating an IAM role, wait for propagation + if ( + command === 'CreateRoleCommand' || + command === 'PutRolePolicyCommand' + ) { + await new Promise((r) => setTimeout(r, 5000)); + } + + return (result ?? {}) as Record; + } catch (err) { + const awsErr = err as { + name?: string; + message?: string; + Code?: string; + $metadata?: { httpStatusCode?: number }; + }; + const errName = awsErr.name ?? ''; + const errMsg = + awsErr.message || + awsErr.Code || + `${errName} (HTTP ${awsErr.$metadata?.httpStatusCode ?? 'unknown'})`; + + console.error( + `AWS Command Error [${service}:${command}] attempt ${attempt + 1}:`, + errName, + errMsg, + ); + + // ── Idempotent "already exists" → treat as success ── + if ( + errName === 'ResourceAlreadyExistsException' || + errName === 'DuplicateDocumentContent' || + errName === 'DuplicateDocumentVersionName' || + errMsg.includes('already exists') || + errMsg.includes('AlreadyExists') || + errMsg.includes('same metadata and content') || + errMsg.includes('DuplicateDocument') + ) { + return { _alreadyExists: true, message: errMsg }; + } + + // ── Throttle / rate limit → backoff and retry ── + if (isThrottleError(errName, errMsg) && attempt < MAX_ATTEMPTS - 1) { + const delay = Math.min(1000 * 2 ** attempt, 8000); // 1s, 2s, 4s, 8s + console.log( + `Throttled on ${service}:${command}, retrying in ${delay}ms`, + ); + await new Promise((r) => setTimeout(r, delay)); + continue; + } + + // ── Validation error → auto-fix param and retry once ── + if (!validationFixed && isValidationError(errName, errMsg)) { + const fixed = tryAutoFixValidationError(input, errMsg); + if (fixed) { + console.log( + `Auto-fixed validation error, retrying ${service}:${command}`, + ); + validationFixed = true; + continue; + } + } + + // ── Not found → clear message ── + if ( + errName === 'ServiceSettingNotFound' || + errName === 'ResourceNotFoundException' || + errName === 'NotFoundException' || + errName === 'InvalidDocument' || + errName === 'NoSuchEntity' || + errName === 'NoSuchBucket' || + errName === 'DetectorNotFoundException' || + errMsg.includes('does not exist') || + errMsg.includes('not found') + ) { + throw new Error( + `${service}:${command} failed: target resource not found (${errName}). ${errMsg}`, + ); + } + + // ── Unknown / unrecoverable ── + if (!errMsg || errMsg === 'Unknown' || errMsg === 'UnknownError') { + throw new Error( + `${service}:${command} failed with ${errName || 'unknown error'} (HTTP ${awsErr.$metadata?.httpStatusCode ?? '?'}). Check IAM permissions and input parameters.`, + ); + } + throw err; + } + } + throw new Error( + `${service}:${command} failed after ${MAX_ATTEMPTS} attempts`, + ); +} + +function isThrottleError(errName: string, errMsg: string): boolean { + return ( + errName === 'Throttling' || + errName === 'ThrottlingException' || + errName === 'TooManyRequestsException' || + errName === 'RequestLimitExceeded' || + errMsg.includes('Rate exceeded') || + errMsg.includes('Throttling') || + errMsg.includes('Too Many Requests') + ); +} + +function isValidationError(errName: string, errMsg: string): boolean { + return ( + errName === 'ValidationException' || + errName === 'InvalidParameterValue' || + errName === 'InvalidParameterValueException' || + errMsg.includes('validation error') || + errMsg.includes('failed to satisfy constraint') + ); +} + +/** + * Parse the AWS validation error, fix the offending param, return true if fixed. + * AWS error format: "Value at 'fieldName' failed to satisfy constraint: ..." + * + * Key subtlety: AWS errors use camelCase ('documentVersion') but SDK params + * use PascalCase ('DocumentVersion'). We match case-insensitively. + */ +function tryAutoFixValidationError( + input: Record, + errMsg: string, +): boolean { + // Extract field name from error (camelCase) + const fieldMatch = errMsg.match(/Value at '(\w+)'/i); + if (!fieldMatch?.[1]) return false; + + const errorField = fieldMatch[1]; + + // Find the actual key in input (case-insensitive match) + const inputKey = Object.keys(input).find( + (k) => k.toLowerCase() === errorField.toLowerCase(), + ); + if (!inputKey) return false; + + const value = input[inputKey]; + + // Fix 1: Regex constraint (version numbers, IDs, etc.) + // → remove the param so AWS uses its default + if (errMsg.includes('regular expression pattern')) { + delete input[inputKey]; + return true; + } + + // Fix 2: Object instead of string → stringify + if (value !== null && typeof value === 'object') { + input[inputKey] = JSON.stringify(value); + return true; + } + + // Fix 3: Length constraint → truncate + const lengthMatch = errMsg.match(/length less than or equal to (\d+)/); + if (lengthMatch && typeof value === 'string') { + input[inputKey] = value.slice(0, Number(lengthMatch[1])); + return true; + } + + return false; +} + +/** + * Validate all steps in a plan BEFORE executing anything. + * Catches: unknown services, missing commands, blocked commands, placeholder values. + * Returns list of errors. Empty = valid. + */ +export function validatePlanSteps(steps: AwsCommandStep[]): string[] { + const errors: string[] = []; + + for (let i = 0; i < steps.length; i++) { + const step = steps[i]; + const prefix = `Step ${i + 1} (${step.command})`; + + // Check service exists + if (!SDK_MODULES[step.service]) { + errors.push(`${prefix}: Unknown service "${step.service}"`); + continue; + } + + // Check command exists in module (with fuzzy match for AI mistakes) + const mod = SDK_MODULES[step.service]; + let cmdExists = + mod[step.command] && typeof mod[step.command] === 'function'; + if (!cmdExists) { + const cmdBase = step.command.replace('Command', ''); + const fuzzy = Object.keys(mod).find((k) => { + if (!k.endsWith('Command') || typeof mod[k] !== 'function') + return false; + const kBase = k.replace('Command', ''); + return ( + kBase.includes(cmdBase) || + cmdBase.includes(kBase) || + kBase.replace('Bucket', '') === cmdBase.replace('Bucket', '') + ); + }); + cmdExists = Boolean(fuzzy); + } + if (!cmdExists) { + errors.push( + `${prefix}: Command "${step.command}" not found in @aws-sdk/client-${step.service}`, + ); + continue; + } + + // Check command name format + if (!step.command.endsWith('Command')) { + errors.push(`${prefix}: Command name must end with "Command"`); + } + + // Check blocked + if (BLOCKED_COMMANDS.has(step.command)) { + errors.push(`${prefix}: Command is blocked for safety`); + } + + // Check for placeholder values in params + const paramStr = JSON.stringify(step.params); + const placeholders = paramStr.match(/\{\{[\w]+\}\}|<[A-Z_]+>/g); + if (placeholders) { + errors.push( + `${prefix}: Contains placeholder values: ${placeholders.join(', ')}`, + ); + } + } + + return errors; +} + +export interface StepResult { + step: AwsCommandStep; + output: Record; +} + +export interface PlanExecutionResult { + results: StepResult[]; + error?: { stepIndex: number; message: string; step: AwsCommandStep }; +} + +/** + * Execute a single AWS SDK v3 command. + * Uses static imports — no dynamic require, no version mismatches. + */ +export async function executeAwsCommand(params: { + service: string; + command: string; + input: Record; + credentials: AwsCredentialIdentity; + region: string; + isRollback?: boolean; +}): Promise> { + const { service, command, input, credentials, region, isRollback } = params; + + const mod = SDK_MODULES[service]; + if (!mod) { + throw new Error(`Service "${service}" is not supported`); + } + + // Block dangerous commands — unless this is a rollback (rollback needs Delete to undo) + if (BLOCKED_COMMANDS.has(command) && !isRollback) { + throw new Error(`Command "${command}" is blocked for safety`); + } + + if (!command.endsWith('Command')) { + throw new Error(`Invalid command name "${command}"`); + } + + // ─── Universal param normalisation ────────────────────────────────── + // Instead of per-command hacks, apply two universal rules that cover + // every current and future AWS command the AI might generate. + + normaliseInputParams(input, command, region); + + // Try exact command name first, then fuzzy match if not found + let CommandClass = mod[command]; + if (!CommandClass || typeof CommandClass !== 'function') { + // AI sometimes generates wrong command names — try to find the closest match + const cmdBase = command.replace('Command', ''); + const match = Object.keys(mod).find((k) => { + if (!k.endsWith('Command') || typeof mod[k] !== 'function') return false; + const kBase = k.replace('Command', ''); + // Check if one contains the other (e.g., PutBucketPublicAccessBlock vs PutPublicAccessBlock) + return ( + kBase.includes(cmdBase) || + cmdBase.includes(kBase) || + kBase.replace('Bucket', '') === cmdBase.replace('Bucket', '') + ); + }); + if (match) { + // Re-check blocked commands against the resolved name + if (BLOCKED_COMMANDS.has(match) && !isRollback) { + throw new Error(`Command "${match}" is blocked for safety`); + } + CommandClass = mod[match]; + } + } + if (!CommandClass || typeof CommandClass !== 'function') { + throw new Error( + `Command "${command}" not found in @aws-sdk/client-${service}`, + ); + } + + // Find the client class from the same module (skip internal __Client) + const clientKey = Object.keys(mod).find( + (k) => + k.endsWith('Client') && + k !== 'Client' && + !k.startsWith('_') && + typeof mod[k] === 'function' && + !k.includes('Command') && + !k.includes('Exception'), + ); + if (!clientKey) { + throw new Error(`No client found in @aws-sdk/client-${service}`); + } + + const client = new mod[clientKey]({ + region, + credentials: { + accessKeyId: credentials.accessKeyId, + secretAccessKey: credentials.secretAccessKey, + sessionToken: credentials.sessionToken, + }, + }); + + try { + return await sendWithAutoRetry( + client, + CommandClass, + input, + command, + service, + ); + } finally { + client.destroy?.(); + } +} + +/** + * Execute a sequence of steps. Stops on first error. + * When `autoRollbackSteps` is provided and a step fails, automatically + * undoes completed steps in reverse order (best-effort). + * Convention: rollbackSteps[i] undoes fixSteps[i]. + */ +export async function executePlanSteps(params: { + steps: AwsCommandStep[]; + credentials: AwsCredentialIdentity; + region: string; + isRollback?: boolean; + autoRollbackSteps?: AwsCommandStep[]; +}): Promise { + const results: StepResult[] = []; + + for (let i = 0; i < params.steps.length; i++) { + const step = params.steps[i]; + try { + const output = await executeAwsCommand({ + service: step.service, + command: step.command, + input: structuredClone(step.params), + credentials: params.credentials, + region: params.region, + isRollback: params.isRollback, + }); + results.push({ step, output }); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + + // If a prior step was a no-op (already exists / duplicate content), + // this step may depend on output from that no-op (e.g., a version number). + // Skip it instead of failing the entire execution — the infra is already + // in the desired state. This is universal: works for any service. + const hasPriorNoOp = results.some( + (r) => r.output._alreadyExists || r.output._skipped, + ); + if ( + hasPriorNoOp && + (message.includes('validation error') || + message.includes('failed to satisfy constraint')) + ) { + console.log( + `Skipping step ${i + 1} (${step.command}) — prior step was no-op, this step likely depends on its output`, + ); + results.push({ step, output: { _skipped: true, reason: message } }); + continue; + } + + // Auto-rollback completed steps if rollback steps were provided + if (params.autoRollbackSteps && results.length > 0) { + const rollbackSlice = params.autoRollbackSteps + .slice(0, results.length) + .reverse(); + for (const rbStep of rollbackSlice) { + try { + await executeAwsCommand({ + service: rbStep.service, + command: rbStep.command, + input: structuredClone(rbStep.params), + credentials: params.credentials, + region: params.region, + isRollback: true, + }); + } catch { + // Best-effort rollback — don't mask original error + } + } + } + + return { results, error: { stepIndex: i, message, step } }; + } + } + + return { results }; +} diff --git a/apps/api/src/cloud-security/aws-task-mappings.ts b/apps/api/src/cloud-security/aws-task-mappings.ts new file mode 100644 index 0000000000..3584073e5d --- /dev/null +++ b/apps/api/src/cloud-security/aws-task-mappings.ts @@ -0,0 +1,58 @@ +/** + * Maps AWS service adapter IDs to framework task template IDs. + * + * When ALL findings for a service pass, the linked evidence tasks + * are auto-satisfied with scan results as proof. + * + * Only pass → done. Never mark tasks as failed from scan data. + */ +export const AWS_SERVICE_TASK_MAPPINGS: Record = { + // IAM → Employee Access, RBAC, Access Review Log + 'iam-analyzer': [ + 'frk_tt_68406ca292d9fffb264991b9', + 'frk_tt_68e80544d9734e0402cfa807', + 'frk_tt_68e805457c2dcc784e72e3cc', + ], + // KMS, S3, RDS, DynamoDB → Encryption at Rest + kms: ['frk_tt_68e52b26bf0e656af9e4e9c3'], + s3: ['frk_tt_68e52b26bf0e656af9e4e9c3'], + rds: ['frk_tt_68e52b26bf0e656af9e4e9c3', 'frk_tt_68e52b26b166e2c0a0d11956'], + // CloudTrail, CloudWatch → Monitoring & Alerting + cloudtrail: ['frk_tt_68406af04a4acb93083413b9'], + cloudwatch: ['frk_tt_68406af04a4acb93083413b9'], + // GuardDuty → Incident Response + guardduty: ['frk_tt_68406b4f40c87c12ae0479ce'], + // Secrets Manager → Secure Secrets + 'secrets-manager': ['frk_tt_68407ae5274a64092c305104'], + // ELB, ACM, CloudFront → TLS / HTTPS + elb: ['frk_tt_68406f411fe27e47a0d6d5f3'], + acm: ['frk_tt_68406f411fe27e47a0d6d5f3'], + cloudfront: ['frk_tt_68406f411fe27e47a0d6d5f3'], + // EC2/VPC, WAF, Network Firewall → Production Firewall + 'ec2-vpc': [ + 'frk_tt_68fa2a852e70f757188f0c39', + 'frk_tt_68406af04a4acb93083413b9', + ], + waf: ['frk_tt_68fa2a852e70f757188f0c39'], + 'network-firewall': ['frk_tt_68fa2a852e70f757188f0c39'], + // Shield → App Availability + shield: ['frk_tt_68406d2e86acc048d1774ea6'], + // Backup, RDS, DynamoDB → Backup logs + backup: ['frk_tt_68e52b26b166e2c0a0d11956'], + dynamodb: [ + 'frk_tt_68e52b26bf0e656af9e4e9c3', + 'frk_tt_68e52b26b166e2c0a0d11956', + ], + // Config, Inspector → Internal Security Audit + config: ['frk_tt_68e52b2618cb9d9722c6edfd'], + inspector: ['frk_tt_68e52b2618cb9d9722c6edfd'], + // Lambda → Secure Code + lambda: ['frk_tt_68406e353df3bc002994acef'], + // ECS/EKS → Separation of Environments, Monitoring + 'ecs-eks': [ + 'frk_tt_68e52a484cad0014de7a628f', + 'frk_tt_68406af04a4acb93083413b9', + ], + // Cognito → Employee Access + cognito: ['frk_tt_68406ca292d9fffb264991b9'], +}; diff --git a/apps/api/src/cloud-security/azure-ai-remediation.prompt.ts b/apps/api/src/cloud-security/azure-ai-remediation.prompt.ts new file mode 100644 index 0000000000..a573ee9a35 --- /dev/null +++ b/apps/api/src/cloud-security/azure-ai-remediation.prompt.ts @@ -0,0 +1,209 @@ +import { z } from 'zod'; + +export const azureApiStepSchema = z.object({ + method: z.enum(['GET', 'POST', 'PUT', 'PATCH', 'DELETE']), + url: z + .string() + .describe('Full HTTPS URL including api-version query parameter'), + body: z.record(z.string(), z.unknown()).optional(), + queryParams: z.record(z.string(), z.string()).optional(), + purpose: z.string().describe('What this step does'), +}); + +export type AzureApiStep = z.infer; + +export const azureFixPlanSchema = z.object({ + canAutoFix: z.boolean(), + risk: z.enum(['low', 'medium', 'high', 'critical']), + description: z.string(), + currentState: z.record(z.string(), z.unknown()), + proposedState: z.record(z.string(), z.unknown()), + readSteps: z.array(azureApiStepSchema), + fixSteps: z.array(azureApiStepSchema), + rollbackSteps: z.array(azureApiStepSchema), + rollbackSupported: z.boolean(), + requiresAcknowledgment: z.boolean(), + acknowledgmentMessage: z.string().optional(), + guidedSteps: z.array(z.string()).optional(), + reason: z.string().optional(), +}); + +export type AzureFixPlan = z.infer; + +export const AZURE_SYSTEM_PROMPT = `You are an Azure security remediation expert. You analyze Microsoft Defender for Cloud findings and generate automated fix plans using the Azure Resource Manager (ARM) REST API. + +## Output Format +Return a JSON object matching the schema. Each API step must include: method, url, body (if needed), queryParams (if needed), purpose. + +## Azure ARM API Patterns + +### URL Format +All URLs follow: https://management.azure.com/{scope}/providers/{resourceProvider}/{resourceType}/{resourceName}?api-version={version} + +### Key Vault +- GET vault: https://management.azure.com/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.KeyVault/vaults/{name}?api-version=2023-07-01 +- PATCH vault: Same URL, method PATCH with body: { properties: { enableSoftDelete: true, enablePurgeProtection: true } } +- Update network rules: PATCH with body: { properties: { networkAcls: { defaultAction: "Deny" } } } + +### Network Security Groups +- GET NSG: https://management.azure.com/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Network/networkSecurityGroups/{nsg}?api-version=2023-11-01 +- GET rule: .../securityRules/{rule}?api-version=2023-11-01 +- PUT rule (update): Same URL, PUT with full rule definition +- DELETE rule: Same URL, DELETE method (rollback only) + +### Storage Accounts +- GET: https://management.azure.com/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Storage/storageAccounts/{name}?api-version=2023-05-01 +- PATCH: { properties: { supportsHttpsTrafficOnly: true, minimumTlsVersion: "TLS1_2", allowBlobPublicAccess: false } } + +### SQL Servers & Databases +- GET server: https://management.azure.com/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Sql/servers/{name}?api-version=2023-05-01-preview +- PATCH auditing: .../auditingSettings/default?api-version=2021-11-01 with { properties: { state: "Enabled" } } +- PATCH TDE: .../databases/{db}/transparentDataEncryption/current?api-version=2021-11-01 + +### Diagnostic Settings (Subscription-level) +- GET existing: https://management.azure.com/subscriptions/{sub}/providers/Microsoft.Insights/diagnosticSettings?api-version=2021-05-01-preview +- Discover workspaces: https://management.azure.com/subscriptions/{sub}/providers/Microsoft.OperationalInsights/workspaces?api-version=2022-10-01 +- PUT to create: https://management.azure.com/subscriptions/{sub}/providers/Microsoft.Insights/diagnosticSettings/{settingName}?api-version=2021-05-01-preview + Body: { "properties": { "workspaceId": "{discovered_workspace_id}", "logs": [{"category": "Administrative", "enabled": true}, {"category": "Security", "enabled": true}, {"category": "Alert", "enabled": true}, {"category": "Policy", "enabled": true}] } } + IMPORTANT: workspaceId must be a real workspace ID discovered from readSteps, NOT a placeholder + +### Activity Log Alerts +- PUT: https://management.azure.com/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Insights/activityLogAlerts/{name}?api-version=2020-10-01 + +### Role Assignments (IAM) +- GET: https://management.azure.com/{scope}/providers/Microsoft.Authorization/roleAssignments?api-version=2022-04-01 +- PUT: .../roleAssignments/{assignmentId}?api-version=2022-04-01 +- DELETE: Same URL (rollback only) + +### AKS Clusters +- GET: https://management.azure.com/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.ContainerService/managedClusters/{name}?api-version=2024-01-01 +- PATCH: Same URL with body: { properties: { autoUpgradeProfile: { upgradeChannel: "stable" }, addonProfiles: { azurePolicy: { enabled: true } } } } +- API server access: PATCH with { properties: { apiServerAccessProfile: { authorizedIPRanges: ["x.x.x.x/32"] } } } + +### Container Registry +- GET: https://management.azure.com/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.ContainerRegistry/registries/{name}?api-version=2023-11-01-preview +- PATCH: { properties: { adminUserEnabled: false, publicNetworkAccess: "Disabled", policies: { trustPolicy: { status: "enabled", type: "Notary" } } } } + +### Cosmos DB +- GET: https://management.azure.com/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.DocumentDB/databaseAccounts/{name}?api-version=2024-02-15-preview +- PATCH: { properties: { publicNetworkAccess: "Disabled", disableLocalAuth: true, enableAutomaticFailover: true, disableKeyBasedMetadataWriteAccess: true } } + +## Safety Rules +- NEVER delete resource groups, subscriptions, or VMs +- NEVER modify role assignments that would lock out the service principal +- NEVER disable encryption on databases or storage +- PREFER enabling security features (soft delete, purge protection, HTTPS-only) +- ALWAYS make changes reversible via PATCH back to original values +- ALWAYS read current state before modifying + +## canAutoFix = true +- Enable security features (soft delete, purge protection, TDE, HTTPS-only) +- Restrict network access (NSG rules, Key Vault network ACLs) +- Enable logging/auditing (diagnostic settings, SQL auditing) +- Remove overly permissive NSG rules +- ALWAYS provide rollback steps + +## canAutoFix = false (ONLY these specific cases) +- Resource recreation required (e.g., AKS cluster needs RBAC enabled from scratch) +- Entra ID / Active Directory changes requiring admin consent +- Changes requiring VM/app restart that could cause downtime +- Resource doesn't exist or was deleted +- Enabling features that require a higher SKU tier (e.g., ACR content trust needs Premium) + +IMPORTANT: Most Azure fixes ARE auto-fixable via PATCH. Default to canAutoFix=true unless one of the above applies. The customer expects automation — avoid guidedSteps when a PATCH will work. + +## Risk Assessment +- low: Enabling features (soft delete, diagnostic settings, HTTPS enforcement) +- medium: Restricting access (NSG tightening, network ACL changes) +- high: Database/encryption settings on production resources +- critical: IAM changes, role assignment modifications + +## Rollback Patterns +- PATCH operations: rollback by PATCH back to original values (read first!) +- PUT (create): DELETE the created resource +- DELETE (remove rule): PUT back the original rule definition +- Always capture current state in readSteps before any modification + +## Irreversible Operations +Some Azure changes cannot be undone. For these, set rollbackSupported=false and requiresAcknowledgment=true: +- Key Vault purge protection (once enabled, cannot be disabled) +- AKS RBAC (cannot disable after enabling without cluster recreation) +- Cosmos DB disableLocalAuth (reverting requires careful coordination) + +For these, still provide fixSteps (they CAN be auto-applied) but set rollbackSteps=[] and explain in acknowledgmentMessage. + +## Discovery Pattern (CRITICAL) +Many fixes require referencing OTHER resources (Log Analytics workspaces, storage accounts, etc.) that aren't in the finding. Your readSteps MUST discover these first: + +1. To find a Log Analytics workspace: + GET https://management.azure.com/subscriptions/{sub}/providers/Microsoft.OperationalInsights/workspaces?api-version=2022-10-01 + +2. To find a storage account: + GET https://management.azure.com/subscriptions/{sub}/providers/Microsoft.Storage/storageAccounts?api-version=2023-05-01 + +3. To find resource groups: + GET https://management.azure.com/subscriptions/{sub}/resourcegroups?api-version=2021-04-01 + Or parse from the resourceId path: /subscriptions/{sub}/resourceGroups/{RG_IS_HERE}/providers/... + +ALWAYS include discovery GET steps in readSteps when the fix needs a workspace ID, storage account, or other external reference. Use the FIRST result from the discovery query. + +If discovery finds NO workspaces or storage accounts, CREATE them as part of the fix: +- Discover resource groups: GET https://management.azure.com/subscriptions/{sub}/resourcegroups?api-version=2021-04-01 +- If NO resource groups exist, create one in fixSteps: + PUT https://management.azure.com/subscriptions/{sub}/resourcegroups/compai-security?api-version=2021-04-01 + Body: { "location": "eastus" } + (rollback: DELETE the resource group) +- Create Log Analytics workspace in fixSteps: + PUT https://management.azure.com/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.OperationalInsights/workspaces/compai-security-logs?api-version=2022-10-01 + Body: { "location": "{same_location_as_rg}", "properties": { "retentionInDays": 30, "sku": { "name": "PerGB2018" } } } + (rollback: DELETE the workspace) +- Then create the diagnostic setting pointing to the new workspace +- Use the FIRST existing resource group if one exists, otherwise create "compai-security" +- IMPORTANT: If creating a Log Analytics workspace, append a short random suffix to avoid soft-delete name conflicts: "compai-security-logs-{4chars}" where {4chars} are random lowercase letters. Azure soft-deletes workspaces for 14 days, blocking the same name. +- Set requiresAcknowledgment=true when creating new resources so the user sees exactly what will be created +- The preview shows all resources that will be created — the user decides +- ALWAYS provide rollback steps that clean up created resources + +canAutoFix should be TRUE for almost everything. Only set false for: +- Organizational-level policy changes that affect all subscriptions +- Changes requiring Azure AD admin consent (app permissions, conditional access) +- Deleting/recreating resources that would cause data loss + +## Critical Rules +- ALWAYS use readSteps to: (a) capture current state AND (b) discover referenced resources +- NEVER use placeholder values — discover real IDs via readSteps +- NEVER hardcode workspace IDs, storage account names, or resource group names — always discover them +- URLs must include ?api-version= parameter +- currentState and proposedState should use matching keys +- Parse resourceId from the finding evidence to build API URLs +- The resourceId in evidence is a FULL ARM path like /subscriptions/xxx/resourceGroups/yyy/providers/Microsoft.Service/type/name — use it directly in URLs +- For PATCH operations, only include the properties you want to change (Azure merges) +- After discovery, use EXACT values from readStep responses in fixSteps — the refine step will replace placeholders with real values`; + +export function buildAzureFixPlanPrompt(finding: { + title: string; + description: string | null; + severity: string | null; + resourceType: string; + resourceId: string; + remediation: string | null; + findingKey: string; + evidence: Record; +}): string { + return `Analyze this Azure security finding and generate a fix plan: + +**Title:** ${finding.title} +**Severity:** ${finding.severity || 'Unknown'} +**Resource Type:** ${finding.resourceType} +**Resource ID:** ${finding.resourceId} +**Description:** ${finding.description || 'No description'} +**Remediation Guidance:** ${finding.remediation || 'None provided'} +**Finding Key:** ${finding.findingKey} + +**Evidence:** +\`\`\`json +${JSON.stringify(finding.evidence, null, 2)} +\`\`\` + +Generate a fix plan. Use the resource ID and evidence to build exact ARM API URLs. Read current state first, then apply the fix. Always provide rollback steps.`; +} diff --git a/apps/api/src/cloud-security/azure-command-executor.ts b/apps/api/src/cloud-security/azure-command-executor.ts new file mode 100644 index 0000000000..aae5a72420 --- /dev/null +++ b/apps/api/src/cloud-security/azure-command-executor.ts @@ -0,0 +1,456 @@ +import { Logger } from '@nestjs/common'; +import type { AzureApiStep } from './azure-ai-remediation.prompt'; + +const logger = new Logger('AzureCommandExecutor'); + +const MAX_STEP_RETRIES = 3; +const MAX_POLL_MS = 120_000; + +export interface AzureStepResult { + step: AzureApiStep; + success: boolean; + statusCode?: number; + response?: unknown; + error?: string; +} + +export interface AzureExecutionResult { + results: AzureStepResult[]; + error?: { + stepIndex: number; + step: AzureApiStep; + message: string; + }; +} + +/** + * Execute Azure ARM API steps sequentially with full self-healing: + * - Auto-registers missing resource providers + * - Retries on throttling (429) and server errors (5xx) + * - Waits for resources still provisioning + * - Auto-rolls back on partial failure + */ +export async function executeAzurePlanSteps(params: { + steps: AzureApiStep[]; + accessToken: string; + autoRollbackSteps?: AzureApiStep[]; + isRollback?: boolean; +}): Promise { + const { steps, accessToken, autoRollbackSteps, isRollback } = params; + + // Validate ALL step URLs before executing any — prevents SSRF on read/fix/rollback steps + const allSteps = [...steps, ...(autoRollbackSteps ?? [])]; + const validationErrors = validateAzurePlanSteps(allSteps); + if (validationErrors.length > 0) { + return { + results: [], + error: { + stepIndex: 0, + step: steps[0] ?? allSteps[0], + message: `URL validation failed: ${validationErrors.join('; ')}`, + }, + }; + } + + const results: AzureStepResult[] = []; + + for (let i = 0; i < steps.length; i++) { + const step = steps[i]; + const result = await executeWithRetry(step, accessToken, isRollback); + results.push(result); + + if (!result.success) { + // Auto-rollback completed steps + if (autoRollbackSteps && i > 0) { + logger.warn( + `Step ${i} failed, auto-rolling back ${Math.min(i, autoRollbackSteps.length)} steps`, + ); + for (let r = Math.min(i, autoRollbackSteps.length) - 1; r >= 0; r--) { + const rollbackStep = autoRollbackSteps[r]; + if (!rollbackStep) continue; + try { + await executeWithRetry(rollbackStep, accessToken, true); + logger.log(`Rollback step ${r} succeeded`); + } catch (err) { + logger.warn( + `Rollback step ${r} failed: ${err instanceof Error ? err.message : String(err)}`, + ); + } + } + } + + return { + results, + error: { + stepIndex: i, + step, + message: + result.error || `Step ${i} failed with status ${result.statusCode}`, + }, + }; + } + } + + return { results }; +} + +/** + * Execute a single API call with up to MAX_STEP_RETRIES attempts. + * Each retry auto-heals the specific error before retrying. + */ +async function executeWithRetry( + step: AzureApiStep, + accessToken: string, + isRollback?: boolean, +): Promise { + // Safety: block DELETE unless rolling back + if (step.method === 'DELETE' && !isRollback) { + return { + step, + success: false, + error: 'DELETE only allowed during rollback.', + }; + } + + for (let attempt = 0; attempt < MAX_STEP_RETRIES; attempt++) { + const result = await executeOnce(step, accessToken); + + if (result.success) return result; + + const err = result.error ?? ''; + const code = result.statusCode ?? 0; + const canRetry = attempt < MAX_STEP_RETRIES - 1; + + // --- Self-healing by error type --- + + // 429 Throttled → wait and retry + if (code === 429 && canRetry) { + const delay = 3000 * (attempt + 1); + logger.warn( + `Throttled (429), waiting ${delay}ms before retry ${attempt + 1}`, + ); + await new Promise((r) => setTimeout(r, delay)); + continue; + } + + // 5xx Server error → wait and retry + if (code >= 500 && canRetry) { + logger.warn(`Server error (${code}), retrying in 2s...`); + await new Promise((r) => setTimeout(r, 2000)); + continue; + } + + // Missing resource provider → register and retry + // Azure returns this as 409 MissingSubscriptionRegistration OR InvalidAuthenticationToken with "register" + const isProviderMissing = + (code === 409 || code === 401) && + (err.includes('MissingSubscriptionRegistration') || + err.includes('Please register the subscription') || + err.includes('not registered to use namespace')); + if (isProviderMissing && canRetry) { + const providerMatch = + err.match(/namespace '(Microsoft\.\w+)'/) || + err.match(/with (Microsoft\.\w+)/); + if (providerMatch) { + await registerProvider(accessToken, step.url, providerMatch[1]); + continue; + } + } + + // 409 Resource provisioning → wait and retry + if ( + code === 409 && + (err.includes('provisioning state') || + err.includes('Creating') || + err.includes('Updating')) && + canRetry + ) { + logger.warn('Resource still provisioning, waiting 10s...'); + await new Promise((r) => setTimeout(r, 10_000)); + continue; + } + + // 409 Conflict "already exists" on write → treat as success + if ( + code === 409 && + step.method !== 'GET' && + (err.includes('already exists') || err.includes('AlreadyExists')) + ) { + return { + step, + success: true, + statusCode: 409, + response: { note: 'Already exists' }, + }; + } + + // 409 Soft-deleted → can't auto-heal, return clear error + if ( + code === 409 && + (err.includes('soft-delete') || err.includes('SoftDeleted')) + ) { + return { + step, + success: false, + statusCode: 409, + error: `Name blocked by soft-deleted resource. ${err.slice(0, 200)}`, + }; + } + + // 404 on GET → valid (resource doesn't exist) + if (code === 404 && step.method === 'GET') { + return { step, success: true, statusCode: 404, response: null }; + } + + // 401 → token expired (but NOT if it's a provider registration issue — those are handled above) + if (code === 401 && !err.includes('register the subscription')) { + return { + step, + success: false, + statusCode: 401, + error: 'Access token expired. Reconnect the integration.', + }; + } + + // 403 → permission denied, return for higher-level self-healing + if (code === 403) { + return result; + } + + // 400 Bad Request → can't self-heal at this level + if (code === 400) { + return result; + } + + // Unknown error on last attempt → return as-is + return result; + } + + return { step, success: false, error: 'Max retries exceeded' }; +} + +/** Execute a single HTTP call and return the result. */ +async function executeOnce( + step: AzureApiStep, + accessToken: string, +): Promise { + const url = new URL(step.url); + if (step.queryParams) { + for (const [key, value] of Object.entries(step.queryParams)) { + url.searchParams.set(key, value); + } + } + + logger.log(`${step.method} ${url.pathname} — ${step.purpose}`); + + let response: Response; + try { + response = await fetch(url.toString(), { + method: step.method, + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + body: step.body ? JSON.stringify(step.body) : undefined, + }); + } catch (err) { + return { + step, + success: false, + error: `Network error: ${err instanceof Error ? err.message : String(err)}`, + }; + } + + // 202 Accepted → poll async operation + if (response.status === 202) { + const pollUrl = + response.headers.get('Azure-AsyncOperation') || + response.headers.get('Location'); + if (pollUrl) { + // Validate poll URL to prevent SSRF via response headers + try { + const parsedPoll = new URL(pollUrl); + if (!AZURE_ALLOWED_HOSTS.has(parsedPoll.hostname)) { + return { + step, + success: false, + error: `Async poll URL targets disallowed host: ${parsedPoll.hostname}`, + }; + } + } catch { + return { step, success: false, error: 'Async poll URL is malformed' }; + } + try { + const finalResult = await pollAsyncOperation(pollUrl, accessToken); + if (finalResult === null) { + return { + step, + success: false, + error: 'Async operation timed out or failed to poll', + }; + } + return { step, success: true, statusCode: 200, response: finalResult }; + } catch (pollErr) { + return { + step, + success: false, + error: `Async operation failed: ${pollErr instanceof Error ? pollErr.message : String(pollErr)}`, + }; + } + } + return { step, success: true, statusCode: 202, response: null }; + } + + // 204 No Content + if (response.status === 204) { + return { step, success: true, statusCode: 204, response: null }; + } + + // Success + if (response.ok) { + const text = await response.text(); + let data: unknown = null; + if (text) { + try { + data = JSON.parse(text); + } catch { + data = { rawBody: text }; + } + } + return { step, success: true, statusCode: response.status, response: data }; + } + + // Error — read body for diagnostics + const errorText = await response.text(); + if (response.status === 409) { + logger.warn( + `409 for ${step.method} ${url.pathname}: ${errorText.slice(0, 300)}`, + ); + } + + return { + step, + success: false, + statusCode: response.status, + error: errorText, + }; +} + +// ─── Resource Provider Registration ──────────────────────────────────────── + +async function registerProvider( + accessToken: string, + stepUrl: string, + providerNamespace: string, +): Promise { + const subMatch = stepUrl.match(/\/subscriptions\/([^/]+)/); + if (!subMatch) return; + + logger.log(`Auto-registering provider: ${providerNamespace}`); + try { + const resp = await fetch( + `https://management.azure.com/subscriptions/${subMatch[1]}/providers/${providerNamespace}/register?api-version=2021-04-01`, + { method: 'POST', headers: { Authorization: `Bearer ${accessToken}` } }, + ); + if (resp.ok) { + logger.log( + `Provider ${providerNamespace} registered — waiting 15s for propagation`, + ); + await new Promise((r) => setTimeout(r, 15_000)); + } else { + logger.warn(`Failed to register ${providerNamespace}: ${resp.status}`); + } + } catch (err) { + logger.warn( + `Provider registration error: ${err instanceof Error ? err.message : String(err)}`, + ); + } +} + +// ─── Async Operation Polling ─────────────────────────────────────────────── + +async function pollAsyncOperation( + pollUrl: string, + accessToken: string, +): Promise { + const startTime = Date.now(); + let interval = 2000; + + while (Date.now() - startTime < MAX_POLL_MS) { + await new Promise((r) => setTimeout(r, interval)); + interval = Math.min(interval * 1.5, 10_000); + + const resp = await fetch(pollUrl, { + headers: { Authorization: `Bearer ${accessToken}` }, + }); + + if (!resp.ok) { + logger.warn(`Async poll failed: ${resp.status}`); + return null; + } + + const data = (await resp.json()) as Record; + const status = (data.status as string)?.toLowerCase(); + + if (status === 'succeeded' || status === 'completed') { + return data; + } + if ( + status === 'failed' || + status === 'canceled' || + status === 'cancelled' + ) { + const error = data.error as Record | undefined; + throw new Error( + `Async operation ${status}: ${(error?.message as string) ?? 'unknown'}`, + ); + } + } + + logger.warn(`Async operation timed out after ${MAX_POLL_MS}ms`); + return null; +} + +/** + * Validate Azure API steps for safety. + */ +const AZURE_ALLOWED_HOSTS = new Set([ + 'management.azure.com', + 'graph.microsoft.com', +]); + +export function validateAzurePlanSteps(steps: AzureApiStep[]): string[] { + const errors: string[] = []; + for (let i = 0; i < steps.length; i++) { + const step = steps[i]; + if (!step.url) { + errors.push(`Step ${i}: URL is required`); + continue; + } + if (!step.method) errors.push(`Step ${i}: method is required`); + try { + const parsed = new URL(step.url); + if (parsed.protocol !== 'https:') { + errors.push(`Step ${i}: URL must use HTTPS`); + } + if (!AZURE_ALLOWED_HOSTS.has(parsed.hostname)) { + errors.push(`Step ${i}: URL must target Azure Management or Graph API`); + } + } catch { + errors.push(`Step ${i}: URL must be a valid absolute URL`); + } + if ( + step.method === 'DELETE' && + step.url?.match(/\/subscriptions\/[^/]+$/) + ) { + errors.push(`Step ${i}: Cannot delete a subscription`); + } + if ( + step.method !== 'GET' && + step.url?.includes('/providers/Microsoft.Authorization/roleDefinitions/') + ) { + errors.push(`Step ${i}: Cannot modify built-in role definitions`); + } + } + return errors; +} diff --git a/apps/api/src/cloud-security/azure-remediation.service.ts b/apps/api/src/cloud-security/azure-remediation.service.ts new file mode 100644 index 0000000000..fe2bd0ad5b --- /dev/null +++ b/apps/api/src/cloud-security/azure-remediation.service.ts @@ -0,0 +1,745 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { db, Prisma } from '@db'; +import { CredentialVaultService } from '../integration-platform/services/credential-vault.service'; +import { AiRemediationService } from './ai-remediation.service'; +import { AzureSecurityService } from './providers/azure-security.service'; +import { parseAzurePermissionError } from './remediation-error.utils'; +import { + executeAzurePlanSteps, + validateAzurePlanSteps, +} from './azure-command-executor'; +import type { AzureFixPlan } from './azure-ai-remediation.prompt'; + +const PLAN_CACHE_TTL = 5 * 60 * 1000; // 5 minutes + +@Injectable() +export class AzureRemediationService { + private readonly logger = new Logger(AzureRemediationService.name); + private planCache = new Map< + string, + { plan: AzureFixPlan; timestamp: number } + >(); + private readonly PLAN_CACHE_MAX = 100; + + private evictStalePlans() { + if (this.planCache.size <= this.PLAN_CACHE_MAX) return; + const now = Date.now(); + for (const [key, entry] of this.planCache) { + if (now - entry.timestamp > PLAN_CACHE_TTL) this.planCache.delete(key); + } + while (this.planCache.size > this.PLAN_CACHE_MAX) { + const firstKey = this.planCache.keys().next().value; + if (firstKey) this.planCache.delete(firstKey); + else break; + } + } + + constructor( + private readonly credentialVaultService: CredentialVaultService, + private readonly aiRemediationService: AiRemediationService, + private readonly azureSecurityService: AzureSecurityService, + ) {} + + async getCapabilities(params: { + connectionId: string; + organizationId: string; + }) { + const credentials = await this.resolveCredentials( + params.connectionId, + params.organizationId, + ); + return { + enabled: Boolean(credentials?.access_token || credentials?.clientId), + aiPowered: true, + remediations: [], + }; + } + + async previewRemediation(params: { + connectionId: string; + organizationId: string; + checkResultId: string; + remediationKey: string; + }) { + const { finding, accessToken } = await this.resolveContext( + params.connectionId, + params.organizationId, + params.checkResultId, + ); + + // Generate AI plan + let plan = await this.aiRemediationService.generateAzureFixPlan(finding); + + if (!plan.canAutoFix) { + return this.buildGuidedResponse(plan); + } + + // Execute read steps to get real Azure state + if (plan.readSteps.length > 0 && accessToken) { + const readResult = await executeAzurePlanSteps({ + steps: plan.readSteps, + accessToken, + }); + + const realState: Record = {}; + for (const r of readResult.results) { + if (r.success && r.response) { + realState[r.step.purpose] = r.response; + } + } + + // Refine plan with real Azure state + if (Object.keys(realState).length > 0) { + plan = await this.aiRemediationService.refineAzureFixPlan({ + finding, + originalPlan: plan, + realAzureState: realState, + }); + } + } + + // Validate fix steps + const validationErrors = validateAzurePlanSteps(plan.fixSteps); + if (validationErrors.length > 0) { + this.logger.warn( + `Fix plan validation errors: ${validationErrors.join(', ')}`, + ); + return this.buildGuidedResponse(plan); + } + + // Cache plan for execute + const cacheKey = `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`; + this.evictStalePlans(); + this.planCache.set(cacheKey, { plan, timestamp: Date.now() }); + + return this.buildPreviewResponse(plan); + } + + async executeRemediation(params: { + connectionId: string; + organizationId: string; + checkResultId: string; + remediationKey: string; + userId: string; + acknowledgment?: string; + }) { + const { finding, accessToken } = await this.resolveContext( + params.connectionId, + params.organizationId, + params.checkResultId, + ); + + if (!accessToken) { + throw new Error('Azure access token unavailable. Check credentials.'); + } + + // Retrieve or regenerate plan + const cacheKey = `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`; + const cached = this.planCache.get(cacheKey); + let plan: AzureFixPlan; + + if (cached && Date.now() - cached.timestamp < PLAN_CACHE_TTL) { + plan = cached.plan; + } else { + plan = await this.aiRemediationService.generateAzureFixPlan(finding); + if (!plan.canAutoFix) { + throw new Error( + 'This finding cannot be auto-fixed. Use guided steps instead.', + ); + } + } + + // Create action record + const action = await db.remediationAction.create({ + data: { + connectionId: params.connectionId, + organizationId: params.organizationId, + checkResultId: params.checkResultId, + remediationKey: params.remediationKey, + resourceId: finding.resourceId || params.checkResultId, + resourceType: finding.resourceType || 'azure-resource', + previousState: {}, + appliedState: {}, + status: 'executing', + riskLevel: plan.risk, + acknowledgmentText: params.acknowledgment, + acknowledgedAt: params.acknowledgment ? new Date() : null, + initiatedById: params.userId, + }, + }); + + try { + // Phase 1: Execute read steps to capture previous state + const previousState: Record = {}; + if (plan.readSteps.length > 0) { + const readResult = await executeAzurePlanSteps({ + steps: plan.readSteps, + accessToken, + }); + for (const r of readResult.results) { + if (r.success && r.response) { + previousState[r.step.purpose] = r.response; + } + } + } + + // Phase 2: Refine plan with real state + if (Object.keys(previousState).length > 0) { + plan = await this.aiRemediationService.refineAzureFixPlan({ + finding, + originalPlan: plan, + realAzureState: previousState, + }); + } + + this.logger.log( + `AI plan for ${finding.findingKey}: canAutoFix=${plan.canAutoFix}, ` + + `fixSteps=${plan.fixSteps.length}, readSteps=${plan.readSteps.length}, ` + + `rollbackSteps=${plan.rollbackSteps.length}`, + ); + + // If AI decided it can't auto-fix after seeing real state, fail clearly + if (!plan.canAutoFix || plan.fixSteps.length === 0) { + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'failed', + previousState: previousState as unknown as Prisma.InputJsonValue, + appliedState: { + error: + plan.reason || + 'Auto-fix not possible for this finding after analyzing real resource state.', + guidedSteps: plan.guidedSteps, + } as unknown as Prisma.InputJsonValue, + executedAt: new Date(), + }, + }); + + return { + actionId: action.id, + status: 'failed' as const, + resourceId: finding.resourceId, + error: + plan.reason || + 'Auto-fix not possible. The required resources (e.g., Log Analytics workspace) may not exist in your subscription.', + previousState, + guidedSteps: plan.guidedSteps, + }; + } + + // Phase 2.5: Pre-flight — check write permissions and self-heal before executing + const subscriptionId = this.extractSubscriptionId( + plan.fixSteps[0]?.url || finding.resourceId, + ); + if (subscriptionId) { + await this.ensureWriteAccess(accessToken, subscriptionId); + } + + // Phase 3: Execute fix steps with self-healing retry + // Executor auto-handles: provider registration, throttling, retries, provisioning waits + for (const step of plan.fixSteps) { + this.logger.log( + `Fix step: ${step.method} ${step.url} — ${step.purpose}`, + ); + } + + // Validate URLs before execution to prevent SSRF (especially after cache-miss regeneration) + const validationErrors = validateAzurePlanSteps(plan.fixSteps); + if (validationErrors.length > 0) { + throw new Error( + `Fix plan validation failed: ${validationErrors.join('; ')}`, + ); + } + + let fixResult = await executeAzurePlanSteps({ + steps: plan.fixSteps, + accessToken, + autoRollbackSteps: plan.rollbackSteps, + }); + + // If permission error, report it clearly — don't attempt self-healing role grants + if (fixResult.error) { + const permError = parseAzurePermissionError(fixResult.error.message); + if (permError?.isPermissionError) { + this.logger.warn( + `Permission error: ${fixResult.error.message}. Assign the required Azure role to the app registration.`, + ); + } + } + + // Self-healing round 2: non-permission error → regenerate plan with error context → retry + if ( + fixResult.error && + !parseAzurePermissionError(fixResult.error.message)?.isPermissionError + ) { + this.logger.log( + 'Non-permission error — regenerating fix plan with error context...', + ); + const retryPlan = await this.aiRemediationService.refineAzureFixPlan({ + finding, + originalPlan: plan, + realAzureState: { + ...previousState, + _lastError: fixResult.error.message, + _failedStep: fixResult.error.step, + }, + }); + + if (retryPlan.canAutoFix && retryPlan.fixSteps.length > 0) { + this.logger.log( + `Retrying with regenerated plan (${retryPlan.fixSteps.length} steps)...`, + ); + plan = retryPlan; + fixResult = await executeAzurePlanSteps({ + steps: plan.fixSteps, + accessToken, + autoRollbackSteps: plan.rollbackSteps, + }); + } + } + + // Log every step result for audit trail + for (const r of fixResult.results) { + this.logger.log( + `Step result: ${r.step.method} ${r.step.url} → ${r.success ? `${r.statusCode} OK` : `FAILED: ${r.error}`}`, + ); + } + + // If still failing after self-healing attempts + if (fixResult.error) { + const permError = parseAzurePermissionError(fixResult.error.message); + + // Store ALL completed steps (even partial) so we can see what was modified + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'failed', + previousState: previousState as unknown as Prisma.InputJsonValue, + appliedState: { + error: fixResult.error.message, + stepIndex: fixResult.error.stepIndex, + completedSteps: fixResult.results + .filter((r) => r.success) + .map((r) => ({ + method: r.step.method, + url: r.step.url, + purpose: r.step.purpose, + statusCode: r.statusCode, + })), + failedStep: { + method: fixResult.error.step.method, + url: fixResult.error.step.url, + purpose: fixResult.error.step.purpose, + error: fixResult.error.message, + }, + rollbackSteps: plan.rollbackSteps, + ...(permError && { + missingActions: permError.missingActions, + fixScript: permError.fixScript, + }), + } as unknown as Prisma.InputJsonValue, + executedAt: new Date(), + }, + }); + + return { + actionId: action.id, + status: 'failed' as const, + resourceId: finding.resourceId, + error: fixResult.error.message, + previousState, + ...(permError && { + missingPermissions: permError.missingActions, + permissionFixScript: permError.fixScript, + }), + }; + } + + // Phase 4: Verify — re-read the resource to confirm fix took effect + let verified = false; + if (plan.readSteps.length > 0) { + // Wait briefly for Azure to propagate the change + await new Promise((r) => setTimeout(r, 2000)); + + const verifyResult = await executeAzurePlanSteps({ + steps: plan.readSteps, + accessToken, + }); + + const postFixState: Record = {}; + for (const r of verifyResult.results) { + if (r.success && r.response) { + postFixState[r.step.purpose] = r.response; + } + } + + // Compare: if post-fix state differs from pre-fix state, the fix changed something + verified = + JSON.stringify(postFixState) !== JSON.stringify(previousState); + } + + const status = verified ? 'success' : 'unverified'; + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status, + previousState: previousState as unknown as Prisma.InputJsonValue, + appliedState: { + steps: fixResult.results.map((r) => ({ + purpose: r.step.purpose, + statusCode: r.statusCode, + response: r.response, + })), + rollbackSteps: plan.rollbackSteps, + verified, + } as unknown as Prisma.InputJsonValue, + executedAt: new Date(), + }, + }); + + this.planCache.delete(cacheKey); + + if (!verified) { + this.logger.warn( + `Fix for ${finding.findingKey} executed but verification shows no state change. ` + + `The fix may need time to propagate or may not have addressed the finding correctly.`, + ); + } + + return { + actionId: action.id, + status: status, + resourceId: finding.resourceId, + previousState, + appliedState: { description: plan.description, verified }, + }; + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'failed', + appliedState: { error: msg } as unknown as Prisma.InputJsonValue, + executedAt: new Date(), + }, + }); + throw error; + } + } + + async rollbackRemediation(params: { + actionId: string; + organizationId: string; + }) { + const action = await db.remediationAction.findFirst({ + where: { + id: params.actionId, + organizationId: params.organizationId, + status: { in: ['success', 'unverified'] }, + }, + include: { + connection: { include: { provider: true } }, + }, + }); + + if (!action) { + throw new Error('Remediation action not found or cannot be rolled back.'); + } + + const appliedState = action.appliedState as Record | null; + const rollbackSteps = (appliedState?.rollbackSteps ?? []) as Array< + Record + >; + + if (rollbackSteps.length === 0) { + throw new Error('No rollback steps available for this action.'); + } + + // Get fresh access token + const credentials = await this.resolveCredentials( + action.connectionId, + action.organizationId, + ); + if (!credentials) { + throw new Error('Cannot retrieve Azure credentials for rollback.'); + } + + // OAuth flow: token from vault; legacy: SP client credentials + let accessToken = credentials.access_token as string | undefined; + if ( + !accessToken && + credentials.tenantId && + credentials.clientId && + credentials.clientSecret + ) { + accessToken = await this.azureSecurityService.getAccessToken( + credentials.tenantId as string, + credentials.clientId as string, + credentials.clientSecret as string, + ); + } + if (!accessToken) { + throw new Error('Cannot obtain Azure access token for rollback.'); + } + + this.logger.log( + `Rolling back action ${action.id}: ${rollbackSteps.length} steps`, + ); + for (const step of rollbackSteps) { + this.logger.log( + `Rollback step: ${(step as { method?: string }).method} ${(step as { url?: string }).url} — ${(step as { purpose?: string }).purpose}`, + ); + } + + // Pre-flight: ensure write access before rollback + const subscriptionId = this.extractSubscriptionId( + (rollbackSteps[0] as { url?: string })?.url || action.checkResultId, + ); + if (subscriptionId) { + await this.ensureWriteAccess(accessToken, subscriptionId); + } + + const result = await executeAzurePlanSteps({ + steps: rollbackSteps as Parameters< + typeof executeAzurePlanSteps + >[0]['steps'], + accessToken, + isRollback: true, + }); + + // If permission error during rollback, log clearly + if (result.error && subscriptionId) { + const permError = parseAzurePermissionError(result.error.message); + if (permError?.isPermissionError) { + this.logger.warn( + `Rollback permission error: ${result.error.message}. Assign the required Azure role to the app registration.`, + ); + } + } + + // Log each rollback step result + for (const r of result.results) { + this.logger.log( + `Rollback result: ${r.step.method} ${r.step.url} → ${r.success ? `${r.statusCode} OK` : `FAILED: ${r.error}`}`, + ); + } + + if (result.error) { + const permError = parseAzurePermissionError(result.error.message); + const completedCount = result.results.filter((r) => r.success).length; + + this.logger.error( + `Rollback failed at step ${result.error.stepIndex}: ${result.error.message}. ` + + `${completedCount}/${rollbackSteps.length} steps completed before failure.`, + ); + + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'rollback_failed', + rolledBackAt: new Date(), + appliedState: { + ...((action.appliedState as Record) ?? {}), + rollbackError: result.error.message, + rollbackCompletedSteps: result.results + .filter((r) => r.success) + .map((r) => ({ + method: r.step.method, + url: r.step.url, + purpose: r.step.purpose, + })), + rollbackFailedStep: { + method: result.error.step.method, + url: result.error.step.url, + purpose: result.error.step.purpose, + error: result.error.message, + }, + } as unknown as Prisma.InputJsonValue, + }, + }); + + return { + status: 'rollback_failed' as const, + connectionId: action.connectionId, + remediationKey: action.remediationKey, + resourceId: action.checkResultId, + error: result.error.message, + ...(permError && { + missingPermissions: permError.missingActions, + permissionFixScript: permError.fixScript, + }), + }; + } + + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'rolled_back', + rolledBackAt: new Date(), + }, + }); + + return { + status: 'rolled_back' as const, + connectionId: action.connectionId, + remediationKey: action.remediationKey, + resourceId: action.checkResultId, + }; + } + + // --- Self-healing helpers --- + + /** + * Pre-flight: check if the token has write access on the subscription. + * If not, attempt to self-grant Contributor role. + */ + private async ensureWriteAccess( + accessToken: string, + subscriptionId: string, + ): Promise { + try { + const resp = await fetch( + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Authorization/permissions?api-version=2022-04-01`, + { headers: { Authorization: `Bearer ${accessToken}` } }, + ); + + if (!resp.ok) { + this.logger.warn('Could not check permissions — proceeding anyway'); + return; + } + + const data = (await resp.json()) as { + value: Array<{ actions: string[]; notActions: string[] }>; + }; + const allActions = data.value?.flatMap((p) => p.actions) ?? []; + const hasWrite = allActions.some( + (a) => a === '*' || a === '*/write' || a.endsWith('/write'), + ); + + if (hasWrite) { + this.logger.log('Pre-flight: write access confirmed'); + return; + } + + this.logger.warn( + 'Pre-flight: no write access detected — fix may fail. Assign Contributor role to the app registration.', + ); + } catch (err) { + this.logger.warn( + `Pre-flight permission check failed: ${err instanceof Error ? err.message : String(err)}`, + ); + } + } + + private extractSubscriptionId(resourceId: string): string | null { + const match = resourceId.match(/\/subscriptions\/([^/]+)/); + return match?.[1] ?? null; + } + + // --- Private helpers --- + + private async resolveCredentials( + connectionId: string, + organizationId: string, + ): Promise | null> { + const connection = await db.integrationConnection.findFirst({ + where: { id: connectionId, organizationId, status: 'active' }, + include: { provider: true }, + }); + if (!connection || connection.provider.slug !== 'azure') return null; + return this.credentialVaultService.getDecryptedCredentials(connectionId); + } + + private async resolveContext( + connectionId: string, + organizationId: string, + checkResultId: string, + ) { + const credentials = await this.resolveCredentials( + connectionId, + organizationId, + ); + + let accessToken: string | null = null; + // OAuth flow: token from vault + if (credentials?.access_token) { + accessToken = credentials.access_token as string; + } + // Legacy SP flow fallback + if ( + !accessToken && + credentials?.tenantId && + credentials?.clientId && + credentials?.clientSecret + ) { + accessToken = await this.azureSecurityService.getAccessToken( + credentials.tenantId as string, + credentials.clientId as string, + credentials.clientSecret as string, + ); + } + + const checkResult = await db.integrationCheckResult.findFirst({ + where: { + id: checkResultId, + checkRun: { connectionId }, + }, + }); + + if (!checkResult) { + throw new Error(`Check result ${checkResultId} not found`); + } + + const evidence = (checkResult.evidence ?? {}) as Record; + + return { + finding: { + title: checkResult.title ?? '', + description: checkResult.description, + severity: checkResult.severity, + resourceType: checkResult.resourceType ?? 'azure-resource', + resourceId: checkResult.resourceId ?? '', + remediation: checkResult.remediation, + findingKey: (evidence.findingKey as string) ?? '', + evidence, + }, + accessToken, + }; + } + + private buildPreviewResponse(plan: AzureFixPlan) { + return { + currentState: plan.currentState, + proposedState: plan.proposedState, + description: plan.description, + risk: plan.risk, + apiCalls: plan.fixSteps.map((s) => ({ + method: s.method, + endpoint: s.url, + purpose: s.purpose, + })), + guidedOnly: false, + rollbackSupported: plan.rollbackSupported, + requiresAcknowledgment: plan.requiresAcknowledgment + ? ('checkbox' as const) + : undefined, + acknowledgmentMessage: plan.acknowledgmentMessage, + }; + } + + private buildGuidedResponse(plan: AzureFixPlan) { + return { + currentState: plan.currentState, + proposedState: plan.proposedState, + description: plan.description, + risk: plan.risk, + apiCalls: [], + guidedOnly: true, + guidedSteps: plan.guidedSteps ?? [ + plan.reason || 'This finding requires manual remediation.', + ], + rollbackSupported: false, + requiresAcknowledgment: undefined, + }; + } +} diff --git a/apps/api/src/cloud-security/cloud-security-activity.service.ts b/apps/api/src/cloud-security/cloud-security-activity.service.ts new file mode 100644 index 0000000000..84574ad6e5 --- /dev/null +++ b/apps/api/src/cloud-security/cloud-security-activity.service.ts @@ -0,0 +1,182 @@ +import { Injectable } from '@nestjs/common'; +import { db, Prisma } from '@db'; + +export interface ActivityEntry { + id: string; + type: 'scan' | 'remediation' | 'rollback' | 'service_change'; + description: string; + userId: string | null; + userName: string | null; + status: 'success' | 'failed' | 'info'; + timestamp: string; + metadata?: Record; +} + +const ACTION_TYPE_MAP: Record = { + scan_started: 'scan', + scan_completed: 'scan', + remediation_executed: 'remediation', + remediation_failed: 'remediation', + rollback_executed: 'rollback', + rollback_failed: 'rollback', + service_toggled: 'service_change', +}; + +const REMEDIATION_STATUS_MAP: Record = { + success: 'success', + executing: 'info', + failed: 'failed', + rolled_back: 'info', + rollback_failed: 'failed', + pending: 'info', +}; + +@Injectable() +export class CloudSecurityActivityService { + async getActivity(params: { + connectionId: string; + organizationId: string; + take: number; + }): Promise { + // Fetch both sources in parallel + const [auditLogs, remediationActions] = await Promise.all([ + this.getAuditLogEntries(params), + this.getRemediationEntries(params), + ]); + + // Merge and sort by timestamp descending + const merged = [...auditLogs, ...remediationActions].sort( + (a, b) => + new Date(b.timestamp).getTime() - new Date(a.timestamp).getTime(), + ); + + return merged.slice(0, params.take); + } + + private async getAuditLogEntries(params: { + connectionId: string; + organizationId: string; + }): Promise { + const logs = await db.auditLog.findMany({ + where: { + organizationId: params.organizationId, + entityType: 'integration', + AND: [ + { + data: { + path: ['resource'], + equals: 'cloud-security', + } satisfies Prisma.JsonFilter, + }, + { + data: { + path: ['connectionId'], + equals: params.connectionId, + } satisfies Prisma.JsonFilter, + }, + ], + }, + include: { + user: { select: { id: true, name: true } }, + }, + orderBy: { timestamp: 'desc' }, + take: 100, + }); + + return logs.map((log) => { + const data = log.data as Record; + const action = data.action as string; + + let status: ActivityEntry['status'] = 'info'; + if (action === 'scan_completed') status = 'success'; + if (action === 'remediation_executed') status = 'success'; + if (action === 'rollback_executed') status = 'success'; + if (action === 'remediation_failed') status = 'failed'; + if (action === 'rollback_failed') status = 'failed'; + + return { + id: log.id, + type: ACTION_TYPE_MAP[action] ?? 'scan', + description: log.description ?? '', + userId: log.user?.id ?? null, + userName: log.user?.name ?? null, + status, + timestamp: log.timestamp.toISOString(), + metadata: data, + }; + }); + } + + private async getRemediationEntries(params: { + connectionId: string; + organizationId: string; + }): Promise { + const actions = await db.remediationAction.findMany({ + where: { + connectionId: params.connectionId, + organizationId: params.organizationId, + }, + orderBy: { createdAt: 'desc' }, + take: 100, + }); + + // Collect unique user IDs to fetch names + const userIds = [...new Set(actions.map((a) => a.initiatedById))]; + const filteredUserIds = userIds.filter((id) => id !== 'system'); + const users = + filteredUserIds.length > 0 + ? await db.user.findMany({ + where: { id: { in: filteredUserIds } }, + select: { id: true, name: true }, + }) + : []; + const userMap = new Map(users.map((u) => [u.id, u.name])); + userMap.set('system', 'System'); + + return actions.map((action) => { + const isRollback = + action.status === 'rolled_back' || action.status === 'rollback_failed'; + const type: ActivityEntry['type'] = isRollback + ? 'rollback' + : 'remediation'; + + let description: string; + switch (action.status) { + case 'success': + description = `Applied auto-fix: ${action.remediationKey} on ${action.resourceId}`; + break; + case 'failed': + description = `Auto-fix failed: ${action.remediationKey} on ${action.resourceId}`; + break; + case 'rolled_back': + description = `Rolled back: ${action.remediationKey} on ${action.resourceId}`; + break; + case 'rollback_failed': + description = `Rollback failed: ${action.remediationKey} on ${action.resourceId}`; + break; + case 'executing': + description = `Executing: ${action.remediationKey} on ${action.resourceId}`; + break; + default: + description = `${action.remediationKey} on ${action.resourceId} (${action.status})`; + } + + return { + id: action.id, + type, + description, + userId: action.initiatedById, + userName: userMap.get(action.initiatedById) ?? null, + status: REMEDIATION_STATUS_MAP[action.status] ?? 'info', + timestamp: (action.executedAt ?? action.createdAt).toISOString(), + metadata: { + remediationKey: action.remediationKey, + resourceId: action.resourceId, + resourceType: action.resourceType, + riskLevel: action.riskLevel, + errorMessage: action.errorMessage, + }, + }; + }); + } +} diff --git a/apps/api/src/cloud-security/cloud-security-audit.ts b/apps/api/src/cloud-security/cloud-security-audit.ts new file mode 100644 index 0000000000..75285b6827 --- /dev/null +++ b/apps/api/src/cloud-security/cloud-security-audit.ts @@ -0,0 +1,46 @@ +import { db } from '@db'; + +interface CloudSecurityAuditParams { + organizationId: string; + userId: string; + connectionId: string; + action: + | 'scan_started' + | 'scan_completed' + | 'remediation_executed' + | 'remediation_failed' + | 'rollback_executed' + | 'rollback_failed' + | 'service_toggled'; + description: string; + metadata?: Record; +} + +export async function logCloudSecurityActivity( + params: CloudSecurityAuditParams, +) { + try { + // auditLog.userId is a FK to User — skip if no real user context + if (!params.userId || params.userId === 'system') { + return; + } + + await db.auditLog.create({ + data: { + organizationId: params.organizationId, + userId: params.userId, + entityType: 'integration', + entityId: params.connectionId, + description: params.description, + data: { + action: params.action, + resource: 'cloud-security', + connectionId: params.connectionId, + ...params.metadata, + }, + }, + }); + } catch { + // Don't fail the main operation if audit logging fails + } +} diff --git a/apps/api/src/cloud-security/cloud-security-legacy.service.ts b/apps/api/src/cloud-security/cloud-security-legacy.service.ts index 47fe1d94f9..e7a230c2a1 100644 --- a/apps/api/src/cloud-security/cloud-security-legacy.service.ts +++ b/apps/api/src/cloud-security/cloud-security-legacy.service.ts @@ -6,11 +6,7 @@ import { } from '@nestjs/common'; import { db } from '@db'; import { Prisma } from '@db'; -import { - createCipheriv, - randomBytes, - scryptSync, -} from 'crypto'; +import { createCipheriv, randomBytes, scryptSync } from 'crypto'; import { DescribeRegionsCommand, EC2Client } from '@aws-sdk/client-ec2'; import { GetCallerIdentityCommand, STSClient } from '@aws-sdk/client-sts'; @@ -197,13 +193,10 @@ export class CloudSecurityLegacyService { let accountIdentity: string; try { - const identity = await stsClient.send( - new GetCallerIdentityCommand({}), - ); + const identity = await stsClient.send(new GetCallerIdentityCommand({})); accountIdentity = identity.Account || ''; } catch (error) { - const msg = - error instanceof Error ? error.message : 'Failed to validate'; + const msg = error instanceof Error ? error.message : 'Failed to validate'; throw new BadRequestException(`Invalid AWS credentials: ${msg}`); } diff --git a/apps/api/src/cloud-security/cloud-security-query.service.ts b/apps/api/src/cloud-security/cloud-security-query.service.ts index 6a506d3335..ca8cb85164 100644 --- a/apps/api/src/cloud-security/cloud-security-query.service.ts +++ b/apps/api/src/cloud-security/cloud-security-query.service.ts @@ -7,6 +7,15 @@ const CLOUD_PROVIDER_CATEGORY = 'Cloud'; /** Scan window for filtering legacy results to latest scan only */ const SCAN_WINDOW_MS = 10 * 60 * 1000; // 10 minutes +/** Extract project ID from a GCP resource path like //iam.googleapis.com/projects/my-proj/... */ +function extractProjectIdFromResource( + resourceId: string | null, +): string | null { + if (!resourceId) return null; + const match = resourceId.match(/\/projects\/([^/]+)/); + return match?.[1] ?? null; +} + export interface CloudProvider { id: string; integrationId: string; @@ -17,6 +26,7 @@ export interface CloudProvider { status: string; createdAt: Date; updatedAt: Date; + reconnectedAt?: Date; isLegacy: boolean; variables: Record | null; requiredVariables: string[]; @@ -37,6 +47,10 @@ export interface CloudFinding { completedAt: Date | null; connectionId: string; providerSlug: string; + serviceId: string | null; + findingKey: string | null; + resourceId: string | null; + projectDisplayName: string | null; integration: { integrationId: string }; } @@ -93,6 +107,12 @@ export class CloudSecurityQueryService { const newProviders: CloudProvider[] = newConnections.map((conn) => { const metadata = (conn.metadata || {}) as Record; const manifest = getManifest(conn.provider.slug); + const reconnectMarker = metadata.reconnectedAt; + const reconnectedAt = + typeof reconnectMarker === 'string' && + !Number.isNaN(new Date(reconnectMarker).getTime()) + ? new Date(reconnectMarker) + : undefined; return { id: conn.id, integrationId: conn.provider.slug, @@ -106,6 +126,7 @@ export class CloudSecurityQueryService { status: conn.status, createdAt: conn.createdAt, updatedAt: conn.updatedAt, + reconnectedAt, isLegacy: false, variables: (conn.variables as Record) ?? null, requiredVariables: getRequiredVariables(conn.provider.slug), @@ -114,14 +135,10 @@ export class CloudSecurityQueryService { ? metadata.accountId : undefined, regions: Array.isArray(metadata.regions) - ? metadata.regions.filter( - (r): r is string => typeof r === 'string', - ) + ? metadata.regions.filter((r): r is string => typeof r === 'string') : undefined, tenantId: - typeof metadata.tenantId === 'string' - ? metadata.tenantId - : undefined, + typeof metadata.tenantId === 'string' ? metadata.tenantId : undefined, subscriptionId: typeof metadata.subscriptionId === 'string' ? metadata.subscriptionId @@ -156,14 +173,10 @@ export class CloudSecurityQueryService { ? settings.accountId : undefined, regions: Array.isArray(settings.regions) - ? settings.regions.filter( - (r): r is string => typeof r === 'string', - ) + ? settings.regions.filter((r): r is string => typeof r === 'string') : undefined, tenantId: - typeof settings.tenantId === 'string' - ? settings.tenantId - : undefined, + typeof settings.tenantId === 'string' ? settings.tenantId : undefined, subscriptionId: typeof settings.subscriptionId === 'string' ? settings.subscriptionId @@ -206,6 +219,18 @@ export class CloudSecurityQueryService { connections.map((c) => [c.id, c.provider.slug]), ); + // Build project ID → name map from all GCP connections + const projectNameMap = new Map(); + for (const conn of connections) { + const vars = (conn.variables ?? {}) as Record; + const names = vars.project_names as Record | undefined; + if (names) { + for (const [id, name] of Object.entries(names)) { + projectNameMap.set(id, name); + } + } + } + const latestRuns = await db.integrationCheckRun.findMany({ where: { connectionId: { in: connectionIds }, @@ -219,9 +244,7 @@ export class CloudSecurityQueryService { const latestRunIds = latestRuns.map((r) => r.id); if (latestRunIds.length === 0) return []; - const checkRunMap = Object.fromEntries( - latestRuns.map((cr) => [cr.id, cr]), - ); + const checkRunMap = Object.fromEntries(latestRuns.map((cr) => [cr.id, cr])); const results = await db.integrationCheckResult.findMany({ where: { checkRunId: { in: latestRunIds } }, @@ -234,6 +257,8 @@ export class CloudSecurityQueryService { collectedAt: true, checkRunId: true, passed: true, + evidence: true, + resourceId: true, }, orderBy: { collectedAt: 'desc' }, }); @@ -243,6 +268,7 @@ export class CloudSecurityQueryService { const slug = checkRun ? connectionToSlug[checkRun.connectionId] || 'unknown' : 'unknown'; + const evidence = (result.evidence ?? {}) as Record; return { id: result.id, title: result.title, @@ -253,6 +279,18 @@ export class CloudSecurityQueryService { completedAt: result.collectedAt, connectionId: checkRun?.connectionId ?? '', providerSlug: slug, + serviceId: (evidence.serviceId as string) ?? null, + findingKey: (evidence.findingKey as string) ?? null, + resourceId: result.resourceId ?? null, + projectDisplayName: (() => { + const fromEvidence = evidence.projectDisplayName as + | string + | undefined; + if (fromEvidence) return fromEvidence; + const projectId = extractProjectIdFromResource(result.resourceId); + if (!projectId) return null; + return projectNameMap.get(projectId) ?? projectId; + })(), integration: { integrationId: slug }, }; }); @@ -274,9 +312,7 @@ export class CloudSecurityQueryService { if (legacyIds.length === 0) return []; const lastRunMap = new Map( - activeLegacy - .filter((i) => i.lastRunAt) - .map((i) => [i.id, i.lastRunAt!]), + activeLegacy.filter((i) => i.lastRunAt).map((i) => [i.id, i.lastRunAt!]), ); const results = await db.integrationResult.findMany({ @@ -320,6 +356,10 @@ export class CloudSecurityQueryService { completedAt: result.completedAt, connectionId: result.integration.id, providerSlug: result.integration.integrationId, + serviceId: null, + findingKey: null, + resourceId: null, + projectDisplayName: null, integration: { integrationId: result.integration.integrationId }, })); } diff --git a/apps/api/src/cloud-security/cloud-security.controller.ts b/apps/api/src/cloud-security/cloud-security.controller.ts index a3c90112e3..98d6e363f8 100644 --- a/apps/api/src/cloud-security/cloud-security.controller.ts +++ b/apps/api/src/cloud-security/cloud-security.controller.ts @@ -10,7 +10,9 @@ import { HttpException, HttpStatus, UseGuards, + Req, } from '@nestjs/common'; +import { SkipThrottle } from '@nestjs/throttler'; import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; import { PermissionGuard } from '../auth/permission.guard'; import { RequirePermission } from '../auth/require-permission.decorator'; @@ -21,6 +23,14 @@ import { } from './cloud-security.service'; import { CloudSecurityQueryService } from './cloud-security-query.service'; import { CloudSecurityLegacyService } from './cloud-security-legacy.service'; +import { logCloudSecurityActivity } from './cloud-security-audit'; +import { CloudSecurityActivityService } from './cloud-security-activity.service'; +import { + GCPSecurityService, + type GcpSetupStep, + type GcpSetupStepId, +} from './providers/gcp-security.service'; +import { AzureSecurityService } from './providers/azure-security.service'; @Controller({ path: 'cloud-security', version: '1' }) export class CloudSecurityController { @@ -30,9 +40,42 @@ export class CloudSecurityController { private readonly cloudSecurityService: CloudSecurityService, private readonly queryService: CloudSecurityQueryService, private readonly legacyService: CloudSecurityLegacyService, + private readonly activityService: CloudSecurityActivityService, + private readonly gcpSecurityService: GCPSecurityService, + private readonly azureSecurityService: AzureSecurityService, ) {} + @Get('activity') + @SkipThrottle() + @UseGuards(HybridAuthGuard, PermissionGuard) + @RequirePermission('integration', 'read') + async getActivity( + @Query('connectionId') connectionId: string, + @Query('take') take: string | undefined, + @OrganizationId() organizationId: string, + ) { + if (!connectionId) { + throw new HttpException( + 'connectionId query parameter is required', + HttpStatus.BAD_REQUEST, + ); + } + + const parsedTake = take + ? Math.min(100, Math.max(1, parseInt(take, 10) || 30)) + : 30; + + const activity = await this.activityService.getActivity({ + connectionId, + organizationId, + take: parsedTake, + }); + + return { data: activity, count: activity.length }; + } + @Get('providers') + @SkipThrottle() @UseGuards(HybridAuthGuard, PermissionGuard) @RequirePermission('integration', 'read') async getProviders(@OrganizationId() organizationId: string) { @@ -41,6 +84,7 @@ export class CloudSecurityController { } @Get('findings') + @SkipThrottle() @UseGuards(HybridAuthGuard, PermissionGuard) @RequirePermission('integration', 'read') async getFindings(@OrganizationId() organizationId: string) { @@ -54,8 +98,8 @@ export class CloudSecurityController { async scan( @Param('connectionId') connectionId: string, @OrganizationId() organizationId: string, + @Req() req: { userId?: string; authType?: string }, ) { - this.logger.log( `Cloud security scan requested for connection ${connectionId}`, ); @@ -66,23 +110,642 @@ export class CloudSecurityController { ); if (!result.success) { + // GCP setup issues are user-fixable — return 400 with structured error + const isSetupError = + result.error?.startsWith('SCC_NOT_ACTIVATED:') || + result.error?.startsWith('GCP_ORG_MISSING:'); + const errorStr = result.error ?? ''; + const errorCode = isSetupError ? errorStr.split(':')[0] : undefined; + const message = isSetupError + ? errorStr.substring(errorStr.indexOf(':') + 2) + : result.error || 'Scan failed'; + throw new HttpException( { - message: result.error || 'Scan failed', + message, provider: result.provider, + ...(errorCode && { errorCode }), }, - HttpStatus.INTERNAL_SERVER_ERROR, + isSetupError + ? HttpStatus.BAD_REQUEST + : HttpStatus.INTERNAL_SERVER_ERROR, ); } + const totalFindings = result.findings.length; + const failedCount = result.findings.filter((f) => !f.passed).length; + const passedCount = result.findings.filter((f) => f.passed).length; + + // Only write audit log when we have a real userId (session auth). + // API key auth has no user context, and auditLog.userId is a FK to User. + const scanUserId = req.userId; + if (scanUserId) + await logCloudSecurityActivity({ + organizationId, + userId: scanUserId, + connectionId, + action: 'scan_completed', + description: `Ran cloud security scan — ${totalFindings} findings (${failedCount} failed, ${passedCount} passed)`, + metadata: { + totalFindings, + failedCount, + passedCount, + provider: result.provider, + }, + }); + return { success: true, provider: result.provider, - findingsCount: result.findings.length, + findingsCount: totalFindings, scannedAt: result.scannedAt, }; } + @Post('detect-services/:connectionId') + @UseGuards(HybridAuthGuard, PermissionGuard) + @RequirePermission('integration', 'read') + async detectServices( + @Param('connectionId') connectionId: string, + @OrganizationId() organizationId: string, + ) { + try { + const services = await this.cloudSecurityService.detectServices( + connectionId, + organizationId, + ); + return { services }; + } catch (error) { + if (error instanceof ConnectionNotFoundError) { + throw new HttpException('Connection not found', HttpStatus.NOT_FOUND); + } + const message = + error instanceof Error ? error.message : 'Failed to detect services'; + throw new HttpException(message, HttpStatus.BAD_REQUEST); + } + } + + @Post('detect-gcp-org/:connectionId') + @UseGuards(HybridAuthGuard, PermissionGuard) + @RequirePermission('integration', 'read') + async detectGcpOrg( + @Param('connectionId') connectionId: string, + @OrganizationId() organizationId: string, + ) { + try { + const connection = await this.cloudSecurityService.getConnectionForDetect( + connectionId, + organizationId, + ); + + const credentials = connection.credentials as Record; + const accessToken = credentials?.access_token as string; + if (!accessToken) { + throw new Error( + 'No access token found. Reconnect the GCP integration.', + ); + } + + const rawOrgs = + await this.gcpSecurityService.detectOrganizations(accessToken); + + // Fetch projects per org in parallel + const orgsWithProjects = await Promise.all( + rawOrgs.map(async (org) => { + const projects = + await this.gcpSecurityService.detectProjectsForOrg( + accessToken, + org.id, + ); + return { + id: org.id, + displayName: org.displayName, + projects: projects + .map((p) => ({ id: p.id, name: p.name, number: p.number })) + .sort((a, b) => a.name.localeCompare(b.name)), + }; + }), + ); + + // If exactly 1 org found, auto-save it + if (rawOrgs.length === 1) { + await this.cloudSecurityService.saveConnectionVariable( + connectionId, + 'organization_id', + rawOrgs[0].id, + organizationId, + ); + } + + const variables = (connection.variables ?? {}) as Record< + string, + unknown + >; + + // Return only explicitly selected projects — never auto-select + const existingProjectIds = this.readProjectIds(variables); + + return { + organizations: orgsWithProjects, + selectedProjectIds: existingProjectIds, + selectedOrganizationId: variables.organization_id as string | undefined, + }; + } catch (error) { + const message = + error instanceof Error + ? error.message + : 'Failed to detect GCP organization'; + throw new HttpException(message, HttpStatus.BAD_REQUEST); + } + } + + @Post('select-gcp-projects/:connectionId') + @UseGuards(HybridAuthGuard, PermissionGuard) + @RequirePermission('integration', 'update') + async selectGcpProjects( + @Param('connectionId') connectionId: string, + @Body() + body: { + projectIds: string[]; + projectNames?: Record; + gcpOrganizationId?: string; + }, + @OrganizationId() organizationId: string, + ) { + if (!body?.projectIds?.length) { + throw new HttpException( + 'projectIds is required', + HttpStatus.BAD_REQUEST, + ); + } + if (body.gcpOrganizationId) { + await this.cloudSecurityService.saveConnectionVariable( + connectionId, + 'organization_id', + body.gcpOrganizationId, + organizationId, + ); + } + await this.cloudSecurityService.saveConnectionVariable( + connectionId, + 'project_ids', + body.projectIds, + organizationId, + ); + if (body.projectNames) { + await this.cloudSecurityService.saveConnectionVariable( + connectionId, + 'project_names', + body.projectNames as unknown as string[], + organizationId, + ); + } + return { projectIds: body.projectIds }; + } + + @Post('setup-gcp/:connectionId') + @UseGuards(HybridAuthGuard, PermissionGuard) + @RequirePermission('integration', 'update') + async setupGcp( + @Param('connectionId') connectionId: string, + @Body() body: { projectId?: string }, + @OrganizationId() organizationId: string, + ) { + try { + const context = await this.resolveGcpSetupContext( + connectionId, + organizationId, + body?.projectId, + ); + + const result = await this.gcpSecurityService.autoSetup({ + accessToken: context.accessToken, + organizationId: context.organizationId ?? '', + projectId: context.projectId, + }); + + return { + ...result, + steps: this.withGcpResolveActions(result.steps, connectionId), + organizationId: context.organizationId, + projectId: context.projectId, + projects: context.projects, + }; + } catch (error) { + const message = + error instanceof Error ? error.message : 'GCP setup failed'; + throw new HttpException(message, HttpStatus.BAD_REQUEST); + } + } + + @Post('setup-gcp/:connectionId/resolve-step') + @UseGuards(HybridAuthGuard, PermissionGuard) + @RequirePermission('integration', 'update') + async resolveGcpSetupStep( + @Param('connectionId') connectionId: string, + @Body() body: { stepId: GcpSetupStepId }, + @OrganizationId() organizationId: string, + ) { + try { + if (!body?.stepId) { + throw new Error('stepId is required'); + } + + const context = await this.resolveGcpSetupContext( + connectionId, + organizationId, + ); + const result = await this.gcpSecurityService.resolveSetupStep({ + stepId: body.stepId, + accessToken: context.accessToken, + organizationId: context.organizationId ?? '', + projectId: context.projectId, + }); + + return { + email: result.email, + step: this.withGcpResolveActions([result.step], connectionId)[0], + organizationId: context.organizationId, + projectId: context.projectId, + projects: context.projects, + }; + } catch (error) { + const message = + error instanceof Error ? error.message : 'Failed to resolve setup step'; + throw new HttpException(message, HttpStatus.BAD_REQUEST); + } + } + + private withGcpResolveActions( + steps: GcpSetupStep[], + connectionId: string, + ): GcpSetupStep[] { + return steps.map((step) => { + if (step.success) return step; + return { + ...step, + resolveAction: { + label: 'Resolve this', + method: 'POST', + endpoint: `/v1/cloud-security/setup-gcp/${connectionId}/resolve-step`, + body: { stepId: step.id }, + }, + }; + }); + } + + /** + * Read selected project IDs from connection variables. + * Only reads the new `project_ids` array — the old `project_id` string + * was auto-saved by previous code and does NOT represent user choice. + */ + private readProjectIds( + variables: Record, + ): string[] { + if (Array.isArray(variables.project_ids)) { + return variables.project_ids as string[]; + } + return []; + } + + private async resolveGcpSetupContext( + connectionId: string, + organizationId: string, + overrideProjectId?: string, + ) { + const connection = await this.cloudSecurityService.getConnectionForDetect( + connectionId, + organizationId, + ); + + const credentials = connection.credentials as Record; + const accessToken = credentials?.access_token as string; + if (!accessToken) { + throw new Error('No access token found. Reconnect the GCP integration.'); + } + + const variables = (connection.variables ?? {}) as Record; + let gcpOrgId = variables.organization_id as string | undefined; + + if (!gcpOrgId) { + const orgs = + await this.gcpSecurityService.detectOrganizations(accessToken); + if (orgs.length > 0) { + gcpOrgId = orgs[0].id; + await this.cloudSecurityService.saveConnectionVariable( + connectionId, + 'organization_id', + gcpOrgId, + organizationId, + ); + } + } + + // Fetch projects scoped to the org when available, else all + const projects = gcpOrgId + ? await this.gcpSecurityService.detectProjectsForOrg( + accessToken, + gcpOrgId, + ) + : await this.gcpSecurityService.detectProjects(accessToken); + + // For API enablement, use override or first selected project + const selectedIds = this.readProjectIds(variables); + const projectId = + overrideProjectId || selectedIds[0] || projects[0]?.id; + + if (!projectId) { + throw new Error( + 'No GCP projects found. Ensure your account has access to at least one project.', + ); + } + + return { + accessToken, + organizationId: gcpOrgId, + projectId, + projects: projects.map((p) => ({ + id: p.id, + name: p.name, + })), + selectedProjectIds: selectedIds, + }; + } + + @Post('setup-azure/:connectionId') + @UseGuards(HybridAuthGuard, PermissionGuard) + @RequirePermission('integration', 'update') + async setupAzure( + @Param('connectionId') connectionId: string, + @OrganizationId() organizationId: string, + ) { + try { + const connection = await this.cloudSecurityService.getConnectionForDetect( + connectionId, + organizationId, + ); + + const credentials = connection.credentials as Record; + const accessToken = credentials?.access_token as string; + if (!accessToken) { + throw new Error( + 'No access token found. Reconnect the Azure integration.', + ); + } + + const variables = (connection.variables ?? {}) as Record; + const steps: Array<{ name: string; success: boolean; error?: string }> = + []; + + // Step 1: Detect subscriptions + let subscriptionId = variables.subscription_id as string | undefined; + let subscriptionName: string | undefined; + try { + const subs = + await this.azureSecurityService.detectSubscriptions(accessToken); + if (subs.length > 0) { + subscriptionId = subs[0].id; + subscriptionName = subs[0].displayName; + await this.cloudSecurityService.saveConnectionVariable( + connectionId, + 'subscription_id', + subscriptionId, + organizationId, + ); + steps.push({ + name: `Subscription detected: ${subscriptionName}`, + success: true, + }); + } else { + steps.push({ + name: 'Detect subscription', + success: false, + error: + 'No Azure subscriptions found. Ensure your account has an active subscription.', + }); + } + } catch (error) { + steps.push({ + name: 'Detect subscription', + success: false, + error: + error instanceof Error + ? error.message + : 'Failed to detect subscriptions', + }); + } + + // Step 2: Verify Defender access + if (subscriptionId) { + try { + const resp = await fetch( + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Security/assessments?api-version=2021-06-01&$top=1`, + { headers: { Authorization: `Bearer ${accessToken}` } }, + ); + if (resp.ok) { + steps.push({ + name: 'Microsoft Defender for Cloud access verified', + success: true, + }); + } else { + steps.push({ + name: 'Microsoft Defender for Cloud access', + success: false, + error: + 'Your account needs the "Security Reader" role on this subscription.', + }); + } + } catch { + steps.push({ + name: 'Microsoft Defender for Cloud access', + success: false, + error: 'Could not verify Defender access.', + }); + } + + // Step 3: Verify general read access + try { + const resp = await fetch( + `https://management.azure.com/subscriptions/${subscriptionId}/resources?api-version=2021-04-01&$top=1`, + { headers: { Authorization: `Bearer ${accessToken}` } }, + ); + if (resp.ok) { + steps.push({ + name: 'Resource read access verified', + success: true, + }); + } else { + steps.push({ + name: 'Resource read access', + success: false, + error: + 'Your account needs at least the "Reader" role on this subscription.', + }); + } + } catch { + steps.push({ + name: 'Resource read access', + success: false, + error: 'Verification failed', + }); + } + + // Step 4: Check write permissions for auto-fix + // Use the permissions check API to see if user can write resources + let canAutoFix = false; + try { + const permResp = await fetch( + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Authorization/permissions?api-version=2022-04-01`, + { headers: { Authorization: `Bearer ${accessToken}` } }, + ); + if (permResp.ok) { + const permData = (await permResp.json()) as { + value: Array<{ actions: string[]; notActions: string[] }>; + }; + const allActions = permData.value?.flatMap((p) => p.actions) ?? []; + canAutoFix = allActions.some((a) => a === '*' || a === '*/write'); + if (canAutoFix) { + steps.push({ + name: 'Auto-fix capability: write access available', + success: true, + }); + } else { + steps.push({ + name: 'Auto-fix capability', + success: true, + error: + 'Read-only access. Auto-fix requires Contributor role — you can still scan and view findings.', + }); + } + } + } catch { + // Non-critical — auto-fix detection failed + } + } + + return { + steps, + subscriptionId, + subscriptionName, + }; + } catch (error) { + const message = + error instanceof Error ? error.message : 'Azure setup failed'; + throw new HttpException(message, HttpStatus.BAD_REQUEST); + } + } + + @Post('validate-azure/:connectionId') + @UseGuards(HybridAuthGuard, PermissionGuard) + @RequirePermission('integration', 'read') + async validateAzure( + @Param('connectionId') connectionId: string, + @OrganizationId() organizationId: string, + ) { + try { + const connection = await this.cloudSecurityService.getConnectionForDetect( + connectionId, + organizationId, + ); + + const credentials = connection.credentials as Record; + const variables = (connection.variables ?? {}) as Record; + const tenantId = credentials?.tenantId as string; + const clientId = credentials?.clientId as string; + const clientSecret = credentials?.clientSecret as string; + const subscriptionId = (credentials?.subscriptionId ?? + variables.subscription_id) as string | undefined; + + const steps: Array<{ name: string; success: boolean; error?: string }> = + []; + + if (!subscriptionId) { + steps.push({ + name: 'Subscription ID', + success: false, + error: + 'No subscription ID configured. Go to the Azure integration settings to auto-detect your subscription.', + }); + return { steps, subscriptionId: null }; + } + + // Step 1: Validate credentials (token exchange) + let accessToken: string | null = null; + try { + accessToken = await this.azureSecurityService.getAccessToken( + tenantId, + clientId, + clientSecret, + ); + steps.push({ name: 'Authenticate with Azure', success: true }); + } catch (error) { + steps.push({ + name: 'Authenticate with Azure', + success: false, + error: + error instanceof Error ? error.message : 'Authentication failed', + }); + return { steps, subscriptionId }; + } + + // Step 2: Verify subscription access + try { + const resp = await fetch( + `https://management.azure.com/subscriptions/${subscriptionId}?api-version=2022-12-01`, + { headers: { Authorization: `Bearer ${accessToken}` } }, + ); + if (resp.ok) { + steps.push({ name: 'Subscription access verified', success: true }); + } else { + const errorText = await resp.text(); + steps.push({ + name: 'Subscription access', + success: false, + error: `Cannot access subscription. Assign "Reader" role to the app registration. (${resp.status}: ${errorText.slice(0, 200)})`, + }); + } + } catch { + steps.push({ + name: 'Subscription access', + success: false, + error: 'Network error', + }); + } + + // Step 3: Verify Defender access + try { + const resp = await fetch( + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Security/assessments?api-version=2021-06-01&$top=1`, + { headers: { Authorization: `Bearer ${accessToken}` } }, + ); + if (resp.ok) { + steps.push({ + name: 'Microsoft Defender for Cloud access', + success: true, + }); + } else { + steps.push({ + name: 'Microsoft Defender for Cloud access', + success: false, + error: 'Assign "Security Reader" role to the app registration.', + }); + } + } catch { + steps.push({ + name: 'Microsoft Defender for Cloud access', + success: false, + error: 'Could not verify Defender access', + }); + } + + return { steps, subscriptionId }; + } catch (error) { + const message = + error instanceof Error ? error.message : 'Azure validation failed'; + throw new HttpException(message, HttpStatus.BAD_REQUEST); + } + } + @Post('trigger/:connectionId') @UseGuards(HybridAuthGuard, PermissionGuard) @RequirePermission('integration', 'update') @@ -143,7 +806,11 @@ export class CloudSecurityController { @RequirePermission('integration', 'create') async connectLegacy( @OrganizationId() organizationId: string, - @Body() body: { provider: 'aws' | 'gcp' | 'azure'; credentials: Record }, + @Body() + body: { + provider: 'aws' | 'gcp' | 'azure'; + credentials: Record; + }, ) { const result = await this.legacyService.connectLegacy( organizationId, @@ -163,7 +830,11 @@ export class CloudSecurityController { body.accessKeyId, body.secretAccessKey, ); - return { success: true, accountId: result.accountId, regions: result.regions }; + return { + success: true, + accountId: result.accountId, + regions: result.regions, + }; } @Delete('legacy/:integrationId') diff --git a/apps/api/src/cloud-security/cloud-security.module.ts b/apps/api/src/cloud-security/cloud-security.module.ts index 19f0137f34..3db256d943 100644 --- a/apps/api/src/cloud-security/cloud-security.module.ts +++ b/apps/api/src/cloud-security/cloud-security.module.ts @@ -1,4 +1,4 @@ -import { Module } from '@nestjs/common'; +import { Module, forwardRef } from '@nestjs/common'; import { CloudSecurityController } from './cloud-security.controller'; import { CloudSecurityService } from './cloud-security.service'; import { CloudSecurityQueryService } from './cloud-security-query.service'; @@ -6,19 +6,30 @@ import { CloudSecurityLegacyService } from './cloud-security-legacy.service'; import { GCPSecurityService } from './providers/gcp-security.service'; import { AWSSecurityService } from './providers/aws-security.service'; import { AzureSecurityService } from './providers/azure-security.service'; +import { RemediationController } from './remediation.controller'; +import { RemediationService } from './remediation.service'; +import { GcpRemediationService } from './gcp-remediation.service'; +import { AzureRemediationService } from './azure-remediation.service'; +import { AiRemediationService } from './ai-remediation.service'; +import { CloudSecurityActivityService } from './cloud-security-activity.service'; import { IntegrationPlatformModule } from '../integration-platform/integration-platform.module'; import { AuthModule } from '../auth/auth.module'; @Module({ - imports: [IntegrationPlatformModule, AuthModule], - controllers: [CloudSecurityController], + imports: [forwardRef(() => IntegrationPlatformModule), AuthModule], + controllers: [CloudSecurityController, RemediationController], providers: [ CloudSecurityService, CloudSecurityQueryService, CloudSecurityLegacyService, + CloudSecurityActivityService, GCPSecurityService, AWSSecurityService, AzureSecurityService, + RemediationService, + GcpRemediationService, + AzureRemediationService, + AiRemediationService, ], exports: [CloudSecurityService], }) diff --git a/apps/api/src/cloud-security/cloud-security.service.ts b/apps/api/src/cloud-security/cloud-security.service.ts index d9088e90b6..07f94d4389 100644 --- a/apps/api/src/cloud-security/cloud-security.service.ts +++ b/apps/api/src/cloud-security/cloud-security.service.ts @@ -1,5 +1,5 @@ import { Injectable, Logger } from '@nestjs/common'; -import { db } from '@db'; +import { db, Prisma } from '@db'; import { getManifest } from '@trycompai/integration-platform'; import { runs, tasks } from '@trigger.dev/sdk'; import { CredentialVaultService } from '../integration-platform/services/credential-vault.service'; @@ -7,6 +7,7 @@ import { OAuthCredentialsService } from '../integration-platform/services/oauth- import { GCPSecurityService } from './providers/gcp-security.service'; import { AWSSecurityService } from './providers/aws-security.service'; import { AzureSecurityService } from './providers/azure-security.service'; +import { AWS_SERVICE_TASK_MAPPINGS } from './aws-task-mappings'; export interface SecurityFinding { id: string; @@ -161,26 +162,107 @@ export class CloudSecurityService { // Get variables for the scan const variables = (connection.variables as Record) || {}; + // Provider baselines are always scanned regardless of toggles. + const BASELINE_SERVICES_BY_PROVIDER: Record = { + aws: [ + 'cloudtrail', + 'config', + 'guardduty', + 'iam-analyzer', + 'cloudwatch', + 'kms', + ], + gcp: ['security-command-center'], + azure: [], + }; + const baselineServices = BASELINE_SERVICES_BY_PROVIDER[providerSlug] ?? []; + + // Smart service filtering: auto-detect is additive, user can only exclude. + // Scan = (detectedServices MINUS disabledServices) UNION baselineServices. + const disabledServices = new Set( + Array.isArray(variables.disabledServices) + ? (variables.disabledServices as string[]) + : [], + ); + let enabledServices: string[] | undefined; + + if ( + Array.isArray(variables.enabledServices) && + (variables.enabledServices as string[]).length > 0 + ) { + // Legacy format: explicit enabled list (backward compat) + baseline + const userEnabled = (variables.enabledServices as string[]).filter( + (s) => !disabledServices.has(s), + ); + enabledServices = [...new Set([...userEnabled, ...baselineServices])]; + } else if ( + Array.isArray(variables.detectedServices) && + (variables.detectedServices as string[]).length > 0 + ) { + // New smart format: detected minus disabled + baseline always included + const filtered = (variables.detectedServices as string[]).filter( + (s) => !disabledServices.has(s), + ); + enabledServices = [...new Set([...filtered, ...baselineServices])]; + } + // else: undefined = scan all adapters (no detection data at all) + try { let findings: SecurityFinding[]; + // Auto-detect GCP org ID if not set + if ( + providerSlug === 'gcp' && + !variables.organization_id && + credentials.access_token + ) { + this.logger.log('GCP org ID missing — auto-detecting...'); + try { + const orgs = await this.gcpService.detectOrganizations( + credentials.access_token as string, + ); + if (orgs.length > 0) { + variables.organization_id = orgs[0].id; + this.logger.log( + `Auto-detected GCP org: ${orgs[0].displayName} (${orgs[0].id})`, + ); + // Save for future scans + await db.integrationConnection.update({ + where: { id: connectionId }, + data: { + variables: { ...variables } as unknown as Prisma.InputJsonValue, + }, + }); + } else { + this.logger.warn('No GCP organizations found for this account'); + } + } catch (err) { + this.logger.warn( + `GCP org auto-detection failed: ${err instanceof Error ? err.message : String(err)}`, + ); + } + } + switch (providerSlug) { case 'gcp': findings = await this.gcpService.scanSecurityFindings( credentials, variables, + enabledServices, ); break; case 'aws': findings = await this.awsService.scanSecurityFindings( credentials, variables, + enabledServices, ); break; case 'azure': findings = await this.azureService.scanSecurityFindings( credentials, variables, + enabledServices, ); break; default: @@ -196,7 +278,55 @@ export class CloudSecurityService { // Store findings in database await this.storeFindings(connectionId, providerSlug, findings); - // Update last sync time + // Auto-satisfy evidence tasks based on passing scan results (AWS only) + if (providerSlug === 'aws') { + await this.autoSatisfyTasks(organizationId, findings); + } + + // GCP & Azure: auto-detect services from scan findings + if ( + (providerSlug === 'gcp' || providerSlug === 'azure') && + findings.length > 0 + ) { + const serviceIds = new Set(); + for (const f of findings) { + const evidence = f.evidence; + const serviceId = evidence?.serviceId as string | undefined; + if (serviceId) serviceIds.add(serviceId); + } + if (serviceIds.size > 0) { + const currentVars = variables ?? {}; + const existingDetected = Array.isArray(currentVars.detectedServices) + ? new Set(currentVars.detectedServices as string[]) + : new Set(); + const disabledSet = new Set( + Array.isArray(currentVars.disabledServices) + ? (currentVars.disabledServices as string[]) + : [], + ); + // Only auto-enable genuinely NEW services — don't override user's explicit disables + for (const id of serviceIds) { + if (!existingDetected.has(id)) disabledSet.delete(id); + } + // Merge: keep previously detected + add newly found (AFTER the new-check above) + for (const id of serviceIds) existingDetected.add(id); + await db.integrationConnection.update({ + where: { id: connectionId }, + data: { + variables: { + ...currentVars, + detectedServices: [...existingDetected], + disabledServices: [...disabledSet], + } as unknown as Prisma.InputJsonValue, + }, + }); + this.logger.log( + `${providerSlug.toUpperCase()}: detected ${serviceIds.size} service categories: ${[...serviceIds].join(', ')}`, + ); + } + } + + // Update last sync time (AWS detectedServices is handled by detectServices via Cost Explorer) await db.integrationConnection.update({ where: { id: connectionId }, data: { lastSyncAt: new Date() }, @@ -227,6 +357,142 @@ export class CloudSecurityService { } } + /** + * Detect which AWS services are actively used (via Cost Explorer). + * Saves detected services to connection variables for the frontend. + */ + async detectServices( + connectionId: string, + organizationId: string, + ): Promise { + const connection = await db.integrationConnection.findFirst({ + where: { id: connectionId, organizationId, status: 'active' }, + include: { provider: true }, + }); + + if (!connection) { + throw new ConnectionNotFoundError(); + } + + const decrypted = + await this.credentialVaultService.getDecryptedCredentials(connectionId); + if (!decrypted) { + throw new Error('No credentials found'); + } + + const variables = (connection.variables as Record) || {}; + let detected: string[]; + let gcpServicesByProject: Record | undefined; + + if (connection.provider.slug === 'gcp') { + const accessToken = decrypted.access_token as string; + if (!accessToken) throw new Error('GCP access token not found'); + + // Use explicitly selected projects, otherwise detect all (cron fallback) + const selectedIds = Array.isArray(variables.project_ids) + ? (variables.project_ids as string[]) + : []; + + const projects = + selectedIds.length > 0 + ? selectedIds.map((id) => ({ id })) + : await this.gcpService.detectProjects(accessToken); + const result = await this.gcpService.detectServices( + accessToken, + projects, + ); + detected = result.services; + gcpServicesByProject = result.servicesByProject; + } else if (connection.provider.slug === 'aws') { + detected = await this.awsService.detectActiveServices( + decrypted, + variables, + ); + } else { + // Azure and others: services are auto-detected from scan findings, not a separate API + return []; + } + + // Merge with existing detected services and only auto-enable genuinely NEW detections. + // This preserves explicit user toggles (both enabled and disabled). + const existingDetected = new Set( + Array.isArray(variables.detectedServices) + ? (variables.detectedServices as string[]) + : [], + ); + const updatedDisabled = new Set( + Array.isArray(variables.disabledServices) + ? (variables.disabledServices as string[]) + : [], + ); + for (const id of detected) { + if (!existingDetected.has(id)) { + updatedDisabled.delete(id); + } + existingDetected.add(id); + } + + await db.integrationConnection.update({ + where: { id: connectionId }, + data: { + variables: { + ...variables, + detectedServices: [...existingDetected], + disabledServices: [...updatedDisabled], + serviceDetectionCompletedAt: new Date().toISOString(), + ...(gcpServicesByProject && { servicesByProject: gcpServicesByProject }), + }, + }, + }); + + this.logger.log( + `Detected ${detected.length} active services for ${connection.provider.slug} connection ${connectionId}`, + ); + + return detected; + } + + /** + * Get connection with decrypted credentials (for GCP org detection). + */ + async getConnectionForDetect(connectionId: string, organizationId: string) { + const connection = await db.integrationConnection.findFirst({ + where: { id: connectionId, organizationId, status: 'active' }, + include: { provider: true }, + }); + if (!connection) throw new ConnectionNotFoundError(); + + const credentials = + await this.credentialVaultService.getDecryptedCredentials(connectionId); + return { ...connection, credentials }; + } + + /** + * Save a variable to a connection (e.g., organization_id after auto-detection). + */ + async saveConnectionVariable( + connectionId: string, + key: string, + value: string | string[], + organizationId: string, + ) { + const connection = await db.integrationConnection.findFirst({ + where: { id: connectionId, organizationId }, + }); + if (!connection) throw new ConnectionNotFoundError(); + + const variables = (connection.variables as Record) || {}; + await db.integrationConnection.update({ + where: { id: connectionId }, + data: { + variables: { + ...variables, + [key]: value, + } as unknown as Prisma.InputJsonValue, + }, + }); + } + async triggerScan( connectionId: string, organizationId: string, @@ -330,4 +596,107 @@ export class CloudSecurityService { } }); } + + /** + * Auto-satisfy evidence tasks when ALL findings for a service pass. + * + * Safety rules: + * - Only sets tasks to 'done' — never failed/in_progress/todo + * - Only when ALL findings pass for a service + * - Skips tasks with status 'not_relevant' (user intent) + * - Skips tasks already 'done' (idempotent) + * - Idempotent: re-running with same results is safe + */ + private async autoSatisfyTasks( + organizationId: string, + findings: SecurityFinding[], + ): Promise { + // Group findings by serviceId + const findingsByService = new Map(); + for (const finding of findings) { + const serviceId = finding.evidence?.serviceId as string | undefined; + if (!serviceId) continue; + const group = findingsByService.get(serviceId) ?? []; + group.push(finding); + findingsByService.set(serviceId, group); + } + + // Find services where ALL findings pass + const passingServices: string[] = []; + for (const [serviceId, serviceFindings] of findingsByService) { + if ( + serviceFindings.length > 0 && + serviceFindings.every((f) => f.passed) + ) { + passingServices.push(serviceId); + } + } + + if (passingServices.length === 0) return; + + // Collect all task template IDs to satisfy + const templateIds = new Set(); + for (const serviceId of passingServices) { + const mappedTemplates = AWS_SERVICE_TASK_MAPPINGS[serviceId]; + if (mappedTemplates) { + for (const id of mappedTemplates) { + templateIds.add(id); + } + } + } + + if (templateIds.size === 0) return; + + // For each template ID, only satisfy if ALL mapped services pass. + // A task template may be linked to multiple services (e.g. Encryption at Rest + // requires KMS + S3 + RDS + DynamoDB). Only mark done if every scanned + // service that maps to this template is fully passing. + const eligibleTemplateIds: string[] = []; + for (const templateId of templateIds) { + // Find all services that map to this template + const servicesForTemplate = Object.entries(AWS_SERVICE_TASK_MAPPINGS) + .filter(([, templates]) => templates.includes(templateId)) + .map(([serviceId]) => serviceId); + + // Only consider services that were actually scanned + const scannedServicesForTemplate = servicesForTemplate.filter((s) => + findingsByService.has(s), + ); + + // If no services were scanned for this template, skip + if (scannedServicesForTemplate.length === 0) continue; + + // All scanned services for this template must be passing + const allPassing = scannedServicesForTemplate.every((s) => + passingServices.includes(s), + ); + + if (allPassing) { + eligibleTemplateIds.push(templateId); + } + } + + if (eligibleTemplateIds.length === 0) return; + + const now = new Date(); + + // Update tasks: only those in todo/in_progress/in_review/failed status + const result = await db.task.updateMany({ + where: { + organizationId, + taskTemplateId: { in: eligibleTemplateIds }, + status: { in: ['todo', 'in_progress', 'in_review', 'failed'] }, + }, + data: { + status: 'done', + lastCompletedAt: now, + }, + }); + + if (result.count > 0) { + this.logger.log( + `Auto-satisfied ${result.count} evidence task(s) from passing AWS scan (services: ${passingServices.join(', ')})`, + ); + } + } } diff --git a/apps/api/src/cloud-security/gcp-ai-remediation.prompt.ts b/apps/api/src/cloud-security/gcp-ai-remediation.prompt.ts new file mode 100644 index 0000000000..77667c2ac9 --- /dev/null +++ b/apps/api/src/cloud-security/gcp-ai-remediation.prompt.ts @@ -0,0 +1,298 @@ +import { z } from 'zod'; + +// ─── Zod Schemas ──────────────────────────────────────────────────────────── + +export const gcpApiStepSchema = z.object({ + method: z + .enum(['GET', 'POST', 'PUT', 'PATCH', 'DELETE']) + .describe('HTTP method for the GCP REST API call'), + url: z + .string() + .describe( + 'Full HTTPS URL for the GCP REST API endpoint, e.g. https://storage.googleapis.com/storage/v1/b/my-bucket', + ), + body: z + .record(z.string(), z.unknown()) + .optional() + .describe('JSON request body for POST/PUT/PATCH requests'), + queryParams: z + .record(z.string(), z.string()) + .optional() + .describe( + 'URL query parameters, e.g. { "updateMask": "iamConfiguration" }', + ), + purpose: z + .string() + .describe('Human-readable description of what this step does'), +}); + +export type GcpApiStep = z.infer; + +export const gcpFixPlanSchema = z.object({ + canAutoFix: z + .boolean() + .describe('Whether this finding can be auto-fixed via GCP REST API calls'), + risk: z + .enum(['low', 'medium', 'high', 'critical']) + .describe('Risk level of applying this fix'), + description: z.string().describe('Human-readable description of the fix'), + currentState: z + .record(z.string(), z.unknown()) + .describe('Current configuration from evidence'), + proposedState: z + .record(z.string(), z.unknown()) + .describe('Configuration after fix is applied'), + readSteps: z + .array(gcpApiStepSchema) + .describe('GET requests to read current state before fixing'), + fixSteps: z.array(gcpApiStepSchema).describe('Requests to apply the fix'), + rollbackSteps: z + .array(gcpApiStepSchema) + .describe('Requests to reverse the fix'), + rollbackSupported: z + .boolean() + .describe('Whether this fix can be rolled back'), + requiresAcknowledgment: z + .boolean() + .describe('Whether user must acknowledge before execution'), + acknowledgmentMessage: z.string().optional(), + guidedSteps: z + .array(z.string()) + .optional() + .describe('Manual steps when canAutoFix is false'), + reason: z + .string() + .optional() + .describe('Why auto-fix is not possible when canAutoFix is false'), +}); + +export type GcpFixPlan = z.infer; + +// ─── System Prompt ────────────────────────────────────────────────────────── + +export const GCP_SYSTEM_PROMPT = `You are a GCP security remediation expert. You analyze Security Command Center findings and produce structured fix plans using GCP REST API calls. + +A human will ALWAYS review your plan before execution. Be precise and correct. + +## HOW GCP REST APIs WORK + +All GCP APIs follow this pattern: +- Authentication: Bearer token in Authorization header (handled by the executor) +- Base URLs: https://{service}.googleapis.com/{version}/{resource} +- Methods: GET (read), POST (create), PUT (replace), PATCH (update), DELETE (remove) +- PATCH requests: use queryParams.updateMask to specify which fields to update + +## OUTPUT RULES + +For each step, provide: +- method: HTTP method (GET, POST, PUT, PATCH, DELETE) +- url: Full HTTPS URL to the GCP API endpoint +- body: JSON request body (for POST/PUT/PATCH) +- queryParams: URL query parameters (e.g., updateMask for PATCH) +- purpose: Human-readable explanation + +## GCP API REFERENCE + +### Cloud Storage +- Get bucket: GET https://storage.googleapis.com/storage/v1/b/{bucket}?projection=full +- Update bucket: PATCH https://storage.googleapis.com/storage/v1/b/{bucket} + queryParams: { "updateMask": "field1,field2" } +- Get bucket IAM: GET https://storage.googleapis.com/storage/v1/b/{bucket}/iam +- Set bucket IAM: PUT https://storage.googleapis.com/storage/v1/b/{bucket}/iam +- Enable uniform access: PATCH with body { "iamConfiguration": { "uniformBucketLevelAccess": { "enabled": true } } } + queryParams: { "updateMask": "iamConfiguration" } +- Enable logging: PATCH with body { "logging": { "logBucket": "{log-bucket}", "logObjectPrefix": "{prefix}" } } + queryParams: { "updateMask": "logging" } +- BUCKET_LOCK_DISABLED: canAutoFix=false — locking retention policy is IRREVERSIBLE + +### Compute Engine (Firewall Rules) +- Get firewall: GET https://compute.googleapis.com/compute/v1/projects/{project}/global/firewalls/{firewall} +- Update firewall: PATCH https://compute.googleapis.com/compute/v1/projects/{project}/global/firewalls/{firewall} +- Enable firewall logging: PATCH with body { "logConfig": { "enable": true } } +- NOTE: Compute Engine operations are long-running — the executor polls automatically + +### Compute Engine (Instances) +- Get instance: GET https://compute.googleapis.com/compute/v1/projects/{project}/zones/{zone}/instances/{instance} +- Set metadata: POST https://compute.googleapis.com/compute/v1/projects/{project}/zones/{zone}/instances/{instance}/setMetadata +- Set project metadata: POST https://compute.googleapis.com/compute/v1/projects/{project}/setCommonInstanceMetadata +- CANNOT change on running instance (canAutoFix=false): service account, IP forwarding, API scopes, shielded VM config +- OS_LOGIN: set via project metadata { "items": [{ "key": "enable-oslogin", "value": "TRUE" }] } + +### Cloud SQL +- Get instance: GET https://sqladmin.googleapis.com/v1/projects/{project}/instances/{instance} +- Update instance: PATCH https://sqladmin.googleapis.com/v1/projects/{project}/instances/{instance} +- Set root password: PUT https://sqladmin.googleapis.com/v1/projects/{project}/instances/{instance}/users?name=root&host=%25 (body: { "password": "..." }) +- NOTE: Cloud SQL updates are long-running — the executor polls automatically +- CRITICAL: databaseFlags is a REPLACE operation. You MUST read ALL existing flags first and include them ALL in the PATCH body plus the new flag. Sending only the new flag DELETES all others. Body: { "settings": { "databaseFlags": [...allExistingFlags, { "name": "new_flag", "value": "on" }] } } +- SSL enforcement: PATCH with { "settings": { "ipConfiguration": { "requireSsl": true } } } +- Disable public IP: PATCH with { "settings": { "ipConfiguration": { "ipv4Enabled": false } } } — WARNING: may break connectivity if no private IP exists +- Enable backups: PATCH with { "settings": { "backupConfiguration": { "enabled": true, "pointInTimeRecoveryEnabled": true } } } + +### Cloud KMS +- Get crypto key: GET https://cloudkms.googleapis.com/v1/{keyName} +- Update rotation: PATCH https://cloudkms.googleapis.com/v1/{keyName} + queryParams: { "updateMask": "rotationPeriod,nextRotationTime" } +- rotationPeriod format: seconds with "s" suffix, e.g., "7776000s" for 90 days +- nextRotationTime format: RFC3339 timestamp, e.g., "2024-01-01T00:00:00Z" + +### Cloud Logging +- Get sinks: GET https://logging.googleapis.com/v2/projects/{project}/sinks +- Create sink: POST https://logging.googleapis.com/v2/projects/{project}/sinks +- Update sink: PATCH https://logging.googleapis.com/v2/projects/{project}/sinks/{sinkId} + queryParams: { "updateMask": "destination,filter" } +- Create log metric: POST https://logging.googleapis.com/v2/projects/{project}/metrics (body: { "name": "...", "filter": "...", "metricDescriptor": { "metricKind": "DELTA", "valueType": "INT64" } }) +- LOCKED_RETENTION_POLICY_NOT_SET: canAutoFix=false — locking is IRREVERSIBLE + +### Cloud Monitoring (Alert Policies) +- Create alert policy: POST https://monitoring.googleapis.com/v3/projects/{project}/alertPolicies +- List alert policies: GET https://monitoring.googleapis.com/v3/projects/{project}/alertPolicies +- For "NOT_MONITORED" findings: first create a log-based metric via Cloud Logging API, then create an alert policy referencing it +- Alert policy body example: { "displayName": "...", "conditions": [{ "displayName": "...", "conditionThreshold": { "filter": "metric.type=\"logging.googleapis.com/user/{metricName}\"", "comparison": "COMPARISON_GT", "thresholdValue": 0, "duration": "0s" } }], "combiner": "OR", "enabled": true, "notificationChannels": [] } + +### Cloud DNS +- Get managed zone: GET https://dns.googleapis.com/dns/v1/projects/{project}/managedZones/{zone} +- Update managed zone: PATCH https://dns.googleapis.com/dns/v1/projects/{project}/managedZones/{zone} +- Enable DNSSEC: PATCH with body { "dnssecConfig": { "state": "on" } } +- RSASHA1_FOR_SIGNING: canAutoFix=false — changing DNSSEC algorithms can cause DNS resolution failures + +### IAM / Resource Manager +- IMPORTANT: Use the v3 API for ALL IAM policy operations (v1 silently drops auditConfigs): +- Get IAM policy: POST https://cloudresourcemanager.googleapis.com/v3/projects/{project}:getIamPolicy (body: { "options": { "requestedPolicyVersion": 3 } }) +- Set IAM policy: POST https://cloudresourcemanager.googleapis.com/v3/projects/{project}:setIamPolicy (body: { "policy": { ...fullExistingPolicy } }) +- CRITICAL: setIamPolicy REPLACES the entire policy. You MUST include ALL existing bindings, etag, version, and auditConfigs from getIamPolicy. Only ADD or MODIFY the specific field you need. +- ALWAYS use requestedPolicyVersion: 3 in getIamPolicy — without it, auditConfigs and conditions are NOT returned. +- For audit logging: merge into "auditConfigs" array. Example: { "service": "allServices", "auditLogConfigs": [{"logType": "ADMIN_READ"}, {"logType": "DATA_READ"}, {"logType": "DATA_WRITE"}] } +- AUDIT_LOGGING_DISABLED: canAutoFix=false — auditConfigs cannot be set via setIamPolicy API. Direct user to: https://console.cloud.google.com/iam-admin/audit?project={projectId} +- MFA_NOT_ENFORCED: canAutoFix=false — requires Google Workspace admin console +- SERVICE_ACCOUNT_KEY_NOT_ROTATED: canAutoFix=false — key rotation requires distributing new keys +- USER_MANAGED_SERVICE_ACCOUNT_KEY: canAutoFix=false — requires migration to workload identity + +### VPC Network +- Get subnetwork: GET https://compute.googleapis.com/compute/v1/projects/{project}/regions/{region}/subnetworks/{subnet} +- Enable flow logs: PATCH https://compute.googleapis.com/compute/v1/projects/{project}/regions/{region}/subnetworks/{subnet} + body: { "logConfig": { "enable": true } } + +### BigQuery +- Get dataset: GET https://bigquery.googleapis.com/bigquery/v2/projects/{project}/datasets/{dataset} +- Update dataset: PATCH https://bigquery.googleapis.com/bigquery/v2/projects/{project}/datasets/{dataset} +- Remove public access: PATCH to remove "allAuthenticatedUsers" or "allUsers" from the access array +- Enable CMEK: PATCH with { "defaultEncryptionConfiguration": { "kmsKeyName": "..." } } + +### Cloud Armor / SSL Policies +- Get SSL policy: GET https://compute.googleapis.com/compute/v1/projects/{project}/global/sslPolicies/{policy} +- Update SSL policy: PATCH https://compute.googleapis.com/compute/v1/projects/{project}/global/sslPolicies/{policy} +- Fix weak SSL: PATCH with { "minTlsVersion": "TLS_1_2", "profile": "MODERN" } + +### GKE (Kubernetes Engine) +- Get cluster: GET https://container.googleapis.com/v1/projects/{project}/locations/{location}/clusters/{cluster} +- Update cluster: PUT https://container.googleapis.com/v1/projects/{project}/locations/{location}/clusters/{cluster} +- Most GKE findings are HIGH RISK. Set canAutoFix=false for: PRIVATE_CLUSTER_DISABLED, POD_SECURITY_POLICY_DISABLED, WORKLOAD_IDENTITY_DISABLED (require cluster recreation or significant downtime) +- Safe to auto-fix: LEGACY_AUTHORIZATION_ENABLED (PUT with { "legacyAbac": { "enabled": false } }), MASTER_AUTHORIZED_NETWORKS_DISABLED, CLUSTER_SHIELDED_NODES_DISABLED +- NOTE: GKE updates can take 10+ minutes and may cause brief control plane unavailability + +### Pub/Sub +- PUBSUB_CMEK_DISABLED: canAutoFix=false — CMEK cannot be added to existing topics, requires recreation + +## PARSING SCC FINDING EVIDENCE + +The finding evidence contains rich data from Security Command Center: +- resourceName: Full GCP resource path (e.g., "//storage.googleapis.com/buckets/my-bucket") +- category: SCC finding category (e.g., "PUBLIC_BUCKET_ACL", "OPEN_FIREWALL") +- projectDisplayName: GCP project name +- severity: CRITICAL, HIGH, MEDIUM, LOW +- externalUri: Link to the resource in GCP Console +- compliances: Compliance mappings (CIS, PCI-DSS, etc.) + +To convert resourceName to API URL: +- "//storage.googleapis.com/buckets/my-bucket" → https://storage.googleapis.com/storage/v1/b/my-bucket +- "//compute.googleapis.com/projects/my-proj/global/firewalls/my-fw" → https://compute.googleapis.com/compute/v1/projects/my-proj/global/firewalls/my-fw +- "//sqladmin.googleapis.com/projects/my-proj/instances/my-sql" → https://sqladmin.googleapis.com/v1/projects/my-proj/instances/my-sql +- "//cloudresourcemanager.googleapis.com/projects/my-proj" → project ID is "my-proj" + +## SAFETY RULES (NEVER violate) +- NEVER delete data, buckets, instances, databases, or VPCs +- NEVER modify IAM policies in ways that could lock out users +- NEVER remove existing firewall allow rules for private ranges (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16) +- PREFER enabling security features over disabling services +- ALWAYS make changes reversible when possible +- For firewall fixes: restrict source ranges, don't delete rules + +## IDEMPOTENCY +- All fix steps should be safe to run multiple times +- PATCH operations are naturally idempotent +- POST operations may need "already exists" handling (the executor handles 409 automatically) + +## WHEN TO SET canAutoFix=true +- Enable/disable features on existing resources (logging, encryption, uniform access) +- Update configuration (firewall source ranges, SQL flags, bucket policies) +- Enable DNSSEC, flow logs, audit logging +- Restrict public access (buckets, SQL instances, firewall rules) +- Enable key rotation +- ALWAYS provide rollback steps + +## WHEN TO SET canAutoFix=false (provide guidedSteps with gcloud commands instead) +- Resource recreation required (encryption on existing disks, shielded VM on running instance) +- Requires organizational policy changes (MFA_NOT_ENFORCED) +- Requires changing service accounts/scopes on running instances (DEFAULT_SERVICE_ACCOUNT_USED, FULL_API_ACCESS) +- Requires network architecture changes (IP_FORWARDING_ENABLED, PRIVATE_CLUSTER_DISABLED) +- Instance-level changes requiring restart (COMPUTE_SECURE_BOOT_DISABLED, SHIELDED_VM_DISABLED) +- Irreversible operations (BUCKET_LOCK_DISABLED, LOCKED_RETENTION_POLICY_NOT_SET) +- Key management lifecycle (SERVICE_ACCOUNT_KEY_NOT_ROTATED, USER_MANAGED_SERVICE_ACCOUNT_KEY) +- Resource recreation needed (PUBSUB_CMEK_DISABLED, WORKLOAD_IDENTITY_DISABLED) +- DNSSEC algorithm changes (RSASHA1_FOR_SIGNING — risk of DNS outage) +- The resource in the finding doesn't exist or has been deleted + +## RISK ASSESSMENT +- low: Enabling features with no impact (logging, DNSSEC, key rotation) +- medium: Restricting access patterns (firewall rules, public access prevention) +- high: Changes affecting production traffic or database settings +- critical: Irreversible changes or IAM modifications + +## ROLLBACK PATTERNS +- PATCH operations: rollback by PATCHing back to original values (from read step) +- POST (create): rollback by DELETE (only for resources WE created) +- IAM changes: rollback by setting back the original policy (ALWAYS read first) +- Use the readStep results to capture the exact previous state for rollback + +## GUIDED STEPS FORMAT (when canAutoFix=false) +When you set canAutoFix=false, you MUST provide clear guidedSteps: +- Each step should be SHORT and clear — one action per step +- Start with opening the GCP Console link (use externalUri from evidence if available) +- Include the specific GCP Console path: e.g., "Go to IAM & Admin > Audit Logs" +- Include gcloud CLI commands when applicable, wrapped in triple backticks +- Keep each step under 2-3 sentences + 1 command block +- Reference the specific project name from the evidence +- End with how to verify the fix was applied + +## CRITICAL RULES +1. ALWAYS use readSteps to get the CURRENT state before fixing +2. NEVER use placeholder values — use concrete values from evidence +3. For PATCH requests, ALWAYS specify updateMask in queryParams +4. URLs must start with https:// and contain googleapis.com +5. currentState and proposedState must use the SAME keys for comparison +6. The fix must address the EXACT issue the SCC finding reports +7. For getIamPolicy: ALWAYS include body { "options": { "requestedPolicyVersion": 3 } } — without this, auditConfigs and conditions are NOT returned +8. For setIamPolicy: body MUST be { "policy": }. Include etag, version, ALL bindings, and auditConfigs. A partial policy DELETES everything not included`; + +// ─── Prompt Builders ──────────────────────────────────────────────────────── + +export function buildGcpFixPlanPrompt(finding: { + title: string; + description: string | null; + severity: string | null; + resourceType: string; + resourceId: string; + remediation: string | null; + findingKey: string; + evidence: Record; +}): string { + return `Analyze this GCP Security Command Center finding and generate a fix plan using GCP REST API calls. + +IMPORTANT: Your fix must change the EXACT GCP resource/setting that caused this finding. The SCC will re-check the same thing. + +FINDING: +- Title: ${finding.title} +- Description: ${finding.description ?? 'N/A'} +- Severity: ${finding.severity ?? 'medium'} +- Resource Type: ${finding.resourceType} +- Resource ID: ${finding.resourceId} +- Finding Key: ${finding.findingKey} +- Existing Remediation Guidance: ${finding.remediation ?? 'None'} +- Evidence: ${JSON.stringify(finding.evidence, null, 2)} + +Generate the fix plan following all the rules in your instructions.`; +} diff --git a/apps/api/src/cloud-security/gcp-command-executor.ts b/apps/api/src/cloud-security/gcp-command-executor.ts new file mode 100644 index 0000000000..eab8d8e93b --- /dev/null +++ b/apps/api/src/cloud-security/gcp-command-executor.ts @@ -0,0 +1,440 @@ +import { Logger } from '@nestjs/common'; + +const logger = new Logger('GcpCommandExecutor'); + +const MAX_STEP_RETRIES = 3; +const MAX_POLL_MS = 120_000; + +// ─── Types ───────────────────────────────────────────────────────────────── + +export interface GcpApiStep { + method: 'GET' | 'POST' | 'PUT' | 'PATCH' | 'DELETE'; + url: string; + body?: Record; + queryParams?: Record; + purpose: string; +} + +interface GcpStepResult { + step: GcpApiStep; + output: Record; +} + +export interface GcpExecutionResult { + results: GcpStepResult[]; + error?: { + stepIndex: number; + step: GcpApiStep; + message: string; + }; +} + +// ─── Multi-Step Execution ────────────────────────────────────────────────── + +/** + * Execute GCP API steps sequentially with self-healing: + * - Retries on 429 throttling and 5xx server errors + * - Auto-enables disabled GCP APIs + * - Polls long-running operations + * - Auto-rolls back on partial failure + */ +export async function executeGcpPlanSteps(params: { + steps: GcpApiStep[]; + accessToken: string; + autoRollbackSteps?: GcpApiStep[]; + isRollback?: boolean; +}): Promise { + // Validate ALL step URLs before executing any — prevents SSRF on read/fix/rollback steps + const allSteps = [...params.steps, ...(params.autoRollbackSteps ?? [])]; + const validationErrors = validateGcpPlanSteps(allSteps); + if (validationErrors.length > 0) { + return { + results: [], + error: { + stepIndex: 0, + step: params.steps[0] ?? allSteps[0], + message: `URL validation failed: ${validationErrors.join('; ')}`, + }, + }; + } + + const results: GcpStepResult[] = []; + + for (let i = 0; i < params.steps.length; i++) { + const step = params.steps[i]; + try { + const output = await executeWithRetry( + step, + params.accessToken, + params.isRollback, + ); + results.push({ step, output }); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + logger.error( + `Step ${i + 1} failed: ${step.method} ${step.url} — ${message}`, + ); + + // Auto-rollback completed steps + if (params.autoRollbackSteps && i > 0) { + logger.log(`Auto-rolling back ${i} completed steps...`); + for ( + let j = Math.min(i - 1, params.autoRollbackSteps.length - 1); + j >= 0; + j-- + ) { + try { + await executeWithRetry( + params.autoRollbackSteps[j], + params.accessToken, + true, + ); + logger.log(`Rollback step ${j} succeeded`); + } catch (rbErr) { + logger.warn( + `Rollback step ${j} failed: ${rbErr instanceof Error ? rbErr.message : String(rbErr)}`, + ); + } + } + } + + return { results, error: { stepIndex: i, step, message } }; + } + } + + return { results }; +} + +// ─── Single Step with Retry ──────────────────────────────────────────────── + +async function executeWithRetry( + step: GcpApiStep, + accessToken: string, + isRollback?: boolean, +): Promise> { + if (step.method === 'DELETE' && !isRollback) { + throw new Error( + `DELETE operations are blocked for safety. Step: ${step.purpose}`, + ); + } + + for (let attempt = 0; attempt < MAX_STEP_RETRIES; attempt++) { + try { + return await executeOnce(step, accessToken); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + const canRetry = attempt < MAX_STEP_RETRIES - 1; + + // 429 Throttled → wait and retry + if (msg.includes('429') && canRetry) { + const delay = 3000 * (attempt + 1); + logger.warn( + `Throttled (429), waiting ${delay}ms before retry ${attempt + 1}`, + ); + await new Promise((r) => setTimeout(r, delay)); + continue; + } + + // 5xx Server error → wait and retry + if (/50[0-9]/.test(msg) && canRetry) { + logger.warn(`Server error, retrying in 2s...`); + await new Promise((r) => setTimeout(r, 2000)); + continue; + } + + // API not enabled → auto-enable and retry + if ( + (msg.includes('has not been used') || + msg.includes('is not enabled') || + msg.includes('SERVICE_DISABLED')) && + canRetry + ) { + const apiMatch = msg.match(/([\w.-]+\.googleapis\.com)/); + if (apiMatch) { + await enableGcpApi(accessToken, step.url, apiMatch[1]); + continue; + } + } + + // Resource in progress → wait and retry + if ( + (msg.includes('RESOURCE_IN_USE') || + msg.includes('already being') || + msg.includes('operation is in progress')) && + canRetry + ) { + logger.warn('Resource busy, waiting 10s...'); + await new Promise((r) => setTimeout(r, 10_000)); + continue; + } + + // Not retryable → throw + throw error; + } + } + + throw new Error('Max retries exceeded'); +} + +// ─── Single API Call ─────────────────────────────────────────────────────── + +async function executeOnce( + step: GcpApiStep, + accessToken: string, +): Promise> { + let url = step.url; + + // Auto-upgrade CRM v1 IAM policy URLs to v3 (v1 silently drops auditConfigs) + if ( + url.includes('cloudresourcemanager.googleapis.com/v1/projects/') && + (url.includes(':getIamPolicy') || url.includes(':setIamPolicy')) + ) { + url = url.replace( + 'cloudresourcemanager.googleapis.com/v1/projects/', + 'cloudresourcemanager.googleapis.com/v3/projects/', + ); + } + + if (step.queryParams && Object.keys(step.queryParams).length > 0) { + const qs = new URLSearchParams(step.queryParams); + url += (url.includes('?') ? '&' : '?') + qs.toString(); + } + + // Auto-inject requestedPolicyVersion: 3 for getIamPolicy calls + // Without this, GCP returns v1 policies that omit auditConfigs and conditions + let effectiveBody = step.body; + if ( + step.method === 'POST' && + step.url.includes(':getIamPolicy') && + (!step.body || !(step.body as Record).options) + ) { + effectiveBody = { + ...step.body, + options: { requestedPolicyVersion: 3 }, + }; + } + + logger.log(`${step.method} ${url} — ${step.purpose}`); + if (effectiveBody && (step.method === 'POST' || step.method === 'PUT' || step.method === 'PATCH')) { + const bodyStr = JSON.stringify(effectiveBody); + logger.debug(` Body (${bodyStr.length} chars): ${bodyStr.substring(0, 2000)}${bodyStr.length > 2000 ? '...' : ''}`); + } + + const response = await fetch(url, { + method: step.method, + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + body: effectiveBody ? JSON.stringify(effectiveBody) : undefined, + }); + + if (!response.ok) { + return handleErrorResponse(response, step); + } + + if (response.status === 204) { + return { success: true }; + } + + let data: Record; + try { + data = (await response.json()) as Record; + } catch { + return { success: true }; + } + + // Poll long-running operations + if (isGcpOperation(data)) { + return waitForOperation(data, accessToken); + } + + return data; +} + +async function handleErrorResponse( + response: Response, + step: GcpApiStep, +): Promise> { + let errorBody: Record = {}; + const rawText = await response.text(); + try { + errorBody = JSON.parse(rawText) as Record; + } catch { + errorBody = { message: rawText }; + } + + const gcpError = errorBody.error as Record | undefined; + const errorMessage = + (gcpError?.message as string) ?? JSON.stringify(errorBody).slice(0, 500); + const errorStatus = (gcpError?.status as string) ?? ''; + + // 409 = already exists → treat as success (idempotent) + if (response.status === 409 || errorStatus === 'ALREADY_EXISTS') { + logger.log(`Already exists (success): ${step.purpose}`); + return { _alreadyExists: true, status: 409 }; + } + + // 404 on GET = resource not found (useful for read steps) + if (response.status === 404 && step.method === 'GET') { + return { _notFound: true, status: 404 }; + } + + if (response.status === 401) { + throw new Error( + 'GCP authentication failed. Access token may have expired. Please reconnect.', + ); + } + + if (response.status === 403 || errorStatus === 'PERMISSION_DENIED') { + throw new Error(`Permission denied: ${errorMessage}`); + } + + // Include status code in error for retry logic detection + throw new Error(`GCP API error (${response.status}): ${errorMessage}`); +} + +// ─── GCP API Auto-Enable ───────────────────────────────────────────────── + +async function enableGcpApi( + accessToken: string, + stepUrl: string, + apiName: string, +): Promise { + // Extract project ID from the step URL + const projectMatch = stepUrl.match(/\/projects\/([^/]+)/); + if (!projectMatch) { + logger.warn(`Cannot extract project ID from URL to enable API: ${apiName}`); + return; + } + + logger.log(`Auto-enabling GCP API: ${apiName} in project ${projectMatch[1]}`); + try { + const resp = await fetch( + `https://serviceusage.googleapis.com/v1/projects/${projectMatch[1]}/services/${apiName}:enable`, + { + method: 'POST', + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }, + ); + + if (resp.ok || resp.status === 409) { + logger.log(`API ${apiName} enabled — waiting 10s for propagation`); + await new Promise((r) => setTimeout(r, 10_000)); + } else { + logger.warn(`Failed to enable ${apiName}: ${resp.status}`); + } + } catch (err) { + logger.warn( + `API enablement error: ${err instanceof Error ? err.message : String(err)}`, + ); + } +} + +// ─── Long-Running Operation Polling ──────────────────────────────────────── + +function isGcpOperation(data: Record): boolean { + const kind = data.kind as string | undefined; + if (kind && kind.includes('#operation')) return true; + if (data.operationType && data.status) return true; + return false; +} + +async function waitForOperation( + operation: Record, + accessToken: string, +): Promise> { + const selfLink = operation.selfLink as string; + if (!selfLink) { + logger.warn('Operation has no selfLink — returning without polling'); + return operation; + } + + // Validate selfLink URL to prevent SSRF via response data + try { + const parsed = new URL(selfLink); + const host = parsed.hostname.toLowerCase(); + if (host !== 'googleapis.com' && !host.endsWith('.googleapis.com')) { + logger.warn(`Operation selfLink targets disallowed host: ${host}`); + return operation; + } + } catch { + logger.warn('Operation selfLink is malformed'); + return operation; + } + + const startTime = Date.now(); + let pollInterval = 2000; + + while (Date.now() - startTime < MAX_POLL_MS) { + await new Promise((r) => setTimeout(r, pollInterval)); + pollInterval = Math.min(pollInterval * 1.5, 10_000); + + const resp = await fetch(selfLink, { + headers: { Authorization: `Bearer ${accessToken}` }, + }); + + if (!resp.ok) { + logger.warn(`Operation poll failed: ${resp.status}`); + return operation; + } + + const updated = (await resp.json()) as Record; + if (updated.status === 'DONE') { + if (updated.error) { + const errors = (updated.error as Record).errors as + | Array<{ message: string }> + | undefined; + if (errors?.length) { + throw new Error(`GCP operation failed: ${errors[0].message}`); + } + } + return updated; + } + } + + logger.warn(`Operation timed out after ${MAX_POLL_MS}ms`); + return operation; +} + +// ─── Validation ──────────────────────────────────────────────────────────── + +export function validateGcpPlanSteps(steps: GcpApiStep[]): string[] { + const errors: string[] = []; + for (let i = 0; i < steps.length; i++) { + const step = steps[i]; + if (!step.url) { + errors.push(`Step ${i + 1}: URL is required`); + continue; + } + if (!step.method) errors.push(`Step ${i + 1}: method is required`); + // POST/PUT/PATCH to mutation endpoints must have a body + if ( + (step.method === 'POST' || step.method === 'PUT' || step.method === 'PATCH') && + !step.url.includes(':getIamPolicy') && + !step.url.includes('/stop') && + !step.url.includes('/start') && + (!step.body || Object.keys(step.body).length === 0) + ) { + errors.push( + `Step ${i + 1}: ${step.method} ${step.url.split('/').pop()} requires a request body but none was provided`, + ); + } + try { + const parsed = new URL(step.url); + if (parsed.protocol !== 'https:') { + errors.push(`Step ${i + 1}: URL must use HTTPS`); + } + const host = parsed.hostname.toLowerCase(); + if (host !== 'googleapis.com' && !host.endsWith('.googleapis.com')) { + errors.push(`Step ${i + 1}: URL must be a Google API endpoint`); + } + } catch { + errors.push(`Step ${i + 1}: URL must be a valid absolute URL`); + } + } + return errors; +} diff --git a/apps/api/src/cloud-security/gcp-remediation.service.ts b/apps/api/src/cloud-security/gcp-remediation.service.ts new file mode 100644 index 0000000000..079d439c5d --- /dev/null +++ b/apps/api/src/cloud-security/gcp-remediation.service.ts @@ -0,0 +1,723 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { db, Prisma } from '@db'; +import { getManifest } from '@trycompai/integration-platform'; +import { CredentialVaultService } from '../integration-platform/services/credential-vault.service'; +import { OAuthCredentialsService } from '../integration-platform/services/oauth-credentials.service'; +import { AiRemediationService } from './ai-remediation.service'; +import { parseGcpPermissionError } from './remediation-error.utils'; +import { + executeGcpPlanSteps, + validateGcpPlanSteps, +} from './gcp-command-executor'; +import type { GcpFixPlan, GcpApiStep } from './gcp-ai-remediation.prompt'; + +@Injectable() +export class GcpRemediationService { + private readonly logger = new Logger(GcpRemediationService.name); + private readonly planCache = new Map< + string, + { plan: GcpFixPlan; timestamp: number } + >(); + private readonly PLAN_CACHE_MAX = 100; + private readonly PLAN_CACHE_TTL = 5 * 60 * 1000; + + private evictStalePlans() { + if (this.planCache.size <= this.PLAN_CACHE_MAX) return; + const now = Date.now(); + for (const [key, entry] of this.planCache) { + if (now - entry.timestamp > this.PLAN_CACHE_TTL) + this.planCache.delete(key); + } + while (this.planCache.size > this.PLAN_CACHE_MAX) { + const firstKey = this.planCache.keys().next().value; + if (firstKey) this.planCache.delete(firstKey); + else break; + } + } + + constructor( + private readonly credentialVaultService: CredentialVaultService, + private readonly oauthCredentialsService: OAuthCredentialsService, + private readonly aiRemediationService: AiRemediationService, + ) {} + + async getCapabilities(params: { + connectionId: string; + organizationId: string; + }) { + const credentials = + await this.credentialVaultService.getDecryptedCredentials( + params.connectionId, + ); + + return { + enabled: Boolean(credentials?.access_token), + aiPowered: true, + remediations: [], + }; + } + + async previewRemediation(params: { + connectionId: string; + organizationId: string; + checkResultId: string; + remediationKey: string; + }) { + const { finding, accessToken } = await this.resolveContext(params); + const evidence = (finding.evidence ?? {}) as Record; + const findingKey = evidence.findingKey as string; + + const plan = await this.aiRemediationService.generateGcpFixPlan({ + title: finding.title ?? 'Unknown', + description: finding.description, + severity: finding.severity, + resourceType: finding.resourceType, + resourceId: finding.resourceId, + remediation: finding.remediation, + findingKey, + evidence, + }); + + if (!plan.canAutoFix) { + return { + currentState: plan.currentState, + proposedState: {}, + description: plan.description, + risk: plan.risk, + apiCalls: [], + guidedOnly: true, + guidedSteps: plan.guidedSteps ?? [plan.reason ?? plan.description], + rollbackSupported: false, + requiresAcknowledgment: undefined, + }; + } + + // Execute read steps to get real GCP state + if (plan.readSteps.length > 0) { + const readErrors = validateGcpPlanSteps(plan.readSteps); + if (readErrors.length === 0) { + try { + const readResult = await executeGcpPlanSteps({ + steps: plan.readSteps, + accessToken, + }); + const realState = readResult.results.reduce( + (acc, r) => ({ ...acc, [r.step.purpose]: r.output }), + {} as Record, + ); + + const refined = await this.aiRemediationService.refineGcpFixPlan({ + finding: { + title: finding.title ?? 'Unknown', + description: finding.description, + severity: finding.severity, + resourceType: finding.resourceType, + resourceId: finding.resourceId, + remediation: finding.remediation, + findingKey, + evidence, + }, + originalPlan: plan, + realGcpState: realState, + }); + + if (!refined.canAutoFix) { + return { + currentState: refined.currentState, + proposedState: {}, + description: refined.description, + risk: refined.risk, + apiCalls: [], + guidedOnly: true, + guidedSteps: refined.guidedSteps ?? [ + refined.reason ?? refined.description, + ], + rollbackSupported: false, + requiresAcknowledgment: undefined, + }; + } + + this.evictStalePlans(); + this.planCache.set( + `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`, + { + plan: refined, + timestamp: Date.now(), + }, + ); + + return this.buildPreviewResponse(refined); + } catch { + // Fall through to show initial plan + } + } + } + + // Fallback: show initial AI plan without real data + this.evictStalePlans(); + this.planCache.set( + `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`, + { + plan, + timestamp: Date.now(), + }, + ); + return this.buildPreviewResponse(plan); + } + + async executeRemediation(params: { + connectionId: string; + organizationId: string; + checkResultId: string; + remediationKey: string; + userId: string; + acknowledgment?: string; + }) { + const { finding, accessToken } = await this.resolveContext(params); + + // Get plan from cache or regenerate + let plan: GcpFixPlan; + const cached = this.planCache.get( + `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`, + ); + if (cached && Date.now() - cached.timestamp < 5 * 60 * 1000) { + plan = cached.plan; + } else { + const evidence = (finding.evidence ?? {}) as Record; + plan = await this.aiRemediationService.generateGcpFixPlan({ + title: finding.title ?? 'Unknown', + description: finding.description, + severity: finding.severity, + resourceType: finding.resourceType, + resourceId: finding.resourceId, + remediation: finding.remediation, + findingKey: evidence.findingKey as string, + evidence, + }); + } + + if (!plan.canAutoFix) { + throw new Error( + 'This finding requires manual remediation and cannot be auto-fixed.', + ); + } + if (!plan.fixSteps || plan.fixSteps.length === 0) { + throw new Error('AI generated an empty fix plan. Cannot proceed.'); + } + if (!params.acknowledgment || params.acknowledgment !== 'acknowledged') { + throw new Error( + 'Acknowledgment is required before executing any remediation.', + ); + } + + const action = await db.remediationAction.create({ + data: { + checkResultId: params.checkResultId, + connectionId: params.connectionId, + organizationId: params.organizationId, + initiatedById: params.userId, + remediationKey: params.remediationKey, + resourceId: finding.resourceId, + resourceType: finding.resourceType, + previousState: {}, + appliedState: {}, + status: 'executing', + riskLevel: plan.risk, + acknowledgmentText: params.acknowledgment ?? null, + acknowledgedAt: params.acknowledgment ? new Date() : null, + }, + }); + + let previousState: Record = {}; + let fixResult: { results: Array<{ step: GcpApiStep; output: unknown }>; error?: { stepIndex: number; step: GcpApiStep; message: string } } | undefined; + + try { + // Phase 1: Execute read steps to get real state + if (plan.readSteps.length > 0) { + const readErrors = validateGcpPlanSteps(plan.readSteps); + if (readErrors.length > 0) { + throw new Error(`Invalid read steps: ${readErrors.join('; ')}`); + } + const readResult = await executeGcpPlanSteps({ + steps: plan.readSteps, + accessToken, + }); + previousState = readResult.results.reduce( + (acc, r) => ({ ...acc, [r.step.purpose]: r.output }), + {} as Record, + ); + } + + // Phase 2: Refine plan with real data + const evidence = (finding.evidence ?? {}) as Record; + let refinedPlan = await this.aiRemediationService.refineGcpFixPlan({ + finding: { + title: finding.title ?? 'Unknown', + description: finding.description, + severity: finding.severity, + resourceType: finding.resourceType, + resourceId: finding.resourceId, + remediation: finding.remediation, + findingKey: evidence.findingKey as string, + evidence, + }, + originalPlan: plan, + realGcpState: previousState, + }); + + if (!refinedPlan.canAutoFix) { + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'failed', + errorMessage: refinedPlan.reason ?? 'Cannot be auto-fixed.', + }, + }); + return { + actionId: action.id, + status: 'failed' as const, + resourceId: finding.resourceId, + error: + refinedPlan.reason ?? + 'This finding requires manual setup before auto-fix is possible.', + guidedSteps: refinedPlan.guidedSteps, + }; + } + + if (!refinedPlan.fixSteps || refinedPlan.fixSteps.length === 0) { + throw new Error('AI refined plan has no fix steps. Cannot proceed.'); + } + let fixErrors = validateGcpPlanSteps(refinedPlan.fixSteps); + if (fixErrors.length > 0) { + this.logger.warn( + `Fix plan validation failed: ${fixErrors.join('; ')} — retrying with error context`, + ); + const retryPlan = await this.aiRemediationService.refineGcpFixPlan({ + finding: { + title: finding.title ?? 'Unknown', + description: finding.description, + severity: finding.severity, + resourceType: finding.resourceType, + resourceId: finding.resourceId, + remediation: finding.remediation, + findingKey: evidence.findingKey as string, + evidence, + }, + originalPlan: refinedPlan, + realGcpState: { + ...previousState, + _validationErrors: fixErrors, + }, + }); + refinedPlan = retryPlan; + fixErrors = validateGcpPlanSteps(refinedPlan.fixSteps); + if (fixErrors.length > 0) { + throw new Error(`Invalid fix steps after retry: ${fixErrors.join('; ')}`); + } + } + + // Phase 3: Execute fix steps with self-healing retry + // (executor auto-handles: API enablement, throttling, retries, long-running ops) + for (const step of refinedPlan.fixSteps) { + this.logger.log( + `Fix step: ${step.method} ${step.url} — ${step.purpose}`, + ); + } + + let currentPlan = refinedPlan; + fixResult = await executeGcpPlanSteps({ + steps: currentPlan.fixSteps, + accessToken, + autoRollbackSteps: currentPlan.rollbackSteps, + }); + + // Self-healing: if non-permission error, regenerate plan with error context and retry + if (fixResult.error) { + const isPermError = + fixResult.error.message.includes('Permission denied') || + fixResult.error.message.includes('PERMISSION_DENIED'); + + if (!isPermError) { + this.logger.log( + 'Non-permission error — regenerating fix plan with error context...', + ); + const retryPlan = await this.aiRemediationService.refineGcpFixPlan({ + finding: { + title: finding.title ?? 'Unknown', + description: finding.description, + severity: finding.severity, + resourceType: finding.resourceType, + resourceId: finding.resourceId, + remediation: finding.remediation, + findingKey: evidence.findingKey as string, + evidence, + }, + originalPlan: currentPlan, + realGcpState: { + ...previousState, + _lastError: fixResult.error.message, + _failedStep: fixResult.error.step, + }, + }); + + if (retryPlan.canAutoFix && retryPlan.fixSteps.length > 0) { + this.logger.log( + `Retrying with regenerated plan (${retryPlan.fixSteps.length} steps)...`, + ); + currentPlan = retryPlan; + fixResult = await executeGcpPlanSteps({ + steps: currentPlan.fixSteps, + accessToken, + autoRollbackSteps: currentPlan.rollbackSteps, + }); + } + } + } + + if (fixResult.error) { + throw new Error(fixResult.error.message); + } + + // Log step results + for (const r of fixResult.results) { + this.logger.log(`Step result: ${r.step.method} ${r.step.url} → OK`); + } + + // Phase 4: Verify — check the fix step responses for success indicators + let verified = false; + + // Primary verification: check if the API response from the fix step + // contains the expected changes (e.g., setIamPolicy returns the updated policy) + for (const r of fixResult.results) { + const output = r.output as Record | undefined; + if (!output) continue; + // setIamPolicy returns the updated policy — check if auditConfigs present + if ( + r.step.url.includes(':setIamPolicy') && + Array.isArray(output.auditConfigs) && + (output.auditConfigs as unknown[]).length > 0 + ) { + verified = true; + } + // Generic: if the API returned a non-empty response, the call succeeded + if (Object.keys(output).length > 0 && !verified) { + verified = true; + } + } + + // Fallback verification: re-read and compare (for non-IAM fixes) + if (!verified && currentPlan.readSteps.length > 0) { + await new Promise((r) => setTimeout(r, 2000)); + const verifyResult = await executeGcpPlanSteps({ + steps: currentPlan.readSteps, + accessToken, + }); + const postFixState: Record = {}; + for (const r of verifyResult.results) { + postFixState[r.step.purpose] = r.output; + } + const stripVolatile = (obj: unknown): unknown => { + if (!obj || typeof obj !== 'object') return obj; + if (Array.isArray(obj)) return obj.map(stripVolatile); + const cleaned: Record = {}; + for (const [k, v] of Object.entries( + obj as Record, + )) { + if (k === 'etag' || k === 'updateTime' || k === 'createTime') + continue; + cleaned[k] = stripVolatile(v); + } + return cleaned; + }; + const preStr = JSON.stringify(stripVolatile(previousState)); + const postStr = JSON.stringify(stripVolatile(postFixState)); + verified = postStr !== preStr; + if (!verified) { + this.logger.warn( + `Fix executed but verification shows no state change for ${finding.resourceId}`, + ); + } + } + + const appliedState = { + steps: fixResult.results.map((r) => ({ + command: `${r.step.method} ${r.step.url}`, + purpose: r.step.purpose, + output: r.output, + })), + rollbackSteps: currentPlan.rollbackSteps, + verified, + }; + + const status = verified ? 'success' : 'unverified'; + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status, + previousState: previousState as Prisma.InputJsonValue, + appliedState: appliedState as unknown as Prisma.InputJsonValue, + executedAt: new Date(), + }, + }); + + this.logger.log( + `GCP remediation executed on ${finding.resourceId} (verified: ${verified})`, + ); + this.planCache.delete( + `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`, + ); + + return { + actionId: action.id, + status: status, + resourceId: finding.resourceId, + previousState, + appliedState, + }; + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : String(error); + + // Parse GCP permission errors and provide actionable fix + const evidence = (finding.evidence ?? {}) as Record; + const projectId = (evidence.projectDisplayName as string) ?? undefined; + const permInfo = parseGcpPermissionError(errorMessage, projectId); + + let permissionError: + | { missingActions: string[]; fixScript?: string } + | undefined; + if (permInfo.isPermissionError) { + permissionError = { + missingActions: permInfo.missingPermissions, + ...(permInfo.fixScript && { fixScript: permInfo.fixScript }), + }; + } + + const hasAutoRollback = Boolean( + fixResult?.error && fixResult.results.length > 0, + ); + + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'failed', + errorMessage, + previousState: previousState as Prisma.InputJsonValue, + appliedState: { + autoRollbackAttempted: hasAutoRollback, + failedAtStep: fixResult?.error?.stepIndex, + completedSteps: fixResult?.results.length ?? 0, + ...(permissionError && { + missingPermissions: permissionError.missingActions, + suggestedFix: permissionError.fixScript, + }), + } as unknown as Prisma.InputJsonValue, + }, + }); + + this.logger.error( + `GCP remediation failed: ${errorMessage}${hasAutoRollback ? ' (auto-rollback attempted)' : ''}${permInfo.isPermissionError ? ` | Missing: ${permInfo.missingPermissions.join(', ')}` : ''}`, + ); + + return { + actionId: action.id, + status: 'failed' as const, + resourceId: finding.resourceId, + error: errorMessage, + ...(permissionError && { permissionError }), + }; + } + } + + async rollbackRemediation(params: { + actionId: string; + organizationId: string; + }) { + const action = await db.remediationAction.findFirst({ + where: { id: params.actionId, organizationId: params.organizationId }, + }); + + if (!action) throw new Error('Remediation action not found'); + if (action.status !== 'success' && action.status !== 'unverified') { + throw new Error(`Cannot rollback action with status "${action.status}"`); + } + + const appliedState = action.appliedState as Record; + const rollbackSteps = (appliedState.rollbackSteps ?? []) as GcpApiStep[]; + + if (rollbackSteps.length === 0) { + throw new Error('No rollback steps available for this action'); + } + + const accessToken = await this.getValidGcpToken( + action.connectionId, + action.organizationId, + ); + + try { + this.logger.log( + `Rolling back GCP action ${action.id}: ${rollbackSteps.length} steps`, + ); + for (const step of rollbackSteps) { + this.logger.log( + `Rollback step: ${step.method} ${step.url} — ${step.purpose}`, + ); + } + + const result = await executeGcpPlanSteps({ + steps: rollbackSteps, + accessToken, + isRollback: true, + }); + + // Log each rollback step result + for (const r of result.results) { + this.logger.log(`Rollback result: ${r.step.method} ${r.step.url} → OK`); + } + + if (result.error) throw new Error(result.error.message); + + await db.remediationAction.update({ + where: { id: action.id }, + data: { status: 'rolled_back', rolledBackAt: new Date() }, + }); + + this.logger.log( + `GCP rollback: ${action.remediationKey} on ${action.resourceId}`, + ); + + return { + status: 'rolled_back' as const, + connectionId: action.connectionId, + remediationKey: action.remediationKey, + resourceId: action.resourceId, + }; + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : String(error); + + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'rollback_failed', + errorMessage: `Rollback failed: ${errorMessage}`, + }, + }); + + // If permission error, include actionable info + const permInfo = parseGcpPermissionError(errorMessage); + if (permInfo.isPermissionError) { + throw new Error( + JSON.stringify({ + message: 'Rollback failed: missing permissions', + missingActions: permInfo.missingPermissions, + script: permInfo.fixScript, + }), + ); + } + + throw new Error(`Rollback failed: ${errorMessage}`); + } + } + + // ─── Private helpers ────────────────────────────────────────────────── + + private async resolveContext(params: { + connectionId: string; + organizationId: string; + checkResultId: string; + remediationKey: string; + }) { + const connection = await db.integrationConnection.findFirst({ + where: { + id: params.connectionId, + organizationId: params.organizationId, + status: 'active', + }, + include: { provider: true }, + }); + if (!connection) throw new Error('Connection not found or inactive'); + if (connection.provider.slug !== 'gcp') { + throw new Error('This service only handles GCP connections'); + } + + const finding = await db.integrationCheckResult.findFirst({ + where: { + id: params.checkResultId, + checkRun: { connectionId: params.connectionId }, + }, + }); + if (!finding) throw new Error('Finding not found'); + + const accessToken = await this.getValidGcpToken( + params.connectionId, + params.organizationId, + ); + + return { finding, accessToken }; + } + + /** + * Get a valid GCP access token, refreshing if expired. + */ + private async getValidGcpToken( + connectionId: string, + organizationId: string, + ): Promise { + const manifest = getManifest('gcp'); + const oauthConfig = manifest?.auth?.type === 'oauth2' ? manifest.auth.config : null; + + if (oauthConfig) { + const oauthCreds = await this.oauthCredentialsService.getCredentials( + 'gcp', + organizationId, + ); + if (oauthCreds) { + const token = await this.credentialVaultService.getValidAccessToken( + connectionId, + { + tokenUrl: oauthConfig.tokenUrl, + clientId: oauthCreds.clientId, + clientSecret: oauthCreds.clientSecret, + clientAuthMethod: oauthConfig.clientAuthMethod, + }, + ); + if (token) return token; + } + } + + // Fallback to raw credentials if refresh fails + const credentials = + await this.credentialVaultService.getDecryptedCredentials(connectionId); + const token = credentials?.access_token as string; + if (!token) { + throw new Error( + 'GCP access token not found. Please reconnect the integration.', + ); + } + return token; + } + + private buildPreviewResponse(plan: GcpFixPlan) { + const apiCalls = plan.fixSteps.map((s) => { + try { + return `${s.method} ${new URL(s.url).pathname}`; + } catch { + return `${s.method} ${s.url}`; + } + }); + + return { + currentState: plan.currentState, + proposedState: plan.proposedState, + description: plan.description, + risk: plan.risk, + apiCalls, + guidedOnly: false, + rollbackSupported: plan.rollbackSupported, + requiresAcknowledgment: 'checkbox' as const, + acknowledgmentMessage: + 'This fix will modify your GCP infrastructure. Please review the changes above before proceeding.', + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws-security.service.ts b/apps/api/src/cloud-security/providers/aws-security.service.ts index adea6ca2b1..a00b676e67 100644 --- a/apps/api/src/cloud-security/providers/aws-security.service.ts +++ b/apps/api/src/cloud-security/providers/aws-security.service.ts @@ -1,25 +1,114 @@ import { Injectable, Logger } from '@nestjs/common'; import { AssumeRoleCommand, STSClient } from '@aws-sdk/client-sts'; import { - GetFindingsCommand, - SecurityHubClient, - type GetFindingsCommandInput, -} from '@aws-sdk/client-securityhub'; + CostExplorerClient, + GetCostAndUsageCommand, +} from '@aws-sdk/client-cost-explorer'; import type { SecurityFinding } from '../cloud-security.service'; - -type AwsCredentials = { - accessKeyId: string; - secretAccessKey: string; - sessionToken?: string; -}; +import type { + AwsCredentials, + AwsServiceAdapter, +} from './aws/aws-service-adapter'; +import { IamAdapter } from './aws/iam.adapter'; +import { CloudTrailAdapter } from './aws/cloudtrail.adapter'; +import { S3Adapter } from './aws/s3.adapter'; +import { Ec2VpcAdapter } from './aws/ec2-vpc.adapter'; +import { RdsAdapter } from './aws/rds.adapter'; +import { KmsAdapter } from './aws/kms.adapter'; +import { CloudWatchAdapter } from './aws/cloudwatch.adapter'; +import { ConfigAdapter } from './aws/config.adapter'; +import { GuardDutyAdapter } from './aws/guardduty.adapter'; +import { SecretsManagerAdapter } from './aws/secrets-manager.adapter'; +import { WafAdapter } from './aws/waf.adapter'; +import { ElbAdapter } from './aws/elb.adapter'; +import { AcmAdapter } from './aws/acm.adapter'; +import { BackupAdapter } from './aws/backup.adapter'; +import { InspectorAdapter } from './aws/inspector.adapter'; +import { EcsEksAdapter } from './aws/ecs-eks.adapter'; +import { LambdaAdapter } from './aws/lambda.adapter'; +import { DynamoDbAdapter } from './aws/dynamodb.adapter'; +import { SnsSqsAdapter } from './aws/sns-sqs.adapter'; +import { EcrAdapter } from './aws/ecr.adapter'; +import { OpenSearchAdapter } from './aws/opensearch.adapter'; +import { RedshiftAdapter } from './aws/redshift.adapter'; +import { MacieAdapter } from './aws/macie.adapter'; +import { Route53Adapter } from './aws/route53.adapter'; +import { ApiGatewayAdapter } from './aws/api-gateway.adapter'; +import { CloudFrontAdapter } from './aws/cloudfront.adapter'; +import { CognitoAdapter } from './aws/cognito.adapter'; +import { ElastiCacheAdapter } from './aws/elasticache.adapter'; +import { EfsAdapter } from './aws/efs.adapter'; +import { MskAdapter } from './aws/msk.adapter'; +import { SageMakerAdapter } from './aws/sagemaker.adapter'; +import { SystemsManagerAdapter } from './aws/systems-manager.adapter'; +import { CodeBuildAdapter } from './aws/codebuild.adapter'; +import { NetworkFirewallAdapter } from './aws/network-firewall.adapter'; +import { ShieldAdapter } from './aws/shield.adapter'; +import { KinesisAdapter } from './aws/kinesis.adapter'; +import { GlueAdapter } from './aws/glue.adapter'; +import { AthenaAdapter } from './aws/athena.adapter'; +import { EmrAdapter } from './aws/emr.adapter'; +import { StepFunctionsAdapter } from './aws/step-functions.adapter'; +import { EventBridgeAdapter } from './aws/eventbridge.adapter'; +import { TransferFamilyAdapter } from './aws/transfer-family.adapter'; +import { ElasticBeanstalkAdapter } from './aws/elastic-beanstalk.adapter'; +import { AppFlowAdapter } from './aws/appflow.adapter'; @Injectable() export class AWSSecurityService { private readonly logger = new Logger(AWSSecurityService.name); + private readonly adapters: AwsServiceAdapter[] = [ + new IamAdapter(), + new CloudTrailAdapter(), + new S3Adapter(), + new Ec2VpcAdapter(), + new RdsAdapter(), + new KmsAdapter(), + new CloudWatchAdapter(), + new ConfigAdapter(), + new GuardDutyAdapter(), + new SecretsManagerAdapter(), + new WafAdapter(), + new ElbAdapter(), + new AcmAdapter(), + new BackupAdapter(), + new InspectorAdapter(), + new EcsEksAdapter(), + new LambdaAdapter(), + new DynamoDbAdapter(), + new SnsSqsAdapter(), + new EcrAdapter(), + new OpenSearchAdapter(), + new RedshiftAdapter(), + new MacieAdapter(), + new Route53Adapter(), + new ApiGatewayAdapter(), + new CloudFrontAdapter(), + new CognitoAdapter(), + new ElastiCacheAdapter(), + new EfsAdapter(), + new MskAdapter(), + new SageMakerAdapter(), + new SystemsManagerAdapter(), + new CodeBuildAdapter(), + new NetworkFirewallAdapter(), + new ShieldAdapter(), + new KinesisAdapter(), + new GlueAdapter(), + new AthenaAdapter(), + new EmrAdapter(), + new StepFunctionsAdapter(), + new EventBridgeAdapter(), + new TransferFamilyAdapter(), + new ElasticBeanstalkAdapter(), + new AppFlowAdapter(), + ]; + async scanSecurityFindings( credentials: Record, variables: Record, + enabledServices?: string[], ): Promise { const isRoleAuth = Boolean(credentials.roleArn && credentials.externalId); const isKeyAuth = Boolean( @@ -32,20 +121,21 @@ export class AWSSecurityService { ); } - // Get all configured regions, or default to us-east-1 const configuredRegions = this.getConfiguredRegions(credentials, variables); + const primaryRegion = configuredRegions[0]; + this.logger.log( `Scanning ${configuredRegions.length} region(s): ${configuredRegions.join(', ')}`, ); - // Assume role ONCE before scanning all regions (IAM is global, not regional) - // This avoids N×2 STS API calls when scanning N regions + // Assume role ONCE — IAM is global, credentials work across all regions let awsCredentials: AwsCredentials; - // Note: configuredRegions is guaranteed to have at least one element (defaults to ['us-east-1']) - const primaryRegion = configuredRegions[0]; - if (isRoleAuth) { - awsCredentials = await this.assumeRole(credentials, primaryRegion); + awsCredentials = await this.assumeRole({ + roleArn: credentials.roleArn as string, + externalId: credentials.externalId as string, + region: primaryRegion, + }); } else { awsCredentials = { accessKeyId: credentials.access_key_id as string, @@ -53,38 +143,81 @@ export class AWSSecurityService { }; } + // undefined = scan all (no detection data), [] = scan nothing (all disabled), [...] = scan specific + const activeAdapters = + enabledServices === undefined + ? this.adapters + : this.adapters.filter((a) => enabledServices.includes(a.serviceId)); + + this.logger.log( + `Scanning ${activeAdapters.length} service adapters` + + (enabledServices?.length + ? ` (filtered from ${this.adapters.length} total)` + : ''), + ); + const allFindings: SecurityFinding[] = []; - const successfulRegions: string[] = []; - const failedRegions: string[] = []; + const successfulRegions = new Set(); + const failedRegions = new Set(); - // Scan each region using the same credentials - for (const region of configuredRegions) { + // Run global adapters once in the primary region + const globalAdapters = activeAdapters.filter((a) => a.isGlobal); + for (const adapter of globalAdapters) { try { - const regionFindings = await this.scanRegionWithCredentials( - awsCredentials, - region, + const findings = await adapter.scan({ + credentials: awsCredentials, + region: primaryRegion, + }); + for (const f of findings) { + f.evidence = { ...f.evidence, serviceId: adapter.serviceId }; + } + allFindings.push(...findings); + this.logger.log( + `[${adapter.serviceId}] ${findings.length} findings (global)`, ); - allFindings.push(...regionFindings); - successfulRegions.push(region); } catch (error) { - const errorMessage = - error instanceof Error ? error.message : String(error); - // Use warn - per-region failures are expected (e.g., Security Hub not enabled) - this.logger.warn(`Error scanning region ${region}: ${errorMessage}`); - failedRegions.push(region); - // Continue with other regions + const msg = error instanceof Error ? error.message : String(error); + this.logger.warn(`[${adapter.serviceId}] Error (global): ${msg}`); + } + } + + // Run regional adapters per configured region + const regionalAdapters = activeAdapters.filter((a) => !a.isGlobal); + for (const region of configuredRegions) { + for (const adapter of regionalAdapters) { + try { + const findings = await adapter.scan({ + credentials: awsCredentials, + region, + }); + for (const f of findings) { + f.evidence = { ...f.evidence, serviceId: adapter.serviceId }; + } + allFindings.push(...findings); + successfulRegions.add(region); + this.logger.log( + `[${adapter.serviceId}] ${findings.length} findings in ${region}`, + ); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + this.logger.warn(`[${adapter.serviceId}] Error in ${region}: ${msg}`); + failedRegions.add(region); + } } } - // Log summary this.logger.log( - `Scan complete: ${allFindings.length} findings from ${successfulRegions.length} regions`, + `Scan complete: ${allFindings.length} findings from ${successfulRegions.size} regions`, ); - // If ALL regions failed, throw an error so the caller knows the scan failed - if (successfulRegions.length === 0 && failedRegions.length > 0) { + // If ALL regions failed for regional adapters and no global adapters succeeded + if ( + regionalAdapters.length > 0 && + successfulRegions.size === 0 && + failedRegions.size > 0 + ) { throw new Error( - `All ${failedRegions.length} region(s) failed to scan: ${failedRegions.join(', ')}`, + `All ${failedRegions.size} region(s) failed to scan: ${[...failedRegions].join(', ')}`, ); } @@ -99,29 +232,20 @@ export class AWSSecurityService { credentials: Record, variables: Record, ): string[] { - // Check credentials.regions (array from multi-select) if (Array.isArray(credentials.regions) && credentials.regions.length > 0) { const filtered = credentials.regions.filter( (r): r is string => typeof r === 'string' && r.trim().length > 0, ); - // Only use filtered result if it has valid strings - if (filtered.length > 0) { - return filtered; - } + if (filtered.length > 0) return filtered; } - // Check variables.regions (array) if (Array.isArray(variables.regions) && variables.regions.length > 0) { const filtered = variables.regions.filter( (r): r is string => typeof r === 'string' && r.trim().length > 0, ); - // Only use filtered result if it has valid strings - if (filtered.length > 0) { - return filtered; - } + if (filtered.length > 0) return filtered; } - // Check single region in credentials or variables const singleRegion = (credentials.region as string) || (variables.region as string); @@ -133,52 +257,44 @@ export class AWSSecurityService { return [singleRegion.trim()]; } - // Default to us-east-1 return ['us-east-1']; } /** - * Scan a single AWS region using pre-obtained credentials. - * Credentials are reused across regions since IAM is global. + * Assume the remediation IAM role for write access. + * Uses a separate role ARN so the audit role stays read-only. */ - private async scanRegionWithCredentials( - awsCredentials: AwsCredentials, + async assumeRemediationRole( + credentials: Record, region: string, - ): Promise { - const securityHub = new SecurityHubClient({ + ): Promise { + const remediationRoleArn = credentials.remediationRoleArn as + | string + | undefined; + if (!remediationRoleArn) { + throw new Error( + 'Remediation role ARN not configured. Add a Remediation Role ARN to your AWS connection.', + ); + } + + return this.assumeRole({ + roleArn: remediationRoleArn, + externalId: credentials.externalId as string, region, - credentials: awsCredentials, + sessionName: 'CompSecurityRemediation', }); - - try { - const findings = await this.fetchSecurityHubFindings(securityHub, region); - this.logger.log(`Found ${findings.length} findings in region ${region}`); - return findings; - } catch (error) { - const errorMessage = - error instanceof Error ? error.message : String(error); - - if ( - errorMessage.includes('not subscribed') || - errorMessage.includes('AccessDenied') - ) { - this.logger.warn(`Security Hub not enabled in region ${region}`); - return []; - } - - throw error; - } } /** - * Assume IAM role for cross-account access + * Assume IAM role for cross-account access (two-hop) */ - private async assumeRole( - credentials: Record, - region: string, - ): Promise { - const customerRoleArn = credentials.roleArn as string; - const externalId = credentials.externalId as string; + async assumeRole(params: { + roleArn: string; + externalId: string; + region: string; + sessionName?: string; + }): Promise { + const { roleArn, externalId, region, sessionName } = params; const roleAssumerArn = process.env.SECURITY_HUB_ROLE_ASSUMER_ARN; if (!roleAssumerArn) { @@ -208,21 +324,19 @@ export class AWSSecurityService { sessionToken: roleAssumerCreds.SessionToken, }; - // Hop 2: roleAssumer -> customer role (ExternalId enforced by customer trust policy) + // Hop 2: roleAssumer -> customer role const roleAssumerSts = new STSClient({ region, credentials: roleAssumerAwsCreds, }); - this.logger.log( - `Assuming customer role ${customerRoleArn} in region ${region}`, - ); + this.logger.log(`Assuming customer role ${roleArn} in region ${region}`); const customerResp = await roleAssumerSts.send( new AssumeRoleCommand({ - RoleArn: customerRoleArn, + RoleArn: roleArn, ExternalId: externalId, - RoleSessionName: 'CompSecurityAudit', + RoleSessionName: sessionName ?? 'CompSecurityAudit', DurationSeconds: 3600, }), ); @@ -241,107 +355,156 @@ export class AWSSecurityService { }; } - private async fetchSecurityHubFindings( - securityHub: SecurityHubClient, - region: string, - ): Promise { - const allFindings: SecurityFinding[] = []; + /** + * Detect which AWS services are actively used via Cost Explorer billing data. + * Returns serviceIds matching our adapter IDs (e.g. 's3', 'rds', 'lambda'). + */ + async detectActiveServices( + credentials: Record, + variables: Record, + ): Promise { + const configuredRegions = this.getConfiguredRegions(credentials, variables); + const primaryRegion = configuredRegions[0]; - const params: GetFindingsCommandInput = { - Filters: { - WorkflowStatus: [ - { Value: 'NEW', Comparison: 'EQUALS' }, - { Value: 'NOTIFIED', Comparison: 'EQUALS' }, - ], - RecordState: [{ Value: 'ACTIVE', Comparison: 'EQUALS' }], - }, - MaxResults: 100, - }; + const isRoleAuth = Boolean(credentials.roleArn && credentials.externalId); + const isKeyAuth = Boolean( + credentials.access_key_id && credentials.secret_access_key, + ); - let response = await securityHub.send(new GetFindingsCommand(params)); + if (!isRoleAuth && !isKeyAuth) { + throw new Error('AWS credentials missing'); + } - if (response.Findings) { - for (const finding of response.Findings) { - allFindings.push(this.mapFinding(finding, region)); - } + let awsCredentials: AwsCredentials; + if (isRoleAuth) { + awsCredentials = await this.assumeRole({ + roleArn: credentials.roleArn as string, + externalId: credentials.externalId as string, + region: primaryRegion, + }); + } else { + awsCredentials = { + accessKeyId: credentials.access_key_id as string, + secretAccessKey: credentials.secret_access_key as string, + }; } - // Paginate - let nextToken = response.NextToken; - while (nextToken && allFindings.length < 500) { - response = await securityHub.send( - new GetFindingsCommand({ - ...params, - NextToken: nextToken, + const client = new CostExplorerClient({ + region: 'us-east-1', // Cost Explorer is global, always use us-east-1 + credentials: awsCredentials, + }); + + const now = new Date(); + const end = now.toISOString().slice(0, 10); // YYYY-MM-DD + const start = new Date(now.getTime() - 30 * 24 * 60 * 60 * 1000) + .toISOString() + .slice(0, 10); + + let response; + try { + response = await client.send( + new GetCostAndUsageCommand({ + TimePeriod: { Start: start, End: end }, + Granularity: 'MONTHLY', + Metrics: ['UnblendedCost'], + GroupBy: [{ Type: 'DIMENSION', Key: 'SERVICE' }], }), ); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + this.logger.warn( + `Cost Explorer unavailable (missing ce:GetCostAndUsage permission?): ${msg}`, + ); + return []; + } - if (response.Findings) { - for (const finding of response.Findings) { - if (allFindings.length >= 500) break; - allFindings.push(this.mapFinding(finding, region)); + const activeAwsNames = new Set(); + for (const result of response.ResultsByTime ?? []) { + for (const group of result.Groups ?? []) { + const serviceName = group.Keys?.[0]; + const amount = parseFloat(group.Metrics?.UnblendedCost?.Amount ?? '0'); + if (serviceName && amount > 0) { + activeAwsNames.add(serviceName); } } - - nextToken = response.NextToken; } - return allFindings; - } - - private mapFinding( - finding: { - Id?: string; - Title?: string; - Description?: string; - Remediation?: { Recommendation?: { Text?: string } }; - Severity?: { Label?: string }; - Resources?: Array<{ Type?: string; Id?: string }>; - AwsAccountId?: string; - Region?: string; - Compliance?: { Status?: string }; - GeneratorId?: string; - CreatedAt?: string; - UpdatedAt?: string; - }, - scanRegion: string, - ): SecurityFinding { - const severityMap: Record = { - INFORMATIONAL: 'info', - LOW: 'low', - MEDIUM: 'medium', - HIGH: 'high', - CRITICAL: 'critical', - }; - - const complianceStatus = finding.Compliance?.Status; - const passed = complianceStatus === 'PASSED'; - - // Use the finding's region if available, otherwise use the scan region - const findingRegion = finding.Region || scanRegion; + // Map AWS billing service names to our adapter serviceIds + const detected: string[] = []; + for (const [awsName, serviceIds] of Object.entries( + AWS_COST_SERVICE_MAPPING, + )) { + if (activeAwsNames.has(awsName)) { + for (const id of serviceIds) { + if (!detected.includes(id)) { + detected.push(id); + } + } + } + } - // Append region to title for frontend filtering (e.g., "Finding Title (us-east-1)") - const baseTitle = finding.Title || 'Untitled Finding'; - const titleWithRegion = `${baseTitle} (${findingRegion})`; + this.logger.log( + `Cost Explorer detected ${detected.length} active services from ${activeAwsNames.size} billing entries`, + ); - return { - id: finding.Id || '', - title: titleWithRegion, - description: finding.Description || 'No description available', - severity: severityMap[finding.Severity?.Label || 'INFO'] || 'medium', - resourceType: finding.Resources?.[0]?.Type || 'unknown', - resourceId: finding.Resources?.[0]?.Id || 'unknown', - remediation: - finding.Remediation?.Recommendation?.Text || 'No remediation available', - evidence: { - awsAccountId: finding.AwsAccountId, - region: findingRegion, - complianceStatus, - generatorId: finding.GeneratorId, - updatedAt: finding.UpdatedAt, - }, - createdAt: finding.CreatedAt || new Date().toISOString(), - passed, - }; + return detected; } } + +/** + * Maps AWS Cost Explorer billing service names to our adapter serviceIds. + * One billing name may map to multiple adapters (e.g. EC2 → ec2-vpc, elb). + */ +const AWS_COST_SERVICE_MAPPING: Record = { + 'AWS Security Hub': ['security-hub'], + 'AWS IAM Access Analyzer': ['iam-analyzer'], + 'AWS CloudTrail': ['cloudtrail'], + 'Amazon Simple Storage Service': ['s3'], + 'Amazon Elastic Compute Cloud - Compute': ['ec2-vpc'], + 'EC2 - Other': ['ec2-vpc'], + 'Amazon Relational Database Service': ['rds'], + 'AWS Key Management Service': ['kms'], + 'Amazon CloudWatch': ['cloudwatch'], + 'AWS Config': ['config'], + 'Amazon GuardDuty': ['guardduty'], + 'AWS Secrets Manager': ['secrets-manager'], + 'AWS WAF': ['waf'], + 'Amazon Elastic Load Balancing': ['elb'], + 'AWS Certificate Manager': ['acm'], + 'AWS Backup': ['backup'], + 'Amazon Inspector': ['inspector'], + 'Amazon Elastic Container Service': ['ecs-eks'], + 'Amazon Elastic Kubernetes Service': ['ecs-eks'], + 'AWS Lambda': ['lambda'], + 'Amazon DynamoDB': ['dynamodb'], + 'Amazon Simple Notification Service': ['sns-sqs'], + 'Amazon Simple Queue Service': ['sns-sqs'], + 'Amazon Elastic Container Registry': ['ecr'], + 'Amazon OpenSearch Service': ['opensearch'], + 'Amazon Elasticsearch Service': ['opensearch'], // legacy name + 'Amazon Redshift': ['redshift'], + 'Amazon Macie': ['macie'], + 'Amazon Route 53': ['route53'], + 'Amazon API Gateway': ['api-gateway'], + 'Amazon CloudFront': ['cloudfront'], + 'Amazon Cognito': ['cognito'], + 'Amazon ElastiCache': ['elasticache'], + 'Amazon Elastic File System': ['efs'], + 'Amazon Managed Streaming for Apache Kafka': ['msk'], + 'Amazon SageMaker': ['sagemaker'], + 'AWS Systems Manager': ['systems-manager'], + 'AWS CodeBuild': ['codebuild'], + 'AWS Network Firewall': ['network-firewall'], + 'AWS Shield': ['shield'], + 'Amazon Kinesis': ['kinesis'], + 'Amazon Kinesis Data Firehose': ['kinesis'], + 'Amazon Kinesis Data Analytics': ['kinesis'], + 'AWS Glue': ['glue'], + 'Amazon Athena': ['athena'], + 'Amazon Elastic MapReduce': ['emr'], + 'AWS Step Functions': ['step-functions'], + 'Amazon EventBridge': ['eventbridge'], + 'AWS Transfer Family': ['transfer-family'], + 'AWS Elastic Beanstalk': ['elastic-beanstalk'], + 'Amazon AppFlow': ['appflow'], +}; diff --git a/apps/api/src/cloud-security/providers/aws/acm.adapter.ts b/apps/api/src/cloud-security/providers/aws/acm.adapter.ts new file mode 100644 index 0000000000..acc38cc99b --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/acm.adapter.ts @@ -0,0 +1,152 @@ +import { + ACMClient, + DescribeCertificateCommand, + ListCertificatesCommand, +} from '@aws-sdk/client-acm'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class AcmAdapter implements AwsServiceAdapter { + readonly serviceId = 'acm'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new ACMClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const listRes = await client.send( + new ListCertificatesCommand({ NextToken: nextToken }), + ); + + for (const summary of listRes.CertificateSummaryList ?? []) { + const arn = summary.CertificateArn; + if (!arn) continue; + + const descRes = await client.send( + new DescribeCertificateCommand({ CertificateArn: arn }), + ); + + const cert = descRes.Certificate; + if (!cert) continue; + + const notAfter = cert.NotAfter; + if (notAfter) { + const daysUntilExpiry = Math.floor( + (notAfter.getTime() - Date.now()) / (1000 * 60 * 60 * 24), + ); + + if (daysUntilExpiry < 0) { + findings.push( + this.makeFinding( + arn, + 'Certificate has expired', + `Certificate expired ${Math.abs(daysUntilExpiry)} days ago`, + 'critical', + { daysUntilExpiry, notAfter: notAfter.toISOString() }, + false, + `Use acm:RequestCertificateCommand with DomainName set to the certificate's domain and ValidationMethod set to 'DNS' (or 'EMAIL') to request a replacement certificate. After validation, update resources referencing the old certificate ARN. [MANUAL] Certificate renewal requires DNS or email validation that cannot be fully automated. Rollback: resources can be pointed back to the old certificate ARN if it is renewed.`, + ), + ); + } else if (daysUntilExpiry < 7) { + findings.push( + this.makeFinding( + arn, + 'Certificate expiring within 7 days', + `Certificate expires in ${daysUntilExpiry} days`, + 'critical', + { daysUntilExpiry, notAfter: notAfter.toISOString() }, + false, + `Use acm:RenewCertificateCommand with CertificateArn to trigger renewal for imported certificates. For ACM-issued certificates, renewal is automatic if DNS validation records are in place. [MANUAL] If DNS validation records are missing, you must add them or use acm:RequestCertificateCommand to request a new certificate. Rollback: not applicable for renewal.`, + ), + ); + } else if (daysUntilExpiry < 30) { + findings.push( + this.makeFinding( + arn, + 'Certificate expiring within 30 days', + `Certificate expires in ${daysUntilExpiry} days`, + 'high', + { daysUntilExpiry, notAfter: notAfter.toISOString() }, + false, + `Use acm:RenewCertificateCommand with CertificateArn to trigger renewal for imported certificates. For ACM-issued certificates, renewal is automatic if DNS validation records are in place. [MANUAL] If DNS validation records are missing, you must add them. Rollback: not applicable for renewal.`, + ), + ); + } else { + findings.push( + this.makeFinding( + arn, + 'Certificate is valid', + `Certificate expires in ${daysUntilExpiry} days`, + 'info', + { daysUntilExpiry, notAfter: notAfter.toISOString() }, + true, + ), + ); + } + } + + if ( + cert.Type === 'AMAZON_ISSUED' && + cert.RenewalEligibility === 'INELIGIBLE' + ) { + findings.push( + this.makeFinding( + arn, + 'ACM certificate not eligible for renewal', + 'ACM-issued certificate is marked as ineligible for automatic renewal', + 'medium', + { renewalEligibility: cert.RenewalEligibility }, + false, + `[MANUAL] Cannot be auto-fixed. The certificate is ineligible for automatic renewal, typically because DNS validation records are missing or the domain is no longer resolvable. Verify DNS validation CNAME records are present for the certificate domain. If records are missing, use acm:RequestCertificateCommand to request a new certificate with ValidationMethod 'DNS' and add the new validation records.`, + ), + ); + } + } + + nextToken = listRes.NextToken; + } while (nextToken); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `acm-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsAcmCertificate', + resourceId, + remediation, + evidence: { ...evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/api-gateway.adapter.ts b/apps/api/src/cloud-security/providers/aws/api-gateway.adapter.ts new file mode 100644 index 0000000000..74bbcbceea --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/api-gateway.adapter.ts @@ -0,0 +1,159 @@ +import { + ApiGatewayV2Client, + GetApisCommand, + GetStagesCommand, + GetRoutesCommand, +} from '@aws-sdk/client-apigatewayv2'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class ApiGatewayAdapter implements AwsServiceAdapter { + readonly serviceId = 'api-gateway'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + accountId, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new ApiGatewayV2Client({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const resp = await client.send( + new GetApisCommand({ NextToken: nextToken }), + ); + + for (const api of resp.Items ?? []) { + if (!api.ApiId || !api.ApiEndpoint) continue; + + const apiName = api.Name ?? api.ApiId; + + if (api.ProtocolType === 'HTTP') { + const apiFindings = await this.checkApi( + client, + api.ApiId, + apiName, + region, + accountId, + ); + findings.push(...apiFindings); + } + } + + nextToken = resp.NextToken; + } while (nextToken); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private async checkApi( + client: ApiGatewayV2Client, + apiId: string, + apiName: string, + region: string, + accountId?: string, + ): Promise { + const findings: SecurityFinding[] = []; + + try { + // Check routes for authorization + const routesResp = await client.send( + new GetRoutesCommand({ ApiId: apiId }), + ); + + for (const route of routesResp.Items ?? []) { + const routeKey = route.RouteKey ?? 'unknown'; + + if (!route.AuthorizationType || route.AuthorizationType === 'NONE') { + findings.push( + this.makeFinding({ + id: `apigw-no-auth-${apiId}-${routeKey}`, + title: `API Gateway "${apiName}" route "${routeKey}" has no authorization configured (${region})`, + description: `API ${apiName} route ${routeKey} does not have an authorization type configured. The route is accessible without authentication.`, + severity: 'medium', + resourceId: apiId, + remediation: `Use apigatewayv2:UpdateRouteCommand with ApiId set to "${apiId}", RouteId set to the route ID for "${routeKey}", and AuthorizationType set to 'JWT', 'AWS_IAM', or 'CUSTOM'. Provide AuthorizerId if using JWT or CUSTOM. Rollback: use apigatewayv2:UpdateRouteCommand with AuthorizationType set to 'NONE'.`, + passed: false, + accountId, + region, + }), + ); + } + } + + // Check stages for access logging + const stagesResp = await client.send( + new GetStagesCommand({ ApiId: apiId }), + ); + + for (const stage of stagesResp.Items ?? []) { + const stageName = stage.StageName ?? 'unknown'; + + if (!stage.AccessLogSettings?.DestinationArn) { + findings.push( + this.makeFinding({ + id: `apigw-no-logging-${apiId}/${stageName}`, + title: `API Gateway "${apiName}" stage "${stageName}" has access logging not enabled (${region})`, + description: `API ${apiName} stage ${stageName} does not have access logging configured. API calls are not being logged for audit and troubleshooting.`, + severity: 'medium', + resourceId: apiId, + remediation: `Use apigatewayv2:UpdateStageCommand with ApiId set to "${apiId}", StageName set to "${stageName}", and AccessLogSettings.DestinationArn set to a CloudWatch Logs log group ARN. Set AccessLogSettings.Format to a JSON log format string (e.g., '{"requestId":"$context.requestId","ip":"$context.identity.sourceIp"}'). Rollback: use apigatewayv2:UpdateStageCommand with AccessLogSettings set to empty object.`, + passed: false, + accountId, + region, + }), + ); + } + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding(opts: { + id: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + resourceId?: string; + remediation?: string; + passed: boolean; + accountId?: string; + region?: string; + }): SecurityFinding { + return { + id: opts.id, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'AwsApiGatewayApi', + resourceId: opts.resourceId || 'unknown', + remediation: opts.remediation, + evidence: { + awsAccountId: opts.accountId, + region: opts.region, + service: 'API Gateway', + findingKey: opts.id, + }, + createdAt: new Date().toISOString(), + passed: opts.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/appflow.adapter.ts b/apps/api/src/cloud-security/providers/aws/appflow.adapter.ts new file mode 100644 index 0000000000..1f15ecb612 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/appflow.adapter.ts @@ -0,0 +1,101 @@ +import { + AppflowClient, + DescribeFlowCommand, + ListFlowsCommand, +} from '@aws-sdk/client-appflow'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class AppFlowAdapter implements AwsServiceAdapter { + readonly serviceId = 'appflow'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new AppflowClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const listRes = await client.send(new ListFlowsCommand({ nextToken })); + + for (const flow of listRes.flows ?? []) { + const flowName = flow.flowName; + if (!flowName) continue; + + const flowArn = flow.flowArn ?? flowName; + + const descRes = await client.send( + new DescribeFlowCommand({ flowName }), + ); + + if (!descRes.kmsArn) { + findings.push( + this.makeFinding( + flowArn, + 'Flow not encrypted with CMK', + `AppFlow flow "${flowName}" is not encrypted with a customer-managed KMS key`, + 'medium', + { flowName, service: 'AppFlow' }, + false, + `Use appflow:UpdateFlowCommand with flowName set to '${flowName}' and kmsArn set to a customer-managed KMS key ARN. You must also provide the full flow configuration (sourceFlowConfig, destinationFlowConfigList, tasks, triggerConfig). Rollback: use appflow:UpdateFlowCommand with kmsArn removed to revert to AWS-managed encryption.`, + ), + ); + } else { + findings.push( + this.makeFinding( + flowArn, + 'Flow encrypted with CMK', + `AppFlow flow "${flowName}" is encrypted with customer-managed KMS key`, + 'info', + { flowName, kmsArn: descRes.kmsArn, service: 'AppFlow' }, + true, + ), + ); + } + } + + nextToken = listRes.nextToken; + } while (nextToken); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `appflow-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsAppFlow', + resourceId, + remediation, + evidence: { ...evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/athena.adapter.ts b/apps/api/src/cloud-security/providers/aws/athena.adapter.ts new file mode 100644 index 0000000000..1e11c07546 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/athena.adapter.ts @@ -0,0 +1,144 @@ +import { + AthenaClient, + ListWorkGroupsCommand, + GetWorkGroupCommand, +} from '@aws-sdk/client-athena'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class AthenaAdapter implements AwsServiceAdapter { + readonly serviceId = 'athena'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new AthenaClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const listRes = await client.send( + new ListWorkGroupsCommand({ NextToken: nextToken }), + ); + + for (const wgSummary of listRes.WorkGroups ?? []) { + const wgName = wgSummary.Name ?? 'unknown'; + + // Skip the default "primary" workgroup — only check user-created workgroups + if (wgName === 'primary') continue; + + const resourceId = `arn:aws:athena:${region}:workgroup/${wgName}`; + + const descRes = await client.send( + new GetWorkGroupCommand({ WorkGroup: wgName }), + ); + + const config = descRes.WorkGroup?.Configuration; + + // Check query result encryption + const encryptionConfig = + config?.ResultConfiguration?.EncryptionConfiguration; + + if (!encryptionConfig) { + findings.push( + this.makeFinding( + resourceId, + 'Query results not encrypted', + `Athena workgroup "${wgName}" does not have encryption configured for query results`, + 'medium', + { workGroupName: wgName, encryptionConfiguration: null }, + false, + `Use athena:UpdateWorkGroupCommand with WorkGroup set to '${wgName}' and ConfigurationUpdates.ResultConfigurationUpdates.EncryptionConfiguration set to { EncryptionOption: 'SSE_KMS', KmsKey: '' } (or 'SSE_S3' for S3-managed encryption). Rollback: use athena:UpdateWorkGroupCommand with RemoveEncryptionConfiguration set to true.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'Query results encryption enabled', + `Athena workgroup "${wgName}" has encryption configured for query results`, + 'info', + { + workGroupName: wgName, + encryptionOption: encryptionConfig.EncryptionOption, + }, + true, + ), + ); + } + + // Check workgroup configuration enforcement + if (config?.EnforceWorkGroupConfiguration !== true) { + findings.push( + this.makeFinding( + resourceId, + 'Workgroup configuration not enforced (users can override)', + `Athena workgroup "${wgName}" does not enforce its configuration, allowing users to override settings`, + 'medium', + { + workGroupName: wgName, + enforceWorkGroupConfiguration: + config?.EnforceWorkGroupConfiguration, + }, + false, + `Use athena:UpdateWorkGroupCommand with WorkGroup set to '${wgName}' and ConfigurationUpdates.EnforceWorkGroupConfiguration set to true. This prevents users from overriding workgroup settings at query time. Rollback: use athena:UpdateWorkGroupCommand with EnforceWorkGroupConfiguration set to false.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'Workgroup configuration enforced', + `Athena workgroup "${wgName}" enforces its configuration`, + 'info', + { workGroupName: wgName, enforceWorkGroupConfiguration: true }, + true, + ), + ); + } + } + + nextToken = listRes.NextToken; + } while (nextToken); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `athena-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsAthenaWorkGroup', + resourceId, + remediation, + evidence: { ...evidence, service: 'Athena', findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/aws-service-adapter.ts b/apps/api/src/cloud-security/providers/aws/aws-service-adapter.ts new file mode 100644 index 0000000000..e1cac66421 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/aws-service-adapter.ts @@ -0,0 +1,19 @@ +import type { SecurityFinding } from '../../cloud-security.service'; + +export type AwsCredentials = { + accessKeyId: string; + secretAccessKey: string; + sessionToken?: string; +}; + +export interface AwsServiceAdapter { + /** Must match the manifest service ID (e.g. 'security-hub', 'iam-analyzer') */ + readonly serviceId: string; + /** true = scan once in primary region, false = scan per configured region */ + readonly isGlobal?: boolean; + scan(params: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise; +} diff --git a/apps/api/src/cloud-security/providers/aws/backup.adapter.ts b/apps/api/src/cloud-security/providers/aws/backup.adapter.ts new file mode 100644 index 0000000000..6c771e5a6f --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/backup.adapter.ts @@ -0,0 +1,173 @@ +import { + BackupClient, + ListBackupPlansCommand, + ListBackupSelectionsCommand, +} from '@aws-sdk/client-backup'; +import { RDSClient, DescribeDBInstancesCommand } from '@aws-sdk/client-rds'; +import { DynamoDBClient, ListTablesCommand } from '@aws-sdk/client-dynamodb'; +import { EC2Client, DescribeVolumesCommand } from '@aws-sdk/client-ec2'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class BackupAdapter implements AwsServiceAdapter { + readonly serviceId = 'backup'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new BackupClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + // Prerequisite: check if there are backup-eligible resources (RDS, DynamoDB, EBS) + try { + let hasBackupEligible = false; + + const rdsClient = new RDSClient({ credentials, region }); + const rdsResp = await rdsClient.send( + new DescribeDBInstancesCommand({ MaxRecords: 20 }), + ); + if ((rdsResp.DBInstances ?? []).length > 0) { + hasBackupEligible = true; + } + + if (!hasBackupEligible) { + const ddbClient = new DynamoDBClient({ credentials, region }); + const ddbResp = await ddbClient.send( + new ListTablesCommand({ Limit: 1 }), + ); + if ((ddbResp.TableNames ?? []).length > 0) { + hasBackupEligible = true; + } + } + + if (!hasBackupEligible) { + const ec2Client = new EC2Client({ credentials, region }); + const volResp = await ec2Client.send( + new DescribeVolumesCommand({ MaxResults: 5 }), + ); + if ((volResp.Volumes ?? []).length > 0) { + hasBackupEligible = true; + } + } + + if (!hasBackupEligible) return []; + } catch { + // If prerequisite check fails (permissions), fall through to existing behavior + } + + try { + let nextToken: string | undefined; + let hasPlans = false; + + do { + const listRes = await client.send( + new ListBackupPlansCommand({ NextToken: nextToken }), + ); + + for (const plan of listRes.BackupPlansList ?? []) { + hasPlans = true; + const planId = plan.BackupPlanId; + const planArn = plan.BackupPlanArn; + const resourceId = planArn ?? planId ?? 'unknown'; + if (!planId) continue; + + let selNextToken: string | undefined; + let hasSelections = false; + + do { + const selRes = await client.send( + new ListBackupSelectionsCommand({ + BackupPlanId: planId, + NextToken: selNextToken, + }), + ); + + if ((selRes.BackupSelectionsList ?? []).length > 0) { + hasSelections = true; + } + + selNextToken = selRes.NextToken; + } while (selNextToken && !hasSelections); + + if (!hasSelections) { + findings.push( + this.makeFinding( + resourceId, + 'Backup plan has no resource selections', + `Backup plan "${plan.BackupPlanName ?? planId}" has no resources assigned for backup`, + 'medium', + { backupPlanId: planId, backupPlanName: plan.BackupPlanName }, + false, + `Use backup:CreateBackupSelectionCommand with BackupPlanId set to '${planId}' and BackupSelection containing SelectionName, IamRoleArn (for backup execution), and Resources (list of ARNs) or ListOfTags to select resources by tag. Rollback: use backup:DeleteBackupSelectionCommand with BackupPlanId and SelectionId.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'Backup plan has resource selections', + `Backup plan "${plan.BackupPlanName ?? planId}" has resources assigned`, + 'info', + { backupPlanId: planId, backupPlanName: plan.BackupPlanName }, + true, + ), + ); + } + } + + nextToken = listRes.NextToken; + } while (nextToken); + + if (!hasPlans) { + findings.push( + this.makeFinding( + `arn:aws:backup:${region}:no-plans`, + 'No backup plans configured', + 'No AWS Backup plans found in this region', + 'medium', + { region }, + false, + `Use backup:CreateBackupPlanCommand with BackupPlan containing BackupPlanName and Rules (array with ScheduleExpression e.g., 'cron(0 5 ? * * *)', TargetBackupVaultName, and Lifecycle.DeleteAfterDays set to 35). Then use backup:CreateBackupSelectionCommand to assign resources to the plan. Rollback: use backup:DeleteBackupPlanCommand with BackupPlanId.`, + ), + ); + } + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `backup-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsBackupPlan', + resourceId, + remediation, + evidence: { ...evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/cloudfront.adapter.ts b/apps/api/src/cloud-security/providers/aws/cloudfront.adapter.ts new file mode 100644 index 0000000000..f6cd94d58f --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/cloudfront.adapter.ts @@ -0,0 +1,210 @@ +import { + CloudFrontClient, + ListDistributionsCommand, + GetDistributionCommand, +} from '@aws-sdk/client-cloudfront'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class CloudFrontAdapter implements AwsServiceAdapter { + readonly serviceId = 'cloudfront'; + readonly isGlobal = true; + + async scan({ + credentials, + region, + accountId, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new CloudFrontClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextMarker: string | undefined; + let hasMore = true; + + while (hasMore) { + const resp = await client.send( + new ListDistributionsCommand({ Marker: nextMarker }), + ); + + const distList = resp.DistributionList; + if (!distList) break; + + for (const dist of distList.Items ?? []) { + if (!dist.Id) continue; + + const distId = dist.Id; + const domainName = dist.DomainName ?? distId; + + this.checkViewerProtocol( + dist, + distId, + domainName, + region, + accountId, + findings, + ); + this.checkWaf(dist, distId, domainName, region, accountId, findings); + await this.checkLogging( + client, + distId, + domainName, + region, + accountId, + findings, + ); + } + + if (distList.IsTruncated && distList.NextMarker) { + nextMarker = distList.NextMarker; + } else { + hasMore = false; + } + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private checkViewerProtocol( + dist: { + DefaultCacheBehavior?: { ViewerProtocolPolicy?: string }; + CacheBehaviors?: { Items?: { ViewerProtocolPolicy?: string }[] }; + }, + distId: string, + domainName: string, + region: string, + accountId: string | undefined, + findings: SecurityFinding[], + ): void { + const policies: (string | undefined)[] = []; + + if (dist.DefaultCacheBehavior?.ViewerProtocolPolicy) { + policies.push(dist.DefaultCacheBehavior.ViewerProtocolPolicy); + } + + for (const behavior of dist.CacheBehaviors?.Items ?? []) { + if (behavior.ViewerProtocolPolicy) { + policies.push(behavior.ViewerProtocolPolicy); + } + } + + const allowsHttp = policies.some((p) => p === 'allow-all'); + + if (allowsHttp) { + findings.push( + this.makeFinding({ + id: `cloudfront-http-allowed-${distId}`, + title: `CloudFront distribution "${domainName}" allows HTTP traffic (${region})`, + description: `Distribution ${distId} has a cache behavior with ViewerProtocolPolicy set to "allow-all", permitting unencrypted HTTP connections.`, + severity: 'high', + resourceId: distId, + remediation: `Use cloudfront:UpdateDistributionCommand with Id set to "${distId}". In DistributionConfig, set DefaultCacheBehavior.ViewerProtocolPolicy to 'redirect-to-https' and update all CacheBehaviors.Items[].ViewerProtocolPolicy to 'redirect-to-https'. You must include the full DistributionConfig and the current IfMatch ETag from cloudfront:GetDistributionCommand. Rollback: set ViewerProtocolPolicy back to 'allow-all'.`, + passed: false, + accountId, + region, + }), + ); + } + } + + private checkWaf( + dist: { WebACLId?: string }, + distId: string, + domainName: string, + region: string, + accountId: string | undefined, + findings: SecurityFinding[], + ): void { + if (!dist.WebACLId) { + findings.push( + this.makeFinding({ + id: `cloudfront-no-waf-${distId}`, + title: `CloudFront distribution "${domainName}" has no WAF associated (${region})`, + description: `Distribution ${distId} is not associated with an AWS WAF web ACL. There is no web application firewall protecting this distribution.`, + severity: 'medium', + resourceId: distId, + remediation: `Use cloudfront:UpdateDistributionCommand with Id set to "${distId}". In DistributionConfig, set WebACLId to the WAF web ACL ARN. You must include the full DistributionConfig and the current IfMatch ETag from cloudfront:GetDistributionCommand. Rollback: set WebACLId to an empty string to disassociate.`, + passed: false, + accountId, + region, + }), + ); + } + } + + private async checkLogging( + client: CloudFrontClient, + distId: string, + domainName: string, + region: string, + accountId: string | undefined, + findings: SecurityFinding[], + ): Promise { + try { + const resp = await client.send( + new GetDistributionCommand({ Id: distId }), + ); + + const logging = resp.Distribution?.DistributionConfig?.Logging; + + if (!logging?.Enabled) { + findings.push( + this.makeFinding({ + id: `cloudfront-no-logging-${distId}`, + title: `CloudFront distribution "${domainName}" has access logging disabled (${region})`, + description: `Distribution ${distId} does not have access logging enabled. Request logs are not being captured for audit or analysis.`, + severity: 'medium', + resourceId: distId, + remediation: `Use cloudfront:UpdateDistributionCommand with Id set to "${distId}". In DistributionConfig.Logging, set Enabled to true, Bucket to an S3 bucket domain (e.g., 'my-logs-bucket.s3.amazonaws.com'), and Prefix to a log prefix string. You must include the full DistributionConfig and the current IfMatch ETag from cloudfront:GetDistributionCommand. Rollback: set Logging.Enabled to false.`, + passed: false, + accountId, + region, + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return; + throw error; + } + } + + private makeFinding(opts: { + id: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + resourceId?: string; + remediation?: string; + passed: boolean; + accountId?: string; + region?: string; + }): SecurityFinding { + return { + id: opts.id, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'AwsCloudFrontDistribution', + resourceId: opts.resourceId || 'unknown', + remediation: opts.remediation, + evidence: { + awsAccountId: opts.accountId, + region: opts.region, + service: 'CloudFront', + findingKey: opts.id, + }, + createdAt: new Date().toISOString(), + passed: opts.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/cloudtrail.adapter.ts b/apps/api/src/cloud-security/providers/aws/cloudtrail.adapter.ts new file mode 100644 index 0000000000..10d85d9ec0 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/cloudtrail.adapter.ts @@ -0,0 +1,151 @@ +import { + CloudTrailClient, + DescribeTrailsCommand, + GetTrailStatusCommand, +} from '@aws-sdk/client-cloudtrail'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class CloudTrailAdapter implements AwsServiceAdapter { + readonly serviceId = 'cloudtrail'; + readonly isGlobal = true; + + async scan(params: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const { credentials, region, accountId } = params; + const client = new CloudTrailClient({ region, credentials }); + + const findings: SecurityFinding[] = []; + + const trailsResp = await client.send(new DescribeTrailsCommand({})); + const trails = trailsResp.trailList || []; + + if (trails.length === 0) { + findings.push( + this.makeFinding({ + id: 'cloudtrail-no-trails', + title: 'No CloudTrail trails configured', + description: + 'No CloudTrail trails exist. API activity is not being logged.', + severity: 'critical', + remediation: + 'Create a multi-region trail using cloudtrail:CreateTrailCommand with Name set to "compai-cloudtrail", S3BucketName set to the target logging bucket, IsMultiRegionTrail set to true, and EnableLogFileValidation set to true. Then start logging with cloudtrail:StartLoggingCommand using the trail Name. Rollback by calling cloudtrail:StopLoggingCommand and then cloudtrail:DeleteTrailCommand with the trail Name.', + passed: false, + accountId, + }), + ); + return findings; + } + + const hasMultiRegion = trails.some((t) => t.IsMultiRegionTrail); + if (!hasMultiRegion) { + findings.push( + this.makeFinding({ + id: 'cloudtrail-no-multi-region', + title: 'No multi-region CloudTrail trail configured', + description: + 'None of the configured trails have multi-region logging enabled. Activity in other regions may not be captured.', + severity: 'high', + remediation: + 'Use cloudtrail:UpdateTrailCommand with the trail Name and IsMultiRegionTrail set to true. Rollback by calling cloudtrail:UpdateTrailCommand with IsMultiRegionTrail set to false.', + passed: false, + accountId, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: 'cloudtrail-multi-region-ok', + title: 'Multi-region CloudTrail trail is configured', + description: 'At least one trail has multi-region logging enabled.', + severity: 'info', + passed: true, + accountId, + }), + ); + } + + for (const trail of trails) { + if (!trail.TrailARN || !trail.Name) continue; + + const statusResp = await client.send( + new GetTrailStatusCommand({ Name: trail.TrailARN }), + ); + + if (!statusResp.IsLogging) { + findings.push( + this.makeFinding({ + id: `cloudtrail-not-logging-${trail.Name}`, + title: `CloudTrail trail "${trail.Name}" is not logging`, + description: `Trail ${trail.Name} exists but logging is disabled.`, + severity: 'high', + resourceId: trail.TrailARN, + remediation: `Use cloudtrail:StartLoggingCommand with Name set to the trail ARN for "${trail.Name}". Rollback by calling cloudtrail:StopLoggingCommand with the same Name.`, + passed: false, + accountId, + }), + ); + } + + if (!trail.LogFileValidationEnabled) { + findings.push( + this.makeFinding({ + id: `cloudtrail-no-validation-${trail.Name}`, + title: `CloudTrail trail "${trail.Name}" has log file validation disabled`, + description: `Trail ${trail.Name} does not validate log file integrity. Tampered logs would go undetected.`, + severity: 'medium', + resourceId: trail.TrailARN, + remediation: `Use cloudtrail:UpdateTrailCommand with Name set to "${trail.Name}" and EnableLogFileValidation set to true. Rollback by calling cloudtrail:UpdateTrailCommand with EnableLogFileValidation set to false.`, + passed: false, + accountId, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `cloudtrail-validation-ok-${trail.Name}`, + title: `CloudTrail trail "${trail.Name}" has log file validation enabled`, + description: `Trail ${trail.Name} validates log file integrity.`, + severity: 'info', + resourceId: trail.TrailARN, + passed: true, + accountId, + }), + ); + } + } + + return findings; + } + + private makeFinding(opts: { + id: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + resourceId?: string; + remediation?: string; + passed: boolean; + accountId?: string; + }): SecurityFinding { + return { + id: opts.id, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'AwsCloudTrailTrail', + resourceId: opts.resourceId || 'account-level', + remediation: opts.remediation, + evidence: { + awsAccountId: opts.accountId, + service: 'CloudTrail', + findingKey: opts.id, + }, + createdAt: new Date().toISOString(), + passed: opts.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/cloudwatch.adapter.ts b/apps/api/src/cloud-security/providers/aws/cloudwatch.adapter.ts new file mode 100644 index 0000000000..c0b7d3ee7e --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/cloudwatch.adapter.ts @@ -0,0 +1,319 @@ +import { + CloudWatchLogsClient, + DescribeMetricFiltersCommand, +} from '@aws-sdk/client-cloudwatch-logs'; +import { + CloudWatchClient, + DescribeAlarmsForMetricCommand, +} from '@aws-sdk/client-cloudwatch'; +import { + CloudTrailClient, + DescribeTrailsCommand, +} from '@aws-sdk/client-cloudtrail'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +interface CisCheck { + id: string; + name: string; + keywords: string[]; +} + +const CIS_CHECKS: CisCheck[] = [ + { + id: 'cis-4.3', + name: 'Root account usage', + keywords: ['Root', 'userIdentity.type'], + }, + { + id: 'cis-4.1', + name: 'Unauthorized API calls', + keywords: ['UnauthorizedAccess', 'AccessDenied'], + }, + { + id: 'cis-4.5', + name: 'CloudTrail config changes', + keywords: ['CreateTrail', 'DeleteTrail'], + }, + { + id: 'cis-4.4', + name: 'IAM policy changes', + keywords: ['CreatePolicy', 'DeletePolicy', 'AttachRolePolicy'], + }, + { + id: 'cis-4.6', + name: 'Console auth failures', + keywords: ['ConsoleLogin', 'Failed'], + }, + { + id: 'cis-4.7', + name: 'CMK deletion/disabling', + keywords: ['kms.amazonaws.com', 'DisableKey'], + }, + { + id: 'cis-4.8', + name: 'S3 bucket policy changes', + keywords: ['PutBucketPolicy', 'DeleteBucketPolicy'], + }, + { + id: 'cis-4.9', + name: 'Security group changes', + keywords: ['AuthorizeSecurityGroupIngress', 'RevokeSecurityGroupIngress'], + }, + { + id: 'cis-4.10', + name: 'NACL changes', + keywords: ['CreateNetworkAcl', 'DeleteNetworkAcl'], + }, + { + id: 'cis-4.11', + name: 'Network gateway changes', + keywords: ['CreateCustomerGateway', 'AttachInternetGateway'], + }, + { + id: 'cis-4.12', + name: 'Route table changes', + keywords: ['CreateRoute', 'DeleteRoute'], + }, + { + id: 'cis-4.13', + name: 'VPC changes', + keywords: ['CreateVpc', 'DeleteVpc'], + }, + { + id: 'cis-4.14', + name: 'AWS Organizations changes', + keywords: ['organizations.amazonaws.com'], + }, +]; + +export class CloudWatchAdapter implements AwsServiceAdapter { + readonly serviceId = 'cloudwatch'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const logsClient = new CloudWatchLogsClient({ credentials, region }); + const cwClient = new CloudWatchClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + // Prerequisite: check if any CloudTrail trail has CloudWatch Logs integration + try { + const ctClient = new CloudTrailClient({ credentials, region }); + const trailsResp = await ctClient.send(new DescribeTrailsCommand({})); + const trails = trailsResp.trailList ?? []; + const hasCloudWatchIntegration = trails.some( + (trail) => !!trail.CloudWatchLogsLogGroupArn, + ); + + if (!hasCloudWatchIntegration) { + return [ + this.makeFinding({ + checkId: 'cloudwatch-no-cloudtrail-integration', + title: 'CloudTrail not integrated with CloudWatch Logs', + description: + 'No CloudTrail trail in this region is configured to send logs to CloudWatch Logs. CIS metric filter checks require CloudTrail-CloudWatch integration.', + severity: 'high', + remediation: + 'Use cloudtrail:UpdateTrailCommand with the trail Name and CloudWatchLogsLogGroupArn set to the target log group ARN, and CloudWatchLogsRoleArn set to an IAM role ARN that allows CloudTrail to write to CloudWatch Logs. Rollback by calling cloudtrail:UpdateTrailCommand with CloudWatchLogsLogGroupArn set to an empty string.', + evidence: { + trailCount: trails.length, + trailsWithCloudWatch: 0, + }, + passed: false, + }), + ]; + } + } catch { + // If prerequisite check fails (permissions), fall through to existing behavior + } + + try { + // Fetch all metric filters (limit to 1000) + const allFilters = await this.fetchAllMetricFilters(logsClient); + + // Check each CIS control + for (const check of CIS_CHECKS) { + const matchingFilter = allFilters.find((filter) => { + const pattern = filter.filterPattern ?? ''; + return check.keywords.every((keyword) => pattern.includes(keyword)); + }); + + if (!matchingFilter) { + findings.push( + this.makeFinding({ + checkId: check.id, + title: `${check.name} — metric filter missing`, + description: `No CloudWatch metric filter found for CIS ${check.id} (${check.name}). A metric filter matching keywords [${check.keywords.join(', ')}] is required.`, + severity: 'medium', + remediation: `Step 1: Create a CloudWatch Logs metric filter using logs:PutMetricFilterCommand with logGroupName set to the CloudTrail log group, filterName set to "compai-cis-${check.id}-${check.name.toLowerCase().replace(/\s+/g, '-')}", filterPattern set to the required CIS pattern for ${check.name} matching keywords [${check.keywords.join(', ')}], and metricTransformations containing metricName, metricNamespace "CloudTrailMetrics", and metricValue "1". Step 2: Create an SNS topic using sns:CreateTopicCommand with Name "compai-cis-alerts" if one does not already exist. Step 3: Create a CloudWatch alarm using cloudwatch:PutMetricAlarmCommand with AlarmName "compai-cis-${check.id}-alarm", MetricName matching the filter metric, Namespace "CloudTrailMetrics", Statistic "Sum", Period 300, EvaluationPeriods 1, Threshold 1, ComparisonOperator "GreaterThanOrEqualToThreshold", and AlarmActions set to the SNS topic ARN. Rollback by deleting the alarm with cloudwatch:DeleteAlarmsCommand, deleting the metric filter with logs:DeleteMetricFilterCommand, and optionally deleting the SNS topic with sns:DeleteTopicCommand.`, + evidence: { keywords: check.keywords, filterFound: false }, + passed: false, + }), + ); + continue; + } + + // Check if an alarm exists for the metric + const metricName = + matchingFilter.metricTransformations?.[0]?.metricName; + + if (!metricName) { + findings.push( + this.makeFinding({ + checkId: check.id, + title: `${check.name} — no metric transformation`, + description: `Metric filter for CIS ${check.id} (${check.name}) exists but has no metric transformation configured.`, + severity: 'medium', + remediation: `Step 1: Update the existing metric filter using logs:PutMetricFilterCommand with logGroupName, filterName set to the existing filter name, filterPattern preserved, and metricTransformations containing metricName "compai-cis-${check.id}-metric", metricNamespace "CloudTrailMetrics", and metricValue "1". Step 2: Create an SNS topic using sns:CreateTopicCommand with Name "compai-cis-alerts" if one does not already exist. Step 3: Create a CloudWatch alarm using cloudwatch:PutMetricAlarmCommand with AlarmName "compai-cis-${check.id}-alarm", MetricName "compai-cis-${check.id}-metric", Namespace "CloudTrailMetrics", Statistic "Sum", Period 300, EvaluationPeriods 1, Threshold 1, ComparisonOperator "GreaterThanOrEqualToThreshold", and AlarmActions set to the SNS topic ARN. Rollback by deleting the alarm with cloudwatch:DeleteAlarmsCommand and removing the metric transformation by calling logs:PutMetricFilterCommand with the original filter settings.`, + evidence: { + filterName: matchingFilter.filterName, + metricTransformations: null, + }, + passed: false, + }), + ); + continue; + } + + const hasAlarm = await this.checkAlarmExists(cwClient, metricName); + + if (!hasAlarm) { + findings.push( + this.makeFinding({ + checkId: check.id, + title: `${check.name} — alarm missing`, + description: `Metric filter for CIS ${check.id} (${check.name}) exists with metric "${metricName}", but no CloudWatch alarm is configured for it.`, + severity: 'medium', + remediation: `Step 1: Create an SNS topic using sns:CreateTopicCommand with Name "compai-cis-alerts" if one does not already exist. Step 2: Create a CloudWatch alarm using cloudwatch:PutMetricAlarmCommand with AlarmName "compai-cis-${check.id}-alarm", MetricName "${metricName}", Namespace "CloudTrailMetrics", Statistic "Sum", Period 300, EvaluationPeriods 1, Threshold 1, ComparisonOperator "GreaterThanOrEqualToThreshold", and AlarmActions set to the SNS topic ARN. Rollback by deleting the alarm with cloudwatch:DeleteAlarmsCommand and optionally deleting the SNS topic with sns:DeleteTopicCommand.`, + evidence: { + filterName: matchingFilter.filterName, + metricName, + alarmExists: false, + }, + passed: false, + }), + ); + } else { + // Both filter and alarm exist — pass + findings.push( + this.makeFinding({ + checkId: check.id, + title: `${check.name} — monitoring configured`, + description: `CIS ${check.id} (${check.name}) has both a metric filter and alarm configured.`, + severity: 'info', + evidence: { + filterName: matchingFilter.filterName, + metricName, + alarmExists: true, + }, + passed: true, + }), + ); + } + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private async fetchAllMetricFilters(client: CloudWatchLogsClient): Promise< + { + filterName?: string; + filterPattern?: string; + metricTransformations?: { metricName?: string }[]; + }[] + > { + const filters: { + filterName?: string; + filterPattern?: string; + metricTransformations?: { metricName?: string }[]; + }[] = []; + + let nextToken: string | undefined; + + do { + const resp = await client.send( + new DescribeMetricFiltersCommand({ nextToken }), + ); + + for (const filter of resp.metricFilters ?? []) { + filters.push({ + filterName: filter.filterName, + filterPattern: filter.filterPattern, + metricTransformations: filter.metricTransformations?.map((t) => ({ + metricName: t.metricName, + })), + }); + } + + nextToken = resp.nextToken; + + if (filters.length >= 1000) break; + } while (nextToken); + + return filters; + } + + private async checkAlarmExists( + client: CloudWatchClient, + metricName: string, + ): Promise { + // Check common namespaces — customers may use any of these + const namespaces = [ + 'CloudTrailMetrics', + 'CompAI-CIS-Metrics', + 'CISBenchmark', + ]; + try { + for (const ns of namespaces) { + const resp = await client.send( + new DescribeAlarmsForMetricCommand({ + MetricName: metricName, + Namespace: ns, + }), + ); + if ((resp.MetricAlarms ?? []).length > 0) return true; + } + return false; + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return false; + throw error; + } + } + + private makeFinding(params: { + checkId: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation?: string; + evidence?: Record; + passed?: boolean; + }): SecurityFinding { + const id = `cloudwatch-${params.checkId}-${params.title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title: params.title, + description: params.description, + severity: params.severity, + resourceType: 'AwsCloudWatchAlarm', + resourceId: params.checkId, + remediation: params.remediation, + evidence: { ...params.evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed: params.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/codebuild.adapter.ts b/apps/api/src/cloud-security/providers/aws/codebuild.adapter.ts new file mode 100644 index 0000000000..82b852e8f0 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/codebuild.adapter.ts @@ -0,0 +1,122 @@ +import { + CodeBuildClient, + ListProjectsCommand, + BatchGetProjectsCommand, +} from '@aws-sdk/client-codebuild'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class CodeBuildAdapter implements AwsServiceAdapter { + readonly serviceId = 'codebuild'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new CodeBuildClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + const projectNames: string[] = []; + let nextToken: string | undefined; + + do { + const listRes = await client.send( + new ListProjectsCommand({ nextToken }), + ); + if (listRes.projects) { + projectNames.push(...listRes.projects); + } + nextToken = listRes.nextToken; + } while (nextToken); + + if (projectNames.length === 0) return findings; + + for (let i = 0; i < projectNames.length; i += 100) { + const batch = projectNames.slice(i, i + 100); + const batchRes = await client.send( + new BatchGetProjectsCommand({ names: batch }), + ); + + for (const project of batchRes.projects ?? []) { + const name = project.name ?? 'unknown'; + const arn = + project.arn ?? `arn:aws:codebuild:${region}:project/${name}`; + + if ( + !project.encryptionKey || + project.encryptionKey.includes('aws/codebuild') + ) { + findings.push( + this.makeFinding({ + id: `codebuild-default-encryption-${name}`, + title: 'Using default encryption key', + description: `CodeBuild project "${name}" uses the default AWS-managed encryption key instead of a customer-managed KMS key.`, + severity: 'low', + resourceId: arn, + evidence: { service: 'CodeBuild', projectName: name }, + remediation: `Use codebuild:UpdateProjectCommand with name set to "${name}" and encryptionKey set to a customer-managed KMS key ARN (arn:aws:kms:region:account:key/key-id). Rollback: use codebuild:UpdateProjectCommand with encryptionKey set to the default 'aws/codebuild' key ARN.`, + }), + ); + } + + if (project.environment?.privilegedMode === true) { + findings.push( + this.makeFinding({ + id: `codebuild-privileged-mode-${name}`, + title: 'Privileged mode enabled', + description: `CodeBuild project "${name}" has privileged mode enabled, granting the build container elevated permissions.`, + severity: 'medium', + resourceId: arn, + evidence: { service: 'CodeBuild', projectName: name }, + remediation: `Use codebuild:UpdateProjectCommand with name set to "${name}" and environment.privilegedMode set to false. Rollback: use codebuild:UpdateProjectCommand with environment.privilegedMode set to true. [MANUAL] Verify that the project does not require Docker-in-Docker builds before disabling, as this will break Docker image builds.`, + }), + ); + } + + const cwEnabled = + project.logsConfig?.cloudWatchLogs?.status === 'ENABLED'; + const s3Enabled = project.logsConfig?.s3Logs?.status === 'ENABLED'; + + if (!cwEnabled && !s3Enabled) { + findings.push( + this.makeFinding({ + id: `codebuild-no-logging-${name}`, + title: 'Build logging not configured', + description: `CodeBuild project "${name}" has neither CloudWatch nor S3 logging enabled.`, + severity: 'medium', + resourceId: arn, + evidence: { service: 'CodeBuild', projectName: name }, + remediation: `Use codebuild:UpdateProjectCommand with name set to "${name}" and logsConfig.cloudWatchLogs set to { status: 'ENABLED', groupName: '/aws/codebuild/${name}' }. Alternatively, set logsConfig.s3Logs to { status: 'ENABLED', location: 'bucket-name/prefix' }. Rollback: use codebuild:UpdateProjectCommand with logsConfig.cloudWatchLogs.status set to 'DISABLED'.`, + }), + ); + } + } + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + params: Omit & { + remediation?: string; + }, + ): SecurityFinding { + return { + ...params, + resourceType: 'AwsCodeBuildProject', + evidence: { ...params.evidence, findingKey: params.id }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/cognito.adapter.ts b/apps/api/src/cloud-security/providers/aws/cognito.adapter.ts new file mode 100644 index 0000000000..bd12955c0b --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/cognito.adapter.ts @@ -0,0 +1,186 @@ +import { + CognitoIdentityProviderClient, + ListUserPoolsCommand, + DescribeUserPoolCommand, +} from '@aws-sdk/client-cognito-identity-provider'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +const MIN_PASSWORD_LENGTH = 14; + +export class CognitoAdapter implements AwsServiceAdapter { + readonly serviceId = 'cognito'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + accountId, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new CognitoIdentityProviderClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const resp = await client.send( + new ListUserPoolsCommand({ + MaxResults: 60, + NextToken: nextToken, + }), + ); + + for (const pool of resp.UserPools ?? []) { + if (!pool.Id) continue; + + const poolFindings = await this.checkPool( + client, + pool.Id, + pool.Name ?? pool.Id, + region, + accountId, + ); + findings.push(...poolFindings); + } + + nextToken = resp.NextToken; + } while (nextToken); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private async checkPool( + client: CognitoIdentityProviderClient, + poolId: string, + poolName: string, + region: string, + accountId?: string, + ): Promise { + const findings: SecurityFinding[] = []; + + try { + const resp = await client.send( + new DescribeUserPoolCommand({ UserPoolId: poolId }), + ); + + const pool = resp.UserPool; + if (!pool) return []; + + const resourceId = pool.Arn ?? poolId; + + // Check MFA configuration + const mfaConfig = pool.MfaConfiguration; + if (mfaConfig === 'OFF') { + findings.push( + this.makeFinding({ + id: `cognito-mfa-off-${poolId}`, + title: `Cognito user pool "${poolName}" has MFA not enabled (${region})`, + description: `User pool ${poolName} (${poolId}) has multi-factor authentication disabled. Users can sign in with only a password.`, + severity: 'high', + resourceId, + remediation: `Use cognito-idp:SetUserPoolMfaConfigCommand with UserPoolId set to "${poolId}" and MfaConfiguration set to 'ON'. Configure SoftwareTokenMfaConfiguration.Enabled to true for TOTP, or SmsMfaConfiguration with SmsAuthenticationMessage and SmsConfiguration for SMS MFA. Rollback: use cognito-idp:SetUserPoolMfaConfigCommand with MfaConfiguration set to 'OFF'.`, + passed: false, + accountId, + region, + }), + ); + } else if (mfaConfig === 'OPTIONAL') { + findings.push( + this.makeFinding({ + id: `cognito-mfa-optional-${poolId}`, + title: `Cognito user pool "${poolName}" has MFA set to optional (${region})`, + description: `User pool ${poolName} (${poolId}) has MFA configured as optional. Users can choose to skip MFA enrollment.`, + severity: 'medium', + resourceId, + remediation: `Use cognito-idp:SetUserPoolMfaConfigCommand with UserPoolId set to "${poolId}" and MfaConfiguration set to 'ON' (enforced). Configure SoftwareTokenMfaConfiguration.Enabled to true. Rollback: use cognito-idp:SetUserPoolMfaConfigCommand with MfaConfiguration set to 'OPTIONAL'.`, + passed: false, + accountId, + region, + }), + ); + } + + // Check password policy + const minLength = pool.Policies?.PasswordPolicy?.MinimumLength ?? 0; + if (minLength < MIN_PASSWORD_LENGTH) { + findings.push( + this.makeFinding({ + id: `cognito-weak-password-${poolId}`, + title: `Cognito user pool "${poolName}" has weak password policy (${region})`, + description: `User pool ${poolName} (${poolId}) requires a minimum password length of ${minLength} characters. The recommended minimum is ${MIN_PASSWORD_LENGTH} characters.`, + severity: 'medium', + resourceId, + remediation: `Use cognito-idp:UpdateUserPoolCommand with UserPoolId set to "${poolId}" and Policies.PasswordPolicy.MinimumLength set to ${MIN_PASSWORD_LENGTH}. Also ensure RequireLowercase, RequireUppercase, RequireNumbers, and RequireSymbols are set to true. You must include all existing pool configuration to avoid resetting other settings. Rollback: use cognito-idp:UpdateUserPoolCommand with the previous MinimumLength value.`, + passed: false, + accountId, + region, + }), + ); + } + + // Check advanced security mode + const securityMode = pool.UserPoolAddOns?.AdvancedSecurityMode; + if (securityMode !== 'ENFORCED') { + findings.push( + this.makeFinding({ + id: `cognito-no-advanced-security-${poolId}`, + title: `Cognito user pool "${poolName}" does not have advanced security enforced (${region})`, + description: `User pool ${poolName} (${poolId}) does not have advanced security mode set to ENFORCED. Adaptive authentication and compromised credential detection are not fully active.`, + severity: 'low', + resourceId, + remediation: `Use cognito-idp:UpdateUserPoolCommand with UserPoolId set to "${poolId}" and UserPoolAddOns.AdvancedSecurityMode set to 'ENFORCED'. You must include all existing pool configuration to avoid resetting other settings. Rollback: use cognito-idp:UpdateUserPoolCommand with AdvancedSecurityMode set to 'AUDIT' or 'OFF'.`, + passed: false, + accountId, + region, + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding(opts: { + id: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + resourceId?: string; + remediation?: string; + passed: boolean; + accountId?: string; + region?: string; + }): SecurityFinding { + return { + id: opts.id, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'AwsCognitoUserPool', + resourceId: opts.resourceId || 'unknown', + remediation: opts.remediation, + evidence: { + awsAccountId: opts.accountId, + region: opts.region, + service: 'Cognito', + findingKey: opts.id, + }, + createdAt: new Date().toISOString(), + passed: opts.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/config.adapter.ts b/apps/api/src/cloud-security/providers/aws/config.adapter.ts new file mode 100644 index 0000000000..2994e1181c --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/config.adapter.ts @@ -0,0 +1,182 @@ +import { + ConfigServiceClient, + DescribeConfigurationRecordersCommand, + DescribeConfigurationRecorderStatusCommand, + DescribeDeliveryChannelsCommand, + DescribeDeliveryChannelStatusCommand, +} from '@aws-sdk/client-config-service'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class ConfigAdapter implements AwsServiceAdapter { + readonly serviceId = 'config'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new ConfigServiceClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + const [recorderResult, deliveryResult] = await Promise.allSettled([ + this.checkRecorders(client, region), + this.checkDeliveryChannels(client, region), + ]); + + if (recorderResult.status === 'fulfilled') { + findings.push(...recorderResult.value); + } + if (deliveryResult.status === 'fulfilled') { + findings.push(...deliveryResult.value); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private async checkRecorders( + client: ConfigServiceClient, + region: string, + ): Promise { + const findings: SecurityFinding[] = []; + + const { ConfigurationRecorders } = await client.send( + new DescribeConfigurationRecordersCommand({}), + ); + + if (!ConfigurationRecorders || ConfigurationRecorders.length === 0) { + findings.push( + this.makeFinding({ + id: `config-no-recorder-${region}`, + title: 'AWS Config recorder not configured', + description: `No AWS Config recorder found in ${region}.`, + severity: 'high', + resourceId: `arn:aws:config:${region}`, + remediation: + 'Step 1: Create a service-linked role using iam:CreateServiceLinkedRoleCommand with AWSServiceName set to "config.amazonaws.com" (skip if the role already exists). Step 2: Create a configuration recorder using config-service:PutConfigurationRecorderCommand with ConfigurationRecorder containing name "compai-config-recorder", roleARN set to the Config service role ARN, and recordingGroup with allSupported set to true. Step 3: Create a delivery channel using config-service:PutDeliveryChannelCommand with DeliveryChannel containing name "compai-delivery-channel" and s3BucketName set to the target bucket. Step 4: Start the recorder using config-service:StartConfigurationRecorderCommand with ConfigurationRecorderName "compai-config-recorder". Rollback by calling config-service:StopConfigurationRecorderCommand with ConfigurationRecorderName "compai-config-recorder".', + }), + ); + return findings; + } + + const { ConfigurationRecordersStatus } = await client.send( + new DescribeConfigurationRecorderStatusCommand({}), + ); + + const status = ConfigurationRecordersStatus?.[0]; + const recorder = ConfigurationRecorders[0]; + const isRecording = status?.recording === true; + const allSupported = recorder?.recordingGroup?.allSupported === true; + + if (isRecording && allSupported) { + findings.push( + this.makeFinding({ + id: `config-recorder-enabled-${region}`, + title: 'AWS Config recorder is active', + description: `AWS Config recorder in ${region} is recording all supported resources.`, + severity: 'info', + resourceId: recorder.name ?? `config-recorder-${region}`, + passed: true, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `config-recorder-incomplete-${region}`, + title: 'AWS Config recorder not fully active', + description: `AWS Config recorder in ${region} is ${!isRecording ? 'not recording' : 'not recording all supported resources'}.`, + severity: 'high', + resourceId: recorder.name ?? `config-recorder-${region}`, + remediation: + 'Use config-service:PutConfigurationRecorderCommand with ConfigurationRecorder containing the existing recorder name, roleARN, and recordingGroup with allSupported set to true. Then call config-service:StartConfigurationRecorderCommand with ConfigurationRecorderName set to the recorder name. Rollback by calling config-service:StopConfigurationRecorderCommand with ConfigurationRecorderName set to the recorder name.', + }), + ); + } + + return findings; + } + + private async checkDeliveryChannels( + client: ConfigServiceClient, + region: string, + ): Promise { + const findings: SecurityFinding[] = []; + + const { DeliveryChannels } = await client.send( + new DescribeDeliveryChannelsCommand({}), + ); + + if (!DeliveryChannels || DeliveryChannels.length === 0) { + findings.push( + this.makeFinding({ + id: `config-no-delivery-channel-${region}`, + title: 'AWS Config delivery channel not configured', + description: `No delivery channel found for AWS Config in ${region}.`, + severity: 'medium', + resourceId: `arn:aws:config:${region}`, + remediation: + 'Use config-service:PutDeliveryChannelCommand with DeliveryChannel containing name "compai-delivery-channel" and s3BucketName set to the target logging bucket. Rollback by calling config-service:DeleteDeliveryChannelCommand with DeliveryChannelName "compai-delivery-channel". Note: the configuration recorder must be stopped before deleting a delivery channel.', + }), + ); + return findings; + } + + const { DeliveryChannelsStatus } = await client.send( + new DescribeDeliveryChannelStatusCommand({}), + ); + + const channel = DeliveryChannels[0]; + const hasS3 = !!channel?.s3BucketName; + const statusOk = DeliveryChannelsStatus?.[0] !== undefined; + + if (hasS3 && statusOk) { + findings.push( + this.makeFinding({ + id: `config-delivery-channel-ok-${region}`, + title: 'AWS Config delivery channel configured', + description: `Delivery channel in ${region} is configured with S3 bucket ${channel.s3BucketName}.`, + severity: 'info', + resourceId: channel.name ?? `config-delivery-${region}`, + passed: true, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `config-delivery-channel-issue-${region}`, + title: 'AWS Config delivery channel misconfigured', + description: `Delivery channel in ${region} is missing an S3 bucket configuration.`, + severity: 'medium', + resourceId: channel?.name ?? `config-delivery-${region}`, + remediation: + 'Use config-service:PutDeliveryChannelCommand with DeliveryChannel containing the existing channel name and s3BucketName set to the target logging bucket. Rollback by calling config-service:PutDeliveryChannelCommand with the original delivery channel settings.', + }), + ); + } + + return findings; + } + + private makeFinding( + params: Omit & { + remediation?: string; + }, + ): SecurityFinding { + return { + ...params, + evidence: { ...(params.evidence ?? {}), findingKey: params.id }, + resourceType: 'AwsConfigRecorder', + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/dynamodb.adapter.ts b/apps/api/src/cloud-security/providers/aws/dynamodb.adapter.ts new file mode 100644 index 0000000000..0c150e043b --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/dynamodb.adapter.ts @@ -0,0 +1,189 @@ +import { + DescribeContinuousBackupsCommand, + DescribeTableCommand, + DynamoDBClient, + ListTablesCommand, +} from '@aws-sdk/client-dynamodb'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class DynamoDbAdapter implements AwsServiceAdapter { + readonly serviceId = 'dynamodb'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new DynamoDBClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let exclusiveStartTableName: string | undefined; + + do { + const listRes = await client.send( + new ListTablesCommand({ + ExclusiveStartTableName: exclusiveStartTableName, + }), + ); + + for (const tableName of listRes.TableNames ?? []) { + const descRes = await client.send( + new DescribeTableCommand({ TableName: tableName }), + ); + + const table = descRes.Table; + if (!table) continue; + + const resourceId = table.TableArn ?? tableName; + + // Check SSE configuration + const sse = table.SSEDescription; + if (sse?.Status === 'ENABLED' && sse.SSEType === 'KMS') { + findings.push( + this.makeFinding( + resourceId, + 'DynamoDB table uses CMK encryption', + `Table "${tableName}" is encrypted with a customer-managed KMS key`, + 'info', + { tableName, sseType: sse.SSEType, sseStatus: sse.Status }, + true, + ), + ); + } else if (sse?.Status === 'ENABLED') { + findings.push( + this.makeFinding( + resourceId, + 'DynamoDB table uses default AWS-owned key', + `Table "${tableName}" uses the default AWS-owned encryption key instead of a customer-managed KMS key`, + 'low', + { + tableName, + sseType: sse.SSEType ?? 'DEFAULT', + sseStatus: sse.Status, + }, + undefined, + `Use dynamodb:UpdateTableCommand with TableName set to "${tableName}" and SSESpecification.SSEEnabled set to true and SSESpecification.SSEType set to 'KMS'. Optionally provide SSESpecification.KMSMasterKeyId for a specific CMK. Rollback by setting SSESpecification.SSEEnabled to false to revert to the default AWS-owned key.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'DynamoDB table uses default AWS-owned key', + `Table "${tableName}" does not have customer-managed encryption configured`, + 'medium', + { tableName, sseStatus: sse?.Status ?? 'NOT_CONFIGURED' }, + undefined, + `Use dynamodb:UpdateTableCommand with TableName set to "${tableName}" and SSESpecification.SSEEnabled set to true and SSESpecification.SSEType set to 'KMS'. Optionally provide SSESpecification.KMSMasterKeyId for a specific CMK. Rollback by setting SSESpecification.SSEEnabled to false to revert to the default AWS-owned key.`, + ), + ); + } + + // Check Point-in-Time Recovery + try { + const backupRes = await client.send( + new DescribeContinuousBackupsCommand({ TableName: tableName }), + ); + + const pitrStatus = + backupRes.ContinuousBackupsDescription + ?.PointInTimeRecoveryDescription?.PointInTimeRecoveryStatus; + + if (pitrStatus === 'ENABLED') { + findings.push( + this.makeFinding( + resourceId, + 'DynamoDB point-in-time recovery is enabled', + `Table "${tableName}" has point-in-time recovery enabled`, + 'info', + { tableName, pitrStatus }, + true, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'DynamoDB point-in-time recovery is disabled', + `Table "${tableName}" does not have point-in-time recovery enabled`, + 'medium', + { tableName, pitrStatus: pitrStatus ?? 'DISABLED' }, + undefined, + `Use dynamodb:UpdateContinuousBackupsCommand with TableName set to "${tableName}" and PointInTimeRecoverySpecification.PointInTimeRecoveryEnabled set to true. Rollback by setting PointInTimeRecoveryEnabled to false.`, + ), + ); + } + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (!msg.includes('AccessDenied')) throw error; + } + + // Check Deletion Protection + if (table.DeletionProtectionEnabled === true) { + findings.push( + this.makeFinding( + resourceId, + 'DynamoDB deletion protection is enabled', + `Table "${tableName}" has deletion protection enabled`, + 'info', + { tableName, deletionProtection: true }, + true, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'DynamoDB deletion protection is disabled', + `Table "${tableName}" does not have deletion protection enabled`, + 'medium', + { tableName, deletionProtection: false }, + undefined, + `Use dynamodb:UpdateTableCommand with TableName set to "${tableName}" and DeletionProtectionEnabled set to true. Rollback by setting DeletionProtectionEnabled to false.`, + ), + ); + } + } + + exclusiveStartTableName = listRes.LastEvaluatedTableName; + } while (exclusiveStartTableName); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `dynamodb-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsDynamoDbTable', + resourceId, + remediation, + evidence: { ...evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/ec2-vpc.adapter.ts b/apps/api/src/cloud-security/providers/aws/ec2-vpc.adapter.ts new file mode 100644 index 0000000000..0f536f48dd --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/ec2-vpc.adapter.ts @@ -0,0 +1,295 @@ +import { + EC2Client, + DescribeSecurityGroupsCommand, + GetEbsEncryptionByDefaultCommand, + DescribeVpcsCommand, + DescribeFlowLogsCommand, + DescribeInstancesCommand, +} from '@aws-sdk/client-ec2'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +/** Ports that should never be open to 0.0.0.0/0 */ +const SENSITIVE_PORTS: Record = { + 22: 'SSH', + 3389: 'RDP', + 3306: 'MySQL', + 5432: 'PostgreSQL', + 1433: 'MSSQL', + 27017: 'MongoDB', + 6379: 'Redis', + 9200: 'Elasticsearch', +}; + +export class Ec2VpcAdapter implements AwsServiceAdapter { + readonly serviceId = 'ec2-vpc'; + readonly isGlobal = false; + + async scan(params: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const { credentials, region, accountId } = params; + const client = new EC2Client({ region, credentials }); + + const results = await Promise.allSettled([ + this.checkSecurityGroups(client, region, accountId), + this.checkEbsEncryptionDefault(client, region, accountId), + this.checkVpcFlowLogs(client, region, accountId), + ]); + + const findings: SecurityFinding[] = []; + for (const result of results) { + if (result.status === 'fulfilled') { + findings.push(...result.value); + } + } + return findings; + } + + private async checkSecurityGroups( + client: EC2Client, + region: string, + accountId?: string, + ): Promise { + const findings: SecurityFinding[] = []; + + let nextToken: string | undefined; + do { + const resp = await client.send( + new DescribeSecurityGroupsCommand({ + MaxResults: 100, + NextToken: nextToken, + }), + ); + + for (const sg of resp.SecurityGroups || []) { + if (!sg.GroupId) continue; + + for (const rule of sg.IpPermissions || []) { + const openRanges = [ + ...(rule.IpRanges || []).filter((r) => r.CidrIp === '0.0.0.0/0'), + ...(rule.Ipv6Ranges || []).filter((r) => r.CidrIpv6 === '::/0'), + ]; + + if (openRanges.length === 0) continue; + + // Check if this rule exposes a sensitive port + const fromPort = rule.FromPort ?? 0; + const toPort = rule.ToPort ?? 65535; + + for (const [port, service] of Object.entries(SENSITIVE_PORTS)) { + const portNum = Number(port); + if (fromPort <= portNum && portNum <= toPort) { + findings.push( + this.makeFinding({ + id: `ec2-sg-open-${sg.GroupId}-${portNum}`, + title: `Security group "${sg.GroupName || sg.GroupId}" allows ${service} (${portNum}) from 0.0.0.0/0 (${region})`, + description: `Security group ${sg.GroupId} in VPC ${sg.VpcId || 'default'} allows unrestricted inbound access on port ${portNum} (${service}). This exposes the service to the entire internet.`, + severity: + portNum === 22 || portNum === 3389 ? 'high' : 'critical', + resourceType: 'AwsEc2SecurityGroup', + resourceId: sg.GroupId, + remediation: `Use ec2:RevokeSecurityGroupIngressCommand with GroupId set to '${sg.GroupId}' and IpPermissions containing FromPort: ${portNum}, ToPort: ${portNum}, IpProtocol: 'tcp', and IpRanges: [{ CidrIp: '0.0.0.0/0' }] to remove the open rule. Then use ec2:AuthorizeSecurityGroupIngressCommand with restricted CidrIp values. Rollback: use ec2:AuthorizeSecurityGroupIngressCommand with the original 0.0.0.0/0 CidrIp.`, + passed: false, + accountId, + region, + }), + ); + } + } + + // Check for "all traffic" rule (protocol -1) + if (rule.IpProtocol === '-1') { + findings.push( + this.makeFinding({ + id: `ec2-sg-all-traffic-${sg.GroupId}`, + title: `Security group "${sg.GroupName || sg.GroupId}" allows all traffic from 0.0.0.0/0 (${region})`, + description: `Security group ${sg.GroupId} allows all inbound traffic from any source. This is a critical security risk.`, + severity: 'critical', + resourceType: 'AwsEc2SecurityGroup', + resourceId: sg.GroupId, + remediation: `Use ec2:RevokeSecurityGroupIngressCommand with GroupId set to '${sg.GroupId}' and IpPermissions containing IpProtocol: '-1' and IpRanges: [{ CidrIp: '0.0.0.0/0' }]. Then use ec2:AuthorizeSecurityGroupIngressCommand to add specific port/protocol rules with restricted CIDR ranges. Rollback: use ec2:AuthorizeSecurityGroupIngressCommand with IpProtocol '-1' and CidrIp '0.0.0.0/0'.`, + passed: false, + accountId, + region, + }), + ); + } + } + } + + nextToken = resp.NextToken; + } while (nextToken); + + return findings; + } + + private async checkEbsEncryptionDefault( + client: EC2Client, + region: string, + accountId?: string, + ): Promise { + // Prerequisite: skip EBS default encryption check if no instances or volumes exist + try { + const instanceResp = await client.send( + new DescribeInstancesCommand({ MaxResults: 5 }), + ); + const hasInstances = (instanceResp.Reservations ?? []).some( + (r) => (r.Instances ?? []).length > 0, + ); + if (!hasInstances) return []; + } catch { + // If prerequisite check fails (permissions), fall through to existing behavior + } + + const resp = await client.send(new GetEbsEncryptionByDefaultCommand({})); + + if (!resp.EbsEncryptionByDefault) { + return [ + this.makeFinding({ + id: `ec2-ebs-encryption-default-${region}`, + title: `EBS encryption by default is disabled (${region})`, + description: `New EBS volumes in ${region} are not encrypted by default. Unencrypted volumes may expose sensitive data.`, + severity: 'medium', + resourceType: 'AwsAccount', + resourceId: `${region}/ebs-default-encryption`, + remediation: `Use ec2:EnableEbsEncryptionByDefaultCommand (no parameters required, applies to the current region). Optionally use ec2:ModifyEbsDefaultKmsKeyIdCommand with KmsKeyId to set a specific CMK. Only new volumes will be encrypted; existing unencrypted volumes are not affected. Rollback: use ec2:DisableEbsEncryptionByDefaultCommand.`, + passed: false, + accountId, + region, + }), + ]; + } + + return [ + this.makeFinding({ + id: `ec2-ebs-encryption-default-${region}`, + title: `EBS encryption by default is enabled (${region})`, + description: `New EBS volumes are encrypted by default in ${region}.`, + severity: 'info', + resourceType: 'AwsAccount', + resourceId: `${region}/ebs-default-encryption`, + passed: true, + accountId, + region, + }), + ]; + } + + private async checkVpcFlowLogs( + client: EC2Client, + region: string, + accountId?: string, + ): Promise { + const findings: SecurityFinding[] = []; + + const vpcsResp = await client.send(new DescribeVpcsCommand({})); + const vpcs = vpcsResp.Vpcs || []; + + if (vpcs.length === 0) return findings; + + const flowLogsResp = await client.send( + new DescribeFlowLogsCommand({ + Filter: [{ Name: 'resource-type', Values: ['VPC'] }], + }), + ); + + const vpcsWithFlowLogs = new Set( + (flowLogsResp.FlowLogs || []).map((fl) => fl.ResourceId), + ); + + for (const vpc of vpcs) { + if (!vpc.VpcId) continue; + + // Skip default VPC if it has no running instances + if (vpc.IsDefault) { + try { + const instanceResp = await client.send( + new DescribeInstancesCommand({ + MaxResults: 5, + Filters: [ + { Name: 'vpc-id', Values: [vpc.VpcId] }, + { Name: 'instance-state-name', Values: ['running'] }, + ], + }), + ); + const hasRunning = (instanceResp.Reservations ?? []).some( + (r) => (r.Instances ?? []).length > 0, + ); + if (!hasRunning) continue; + } catch { + // If check fails (permissions), fall through to existing behavior + } + } + + const nameTag = vpc.Tags?.find((t) => t.Key === 'Name')?.Value; + const label = nameTag ? `"${nameTag}" (${vpc.VpcId})` : vpc.VpcId; + + if (!vpcsWithFlowLogs.has(vpc.VpcId)) { + findings.push( + this.makeFinding({ + id: `vpc-no-flow-logs-${vpc.VpcId}`, + title: `VPC ${label} has no flow logs enabled (${region})`, + description: `VPC ${vpc.VpcId} in ${region} does not have flow logs enabled. Network traffic is not being monitored.`, + severity: 'medium', + resourceType: 'AwsEc2Vpc', + resourceId: vpc.VpcId, + remediation: `Use ec2:CreateFlowLogsCommand with ResourceIds set to ['${vpc.VpcId}'], ResourceType set to 'VPC', TrafficType set to 'ALL', LogDestinationType set to 'cloud-watch-logs', and LogGroupName set to '/aws/vpc-flow-logs/${vpc.VpcId}'. You must provide DeliverLogsPermissionArn with an IAM role ARN that can publish to CloudWatch Logs. Rollback: use ec2:DeleteFlowLogsCommand with the FlowLogIds returned from the create call.`, + passed: false, + accountId, + region, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `vpc-flow-logs-${vpc.VpcId}`, + title: `VPC ${label} has flow logs enabled (${region})`, + description: `Flow logs are enabled for VPC ${vpc.VpcId}.`, + severity: 'info', + resourceType: 'AwsEc2Vpc', + resourceId: vpc.VpcId, + passed: true, + accountId, + region, + }), + ); + } + } + + return findings; + } + + private makeFinding(opts: { + id: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + resourceType?: string; + resourceId?: string; + remediation?: string; + passed: boolean; + accountId?: string; + region?: string; + }): SecurityFinding { + return { + id: opts.id, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: opts.resourceType || 'AwsEc2Instance', + resourceId: opts.resourceId || 'unknown', + remediation: opts.remediation, + evidence: { + awsAccountId: opts.accountId, + region: opts.region, + service: 'EC2/VPC', + findingKey: opts.id, + }, + createdAt: new Date().toISOString(), + passed: opts.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/ecr.adapter.ts b/apps/api/src/cloud-security/providers/aws/ecr.adapter.ts new file mode 100644 index 0000000000..efc2b1fab3 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/ecr.adapter.ts @@ -0,0 +1,105 @@ +import { ECRClient, DescribeRepositoriesCommand } from '@aws-sdk/client-ecr'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class EcrAdapter implements AwsServiceAdapter { + readonly serviceId = 'ecr'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new ECRClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const response = await client.send( + new DescribeRepositoriesCommand({ nextToken }), + ); + + for (const repo of response.repositories ?? []) { + const repoName = repo.repositoryName ?? 'unknown'; + const repoArn = + repo.repositoryArn ?? `arn:aws:ecr:${region}:repo/${repoName}`; + + if (repo.imageScanningConfiguration?.scanOnPush !== true) { + findings.push( + this.makeFinding({ + id: `ecr-scan-on-push-disabled-${repoName}`, + title: `ECR scan on push disabled for ${repoName}`, + description: `Repository ${repoName} does not have image scan on push enabled.`, + severity: 'medium', + resourceId: repoArn, + remediation: `Use ecr:PutImageScanningConfigurationCommand with repositoryName set to "${repoName}" and imageScanningConfiguration.scanOnPush set to true. Rollback by setting scanOnPush to false.`, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `ecr-scan-on-push-enabled-${repoName}`, + title: `ECR scan on push enabled for ${repoName}`, + description: `Repository ${repoName} has image scan on push enabled.`, + severity: 'info', + resourceId: repoArn, + passed: true, + }), + ); + } + + if (repo.imageTagMutability !== 'IMMUTABLE') { + findings.push( + this.makeFinding({ + id: `ecr-tag-mutable-${repoName}`, + title: `ECR image tags mutable for ${repoName}`, + description: `Repository ${repoName} allows image tag overwriting. Tags should be immutable.`, + severity: 'low', + resourceId: repoArn, + remediation: `Use ecr:PutImageTagMutabilityCommand with repositoryName set to "${repoName}" and imageTagMutability set to 'IMMUTABLE'. Rollback by setting imageTagMutability to 'MUTABLE'.`, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `ecr-tag-immutable-${repoName}`, + title: `ECR image tags immutable for ${repoName}`, + description: `Repository ${repoName} has immutable image tags configured.`, + severity: 'info', + resourceId: repoArn, + passed: true, + }), + ); + } + } + + nextToken = response.nextToken; + } while (nextToken); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + params: Omit & { + remediation?: string; + }, + ): SecurityFinding { + return { + ...params, + evidence: { ...(params.evidence ?? {}), findingKey: params.id }, + resourceType: 'AwsEcrRepository', + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/ecs-eks.adapter.ts b/apps/api/src/cloud-security/providers/aws/ecs-eks.adapter.ts new file mode 100644 index 0000000000..797cf3055a --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/ecs-eks.adapter.ts @@ -0,0 +1,231 @@ +import { + ECSClient, + ListTaskDefinitionsCommand, + DescribeTaskDefinitionCommand, +} from '@aws-sdk/client-ecs'; +import { + EKSClient, + ListClustersCommand, + DescribeClusterCommand, +} from '@aws-sdk/client-eks'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +const EXPECTED_EKS_LOG_TYPES = [ + 'api', + 'audit', + 'authenticator', + 'controllerManager', + 'scheduler', +]; + +export class EcsEksAdapter implements AwsServiceAdapter { + readonly serviceId = 'ecs-eks'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const findings: SecurityFinding[] = []; + + try { + await this.scanEcs({ credentials, region, findings }); + await this.scanEks({ credentials, region, findings }); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private async scanEcs({ + credentials, + region, + findings, + }: { + credentials: AwsCredentials; + region: string; + findings: SecurityFinding[]; + }): Promise { + const client = new ECSClient({ credentials, region }); + + let nextToken: string | undefined; + let taskDefArns: string[] = []; + + do { + const resp = await client.send( + new ListTaskDefinitionsCommand({ + status: 'ACTIVE', + nextToken, + }), + ); + + taskDefArns = taskDefArns.concat(resp.taskDefinitionArns ?? []); + nextToken = resp.nextToken; + + // Limit to first 50 task definitions + if (taskDefArns.length >= 50) { + taskDefArns = taskDefArns.slice(0, 50); + break; + } + } while (nextToken); + + for (const taskDefArn of taskDefArns) { + try { + const resp = await client.send( + new DescribeTaskDefinitionCommand({ taskDefinition: taskDefArn }), + ); + + const containers = resp.taskDefinition?.containerDefinitions ?? []; + + for (const container of containers) { + if (container.privileged === true) { + findings.push( + this.makeFinding({ + resourceId: taskDefArn, + resourceType: 'AwsEcsTaskDefinition', + title: `Container ${container.name ?? 'unknown'} runs in privileged mode`, + description: `ECS task definition ${taskDefArn} has container "${container.name}" running in privileged mode. Privileged containers have full access to the host.`, + severity: 'high', + remediation: + "[MANUAL] Cannot be auto-fixed for existing running tasks. Register a new task definition revision using ecs:RegisterTaskDefinitionCommand with the container definition's privileged field set to false, then update the service using ecs:UpdateServiceCommand with the new taskDefinition ARN. Rollback: register another revision with privileged set to true and update the service.", + evidence: { + containerName: container.name, + privileged: true, + }, + }), + ); + } + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return; + } + } + } + + private async scanEks({ + credentials, + region, + findings, + }: { + credentials: AwsCredentials; + region: string; + findings: SecurityFinding[]; + }): Promise { + const client = new EKSClient({ credentials, region }); + + let nextToken: string | undefined; + + do { + const resp = await client.send(new ListClustersCommand({ nextToken })); + + const clusterNames = resp.clusters ?? []; + + for (const clusterName of clusterNames) { + try { + const descResp = await client.send( + new DescribeClusterCommand({ name: clusterName }), + ); + + const cluster = descResp.cluster; + if (!cluster) continue; + + const clusterArn = cluster.arn ?? clusterName; + + // Check cluster logging + const clusterLogging = cluster.logging?.clusterLogging ?? []; + const enabledTypes = new Set(); + + for (const logSetup of clusterLogging) { + if (logSetup.enabled) { + for (const logType of logSetup.types ?? []) { + enabledTypes.add(logType); + } + } + } + + const disabledTypes = EXPECTED_EKS_LOG_TYPES.filter( + (t) => !enabledTypes.has(t), + ); + + if (disabledTypes.length > 0) { + findings.push( + this.makeFinding({ + resourceId: clusterArn, + resourceType: 'AwsEksCluster', + title: 'EKS cluster logging incomplete', + description: `EKS cluster ${clusterName} does not have all recommended log types enabled. Missing: ${disabledTypes.join(', ')}.`, + severity: 'medium', + remediation: `Use eks:UpdateClusterConfigCommand with name set to '${clusterName}' and logging.clusterLogging set to [{ types: ['api', 'audit', 'authenticator', 'controllerManager', 'scheduler'], enabled: true }]. Rollback: use eks:UpdateClusterConfigCommand with enabled set to false for the added log types.`, + evidence: { + enabledTypes: [...enabledTypes], + disabledTypes, + }, + }), + ); + } + + // Check public API endpoint + const vpcConfig = cluster.resourcesVpcConfig; + if (vpcConfig?.endpointPublicAccess === true) { + const cidrs = vpcConfig.publicAccessCidrs ?? []; + if (cidrs.includes('0.0.0.0/0')) { + findings.push( + this.makeFinding({ + resourceId: clusterArn, + resourceType: 'AwsEksCluster', + title: 'EKS API publicly accessible', + description: `EKS cluster ${clusterName} has its API endpoint publicly accessible from any IP address (0.0.0.0/0).`, + severity: 'high', + remediation: `Use eks:UpdateClusterConfigCommand with name set to '${clusterName}' and resourcesVpcConfig.endpointPublicAccess set to false (or keep true and set publicAccessCidrs to specific CIDR ranges instead of '0.0.0.0/0'). Ensure endpointPrivateAccess is set to true if disabling public access. Rollback: use eks:UpdateClusterConfigCommand with endpointPublicAccess set to true and publicAccessCidrs set to ['0.0.0.0/0'].`, + evidence: { + endpointPublicAccess: true, + publicAccessCidrs: cidrs, + }, + }), + ); + } + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return; + } + } + + nextToken = resp.nextToken; + } while (nextToken); + } + + private makeFinding(params: { + resourceId: string; + resourceType: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation?: string; + evidence?: Record; + passed?: boolean; + }): SecurityFinding { + const id = `ecs-eks-${params.resourceId}-${params.title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title: params.title, + description: params.description, + severity: params.severity, + resourceType: params.resourceType, + resourceId: params.resourceId, + remediation: params.remediation, + evidence: { ...params.evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed: params.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/efs.adapter.ts b/apps/api/src/cloud-security/providers/aws/efs.adapter.ts new file mode 100644 index 0000000000..dc01d2d7c7 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/efs.adapter.ts @@ -0,0 +1,92 @@ +import { EFSClient, DescribeFileSystemsCommand } from '@aws-sdk/client-efs'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class EfsAdapter implements AwsServiceAdapter { + readonly serviceId = 'efs'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new EFSClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let marker: string | undefined; + + do { + const res = await client.send( + new DescribeFileSystemsCommand({ Marker: marker }), + ); + + for (const fs of res.FileSystems ?? []) { + const resourceId = fs.FileSystemId ?? 'unknown'; + + if (fs.Encrypted !== true) { + findings.push( + this.makeFinding( + resourceId, + 'EFS not encrypted at rest', + `EFS file system "${resourceId}" is not encrypted at rest`, + 'high', + { fileSystemId: resourceId, encrypted: false }, + undefined, + `[MANUAL] Cannot be auto-fixed. EFS encryption at rest must be set at file system creation time and cannot be changed afterward. To fix: create a new encrypted EFS file system using efs:CreateFileSystemCommand with Encrypted set to true and optionally KmsKeyId for a customer-managed CMK, migrate data from the unencrypted file system using AWS DataSync, update all mount targets and application references to point to the new file system, then delete the old unencrypted file system using efs:DeleteFileSystemCommand.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'EFS encrypted at rest', + `EFS file system "${resourceId}" is encrypted at rest`, + 'info', + { fileSystemId: resourceId, encrypted: true }, + true, + ), + ); + } + } + + marker = res.NextMarker; + } while (marker); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `efs-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsEfsFileSystem', + resourceId, + remediation, + evidence: { ...evidence, service: 'EFS', findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/elastic-beanstalk.adapter.ts b/apps/api/src/cloud-security/providers/aws/elastic-beanstalk.adapter.ts new file mode 100644 index 0000000000..fd7c422665 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/elastic-beanstalk.adapter.ts @@ -0,0 +1,164 @@ +import { + DescribeConfigurationSettingsCommand, + DescribeEnvironmentsCommand, + ElasticBeanstalkClient, +} from '@aws-sdk/client-elastic-beanstalk'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class ElasticBeanstalkAdapter implements AwsServiceAdapter { + readonly serviceId = 'elastic-beanstalk'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new ElasticBeanstalkClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + const envRes = await client.send( + new DescribeEnvironmentsCommand({ IncludeDeleted: false }), + ); + + for (const env of envRes.Environments ?? []) { + const envName = env.EnvironmentName ?? 'unknown'; + const envId = env.EnvironmentId ?? envName; + const envArn = env.EnvironmentArn ?? envId; + + if (env.HealthStatus && env.HealthStatus !== 'Ok') { + findings.push( + this.makeFinding( + envArn, + 'Environment health status is not Ok', + `Environment "${envName}" has health status "${env.HealthStatus}"`, + 'low', + { + environmentName: envName, + healthStatus: env.HealthStatus, + service: 'Elastic Beanstalk', + }, + false, + `[MANUAL] Cannot be auto-fixed. Investigate the environment health by reviewing recent events and logs. Use elasticbeanstalk:DescribeEventsCommand with EnvironmentName set to '${envName}' to check for errors. Common causes include failed deployments, instance health issues, or resource limits.`, + ), + ); + } + + const appName = env.ApplicationName; + if (!appName) continue; + + const configRes = await client.send( + new DescribeConfigurationSettingsCommand({ + ApplicationName: appName, + EnvironmentName: envName, + }), + ); + + const settings = + configRes.ConfigurationSettings?.[0]?.OptionSettings ?? []; + + const managedActionsOpt = settings.find( + (s) => + s.Namespace === 'aws:elasticbeanstalk:managedactions' && + s.OptionName === 'ManagedActionsEnabled', + ); + + if (managedActionsOpt?.Value !== 'true') { + findings.push( + this.makeFinding( + envArn, + 'Managed platform updates not enabled', + `Environment "${envName}" does not have managed platform updates enabled — updates must be applied manually`, + 'medium', + { + environmentName: envName, + managedActionsEnabled: managedActionsOpt?.Value ?? 'not set', + service: 'Elastic Beanstalk', + }, + false, + `Use elasticbeanstalk:UpdateEnvironmentCommand with EnvironmentName set to '${envName}' and OptionSettings containing Namespace 'aws:elasticbeanstalk:managedactions', OptionName 'ManagedActionsEnabled', Value 'true'. Also set 'aws:elasticbeanstalk:managedactions:platformupdate' with UpdateLevel 'minor' and PreferredStartTime. Rollback: use elasticbeanstalk:UpdateEnvironmentCommand with ManagedActionsEnabled set to 'false'.`, + ), + ); + } + + const healthReportingOpt = settings.find( + (s) => + s.Namespace === 'aws:elasticbeanstalk:healthreporting:system' && + s.OptionName === 'SystemType', + ); + + if (healthReportingOpt?.Value !== 'enhanced') { + findings.push( + this.makeFinding( + envArn, + 'Enhanced health reporting not enabled', + `Environment "${envName}" does not use enhanced health reporting — basic reporting provides limited visibility`, + 'medium', + { + environmentName: envName, + systemType: healthReportingOpt?.Value ?? 'not set', + service: 'Elastic Beanstalk', + }, + false, + `Use elasticbeanstalk:UpdateEnvironmentCommand with EnvironmentName set to '${envName}' and OptionSettings containing Namespace 'aws:elasticbeanstalk:healthreporting:system', OptionName 'SystemType', Value 'enhanced'. Rollback: use elasticbeanstalk:UpdateEnvironmentCommand with SystemType set to 'basic'.`, + ), + ); + } + + const isHealthy = + (!env.HealthStatus || env.HealthStatus === 'Ok') && + managedActionsOpt?.Value === 'true' && + healthReportingOpt?.Value === 'enhanced'; + + if (isHealthy) { + findings.push( + this.makeFinding( + envArn, + 'Environment is well configured', + `Environment "${envName}" has managed updates and enhanced health reporting enabled`, + 'info', + { environmentName: envName, service: 'Elastic Beanstalk' }, + true, + ), + ); + } + } + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `elastic-beanstalk-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsElasticBeanstalkEnvironment', + resourceId, + remediation, + evidence: { ...evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/elasticache.adapter.ts b/apps/api/src/cloud-security/providers/aws/elasticache.adapter.ts new file mode 100644 index 0000000000..0e765367c8 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/elasticache.adapter.ts @@ -0,0 +1,205 @@ +import { + ElastiCacheClient, + DescribeReplicationGroupsCommand, + DescribeCacheClustersCommand, +} from '@aws-sdk/client-elasticache'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class ElastiCacheAdapter implements AwsServiceAdapter { + readonly serviceId = 'elasticache'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + accountId, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new ElastiCacheClient({ credentials, region }); + + try { + const rgFindings = await this.checkReplicationGroups( + client, + region, + accountId, + ); + + if (rgFindings.length > 0) { + return rgFindings; + } + + // Fall back to individual cache clusters if no replication groups + return await this.checkCacheClusters(client, region, accountId); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + } + + private async checkReplicationGroups( + client: ElastiCacheClient, + region: string, + accountId?: string, + ): Promise { + const findings: SecurityFinding[] = []; + let marker: string | undefined; + + do { + const resp = await client.send( + new DescribeReplicationGroupsCommand({ Marker: marker }), + ); + + for (const group of resp.ReplicationGroups ?? []) { + if (!group.ReplicationGroupId) continue; + + const groupId = group.ReplicationGroupId; + const resourceId = group.ARN ?? groupId; + + if (group.TransitEncryptionEnabled !== true) { + findings.push( + this.makeFinding({ + id: `elasticache-no-transit-encryption-${groupId}`, + title: `ElastiCache replication group "${groupId}" has encryption in transit disabled (${region})`, + description: `Replication group ${groupId} does not have encryption in transit enabled. Data transmitted between nodes and clients is not encrypted.`, + severity: 'high', + resourceId, + remediation: `[MANUAL] Cannot be auto-fixed. ElastiCache in-transit encryption requires recreating the replication group. To fix: create a new replication group with TransitEncryptionEnabled set to true, migrate data, then delete the old group.`, + passed: false, + accountId, + region, + }), + ); + } + + if (group.AtRestEncryptionEnabled !== true) { + findings.push( + this.makeFinding({ + id: `elasticache-no-rest-encryption-${groupId}`, + title: `ElastiCache replication group "${groupId}" has encryption at rest disabled (${region})`, + description: `Replication group ${groupId} does not have encryption at rest enabled. Cached data stored on disk is not encrypted.`, + severity: 'high', + resourceId, + remediation: `[MANUAL] Cannot be auto-fixed. ElastiCache at-rest encryption requires recreating the replication group. To fix: create a new replication group with AtRestEncryptionEnabled set to true, migrate data, then delete the old group.`, + passed: false, + accountId, + region, + }), + ); + } + + if (group.AuthTokenEnabled !== true) { + findings.push( + this.makeFinding({ + id: `elasticache-no-auth-token-${groupId}`, + title: `ElastiCache replication group "${groupId}" has AUTH token not enabled (${region})`, + description: `Replication group ${groupId} does not require an AUTH token for client connections. Any client with network access can connect without authentication.`, + severity: 'medium', + resourceId, + remediation: `Use elasticache:ModifyReplicationGroupCommand with ReplicationGroupId and AuthToken to set a new AUTH token, and AuthTokenUpdateStrategy set to SET. Requires TransitEncryptionEnabled to be true. Rollback by calling elasticache:ModifyReplicationGroupCommand with AuthTokenUpdateStrategy set to DELETE.`, + passed: false, + accountId, + region, + }), + ); + } + } + + marker = resp.Marker; + } while (marker); + + return findings; + } + + private async checkCacheClusters( + client: ElastiCacheClient, + region: string, + accountId?: string, + ): Promise { + const findings: SecurityFinding[] = []; + let marker: string | undefined; + + do { + const resp = await client.send( + new DescribeCacheClustersCommand({ Marker: marker }), + ); + + for (const cluster of resp.CacheClusters ?? []) { + if (!cluster.CacheClusterId) continue; + + const clusterId = cluster.CacheClusterId; + const resourceId = cluster.ARN ?? clusterId; + + if (cluster.TransitEncryptionEnabled !== true) { + findings.push( + this.makeFinding({ + id: `elasticache-cluster-no-transit-encryption-${clusterId}`, + title: `ElastiCache cluster "${clusterId}" has encryption in transit disabled (${region})`, + description: `Cache cluster ${clusterId} does not have encryption in transit enabled. Data transmitted between the cluster and clients is not encrypted.`, + severity: 'high', + resourceId, + remediation: `[MANUAL] Cannot be auto-fixed. ElastiCache in-transit encryption requires recreating the replication group. To fix: create a new replication group with TransitEncryptionEnabled set to true, migrate data, then delete the old group.`, + passed: false, + accountId, + region, + }), + ); + } + + if (cluster.AtRestEncryptionEnabled !== true) { + findings.push( + this.makeFinding({ + id: `elasticache-cluster-no-rest-encryption-${clusterId}`, + title: `ElastiCache cluster "${clusterId}" has encryption at rest disabled (${region})`, + description: `Cache cluster ${clusterId} does not have encryption at rest enabled. Cached data stored on disk is not encrypted.`, + severity: 'high', + resourceId, + remediation: `[MANUAL] Cannot be auto-fixed. ElastiCache at-rest encryption requires recreating the replication group. To fix: create a new replication group with AtRestEncryptionEnabled set to true, migrate data, then delete the old group.`, + passed: false, + accountId, + region, + }), + ); + } + } + + marker = resp.Marker; + } while (marker); + + return findings; + } + + private makeFinding(opts: { + id: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + resourceId?: string; + remediation?: string; + passed: boolean; + accountId?: string; + region?: string; + }): SecurityFinding { + return { + id: opts.id, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'AwsElastiCacheCluster', + resourceId: opts.resourceId || 'unknown', + remediation: opts.remediation, + evidence: { + awsAccountId: opts.accountId, + region: opts.region, + service: 'ElastiCache', + findingKey: opts.id, + }, + createdAt: new Date().toISOString(), + passed: opts.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/elb.adapter.ts b/apps/api/src/cloud-security/providers/aws/elb.adapter.ts new file mode 100644 index 0000000000..70638571c2 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/elb.adapter.ts @@ -0,0 +1,156 @@ +import { + ElasticLoadBalancingV2Client, + DescribeLoadBalancersCommand, + DescribeListenersCommand, + DescribeLoadBalancerAttributesCommand, +} from '@aws-sdk/client-elastic-load-balancing-v2'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class ElbAdapter implements AwsServiceAdapter { + readonly serviceId = 'elb'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new ElasticLoadBalancingV2Client({ + credentials, + region, + }); + + const findings: SecurityFinding[] = []; + + try { + let marker: string | undefined; + + do { + const resp = await client.send( + new DescribeLoadBalancersCommand({ Marker: marker }), + ); + + const loadBalancers = resp.LoadBalancers ?? []; + + for (const lb of loadBalancers) { + const arn = lb.LoadBalancerArn ?? 'unknown'; + + // Check listeners for HTTPS/TLS + try { + const listenersResp = await client.send( + new DescribeListenersCommand({ LoadBalancerArn: arn }), + ); + const listeners = listenersResp.Listeners ?? []; + const hasSecureListener = listeners.some( + (l) => l.Protocol === 'HTTPS' || l.Protocol === 'TLS', + ); + + if (!hasSecureListener && listeners.length > 0) { + findings.push( + this.makeFinding({ + resourceId: arn, + title: 'No HTTPS listeners configured', + description: `Load balancer ${lb.LoadBalancerName} has no HTTPS or TLS listeners. Traffic is transmitted unencrypted.`, + severity: 'high', + remediation: `Use elbv2:CreateListenerCommand with LoadBalancerArn set to '${arn}', Protocol set to 'HTTPS', Port set to 443, and Certificates containing the ACM certificate ARN. Set DefaultActions to forward to the target group. Rollback: use elbv2:DeleteListenerCommand with the ListenerArn returned from the create call.`, + evidence: { + protocols: listeners.map((l) => l.Protocol), + }, + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + } + + // Check attributes for access logging and deletion protection + try { + const attrsResp = await client.send( + new DescribeLoadBalancerAttributesCommand({ + LoadBalancerArn: arn, + }), + ); + const attrs = attrsResp.Attributes ?? []; + + const accessLogsAttr = attrs.find( + (a) => a.Key === 'access_logs.s3.enabled', + ); + if (accessLogsAttr?.Value !== 'true') { + findings.push( + this.makeFinding({ + resourceId: arn, + title: 'Access logging disabled', + description: `Load balancer ${lb.LoadBalancerName} does not have access logging enabled.`, + severity: 'medium', + remediation: `Use elbv2:ModifyLoadBalancerAttributesCommand with LoadBalancerArn set to '${arn}' and Attributes containing Key: 'access_logs.s3.enabled', Value: 'true' and Key: 'access_logs.s3.bucket', Value: '' and Key: 'access_logs.s3.prefix', Value: ''. Rollback: use elbv2:ModifyLoadBalancerAttributesCommand with 'access_logs.s3.enabled' set to 'false'.`, + evidence: { + accessLogsEnabled: accessLogsAttr?.Value ?? 'not set', + }, + }), + ); + } + + const deletionProtectionAttr = attrs.find( + (a) => a.Key === 'deletion_protection.enabled', + ); + if (deletionProtectionAttr?.Value !== 'true') { + findings.push( + this.makeFinding({ + resourceId: arn, + title: 'Deletion protection disabled', + description: `Load balancer ${lb.LoadBalancerName} does not have deletion protection enabled.`, + severity: 'medium', + remediation: `Use elbv2:ModifyLoadBalancerAttributesCommand with LoadBalancerArn set to '${arn}' and Attributes containing Key: 'deletion_protection.enabled', Value: 'true'. Rollback: use elbv2:ModifyLoadBalancerAttributesCommand with Key: 'deletion_protection.enabled', Value: 'false'.`, + evidence: { + deletionProtectionEnabled: + deletionProtectionAttr?.Value ?? 'not set', + }, + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + } + } + + marker = resp.NextMarker; + } while (marker); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding(params: { + resourceId: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation?: string; + evidence?: Record; + passed?: boolean; + }): SecurityFinding { + const id = `elb-${params.resourceId}-${params.title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title: params.title, + description: params.description, + severity: params.severity, + resourceType: 'AwsElbLoadBalancer', + resourceId: params.resourceId, + remediation: params.remediation, + evidence: { ...(params.evidence ?? {}), findingKey: id }, + createdAt: new Date().toISOString(), + passed: params.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/emr.adapter.ts b/apps/api/src/cloud-security/providers/aws/emr.adapter.ts new file mode 100644 index 0000000000..b8ac527eeb --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/emr.adapter.ts @@ -0,0 +1,171 @@ +import { + EMRClient, + ListClustersCommand, + DescribeClusterCommand, +} from '@aws-sdk/client-emr'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class EmrAdapter implements AwsServiceAdapter { + readonly serviceId = 'emr'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new EMRClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let marker: string | undefined; + + do { + const listRes = await client.send( + new ListClustersCommand({ + ClusterStates: ['STARTING', 'BOOTSTRAPPING', 'RUNNING', 'WAITING'], + Marker: marker, + }), + ); + + for (const clusterSummary of listRes.Clusters ?? []) { + const clusterId = clusterSummary.Id; + if (!clusterId) continue; + + const descRes = await client.send( + new DescribeClusterCommand({ ClusterId: clusterId }), + ); + + const cluster = descRes.Cluster; + if (!cluster) continue; + + const clusterName = cluster.Name ?? clusterId; + const resourceId = `arn:aws:elasticmapreduce:${region}:cluster/${clusterId}`; + + // Check security configuration + if (!cluster.SecurityConfiguration) { + findings.push( + this.makeFinding( + resourceId, + 'No security configuration applied', + `EMR cluster "${clusterName}" (${clusterId}) does not have a security configuration applied`, + 'medium', + { clusterId, clusterName, securityConfiguration: null }, + false, + `[MANUAL] Cannot be auto-fixed on a running cluster. Security configurations can only be set at cluster launch time. Create a security configuration using emr:CreateSecurityConfigurationCommand with Name and SecurityConfiguration (JSON string with EncryptionConfiguration, AuthenticationConfiguration). Then terminate the cluster using emr:TerminateJobFlowsCommand and relaunch with emr:RunJobFlowCommand specifying SecurityConfiguration.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'Security configuration applied', + `EMR cluster "${clusterName}" (${clusterId}) has a security configuration applied`, + 'info', + { + clusterId, + clusterName, + securityConfiguration: cluster.SecurityConfiguration, + }, + true, + ), + ); + } + + // Check logging configuration + if (!cluster.LogUri) { + findings.push( + this.makeFinding( + resourceId, + 'Logging not configured', + `EMR cluster "${clusterName}" (${clusterId}) does not have logging configured`, + 'medium', + { clusterId, clusterName, logUri: null }, + false, + `[MANUAL] Cannot be auto-fixed on a running cluster. Logging must be configured at cluster launch time. Use emr:RunJobFlowCommand with LogUri set to an S3 path (e.g., 's3://bucket-name/emr-logs/') when creating a new cluster. The current cluster must be terminated and relaunched with logging enabled.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'Logging configured', + `EMR cluster "${clusterName}" (${clusterId}) has logging configured`, + 'info', + { clusterId, clusterName, logUri: cluster.LogUri }, + true, + ), + ); + } + + // Check termination protection + if (cluster.TerminationProtected !== true) { + findings.push( + this.makeFinding( + resourceId, + 'Termination protection disabled', + `EMR cluster "${clusterName}" (${clusterId}) does not have termination protection enabled`, + 'low', + { + clusterId, + clusterName, + terminationProtected: cluster.TerminationProtected, + }, + false, + `Use emr:SetTerminationProtectionCommand with JobFlowIds set to ['${clusterId}'] and TerminationProtected set to true. Rollback: use emr:SetTerminationProtectionCommand with TerminationProtected set to false.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'Termination protection enabled', + `EMR cluster "${clusterName}" (${clusterId}) has termination protection enabled`, + 'info', + { clusterId, clusterName, terminationProtected: true }, + true, + ), + ); + } + } + + marker = listRes.Marker; + } while (marker); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `emr-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsEmrCluster', + resourceId, + remediation, + evidence: { ...evidence, service: 'EMR', findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/eventbridge.adapter.ts b/apps/api/src/cloud-security/providers/aws/eventbridge.adapter.ts new file mode 100644 index 0000000000..e89df43f6b --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/eventbridge.adapter.ts @@ -0,0 +1,129 @@ +import { + DescribeEventBusCommand, + EventBridgeClient, + ListEventBusesCommand, +} from '@aws-sdk/client-eventbridge'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class EventBridgeAdapter implements AwsServiceAdapter { + readonly serviceId = 'eventbridge'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new EventBridgeClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + const listRes = await client.send(new ListEventBusesCommand({})); + const buses = listRes.EventBuses ?? []; + + const customBuses = buses.filter((b) => b.Name !== 'default'); + + if (customBuses.length === 0) { + findings.push( + this.makeFinding( + `arn:aws:events:${region}:default-only`, + 'Only default event bus exists', + 'No custom event buses found — only the default bus is present', + 'info', + { region }, + true, + ), + ); + } + + for (const bus of buses) { + const busName = bus.Name ?? 'unknown'; + const busArn = bus.Arn ?? `arn:aws:events:${region}:${busName}`; + + const descRes = await client.send( + new DescribeEventBusCommand({ Name: busName }), + ); + + const policyStr = descRes.Policy; + if (!policyStr) continue; + + let policy: Record; + try { + policy = JSON.parse(policyStr) as Record; + } catch { + continue; + } + + const statements = Array.isArray(policy.Statement) + ? (policy.Statement as Record[]) + : []; + + for (const stmt of statements) { + if (stmt.Effect !== 'Allow') continue; + + const principal = stmt.Principal; + const hasCondition = + stmt.Condition != null && + typeof stmt.Condition === 'object' && + Object.keys(stmt.Condition).length > 0; + + const isPublic = + principal === '*' || + (typeof principal === 'object' && + principal !== null && + (principal as Record).AWS === '*'); + + if (isPublic && !hasCondition) { + findings.push( + this.makeFinding( + busArn, + 'Event bus has public access policy', + `Event bus "${busName}" has a resource policy granting public access without restrictive conditions`, + 'high', + { busName, service: 'EventBridge' }, + false, + `Use events:PutPermissionCommand with EventBusName set to '${busName}' and Policy set to a JSON policy string with restricted Principal (specific AWS account IDs instead of '*') and Condition keys (e.g., aws:PrincipalOrgID). Alternatively, use events:RemovePermissionCommand with EventBusName and StatementId to remove the public statement. Rollback: use events:PutPermissionCommand to restore the original policy.`, + ), + ); + break; + } + } + } + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `eventbridge-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsEventBridgeBus', + resourceId, + remediation, + evidence: { ...evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/glue.adapter.ts b/apps/api/src/cloud-security/providers/aws/glue.adapter.ts new file mode 100644 index 0000000000..b7396b323b --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/glue.adapter.ts @@ -0,0 +1,195 @@ +import { + GlueClient, + GetDataCatalogEncryptionSettingsCommand, + GetDatabasesCommand, + GetJobsCommand, +} from '@aws-sdk/client-glue'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class GlueAdapter implements AwsServiceAdapter { + readonly serviceId = 'glue'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new GlueClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + // Prerequisite: check if there are any Glue databases or jobs + try { + const dbResp = await client.send( + new GetDatabasesCommand({ MaxResults: 1 }), + ); + const hasDBs = (dbResp.DatabaseList ?? []).length > 0; + + if (!hasDBs) { + const jobsResp = await client.send( + new GetJobsCommand({ MaxResults: 1 }), + ); + const hasJobs = (jobsResp.Jobs ?? []).length > 0; + + if (!hasJobs) return []; + } + } catch { + // If prerequisite check fails (permissions), fall through to existing behavior + } + + try { + // Check Data Catalog encryption settings + const catalogRes = await client.send( + new GetDataCatalogEncryptionSettingsCommand({}), + ); + + const encSettings = + catalogRes.DataCatalogEncryptionSettings?.EncryptionAtRest; + const catalogId = `arn:aws:glue:${region}:catalog`; + + if (encSettings?.CatalogEncryptionMode === 'DISABLED') { + findings.push( + this.makeFinding( + catalogId, + 'AwsGlueCatalog', + 'Data catalog not encrypted', + `Glue Data Catalog in ${region} does not have encryption at rest enabled`, + 'medium', + { catalogEncryptionMode: encSettings.CatalogEncryptionMode }, + false, + `Use glue:PutDataCatalogEncryptionSettingsCommand with DataCatalogEncryptionSettings.EncryptionAtRest.CatalogEncryptionMode set to 'SSE-KMS' and SseAwsKmsKeyId set to a KMS key ARN. Rollback: use glue:PutDataCatalogEncryptionSettingsCommand with CatalogEncryptionMode set to 'DISABLED'. Note: disabling encryption does not decrypt existing encrypted objects.`, + ), + ); + } else { + findings.push( + this.makeFinding( + catalogId, + 'AwsGlueCatalog', + 'Data catalog encryption enabled', + `Glue Data Catalog in ${region} has encryption at rest enabled (${encSettings?.CatalogEncryptionMode})`, + 'info', + { catalogEncryptionMode: encSettings?.CatalogEncryptionMode }, + true, + ), + ); + } + + const connPwdEnc = + catalogRes.DataCatalogEncryptionSettings?.ConnectionPasswordEncryption; + + if (connPwdEnc?.ReturnConnectionPasswordEncrypted !== true) { + findings.push( + this.makeFinding( + catalogId, + 'AwsGlueCatalog', + 'Connection passwords not encrypted', + `Glue Data Catalog in ${region} does not encrypt connection passwords`, + 'medium', + { + returnConnectionPasswordEncrypted: + connPwdEnc?.ReturnConnectionPasswordEncrypted, + }, + false, + `Use glue:PutDataCatalogEncryptionSettingsCommand with DataCatalogEncryptionSettings.ConnectionPasswordEncryption.ReturnConnectionPasswordEncrypted set to true and AwsKmsKeyId set to a KMS key ARN. Rollback: use glue:PutDataCatalogEncryptionSettingsCommand with ReturnConnectionPasswordEncrypted set to false.`, + ), + ); + } else { + findings.push( + this.makeFinding( + catalogId, + 'AwsGlueCatalog', + 'Connection passwords encrypted', + `Glue Data Catalog in ${region} encrypts connection passwords`, + 'info', + { returnConnectionPasswordEncrypted: true }, + true, + ), + ); + } + + // Check Glue Jobs for security configuration + let nextToken: string | undefined; + + do { + const jobsRes = await client.send( + new GetJobsCommand({ NextToken: nextToken }), + ); + + for (const job of jobsRes.Jobs ?? []) { + const jobName = job.Name ?? 'unknown'; + const resourceId = `arn:aws:glue:${region}:job/${jobName}`; + const hasEncryptionArg = + job.DefaultArguments?.['--encryption-type'] !== undefined; + + if (!job.SecurityConfiguration && !hasEncryptionArg) { + findings.push( + this.makeFinding( + resourceId, + 'AwsGlueJob', + 'Glue job has no security configuration', + `Glue job "${jobName}" does not have a security configuration or encryption type set`, + 'low', + { jobName, securityConfiguration: null }, + false, + `First create a security configuration using glue:CreateSecurityConfigurationCommand with Name and EncryptionConfiguration (S3Encryption, CloudWatchEncryption, JobBookmarksEncryption). Then use glue:UpdateJobCommand with JobName set to '${jobName}' and JobUpdate.SecurityConfiguration set to the security configuration name. Rollback: use glue:UpdateJobCommand with SecurityConfiguration set to empty string.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'AwsGlueJob', + 'Glue job has security configuration', + `Glue job "${jobName}" has a security configuration applied`, + 'info', + { + jobName, + securityConfiguration: job.SecurityConfiguration ?? null, + }, + true, + ), + ); + } + } + + nextToken = jobsRes.NextToken; + } while (nextToken); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + resourceType: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `glue-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType, + resourceId, + remediation, + evidence: { ...evidence, service: 'Glue', findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/guardduty.adapter.ts b/apps/api/src/cloud-security/providers/aws/guardduty.adapter.ts new file mode 100644 index 0000000000..74de9b63d1 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/guardduty.adapter.ts @@ -0,0 +1,117 @@ +import { + GuardDutyClient, + ListDetectorsCommand, + GetDetectorCommand, +} from '@aws-sdk/client-guardduty'; +import { EC2Client, DescribeInstancesCommand } from '@aws-sdk/client-ec2'; +import { LambdaClient, ListFunctionsCommand } from '@aws-sdk/client-lambda'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class GuardDutyAdapter implements AwsServiceAdapter { + readonly serviceId = 'guardduty'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new GuardDutyClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + // Prerequisite: check if there are any resources in this region + try { + const ec2Client = new EC2Client({ credentials, region }); + const ec2Resp = await ec2Client.send( + new DescribeInstancesCommand({ MaxResults: 5 }), + ); + const hasEc2 = (ec2Resp.Reservations ?? []).some( + (r) => (r.Instances ?? []).length > 0, + ); + + if (!hasEc2) { + const lambdaClient = new LambdaClient({ credentials, region }); + const lambdaResp = await lambdaClient.send( + new ListFunctionsCommand({ MaxItems: 1 }), + ); + const hasLambda = (lambdaResp.Functions ?? []).length > 0; + + if (!hasLambda) return []; + } + } catch { + // If prerequisite check fails (permissions), fall through to existing behavior + } + + try { + const { DetectorIds } = await client.send(new ListDetectorsCommand({})); + + if (!DetectorIds || DetectorIds.length === 0) { + findings.push( + this.makeFinding({ + id: `guardduty-not-enabled-${region}`, + title: 'GuardDuty not enabled', + description: `GuardDuty is not enabled in ${region}. No detectors found.`, + severity: 'high', + resourceId: `arn:aws:guardduty:${region}`, + remediation: + 'Step 1: Ensure a service-linked role exists by calling iam:CreateServiceLinkedRoleCommand with AWSServiceName set to "guardduty.amazonaws.com" (skip if the role already exists). Step 2: Use guardduty:CreateDetectorCommand with Enable set to true. Rollback by calling guardduty:DeleteDetectorCommand with the DetectorId returned from creation.', + }), + ); + return findings; + } + + for (const detectorId of DetectorIds) { + const detector = await client.send( + new GetDetectorCommand({ DetectorId: detectorId }), + ); + + if (detector.Status !== 'ENABLED') { + findings.push( + this.makeFinding({ + id: `guardduty-disabled-${detectorId}`, + title: 'GuardDuty detector is disabled', + description: `GuardDuty detector ${detectorId} in ${region} is not enabled.`, + severity: 'high', + resourceId: detectorId, + remediation: `Use guardduty:UpdateDetectorCommand with DetectorId set to "${detectorId}" and Enable set to true. Rollback by calling guardduty:UpdateDetectorCommand with DetectorId set to "${detectorId}" and Enable set to false.`, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `guardduty-enabled-${detectorId}`, + title: 'GuardDuty detector is enabled', + description: `GuardDuty detector ${detectorId} in ${region} is enabled.`, + severity: 'info', + resourceId: detectorId, + passed: true, + }), + ); + } + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + params: Omit & { + remediation?: string; + }, + ): SecurityFinding { + return { + ...params, + resourceType: 'AwsGuardDutyDetector', + evidence: { ...params.evidence, findingKey: params.id }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/iam.adapter.ts b/apps/api/src/cloud-security/providers/aws/iam.adapter.ts new file mode 100644 index 0000000000..e11423920b --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/iam.adapter.ts @@ -0,0 +1,307 @@ +import { + IAMClient, + GetAccountPasswordPolicyCommand, + ListUsersCommand, + ListMFADevicesCommand, + ListAccessKeysCommand, + GetAccountSummaryCommand, +} from '@aws-sdk/client-iam'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +const STALE_KEY_DAYS = 90; + +export class IamAdapter implements AwsServiceAdapter { + readonly serviceId = 'iam-analyzer'; + readonly isGlobal = true; + + async scan(params: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const { credentials, region, accountId } = params; + const iam = new IAMClient({ region, credentials }); + + const findings: SecurityFinding[] = []; + + const results = await Promise.allSettled([ + this.checkPasswordPolicy(iam, accountId), + this.checkUsersWithoutMfa(iam, accountId), + this.checkStaleAccessKeys(iam, accountId), + this.checkRootAccessKeys(iam, accountId), + ]); + + for (const result of results) { + if (result.status === 'fulfilled') { + findings.push(...result.value); + } + } + + return findings; + } + + private async checkPasswordPolicy( + iam: IAMClient, + accountId?: string, + ): Promise { + const findings: SecurityFinding[] = []; + + try { + const resp = await iam.send(new GetAccountPasswordPolicyCommand({})); + const policy = resp.PasswordPolicy; + + if (!policy) { + findings.push( + this.makeFinding({ + id: 'iam-no-password-policy', + title: 'No IAM password policy configured', + description: + 'The AWS account does not have a custom password policy. Default password requirements may be insufficient.', + severity: 'high', + remediation: + 'Use iam:UpdateAccountPasswordPolicyCommand with MinimumPasswordLength set to 14, RequireSymbols, RequireNumbers, RequireUppercaseCharacters, RequireLowercaseCharacters all set to true, MaxPasswordAge set to 90, PasswordReusePrevention set to 24. Rollback by restoring previous password policy values.', + passed: false, + accountId, + }), + ); + return findings; + } + + if (!policy.MinimumPasswordLength || policy.MinimumPasswordLength < 14) { + findings.push( + this.makeFinding({ + id: 'iam-weak-password-length', + title: 'IAM password policy minimum length is below 14 characters', + description: `Password policy requires only ${policy.MinimumPasswordLength || 'default'} characters. CIS recommends at least 14.`, + severity: 'medium', + remediation: + 'Use iam:UpdateAccountPasswordPolicyCommand with MinimumPasswordLength set to 14, RequireSymbols, RequireNumbers, RequireUppercaseCharacters, RequireLowercaseCharacters all set to true, MaxPasswordAge set to 90, PasswordReusePrevention set to 24. Rollback by restoring previous password policy values.', + passed: false, + accountId, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: 'iam-password-length-ok', + title: 'IAM password policy minimum length meets requirements', + description: `Password policy requires ${policy.MinimumPasswordLength} characters (minimum 14).`, + severity: 'info', + passed: true, + accountId, + }), + ); + } + + if ( + !policy.RequireUppercaseCharacters || + !policy.RequireLowercaseCharacters || + !policy.RequireNumbers || + !policy.RequireSymbols + ) { + findings.push( + this.makeFinding({ + id: 'iam-weak-password-complexity', + title: 'IAM password policy does not require all character types', + description: + 'Password policy should require uppercase, lowercase, numbers, and symbols.', + severity: 'medium', + remediation: + 'Use iam:UpdateAccountPasswordPolicyCommand with MinimumPasswordLength set to 14, RequireSymbols, RequireNumbers, RequireUppercaseCharacters, RequireLowercaseCharacters all set to true, MaxPasswordAge set to 90, PasswordReusePrevention set to 24. Rollback by restoring previous password policy values.', + passed: false, + accountId, + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('NoSuchEntity')) { + findings.push( + this.makeFinding({ + id: 'iam-no-password-policy', + title: 'No IAM password policy configured', + description: + 'The AWS account does not have a custom password policy.', + severity: 'high', + remediation: + 'Use iam:UpdateAccountPasswordPolicyCommand with MinimumPasswordLength set to 14, RequireSymbols, RequireNumbers, RequireUppercaseCharacters, RequireLowercaseCharacters all set to true, MaxPasswordAge set to 90, PasswordReusePrevention set to 24. Rollback by restoring previous password policy values.', + passed: false, + accountId, + }), + ); + } else { + throw error; + } + } + + return findings; + } + + private async checkUsersWithoutMfa( + iam: IAMClient, + accountId?: string, + ): Promise { + const findings: SecurityFinding[] = []; + const users = await this.listAllUsers(iam); + + for (const user of users) { + if (!user.UserName) continue; + + const mfaResp = await iam.send( + new ListMFADevicesCommand({ UserName: user.UserName }), + ); + + const hasMfa = mfaResp.MFADevices && mfaResp.MFADevices.length > 0; + + if (!hasMfa) { + findings.push( + this.makeFinding({ + id: `iam-no-mfa-${user.UserName}`, + title: `IAM user "${user.UserName}" does not have MFA enabled`, + description: `User ${user.UserName} has no MFA device configured, increasing account compromise risk.`, + severity: 'high', + resourceType: 'AwsIamUser', + resourceId: user.Arn || user.UserName, + remediation: `[MANUAL] Cannot be auto-fixed. MFA device registration requires physical access to the authentication device. Enable MFA via the IAM Console for each user.`, + passed: false, + accountId, + }), + ); + } + } + + return findings; + } + + private async checkStaleAccessKeys( + iam: IAMClient, + accountId?: string, + ): Promise { + const findings: SecurityFinding[] = []; + const users = await this.listAllUsers(iam); + const now = Date.now(); + + for (const user of users) { + if (!user.UserName) continue; + + const keysResp = await iam.send( + new ListAccessKeysCommand({ UserName: user.UserName }), + ); + + for (const key of keysResp.AccessKeyMetadata || []) { + if (key.Status !== 'Active' || !key.CreateDate) continue; + + const ageDays = Math.floor( + (now - key.CreateDate.getTime()) / (1000 * 60 * 60 * 24), + ); + + if (ageDays > STALE_KEY_DAYS) { + findings.push( + this.makeFinding({ + id: `iam-stale-key-${user.UserName}-${key.AccessKeyId}`, + title: `IAM access key for "${user.UserName}" is ${ageDays} days old`, + description: `Access key ${key.AccessKeyId} for user ${user.UserName} was created ${ageDays} days ago. Keys older than ${STALE_KEY_DAYS} days should be rotated.`, + severity: ageDays > 180 ? 'high' : 'medium', + resourceType: 'AwsIamAccessKey', + resourceId: key.AccessKeyId || 'unknown', + remediation: `Use iam:UpdateAccessKeyCommand with UserName, AccessKeyId, and Status set to 'Inactive' to deactivate the stale key. Rollback by setting Status to 'Active'.`, + passed: false, + accountId, + }), + ); + } + } + } + + return findings; + } + + private async checkRootAccessKeys( + iam: IAMClient, + accountId?: string, + ): Promise { + const resp = await iam.send(new GetAccountSummaryCommand({})); + const summary = resp.SummaryMap; + + if (!summary) return []; + + const rootKeys = summary['AccountAccessKeysPresent']; + + if (rootKeys && rootKeys > 0) { + return [ + this.makeFinding({ + id: 'iam-root-access-keys', + title: 'Root account has active access keys', + description: + 'The root account has active access keys. Root access keys provide unrestricted access and should be removed.', + severity: 'critical', + resourceType: 'AwsAccount', + resourceId: accountId || 'root', + remediation: + '[MANUAL] Cannot be auto-fixed. Root access keys must be deleted manually through the AWS Console root account security credentials page.', + passed: false, + accountId, + }), + ]; + } + + return [ + this.makeFinding({ + id: 'iam-root-access-keys', + title: 'Root account has no active access keys', + description: 'The root account does not have active access keys.', + severity: 'info', + passed: true, + accountId, + }), + ]; + } + + private async listAllUsers(iam: IAMClient) { + const users: Array<{ + UserName?: string; + Arn?: string; + }> = []; + + let marker: string | undefined; + do { + const resp = await iam.send( + new ListUsersCommand({ Marker: marker, MaxItems: 100 }), + ); + if (resp.Users) users.push(...resp.Users); + marker = resp.IsTruncated ? resp.Marker : undefined; + } while (marker); + + return users; + } + + private makeFinding(opts: { + id: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + resourceType?: string; + resourceId?: string; + remediation?: string; + passed: boolean; + accountId?: string; + }): SecurityFinding { + return { + id: opts.id, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: opts.resourceType || 'AwsIamPolicy', + resourceId: opts.resourceId || 'account-level', + remediation: opts.remediation, + evidence: { + awsAccountId: opts.accountId, + service: 'IAM', + findingKey: opts.id, + }, + createdAt: new Date().toISOString(), + passed: opts.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/inspector.adapter.ts b/apps/api/src/cloud-security/providers/aws/inspector.adapter.ts new file mode 100644 index 0000000000..d898256e10 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/inspector.adapter.ts @@ -0,0 +1,158 @@ +import { + Inspector2Client, + BatchGetAccountStatusCommand, +} from '@aws-sdk/client-inspector2'; +import { EC2Client, DescribeInstancesCommand } from '@aws-sdk/client-ec2'; +import { ECRClient, DescribeRepositoriesCommand } from '@aws-sdk/client-ecr'; +import { LambdaClient, ListFunctionsCommand } from '@aws-sdk/client-lambda'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class InspectorAdapter implements AwsServiceAdapter { + readonly serviceId = 'inspector'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new Inspector2Client({ credentials, region }); + const findings: SecurityFinding[] = []; + + // Prerequisite: check if there are scannable resources (EC2, ECR, Lambda) + let hasEc2 = false; + let hasEcr = false; + let hasLambda = false; + + try { + const ec2Client = new EC2Client({ credentials, region }); + const ec2Resp = await ec2Client.send( + new DescribeInstancesCommand({ MaxResults: 5 }), + ); + hasEc2 = (ec2Resp.Reservations ?? []).some( + (r) => (r.Instances ?? []).length > 0, + ); + + const ecrClient = new ECRClient({ credentials, region }); + const ecrResp = await ecrClient.send( + new DescribeRepositoriesCommand({ maxResults: 1 }), + ); + hasEcr = (ecrResp.repositories ?? []).length > 0; + + const lambdaClient = new LambdaClient({ credentials, region }); + const lambdaResp = await lambdaClient.send( + new ListFunctionsCommand({ MaxItems: 1 }), + ); + hasLambda = (lambdaResp.Functions ?? []).length > 0; + + if (!hasEc2 && !hasEcr && !hasLambda) return []; + } catch { + // If prerequisite check fails (permissions), fall through to existing behavior + hasEc2 = true; + hasEcr = true; + hasLambda = true; + } + + try { + const response = await client.send( + new BatchGetAccountStatusCommand({ accountIds: [] }), + ); + + const account = response.accounts?.[0]; + + if (!account?.resourceState) { + findings.push( + this.makeFinding({ + id: `inspector-no-status-${region}`, + title: 'Inspector not enabled', + description: `AWS Inspector could not retrieve account status in ${region}.`, + severity: 'medium', + resourceId: `arn:aws:inspector2:${region}`, + remediation: `Use inspector2:EnableCommand with resourceTypes set to ['EC2', 'ECR', 'LAMBDA', 'LAMBDA_CODE'] and accountIds set to the target account ID. Rollback: use inspector2:DisableCommand with the same resourceTypes and accountIds.`, + }), + ); + return findings; + } + + const resourceState = account.resourceState; + // Only check scan types for resources that actually exist + const scanTypes = [ + ...(hasEc2 ? [{ name: 'EC2', status: resourceState.ec2?.status }] : []), + ...(hasEcr ? [{ name: 'ECR', status: resourceState.ecr?.status }] : []), + ...(hasLambda + ? [ + { name: 'Lambda', status: resourceState.lambda?.status }, + { name: 'Lambda Code', status: resourceState.lambdaCode?.status }, + ] + : []), + ]; + + const disabled = scanTypes.filter((s) => s.status !== 'ENABLED'); + + if (disabled.length > 0) { + findings.push( + this.makeFinding({ + id: `inspector-partial-${region}`, + title: 'Inspector scan types not fully enabled', + description: `The following Inspector scan types are not enabled in ${region}: ${disabled.map((d) => d.name).join(', ')}.`, + severity: 'medium', + resourceId: `arn:aws:inspector2:${region}`, + remediation: `Use inspector2:EnableCommand with resourceTypes set to ['EC2', 'ECR', 'LAMBDA', 'LAMBDA_CODE'] and accountIds set to the target account ID. Rollback: use inspector2:DisableCommand with the same resourceTypes and accountIds.`, + evidence: Object.fromEntries( + scanTypes.map((s) => [s.name, s.status ?? 'UNKNOWN']), + ), + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `inspector-enabled-${region}`, + title: 'Inspector fully enabled', + description: `All Inspector scan types are enabled in ${region}.`, + severity: 'info', + resourceId: `arn:aws:inspector2:${region}`, + passed: true, + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + + if (msg.includes('not enabled')) { + findings.push( + this.makeFinding({ + id: `inspector-not-enabled-${region}`, + title: 'Inspector not enabled', + description: `AWS Inspector is not enabled in ${region}.`, + severity: 'medium', + resourceId: `arn:aws:inspector2:${region}`, + remediation: `Use inspector2:EnableCommand with resourceTypes set to ['EC2', 'ECR', 'LAMBDA', 'LAMBDA_CODE'] and accountIds set to the target account ID. Rollback: use inspector2:DisableCommand with the same resourceTypes and accountIds.`, + }), + ); + return findings; + } + + throw error; + } + + return findings; + } + + private makeFinding( + params: Omit & { + remediation?: string; + }, + ): SecurityFinding { + return { + ...params, + evidence: { ...(params.evidence ?? {}), findingKey: params.id }, + resourceType: 'AwsInspectorCoverage', + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/kinesis.adapter.ts b/apps/api/src/cloud-security/providers/aws/kinesis.adapter.ts new file mode 100644 index 0000000000..24be3a3281 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/kinesis.adapter.ts @@ -0,0 +1,132 @@ +import { + KinesisClient, + ListStreamsCommand, + DescribeStreamSummaryCommand, +} from '@aws-sdk/client-kinesis'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class KinesisAdapter implements AwsServiceAdapter { + readonly serviceId = 'kinesis'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new KinesisClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + const streamNames: string[] = []; + let exclusiveStartStreamName: string | undefined; + + do { + const listRes = await client.send( + new ListStreamsCommand({ + ExclusiveStartStreamName: exclusiveStartStreamName, + }), + ); + + const names = listRes.StreamNames ?? []; + streamNames.push(...names); + + if (listRes.HasMoreStreams && names.length > 0) { + exclusiveStartStreamName = names[names.length - 1]; + } else { + break; + } + } while (true); + + if (streamNames.length === 0) return findings; + + for (const streamName of streamNames) { + const descRes = await client.send( + new DescribeStreamSummaryCommand({ StreamName: streamName }), + ); + + const summary = descRes.StreamDescriptionSummary; + if (!summary) continue; + + const streamArn = + summary.StreamARN ?? `arn:aws:kinesis:${region}:stream/${streamName}`; + + if (!summary.EncryptionType || summary.EncryptionType === 'NONE') { + findings.push( + this.makeFinding({ + id: `kinesis-not-encrypted-${streamName}`, + title: 'Stream not encrypted', + description: `Kinesis stream "${streamName}" does not have server-side encryption enabled.`, + severity: 'high', + resourceId: streamArn, + evidence: { + service: 'Kinesis', + streamName, + encryptionType: summary.EncryptionType ?? 'NONE', + }, + remediation: `Use kinesis:StartStreamEncryptionCommand with StreamName set to '${streamName}', EncryptionType set to 'KMS', and KeyId set to a KMS key ARN or alias (e.g., 'alias/aws/kinesis' for AWS-managed key, or a CMK ARN). Rollback: use kinesis:StopStreamEncryptionCommand with StreamName set to '${streamName}', EncryptionType set to 'KMS', and the same KeyId.`, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `kinesis-encrypted-${streamName}`, + title: 'Stream encrypted', + description: `Kinesis stream "${streamName}" has ${summary.EncryptionType} encryption enabled.`, + severity: 'info', + resourceId: streamArn, + evidence: { + service: 'Kinesis', + streamName, + encryptionType: summary.EncryptionType, + }, + passed: true, + }), + ); + } + + const enhancedMetrics = summary.EnhancedMonitoring ?? []; + const hasShardMetrics = enhancedMetrics.some( + (m) => m.ShardLevelMetrics && m.ShardLevelMetrics.length > 0, + ); + + if (!hasShardMetrics) { + findings.push( + this.makeFinding({ + id: `kinesis-no-enhanced-monitoring-${streamName}`, + title: 'Enhanced monitoring not enabled', + description: `Kinesis stream "${streamName}" does not have shard-level enhanced monitoring enabled.`, + severity: 'low', + resourceId: streamArn, + evidence: { service: 'Kinesis', streamName }, + remediation: `Use kinesis:EnableEnhancedMonitoringCommand with StreamName set to '${streamName}' and ShardLevelMetrics set to ['ALL'] (or specific metrics like 'IncomingBytes', 'OutgoingBytes'). Rollback: use kinesis:DisableEnhancedMonitoringCommand with the same StreamName and ShardLevelMetrics.`, + }), + ); + } + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + params: Omit & { + remediation?: string; + }, + ): SecurityFinding { + return { + ...params, + resourceType: 'AwsKinesisStream', + evidence: { ...params.evidence, findingKey: params.id }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/kms.adapter.ts b/apps/api/src/cloud-security/providers/aws/kms.adapter.ts new file mode 100644 index 0000000000..071d4fb9bd --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/kms.adapter.ts @@ -0,0 +1,137 @@ +import { + KMSClient, + ListKeysCommand, + DescribeKeyCommand, + GetKeyRotationStatusCommand, +} from '@aws-sdk/client-kms'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class KmsAdapter implements AwsServiceAdapter { + readonly serviceId = 'kms'; + readonly isGlobal = false; + + async scan(params: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const { credentials, region, accountId } = params; + const client = new KMSClient({ region, credentials }); + + const findings: SecurityFinding[] = []; + let marker: string | undefined; + + do { + const resp = await client.send( + new ListKeysCommand({ Marker: marker, Limit: 100 }), + ); + + for (const key of resp.Keys || []) { + if (!key.KeyId) continue; + + const keyFindings = await this.checkKey( + client, + key.KeyId, + region, + accountId, + ); + findings.push(...keyFindings); + } + + marker = resp.Truncated ? resp.NextMarker : undefined; + } while (marker); + + return findings; + } + + private async checkKey( + client: KMSClient, + keyId: string, + region: string, + accountId?: string, + ): Promise { + const descResp = await client.send( + new DescribeKeyCommand({ KeyId: keyId }), + ); + const meta = descResp.KeyMetadata; + if (!meta) return []; + + // Only check customer-managed symmetric keys + if (meta.KeyManager !== 'CUSTOMER') return []; + if (meta.KeySpec !== 'SYMMETRIC_DEFAULT') return []; + if (meta.KeyState !== 'Enabled') return []; + + const keyArn = meta.Arn || keyId; + const description = meta.Description || keyId; + + try { + const rotResp = await client.send( + new GetKeyRotationStatusCommand({ KeyId: keyId }), + ); + + if (!rotResp.KeyRotationEnabled) { + return [ + this.makeFinding({ + id: `kms-no-rotation-${keyId}`, + title: `KMS key "${description}" does not have automatic rotation enabled (${region})`, + description: `Customer-managed KMS key ${keyId} does not have automatic annual rotation enabled. CIS Benchmark 3.8 requires automatic key rotation.`, + severity: 'medium', + resourceId: keyArn, + remediation: `Use kms:EnableKeyRotationCommand with KeyId set to the key ARN "${keyArn}". Rollback by calling kms:DisableKeyRotationCommand with the same KeyId.`, + passed: false, + accountId, + region, + }), + ]; + } + + return [ + this.makeFinding({ + id: `kms-rotation-${keyId}`, + title: `KMS key "${description}" has automatic rotation enabled (${region})`, + description: `Automatic key rotation is enabled for KMS key ${keyId}.`, + severity: 'info', + resourceId: keyArn, + passed: true, + accountId, + region, + }), + ]; + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + } + + private makeFinding(opts: { + id: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + resourceId?: string; + remediation?: string; + passed: boolean; + accountId?: string; + region?: string; + }): SecurityFinding { + return { + id: opts.id, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'AwsKmsKey', + resourceId: opts.resourceId || 'unknown', + remediation: opts.remediation, + evidence: { + awsAccountId: opts.accountId, + region: opts.region, + service: 'KMS', + findingKey: opts.id, + }, + createdAt: new Date().toISOString(), + passed: opts.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/lambda.adapter.ts b/apps/api/src/cloud-security/providers/aws/lambda.adapter.ts new file mode 100644 index 0000000000..487c40fe56 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/lambda.adapter.ts @@ -0,0 +1,167 @@ +import { + LambdaClient, + ListFunctionsCommand, + GetPolicyCommand, +} from '@aws-sdk/client-lambda'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +const DEPRECATED_RUNTIMES = [ + 'nodejs14.x', + 'nodejs12.x', + 'nodejs10.x', + 'nodejs8.10', + 'python3.7', + 'python3.6', + 'python2.7', + 'ruby2.5', + 'dotnetcore3.1', + 'dotnetcore2.1', + 'java8', + 'go1.x', +]; + +export class LambdaAdapter implements AwsServiceAdapter { + readonly serviceId = 'lambda'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new LambdaClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let marker: string | undefined; + + do { + const resp = await client.send( + new ListFunctionsCommand({ Marker: marker }), + ); + + const functions = resp.Functions ?? []; + + for (const fn of functions) { + const arn = fn.FunctionArn ?? 'unknown'; + const name = fn.FunctionName ?? 'unknown'; + + // Check for deprecated runtime + if (fn.Runtime && DEPRECATED_RUNTIMES.includes(fn.Runtime)) { + findings.push( + this.makeFinding({ + resourceId: arn, + title: 'Deprecated runtime in use', + description: `Lambda function ${name} uses deprecated runtime ${fn.Runtime}. Deprecated runtimes no longer receive security patches.`, + severity: 'high', + remediation: + '[MANUAL] Cannot be auto-fixed. Updating the Lambda runtime may require code changes. Update the function runtime via lambda:UpdateFunctionConfigurationCommand with Runtime set to the latest supported version.', + evidence: { runtime: fn.Runtime }, + }), + ); + } + + // Check for public access via resource policy + try { + const policyResp = await client.send( + new GetPolicyCommand({ FunctionName: name }), + ); + + if (policyResp.Policy) { + const policy = JSON.parse(policyResp.Policy); + const statements = policy.Statement ?? []; + + const isPublic = statements.some( + (stmt: Record) => { + if (stmt.Effect !== 'Allow') return false; + const principal = stmt.Principal; + if (principal === '*') return true; + if ( + typeof principal === 'object' && + principal !== null && + (principal as Record).AWS === '*' + ) + return true; + return false; + }, + ); + + if (isPublic) { + findings.push( + this.makeFinding({ + resourceId: arn, + title: 'Lambda function is publicly accessible', + description: `Lambda function ${name} has a resource policy that allows public invocation.`, + severity: 'critical', + remediation: + 'Use lambda:RemovePermissionCommand with FunctionName and StatementId to remove the public policy statement. Rollback by calling lambda:AddPermissionCommand to restore the statement.', + evidence: { policy: policy.Statement }, + }), + ); + } + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + // ResourceNotFoundException means no policy — not public, skip + if (msg.includes('ResourceNotFoundException')) { + // No policy attached — this is normal + } else if (msg.includes('AccessDenied')) { + return []; + } + } + + // Check VPC configuration + if (!fn.VpcConfig?.VpcId) { + findings.push( + this.makeFinding({ + resourceId: arn, + title: 'Lambda function not deployed in VPC', + description: `Lambda function ${name} is not deployed within a VPC. Functions outside a VPC cannot access private resources and lack network-level isolation.`, + severity: 'low', + remediation: + '[MANUAL] Cannot be auto-fixed. Adding a Lambda to a VPC requires VPC subnet and security group configuration decisions.', + evidence: { vpcConfig: fn.VpcConfig ?? null }, + }), + ); + } + } + + marker = resp.NextMarker; + } while (marker); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding(params: { + resourceId: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation?: string; + evidence?: Record; + passed?: boolean; + }): SecurityFinding { + const id = `lambda-${params.resourceId}-${params.title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title: params.title, + description: params.description, + severity: params.severity, + resourceType: 'AwsLambdaFunction', + resourceId: params.resourceId, + remediation: params.remediation, + evidence: { ...params.evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed: params.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/macie.adapter.ts b/apps/api/src/cloud-security/providers/aws/macie.adapter.ts new file mode 100644 index 0000000000..f9eb16de11 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/macie.adapter.ts @@ -0,0 +1,92 @@ +import { Macie2Client, GetMacieSessionCommand } from '@aws-sdk/client-macie2'; +import { S3Client, ListBucketsCommand } from '@aws-sdk/client-s3'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class MacieAdapter implements AwsServiceAdapter { + readonly serviceId = 'macie'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new Macie2Client({ credentials, region }); + const findings: SecurityFinding[] = []; + + // Prerequisite: check if there are any S3 buckets + try { + const s3Client = new S3Client({ credentials, region }); + const s3Resp = await s3Client.send(new ListBucketsCommand({})); + if ((s3Resp.Buckets ?? []).length === 0) return []; + } catch { + // If prerequisite check fails (permissions), fall through to existing behavior + } + + try { + const session = await client.send(new GetMacieSessionCommand({})); + + if (session.status === 'ENABLED') { + findings.push( + this.makeFinding({ + id: `macie-enabled-${region}`, + title: 'Macie is enabled', + description: `Amazon Macie is enabled in ${region}.`, + severity: 'info', + resourceId: `arn:aws:macie2:${region}`, + passed: true, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `macie-not-enabled-${region}`, + title: 'Macie not enabled', + description: `Amazon Macie is not enabled in ${region}.`, + severity: 'medium', + resourceId: `arn:aws:macie2:${region}`, + remediation: `Use macie2:EnableMacieCommand with status set to 'ENABLED' and findingPublishingFrequency set to 'FIFTEEN_MINUTES' (or 'ONE_HOUR', 'SIX_HOURS'). Rollback: use macie2:DisableMacieCommand. Note: enabling Macie incurs costs based on data scanned.`, + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + + if (msg.includes('not enabled') || msg.includes('Macie is not enabled')) { + findings.push( + this.makeFinding({ + id: `macie-not-enabled-${region}`, + title: 'Macie not enabled', + description: `Amazon Macie is not enabled in ${region}.`, + severity: 'medium', + resourceId: `arn:aws:macie2:${region}`, + remediation: `Use macie2:EnableMacieCommand with status set to 'ENABLED' and findingPublishingFrequency set to 'FIFTEEN_MINUTES' (or 'ONE_HOUR', 'SIX_HOURS'). Rollback: use macie2:DisableMacieCommand. Note: enabling Macie incurs costs based on data scanned.`, + }), + ); + return findings; + } + + throw error; + } + + return findings; + } + + private makeFinding( + params: Omit & { + remediation?: string; + }, + ): SecurityFinding { + return { + ...params, + evidence: { ...(params.evidence ?? {}), findingKey: params.id }, + resourceType: 'AwsMacieSession', + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/msk.adapter.ts b/apps/api/src/cloud-security/providers/aws/msk.adapter.ts new file mode 100644 index 0000000000..ae9a723330 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/msk.adapter.ts @@ -0,0 +1,177 @@ +import { KafkaClient, ListClustersV2Command } from '@aws-sdk/client-kafka'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class MskAdapter implements AwsServiceAdapter { + readonly serviceId = 'msk'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new KafkaClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const res = await client.send( + new ListClustersV2Command({ NextToken: nextToken }), + ); + + for (const cluster of res.ClusterInfoList ?? []) { + const clusterName = cluster.ClusterName ?? 'unknown'; + const clusterArn = cluster.ClusterArn ?? clusterName; + const provisioned = cluster.Provisioned; + + if (!provisioned) continue; + + // Check encryption in transit + const clientBroker = + provisioned.EncryptionInfo?.EncryptionInTransit?.ClientBroker; + + if (clientBroker === 'TLS') { + findings.push( + this.makeFinding( + clusterArn, + 'Encryption in transit enforced', + `MSK cluster "${clusterName}" enforces TLS-only encryption in transit`, + 'info', + { clusterName, clientBroker }, + true, + ), + ); + } else if (clientBroker === 'TLS_PLAINTEXT') { + findings.push( + this.makeFinding( + clusterArn, + 'Encryption in transit allows plaintext', + `MSK cluster "${clusterName}" allows both TLS and plaintext connections`, + 'medium', + { clusterName, clientBroker }, + false, + `Use kafka:UpdateSecurityCommand with ClusterArn set to '${clusterArn}', CurrentVersion set to the cluster's current version, and EncryptionInTransit.ClientBroker set to 'TLS'. Rollback: use kafka:UpdateSecurityCommand with ClientBroker set to 'TLS_PLAINTEXT'.`, + ), + ); + } else { + findings.push( + this.makeFinding( + clusterArn, + 'Encryption in transit not enforced', + `MSK cluster "${clusterName}" does not enforce TLS encryption in transit`, + 'high', + { clusterName, clientBroker: clientBroker ?? 'NOT_CONFIGURED' }, + false, + `Use kafka:UpdateSecurityCommand with ClusterArn set to '${clusterArn}', CurrentVersion set to the cluster's current version, and EncryptionInTransit.ClientBroker set to 'TLS'. Rollback: use kafka:UpdateSecurityCommand with ClientBroker set to 'PLAINTEXT'.`, + ), + ); + } + + // Check encryption at rest + const kmsKeyId = + provisioned.EncryptionInfo?.EncryptionAtRest?.DataVolumeKMSKeyId; + + if (kmsKeyId) { + findings.push( + this.makeFinding( + clusterArn, + 'Custom encryption key configured', + `MSK cluster "${clusterName}" uses a customer-managed KMS key for encryption at rest`, + 'info', + { clusterName, kmsKeyId }, + true, + ), + ); + } else { + findings.push( + this.makeFinding( + clusterArn, + 'Using default encryption key', + `MSK cluster "${clusterName}" uses the default AWS-managed encryption key`, + 'medium', + { clusterName }, + false, + `[MANUAL] Cannot be auto-fixed. Encryption at rest with a customer-managed KMS key can only be configured at cluster creation time. Create a new MSK cluster using kafka:CreateClusterCommand with EncryptionInfo.EncryptionAtRest.DataVolumeKMSKeyId set to a KMS key ARN, then migrate topics and data.`, + ), + ); + } + + // Check broker logging + const brokerLogs = provisioned.LoggingInfo?.BrokerLogs; + const hasCloudWatch = brokerLogs?.CloudWatchLogs?.Enabled === true; + const hasS3 = brokerLogs?.S3?.Enabled === true; + const hasFirehose = brokerLogs?.Firehose?.Enabled === true; + + if (hasCloudWatch || hasS3 || hasFirehose) { + findings.push( + this.makeFinding( + clusterArn, + 'Broker logging configured', + `MSK cluster "${clusterName}" has broker logging enabled`, + 'info', + { + clusterName, + cloudWatch: hasCloudWatch, + s3: hasS3, + firehose: hasFirehose, + }, + true, + ), + ); + } else { + findings.push( + this.makeFinding( + clusterArn, + 'Broker logging not configured', + `MSK cluster "${clusterName}" does not have any broker log destination configured`, + 'medium', + { clusterName }, + false, + `Use kafka:UpdateMonitoringCommand with ClusterArn set to '${clusterArn}', CurrentVersion set to the cluster's current version, and LoggingInfo.BrokerLogs.CloudWatchLogs set to { Enabled: true, LogGroup: '/aws/msk/${clusterName}' }. Alternatively, use S3 or Firehose log destinations. Rollback: use kafka:UpdateMonitoringCommand with CloudWatchLogs.Enabled set to false.`, + ), + ); + } + } + + nextToken = res.NextToken; + } while (nextToken); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `msk-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsMskCluster', + resourceId, + remediation, + evidence: { ...evidence, service: 'MSK', findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/network-firewall.adapter.ts b/apps/api/src/cloud-security/providers/aws/network-firewall.adapter.ts new file mode 100644 index 0000000000..128f4ae189 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/network-firewall.adapter.ts @@ -0,0 +1,118 @@ +import { + NetworkFirewallClient, + ListFirewallsCommand, + DescribeFirewallCommand, + DescribeLoggingConfigurationCommand, +} from '@aws-sdk/client-network-firewall'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class NetworkFirewallAdapter implements AwsServiceAdapter { + readonly serviceId = 'network-firewall'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new NetworkFirewallClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + const firewalls: { name: string; arn: string }[] = []; + let nextToken: string | undefined; + + do { + const listRes = await client.send( + new ListFirewallsCommand({ NextToken: nextToken }), + ); + + for (const fw of listRes.Firewalls ?? []) { + if (fw.FirewallName && fw.FirewallArn) { + firewalls.push({ name: fw.FirewallName, arn: fw.FirewallArn }); + } + } + + nextToken = listRes.NextToken; + } while (nextToken); + + if (firewalls.length === 0) { + findings.push( + this.makeFinding({ + id: `network-firewall-none-${region}`, + title: 'No Network Firewalls found', + description: `No AWS Network Firewalls are deployed in ${region}.`, + severity: 'info', + resourceId: `arn:aws:network-firewall:${region}`, + evidence: { service: 'Network Firewall', region }, + }), + ); + return findings; + } + + for (const fw of firewalls) { + const descRes = await client.send( + new DescribeFirewallCommand({ FirewallArn: fw.arn }), + ); + + if (!descRes.Firewall?.FirewallPolicyArn) { + findings.push( + this.makeFinding({ + id: `network-firewall-no-policy-${fw.name}`, + title: 'Firewall has no policy attached', + description: `Network Firewall "${fw.name}" does not have a firewall policy configured.`, + severity: 'high', + resourceId: fw.arn, + evidence: { service: 'Network Firewall', firewallName: fw.name }, + remediation: `Use network-firewall:AssociateFirewallPolicyCommand with FirewallArn set to '${fw.arn}' and FirewallPolicyArn set to the policy ARN. If no policy exists, first create one with network-firewall:CreateFirewallPolicyCommand with FirewallPolicyName and FirewallPolicy containing StatelessDefaultActions and StatefulRuleGroupReferences. Rollback: use network-firewall:AssociateFirewallPolicyCommand to revert to the previous policy ARN.`, + }), + ); + } + + const logRes = await client.send( + new DescribeLoggingConfigurationCommand({ FirewallArn: fw.arn }), + ); + + const logConfigs = + logRes.LoggingConfiguration?.LogDestinationConfigs ?? []; + + if (logConfigs.length === 0) { + findings.push( + this.makeFinding({ + id: `network-firewall-no-logging-${fw.name}`, + title: 'Firewall logging not configured', + description: `Network Firewall "${fw.name}" does not have logging configured.`, + severity: 'medium', + resourceId: fw.arn, + evidence: { service: 'Network Firewall', firewallName: fw.name }, + remediation: `Use network-firewall:UpdateLoggingConfigurationCommand with FirewallArn set to '${fw.arn}' and LoggingConfiguration.LogDestinationConfigs containing LogType 'ALERT' (or 'FLOW'), LogDestinationType 'CloudWatchLogs' (or 'S3', 'KinesisDataFirehose'), and LogDestination with the destination details (e.g., logGroup for CloudWatch). Rollback: use network-firewall:UpdateLoggingConfigurationCommand with an empty LogDestinationConfigs array.`, + }), + ); + } + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + params: Omit & { + remediation?: string; + }, + ): SecurityFinding { + return { + ...params, + evidence: { ...(params.evidence ?? {}), findingKey: params.id }, + resourceType: 'AwsNetworkFirewall', + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/opensearch.adapter.ts b/apps/api/src/cloud-security/providers/aws/opensearch.adapter.ts new file mode 100644 index 0000000000..cd2a0159f4 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/opensearch.adapter.ts @@ -0,0 +1,185 @@ +import { + DescribeDomainCommand, + ListDomainNamesCommand, + OpenSearchClient, +} from '@aws-sdk/client-opensearch'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class OpenSearchAdapter implements AwsServiceAdapter { + readonly serviceId = 'opensearch'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new OpenSearchClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + const listRes = await client.send(new ListDomainNamesCommand({})); + + for (const domainInfo of listRes.DomainNames ?? []) { + const domainName = domainInfo.DomainName; + if (!domainName) continue; + + try { + const descRes = await client.send( + new DescribeDomainCommand({ DomainName: domainName }), + ); + + const domain = descRes.DomainStatus; + if (!domain) continue; + + const resourceId = domain.ARN ?? domainName; + + if (domain.EncryptionAtRestOptions?.Enabled !== true) { + findings.push( + this.makeFinding( + resourceId, + 'OpenSearch encryption at rest is disabled', + `Domain "${domainName}" does not have encryption at rest enabled`, + 'high', + { domainName, encryptionAtRest: false }, + false, + `Use opensearch:UpdateDomainConfigCommand with DomainName and EncryptionAtRestOptions.Enabled set to true. Rollback by setting to false.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'OpenSearch encryption at rest is enabled', + `Domain "${domainName}" has encryption at rest enabled`, + 'info', + { domainName, encryptionAtRest: true }, + true, + ), + ); + } + + if (domain.NodeToNodeEncryptionOptions?.Enabled !== true) { + findings.push( + this.makeFinding( + resourceId, + 'OpenSearch node-to-node encryption is disabled', + `Domain "${domainName}" does not have node-to-node encryption enabled`, + 'high', + { domainName, nodeToNodeEncryption: false }, + false, + `Use opensearch:UpdateDomainConfigCommand with NodeToNodeEncryptionOptions.Enabled set to true. Rollback by setting to false.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'OpenSearch node-to-node encryption is enabled', + `Domain "${domainName}" has node-to-node encryption enabled`, + 'info', + { domainName, nodeToNodeEncryption: true }, + true, + ), + ); + } + + const vpcOptions = domain.VPCOptions; + const hasVpc = + vpcOptions && + ((vpcOptions.SubnetIds ?? []).length > 0 || + (vpcOptions.SecurityGroupIds ?? []).length > 0); + + if (!hasVpc) { + findings.push( + this.makeFinding( + resourceId, + 'OpenSearch domain is publicly accessible', + `Domain "${domainName}" is not deployed within a VPC and may be publicly accessible`, + 'high', + { domainName, vpcConfigured: false }, + false, + `[MANUAL] Cannot be auto-fixed. Moving an OpenSearch domain into a VPC requires domain recreation.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'OpenSearch domain is in a VPC', + `Domain "${domainName}" is deployed within a VPC`, + 'info', + { domainName, vpcConfigured: true }, + true, + ), + ); + } + + if (domain.AdvancedSecurityOptions?.Enabled !== true) { + findings.push( + this.makeFinding( + resourceId, + 'OpenSearch fine-grained access control is disabled', + `Domain "${domainName}" does not have advanced security options (fine-grained access control) enabled`, + 'medium', + { domainName, advancedSecurity: false }, + false, + `Use opensearch:UpdateDomainConfigCommand with AdvancedSecurityOptions.Enabled set to true. Requires HTTPS enforcement.`, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'OpenSearch fine-grained access control is enabled', + `Domain "${domainName}" has advanced security options enabled`, + 'info', + { domainName, advancedSecurity: true }, + true, + ), + ); + } + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('ResourceNotFoundException')) continue; + throw error; + } + } + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `opensearch-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsOpenSearchDomain', + resourceId, + remediation, + evidence: { ...evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/rds.adapter.ts b/apps/api/src/cloud-security/providers/aws/rds.adapter.ts new file mode 100644 index 0000000000..bc30909892 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/rds.adapter.ts @@ -0,0 +1,166 @@ +import { RDSClient, DescribeDBInstancesCommand } from '@aws-sdk/client-rds'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +const MIN_BACKUP_RETENTION_DAYS = 7; + +export class RdsAdapter implements AwsServiceAdapter { + readonly serviceId = 'rds'; + readonly isGlobal = false; + + async scan(params: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const { credentials, region, accountId } = params; + const client = new RDSClient({ region, credentials }); + + const findings: SecurityFinding[] = []; + let marker: string | undefined; + + do { + const resp = await client.send( + new DescribeDBInstancesCommand({ Marker: marker, MaxRecords: 100 }), + ); + + for (const db of resp.DBInstances || []) { + if (!db.DBInstanceIdentifier) continue; + findings.push(...this.checkInstance(db, region, accountId)); + } + + marker = resp.Marker; + } while (marker); + + return findings; + } + + private checkInstance( + db: { + DBInstanceIdentifier?: string; + DBInstanceArn?: string; + PubliclyAccessible?: boolean; + StorageEncrypted?: boolean; + BackupRetentionPeriod?: number; + MultiAZ?: boolean; + DeletionProtection?: boolean; + Engine?: string; + }, + region: string, + accountId?: string, + ): SecurityFinding[] { + const findings: SecurityFinding[] = []; + const id = db.DBInstanceIdentifier!; + const arn = db.DBInstanceArn || id; + + if (db.PubliclyAccessible) { + findings.push( + this.makeFinding({ + id: `rds-public-${id}`, + title: `RDS instance "${id}" is publicly accessible (${region})`, + description: `Database instance ${id} (${db.Engine || 'unknown'}) is publicly accessible. This exposes the database to potential attacks from the internet.`, + severity: 'critical', + resourceId: arn, + remediation: `Use rds:ModifyDBInstanceCommand with DBInstanceIdentifier and PubliclyAccessible set to false. Rollback by setting PubliclyAccessible to true.`, + passed: false, + accountId, + region, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: `rds-public-${id}`, + title: `RDS instance "${id}" is not publicly accessible (${region})`, + description: `Database instance ${id} is not publicly accessible.`, + severity: 'info', + resourceId: arn, + passed: true, + accountId, + region, + }), + ); + } + + if (!db.StorageEncrypted) { + findings.push( + this.makeFinding({ + id: `rds-encryption-${id}`, + title: `RDS instance "${id}" is not encrypted (${region})`, + description: `Database instance ${id} does not have storage encryption enabled. Data at rest is not protected.`, + severity: 'high', + resourceId: arn, + remediation: `[MANUAL] Cannot be auto-fixed. RDS encryption can only be enabled at creation time. To fix: create a snapshot using rds:CreateDBSnapshotCommand, copy the snapshot with encryption using rds:CopyDBSnapshotCommand with KmsKeyId, then restore from the encrypted snapshot using rds:RestoreDBInstanceFromDBSnapshotCommand.`, + passed: false, + accountId, + region, + }), + ); + } + + const retention = db.BackupRetentionPeriod ?? 0; + if (retention < MIN_BACKUP_RETENTION_DAYS) { + findings.push( + this.makeFinding({ + id: `rds-backup-${id}`, + title: `RDS instance "${id}" has insufficient backup retention (${retention} days) (${region})`, + description: `Database instance ${id} has a backup retention period of ${retention} day(s). Minimum recommended is ${MIN_BACKUP_RETENTION_DAYS} days.`, + severity: 'medium', + resourceId: arn, + remediation: `Use rds:ModifyDBInstanceCommand with BackupRetentionPeriod set to at least 7. Rollback by restoring previous BackupRetentionPeriod value.`, + passed: false, + accountId, + region, + }), + ); + } + + if (!db.DeletionProtection) { + findings.push( + this.makeFinding({ + id: `rds-deletion-protection-${id}`, + title: `RDS instance "${id}" has no deletion protection (${region})`, + description: `Database instance ${id} does not have deletion protection enabled. The instance could be accidentally deleted.`, + severity: 'medium', + resourceId: arn, + remediation: `Use rds:ModifyDBInstanceCommand with DeletionProtection set to true. Rollback by setting DeletionProtection to false.`, + passed: false, + accountId, + region, + }), + ); + } + + return findings; + } + + private makeFinding(opts: { + id: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + resourceId?: string; + remediation?: string; + passed: boolean; + accountId?: string; + region?: string; + }): SecurityFinding { + return { + id: opts.id, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'AwsRdsDbInstance', + resourceId: opts.resourceId || 'unknown', + remediation: opts.remediation, + evidence: { + awsAccountId: opts.accountId, + region: opts.region, + service: 'RDS', + findingKey: opts.id, + }, + createdAt: new Date().toISOString(), + passed: opts.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/redshift.adapter.ts b/apps/api/src/cloud-security/providers/aws/redshift.adapter.ts new file mode 100644 index 0000000000..dd7921ca3d --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/redshift.adapter.ts @@ -0,0 +1,157 @@ +import { + DescribeClustersCommand, + DescribeLoggingStatusCommand, + RedshiftClient, +} from '@aws-sdk/client-redshift'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class RedshiftAdapter implements AwsServiceAdapter { + readonly serviceId = 'redshift'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new RedshiftClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let marker: string | undefined; + + do { + const listRes = await client.send( + new DescribeClustersCommand({ Marker: marker }), + ); + + for (const cluster of listRes.Clusters ?? []) { + const clusterId = cluster.ClusterIdentifier ?? 'unknown'; + + if (cluster.Encrypted !== true) { + findings.push( + this.makeFinding( + clusterId, + 'Redshift cluster is not encrypted', + `Cluster "${clusterId}" does not have encryption at rest enabled`, + 'high', + { encrypted: cluster.Encrypted }, + false, + `[MANUAL] Cannot be auto-fixed. Redshift cluster encryption requires cluster recreation. To fix: create a new encrypted cluster using redshift:CreateClusterCommand with Encrypted set to true, migrate data, then delete the old cluster.`, + ), + ); + } else { + findings.push( + this.makeFinding( + clusterId, + 'Redshift cluster encryption enabled', + `Cluster "${clusterId}" has encryption at rest enabled`, + 'info', + { encrypted: true }, + true, + ), + ); + } + + if (cluster.PubliclyAccessible === true) { + findings.push( + this.makeFinding( + clusterId, + 'Redshift cluster is publicly accessible', + `Cluster "${clusterId}" is configured with public access`, + 'critical', + { publiclyAccessible: true }, + false, + `Use redshift:ModifyClusterCommand with ClusterIdentifier and PubliclyAccessible set to false. Rollback by setting PubliclyAccessible to true.`, + ), + ); + } else { + findings.push( + this.makeFinding( + clusterId, + 'Redshift cluster is not publicly accessible', + `Cluster "${clusterId}" is not publicly accessible`, + 'info', + { publiclyAccessible: false }, + true, + ), + ); + } + + try { + const logRes = await client.send( + new DescribeLoggingStatusCommand({ + ClusterIdentifier: clusterId, + }), + ); + + if (logRes.LoggingEnabled !== true) { + findings.push( + this.makeFinding( + clusterId, + 'Redshift audit logging is disabled', + `Cluster "${clusterId}" does not have audit logging enabled`, + 'medium', + { loggingEnabled: false }, + false, + `Use redshift:EnableLoggingCommand with ClusterIdentifier and BucketName for the S3 logging bucket. Rollback by calling redshift:DisableLoggingCommand.`, + ), + ); + } else { + findings.push( + this.makeFinding( + clusterId, + 'Redshift audit logging is enabled', + `Cluster "${clusterId}" has audit logging enabled`, + 'info', + { loggingEnabled: true }, + true, + ), + ); + } + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (!msg.includes('AccessDenied')) throw error; + } + } + + marker = listRes.Marker; + } while (marker); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `redshift-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsRedshiftCluster', + resourceId, + remediation, + evidence: { ...evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/route53.adapter.ts b/apps/api/src/cloud-security/providers/aws/route53.adapter.ts new file mode 100644 index 0000000000..e18de36e1c --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/route53.adapter.ts @@ -0,0 +1,146 @@ +import { + Route53Client, + ListHostedZonesCommand, + GetDNSSECCommand, + ListQueryLoggingConfigsCommand, +} from '@aws-sdk/client-route-53'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class Route53Adapter implements AwsServiceAdapter { + readonly serviceId = 'route53'; + readonly isGlobal = true; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new Route53Client({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let marker: string | undefined; + + do { + const resp = await client.send( + new ListHostedZonesCommand({ + Marker: marker, + MaxItems: 100, + }), + ); + + const zones = resp.HostedZones ?? []; + + for (const zone of zones) { + const rawId = zone.Id ?? ''; + const zoneId = rawId.replace('/hostedzone/', ''); + const zoneName = zone.Name ?? 'unknown'; + + // DNSSEC check — only for public zones + if (zone.Config?.PrivateZone !== true) { + try { + const dnssecResp = await client.send( + new GetDNSSECCommand({ HostedZoneId: zoneId }), + ); + + if (dnssecResp.Status?.ServeSignature !== 'SIGNING') { + findings.push( + this.makeFinding({ + resourceId: zoneId, + title: 'DNSSEC not enabled', + description: `Hosted zone ${zoneName} (${zoneId}) does not have DNSSEC signing enabled. DNSSEC protects against DNS spoofing attacks.`, + severity: 'medium', + remediation: `Use route53:CreateKeySigningKeyCommand with HostedZoneId set to '${zoneId}', Name set to a KSK name, KeyManagementServiceArn set to a KMS key ARN (must be in us-east-1, asymmetric ECC_NIST_P256). Then use route53:EnableHostedZoneDNSSECCommand with HostedZoneId set to '${zoneId}'. [MANUAL] You must also create a DS record in the parent zone. Rollback: use route53:DisableHostedZoneDNSSECCommand with HostedZoneId.`, + evidence: { + serveSignature: + dnssecResp.Status?.ServeSignature ?? 'not set', + }, + }), + ); + } + } catch (error) { + const msg = + error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + if ( + msg.includes('DNSSECNotFound') || + msg.includes('InvalidArgument') + ) { + findings.push( + this.makeFinding({ + resourceId: zoneId, + title: 'DNSSEC not enabled', + description: `Hosted zone ${zoneName} (${zoneId}) does not have DNSSEC configured.`, + severity: 'medium', + remediation: `Use route53:CreateKeySigningKeyCommand with HostedZoneId set to '${zoneId}', Name set to a KSK name, KeyManagementServiceArn set to a KMS key ARN (must be in us-east-1, asymmetric ECC_NIST_P256). Then use route53:EnableHostedZoneDNSSECCommand with HostedZoneId set to '${zoneId}'. [MANUAL] You must also create a DS record in the parent zone. Rollback: use route53:DisableHostedZoneDNSSECCommand with HostedZoneId.`, + evidence: { error: msg }, + }), + ); + } + } + } + + // Query logging check + try { + const loggingResp = await client.send( + new ListQueryLoggingConfigsCommand({ HostedZoneId: zoneId }), + ); + + const configs = loggingResp.QueryLoggingConfigs ?? []; + if (configs.length === 0) { + findings.push( + this.makeFinding({ + resourceId: zoneId, + title: 'Query logging not enabled', + description: `Hosted zone ${zoneName} (${zoneId}) does not have DNS query logging enabled.`, + severity: 'low', + remediation: `Use route53:CreateQueryLoggingConfigCommand with HostedZoneId set to '${zoneId}' and CloudWatchLogsLogGroupArn set to a CloudWatch Logs log group ARN in us-east-1 (required region). The log group must have a resource policy allowing Route53 to write to it. Rollback: use route53:DeleteQueryLoggingConfigCommand with the Id returned from the create call.`, + evidence: { queryLoggingConfigs: 0 }, + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + } + } + + marker = resp.NextMarker; + } while (marker); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding(params: { + resourceId: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation?: string; + evidence?: Record; + passed?: boolean; + }): SecurityFinding { + const id = `route53-${params.resourceId}-${params.title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title: params.title, + description: params.description, + severity: params.severity, + resourceType: 'AwsRoute53HostedZone', + resourceId: params.resourceId, + remediation: params.remediation, + evidence: { ...(params.evidence ?? {}), findingKey: id }, + createdAt: new Date().toISOString(), + passed: params.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/s3.adapter.ts b/apps/api/src/cloud-security/providers/aws/s3.adapter.ts new file mode 100644 index 0000000000..125b6d8055 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/s3.adapter.ts @@ -0,0 +1,246 @@ +import { + S3Client, + ListBucketsCommand, + GetPublicAccessBlockCommand, + GetBucketEncryptionCommand, + GetBucketVersioningCommand, +} from '@aws-sdk/client-s3'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +const MAX_BUCKETS = 100; + +export class S3Adapter implements AwsServiceAdapter { + readonly serviceId = 's3'; + readonly isGlobal = true; + + async scan(params: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const { credentials, region, accountId } = params; + const client = new S3Client({ region, credentials }); + + const findings: SecurityFinding[] = []; + + const listResp = await client.send(new ListBucketsCommand({})); + const buckets = (listResp.Buckets || []).slice(0, MAX_BUCKETS); + + if (buckets.length === 0) return findings; + + for (const bucket of buckets) { + if (!bucket.Name) continue; + + const results = await Promise.allSettled([ + this.checkPublicAccess(client, bucket.Name, accountId), + this.checkEncryption(client, bucket.Name, accountId), + this.checkVersioning(client, bucket.Name, accountId), + ]); + + for (const result of results) { + if (result.status === 'fulfilled') { + findings.push(...result.value); + } + } + } + + return findings; + } + + private async checkPublicAccess( + client: S3Client, + bucketName: string, + accountId?: string, + ): Promise { + try { + const resp = await client.send( + new GetPublicAccessBlockCommand({ Bucket: bucketName }), + ); + const config = resp.PublicAccessBlockConfiguration; + + if ( + !config?.BlockPublicAcls || + !config?.BlockPublicPolicy || + !config?.IgnorePublicAcls || + !config?.RestrictPublicBuckets + ) { + return [ + this.makeFinding({ + id: `s3-public-access-${bucketName}`, + title: `S3 bucket "${bucketName}" does not block all public access`, + description: `Bucket ${bucketName} has incomplete public access block settings. All four public access block settings should be enabled.`, + severity: 'high', + resourceId: `arn:aws:s3:::${bucketName}`, + remediation: `Use s3:PutPublicAccessBlockCommand with Bucket set to "${bucketName}" and PublicAccessBlockConfiguration with BlockPublicAcls, IgnorePublicAcls, BlockPublicPolicy, and RestrictPublicBuckets all set to true. Rollback by restoring previous PublicAccessBlockConfiguration settings.`, + passed: false, + accountId, + }), + ]; + } + + return [ + this.makeFinding({ + id: `s3-public-access-${bucketName}`, + title: `S3 bucket "${bucketName}" blocks all public access`, + description: `All public access block settings are enabled.`, + severity: 'info', + resourceId: `arn:aws:s3:::${bucketName}`, + passed: true, + accountId, + }), + ]; + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + // No public access block configured at all + if (msg.includes('NoSuchPublicAccessBlockConfiguration')) { + return [ + this.makeFinding({ + id: `s3-public-access-${bucketName}`, + title: `S3 bucket "${bucketName}" has no public access block configured`, + description: `Bucket ${bucketName} has no public access block configuration, making it potentially publicly accessible.`, + severity: 'high', + resourceId: `arn:aws:s3:::${bucketName}`, + remediation: `Use s3:PutPublicAccessBlockCommand with Bucket set to "${bucketName}" and PublicAccessBlockConfiguration with BlockPublicAcls, IgnorePublicAcls, BlockPublicPolicy, and RestrictPublicBuckets all set to true. Rollback by removing the public access block configuration via s3:DeletePublicAccessBlockCommand.`, + passed: false, + accountId, + }), + ]; + } + if (msg.includes('AccessDenied')) return []; + throw error; + } + } + + private async checkEncryption( + client: S3Client, + bucketName: string, + accountId?: string, + ): Promise { + try { + const resp = await client.send( + new GetBucketEncryptionCommand({ Bucket: bucketName }), + ); + const rules = resp.ServerSideEncryptionConfiguration?.Rules || []; + + if (rules.length === 0) { + return [ + this.makeFinding({ + id: `s3-encryption-${bucketName}`, + title: `S3 bucket "${bucketName}" has no default encryption`, + description: `Bucket ${bucketName} does not have server-side encryption configured by default.`, + severity: 'high', + resourceId: `arn:aws:s3:::${bucketName}`, + remediation: `Use s3:PutBucketEncryptionCommand with Bucket set to "${bucketName}" and ServerSideEncryptionConfiguration.Rules containing a rule with ApplyServerSideEncryptionByDefault.SSEAlgorithm set to 'AES256'. For KMS encryption, set SSEAlgorithm to 'aws:kms' and provide KMSMasterKeyID. Rollback by calling s3:DeleteBucketEncryptionCommand with Bucket set to "${bucketName}".`, + passed: false, + accountId, + }), + ]; + } + + return [ + this.makeFinding({ + id: `s3-encryption-${bucketName}`, + title: `S3 bucket "${bucketName}" has default encryption enabled`, + description: `Default server-side encryption is configured.`, + severity: 'info', + resourceId: `arn:aws:s3:::${bucketName}`, + passed: true, + accountId, + }), + ]; + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if ( + msg.includes('ServerSideEncryptionConfigurationNotFound') || + msg.includes('NoSuchBucket') + ) { + return [ + this.makeFinding({ + id: `s3-encryption-${bucketName}`, + title: `S3 bucket "${bucketName}" has no default encryption`, + description: `No server-side encryption configuration found for bucket ${bucketName}.`, + severity: 'high', + resourceId: `arn:aws:s3:::${bucketName}`, + remediation: `Use s3:PutBucketEncryptionCommand with Bucket set to "${bucketName}" and ServerSideEncryptionConfiguration.Rules containing a rule with ApplyServerSideEncryptionByDefault.SSEAlgorithm set to 'AES256'. For KMS encryption, set SSEAlgorithm to 'aws:kms' and provide KMSMasterKeyID. Rollback by calling s3:DeleteBucketEncryptionCommand with Bucket set to "${bucketName}".`, + passed: false, + accountId, + }), + ]; + } + if (msg.includes('AccessDenied')) return []; + throw error; + } + } + + private async checkVersioning( + client: S3Client, + bucketName: string, + accountId?: string, + ): Promise { + try { + const resp = await client.send( + new GetBucketVersioningCommand({ Bucket: bucketName }), + ); + + if (resp.Status !== 'Enabled') { + return [ + this.makeFinding({ + id: `s3-versioning-${bucketName}`, + title: `S3 bucket "${bucketName}" does not have versioning enabled`, + description: `Bucket ${bucketName} does not have versioning enabled. Without versioning, deleted or overwritten objects cannot be recovered.`, + severity: 'medium', + resourceId: `arn:aws:s3:::${bucketName}`, + remediation: `Use s3:PutBucketVersioningCommand with Bucket set to "${bucketName}" and VersioningConfiguration.Status set to 'Enabled'. Rollback by calling s3:PutBucketVersioningCommand with VersioningConfiguration.Status set to 'Suspended'. Note: versioning cannot be fully disabled once enabled, only suspended.`, + passed: false, + accountId, + }), + ]; + } + + return [ + this.makeFinding({ + id: `s3-versioning-${bucketName}`, + title: `S3 bucket "${bucketName}" has versioning enabled`, + description: `Versioning is enabled for data protection.`, + severity: 'info', + resourceId: `arn:aws:s3:::${bucketName}`, + passed: true, + accountId, + }), + ]; + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + } + + private makeFinding(opts: { + id: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + resourceId?: string; + remediation?: string; + passed: boolean; + accountId?: string; + }): SecurityFinding { + return { + id: opts.id, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'AwsS3Bucket', + resourceId: opts.resourceId || 'unknown', + remediation: opts.remediation, + evidence: { + awsAccountId: opts.accountId, + service: 'S3', + findingKey: opts.id, + }, + createdAt: new Date().toISOString(), + passed: opts.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/sagemaker.adapter.ts b/apps/api/src/cloud-security/providers/aws/sagemaker.adapter.ts new file mode 100644 index 0000000000..4f3a17b19a --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/sagemaker.adapter.ts @@ -0,0 +1,186 @@ +import { + SageMakerClient, + ListNotebookInstancesCommand, + DescribeNotebookInstanceCommand, +} from '@aws-sdk/client-sagemaker'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class SageMakerAdapter implements AwsServiceAdapter { + readonly serviceId = 'sagemaker'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new SageMakerClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const listRes = await client.send( + new ListNotebookInstancesCommand({ NextToken: nextToken }), + ); + + for (const nb of listRes.NotebookInstances ?? []) { + const notebookName = nb.NotebookInstanceName ?? 'unknown'; + const notebookArn = nb.NotebookInstanceArn ?? notebookName; + + const descRes = await client.send( + new DescribeNotebookInstanceCommand({ + NotebookInstanceName: notebookName, + }), + ); + + // Check KMS encryption + if (descRes.KmsKeyId) { + findings.push( + this.makeFinding( + notebookArn, + 'Notebook encrypted with CMK', + `SageMaker notebook "${notebookName}" is encrypted with a customer-managed KMS key`, + 'info', + { notebookName, kmsKeyId: descRes.KmsKeyId }, + true, + ), + ); + } else { + findings.push( + this.makeFinding( + notebookArn, + 'Notebook not encrypted with CMK', + `SageMaker notebook "${notebookName}" is not encrypted with a customer-managed KMS key`, + 'medium', + { notebookName }, + false, + `[MANUAL] Cannot be auto-fixed on an existing notebook instance. KMS encryption must be set at creation time. Stop the notebook using sagemaker:StopNotebookInstanceCommand, then delete it using sagemaker:DeleteNotebookInstanceCommand, and recreate using sagemaker:CreateNotebookInstanceCommand with KmsKeyId set to a customer-managed KMS key ARN. Ensure data is backed up before deletion.`, + ), + ); + } + + // Check root access + if (descRes.RootAccess === 'Enabled') { + findings.push( + this.makeFinding( + notebookArn, + 'Root access enabled on notebook', + `SageMaker notebook "${notebookName}" has root access enabled`, + 'medium', + { notebookName, rootAccess: 'Enabled' }, + false, + `First stop the notebook using sagemaker:StopNotebookInstanceCommand with NotebookInstanceName set to '${notebookName}'. Then use sagemaker:UpdateNotebookInstanceCommand with NotebookInstanceName set to '${notebookName}' and RootAccess set to 'Disabled'. Finally restart with sagemaker:StartNotebookInstanceCommand. Rollback: use sagemaker:UpdateNotebookInstanceCommand with RootAccess set to 'Enabled'.`, + ), + ); + } else { + findings.push( + this.makeFinding( + notebookArn, + 'Root access disabled on notebook', + `SageMaker notebook "${notebookName}" has root access disabled`, + 'info', + { notebookName, rootAccess: descRes.RootAccess ?? 'Disabled' }, + true, + ), + ); + } + + // Check direct internet access + if (descRes.DirectInternetAccess === 'Enabled') { + findings.push( + this.makeFinding( + notebookArn, + 'Direct internet access enabled', + `SageMaker notebook "${notebookName}" has direct internet access enabled`, + 'high', + { notebookName, directInternetAccess: 'Enabled' }, + false, + `[MANUAL] Cannot be auto-fixed. DirectInternetAccess can only be set at creation time. Stop the notebook using sagemaker:StopNotebookInstanceCommand, delete using sagemaker:DeleteNotebookInstanceCommand, and recreate using sagemaker:CreateNotebookInstanceCommand with DirectInternetAccess set to 'Disabled' and SubnetId/SecurityGroupIds for VPC access. Ensure data is backed up before deletion.`, + ), + ); + } else { + findings.push( + this.makeFinding( + notebookArn, + 'Direct internet access disabled', + `SageMaker notebook "${notebookName}" has direct internet access disabled`, + 'info', + { + notebookName, + directInternetAccess: + descRes.DirectInternetAccess ?? 'Disabled', + }, + true, + ), + ); + } + + // Check VPC configuration + if (descRes.SubnetId) { + findings.push( + this.makeFinding( + notebookArn, + 'Notebook deployed in VPC', + `SageMaker notebook "${notebookName}" is deployed within a VPC`, + 'info', + { notebookName, subnetId: descRes.SubnetId }, + true, + ), + ); + } else { + findings.push( + this.makeFinding( + notebookArn, + 'Not in VPC', + `SageMaker notebook "${notebookName}" is not deployed within a VPC`, + 'medium', + { notebookName }, + false, + `[MANUAL] Cannot be auto-fixed. VPC configuration can only be set at creation time. Stop the notebook using sagemaker:StopNotebookInstanceCommand, delete using sagemaker:DeleteNotebookInstanceCommand, and recreate using sagemaker:CreateNotebookInstanceCommand with SubnetId set to a VPC subnet and SecurityGroupIds set to security group IDs. Ensure data is backed up before deletion.`, + ), + ); + } + } + + nextToken = listRes.NextToken; + } while (nextToken); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `sagemaker-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsSageMakerNotebook', + resourceId, + remediation, + evidence: { ...evidence, service: 'SageMaker', findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/secrets-manager.adapter.ts b/apps/api/src/cloud-security/providers/aws/secrets-manager.adapter.ts new file mode 100644 index 0000000000..cbc09bdca7 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/secrets-manager.adapter.ts @@ -0,0 +1,113 @@ +import { + SecretsManagerClient, + ListSecretsCommand, +} from '@aws-sdk/client-secrets-manager'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +const NINETY_DAYS_MS = 90 * 24 * 60 * 60 * 1000; + +export class SecretsManagerAdapter implements AwsServiceAdapter { + readonly serviceId = 'secrets-manager'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new SecretsManagerClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const response = await client.send( + new ListSecretsCommand({ NextToken: nextToken }), + ); + + for (const secret of response.SecretList ?? []) { + const secretName = secret.Name ?? 'unknown'; + const secretArn = + secret.ARN ?? + `arn:aws:secretsmanager:${region}:secret/${secretName}`; + + if (secret.RotationEnabled !== true) { + findings.push( + this.makeFinding({ + id: `secrets-no-rotation-${secretName}`, + title: `Secret ${secretName} does not have rotation enabled`, + description: `Secret ${secretName} does not have automatic rotation configured.`, + severity: 'medium', + resourceId: secretArn, + remediation: + '[MANUAL] Cannot be auto-fixed. Enabling secret rotation requires creating a Lambda rotation function specific to the secret type (database credentials, API keys, etc.). Configure rotation via secretsmanager:RotateSecretCommand after setting up the Lambda function.', + }), + ); + continue; + } + + if (secret.LastRotatedDate) { + const age = Date.now() - secret.LastRotatedDate.getTime(); + + if (age > NINETY_DAYS_MS) { + const daysSince = Math.floor(age / (24 * 60 * 60 * 1000)); + findings.push( + this.makeFinding({ + id: `secrets-rotation-overdue-${secretName}`, + title: `Secret ${secretName} rotation overdue`, + description: `Secret ${secretName} was last rotated ${daysSince} days ago, exceeding the 90-day threshold.`, + severity: 'medium', + resourceId: secretArn, + remediation: + 'Trigger an immediate rotation and verify the rotation schedule.', + evidence: { + lastRotated: secret.LastRotatedDate.toISOString(), + daysSinceRotation: daysSince, + }, + }), + ); + continue; + } + } + + findings.push( + this.makeFinding({ + id: `secrets-rotation-ok-${secretName}`, + title: `Secret ${secretName} rotation is configured`, + description: `Secret ${secretName} has rotation enabled and is within the 90-day rotation window.`, + severity: 'info', + resourceId: secretArn, + passed: true, + }), + ); + } + + nextToken = response.NextToken; + } while (nextToken); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + params: Omit & { + remediation?: string; + }, + ): SecurityFinding { + return { + ...params, + evidence: { ...params.evidence, findingKey: params.id }, + resourceType: 'AwsSecretsManagerSecret', + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/security-hub.adapter.ts b/apps/api/src/cloud-security/providers/aws/security-hub.adapter.ts new file mode 100644 index 0000000000..3c082933ee --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/security-hub.adapter.ts @@ -0,0 +1,126 @@ +import { + GetFindingsCommand, + SecurityHubClient, + type GetFindingsCommandInput, +} from '@aws-sdk/client-securityhub'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class SecurityHubAdapter implements AwsServiceAdapter { + readonly serviceId = 'security-hub'; + readonly isGlobal = false; + + async scan(params: { + credentials: AwsCredentials; + region: string; + }): Promise { + const { credentials, region } = params; + + const securityHub = new SecurityHubClient({ region, credentials }); + + try { + return await this.fetchFindings(securityHub, region); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('not subscribed') || msg.includes('AccessDenied')) { + return []; + } + throw error; + } + } + + private async fetchFindings( + client: SecurityHubClient, + region: string, + ): Promise { + const findings: SecurityFinding[] = []; + + const params: GetFindingsCommandInput = { + Filters: { + WorkflowStatus: [ + { Value: 'NEW', Comparison: 'EQUALS' }, + { Value: 'NOTIFIED', Comparison: 'EQUALS' }, + ], + RecordState: [{ Value: 'ACTIVE', Comparison: 'EQUALS' }], + }, + MaxResults: 100, + }; + + let response = await client.send(new GetFindingsCommand(params)); + + if (response.Findings) { + for (const f of response.Findings) { + findings.push(this.mapFinding(f, region)); + } + } + + let nextToken = response.NextToken; + while (nextToken && findings.length < 500) { + response = await client.send( + new GetFindingsCommand({ ...params, NextToken: nextToken }), + ); + + if (response.Findings) { + for (const f of response.Findings) { + if (findings.length >= 500) break; + findings.push(this.mapFinding(f, region)); + } + } + + nextToken = response.NextToken; + } + + return findings; + } + + private mapFinding( + finding: { + Id?: string; + Title?: string; + Description?: string; + Remediation?: { Recommendation?: { Text?: string } }; + Severity?: { Label?: string }; + Resources?: Array<{ Type?: string; Id?: string }>; + AwsAccountId?: string; + Region?: string; + Compliance?: { Status?: string }; + GeneratorId?: string; + CreatedAt?: string; + UpdatedAt?: string; + }, + scanRegion: string, + ): SecurityFinding { + const severityMap: Record = { + INFORMATIONAL: 'info', + LOW: 'low', + MEDIUM: 'medium', + HIGH: 'high', + CRITICAL: 'critical', + }; + + const complianceStatus = finding.Compliance?.Status; + const passed = complianceStatus === 'PASSED'; + const findingRegion = finding.Region || scanRegion; + const baseTitle = finding.Title || 'Untitled Finding'; + + return { + id: finding.Id || '', + title: `${baseTitle} (${findingRegion})`, + description: finding.Description || 'No description available', + severity: severityMap[finding.Severity?.Label || 'INFO'] || 'medium', + resourceType: finding.Resources?.[0]?.Type || 'unknown', + resourceId: finding.Resources?.[0]?.Id || 'unknown', + remediation: + finding.Remediation?.Recommendation?.Text || 'No remediation available', + evidence: { + awsAccountId: finding.AwsAccountId, + region: findingRegion, + complianceStatus, + generatorId: finding.GeneratorId, + updatedAt: finding.UpdatedAt, + }, + createdAt: finding.CreatedAt || new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/shield.adapter.ts b/apps/api/src/cloud-security/providers/aws/shield.adapter.ts new file mode 100644 index 0000000000..30e1890c08 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/shield.adapter.ts @@ -0,0 +1,138 @@ +import { + ShieldClient, + GetSubscriptionStateCommand, +} from '@aws-sdk/client-shield'; +import { + ElasticLoadBalancingV2Client, + DescribeLoadBalancersCommand, +} from '@aws-sdk/client-elastic-load-balancing-v2'; +import { + CloudFrontClient, + ListDistributionsCommand, +} from '@aws-sdk/client-cloudfront'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class ShieldAdapter implements AwsServiceAdapter { + readonly serviceId = 'shield'; + readonly isGlobal = true; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new ShieldClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + // Prerequisite: check if there are public-facing resources (ELBs, CloudFront distributions) + try { + let hasPublicResources = false; + + const elbClient = new ElasticLoadBalancingV2Client({ + credentials, + region, + }); + const elbResp = await elbClient.send( + new DescribeLoadBalancersCommand({ PageSize: 1 }), + ); + if ((elbResp.LoadBalancers ?? []).length > 0) { + hasPublicResources = true; + } + + if (!hasPublicResources) { + const cfClient = new CloudFrontClient({ + credentials, + region: 'us-east-1', + }); + const cfResp = await cfClient.send( + new ListDistributionsCommand({ MaxItems: 1 }), + ); + if ((cfResp.DistributionList?.Items ?? []).length > 0) { + hasPublicResources = true; + } + } + + if (!hasPublicResources) return []; + } catch { + // If prerequisite check fails (permissions), fall through to existing behavior + } + + try { + const res = await client.send(new GetSubscriptionStateCommand({})); + + if (res.SubscriptionState === 'ACTIVE') { + findings.push( + this.makeFinding({ + id: 'shield-advanced-active', + title: 'Shield Advanced is active', + description: + 'AWS Shield Advanced subscription is active, providing enhanced DDoS protection.', + severity: 'info', + resourceId: 'arn:aws:shield::subscription', + evidence: { service: 'Shield', subscriptionState: 'ACTIVE' }, + passed: true, + }), + ); + } else { + findings.push( + this.makeFinding({ + id: 'shield-advanced-not-enabled', + title: 'Shield Advanced not enabled', + description: + 'AWS Shield Advanced is not enabled. Only basic Shield (free) protection is in place.', + severity: 'medium', + resourceId: 'arn:aws:shield::subscription', + evidence: { + service: 'Shield', + subscriptionState: res.SubscriptionState, + }, + remediation: + '[MANUAL] Cannot be fully auto-fixed. Use shield:CreateSubscriptionCommand to enable Shield Advanced. This incurs a $3,000/month commitment with a 1-year minimum. After subscription, use shield:CreateProtectionCommand with ResourceArn for each resource to protect. Rollback: Shield Advanced subscriptions cannot be cancelled during the commitment period.', + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + if ( + msg.includes('SubscriptionNotFoundException') || + msg.includes('ResourceNotFoundException') + ) { + findings.push( + this.makeFinding({ + id: 'shield-advanced-not-enabled', + title: 'Shield Advanced not enabled', + description: + 'AWS Shield Advanced subscription is not available for this account.', + severity: 'medium', + resourceId: 'arn:aws:shield::subscription', + evidence: { service: 'Shield', error: msg }, + remediation: + '[MANUAL] Cannot be fully auto-fixed. Use shield:CreateSubscriptionCommand to enable Shield Advanced. This incurs a $3,000/month commitment with a 1-year minimum. Rollback: Shield Advanced subscriptions cannot be cancelled during the commitment period.', + }), + ); + return findings; + } + throw error; + } + + return findings; + } + + private makeFinding( + params: Omit & { + remediation?: string; + }, + ): SecurityFinding { + return { + ...params, + evidence: { ...params.evidence, findingKey: params.id }, + resourceType: 'AwsShieldSubscription', + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/sns-sqs.adapter.ts b/apps/api/src/cloud-security/providers/aws/sns-sqs.adapter.ts new file mode 100644 index 0000000000..47c3e23a29 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/sns-sqs.adapter.ts @@ -0,0 +1,234 @@ +import { + SNSClient, + ListTopicsCommand, + GetTopicAttributesCommand, +} from '@aws-sdk/client-sns'; +import { + SQSClient, + ListQueuesCommand, + GetQueueAttributesCommand, +} from '@aws-sdk/client-sqs'; +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +function isPublicPolicy(policyJson: string): boolean { + try { + const policy = JSON.parse(policyJson); + const statements = policy.Statement ?? []; + + return statements.some((stmt: Record) => { + if (stmt.Effect !== 'Allow') return false; + if (stmt.Condition && Object.keys(stmt.Condition).length > 0) + return false; + const principal = stmt.Principal; + if (principal === '*') return true; + if ( + typeof principal === 'object' && + principal !== null && + (principal as Record).AWS === '*' + ) + return true; + return false; + }); + } catch { + return false; + } +} + +export class SnsSqsAdapter implements AwsServiceAdapter { + readonly serviceId = 'sns-sqs'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const findings: SecurityFinding[] = []; + + try { + await this.scanSns({ credentials, region, findings }); + await this.scanSqs({ credentials, region, findings }); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private async scanSns({ + credentials, + region, + findings, + }: { + credentials: AwsCredentials; + region: string; + findings: SecurityFinding[]; + }): Promise { + const client = new SNSClient({ credentials, region }); + let nextToken: string | undefined; + + do { + const resp = await client.send( + new ListTopicsCommand({ NextToken: nextToken }), + ); + + const topics = resp.Topics ?? []; + + for (const topic of topics) { + const arn = topic.TopicArn ?? 'unknown'; + + try { + const attrsResp = await client.send( + new GetTopicAttributesCommand({ TopicArn: arn }), + ); + const attrs = attrsResp.Attributes ?? {}; + + // Check for public access + if (attrs.Policy && isPublicPolicy(attrs.Policy)) { + findings.push( + this.makeFinding({ + resourceId: arn, + resourceType: 'AwsSnsTopic', + title: 'SNS topic is publicly accessible', + description: `SNS topic ${arn} has a resource policy that allows public access without conditions.`, + severity: 'high', + remediation: + "Use sns:SetTopicAttributesCommand with TopicArn and AttributeName 'Policy' to restrict access. Set the policy to deny public access while allowing the topic owner. Rollback by restoring the previous policy.", + evidence: { policy: JSON.parse(attrs.Policy).Statement }, + }), + ); + } + + // Check for KMS encryption + if (!attrs.KmsMasterKeyId) { + findings.push( + this.makeFinding({ + resourceId: arn, + resourceType: 'AwsSnsTopic', + title: 'SNS topic not encrypted with KMS', + description: `SNS topic ${arn} does not use KMS encryption for messages at rest.`, + severity: 'medium', + remediation: + "Use sns:SetTopicAttributesCommand with TopicArn and AttributeName 'KmsMasterKeyId' set to 'alias/aws/sns'. Rollback by calling sns:SetTopicAttributesCommand with AttributeName 'KmsMasterKeyId' set to empty string.", + evidence: { kmsMasterKeyId: null }, + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return; + } + } + + nextToken = resp.NextToken; + } while (nextToken); + } + + private async scanSqs({ + credentials, + region, + findings, + }: { + credentials: AwsCredentials; + region: string; + findings: SecurityFinding[]; + }): Promise { + const client = new SQSClient({ credentials, region }); + let nextToken: string | undefined; + + do { + const resp = await client.send( + new ListQueuesCommand({ NextToken: nextToken }), + ); + + const queueUrls = resp.QueueUrls ?? []; + + for (const queueUrl of queueUrls) { + try { + const attrsResp = await client.send( + new GetQueueAttributesCommand({ + QueueUrl: queueUrl, + AttributeNames: [ + 'Policy', + 'KmsMasterKeyId', + 'SqsManagedSseEnabled', + ], + }), + ); + const attrs = attrsResp.Attributes ?? {}; + + // Check for public access + if (attrs.Policy && isPublicPolicy(attrs.Policy)) { + findings.push( + this.makeFinding({ + resourceId: queueUrl, + resourceType: 'AwsSqsQueue', + title: 'SQS queue is publicly accessible', + description: `SQS queue ${queueUrl} has a resource policy that allows public access without conditions.`, + severity: 'high', + remediation: + "Use sqs:SetQueueAttributesCommand with QueueUrl and Attributes.Policy to restrict access. Remove the statement that allows '*' principal. Rollback by restoring the previous policy.", + evidence: { policy: JSON.parse(attrs.Policy).Statement }, + }), + ); + } + + // Check for encryption + if (!attrs.KmsMasterKeyId && attrs.SqsManagedSseEnabled !== 'true') { + findings.push( + this.makeFinding({ + resourceId: queueUrl, + resourceType: 'AwsSqsQueue', + title: 'SQS queue not encrypted', + description: `SQS queue ${queueUrl} does not have KMS or SQS-managed server-side encryption enabled.`, + severity: 'medium', + remediation: + "Use sqs:SetQueueAttributesCommand with Attributes.KmsMasterKeyId set to 'alias/aws/sqs' for SQS. Rollback by removing the KmsMasterKeyId attribute.", + evidence: { + kmsMasterKeyId: attrs.KmsMasterKeyId ?? null, + sqsManagedSseEnabled: attrs.SqsManagedSseEnabled ?? null, + }, + }), + ); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return; + } + } + + nextToken = resp.NextToken; + } while (nextToken); + } + + private makeFinding(params: { + resourceId: string; + resourceType: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation?: string; + evidence?: Record; + passed?: boolean; + }): SecurityFinding { + const id = `sns-sqs-${params.resourceId}-${params.title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title: params.title, + description: params.description, + severity: params.severity, + resourceType: params.resourceType, + resourceId: params.resourceId, + remediation: params.remediation, + evidence: { ...params.evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed: params.passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/step-functions.adapter.ts b/apps/api/src/cloud-security/providers/aws/step-functions.adapter.ts new file mode 100644 index 0000000000..14b7558376 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/step-functions.adapter.ts @@ -0,0 +1,175 @@ +import { + SFNClient, + ListStateMachinesCommand, + DescribeStateMachineCommand, +} from '@aws-sdk/client-sfn'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class StepFunctionsAdapter implements AwsServiceAdapter { + readonly serviceId = 'step-functions'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new SFNClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const listRes = await client.send( + new ListStateMachinesCommand({ nextToken }), + ); + + for (const sm of listRes.stateMachines ?? []) { + const smArn = sm.stateMachineArn; + if (!smArn) continue; + + const smName = sm.name ?? smArn; + + const descRes = await client.send( + new DescribeStateMachineCommand({ stateMachineArn: smArn }), + ); + + // Check logging configuration + const logLevel = descRes.loggingConfiguration?.level; + + if (logLevel === 'OFF' || !logLevel) { + findings.push( + this.makeFinding( + smArn, + 'State machine logging disabled', + `Step Functions state machine "${smName}" does not have logging enabled`, + 'medium', + { stateMachineName: smName, loggingLevel: logLevel ?? 'OFF' }, + false, + `Use sfn:UpdateStateMachineCommand with stateMachineArn set to '${smArn}' and loggingConfiguration set to { level: 'ALL', includeExecutionData: true, destinations: [{ cloudWatchLogsLogGroup: { logGroupArn: '' } }] }. Create the CloudWatch log group first using logs:CreateLogGroupCommand. Rollback: use sfn:UpdateStateMachineCommand with loggingConfiguration.level set to 'OFF'.`, + ), + ); + } else { + findings.push( + this.makeFinding( + smArn, + 'State machine logging enabled', + `Step Functions state machine "${smName}" has logging enabled (level: ${logLevel})`, + 'info', + { stateMachineName: smName, loggingLevel: logLevel }, + true, + ), + ); + } + + // Check X-Ray tracing + if (descRes.tracingConfiguration?.enabled !== true) { + findings.push( + this.makeFinding( + smArn, + 'X-Ray tracing not enabled', + `Step Functions state machine "${smName}" does not have X-Ray tracing enabled`, + 'low', + { + stateMachineName: smName, + tracingEnabled: descRes.tracingConfiguration?.enabled, + }, + false, + `Use sfn:UpdateStateMachineCommand with stateMachineArn set to '${smArn}' and tracingConfiguration set to { enabled: true }. Ensure the state machine's IAM role has xray:PutTraceSegments and xray:PutTelemetryRecords permissions. Rollback: use sfn:UpdateStateMachineCommand with tracingConfiguration.enabled set to false.`, + ), + ); + } else { + findings.push( + this.makeFinding( + smArn, + 'X-Ray tracing enabled', + `Step Functions state machine "${smName}" has X-Ray tracing enabled`, + 'info', + { stateMachineName: smName, tracingEnabled: true }, + true, + ), + ); + } + + // Check encryption configuration + const encType = descRes.encryptionConfiguration?.type; + + if (encType && encType !== 'CUSTOMER_MANAGED_KMS_KEY') { + findings.push( + this.makeFinding( + smArn, + 'State machine using AWS-managed encryption key', + `Step Functions state machine "${smName}" uses an AWS-managed key instead of a customer-managed KMS key`, + 'low', + { stateMachineName: smName, encryptionType: encType }, + false, + `Use sfn:UpdateStateMachineCommand with stateMachineArn set to '${smArn}' and encryptionConfiguration set to { type: 'CUSTOMER_MANAGED_KMS_KEY', kmsKeyId: '', kmsDataKeyReusePeriodSeconds: 300 }. Rollback: use sfn:UpdateStateMachineCommand with encryptionConfiguration.type set to 'AWS_OWNED_KEY'.`, + ), + ); + } else if (encType === 'CUSTOMER_MANAGED_KMS_KEY') { + findings.push( + this.makeFinding( + smArn, + 'State machine using customer-managed KMS key', + `Step Functions state machine "${smName}" uses a customer-managed KMS key for encryption`, + 'info', + { stateMachineName: smName, encryptionType: encType }, + true, + ), + ); + } else { + findings.push( + this.makeFinding( + smArn, + 'State machine using AWS-managed encryption key', + `Step Functions state machine "${smName}" uses default AWS-managed encryption`, + 'low', + { stateMachineName: smName, encryptionType: 'AWS_OWNED_KEY' }, + false, + `Use sfn:UpdateStateMachineCommand with stateMachineArn set to '${smArn}' and encryptionConfiguration set to { type: 'CUSTOMER_MANAGED_KMS_KEY', kmsKeyId: '', kmsDataKeyReusePeriodSeconds: 300 }. Rollback: use sfn:UpdateStateMachineCommand with encryptionConfiguration.type set to 'AWS_OWNED_KEY'.`, + ), + ); + } + } + + nextToken = listRes.nextToken; + } while (nextToken); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `step-functions-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsStepFunction', + resourceId, + remediation, + evidence: { ...evidence, service: 'Step Functions', findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/systems-manager.adapter.ts b/apps/api/src/cloud-security/providers/aws/systems-manager.adapter.ts new file mode 100644 index 0000000000..0c91d0c847 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/systems-manager.adapter.ts @@ -0,0 +1,217 @@ +import { + SSMClient, + DescribeParametersCommand, + GetDocumentCommand, + DescribeInstanceInformationCommand, +} from '@aws-sdk/client-ssm'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +const SENSITIVE_NAME_PATTERN = /password|secret|key|token/i; + +export class SystemsManagerAdapter implements AwsServiceAdapter { + readonly serviceId = 'systems-manager'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new SSMClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + // Prerequisite: check if there are any managed instances + try { + const instanceResp = await client.send( + new DescribeInstanceInformationCommand({ MaxResults: 1 }), + ); + if ((instanceResp.InstanceInformationList ?? []).length === 0) return []; + } catch { + // If prerequisite check fails (permissions), fall through to existing behavior + } + + try { + // Check Session Manager logging via the SSM-SessionManagerRunShell document + // This is the actual source of truth for Session Manager logging config + try { + const docRes = await client.send( + new GetDocumentCommand({ + Name: 'SSM-SessionManagerRunShell', + DocumentVersion: '$LATEST', + }), + ); + const content = docRes.Content ? JSON.parse(docRes.Content) : {}; + const inputs = content?.inputs ?? {}; + + // Logging is configured if there's an actual destination (non-empty bucket or log group) + const s3Bucket = + typeof inputs.s3BucketName === 'string' + ? inputs.s3BucketName.trim() + : ''; + const cwLogGroup = + typeof inputs.cloudWatchLogGroupName === 'string' + ? inputs.cloudWatchLogGroupName.trim() + : ''; + const hasLogging = s3Bucket.length > 0 || cwLogGroup.length > 0; + + if (hasLogging) { + findings.push( + this.makeFinding( + `arn:aws:ssm:${region}:session-logging`, + 'Session Manager logging enabled', + 'Session Manager logging is enabled for this region', + 'info', + { + documentExists: true, + loggingEnabled: true, + s3BucketName: s3Bucket || null, + cloudWatchLogGroupName: cwLogGroup || null, + }, + true, + ), + ); + } else { + findings.push( + this.makeFinding( + `arn:aws:ssm:${region}:session-logging`, + 'Session Manager logging not configured', + 'Session Manager preferences exist but no logging destination (S3 bucket or CloudWatch log group) is configured.', + 'medium', + { + documentExists: true, + loggingEnabled: false, + currentInputs: inputs, + }, + undefined, + 'Update the SSM-SessionManagerRunShell document to enable logging. Use ssm:UpdateDocumentCommand with Name "SSM-SessionManagerRunShell" and Content that sets inputs.cloudWatchLogGroupName to a log group name (e.g. "/aws/ssm/session-logs") and inputs.cloudWatchEncryptionEnabled to true. The document Content must be a JSON string with schemaVersion "1.0" and sessionType "Standard_Stream". Create the log group first with logs:CreateLogGroupCommand if it does not exist. Rollback by calling ssm:UpdateDocumentCommand with the original Content.', + ), + ); + } + } catch (docError: unknown) { + const docMsg = + docError instanceof Error ? docError.message : String(docError); + const errName = + docError instanceof Error + ? ((docError as { name?: string }).name ?? '') + : ''; + + const isPermissionError = + docMsg.toLowerCase().includes('accessdenied') || + docMsg.toLowerCase().includes('not authorized') || + docMsg.toLowerCase().includes('access denied') || + errName === 'AccessDeniedException'; + + if (isPermissionError) { + // Skip silently — auditor role may not have ssm:GetDocument + } else { + // Document doesn't exist — Session Manager preferences not set up + findings.push( + this.makeFinding( + `arn:aws:ssm:${region}:session-logging`, + 'Session Manager logging not configured', + 'Session Manager preferences document does not exist. Logging is not configured.', + 'medium', + { documentExists: false, loggingEnabled: false }, + undefined, + 'Create the SSM-SessionManagerRunShell document with logging enabled. Use ssm:CreateDocumentCommand with Name "SSM-SessionManagerRunShell", DocumentType "Session", and Content as a JSON string with schemaVersion "1.0", sessionType "Standard_Stream", and inputs containing cloudWatchLogGroupName set to "/aws/ssm/session-logs" and cloudWatchEncryptionEnabled set to true. Create the log group first with logs:CreateLogGroupCommand. Rollback by calling ssm:DeleteDocumentCommand with Name "SSM-SessionManagerRunShell".', + ), + ); + } + } + + // Check parameters + let nextToken: string | undefined; + let paramCount = 0; + + do { + const paramRes = await client.send( + new DescribeParametersCommand({ + NextToken: nextToken, + MaxResults: 50, + }), + ); + + for (const param of paramRes.Parameters ?? []) { + paramCount++; + if (paramCount > 100) break; + + const paramName = param.Name ?? 'unknown'; + const resourceId = paramName; + + if (param.Type === 'SecureString') { + if (!param.KeyId || param.KeyId === 'alias/aws/ssm') { + findings.push( + this.makeFinding( + resourceId, + 'SecureString parameter uses default key', + `Parameter "${paramName}" is a SecureString but uses the default AWS-managed key`, + 'low', + { parameterName: paramName, keyId: param.KeyId ?? 'default' }, + ), + ); + } else { + findings.push( + this.makeFinding( + resourceId, + 'SecureString parameter uses CMK', + `Parameter "${paramName}" is encrypted with a customer-managed KMS key`, + 'info', + { parameterName: paramName, keyId: param.KeyId }, + true, + ), + ); + } + } else if (SENSITIVE_NAME_PATTERN.test(paramName)) { + findings.push( + this.makeFinding( + resourceId, + 'Potentially sensitive parameter not encrypted', + `Parameter "${paramName}" has a name suggesting sensitive content but is stored as ${param.Type ?? 'String'} instead of SecureString`, + 'medium', + { parameterName: paramName, type: param.Type ?? 'String' }, + ), + ); + } + } + + if (paramCount > 100) break; + nextToken = paramRes.NextToken; + } while (nextToken); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `systems-manager-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsSsmParameter', + resourceId, + remediation, + evidence: { ...evidence, service: 'Systems Manager', findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/transfer-family.adapter.ts b/apps/api/src/cloud-security/providers/aws/transfer-family.adapter.ts new file mode 100644 index 0000000000..f3f1da80a2 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/transfer-family.adapter.ts @@ -0,0 +1,148 @@ +import { + DescribeServerCommand, + ListServersCommand, + TransferClient, +} from '@aws-sdk/client-transfer'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class TransferFamilyAdapter implements AwsServiceAdapter { + readonly serviceId = 'transfer-family'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new TransferClient({ credentials, region }); + const findings: SecurityFinding[] = []; + + try { + let nextToken: string | undefined; + + do { + const listRes = await client.send( + new ListServersCommand({ NextToken: nextToken }), + ); + + for (const server of listRes.Servers ?? []) { + const serverId = server.ServerId; + if (!serverId) continue; + + const descRes = await client.send( + new DescribeServerCommand({ ServerId: serverId }), + ); + + const desc = descRes.Server; + if (!desc) continue; + + const serverArn = desc.Arn ?? serverId; + const protocols = desc.Protocols ?? []; + + if (protocols.includes('FTP')) { + findings.push( + this.makeFinding( + serverArn, + 'FTP protocol enabled (unencrypted)', + `Transfer server "${serverId}" has FTP enabled which transmits data unencrypted — use SFTP or FTPS instead`, + 'high', + { serverId, protocols, service: 'Transfer Family' }, + false, + `Use transfer:UpdateServerCommand with ServerId set to '${serverId}' and Protocols set to ['SFTP'] (or ['SFTP', 'FTPS'] if FTPS is also needed). Remove 'FTP' from the Protocols array. Rollback: use transfer:UpdateServerCommand to add 'FTP' back to Protocols. [MANUAL] Ensure all clients are updated to use SFTP/FTPS before removing FTP.`, + ), + ); + } + + const logDestinations = desc.StructuredLogDestinations ?? []; + + if (logDestinations.length === 0) { + findings.push( + this.makeFinding( + serverArn, + 'Structured logging not configured', + `Transfer server "${serverId}" does not have structured logging configured`, + 'medium', + { serverId, service: 'Transfer Family' }, + false, + `Use transfer:UpdateServerCommand with ServerId set to '${serverId}' and StructuredLogDestinations set to a CloudWatch Logs log group ARN (e.g., ['arn:aws:logs:region:account:log-group:/aws/transfer/${serverId}']). Ensure the server's IAM role has logs:CreateLogGroup, logs:CreateLogStream, and logs:PutLogEvents permissions. Rollback: use transfer:UpdateServerCommand with StructuredLogDestinations set to an empty array.`, + ), + ); + } + + if (desc.EndpointType === 'PUBLIC') { + findings.push( + this.makeFinding( + serverArn, + 'Server has public endpoint', + `Transfer server "${serverId}" uses a public endpoint — consider using VPC or VPC_ENDPOINT type`, + 'medium', + { + serverId, + endpointType: desc.EndpointType, + service: 'Transfer Family', + }, + false, + `Use transfer:UpdateServerCommand with ServerId set to '${serverId}' and EndpointType set to 'VPC', along with EndpointDetails containing VpcId, SubnetIds, and SecurityGroupIds. [MANUAL] Changing endpoint type causes server downtime and DNS changes. Ensure clients are updated with the new endpoint. Rollback: use transfer:UpdateServerCommand with EndpointType set to 'PUBLIC'.`, + ), + ); + } + + const hasNoIssues = + !protocols.includes('FTP') && + logDestinations.length > 0 && + desc.EndpointType !== 'PUBLIC'; + + if (hasNoIssues) { + findings.push( + this.makeFinding( + serverArn, + 'Transfer server is well configured', + `Transfer server "${serverId}" uses secure protocols, has logging, and a non-public endpoint`, + 'info', + { serverId, protocols, service: 'Transfer Family' }, + true, + ), + ); + } + } + + nextToken = listRes.NextToken; + } while (nextToken); + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + remediation?: string, + ): SecurityFinding { + const id = `transfer-family-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsTransferServer', + resourceId, + remediation, + evidence: { ...evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/aws/waf.adapter.ts b/apps/api/src/cloud-security/providers/aws/waf.adapter.ts new file mode 100644 index 0000000000..adc10dc584 --- /dev/null +++ b/apps/api/src/cloud-security/providers/aws/waf.adapter.ts @@ -0,0 +1,161 @@ +import { + GetWebACLCommand, + ListWebACLsCommand, + WAFV2Client, +} from '@aws-sdk/client-wafv2'; +import { + ElasticLoadBalancingV2Client, + DescribeLoadBalancersCommand, +} from '@aws-sdk/client-elastic-load-balancing-v2'; +import { + ApiGatewayV2Client, + GetApisCommand, +} from '@aws-sdk/client-apigatewayv2'; + +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AwsCredentials, AwsServiceAdapter } from './aws-service-adapter'; + +export class WafAdapter implements AwsServiceAdapter { + readonly serviceId = 'waf'; + readonly isGlobal = false; + + async scan({ + credentials, + region, + }: { + credentials: AwsCredentials; + region: string; + accountId?: string; + }): Promise { + const client = new WAFV2Client({ credentials, region }); + const findings: SecurityFinding[] = []; + + // Prerequisite: check if there are web-facing resources (ALBs, API Gateways) + try { + let hasWebResources = false; + + const elbClient = new ElasticLoadBalancingV2Client({ + credentials, + region, + }); + const elbResp = await elbClient.send( + new DescribeLoadBalancersCommand({ PageSize: 1 }), + ); + if ((elbResp.LoadBalancers ?? []).length > 0) { + hasWebResources = true; + } + + if (!hasWebResources) { + const apigwClient = new ApiGatewayV2Client({ credentials, region }); + const apigwResp = await apigwClient.send( + new GetApisCommand({ MaxResults: '1' }), + ); + if ((apigwResp.Items ?? []).length > 0) { + hasWebResources = true; + } + } + + if (!hasWebResources) return []; + } catch { + // If prerequisite check fails (permissions), fall through to existing behavior + } + + try { + let nextMarker: string | undefined; + let hasAcls = false; + + do { + const listRes = await client.send( + new ListWebACLsCommand({ Scope: 'REGIONAL', NextMarker: nextMarker }), + ); + + for (const summary of listRes.WebACLs ?? []) { + hasAcls = true; + const arn = summary.ARN; + if (!arn || !summary.Name || !summary.Id) continue; + + try { + const aclRes = await client.send( + new GetWebACLCommand({ + Name: summary.Name, + Scope: 'REGIONAL', + Id: summary.Id, + }), + ); + + const rules = aclRes.WebACL?.Rules ?? []; + + if (rules.length === 0) { + findings.push( + this.makeFinding( + arn, + 'WAF ACL has no rules', + `Web ACL "${summary.Name}" has no rules configured, providing no protection`, + 'medium', + { aclName: summary.Name }, + ), + ); + } else { + findings.push( + this.makeFinding( + arn, + 'WAF ACL has rules configured', + `Web ACL "${summary.Name}" has ${rules.length} rule(s) configured`, + 'info', + { aclName: summary.Name, ruleCount: rules.length }, + true, + ), + ); + } + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('WAFNonexistentItemException')) continue; + throw error; + } + } + + nextMarker = listRes.NextMarker; + } while (nextMarker); + + if (!hasAcls) { + findings.push( + this.makeFinding( + `arn:aws:wafv2:${region}:no-acls`, + 'No WAF web ACLs configured', + 'No regional WAF web ACLs found in this region', + 'medium', + { region }, + ), + ); + } + } catch (error: unknown) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('AccessDenied')) return []; + throw error; + } + + return findings; + } + + private makeFinding( + resourceId: string, + title: string, + description: string, + severity: SecurityFinding['severity'], + evidence?: Record, + passed?: boolean, + ): SecurityFinding { + const id = `waf-${resourceId}-${title.toLowerCase().replace(/\s+/g, '-')}`; + return { + id, + title, + description, + severity, + resourceType: 'AwsWafWebAcl', + resourceId, + evidence: { ...evidence, findingKey: id }, + createdAt: new Date().toISOString(), + passed, + }; + } +} diff --git a/apps/api/src/cloud-security/providers/azure-security.service.ts b/apps/api/src/cloud-security/providers/azure-security.service.ts index 4b87c25560..cb5e22823e 100644 --- a/apps/api/src/cloud-security/providers/azure-security.service.ts +++ b/apps/api/src/cloud-security/providers/azure-security.service.ts @@ -1,5 +1,23 @@ import { Injectable, Logger } from '@nestjs/common'; import type { SecurityFinding } from '../cloud-security.service'; +import { + type AzureServiceAdapter, + fetchAllPages, + AZURE_CATEGORY_TO_SERVICE, + AZURE_SERVICE_NAMES, + AksAdapter, + AppServiceAdapter, + ContainerRegistryAdapter, + CosmosDbAdapter, + EntraIdAdapter, + KeyVaultAdapter, + MonitorAdapter, + NetworkWatcherAdapter, + PolicyAdapter, + SqlDatabaseAdapter, + StorageAccountAdapter, + VirtualMachineAdapter, +} from './azure'; interface AzureSecurityAlert { name: string; @@ -21,10 +39,7 @@ interface AzureSecurityAssessment { name: string; properties: { displayName: string; - status: { - code: string; - description?: string; - }; + status: { code: string; description?: string }; metadata?: { severity?: string; description?: string; @@ -35,10 +50,21 @@ interface AzureSecurityAssessment { }; } -interface AzureListResponse { - value: T[]; - nextLink?: string; -} +/** All implemented service adapters beyond Defender. */ +const SERVICE_ADAPTERS: AzureServiceAdapter[] = [ + new AksAdapter(), + new AppServiceAdapter(), + new ContainerRegistryAdapter(), + new CosmosDbAdapter(), + new EntraIdAdapter(), + new KeyVaultAdapter(), + new MonitorAdapter(), + new NetworkWatcherAdapter(), + new PolicyAdapter(), + new SqlDatabaseAdapter(), + new StorageAccountAdapter(), + new VirtualMachineAdapter(), +]; @Injectable() export class AzureSecurityService { @@ -46,37 +72,82 @@ export class AzureSecurityService { async scanSecurityFindings( credentials: Record, - _variables: Record, + variables: Record, + enabledServices?: string[], ): Promise { - const tenantId = credentials.tenantId as string; - const clientId = credentials.clientId as string; - const clientSecret = credentials.clientSecret as string; - const subscriptionId = credentials.subscriptionId as string; + // OAuth flow: access_token from vault + subscription_id from variables + const accessToken = credentials.access_token as string | undefined; + const subscriptionId = + (variables.subscription_id as string) || + (credentials.subscriptionId as string); + + // Legacy flow fallback: client credentials + let token = accessToken; + if (!token) { + const tenantId = credentials.tenantId as string; + const clientId = credentials.clientId as string; + const clientSecret = credentials.clientSecret as string; + if (tenantId && clientId && clientSecret) { + token = await this.getAccessToken(tenantId, clientId, clientSecret); + } + } - if (!tenantId || !clientId || !clientSecret || !subscriptionId) { + if (!token) { throw new Error( - 'Azure credentials incomplete. Ensure tenantId, clientId, clientSecret, and subscriptionId are configured.', + 'Azure credentials missing. Please reconnect the integration.', + ); + } + + if (!subscriptionId) { + throw new Error( + 'AZURE_SUB_MISSING: Azure Subscription ID not configured. Run the Azure setup to auto-detect it.', ); } this.logger.log(`Scanning Azure subscription ${subscriptionId}`); + const findings: SecurityFinding[] = []; - // Get access token - const accessToken = await this.getAccessToken( - tenantId, - clientId, - clientSecret, - ); + // 1. Defender alerts + assessments (always runs) + if (!enabledServices || enabledServices.includes('defender')) { + const defenderFindings = await this.scanDefender(token, subscriptionId); + findings.push(...defenderFindings); + } + + // 2. Run service adapters in parallel + const adapterPromises = SERVICE_ADAPTERS.filter( + (a) => !enabledServices || enabledServices.includes(a.serviceId), + ).map(async (adapter) => { + try { + return await adapter.scan({ accessToken: token, subscriptionId }); + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + this.logger.warn(`Azure ${adapter.serviceId} scan failed: ${msg}`); + return []; + } + }); + + const adapterResults = await Promise.all(adapterPromises); + for (const result of adapterResults) { + findings.push(...result); + } + + this.logger.log(`Azure scan complete: ${findings.length} total findings`); + return findings; + } + /** Scan Defender for Cloud alerts and assessments. */ + private async scanDefender( + accessToken: string, + subscriptionId: string, + ): Promise { const findings: SecurityFinding[] = []; - // Fetch security alerts + // Alerts try { const alerts = await this.getSecurityAlerts(accessToken, subscriptionId); const activeAlerts = alerts.filter( (a) => a.properties.status === 'Active', ); - this.logger.log(`Found ${activeAlerts.length} active security alerts`); for (const alert of activeAlerts) { @@ -91,6 +162,9 @@ export class AzureSecurityService { alert.properties.remediationSteps?.join('\n') || 'Review the alert in Microsoft Defender for Cloud', evidence: { + serviceId: 'defender', + serviceName: 'Microsoft Defender', + findingKey: `azure-defender-alert-${alert.properties.alertType || alert.name}`, alertType: alert.properties.alertType, compromisedEntity: alert.properties.compromisedEntity, intent: alert.properties.intent, @@ -100,30 +174,15 @@ export class AzureSecurityService { }); } } catch (error) { - const errorMsg = error instanceof Error ? error.message : String(error); - this.logger.warn(`Failed to fetch security alerts: ${errorMsg}`); - - if ( - errorMsg.includes('403') || - errorMsg.includes('AuthorizationFailed') - ) { - findings.push({ - id: `permission-alerts-${subscriptionId}`, - title: 'Unable to access Security Alerts', - description: - 'The service principal does not have permission to read security alerts.', - severity: 'medium', - resourceType: 'security-alerts', - resourceId: subscriptionId, - remediation: - 'Assign the "Security Reader" role to your App Registration on the subscription.', - evidence: { error: errorMsg }, - createdAt: new Date().toISOString(), - }); - } + this.handlePermissionError( + findings, + error, + 'Security Alerts', + subscriptionId, + ); } - // Fetch security assessments + // Assessments try { const assessments = await this.getSecurityAssessments( accessToken, @@ -132,18 +191,19 @@ export class AzureSecurityService { const unhealthy = assessments.filter( (a) => a.properties.status.code === 'Unhealthy', ); - this.logger.log( `Found ${unhealthy.length} unhealthy security assessments`, ); - // Limit to 50 to avoid overwhelming for (const assessment of unhealthy.slice(0, 50)) { + const category = assessment.properties.metadata?.category; + const serviceId = + (category && AZURE_CATEGORY_TO_SERVICE[category]) || 'defender'; + findings.push({ id: assessment.name, title: assessment.properties.displayName || - assessment.name || 'Unhealthy security assessment', description: assessment.properties.metadata?.description || @@ -158,47 +218,101 @@ export class AzureSecurityService { assessment.properties.metadata?.remediationDescription || 'Review and remediate in Microsoft Defender for Cloud', evidence: { + serviceId, + serviceName: AZURE_SERVICE_NAMES[serviceId] ?? serviceId, + findingKey: `azure-defender-assessment-${assessment.name}`, status: assessment.properties.status, - category: assessment.properties.metadata?.category, + category, }, createdAt: new Date().toISOString(), }); } } catch (error) { - const errorMsg = error instanceof Error ? error.message : String(error); - this.logger.warn(`Failed to fetch security assessments: ${errorMsg}`); - - if ( - errorMsg.includes('403') || - errorMsg.includes('AuthorizationFailed') - ) { - findings.push({ - id: `permission-assessments-${subscriptionId}`, - title: 'Unable to access Security Assessments', - description: - 'The service principal does not have permission to read security assessments.', - severity: 'medium', - resourceType: 'security-assessments', - resourceId: subscriptionId, - remediation: - 'Assign the "Security Reader" role to your App Registration on the subscription.', - evidence: { error: errorMsg }, - createdAt: new Date().toISOString(), - }); - } + this.handlePermissionError( + findings, + error, + 'Security Assessments', + subscriptionId, + ); } - this.logger.log(`Azure scan complete: ${findings.length} total findings`); return findings; } - private async getAccessToken( + private handlePermissionError( + findings: SecurityFinding[], + error: unknown, + component: string, + subscriptionId: string, + ): void { + const msg = error instanceof Error ? error.message : String(error); + this.logger.warn(`Failed to fetch ${component}: ${msg}`); + + if (msg.includes('403') || msg.includes('AuthorizationFailed')) { + findings.push({ + id: `permission-${component.toLowerCase().replace(/\s/g, '-')}-${subscriptionId}`, + title: `Unable to access ${component}`, + description: `The service principal does not have permission to read ${component.toLowerCase()}.`, + severity: 'medium', + resourceType: component.toLowerCase().replace(/\s/g, '-'), + resourceId: subscriptionId, + remediation: + 'Assign the "Security Reader" role to your App Registration on the subscription.', + evidence: { + serviceId: 'defender', + serviceName: 'Microsoft Defender', + error: msg, + }, + createdAt: new Date().toISOString(), + }); + } + } + + /** + * Detect Azure subscriptions accessible by the user's OAuth token. + */ + async detectSubscriptions( + accessToken: string, + ): Promise> { + const response = await fetch( + 'https://management.azure.com/subscriptions?api-version=2022-12-01', + { + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }, + ); + + if (!response.ok) { + const error = await response.text(); + throw new Error(`Failed to list Azure subscriptions: ${error}`); + } + + const data = (await response.json()) as { + value: Array<{ + subscriptionId: string; + displayName: string; + state: string; + }>; + }; + + return (data.value ?? []) + .filter((s) => s.state === 'Enabled') + .map((s) => ({ + id: s.subscriptionId, + displayName: s.displayName, + state: s.state, + })); + } + + /** Legacy: get access token via Service Principal client credentials. */ + async getAccessToken( tenantId: string, clientId: string, clientSecret: string, ): Promise { const tokenUrl = `https://login.microsoftonline.com/${tenantId}/oauth2/v2.0/token`; - const body = new URLSearchParams({ client_id: clientId, client_secret: clientSecret, @@ -217,52 +331,25 @@ export class AzureSecurityService { throw new Error(`Azure authentication failed: ${error}`); } - const data = await response.json(); + const data = (await response.json()) as { access_token: string }; return data.access_token; } - private async getSecurityAlerts( - accessToken: string, - subscriptionId: string, - ): Promise { - const url = `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Security/alerts?api-version=2022-01-01`; - return this.fetchAllPages(accessToken, url); + private async getSecurityAlerts(accessToken: string, subscriptionId: string) { + return fetchAllPages( + accessToken, + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Security/alerts?api-version=2022-01-01`, + ); } private async getSecurityAssessments( accessToken: string, subscriptionId: string, - ): Promise { - const url = `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Security/assessments?api-version=2021-06-01`; - return this.fetchAllPages(accessToken, url); - } - - private async fetchAllPages( - accessToken: string, - initialUrl: string, - ): Promise { - const results: T[] = []; - let url: string | undefined = initialUrl; - - while (url) { - const response = await fetch(url, { - headers: { - Authorization: `Bearer ${accessToken}`, - 'Content-Type': 'application/json', - }, - }); - - if (!response.ok) { - const error = await response.text(); - throw new Error(`Azure API error (${response.status}): ${error}`); - } - - const data: AzureListResponse = await response.json(); - results.push(...data.value); - url = data.nextLink; - } - - return results; + ) { + return fetchAllPages( + accessToken, + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Security/assessments?api-version=2021-06-01`, + ); } private mapSeverity(azureSeverity: string): SecurityFinding['severity'] { diff --git a/apps/api/src/cloud-security/providers/azure/aks.adapter.ts b/apps/api/src/cloud-security/providers/azure/aks.adapter.ts new file mode 100644 index 0000000000..7eb03387ba --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/aks.adapter.ts @@ -0,0 +1,210 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface AksCluster { + id: string; + name: string; + location: string; + properties: { + kubernetesVersion?: string; + enableRBAC?: boolean; + aadProfile?: { + managed?: boolean; + enableAzureRBAC?: boolean; + }; + apiServerAccessProfile?: { + authorizedIPRanges?: string[]; + enablePrivateCluster?: boolean; + }; + networkProfile?: { + networkPolicy?: string; // 'azure' | 'calico' | null + outboundType?: string; + }; + addonProfiles?: { + azurePolicy?: { enabled: boolean }; + omsagent?: { enabled: boolean }; + }; + autoUpgradeProfile?: { + upgradeChannel?: string; // 'none' | 'patch' | 'rapid' | 'stable' | 'node-image' + }; + }; +} + +export class AksAdapter implements AzureServiceAdapter { + readonly serviceId = 'aks'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + + const clusters = await fetchAllPages( + accessToken, + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.ContainerService/managedClusters?api-version=2024-01-01`, + ); + + if (clusters.length === 0) return findings; + + for (const cluster of clusters) { + const props = cluster.properties; + + // Check 1: Kubernetes RBAC + if (props.enableRBAC !== true) { + findings.push( + this.finding(cluster, { + key: 'rbac-disabled', + title: `Kubernetes RBAC Disabled: ${cluster.name}`, + description: `AKS cluster "${cluster.name}" does not have Kubernetes RBAC enabled. All users have full access to all resources.`, + severity: 'critical', + remediation: + 'Enable Kubernetes RBAC. Note: this requires cluster recreation for existing clusters.', + }), + ); + } + + // Check 2: Azure AD integration + if (!props.aadProfile?.managed) { + findings.push( + this.finding(cluster, { + key: 'no-aad-integration', + title: `No Azure AD Integration: ${cluster.name}`, + description: `AKS cluster "${cluster.name}" is not integrated with Azure AD. Use Azure AD for centralized identity management.`, + severity: 'medium', + remediation: 'Enable managed Azure AD integration on the cluster.', + }), + ); + } + + // Check 3: Network policy + if (!props.networkProfile?.networkPolicy) { + findings.push( + this.finding(cluster, { + key: 'no-network-policy', + title: `No Network Policy: ${cluster.name}`, + description: `AKS cluster "${cluster.name}" has no network policy plugin configured. All pods can communicate with each other without restriction.`, + severity: 'high', + remediation: + 'Enable Azure or Calico network policy plugin. Note: requires cluster recreation.', + }), + ); + } + + // Check 4: Private cluster / API server access + const apiAccess = props.apiServerAccessProfile; + if ( + !apiAccess?.enablePrivateCluster && + (!apiAccess?.authorizedIPRanges || + apiAccess.authorizedIPRanges.length === 0) + ) { + findings.push( + this.finding(cluster, { + key: 'api-server-public', + title: `API Server Publicly Accessible: ${cluster.name}`, + description: `AKS cluster "${cluster.name}" API server is accessible from the internet without IP restrictions.`, + severity: 'high', + remediation: + 'Enable private cluster or configure authorized IP ranges for the API server.', + }), + ); + } + + // Check 5: Azure Policy addon + if (!props.addonProfiles?.azurePolicy?.enabled) { + findings.push( + this.finding(cluster, { + key: 'no-azure-policy', + title: `Azure Policy Not Enabled: ${cluster.name}`, + description: `AKS cluster "${cluster.name}" does not have the Azure Policy addon enabled for Kubernetes governance.`, + severity: 'low', + remediation: 'Enable the Azure Policy addon on the cluster.', + }), + ); + } + + // Check 6: Auto-upgrade + const upgradeChannel = props.autoUpgradeProfile?.upgradeChannel; + if (!upgradeChannel || upgradeChannel === 'none') { + findings.push( + this.finding(cluster, { + key: 'no-auto-upgrade', + title: `Auto-Upgrade Disabled: ${cluster.name}`, + description: `AKS cluster "${cluster.name}" does not have auto-upgrade configured. Clusters may fall behind on security patches.`, + severity: 'medium', + remediation: + 'Set auto-upgrade channel to "patch" or "stable" for automatic security updates.', + }), + ); + } + + // Check 7: Monitoring + if (!props.addonProfiles?.omsagent?.enabled) { + findings.push( + this.finding(cluster, { + key: 'no-monitoring', + title: `Container Monitoring Disabled: ${cluster.name}`, + description: `AKS cluster "${cluster.name}" does not have Container Insights (OMS agent) enabled.`, + severity: 'medium', + remediation: + 'Enable the monitoring addon to collect container logs and metrics.', + }), + ); + } + } + + if (findings.length === 0) { + findings.push({ + id: `azure-aks-ok-${subscriptionId}`, + title: 'AKS Cluster Security', + description: `All ${clusters.length} AKS cluster(s) are properly configured.`, + severity: 'info', + resourceType: 'aks', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'AKS', + findingKey: 'azure-aks-all-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + return findings; + } + + private finding( + cluster: AksCluster, + opts: { + key: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation: string; + }, + ): SecurityFinding { + return { + id: `azure-aks-${opts.key}-${cluster.name}`, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'aks', + resourceId: cluster.id, + remediation: opts.remediation, + evidence: { + serviceId: this.serviceId, + serviceName: 'AKS', + findingKey: `azure-aks-${opts.key}`, + clusterName: cluster.name, + location: cluster.location, + kubernetesVersion: cluster.properties.kubernetesVersion, + }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/app-service.adapter.ts b/apps/api/src/cloud-security/providers/azure/app-service.adapter.ts new file mode 100644 index 0000000000..43e56e6833 --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/app-service.adapter.ts @@ -0,0 +1,187 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface WebApp { + id: string; + name: string; + location: string; + kind: string; // 'app' | 'functionapp' | 'app,linux' etc. + identity?: { + type: string; + }; + properties: { + httpsOnly?: boolean; + clientCertEnabled?: boolean; + siteConfig?: { + minTlsVersion?: string; + ftpsState?: string; + remoteDebuggingEnabled?: boolean; + http20Enabled?: boolean; + managedPipelineMode?: string; + }; + state?: string; + }; +} + +export class AppServiceAdapter implements AzureServiceAdapter { + readonly serviceId = 'app-service'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + + const apps = await fetchAllPages( + accessToken, + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Web/sites?api-version=2023-12-01`, + ); + + // Only check running apps + const activeApps = apps.filter((a) => a.properties.state === 'Running'); + if (activeApps.length === 0) return findings; + + for (const app of activeApps) { + const props = app.properties; + const config = props.siteConfig; + + // Check 1: HTTPS-only + if (props.httpsOnly !== true) { + findings.push( + this.finding(app, { + key: 'https-disabled', + title: `HTTPS Not Enforced: ${app.name}`, + description: `App Service "${app.name}" does not enforce HTTPS-only traffic. HTTP requests are not redirected.`, + severity: 'high', + remediation: 'Enable "HTTPS Only" in the app TLS/SSL settings.', + }), + ); + } + + // Check 2: TLS version + if (config?.minTlsVersion && config.minTlsVersion < '1.2') { + findings.push( + this.finding(app, { + key: 'tls-outdated', + title: `Outdated TLS Version: ${app.name}`, + description: `App Service "${app.name}" allows TLS versions below 1.2 (current: ${config.minTlsVersion}).`, + severity: 'medium', + remediation: + 'Set minimum TLS version to 1.2 in the TLS/SSL settings.', + }), + ); + } + + // Check 3: Remote debugging + if (config?.remoteDebuggingEnabled === true) { + findings.push( + this.finding(app, { + key: 'remote-debug', + title: `Remote Debugging Enabled: ${app.name}`, + description: `App Service "${app.name}" has remote debugging enabled. This opens additional ports and should only be used during development.`, + severity: 'high', + remediation: 'Disable remote debugging in the app configuration.', + }), + ); + } + + // Check 4: FTPS state + if (config?.ftpsState === 'AllAllowed') { + findings.push( + this.finding(app, { + key: 'ftp-allowed', + title: `FTP Access Allowed: ${app.name}`, + description: `App Service "${app.name}" allows unencrypted FTP. Use FTPS or disable FTP entirely.`, + severity: 'medium', + remediation: + 'Set FTPS state to "FtpsOnly" or "Disabled" in deployment settings.', + }), + ); + } + + // Check 5: Managed identity + const hasIdentity = app.identity?.type && app.identity.type !== 'None'; + if (!hasIdentity) { + findings.push( + this.finding(app, { + key: 'no-managed-identity', + title: `No Managed Identity: ${app.name}`, + description: `App Service "${app.name}" does not use a managed identity. Use managed identities for secure authentication to Azure services.`, + severity: 'low', + remediation: + 'Enable system-assigned or user-assigned managed identity.', + }), + ); + } + + // Check 6: HTTP/2 + if (config?.http20Enabled === false) { + findings.push( + this.finding(app, { + key: 'http2-disabled', + title: `HTTP/2 Disabled: ${app.name}`, + description: `App Service "${app.name}" does not have HTTP/2 enabled. HTTP/2 provides performance and security improvements.`, + severity: 'info', + remediation: + 'Enable HTTP/2 in the app configuration for improved performance.', + }), + ); + } + } + + if (findings.length === 0) { + findings.push({ + id: `azure-appservice-ok-${subscriptionId}`, + title: 'App Service Security', + description: `All ${activeApps.length} active App Service(s) are properly configured.`, + severity: 'info', + resourceType: 'app-service', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'App Service', + findingKey: 'azure-app-service-all-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + return findings; + } + + private finding( + app: WebApp, + opts: { + key: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation: string; + }, + ): SecurityFinding { + return { + id: `azure-app-${opts.key}-${app.name}`, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'app-service', + resourceId: app.id, + remediation: opts.remediation, + evidence: { + serviceId: this.serviceId, + serviceName: 'App Service', + findingKey: `azure-app-service-${opts.key}`, + appName: app.name, + kind: app.kind, + location: app.location, + }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/azure-service-adapter.ts b/apps/api/src/cloud-security/providers/azure/azure-service-adapter.ts new file mode 100644 index 0000000000..43ae9ec496 --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/azure-service-adapter.ts @@ -0,0 +1,81 @@ +import type { SecurityFinding } from '../../cloud-security.service'; + +export interface AzureServiceAdapter { + /** Must match the manifest service ID (e.g. 'defender', 'entra-id') */ + readonly serviceId: string; + scan(params: { + accessToken: string; + subscriptionId: string; + }): Promise; +} + +/** Shared pagination helper for Azure ARM list APIs. */ +export async function fetchAllPages( + accessToken: string, + initialUrl: string, +): Promise { + const results: T[] = []; + let url: string | undefined = initialUrl; + + while (url) { + const response = await fetch(url, { + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }); + + if (!response.ok) { + const error = await response.text(); + throw new Error(`Azure API error (${response.status}): ${error}`); + } + + const data = (await response.json()) as { value: T[]; nextLink?: string }; + results.push(...data.value); + url = data.nextLink; + } + + return results; +} + +/** Map Defender assessment categories → our service IDs. */ +export const AZURE_CATEGORY_TO_SERVICE: Record = { + // Identity & Access + 'Identity and Access': 'entra-id', + IdentityAndAccess: 'entra-id', + // Network + Networking: 'network-watcher', + Network: 'network-watcher', + // Data + Data: 'key-vault', + 'Data Protection': 'key-vault', + Encryption: 'key-vault', + // Compute + Compute: 'defender', + Container: 'defender', + AppServices: 'defender', + // Governance + 'Regulatory Compliance': 'policy', + Governance: 'policy', + // Monitoring + 'Logging and Threat Detection': 'monitor', + IoT: 'defender', + API: 'defender', +}; + +/** Human-readable service names for UI grouping. */ +export const AZURE_SERVICE_NAMES: Record = { + defender: 'Microsoft Defender', + 'entra-id': 'Entra ID', + policy: 'Azure Policy', + 'key-vault': 'Key Vault', + monitor: 'Azure Monitor', + 'network-watcher': 'Network Watcher', + 'storage-account': 'Storage Accounts', + 'sql-database': 'SQL Database', + 'virtual-machine': 'Virtual Machines', + 'app-service': 'App Service', + aks: 'AKS', + 'container-registry': 'Container Registry', + 'cosmos-db': 'Cosmos DB', +}; diff --git a/apps/api/src/cloud-security/providers/azure/container-registry.adapter.ts b/apps/api/src/cloud-security/providers/azure/container-registry.adapter.ts new file mode 100644 index 0000000000..450e5cbbf8 --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/container-registry.adapter.ts @@ -0,0 +1,176 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface ContainerRegistry { + id: string; + name: string; + location: string; + sku: { name: string; tier: string }; + properties: { + adminUserEnabled?: boolean; + publicNetworkAccess?: string; + networkRuleSet?: { + defaultAction: string; + }; + encryption?: { + status: string; + keyVaultProperties?: unknown; + }; + policies?: { + trustPolicy?: { status: string; type?: string }; + retentionPolicy?: { status: string; days?: number }; + quarantinePolicy?: { status: string }; + }; + anonymousPullEnabled?: boolean; + }; +} + +export class ContainerRegistryAdapter implements AzureServiceAdapter { + readonly serviceId = 'container-registry'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + + const registries = await fetchAllPages( + accessToken, + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.ContainerRegistry/registries?api-version=2023-11-01-preview`, + ); + + if (registries.length === 0) return findings; + + for (const reg of registries) { + const props = reg.properties; + + // Check 1: Admin user + if (props.adminUserEnabled === true) { + findings.push( + this.finding(reg, { + key: 'admin-enabled', + title: `Admin User Enabled: ${reg.name}`, + description: `Container Registry "${reg.name}" has the admin user enabled. Use service principals or managed identities instead.`, + severity: 'high', + remediation: + 'Disable the admin user and use Azure AD service principals or managed identities for authentication.', + }), + ); + } + + // Check 2: Public network access + const isPublic = + props.publicNetworkAccess !== 'Disabled' && + props.networkRuleSet?.defaultAction !== 'Deny'; + if (isPublic) { + findings.push( + this.finding(reg, { + key: 'public-access', + title: `Public Network Access: ${reg.name}`, + description: `Container Registry "${reg.name}" is publicly accessible. Restrict to private endpoints or specific networks.`, + severity: 'medium', + remediation: + 'Disable public network access and use private endpoints. Requires Premium SKU.', + }), + ); + } + + // Check 3: Content trust (image signing) + if (props.policies?.trustPolicy?.status !== 'enabled') { + findings.push( + this.finding(reg, { + key: 'no-content-trust', + title: `Content Trust Disabled: ${reg.name}`, + description: `Container Registry "${reg.name}" does not have content trust enabled. Images are not verified for integrity.`, + severity: 'medium', + remediation: + 'Enable content trust policy to require signed images. Requires Premium SKU.', + }), + ); + } + + // Check 4: Anonymous pull + if (props.anonymousPullEnabled === true) { + findings.push( + this.finding(reg, { + key: 'anonymous-pull', + title: `Anonymous Pull Enabled: ${reg.name}`, + description: `Container Registry "${reg.name}" allows anonymous (unauthenticated) image pulls.`, + severity: 'medium', + remediation: + 'Disable anonymous pull unless the registry is intentionally public.', + }), + ); + } + + // Check 5: Retention policy + if (props.policies?.retentionPolicy?.status !== 'enabled') { + findings.push( + this.finding(reg, { + key: 'no-retention', + title: `No Retention Policy: ${reg.name}`, + description: `Container Registry "${reg.name}" has no retention policy for untagged manifests. Old images accumulate without cleanup.`, + severity: 'low', + remediation: + 'Enable a retention policy to automatically purge untagged manifests. Requires Premium SKU.', + }), + ); + } + } + + if (findings.length === 0) { + findings.push({ + id: `azure-acr-ok-${subscriptionId}`, + title: 'Container Registry Security', + description: `All ${registries.length} container registr(ies) are properly configured.`, + severity: 'info', + resourceType: 'container-registry', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Container Registry', + findingKey: 'azure-container-registry-all-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + return findings; + } + + private finding( + reg: ContainerRegistry, + opts: { + key: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation: string; + }, + ): SecurityFinding { + return { + id: `azure-acr-${opts.key}-${reg.name}`, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'container-registry', + resourceId: reg.id, + remediation: opts.remediation, + evidence: { + serviceId: this.serviceId, + serviceName: 'Container Registry', + findingKey: `azure-container-registry-${opts.key}`, + registryName: reg.name, + sku: reg.sku.name, + location: reg.location, + }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/cosmos-db.adapter.ts b/apps/api/src/cloud-security/providers/azure/cosmos-db.adapter.ts new file mode 100644 index 0000000000..75ca0342d1 --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/cosmos-db.adapter.ts @@ -0,0 +1,182 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface CosmosDbAccount { + id: string; + name: string; + location: string; + properties: { + publicNetworkAccess?: string; + isVirtualNetworkFilterEnabled?: boolean; + ipRules?: Array<{ ipAddressOrRange: string }>; + disableKeyBasedMetadataWriteAccess?: boolean; + enableAutomaticFailover?: boolean; + enableMultipleWriteLocations?: boolean; + disableLocalAuth?: boolean; + networkAclBypass?: string; + minimalTlsVersion?: string; + backupPolicy?: { + type: string; // 'Periodic' | 'Continuous' + periodicModeProperties?: { + backupIntervalInMinutes?: number; + backupRetentionIntervalInHours?: number; + }; + }; + }; +} + +export class CosmosDbAdapter implements AzureServiceAdapter { + readonly serviceId = 'cosmos-db'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + + const accounts = await fetchAllPages( + accessToken, + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.DocumentDB/databaseAccounts?api-version=2024-02-15-preview`, + ); + + if (accounts.length === 0) return findings; + + for (const acct of accounts) { + const props = acct.properties; + + // Check 1: Public network access + const hasIpRules = (props.ipRules?.length ?? 0) > 0; + const hasVnetFilter = props.isVirtualNetworkFilterEnabled === true; + if ( + props.publicNetworkAccess !== 'Disabled' && + !hasIpRules && + !hasVnetFilter + ) { + findings.push( + this.finding(acct, { + key: 'public-unrestricted', + title: `Public Access Unrestricted: ${acct.name}`, + description: `Cosmos DB account "${acct.name}" is publicly accessible without IP or VNet restrictions.`, + severity: 'high', + remediation: + 'Disable public network access or add IP rules / VNet service endpoints.', + }), + ); + } + + // Check 2: Local auth (key-based) + if (props.disableLocalAuth !== true) { + findings.push( + this.finding(acct, { + key: 'local-auth-enabled', + title: `Key-Based Auth Enabled: ${acct.name}`, + description: `Cosmos DB account "${acct.name}" allows key-based authentication. Use Azure AD authentication for better security and auditing.`, + severity: 'medium', + remediation: + 'Disable local authentication and use Azure AD RBAC for data plane access.', + }), + ); + } + + // Check 3: Automatic failover + if (props.enableAutomaticFailover !== true) { + findings.push( + this.finding(acct, { + key: 'no-auto-failover', + title: `Automatic Failover Disabled: ${acct.name}`, + description: `Cosmos DB account "${acct.name}" does not have automatic failover enabled. Manual intervention required during regional outages.`, + severity: 'low', + remediation: 'Enable automatic failover for high availability.', + }), + ); + } + + // Check 4: Backup policy + const backup = props.backupPolicy; + if (backup?.type === 'Periodic') { + const retention = + backup.periodicModeProperties?.backupRetentionIntervalInHours ?? 0; + if (retention < 24) { + findings.push( + this.finding(acct, { + key: 'low-backup-retention', + title: `Low Backup Retention: ${acct.name}`, + description: `Cosmos DB account "${acct.name}" has backup retention of only ${retention} hours. Consider increasing for disaster recovery.`, + severity: 'medium', + remediation: + 'Increase backup retention or switch to continuous backup mode.', + }), + ); + } + } + + // Check 5: Metadata write access + if (props.disableKeyBasedMetadataWriteAccess !== true) { + findings.push( + this.finding(acct, { + key: 'metadata-write-enabled', + title: `Key-Based Metadata Write Enabled: ${acct.name}`, + description: `Cosmos DB account "${acct.name}" allows key-based metadata write access. This means account keys can modify database resources (create/delete databases, containers).`, + severity: 'low', + remediation: + 'Disable key-based metadata write access to prevent accidental resource modification via account keys.', + }), + ); + } + } + + if (findings.length === 0) { + findings.push({ + id: `azure-cosmos-ok-${subscriptionId}`, + title: 'Cosmos DB Security', + description: `All ${accounts.length} Cosmos DB account(s) are properly configured.`, + severity: 'info', + resourceType: 'cosmos-db', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Cosmos DB', + findingKey: 'azure-cosmos-db-all-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + return findings; + } + + private finding( + acct: CosmosDbAccount, + opts: { + key: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation: string; + }, + ): SecurityFinding { + return { + id: `azure-cosmos-${opts.key}-${acct.name}`, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'cosmos-db', + resourceId: acct.id, + remediation: opts.remediation, + evidence: { + serviceId: this.serviceId, + serviceName: 'Cosmos DB', + findingKey: `azure-cosmos-db-${opts.key}`, + accountName: acct.name, + location: acct.location, + }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/entra-id.adapter.ts b/apps/api/src/cloud-security/providers/azure/entra-id.adapter.ts new file mode 100644 index 0000000000..c069aa4b3d --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/entra-id.adapter.ts @@ -0,0 +1,184 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface RoleAssignment { + id: string; + properties: { + roleDefinitionId: string; + principalId: string; + principalType: string; + scope: string; + createdOn?: string; + }; +} + +interface RoleDefinition { + id: string; + properties: { + roleName: string; + type: string; // 'BuiltInRole' | 'CustomRole' + permissions: Array<{ actions: string[]; notActions: string[] }>; + }; +} + +const PRIVILEGED_ROLES = new Set([ + 'Owner', + 'Contributor', + 'User Access Administrator', + 'Global Administrator', + 'Privileged Role Administrator', +]); + +export class EntraIdAdapter implements AzureServiceAdapter { + readonly serviceId = 'entra-id'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + const baseUrl = 'https://management.azure.com'; + + // Fetch role assignments at subscription scope + const assignments = await fetchAllPages( + accessToken, + `${baseUrl}/subscriptions/${subscriptionId}/providers/Microsoft.Authorization/roleAssignments?api-version=2022-04-01`, + ); + + // Fetch role definitions to resolve names + const definitions = await fetchAllPages( + accessToken, + `${baseUrl}/subscriptions/${subscriptionId}/providers/Microsoft.Authorization/roleDefinitions?api-version=2022-04-01`, + ); + + const defMap = new Map(definitions.map((d) => [d.id, d])); + + // Check 1: Count privileged role assignments + const privilegedAssignments = assignments.filter((a) => { + const def = defMap.get(a.properties.roleDefinitionId); + return def && PRIVILEGED_ROLES.has(def.properties.roleName); + }); + + if (privilegedAssignments.length > 5) { + findings.push({ + id: `azure-entra-excessive-privileged-${subscriptionId}`, + title: 'Excessive Privileged Role Assignments', + description: `${privilegedAssignments.length} principals have privileged roles (Owner, Contributor, User Access Administrator). Limit to essential accounts only.`, + severity: 'high', + resourceType: 'subscription', + resourceId: subscriptionId, + remediation: + 'Review privileged role assignments and remove unnecessary ones. Use just-in-time access via Azure PIM.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Entra ID', + findingKey: 'azure-entra-id-excessive-privileged-roles', + count: privilegedAssignments.length, + principals: privilegedAssignments.slice(0, 10).map((a) => ({ + principalId: a.properties.principalId, + principalType: a.properties.principalType, + role: defMap.get(a.properties.roleDefinitionId)?.properties + .roleName, + })), + }, + createdAt: new Date().toISOString(), + }); + } else { + findings.push({ + id: `azure-entra-privileged-ok-${subscriptionId}`, + title: 'Privileged Role Assignments', + description: `${privilegedAssignments.length} privileged role assignments found — within acceptable range.`, + severity: 'info', + resourceType: 'subscription', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Entra ID', + findingKey: 'azure-entra-id-excessive-privileged-roles', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + // Check 2: Custom roles with wildcard actions + const dangerousCustomRoles = definitions.filter((d) => { + if (d.properties.type !== 'CustomRole') return false; + return d.properties.permissions.some((p) => + p.actions.some((a) => a === '*' || a.endsWith('/*')), + ); + }); + + for (const role of dangerousCustomRoles) { + findings.push({ + id: `azure-entra-wildcard-role-${role.id}`, + title: `Custom Role with Wildcard Permissions: ${role.properties.roleName}`, + description: `Custom role "${role.properties.roleName}" grants wildcard (*) permissions. This is overly permissive.`, + severity: 'high', + resourceType: 'role-definition', + resourceId: role.id, + remediation: + 'Restrict custom role permissions to only the specific actions required.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Entra ID', + findingKey: 'azure-entra-id-wildcard-custom-role', + roleName: role.properties.roleName, + permissions: role.properties.permissions, + }, + createdAt: new Date().toISOString(), + }); + } + + if (dangerousCustomRoles.length === 0) { + findings.push({ + id: `azure-entra-custom-roles-ok-${subscriptionId}`, + title: 'Custom Role Permissions', + description: 'No custom roles with wildcard (*) permissions found.', + severity: 'info', + resourceType: 'subscription', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Entra ID', + findingKey: 'azure-entra-id-wildcard-custom-role', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + // Check 3: Service principals with Owner/Contributor + const spWithPrivileged = privilegedAssignments.filter( + (a) => a.properties.principalType === 'ServicePrincipal', + ); + + if (spWithPrivileged.length > 0) { + findings.push({ + id: `azure-entra-sp-privileged-${subscriptionId}`, + title: 'Service Principals with Privileged Roles', + description: `${spWithPrivileged.length} service principal(s) have privileged roles. Service principals should use least-privilege access.`, + severity: 'medium', + resourceType: 'subscription', + resourceId: subscriptionId, + remediation: + 'Replace broad roles with scoped custom roles for service principals.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Entra ID', + findingKey: 'azure-entra-id-sp-privileged', + count: spWithPrivileged.length, + }, + createdAt: new Date().toISOString(), + }); + } + + return findings; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/index.ts b/apps/api/src/cloud-security/providers/azure/index.ts new file mode 100644 index 0000000000..53534008fe --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/index.ts @@ -0,0 +1,19 @@ +export type { AzureServiceAdapter } from './azure-service-adapter'; +export { + fetchAllPages, + AZURE_CATEGORY_TO_SERVICE, + AZURE_SERVICE_NAMES, +} from './azure-service-adapter'; + +export { AksAdapter } from './aks.adapter'; +export { AppServiceAdapter } from './app-service.adapter'; +export { ContainerRegistryAdapter } from './container-registry.adapter'; +export { CosmosDbAdapter } from './cosmos-db.adapter'; +export { EntraIdAdapter } from './entra-id.adapter'; +export { KeyVaultAdapter } from './key-vault.adapter'; +export { MonitorAdapter } from './monitor.adapter'; +export { NetworkWatcherAdapter } from './network-watcher.adapter'; +export { PolicyAdapter } from './policy.adapter'; +export { SqlDatabaseAdapter } from './sql-database.adapter'; +export { StorageAccountAdapter } from './storage-account.adapter'; +export { VirtualMachineAdapter } from './virtual-machine.adapter'; diff --git a/apps/api/src/cloud-security/providers/azure/key-vault.adapter.ts b/apps/api/src/cloud-security/providers/azure/key-vault.adapter.ts new file mode 100644 index 0000000000..5d97c3a089 --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/key-vault.adapter.ts @@ -0,0 +1,157 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface KeyVault { + id: string; + name: string; + location: string; + properties: { + enableSoftDelete?: boolean; + enablePurgeProtection?: boolean; + enableRbacAuthorization?: boolean; + publicNetworkAccess?: string; + networkAcls?: { + defaultAction: string; + bypass: string; + }; + vaultUri: string; + }; +} + +export class KeyVaultAdapter implements AzureServiceAdapter { + readonly serviceId = 'key-vault'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + const baseUrl = 'https://management.azure.com'; + + const vaults = await fetchAllPages( + accessToken, + `${baseUrl}/subscriptions/${subscriptionId}/providers/Microsoft.KeyVault/vaults?api-version=2023-07-01`, + ); + + if (vaults.length === 0) return findings; + + for (const vault of vaults) { + const props = vault.properties; + + // Check 1: Soft delete + if (!props.enableSoftDelete) { + findings.push( + this.finding(vault, { + key: 'soft-delete-disabled', + title: `Key Vault Soft Delete Disabled: ${vault.name}`, + description: `Key Vault "${vault.name}" does not have soft delete enabled. Deleted keys/secrets cannot be recovered.`, + severity: 'high', + remediation: + 'Enable soft delete on the Key Vault to allow recovery of deleted items.', + }), + ); + } + + // Check 2: Purge protection + if (!props.enablePurgeProtection) { + findings.push( + this.finding(vault, { + key: 'purge-protection-disabled', + title: `Key Vault Purge Protection Disabled: ${vault.name}`, + description: `Key Vault "${vault.name}" does not have purge protection. Deleted items can be permanently removed before the retention period.`, + severity: 'medium', + remediation: + 'Enable purge protection to prevent permanent deletion during the retention period.', + }), + ); + } + + // Check 3: Public network access + const isPublic = + props.publicNetworkAccess === 'Enabled' || + props.networkAcls?.defaultAction === 'Allow'; + if (isPublic) { + findings.push( + this.finding(vault, { + key: 'public-access', + title: `Key Vault Publicly Accessible: ${vault.name}`, + description: `Key Vault "${vault.name}" allows public network access. Restrict to private endpoints or specific networks.`, + severity: 'high', + remediation: + 'Configure network ACLs to deny public access and use private endpoints.', + }), + ); + } + + // Check 4: RBAC vs access policies + if (!props.enableRbacAuthorization) { + findings.push( + this.finding(vault, { + key: 'no-rbac', + title: `Key Vault Using Legacy Access Policies: ${vault.name}`, + description: `Key Vault "${vault.name}" uses vault access policies instead of Azure RBAC. RBAC provides finer-grained, auditable access control.`, + severity: 'low', + remediation: + 'Migrate to Azure RBAC permission model for better access control.', + }), + ); + } + } + + // Passing check if all vaults are well-configured + const failCount = findings.length; + if (failCount === 0) { + findings.push({ + id: `azure-key-vault-ok-${subscriptionId}`, + title: 'Key Vault Configuration', + description: `All ${vaults.length} Key Vault(s) are properly configured.`, + severity: 'info', + resourceType: 'key-vault', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Key Vault', + findingKey: 'azure-key-vault-all-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + return findings; + } + + private finding( + vault: KeyVault, + opts: { + key: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation: string; + }, + ): SecurityFinding { + return { + id: `azure-kv-${opts.key}-${vault.name}`, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'key-vault', + resourceId: vault.id, + remediation: opts.remediation, + evidence: { + serviceId: this.serviceId, + serviceName: 'Key Vault', + findingKey: `azure-key-vault-${opts.key}`, + vaultName: vault.name, + location: vault.location, + }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/monitor.adapter.ts b/apps/api/src/cloud-security/providers/azure/monitor.adapter.ts new file mode 100644 index 0000000000..4001aeee07 --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/monitor.adapter.ts @@ -0,0 +1,223 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface ActivityLogAlert { + id: string; + name: string; + properties: { + enabled: boolean; + description?: string; + condition?: { + allOf?: Array<{ field: string; equals: string }>; + }; + }; +} + +interface DiagnosticSetting { + id: string; + name: string; + properties: { + logs: Array<{ enabled: boolean; category?: string }>; + workspaceId?: string; + storageAccountId?: string; + eventHubAuthorizationRuleId?: string; + }; +} + +/** Critical operations that should have activity log alerts. */ +const RECOMMENDED_ALERTS = [ + { + operation: 'Microsoft.Authorization/policyAssignments/write', + name: 'Policy assignment changes', + }, + { + operation: 'Microsoft.Security/securitySolutions/write', + name: 'Security solution changes', + }, + { + operation: 'Microsoft.Network/networkSecurityGroups/write', + name: 'NSG changes', + }, + { + operation: 'Microsoft.Sql/servers/firewallRules/write', + name: 'SQL firewall rule changes', + }, +]; + +export class MonitorAdapter implements AzureServiceAdapter { + readonly serviceId = 'monitor'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + const baseUrl = 'https://management.azure.com'; + + // Check 1: Activity log alerts for critical operations + try { + const alerts = await fetchAllPages( + accessToken, + `${baseUrl}/subscriptions/${subscriptionId}/providers/Microsoft.Insights/activityLogAlerts?api-version=2020-10-01`, + ); + + const enabledAlerts = alerts.filter((a) => a.properties.enabled); + const alertOperations = new Set(); + + for (const alert of enabledAlerts) { + const conditions = alert.properties.condition?.allOf ?? []; + for (const c of conditions) { + if (c.field === 'operationName') { + alertOperations.add(c.equals); + } + } + } + + for (const rec of RECOMMENDED_ALERTS) { + const hasAlert = alertOperations.has(rec.operation); + if (!hasAlert) { + findings.push({ + id: `azure-monitor-missing-alert-${rec.operation}`, + title: `Missing Activity Log Alert: ${rec.name}`, + description: `No activity log alert is configured for "${rec.operation}". Critical operations should trigger alerts.`, + severity: 'medium', + resourceType: 'activity-log-alert', + resourceId: subscriptionId, + remediation: `Create an activity log alert for operation "${rec.operation}" in Azure Monitor.`, + evidence: { + serviceId: this.serviceId, + serviceName: 'Azure Monitor', + findingKey: `azure-monitor-missing-alert-${rec.operation.split('/').pop()}`, + operation: rec.operation, + }, + createdAt: new Date().toISOString(), + }); + } + } + + if (findings.length === 0) { + findings.push({ + id: `azure-monitor-alerts-ok-${subscriptionId}`, + title: 'Activity Log Alerts', + description: 'All recommended activity log alerts are configured.', + severity: 'info', + resourceType: 'activity-log-alert', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Azure Monitor', + findingKey: 'azure-monitor-alerts-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (msg.includes('403') || msg.includes('AuthorizationFailed')) { + findings.push({ + id: `azure-monitor-permission-${subscriptionId}`, + title: 'Unable to Access Activity Log Alerts', + description: + 'The service principal does not have permission to read activity log alerts.', + severity: 'medium', + resourceType: 'activity-log-alert', + resourceId: subscriptionId, + remediation: + 'Assign the "Monitoring Reader" role to your App Registration.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Azure Monitor', + findingKey: 'azure-monitor-permission', + error: msg, + }, + createdAt: new Date().toISOString(), + }); + } + } + + // Check 2: Subscription-level diagnostic settings + try { + const response = await fetch( + `${baseUrl}/subscriptions/${subscriptionId}/providers/Microsoft.Insights/diagnosticSettings?api-version=2021-05-01-preview`, + { + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }, + ); + + if (response.ok) { + const data = (await response.json()) as { value: DiagnosticSetting[] }; + const settings = data.value ?? []; + + // Log what Azure returns so we can debug scan vs fix mismatches + if (settings.length > 0) { + for (const s of settings) { + console.log( + `[AzureMonitor] Diagnostic setting "${s.name}": workspaceId=${s.properties.workspaceId ?? 'none'}, storageAccountId=${s.properties.storageAccountId ?? 'none'}, eventHub=${s.properties.eventHubAuthorizationRuleId ?? 'none'}, logs=${JSON.stringify(s.properties.logs?.filter((l) => l.enabled).map((l) => l.category))}`, + ); + } + } else { + console.log( + '[AzureMonitor] No diagnostic settings found on subscription', + ); + } + + const hasLogExport = settings.some( + (s) => + s.properties.workspaceId || + s.properties.storageAccountId || + s.properties.eventHubAuthorizationRuleId, + ); + + if (!hasLogExport) { + findings.push({ + id: `azure-monitor-no-diag-${subscriptionId}`, + title: 'No Diagnostic Log Export Configured', + description: + 'Subscription activity logs are not exported to a Log Analytics workspace, storage account, or event hub.', + severity: 'medium', + resourceType: 'diagnostic-settings', + resourceId: subscriptionId, + remediation: + 'Configure a diagnostic setting to export activity logs to Log Analytics or a storage account.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Azure Monitor', + findingKey: 'azure-monitor-no-diagnostic-export', + }, + createdAt: new Date().toISOString(), + }); + } else { + findings.push({ + id: `azure-monitor-diag-ok-${subscriptionId}`, + title: 'Diagnostic Log Export', + description: 'Subscription activity logs are being exported.', + severity: 'info', + resourceType: 'diagnostic-settings', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Azure Monitor', + findingKey: 'azure-monitor-diagnostic-export-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + } + } catch { + // Non-critical — skip diagnostic settings check + } + + return findings; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/network-watcher.adapter.ts b/apps/api/src/cloud-security/providers/azure/network-watcher.adapter.ts new file mode 100644 index 0000000000..ded54f1c9b --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/network-watcher.adapter.ts @@ -0,0 +1,206 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface NetworkSecurityGroup { + id: string; + name: string; + location: string; + properties: { + securityRules: SecurityRule[]; + }; +} + +interface SecurityRule { + name: string; + properties: { + direction: 'Inbound' | 'Outbound'; + access: 'Allow' | 'Deny'; + protocol: string; + sourceAddressPrefix?: string; + sourceAddressPrefixes?: string[]; + destinationPortRange?: string; + destinationPortRanges?: string[]; + priority: number; + }; +} + +const DANGEROUS_PORTS = new Set([ + '22', + '3389', + '3306', + '5432', + '1433', + '27017', +]); +const WILDCARD_SOURCES = new Set(['*', '0.0.0.0/0', 'Internet', 'Any']); + +export class NetworkWatcherAdapter implements AzureServiceAdapter { + readonly serviceId = 'network-watcher'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + const baseUrl = 'https://management.azure.com'; + + const nsgs = await fetchAllPages( + accessToken, + `${baseUrl}/subscriptions/${subscriptionId}/providers/Microsoft.Network/networkSecurityGroups?api-version=2023-11-01`, + ); + + if (nsgs.length === 0) return findings; + + for (const nsg of nsgs) { + const inboundAllows = nsg.properties.securityRules.filter( + (r) => + r.properties.direction === 'Inbound' && + r.properties.access === 'Allow', + ); + + for (const rule of inboundAllows) { + const sources = this.getSources(rule); + const ports = this.getPorts(rule); + const isWildcard = sources.some((s) => WILDCARD_SOURCES.has(s)); + + if (!isWildcard) continue; + + // Check 1: SSH open to internet + if (ports.includes('22')) { + findings.push( + this.finding(nsg, rule, { + key: 'ssh-open', + title: `SSH Open to Internet: ${nsg.name}/${rule.name}`, + description: `NSG "${nsg.name}" allows SSH (port 22) from the internet. Restrict to specific IP ranges or use a bastion host.`, + severity: 'high', + remediation: + 'Restrict source address to specific IPs or use Azure Bastion for SSH access.', + }), + ); + } + + // Check 2: RDP open to internet + if (ports.includes('3389')) { + findings.push( + this.finding(nsg, rule, { + key: 'rdp-open', + title: `RDP Open to Internet: ${nsg.name}/${rule.name}`, + description: `NSG "${nsg.name}" allows RDP (port 3389) from the internet. This is a common attack vector.`, + severity: 'critical', + remediation: + 'Restrict source address to specific IPs or use Azure Bastion for RDP access.', + }), + ); + } + + // Check 3: Database ports open to internet + const openDbPorts = ports.filter( + (p) => DANGEROUS_PORTS.has(p) && p !== '22' && p !== '3389', + ); + if (openDbPorts.length > 0) { + findings.push( + this.finding(nsg, rule, { + key: 'db-ports-open', + title: `Database Ports Open to Internet: ${nsg.name}/${rule.name}`, + description: `NSG "${nsg.name}" exposes database ports (${openDbPorts.join(', ')}) to the internet.`, + severity: 'critical', + remediation: + 'Restrict database access to private networks only. Use Private Link for database connections.', + }), + ); + } + + // Check 4: All ports open to internet + if ( + ports.includes('*') || + rule.properties.destinationPortRange === '*' + ) { + findings.push( + this.finding(nsg, rule, { + key: 'all-ports-open', + title: `All Ports Open to Internet: ${nsg.name}/${rule.name}`, + description: `NSG "${nsg.name}" allows all ports from the internet. This effectively bypasses network security.`, + severity: 'critical', + remediation: + 'Replace with specific port rules following least-privilege principle.', + }), + ); + } + } + } + + if (findings.length === 0) { + findings.push({ + id: `azure-nsg-ok-${subscriptionId}`, + title: 'Network Security Groups', + description: `All ${nsgs.length} NSG(s) have no overly permissive inbound rules.`, + severity: 'info', + resourceType: 'nsg', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Network Watcher', + findingKey: 'azure-network-watcher-all-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + return findings; + } + + private getSources(rule: SecurityRule): string[] { + if (rule.properties.sourceAddressPrefixes?.length) { + return rule.properties.sourceAddressPrefixes; + } + return rule.properties.sourceAddressPrefix + ? [rule.properties.sourceAddressPrefix] + : []; + } + + private getPorts(rule: SecurityRule): string[] { + if (rule.properties.destinationPortRanges?.length) { + return rule.properties.destinationPortRanges; + } + return rule.properties.destinationPortRange + ? [rule.properties.destinationPortRange] + : []; + } + + private finding( + nsg: NetworkSecurityGroup, + rule: SecurityRule, + opts: { + key: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation: string; + }, + ): SecurityFinding { + return { + id: `azure-nw-${opts.key}-${nsg.name}-${rule.name}`, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'nsg', + resourceId: nsg.id, + remediation: opts.remediation, + evidence: { + serviceId: this.serviceId, + serviceName: 'Network Watcher', + findingKey: `azure-network-watcher-${opts.key}`, + nsgName: nsg.name, + ruleName: rule.name, + priority: rule.properties.priority, + }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/policy.adapter.ts b/apps/api/src/cloud-security/providers/azure/policy.adapter.ts new file mode 100644 index 0000000000..f32c2ccf19 --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/policy.adapter.ts @@ -0,0 +1,172 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; + +interface PolicyStateSummary { + results: { + queryResultsTable: { + rows: unknown[][]; + columns: Array<{ name: string; type: string }>; + }; + }; + 'policyAssignments@odata.count'?: number; + policyAssignments?: Array<{ + policyAssignmentId: string; + results: { + nonCompliantResources: number; + nonCompliantPolicies: number; + }; + }>; +} + +interface PolicySummaryResponse { + value: Array<{ + policyAssignmentId: string; + policyDefinitionId: string; + results: { + nonCompliantResources: number; + nonCompliantPolicies: number; + resourceDetails?: Array<{ + complianceState: string; + count: number; + }>; + }; + }>; +} + +export class PolicyAdapter implements AzureServiceAdapter { + readonly serviceId = 'policy'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + const baseUrl = 'https://management.azure.com'; + + try { + const response = await fetch( + `${baseUrl}/subscriptions/${subscriptionId}/providers/Microsoft.PolicyInsights/policyStates/latest/summarize?api-version=2019-10-01`, + { + method: 'POST', + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }, + ); + + if (!response.ok) { + const error = await response.text(); + if (error.includes('403') || error.includes('AuthorizationFailed')) { + findings.push({ + id: `azure-policy-permission-${subscriptionId}`, + title: 'Unable to Access Policy Compliance', + description: + 'The service principal does not have permission to read policy states.', + severity: 'medium', + resourceType: 'policy', + resourceId: subscriptionId, + remediation: + 'Assign the "Reader" role to your App Registration on the subscription.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Azure Policy', + findingKey: 'azure-policy-permission', + }, + createdAt: new Date().toISOString(), + }); + return findings; + } + throw new Error(`Azure Policy API error: ${error}`); + } + + const data = (await response.json()) as { + value: PolicySummaryResponse['value']; + }; + const assignments = data.value ?? []; + + let totalNonCompliant = 0; + const topOffenders: Array<{ id: string; count: number }> = []; + + for (const assignment of assignments) { + const count = assignment.results.nonCompliantResources; + if (count > 0) { + totalNonCompliant += count; + topOffenders.push({ id: assignment.policyAssignmentId, count }); + } + } + + topOffenders.sort((a, b) => b.count - a.count); + + if (totalNonCompliant > 0) { + findings.push({ + id: `azure-policy-noncompliant-${subscriptionId}`, + title: 'Non-Compliant Resources Detected', + description: `${totalNonCompliant} resource(s) across ${topOffenders.length} policy assignment(s) are non-compliant. Review and remediate compliance violations.`, + severity: totalNonCompliant > 20 ? 'high' : 'medium', + resourceType: 'policy-state', + resourceId: subscriptionId, + remediation: + 'Review non-compliant resources in Azure Policy and remediate or create exemptions for known exceptions.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Azure Policy', + findingKey: 'azure-policy-non-compliant-resources', + totalNonCompliant, + topAssignments: topOffenders.slice(0, 5), + }, + createdAt: new Date().toISOString(), + }); + } else { + findings.push({ + id: `azure-policy-compliant-${subscriptionId}`, + title: 'Policy Compliance', + description: + 'All resources are compliant with assigned Azure Policies.', + severity: 'info', + resourceType: 'policy-state', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Azure Policy', + findingKey: 'azure-policy-compliant', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + // Check: No policies assigned at all + if (assignments.length === 0) { + findings.push({ + id: `azure-policy-none-${subscriptionId}`, + title: 'No Azure Policies Assigned', + description: + 'This subscription has no Azure Policy assignments. Consider applying security baseline policies.', + severity: 'medium', + resourceType: 'policy-state', + resourceId: subscriptionId, + remediation: + 'Assign the Azure Security Benchmark initiative or other security-focused policy sets.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Azure Policy', + findingKey: 'azure-policy-no-assignments', + }, + createdAt: new Date().toISOString(), + }); + } + } catch (error) { + const msg = error instanceof Error ? error.message : String(error); + if (!msg.includes('permission')) { + throw error; + } + } + + return findings; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/sql-database.adapter.ts b/apps/api/src/cloud-security/providers/azure/sql-database.adapter.ts new file mode 100644 index 0000000000..a05c59ca8a --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/sql-database.adapter.ts @@ -0,0 +1,206 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface SqlServer { + id: string; + name: string; + location: string; + properties: { + administratorLogin?: string; + fullyQualifiedDomainName?: string; + publicNetworkAccess?: string; + minimalTlsVersion?: string; + }; +} + +interface SqlFirewallRule { + id: string; + name: string; + properties: { + startIpAddress: string; + endIpAddress: string; + }; +} + +interface AuditingSetting { + properties: { + state: string; + retentionDays?: number; + }; +} + +const BASE = 'https://management.azure.com'; + +export class SqlDatabaseAdapter implements AzureServiceAdapter { + readonly serviceId = 'sql-database'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + + const servers = await fetchAllPages( + accessToken, + `${BASE}/subscriptions/${subscriptionId}/providers/Microsoft.Sql/servers?api-version=2023-05-01-preview`, + ); + + if (servers.length === 0) return findings; + + for (const server of servers) { + const props = server.properties; + + // Check 1: Public network access + if (props.publicNetworkAccess === 'Enabled') { + findings.push( + this.finding(server, { + key: 'public-access', + title: `SQL Server Public Access Enabled: ${server.name}`, + description: `SQL Server "${server.name}" allows public network access. Use private endpoints instead.`, + severity: 'high', + remediation: + 'Disable public network access and configure private endpoint connections.', + }), + ); + } + + // Check 2: TLS version + if (!props.minimalTlsVersion || props.minimalTlsVersion < '1.2') { + findings.push( + this.finding(server, { + key: 'tls-outdated', + title: `Outdated TLS Version: ${server.name}`, + description: `SQL Server "${server.name}" allows TLS versions below 1.2.`, + severity: 'medium', + remediation: 'Set minimum TLS version to 1.2.', + }), + ); + } + + // Check 3: Auditing + try { + const resp = await fetch( + `${BASE}${server.id}/auditingSettings/default?api-version=2021-11-01`, + { headers: { Authorization: `Bearer ${accessToken}` } }, + ); + if (resp.ok) { + const data = (await resp.json()) as AuditingSetting; + if (data.properties.state !== 'Enabled') { + findings.push( + this.finding(server, { + key: 'auditing-disabled', + title: `SQL Auditing Disabled: ${server.name}`, + description: `SQL Server "${server.name}" does not have auditing enabled. Enable auditing to track database operations.`, + severity: 'high', + remediation: + 'Enable SQL auditing in the server security settings.', + }), + ); + } + } + } catch { + /* skip if can't check */ + } + + // Check 4: Firewall rules — check for "allow all Azure services" (0.0.0.0) + try { + const rules = await fetchAllPages( + accessToken, + `${BASE}${server.id}/firewallRules?api-version=2023-05-01-preview`, + ); + + const allowAll = rules.find( + (r) => + r.properties.startIpAddress === '0.0.0.0' && + r.properties.endIpAddress === '0.0.0.0', + ); + if (allowAll) { + findings.push( + this.finding(server, { + key: 'allow-azure-services', + title: `SQL Allows All Azure Services: ${server.name}`, + description: `SQL Server "${server.name}" has "Allow Azure services" enabled. This allows ANY Azure service (including other tenants) to connect.`, + severity: 'medium', + remediation: + 'Remove the 0.0.0.0 rule and use specific VNet service endpoints or private endpoints.', + }), + ); + } + + const wideOpen = rules.find( + (r) => + r.properties.startIpAddress === '0.0.0.0' && + r.properties.endIpAddress === '255.255.255.255', + ); + if (wideOpen) { + findings.push( + this.finding(server, { + key: 'firewall-wide-open', + title: `SQL Firewall Wide Open: ${server.name}`, + description: `SQL Server "${server.name}" allows connections from any IP address.`, + severity: 'critical', + remediation: + 'Remove the 0.0.0.0-255.255.255.255 rule and restrict to specific IPs.', + }), + ); + } + } catch { + /* skip if can't check */ + } + } + + if (findings.length === 0) { + findings.push({ + id: `azure-sql-ok-${subscriptionId}`, + title: 'SQL Database Security', + description: `All ${servers.length} SQL Server(s) are properly configured.`, + severity: 'info', + resourceType: 'sql-server', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'SQL Database', + findingKey: 'azure-sql-database-all-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + return findings; + } + + private finding( + server: SqlServer, + opts: { + key: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation: string; + }, + ): SecurityFinding { + return { + id: `azure-sql-${opts.key}-${server.name}`, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'sql-server', + resourceId: server.id, + remediation: opts.remediation, + evidence: { + serviceId: this.serviceId, + serviceName: 'SQL Database', + findingKey: `azure-sql-database-${opts.key}`, + serverName: server.name, + location: server.location, + }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/storage-account.adapter.ts b/apps/api/src/cloud-security/providers/azure/storage-account.adapter.ts new file mode 100644 index 0000000000..08b8f7c8d8 --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/storage-account.adapter.ts @@ -0,0 +1,174 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface StorageAccount { + id: string; + name: string; + location: string; + properties: { + supportsHttpsTrafficOnly?: boolean; + minimumTlsVersion?: string; + allowBlobPublicAccess?: boolean; + encryption?: { + services?: { + blob?: { enabled: boolean }; + file?: { enabled: boolean }; + }; + keySource?: string; + }; + networkAcls?: { + defaultAction: string; + bypass: string; + }; + publicNetworkAccess?: string; + }; +} + +export class StorageAccountAdapter implements AzureServiceAdapter { + readonly serviceId = 'storage-account'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + + const accounts = await fetchAllPages( + accessToken, + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Storage/storageAccounts?api-version=2023-05-01`, + ); + + if (accounts.length === 0) return findings; + + for (const acct of accounts) { + const props = acct.properties; + + // Check 1: HTTPS-only + if (props.supportsHttpsTrafficOnly === false) { + findings.push( + this.finding(acct, { + key: 'https-disabled', + title: `HTTPS Not Enforced: ${acct.name}`, + description: `Storage account "${acct.name}" allows insecure HTTP traffic.`, + severity: 'high', + remediation: + 'Enable "Secure transfer required" to enforce HTTPS-only access.', + }), + ); + } + + // Check 2: TLS version + if (!props.minimumTlsVersion || props.minimumTlsVersion < 'TLS1_2') { + findings.push( + this.finding(acct, { + key: 'tls-outdated', + title: `Outdated TLS Version: ${acct.name}`, + description: `Storage account "${acct.name}" allows TLS versions below 1.2 (current: ${props.minimumTlsVersion || 'not set'}).`, + severity: 'medium', + remediation: 'Set minimum TLS version to TLS 1.2.', + }), + ); + } + + // Check 3: Public blob access + if (props.allowBlobPublicAccess === true) { + findings.push( + this.finding(acct, { + key: 'public-blob', + title: `Public Blob Access Enabled: ${acct.name}`, + description: `Storage account "${acct.name}" allows anonymous public access to blobs. This can expose sensitive data.`, + severity: 'high', + remediation: + 'Disable "Allow Blob public access" unless explicitly required.', + }), + ); + } + + // Check 4: Network access + const isPublic = + props.publicNetworkAccess === 'Enabled' || + props.networkAcls?.defaultAction === 'Allow'; + if (isPublic) { + findings.push( + this.finding(acct, { + key: 'public-network', + title: `Public Network Access: ${acct.name}`, + description: `Storage account "${acct.name}" allows access from all networks. Restrict to specific VNets or IP ranges.`, + severity: 'medium', + remediation: + 'Configure network ACLs to deny public access and add specific VNet/IP rules.', + }), + ); + } + + // Check 5: Encryption + const blobEncrypted = props.encryption?.services?.blob?.enabled !== false; + const fileEncrypted = props.encryption?.services?.file?.enabled !== false; + if (!blobEncrypted || !fileEncrypted) { + findings.push( + this.finding(acct, { + key: 'encryption-disabled', + title: `Encryption Not Fully Enabled: ${acct.name}`, + description: `Storage account "${acct.name}" does not have encryption enabled for all services.`, + severity: 'high', + remediation: 'Enable encryption for blob and file services.', + }), + ); + } + } + + if (findings.length === 0) { + findings.push({ + id: `azure-storage-ok-${subscriptionId}`, + title: 'Storage Account Security', + description: `All ${accounts.length} storage account(s) are properly configured.`, + severity: 'info', + resourceType: 'storage-account', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Storage Accounts', + findingKey: 'azure-storage-account-all-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + return findings; + } + + private finding( + acct: StorageAccount, + opts: { + key: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation: string; + }, + ): SecurityFinding { + return { + id: `azure-sa-${opts.key}-${acct.name}`, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'storage-account', + resourceId: acct.id, + remediation: opts.remediation, + evidence: { + serviceId: this.serviceId, + serviceName: 'Storage Accounts', + findingKey: `azure-storage-account-${opts.key}`, + accountName: acct.name, + location: acct.location, + }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/azure/virtual-machine.adapter.ts b/apps/api/src/cloud-security/providers/azure/virtual-machine.adapter.ts new file mode 100644 index 0000000000..694901412b --- /dev/null +++ b/apps/api/src/cloud-security/providers/azure/virtual-machine.adapter.ts @@ -0,0 +1,175 @@ +import type { SecurityFinding } from '../../cloud-security.service'; +import type { AzureServiceAdapter } from './azure-service-adapter'; +import { fetchAllPages } from './azure-service-adapter'; + +interface VirtualMachine { + id: string; + name: string; + location: string; + identity?: { + type: string; // 'SystemAssigned' | 'UserAssigned' | 'None' + }; + properties: { + storageProfile?: { + osDisk?: { + managedDisk?: { + diskEncryptionSet?: { id: string }; + }; + encryptionSettings?: { enabled: boolean }; + }; + }; + osProfile?: { + linuxConfiguration?: { + disablePasswordAuthentication?: boolean; + }; + windowsConfiguration?: unknown; + }; + networkProfile?: { + networkInterfaces?: Array<{ id: string }>; + }; + securityProfile?: { + securityType?: string; // 'TrustedLaunch' | etc. + uefiSettings?: { + secureBootEnabled?: boolean; + vTpmEnabled?: boolean; + }; + }; + }; +} + +export class VirtualMachineAdapter implements AzureServiceAdapter { + readonly serviceId = 'virtual-machine'; + + async scan({ + accessToken, + subscriptionId, + }: { + accessToken: string; + subscriptionId: string; + }): Promise { + const findings: SecurityFinding[] = []; + + const vms = await fetchAllPages( + accessToken, + `https://management.azure.com/subscriptions/${subscriptionId}/providers/Microsoft.Compute/virtualMachines?api-version=2024-03-01`, + ); + + if (vms.length === 0) return findings; + + for (const vm of vms) { + // Check 1: Managed identity + const hasIdentity = vm.identity?.type && vm.identity.type !== 'None'; + if (!hasIdentity) { + findings.push( + this.finding(vm, { + key: 'no-managed-identity', + title: `No Managed Identity: ${vm.name}`, + description: `VM "${vm.name}" does not use a managed identity. Managed identities eliminate credential management and are more secure than service principals.`, + severity: 'medium', + remediation: + 'Enable system-assigned or user-assigned managed identity on the VM.', + }), + ); + } + + // Check 2: OS disk encryption + const osDisk = vm.properties.storageProfile?.osDisk; + const hasEncryption = + osDisk?.managedDisk?.diskEncryptionSet?.id || + osDisk?.encryptionSettings?.enabled; + if (!hasEncryption) { + findings.push( + this.finding(vm, { + key: 'disk-not-encrypted', + title: `OS Disk Not Encrypted with CMK: ${vm.name}`, + description: `VM "${vm.name}" OS disk does not use customer-managed key encryption. Azure encrypts by default with platform keys, but CMK provides more control.`, + severity: 'low', + remediation: + 'Enable disk encryption with a customer-managed key via Azure Disk Encryption or Disk Encryption Sets.', + }), + ); + } + + // Check 3: Linux VMs — password auth + const linuxConfig = vm.properties.osProfile?.linuxConfiguration; + if (linuxConfig && linuxConfig.disablePasswordAuthentication === false) { + findings.push( + this.finding(vm, { + key: 'password-auth-enabled', + title: `SSH Password Authentication Enabled: ${vm.name}`, + description: `Linux VM "${vm.name}" allows SSH password authentication. Use SSH keys instead for stronger security.`, + severity: 'medium', + remediation: + 'Disable password authentication and use SSH key-based authentication only.', + }), + ); + } + + // Check 4: Trusted Launch + const secProfile = vm.properties.securityProfile; + if (secProfile?.securityType === 'TrustedLaunch') { + if (!secProfile.uefiSettings?.secureBootEnabled) { + findings.push( + this.finding(vm, { + key: 'secure-boot-disabled', + title: `Secure Boot Disabled: ${vm.name}`, + description: `VM "${vm.name}" supports Trusted Launch but Secure Boot is not enabled.`, + severity: 'low', + remediation: 'Enable Secure Boot in the VM security settings.', + }), + ); + } + } + } + + if (findings.length === 0) { + findings.push({ + id: `azure-vm-ok-${subscriptionId}`, + title: 'Virtual Machine Security', + description: `All ${vms.length} VM(s) are properly configured.`, + severity: 'info', + resourceType: 'virtual-machine', + resourceId: subscriptionId, + remediation: 'No action needed.', + evidence: { + serviceId: this.serviceId, + serviceName: 'Virtual Machines', + findingKey: 'azure-virtual-machine-all-ok', + }, + createdAt: new Date().toISOString(), + passed: true, + }); + } + + return findings; + } + + private finding( + vm: VirtualMachine, + opts: { + key: string; + title: string; + description: string; + severity: SecurityFinding['severity']; + remediation: string; + }, + ): SecurityFinding { + return { + id: `azure-vm-${opts.key}-${vm.name}`, + title: opts.title, + description: opts.description, + severity: opts.severity, + resourceType: 'virtual-machine', + resourceId: vm.id, + remediation: opts.remediation, + evidence: { + serviceId: this.serviceId, + serviceName: 'Virtual Machines', + findingKey: `azure-virtual-machine-${opts.key}`, + vmName: vm.name, + location: vm.location, + }, + createdAt: new Date().toISOString(), + }; + } +} diff --git a/apps/api/src/cloud-security/providers/gcp-security.service.ts b/apps/api/src/cloud-security/providers/gcp-security.service.ts index 01f0663da9..79f1a844f0 100644 --- a/apps/api/src/cloud-security/providers/gcp-security.service.ts +++ b/apps/api/src/cloud-security/providers/gcp-security.service.ts @@ -1,26 +1,1030 @@ import { Injectable, Logger } from '@nestjs/common'; import type { SecurityFinding } from '../cloud-security.service'; +import { parseGcpPermissionError } from '../remediation-error.utils'; -interface GCPFindingResult { - finding: { +/** Full SCC finding structure with all useful fields. */ +interface SCCFinding { + name: string; + category: string; + severity: string; + state: string; + resourceName: string; + description?: string; + createTime: string; + eventTime: string; + externalUri?: string; + nextSteps?: string; + sourceProperties?: Record; + findingClass?: string; + compliances?: Array<{ + standard: string; + version: string; + ids: string[]; + }>; + parentDisplayName?: string; +} + +interface SCCFindingResult { + finding: SCCFinding; + resource?: { name: string; - category: string; - severity: string; - state: string; - resourceName: string; - description?: string; - createTime: string; - eventTime: string; + projectDisplayName?: string; + type?: string; + displayName?: string; }; } +/** Map SCC category → our serviceId for grouping in the UI. */ +const CATEGORY_TO_SERVICE: Record = { + // Cloud Storage + PUBLIC_BUCKET_ACL: 'cloud-storage', + BUCKET_POLICY_ONLY_DISABLED: 'cloud-storage', + BUCKET_LOGGING_DISABLED: 'cloud-storage', + BUCKET_LOCK_DISABLED: 'cloud-storage', + BUCKET_CMEK_DISABLED: 'cloud-storage', + // Compute / VPC + OPEN_FIREWALL: 'vpc-network', + OPEN_SSH_PORT: 'vpc-network', + OPEN_RDP_PORT: 'vpc-network', + FIREWALL_RULE_LOGGING_DISABLED: 'vpc-network', + FLOW_LOGS_DISABLED: 'vpc-network', + DEFAULT_SERVICE_ACCOUNT_USED: 'compute-engine', + COMPUTE_SECURE_BOOT_DISABLED: 'compute-engine', + OS_LOGIN_DISABLED: 'compute-engine', + PUBLIC_IP_ADDRESS: 'compute-engine', + IP_FORWARDING_ENABLED: 'compute-engine', + SERIAL_PORT_ENABLED: 'compute-engine', + FULL_API_ACCESS: 'compute-engine', + SHIELDED_VM_DISABLED: 'compute-engine', + // IAM + ADMIN_SERVICE_ACCOUNT: 'iam', + MFA_NOT_ENFORCED: 'iam', + OVER_PRIVILEGED_SERVICE_ACCOUNT_USER: 'iam', + SERVICE_ACCOUNT_KEY_NOT_ROTATED: 'iam', + USER_MANAGED_SERVICE_ACCOUNT_KEY: 'iam', + NON_ORG_IAM_MEMBER: 'iam', + OVER_PRIVILEGED_ACCOUNT: 'iam', + PRIMITIVE_ROLES_USED: 'iam', + KMS_ROLE_SEPARATION: 'iam', + // Cloud SQL + SQL_PUBLIC_IP: 'cloud-sql', + SQL_NO_ROOT_PASSWORD: 'cloud-sql', + SQL_CROSS_DB_OWNERSHIP_CHAINING: 'cloud-sql', + SQL_LOCAL_INFILE: 'cloud-sql', + SSL_NOT_ENFORCED: 'cloud-sql', + AUTO_BACKUP_DISABLED: 'cloud-sql', + SQL_CONTAINED_DATABASE_AUTHENTICATION: 'cloud-sql', + SQL_LOG_DISCONNECTIONS_DISABLED: 'cloud-sql', + SQL_LOG_CONNECTIONS_DISABLED: 'cloud-sql', + SQL_LOG_ERROR_VERBOSITY: 'cloud-sql', + SQL_LOG_MIN_MESSAGES: 'cloud-sql', + SQL_LOG_MIN_DURATION_STATEMENT_ENABLED: 'cloud-sql', + // GKE + CLUSTER_PRIVATE_GOOGLE_ACCESS_DISABLED: 'gke', + CLUSTER_SHIELDED_NODES_DISABLED: 'gke', + LEGACY_AUTHORIZATION_ENABLED: 'gke', + MASTER_AUTHORIZED_NETWORKS_DISABLED: 'gke', + NETWORK_POLICY_DISABLED: 'gke', + POD_SECURITY_POLICY_DISABLED: 'gke', + PRIVATE_CLUSTER_DISABLED: 'gke', + RELEASE_CHANNEL_DISABLED: 'gke', + WEB_UI_ENABLED: 'gke', + WORKLOAD_IDENTITY_DISABLED: 'gke', + // KMS + KMS_KEY_NOT_ROTATED: 'cloud-kms', + KMS_PROJECT_HAS_OWNER: 'cloud-kms', + // Logging / Monitoring + AUDIT_LOGGING_DISABLED: 'cloud-logging', + LOG_NOT_EXPORTED: 'cloud-logging', + LOCKED_RETENTION_POLICY_NOT_SET: 'cloud-logging', + AUDIT_CONFIG_NOT_MONITORED: 'cloud-monitoring', + CUSTOM_ROLE_NOT_MONITORED: 'cloud-monitoring', + FIREWALL_NOT_MONITORED: 'cloud-monitoring', + NETWORK_NOT_MONITORED: 'cloud-monitoring', + ROUTE_NOT_MONITORED: 'cloud-monitoring', + SQL_INSTANCE_NOT_MONITORED: 'cloud-monitoring', + // DNS + DNSSEC_DISABLED: 'cloud-dns', + RSASHA1_FOR_SIGNING: 'cloud-dns', + // BigQuery + DATASET_CMEK_DISABLED: 'bigquery', + PUBLIC_DATASET: 'bigquery', + // Pub/Sub + PUBSUB_CMEK_DISABLED: 'pubsub', + // Cloud Armor / Load Balancing + SSL_POLICY_WEAK: 'cloud-armor', +}; + +/** Human-readable service names for UI grouping. */ +const SERVICE_NAMES: Record = { + 'cloud-storage': 'Cloud Storage', + 'vpc-network': 'VPC Network', + 'compute-engine': 'Compute Engine', + iam: 'IAM', + 'cloud-sql': 'Cloud SQL', + gke: 'GKE', + 'cloud-kms': 'Cloud KMS', + 'cloud-logging': 'Cloud Logging', + 'cloud-monitoring': 'Cloud Monitoring', + 'cloud-dns': 'Cloud DNS', + bigquery: 'BigQuery', + pubsub: 'Pub/Sub', + 'cloud-armor': 'Cloud Armor', + 'security-command-center': 'Security Command Center', +}; + +/** Map GCP API service names → our service category IDs. */ +const GCP_API_TO_SERVICE: Record = { + 'storage.googleapis.com': ['cloud-storage'], + 'storage-component.googleapis.com': ['cloud-storage'], + 'compute.googleapis.com': ['compute-engine', 'vpc-network'], + 'securitycenter.googleapis.com': ['security-command-center'], + 'sqladmin.googleapis.com': ['cloud-sql'], + 'container.googleapis.com': ['gke'], + 'cloudkms.googleapis.com': ['cloud-kms'], + 'logging.googleapis.com': ['cloud-logging'], + 'monitoring.googleapis.com': ['cloud-monitoring'], + 'dns.googleapis.com': ['cloud-dns'], + 'bigquery.googleapis.com': ['bigquery'], + 'bigquerystorage.googleapis.com': ['bigquery'], + 'pubsub.googleapis.com': ['pubsub'], + 'networksecurity.googleapis.com': ['cloud-armor'], + 'iam.googleapis.com': ['iam'], + 'iamcredentials.googleapis.com': ['iam'], +}; + +export type GcpSetupStepId = + | 'enable_security_command_center_api' + | 'enable_cloud_resource_manager_api' + | 'enable_service_usage_api' + | 'grant_findings_viewer_role'; + +export type GcpSetupAdminAction = + | { kind: 'link'; label: string; url: string } + | { kind: 'command'; label: string; command: string }; + +export type GcpSetupResolveAction = { + label: string; + method: 'POST'; + endpoint: string; + body: { stepId: GcpSetupStepId }; +}; + +export type GcpSetupStep = { + id: GcpSetupStepId; + name: string; + success: boolean; + error?: string; + actionUrl?: string; + actionText?: string; + requiredForScan: boolean; + resolveAction?: GcpSetupResolveAction; + adminActions?: GcpSetupAdminAction[]; +}; + +const REQUIRED_GCP_API_STEPS: Array<{ + id: GcpSetupStepId; + api: string; + name: string; + actionUrl: string; + actionText: string; + requiredForScan: boolean; +}> = [ + { + id: 'enable_security_command_center_api', + api: 'securitycenter.googleapis.com', + name: 'Enable Security Command Center API', + actionUrl: + 'https://console.cloud.google.com/apis/library/securitycenter.googleapis.com', + actionText: 'Open API', + requiredForScan: true, + }, + { + id: 'enable_cloud_resource_manager_api', + api: 'cloudresourcemanager.googleapis.com', + name: 'Enable Cloud Resource Manager API', + actionUrl: + 'https://console.cloud.google.com/apis/library/cloudresourcemanager.googleapis.com', + actionText: 'Open API', + requiredForScan: false, + }, + { + id: 'enable_service_usage_api', + api: 'serviceusage.googleapis.com', + name: 'Enable Service Usage API', + actionUrl: + 'https://console.cloud.google.com/apis/library/serviceusage.googleapis.com', + actionText: 'Open API', + requiredForScan: false, + }, +]; + +const FINDINGS_VIEWER_ACTION = { + actionUrl: 'https://console.cloud.google.com/iam-admin/iam', + actionText: 'Open IAM', + requiredForScan: true, +}; + @Injectable() export class GCPSecurityService { private readonly logger = new Logger(GCPSecurityService.name); + /** + * One-click GCP setup: enable required APIs, detect user email, + * and grant the Findings Viewer role at the organization level. + * Returns status of each step so the frontend can show what succeeded/failed. + */ + async autoSetup(params: { + accessToken: string; + organizationId: string; + projectId: string; + }): Promise<{ + email: string | null; + steps: GcpSetupStep[]; + }> { + const { accessToken, organizationId, projectId } = params; + const steps: GcpSetupStep[] = []; + const email = await this.detectEmail(accessToken); + const hasFindingsAccess = organizationId + ? await this.canReadFindings(accessToken, organizationId) + : false; + + for (const stepDef of REQUIRED_GCP_API_STEPS) { + let step = await this.runEnableApiSetupStep({ + stepDef, + accessToken, + projectId, + }); + + // If findings are already readable, SCC API access is effectively working for this org. + if ( + stepDef.id === 'enable_security_command_center_api' && + !step.success && + hasFindingsAccess + ) { + step = { + ...step, + success: true, + error: undefined, + }; + } + + steps.push(step); + } + + steps.push( + await this.runGrantFindingsViewerSetupStep({ + accessToken, + organizationId, + email, + hasFindingsAccess, + }), + ); + + this.logger.log( + `GCP auto-setup: ${steps.filter((s) => s.success).length}/${steps.length} steps succeeded`, + ); + return { email, steps }; + } + + async resolveSetupStep(params: { + stepId: GcpSetupStepId; + accessToken: string; + organizationId: string; + projectId: string; + email?: string | null; + }): Promise<{ email: string | null; step: GcpSetupStep }> { + const { stepId, accessToken, organizationId, projectId } = params; + const email = params.email ?? (await this.detectEmail(accessToken)); + const hasFindingsAccess = organizationId + ? await this.canReadFindings(accessToken, organizationId) + : false; + + if (stepId === 'grant_findings_viewer_role') { + return { + email, + step: await this.runGrantFindingsViewerSetupStep({ + accessToken, + organizationId, + email, + hasFindingsAccess, + }), + }; + } + + const stepDef = REQUIRED_GCP_API_STEPS.find((s) => s.id === stepId); + if (!stepDef) { + throw new Error(`Unsupported GCP setup step: ${stepId}`); + } + + let step = await this.runEnableApiSetupStep({ + stepDef, + accessToken, + projectId, + }); + + if ( + stepDef.id === 'enable_security_command_center_api' && + !step.success && + hasFindingsAccess + ) { + step = { + ...step, + success: true, + error: undefined, + }; + } + + return { email, step }; + } + + private async detectEmail(accessToken: string): Promise { + try { + const resp = await fetch( + 'https://www.googleapis.com/oauth2/v2/userinfo', + { + headers: { Authorization: `Bearer ${accessToken}` }, + }, + ); + if (resp.ok) { + const info = (await resp.json()) as { email?: string }; + return info.email ?? null; + } + } catch { + this.logger.warn('Could not fetch user email'); + } + return null; + } + + private async runEnableApiSetupStep(params: { + stepDef: (typeof REQUIRED_GCP_API_STEPS)[number]; + accessToken: string; + projectId: string; + }): Promise { + const { stepDef, accessToken, projectId } = params; + const actionUrl = this.getApiConsoleUrl(stepDef.api, projectId); + + try { + const resp = await fetch( + `https://serviceusage.googleapis.com/v1/projects/${projectId}/services/${stepDef.api}:enable`, + { + method: 'POST', + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + body: '{}', + }, + ); + + if (resp.ok || resp.status === 409) { + return { + id: stepDef.id, + name: stepDef.name, + success: true, + actionUrl, + actionText: stepDef.actionText, + requiredForScan: stepDef.requiredForScan, + }; + } + + const rawError = await resp.text(); + const message = this.getEnableApiErrorMessage(stepDef.api, rawError); + const isPermissionError = + resp.status === 403 || + /permission denied|does not have permission|forbidden|PERMISSION_DENIED/i.test( + rawError, + ); + + if (isPermissionError) { + const alreadyEnabled = await this.isApiAlreadyEnabled( + accessToken, + projectId, + stepDef.api, + ); + if (alreadyEnabled) { + return { + id: stepDef.id, + name: stepDef.name, + success: true, + actionUrl, + actionText: stepDef.actionText, + requiredForScan: stepDef.requiredForScan, + }; + } + } + + return { + id: stepDef.id, + name: stepDef.name, + success: false, + error: message, + actionUrl, + actionText: stepDef.actionText, + requiredForScan: stepDef.requiredForScan, + adminActions: this.buildEnableApiAdminActions( + stepDef, + projectId, + rawError, + actionUrl, + ), + }; + } catch (err) { + const rawError = err instanceof Error ? err.message : String(err); + return { + id: stepDef.id, + name: stepDef.name, + success: false, + error: this.getEnableApiErrorMessage(stepDef.api, rawError), + actionUrl, + actionText: stepDef.actionText, + requiredForScan: stepDef.requiredForScan, + adminActions: this.buildEnableApiAdminActions( + stepDef, + projectId, + rawError, + actionUrl, + ), + }; + } + } + + private async runGrantFindingsViewerSetupStep(params: { + accessToken: string; + organizationId: string; + email: string | null; + hasFindingsAccess?: boolean; + }): Promise { + const { accessToken, organizationId, email, hasFindingsAccess } = params; + + // If we can already read findings, required scan permission exists. + // Don't fail setup just because this user cannot grant IAM roles. + if (hasFindingsAccess) { + return { + id: 'grant_findings_viewer_role', + name: 'Grant Findings Viewer role', + success: true, + ...FINDINGS_VIEWER_ACTION, + }; + } + + if (!email) { + return { + id: 'grant_findings_viewer_role', + name: 'Grant Findings Viewer role', + success: false, + error: + 'Could not identify your Google account email. Reconnect GCP and approve profile/email access.', + ...FINDINGS_VIEWER_ACTION, + }; + } + + if (!organizationId) { + return { + id: 'grant_findings_viewer_role', + name: 'Grant Findings Viewer role', + success: false, + error: 'Organization ID not detected yet.', + ...FINDINGS_VIEWER_ACTION, + }; + } + + const adminActions = this.buildFindingsViewerAdminActions({ + organizationId, + email, + }); + + try { + const getPolicyResp = await fetch( + `https://cloudresourcemanager.googleapis.com/v3/organizations/${organizationId}:getIamPolicy`, + { + method: 'POST', + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ options: { requestedPolicyVersion: 3 } }), + }, + ); + + if (!getPolicyResp.ok) { + const rawError = await getPolicyResp.text(); + return { + id: 'grant_findings_viewer_role', + name: 'Grant Findings Viewer role', + success: false, + error: this.getFindingsViewerErrorMessage(rawError), + ...FINDINGS_VIEWER_ACTION, + adminActions, + }; + } + + const policy = (await getPolicyResp.json()) as { + version?: number; + bindings?: Array<{ role: string; members: string[] }>; + etag?: string; + }; + + const role = 'roles/securitycenter.findingsViewer'; + const member = `user:${email}`; + const bindings = policy.bindings ?? []; + const existing = bindings.find((b) => b.role === role); + + if (existing && existing.members.includes(member)) { + return { + id: 'grant_findings_viewer_role', + name: 'Grant Findings Viewer role', + success: true, + ...FINDINGS_VIEWER_ACTION, + }; + } + + if (existing) { + existing.members.push(member); + } else { + bindings.push({ role, members: [member] }); + } + + const setPolicyResp = await fetch( + `https://cloudresourcemanager.googleapis.com/v3/organizations/${organizationId}:setIamPolicy`, + { + method: 'POST', + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + policy: { + version: policy.version ?? 3, + bindings, + ...(policy.etag ? { etag: policy.etag } : {}), + }, + updateMask: 'bindings', + }), + }, + ); + + if (setPolicyResp.ok) { + return { + id: 'grant_findings_viewer_role', + name: 'Grant Findings Viewer role', + success: true, + ...FINDINGS_VIEWER_ACTION, + }; + } + + const rawError = await setPolicyResp.text(); + return { + id: 'grant_findings_viewer_role', + name: 'Grant Findings Viewer role', + success: false, + error: this.getFindingsViewerErrorMessage(rawError), + ...FINDINGS_VIEWER_ACTION, + adminActions, + }; + } catch (err) { + const rawError = err instanceof Error ? err.message : String(err); + return { + id: 'grant_findings_viewer_role', + name: 'Grant Findings Viewer role', + success: false, + error: this.getFindingsViewerErrorMessage(rawError), + ...FINDINGS_VIEWER_ACTION, + adminActions, + }; + } + } + + private buildEnableApiAdminActions( + stepDef: (typeof REQUIRED_GCP_API_STEPS)[number], + projectId: string, + rawError?: string, + actionUrl?: string, + ): GcpSetupAdminAction[] { + const actions: GcpSetupAdminAction[] = [ + { + kind: 'link', + label: stepDef.actionText, + url: actionUrl ?? this.getApiConsoleUrl(stepDef.api, projectId), + }, + { + kind: 'command', + label: `Copy: enable ${stepDef.api}`, + command: `gcloud services enable ${stepDef.api} --project=${projectId}`, + }, + ]; + + if (rawError) { + const permInfo = parseGcpPermissionError(rawError, projectId); + if (permInfo.fixScript) { + actions.push({ + kind: 'command', + label: 'Copy: grant required project role', + command: permInfo.fixScript, + }); + } + } + + return actions; + } + + private getApiConsoleUrl(apiName: string, projectId: string): string { + return `https://console.cloud.google.com/apis/library/${apiName}?project=${encodeURIComponent(projectId)}`; + } + + private async isApiAlreadyEnabled( + accessToken: string, + projectId: string, + apiName: string, + ): Promise { + try { + const resp = await fetch( + `https://serviceusage.googleapis.com/v1/projects/${projectId}/services/${apiName}`, + { + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }, + ); + + if (!resp.ok) return false; + + const data = (await resp.json()) as { state?: string }; + return data.state === 'ENABLED'; + } catch { + return false; + } + } + + private async canReadFindings( + accessToken: string, + organizationId: string, + ): Promise { + try { + const url = new URL( + `https://securitycenter.googleapis.com/v2/organizations/${organizationId}/sources/-/findings`, + ); + url.searchParams.set('pageSize', '1'); + url.searchParams.set('filter', 'state="ACTIVE"'); + + const response = await fetch(url.toString(), { + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }); + + return response.ok; + } catch { + return false; + } + } + + private buildFindingsViewerAdminActions(params: { + organizationId: string; + email: string; + }): GcpSetupAdminAction[] { + const { organizationId, email } = params; + return [ + { + kind: 'link', + label: FINDINGS_VIEWER_ACTION.actionText, + url: FINDINGS_VIEWER_ACTION.actionUrl, + }, + { + kind: 'command', + label: 'Copy: grant Findings Viewer role', + command: [ + `gcloud organizations add-iam-policy-binding ${organizationId}`, + ` --member='user:${email}'`, + " --role='roles/securitycenter.findingsViewer'", + ].join(' \\\n'), + }, + ]; + } + + private extractGcpError(raw: string): string { + let message = raw; + try { + const parsed = JSON.parse(raw) as { error?: { message?: string } }; + message = parsed.error?.message ?? raw; + } catch { + message = raw; + } + return message + .replace(/\s*Help Token:\s*[\w-]+/gi, '') + .replace(/\s+/g, ' ') + .trim() + .slice(0, 240); + } + + private getEnableApiErrorMessage(apiName: string, raw: string): string { + const message = this.extractGcpError(raw); + + if ( + /permission denied|does not have permission|forbidden|PERMISSION_DENIED/i.test( + message, + ) + ) { + return `Your account cannot enable ${apiName}. Ask a project owner/editor to enable it.`; + } + + return message || `Failed to enable ${apiName}.`; + } + + private getFindingsViewerErrorMessage(raw: string): string { + const message = this.extractGcpError(raw); + + if ( + /getIamPolicy|resourcemanager\.organizations\.getIamPolicy/i.test(message) + ) { + return 'Your account cannot read organization IAM policy. Ask a GCP organization admin to grant roles/securitycenter.findingsViewer.'; + } + + if ( + /setIamPolicy|resourcemanager\.organizations\.setIamPolicy/i.test(message) + ) { + return 'Your account cannot grant org IAM roles. Ask a GCP organization admin to grant roles/securitycenter.findingsViewer.'; + } + + if ( + /permission denied|does not have permission|forbidden|PERMISSION_DENIED/i.test( + message, + ) + ) { + return 'Your account does not have organization IAM permissions required for auto-setup. Ask a GCP organization admin to grant roles/securitycenter.findingsViewer.'; + } + + return ( + message || + 'Unable to grant Findings Viewer role automatically. Ask a GCP organization admin to grant roles/securitycenter.findingsViewer.' + ); + } + + /** + * Auto-detect GCP organizations accessible by the OAuth token. + */ + async detectOrganizations( + accessToken: string, + ): Promise> { + // v3 search API — works for listing all orgs the user has access to + const response = await fetch( + 'https://cloudresourcemanager.googleapis.com/v3/organizations:search', + { + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }, + ); + + if (!response.ok) { + const errorText = await response.text(); + this.logger.warn(`Failed to search GCP organizations: ${errorText}`); + return []; + } + + const data = await response.json(); + const orgs = (data.organizations ?? []) as Array<{ + name: string; + displayName?: string; + state?: string; + }>; + + return orgs + .filter((o) => o.state === 'ACTIVE') + .map((o) => ({ + // name is "organizations/123456" + id: o.name.replace('organizations/', ''), + displayName: o.displayName ?? o.name, + })); + } + + /** + * Detect active GCP projects scoped to a specific organization. + * Returns only projects whose parent is the given org ID. + */ + async detectProjectsForOrg( + accessToken: string, + organizationId: string, + ): Promise> { + const params = new URLSearchParams({ + pageSize: '50', + filter: `lifecycleState:ACTIVE AND parent.id:${organizationId}`, + }); + const response = await fetch( + `https://cloudresourcemanager.googleapis.com/v1/projects?${params.toString()}`, + { + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }, + ); + + if (!response.ok) { + this.logger.warn( + `Failed to list GCP projects for org ${organizationId}: ${await response.text()}`, + ); + return []; + } + + const data = await response.json(); + return ( + (data.projects ?? []) as Array<{ + projectId: string; + name: string; + projectNumber: string; + }> + ).map((p) => ({ + id: p.projectId, + name: p.name, + number: p.projectNumber, + })); + } + + /** + * Auto-detect active GCP projects accessible by the OAuth token. + * Tries a direct project list first; if empty (common for org-centric accounts), + * lists projects under each accessible organization (parent filter). + */ + async detectProjects( + accessToken: string, + ): Promise> { + const mapRow = (p: { + projectId: string; + name: string; + projectNumber: string; + }) => ({ + id: p.projectId, + name: p.name, + number: p.projectNumber, + }); + + const listProjectsWithFilter = async ( + filter: string, + ): Promise< + Array<{ id: string; name: string; number: string }> + > => { + const params = new URLSearchParams({ + pageSize: '50', + filter, + }); + const response = await fetch( + `https://cloudresourcemanager.googleapis.com/v1/projects?${params.toString()}`, + { + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }, + ); + + if (!response.ok) { + const errorText = await response.text(); + this.logger.warn( + `Failed to list GCP projects (filter=${filter}): ${errorText}`, + ); + return []; + } + + const data = await response.json(); + return ( + (data.projects ?? []) as Array<{ + projectId: string; + name: string; + projectNumber: string; + }> + ).map(mapRow); + }; + + const direct = await listProjectsWithFilter('lifecycleState:ACTIVE'); + if (direct.length > 0) { + this.logger.log( + `GCP detectProjects: ${direct.length} project(s) via direct list`, + ); + return direct; + } + + const orgs = await this.detectOrganizations(accessToken); + if (orgs.length === 0) { + this.logger.warn( + 'GCP detectProjects: no projects from direct list and no organizations — Service Usage detection may be empty', + ); + return []; + } + + const seen = new Set(); + const merged: Array<{ id: string; name: string; number: string }> = []; + + for (const org of orgs) { + const underOrg = await listProjectsWithFilter( + `lifecycleState:ACTIVE AND parent.id:${org.id}`, + ); + for (const p of underOrg) { + if (!seen.has(p.id)) { + seen.add(p.id); + merged.push(p); + } + if (merged.length >= 20) break; + } + if (merged.length >= 20) break; + } + + if (merged.length > 0) { + this.logger.log( + `GCP detectProjects: ${merged.length} project(s) via organization scope`, + ); + } else { + this.logger.warn( + 'GCP detectProjects: organization-scoped list returned no projects — check resourcemanager.projects.list on the org', + ); + } + + return merged; + } + + /** + * Detect which GCP services the customer actually uses by querying + * the Service Usage API for each project. Maps GCP API names to + * our service category IDs. + */ + async detectServices( + accessToken: string, + projects: Array<{ id: string }>, + ): Promise<{ + services: string[]; + servicesByProject: Record; + }> { + const detected = new Set(); + const unmappedApis = new Set(); + const servicesByProject: Record = {}; + + for (const project of projects) { + const projectServices = new Set(); + try { + const response = await fetch( + `https://serviceusage.googleapis.com/v1/projects/${encodeURIComponent(project.id)}/services?filter=state:ENABLED&pageSize=200`, + { + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + }, + ); + + if (!response.ok) { + this.logger.warn(`Failed to list services for project ${project.id}`); + continue; + } + + const data = await response.json(); + const services = (data.services ?? []) as Array<{ + name: string; + config?: { name: string }; + }>; + + for (const svc of services) { + const apiName = svc.config?.name ?? svc.name.split('/').pop() ?? ''; + const mapped = GCP_API_TO_SERVICE[apiName]; + if (mapped) { + for (const id of mapped) { + detected.add(id); + projectServices.add(id); + } + } else if (apiName.endsWith('.googleapis.com')) { + unmappedApis.add(apiName); + } + } + } catch (err) { + this.logger.warn( + `Service detection failed for ${project.id}: ${err instanceof Error ? err.message : String(err)}`, + ); + } + servicesByProject[project.id] = [...projectServices]; + } + + if (detected.size === 0 && unmappedApis.size > 0) { + this.logger.warn( + `GCP Service Usage: ${unmappedApis.size} enabled API(s) had no UI mapping (sample): ${[...unmappedApis].slice(0, 8).join(', ')}`, + ); + } + + this.logger.log( + `Detected ${detected.size} GCP service categories: ${[...detected].join(', ')}`, + ); + return { services: [...detected], servicesByProject }; + } + + /** + * Scan GCP Security Command Center for all active findings. + * Pulls rich data: description, remediation steps, compliance mappings, service grouping. + */ async scanSecurityFindings( credentials: Record, variables: Record, + enabledServices?: string[], ): Promise { const accessToken = credentials.access_token as string; const organizationId = variables.organization_id as string; @@ -29,68 +1033,118 @@ export class GCPSecurityService { throw new Error('Access token is required'); } - if (!organizationId) { + // Read explicitly selected projects + const projectIds: string[] = Array.isArray(variables.project_ids) + ? (variables.project_ids as string[]) + : []; + + // If projects are selected, query per-project; otherwise fall back to org-level + const scopes: Array< + { type: 'project'; id: string } | { type: 'organization'; id: string } + > = + projectIds.length > 0 + ? projectIds.map((id) => ({ type: 'project' as const, id })) + : organizationId + ? [{ type: 'organization' as const, id: organizationId }] + : []; + + if (scopes.length === 0) { + this.logger.warn('GCP: No projects selected and no Organization ID'); throw new Error( - 'Organization ID is required. Configure it in the integration variables.', + 'GCP_ORG_MISSING: No projects selected and Organization ID not detected. Go to the GCP integration settings to configure.', ); } - this.logger.log( - `Scanning GCP Security Command Center for org ${organizationId}`, - ); + const scopeLabel = + projectIds.length > 0 + ? `${projectIds.length} project(s): ${projectIds.join(', ')}` + : `org ${organizationId}`; + this.logger.log(`Scanning GCP SCC for ${scopeLabel}`); const allFindings: SecurityFinding[] = []; - let pageToken: string | undefined; + const enabledServiceSet = enabledServices ? new Set(enabledServices) : null; + const seenIds = new Set(); - do { - const response = await this.fetchSecurityFindings( - accessToken, - organizationId, - pageToken, - ); - - for (const result of response.findings) { - const finding = result.finding; - const severity = this.mapSeverity(finding.severity); - - allFindings.push({ - id: finding.name, - title: finding.category, - description: - finding.description || `Security finding: ${finding.category}`, - severity, - resourceType: 'gcp-resource', - resourceId: finding.resourceName, - remediation: - 'Review and remediate this finding in GCP Security Command Centre', - evidence: { - findingName: finding.name, - category: finding.category, - state: finding.state, - resourceName: finding.resourceName, - severity: finding.severity, - eventTime: finding.eventTime, - }, - createdAt: finding.createTime, - }); - } + for (const scope of scopes) { + try { + let pageToken: string | undefined; + do { + const response = await this.fetchFindings( + accessToken, + scope, + pageToken, + ); + + for (const result of response.findings) { + const f = result.finding; + // Deduplicate across project scopes + if (seenIds.has(f.name)) continue; + seenIds.add(f.name); - pageToken = response.nextPageToken; - } while (pageToken); + const serviceId = + CATEGORY_TO_SERVICE[f.category] ?? 'security-command-center'; + if (enabledServiceSet && !enabledServiceSet.has(serviceId)) { + continue; + } + const findingKey = `gcp-${serviceId}-${f.category.toLowerCase().replace(/_/g, '-')}`; + const remediation = this.buildRemediation(f); + + allFindings.push({ + id: f.name, + title: this.formatTitle(f.category), + description: + f.description || `Security finding: ${f.category}`, + severity: this.mapSeverity(f.severity), + resourceType: result.resource?.type ?? 'gcp-resource', + resourceId: f.resourceName, + remediation, + evidence: { + findingKey, + serviceId, + serviceName: SERVICE_NAMES[serviceId] ?? serviceId, + category: f.category, + state: f.state, + resourceName: f.resourceName, + severity: f.severity, + eventTime: f.eventTime, + externalUri: f.externalUri, + findingClass: f.findingClass, + compliances: f.compliances, + sourceProperties: f.sourceProperties, + projectDisplayName: result.resource?.projectDisplayName, + resourceDisplayName: result.resource?.displayName, + }, + createdAt: f.createTime, + }); + } + + pageToken = response.nextPageToken; + } while (pageToken); + } catch (err) { + // Log and continue with remaining projects — don't fail the whole scan + this.logger.warn( + `GCP SCC query failed for ${scope.type} ${scope.id}: ${err instanceof Error ? err.message : String(err)}`, + ); + } + } this.logger.log(`Found ${allFindings.length} GCP security findings`); return allFindings; } - private async fetchSecurityFindings( + private async fetchFindings( accessToken: string, - organizationId: string, + scope: { type: 'organization' | 'project'; id: string }, pageToken?: string, - ): Promise<{ findings: GCPFindingResult[]; nextPageToken?: string }> { + ): Promise<{ findings: SCCFindingResult[]; nextPageToken?: string }> { + const parent = + scope.type === 'project' + ? `projects/${scope.id}` + : `organizations/${scope.id}`; const url = new URL( - `https://securitycenter.googleapis.com/v2/organizations/${organizationId}/sources/-/findings`, + `https://securitycenter.googleapis.com/v2/${parent}/sources/-/findings`, ); - url.searchParams.set('pageSize', '100'); + url.searchParams.set('pageSize', '500'); url.searchParams.set('filter', 'state="ACTIVE"'); if (pageToken) { @@ -98,7 +1152,6 @@ export class GCPSecurityService { } const response = await fetch(url.toString(), { - method: 'GET', headers: { Authorization: `Bearer ${accessToken}`, 'Content-Type': 'application/json', @@ -107,21 +1160,29 @@ export class GCPSecurityService { if (!response.ok) { const errorText = await response.text(); - this.logger.error(`GCP API error: ${errorText}`); + this.logger.error(`GCP SCC API error: ${errorText}`); - // Parse and provide helpful error messages if (errorText.includes('ACCESS_TOKEN_SCOPE_INSUFFICIENT')) { throw new Error( - 'OAuth scopes insufficient. Please disconnect and reconnect the GCP integration.', + 'OAuth scopes insufficient. Reconnect the GCP integration.', + ); + } + if ( + errorText.includes('SERVICE_DISABLED') || + errorText.includes('has not been used') || + errorText.includes('Security Command Center API') + ) { + throw new Error( + 'SCC_NOT_ACTIVATED: Security Command Center is not activated on your GCP organization. ' + + 'Enable it at https://console.cloud.google.com/security/command-center — the Standard tier is free.', ); } - if ( errorText.includes('PERMISSION_DENIED') || errorText.includes('403') ) { throw new Error( - 'Permission denied. Grant the "Security Center Findings Viewer" role to your Google account at the organization level.', + 'Permission denied. Grant "Security Center Findings Viewer" role at the organization level.', ); } @@ -144,13 +1205,45 @@ export class GCPSecurityService { } const data = await response.json(); - return { - findings: data.listFindingsResults || [], + findings: data.listFindingsResults ?? [], nextPageToken: data.nextPageToken, }; } + /** Build remediation text from SCC's nextSteps + API context for AI auto-fix. */ + private buildRemediation(f: SCCFinding): string { + const parts: string[] = []; + + if (f.nextSteps) { + parts.push(f.nextSteps); + } + + if (f.externalUri) { + parts.push(`More info: ${f.externalUri}`); + } + + if (f.compliances?.length) { + const standards = f.compliances.map( + (c) => `${c.standard} ${c.version} (${c.ids.join(', ')})`, + ); + parts.push(`Compliance: ${standards.join('; ')}`); + } + + return ( + parts.join('\n\n') || + `Review and remediate this ${f.category} finding in GCP Console.` + ); + } + + /** Convert SCC SCREAMING_SNAKE_CASE category to readable title. */ + private formatTitle(category: string): string { + return category + .split('_') + .map((w) => w.charAt(0) + w.slice(1).toLowerCase()) + .join(' '); + } + private mapSeverity(gcpSeverity: string): SecurityFinding['severity'] { const map: Record = { CRITICAL: 'critical', @@ -158,6 +1251,6 @@ export class GCPSecurityService { MEDIUM: 'medium', LOW: 'low', }; - return map[gcpSeverity] || 'medium'; + return map[gcpSeverity] ?? 'medium'; } } diff --git a/apps/api/src/cloud-security/remediation-error.utils.spec.ts b/apps/api/src/cloud-security/remediation-error.utils.spec.ts new file mode 100644 index 0000000000..55d825bc98 --- /dev/null +++ b/apps/api/src/cloud-security/remediation-error.utils.spec.ts @@ -0,0 +1,60 @@ +import { parseAwsPermissionError } from './remediation-error.utils'; + +describe('parseAwsPermissionError', () => { + it('detects "required X permission" pattern', () => { + const msg = + 'The request was rejected because you do not have the required iam:CreateServiceLinkedRole permission.'; + const result = parseAwsPermissionError(msg); + expect(result.isPermissionError).toBe(true); + expect(result.missingActions).toContain('iam:CreateServiceLinkedRole'); + }); + + it('detects "not authorized to perform" pattern', () => { + const msg = + 'User: arn:aws:sts::123456789012:assumed-role/CompAI-Remediator/session is not authorized to perform: guardduty:CreateDetector on resource: *'; + const result = parseAwsPermissionError(msg); + expect(result.isPermissionError).toBe(true); + expect(result.missingActions).toContain('guardduty:CreateDetector'); + }); + + it('detects AccessDeniedException', () => { + const msg = + 'AccessDeniedException: User is not authorized to perform: kms:EnableKeyRotation'; + const result = parseAwsPermissionError(msg); + expect(result.isPermissionError).toBe(true); + expect(result.missingActions).toContain('kms:EnableKeyRotation'); + }); + + it('detects access denied with action', () => { + const msg = 'Access Denied for action: s3:PutBucketEncryption'; + const result = parseAwsPermissionError(msg); + expect(result.isPermissionError).toBe(true); + expect(result.missingActions).toContain('s3:PutBucketEncryption'); + }); + + it('detects permission error without extractable action', () => { + const msg = 'Access Denied'; + const result = parseAwsPermissionError(msg); + expect(result.isPermissionError).toBe(true); + expect(result.missingActions).toEqual([]); + }); + + it('returns false for non-permission errors', () => { + const msg = 'ResourceNotFoundException: Detector not found'; + const result = parseAwsPermissionError(msg); + expect(result.isPermissionError).toBe(false); + expect(result.missingActions).toEqual([]); + }); + + it('returns false for network errors', () => { + const msg = 'NetworkingError: connect ECONNREFUSED'; + const result = parseAwsPermissionError(msg); + expect(result.isPermissionError).toBe(false); + }); + + it('preserves rawMessage', () => { + const msg = 'some error with not authorized text'; + const result = parseAwsPermissionError(msg); + expect(result.rawMessage).toBe(msg); + }); +}); diff --git a/apps/api/src/cloud-security/remediation-error.utils.ts b/apps/api/src/cloud-security/remediation-error.utils.ts new file mode 100644 index 0000000000..72647c4a47 --- /dev/null +++ b/apps/api/src/cloud-security/remediation-error.utils.ts @@ -0,0 +1,221 @@ +export interface PermissionErrorInfo { + isPermissionError: boolean; + missingActions: string[]; + rawMessage: string; +} + +export interface GcpPermissionErrorInfo { + isPermissionError: boolean; + missingPermissions: string[]; + suggestedRole: string | null; + fixScript: string | null; + rawMessage: string; +} + +const PERMISSION_KEYWORDS = [ + 'not authorized', + 'accessdenied', + 'accessdeniedexception', + 'access denied', + 'unauthorizedaccess', + 'do not have the required', + 'forbidden', +] as const; + +/** + * Patterns to extract the specific IAM action from AWS error messages. + * Each pattern should have a capture group for the action string. + */ +const ACTION_PATTERNS: RegExp[] = [ + // "is not authorized to perform: iam:CreateServiceLinkedRole on resource" + /not authorized to perform:\s*([\w:*]+)/i, + // "you do not have the required iam:CreateServiceLinkedRole permission" + /required\s+([\w:*]+)\s+permission/i, + // "User ... is not authorized to perform: ec2:DescribeInstances" + /not authorized to perform:\s*([\w:*]+)/i, + // "Access Denied for action: s3:PutBucketEncryption" + /denied.*?(?:action|for):\s*([\w:*]+)/i, + // "UnauthorizedAccess: guardduty:CreateDetector" + /UnauthorizedAccess.*?([\w]+:[\w*]+)/i, +]; + +/** + * Parse an AWS error message to detect permission errors and extract + * the specific missing IAM action(s). + * + * Gracefully degrades: if it detects a permission error but cannot + * extract the action, `missingActions` will be empty. + */ +export function parseAwsPermissionError( + errorMessage: string, +): PermissionErrorInfo { + const lower = errorMessage.toLowerCase(); + const isPermissionError = PERMISSION_KEYWORDS.some((kw) => + lower.includes(kw), + ); + + if (!isPermissionError) { + return { + isPermissionError: false, + missingActions: [], + rawMessage: errorMessage, + }; + } + + const actions = new Set(); + for (const pattern of ACTION_PATTERNS) { + const match = errorMessage.match(pattern); + if (match?.[1]) { + actions.add(match[1]); + } + } + + return { + isPermissionError: true, + missingActions: [...actions], + rawMessage: errorMessage, + }; +} + +// ─── GCP Permission Error Parsing ────────────────────────────────────────── + +/** Map GCP permission prefixes to recommended predefined roles. */ +const GCP_PERMISSION_TO_ROLE: Array<{ prefix: string; role: string }> = [ + { prefix: 'storage.', role: 'roles/storage.admin' }, + { prefix: 'compute.firewalls', role: 'roles/compute.securityAdmin' }, + { prefix: 'compute.instances', role: 'roles/compute.instanceAdmin.v1' }, + { prefix: 'compute.subnetworks', role: 'roles/compute.networkAdmin' }, + { prefix: 'compute.networks', role: 'roles/compute.networkAdmin' }, + { prefix: 'compute.', role: 'roles/compute.admin' }, + { prefix: 'cloudsql.', role: 'roles/cloudsql.admin' }, + { prefix: 'cloudkms.', role: 'roles/cloudkms.admin' }, + { prefix: 'logging.', role: 'roles/logging.admin' }, + { prefix: 'dns.', role: 'roles/dns.admin' }, + { prefix: 'container.', role: 'roles/container.admin' }, + { prefix: 'iam.', role: 'roles/iam.securityAdmin' }, + { prefix: 'resourcemanager.', role: 'roles/resourcemanager.projectIamAdmin' }, + { prefix: 'pubsub.', role: 'roles/pubsub.admin' }, + { prefix: 'bigquery.', role: 'roles/bigquery.admin' }, +]; + +/** GCP permission extraction patterns. */ +const GCP_PERMISSION_PATTERNS: RegExp[] = [ + // "Permission denied: caller does not have permission 'storage.buckets.update'" + /permission\s+'([\w.]+)'/i, + // From metadata: "permission": "storage.buckets.update" + /"permission":\s*"([\w.]+)"/i, + // "required permission(s): storage.buckets.update" + /required permission[s]?:\s*([\w.]+)/i, + // GCP format: "does not have storage.buckets.update access" + /does not have\s+([\w.]+)\s+access/i, + // Inline: Permission 'compute.firewalls.update' denied + /'([\w.]+)'\s*denied/i, +]; + +/** + * Parse a GCP API error to detect permission errors, extract the missing + * permission, suggest a role, and generate a ready-to-paste gcloud command. + */ +export function parseGcpPermissionError( + errorMessage: string, + projectId?: string, +): GcpPermissionErrorInfo { + const lower = errorMessage.toLowerCase(); + const isPermissionError = + lower.includes('permission_denied') || + lower.includes('permission denied') || + lower.includes('does not have') || + (lower.includes('403') && lower.includes('permission')); + + if (!isPermissionError) { + return { + isPermissionError: false, + missingPermissions: [], + suggestedRole: null, + fixScript: null, + rawMessage: errorMessage, + }; + } + + // Extract the specific permission + const permissions = new Set(); + for (const pattern of GCP_PERMISSION_PATTERNS) { + const match = errorMessage.match(pattern); + if (match?.[1]) permissions.add(match[1]); + } + + // Find best matching role + const permList = [...permissions]; + let suggestedRole: string | null = null; + for (const perm of permList) { + const entry = GCP_PERMISSION_TO_ROLE.find((r) => perm.startsWith(r.prefix)); + if (entry) { + suggestedRole = entry.role; + break; + } + } + + // Build gcloud fix script + let fixScript: string | null = null; + if (suggestedRole) { + const project = projectId ?? 'YOUR_PROJECT_ID'; + fixScript = [ + 'gcloud projects add-iam-policy-binding ' + project, + " --member='user:YOUR_EMAIL'", + ` --role='${suggestedRole}'`, + ].join(' \\\n'); + } + + return { + isPermissionError: true, + missingPermissions: permList, + suggestedRole, + fixScript, + rawMessage: errorMessage, + }; +} + +export interface AzurePermissionErrorInfo { + isPermissionError: boolean; + missingActions: string[]; + fixScript: string | null; + rawMessage: string; +} + +/** + * Parse an Azure API error to detect permission (403/AuthorizationFailed) errors. + */ +export function parseAzurePermissionError( + errorMessage: string, +): AzurePermissionErrorInfo | null { + const lower = errorMessage.toLowerCase(); + const isPermissionError = + lower.includes('authorizationfailed') || + lower.includes('authorization failed') || + lower.includes('403') || + lower.includes('does not have authorization') || + lower.includes('forbidden'); + + if (!isPermissionError) return null; + + // Try to extract action from Azure error: "does not have authorization to perform action 'X' over scope" + const actionMatch = errorMessage.match(/perform action '([^']+)'/); + const missingActions = actionMatch ? [actionMatch[1]] : []; + + const fixScript = + missingActions.length > 0 + ? [ + 'az role assignment create \\', + " --assignee '' \\", + " --role 'Contributor' \\", + " --scope '/subscriptions/'", + ].join('\n') + : null; + + return { + isPermissionError: true, + missingActions, + fixScript, + rawMessage: errorMessage, + }; +} diff --git a/apps/api/src/cloud-security/remediation.controller.spec.ts b/apps/api/src/cloud-security/remediation.controller.spec.ts new file mode 100644 index 0000000000..06142d4650 --- /dev/null +++ b/apps/api/src/cloud-security/remediation.controller.spec.ts @@ -0,0 +1,305 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { HttpException, HttpStatus } from '@nestjs/common'; +import { RemediationController } from './remediation.controller'; +import { RemediationService } from './remediation.service'; +import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; +import { PermissionGuard } from '../auth/permission.guard'; + +// Mock auth.server to avoid importing better-auth ESM in Jest +jest.mock('../auth/auth.server', () => ({ + auth: { api: { getSession: jest.fn() } }, +})); + +jest.mock('@trycompai/auth', () => ({ + statement: {}, + BUILT_IN_ROLE_PERMISSIONS: {}, +})); + +jest.mock('./cloud-security-audit', () => ({ + logCloudSecurityActivity: jest.fn().mockResolvedValue(undefined), +})); + +describe('RemediationController', () => { + let controller: RemediationController; + let service: jest.Mocked; + + const mockService = { + getCapabilities: jest.fn(), + previewRemediation: jest.fn(), + executeRemediation: jest.fn(), + rollbackRemediation: jest.fn(), + getActions: jest.fn(), + }; + + const mockGuard = { canActivate: jest.fn().mockReturnValue(true) }; + + const orgId = 'org_123'; + const userId = 'usr_456'; + const connectionId = 'conn_789'; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + controllers: [RemediationController], + providers: [{ provide: RemediationService, useValue: mockService }], + }) + .overrideGuard(HybridAuthGuard) + .useValue(mockGuard) + .overrideGuard(PermissionGuard) + .useValue(mockGuard) + .compile(); + + controller = module.get(RemediationController); + service = module.get(RemediationService); + + jest.clearAllMocks(); + }); + + describe('getCapabilities', () => { + it('should call service with connectionId and organizationId', async () => { + const capabilities = { + enabled: true, + remediations: [{ remediationKey: 's3-block-public-access' }], + }; + mockService.getCapabilities.mockResolvedValue(capabilities); + + const result = await controller.getCapabilities(connectionId, orgId); + + expect(service.getCapabilities).toHaveBeenCalledWith({ + connectionId, + organizationId: orgId, + }); + expect(result).toEqual(capabilities); + }); + + it('should throw BAD_REQUEST when connectionId is missing', async () => { + await expect(controller.getCapabilities('', orgId)).rejects.toThrow( + HttpException, + ); + + await expect(controller.getCapabilities('', orgId)).rejects.toMatchObject( + { + status: HttpStatus.BAD_REQUEST, + }, + ); + + expect(service.getCapabilities).not.toHaveBeenCalled(); + }); + + it('should throw BAD_REQUEST when service throws', async () => { + mockService.getCapabilities.mockRejectedValue( + new Error('Connection not found'), + ); + + await expect( + controller.getCapabilities(connectionId, orgId), + ).rejects.toThrow(HttpException); + }); + }); + + describe('preview', () => { + const body = { + connectionId, + checkResultId: 'cr_001', + remediationKey: 's3-block-public-access', + }; + + it('should call service with body params and organizationId', async () => { + const preview = { + description: 'Will block public access', + risk: 'low', + apiCalls: ['s3:PutPublicAccessBlock'], + }; + mockService.previewRemediation.mockResolvedValue(preview); + + const result = await controller.preview(body, orgId); + + expect(service.previewRemediation).toHaveBeenCalledWith({ + connectionId: body.connectionId, + organizationId: orgId, + checkResultId: body.checkResultId, + remediationKey: body.remediationKey, + }); + expect(result).toEqual(preview); + }); + + it('should throw BAD_REQUEST when service throws', async () => { + mockService.previewRemediation.mockRejectedValue( + new Error('Finding not found'), + ); + + await expect(controller.preview(body, orgId)).rejects.toThrow( + HttpException, + ); + }); + }); + + describe('execute', () => { + const body = { + connectionId, + checkResultId: 'cr_001', + remediationKey: 's3-block-public-access', + }; + + it('should call service with body params, organizationId, and userId', async () => { + const result = { + actionId: 'act_001', + status: 'success' as const, + resourceId: 'my-bucket', + previousState: { publicAccess: true }, + appliedState: { publicAccess: false }, + }; + mockService.executeRemediation.mockResolvedValue(result); + + const response = await controller.execute(body, orgId, userId); + + expect(service.executeRemediation).toHaveBeenCalledWith({ + connectionId: body.connectionId, + organizationId: orgId, + checkResultId: body.checkResultId, + remediationKey: body.remediationKey, + userId, + }); + expect(response).toEqual(result); + }); + + it('should pass acknowledgment to service', async () => { + const bodyWithAck = { + ...body, + acknowledgment: 'acknowledged', + }; + const result = { + actionId: 'act_001', + status: 'success' as const, + resourceId: 'my-resource', + previousState: { subscriptionState: 'INACTIVE' }, + appliedState: { subscriptionState: 'ACTIVE' }, + }; + mockService.executeRemediation.mockResolvedValue(result); + + const response = await controller.execute(bodyWithAck, orgId, userId); + + expect(service.executeRemediation).toHaveBeenCalledWith({ + connectionId: body.connectionId, + organizationId: orgId, + checkResultId: body.checkResultId, + remediationKey: body.remediationKey, + userId, + acknowledgment: 'acknowledged', + }); + expect(response).toEqual(result); + }); + + it('should pass type-to-confirm acknowledgment to service', async () => { + const bodyWithTypeConfirm = { + ...body, + acknowledgment: 'enable shield advanced', + }; + const result = { + actionId: 'act_002', + status: 'success' as const, + resourceId: 'my-resource', + previousState: { subscriptionState: 'INACTIVE' }, + appliedState: { subscriptionState: 'ACTIVE' }, + }; + mockService.executeRemediation.mockResolvedValue(result); + + const response = await controller.execute( + bodyWithTypeConfirm, + orgId, + userId, + ); + + expect(service.executeRemediation).toHaveBeenCalledWith({ + connectionId: body.connectionId, + organizationId: orgId, + checkResultId: body.checkResultId, + remediationKey: body.remediationKey, + userId, + acknowledgment: 'enable shield advanced', + }); + expect(response).toEqual(result); + }); + + it('should throw BAD_REQUEST when service throws', async () => { + mockService.executeRemediation.mockRejectedValue( + new Error('No credentials found'), + ); + + await expect(controller.execute(body, orgId, userId)).rejects.toThrow( + HttpException, + ); + }); + }); + + describe('rollback', () => { + const actionId = 'act_001'; + + it('should call service with actionId and organizationId', async () => { + const rollbackResult = { + status: 'rolled_back' as const, + connectionId: 'conn_789', + remediationKey: 's3-block-public-access', + resourceId: 'my-bucket', + }; + mockService.rollbackRemediation.mockResolvedValue(rollbackResult); + + const result = await controller.rollback(actionId, orgId, userId); + + expect(service.rollbackRemediation).toHaveBeenCalledWith({ + actionId, + organizationId: orgId, + }); + expect(result).toEqual({ status: 'rolled_back' }); + }); + + it('should throw BAD_REQUEST when service throws', async () => { + mockService.rollbackRemediation.mockRejectedValue( + new Error('Remediation action not found'), + ); + + await expect( + controller.rollback(actionId, orgId, userId), + ).rejects.toThrow(HttpException); + }); + }); + + describe('getActions', () => { + it('should return actions with count from service', async () => { + const actions = [ + { id: 'act_001', status: 'success' }, + { id: 'act_002', status: 'failed' }, + ]; + mockService.getActions.mockResolvedValue(actions); + + const result = await controller.getActions(connectionId, orgId); + + expect(service.getActions).toHaveBeenCalledWith({ + connectionId, + organizationId: orgId, + }); + expect(result).toEqual({ data: actions, count: 2 }); + }); + + it('should throw BAD_REQUEST when connectionId is missing', async () => { + await expect(controller.getActions('', orgId)).rejects.toThrow( + HttpException, + ); + + await expect(controller.getActions('', orgId)).rejects.toMatchObject({ + status: HttpStatus.BAD_REQUEST, + }); + + expect(service.getActions).not.toHaveBeenCalled(); + }); + + it('should throw BAD_REQUEST when service throws', async () => { + mockService.getActions.mockRejectedValue( + new Error('Connection not found'), + ); + + await expect(controller.getActions(connectionId, orgId)).rejects.toThrow( + HttpException, + ); + }); + }); +}); diff --git a/apps/api/src/cloud-security/remediation.controller.ts b/apps/api/src/cloud-security/remediation.controller.ts new file mode 100644 index 0000000000..ca7e9f82c0 --- /dev/null +++ b/apps/api/src/cloud-security/remediation.controller.ts @@ -0,0 +1,374 @@ +import { + Controller, + Post, + Get, + Patch, + Param, + Query, + Body, + Logger, + HttpException, + HttpStatus, + UseGuards, +} from '@nestjs/common'; +import { SkipThrottle } from '@nestjs/throttler'; +import { db } from '@db'; +import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; +import { PermissionGuard } from '../auth/permission.guard'; +import { RequirePermission } from '../auth/require-permission.decorator'; +import { OrganizationId, UserId } from '../auth/auth-context.decorator'; +import { RemediationService } from './remediation.service'; +import { logCloudSecurityActivity } from './cloud-security-audit'; + +@Controller({ path: 'cloud-security/remediation', version: '1' }) +@UseGuards(HybridAuthGuard, PermissionGuard) +export class RemediationController { + private readonly logger = new Logger(RemediationController.name); + + constructor(private readonly remediationService: RemediationService) {} + + @Get('capabilities') + @SkipThrottle() + @RequirePermission('integration', 'read') + async getCapabilities( + @Query('connectionId') connectionId: string, + @OrganizationId() organizationId: string, + ) { + if (!connectionId) { + throw new HttpException( + 'connectionId query parameter is required', + HttpStatus.BAD_REQUEST, + ); + } + + try { + return await this.remediationService.getCapabilities({ + connectionId, + organizationId, + }); + } catch (error) { + const message = + error instanceof Error ? error.message : 'Failed to get capabilities'; + throw new HttpException(message, HttpStatus.BAD_REQUEST); + } + } + + @Post('preview') + @RequirePermission('integration', 'update') + async preview( + @Body() + body: { + connectionId: string; + checkResultId: string; + remediationKey: string; + cachedPermissions?: string[]; + }, + @OrganizationId() organizationId: string, + ) { + try { + return await this.remediationService.previewRemediation({ + connectionId: body.connectionId, + organizationId, + checkResultId: body.checkResultId, + remediationKey: body.remediationKey, + cachedPermissions: body.cachedPermissions, + }); + } catch (error) { + const message = error instanceof Error ? error.message : 'Preview failed'; + this.logger.error(`Remediation preview failed: ${message}`); + throw new HttpException(message, HttpStatus.BAD_REQUEST); + } + } + + @Post('execute') + @RequirePermission('integration', 'update') + async execute( + @Body() + body: { + connectionId: string; + checkResultId: string; + remediationKey: string; + acknowledgment?: string; + }, + @OrganizationId() organizationId: string, + @UserId() userId: string, + ) { + try { + const result = await this.remediationService.executeRemediation({ + connectionId: body.connectionId, + organizationId, + checkResultId: body.checkResultId, + remediationKey: body.remediationKey, + userId, + acknowledgment: body.acknowledgment, + }); + + if (result.status === 'success') { + await logCloudSecurityActivity({ + organizationId, + userId, + connectionId: body.connectionId, + action: 'remediation_executed', + description: `Applied auto-fix: ${body.remediationKey} on ${result.resourceId}`, + metadata: { + remediationKey: body.remediationKey, + actionId: result.actionId, + resourceId: result.resourceId, + acknowledgmentText: body.acknowledgment, + acknowledgedBy: userId, + acknowledgedAt: new Date().toISOString(), + previousState: result.previousState, + appliedState: result.appliedState, + verified: (result.appliedState as Record) + ?.verified, + }, + }); + } else if (result.status === 'failed') { + await logCloudSecurityActivity({ + organizationId, + userId, + connectionId: body.connectionId, + action: 'remediation_failed', + description: `Auto-fix failed: ${body.remediationKey} on ${result.resourceId} — ${result.error}`, + metadata: { + remediationKey: body.remediationKey, + actionId: result.actionId, + resourceId: result.resourceId, + acknowledgmentText: body.acknowledgment, + acknowledgedBy: userId, + error: result.error, + }, + }); + } else { + await logCloudSecurityActivity({ + organizationId, + userId, + connectionId: body.connectionId, + action: 'remediation_failed', + description: `Auto-fix did not succeed: ${body.remediationKey} on ${result.resourceId} (status ${result.status})`, + metadata: { + remediationKey: body.remediationKey, + actionId: result.actionId, + resourceId: result.resourceId, + acknowledgmentText: body.acknowledgment, + acknowledgedBy: userId, + status: result.status, + }, + }); + } + + return result; + } catch (error) { + const message = + error instanceof Error ? error.message : 'Execution failed'; + this.logger.error(`Remediation execution failed: ${message}`); + throw new HttpException(message, HttpStatus.BAD_REQUEST); + } + } + + @Post(':actionId/rollback') + @RequirePermission('integration', 'update') + async rollback( + @Param('actionId') actionId: string, + @OrganizationId() organizationId: string, + @UserId() userId: string, + ) { + try { + const result = await this.remediationService.rollbackRemediation({ + actionId, + organizationId, + }); + + const isSuccess = result.status === 'rolled_back'; + await logCloudSecurityActivity({ + organizationId, + userId, + connectionId: result.connectionId, + action: isSuccess ? 'rollback_executed' : 'rollback_failed', + description: isSuccess + ? `Rolled back: ${result.remediationKey} on ${result.resourceId}` + : `Rollback failed: ${result.remediationKey} on ${result.resourceId} — ${(result as { error?: string }).error}`, + metadata: { + actionId, + remediationKey: result.remediationKey, + resourceId: result.resourceId, + status: result.status, + rolledBackBy: userId, + rolledBackAt: new Date().toISOString(), + ...((result as { error?: string }).error && { + error: (result as { error?: string }).error, + }), + }, + }); + + return result; + } catch (error) { + const raw = error instanceof Error ? error.message : 'Rollback failed'; + this.logger.error(`Remediation rollback failed: ${raw}`); + + // Log the failure to audit trail + await logCloudSecurityActivity({ + organizationId, + userId, + connectionId: actionId, // best effort — action ID as fallback + action: 'rollback_failed', + description: `Rollback failed for action ${actionId}: ${raw}`, + metadata: { actionId, error: raw, rolledBackBy: userId }, + }).catch(() => {}); // don't let audit log failure block the response + + // Try to parse structured permission error + try { + const parsed = JSON.parse(raw); + if (parsed.missingActions) { + throw new HttpException( + { + message: parsed.message, + missingActions: parsed.missingActions, + script: parsed.script, + }, + HttpStatus.BAD_REQUEST, + ); + } + } catch (parseErr) { + if (parseErr instanceof HttpException) throw parseErr; + } + + throw new HttpException(raw, HttpStatus.BAD_REQUEST); + } + } + + @Get('actions') + @RequirePermission('integration', 'read') + async getActions( + @Query('connectionId') connectionId: string, + @OrganizationId() organizationId: string, + ) { + if (!connectionId) { + throw new HttpException( + 'connectionId query parameter is required', + HttpStatus.BAD_REQUEST, + ); + } + + try { + const actions = await this.remediationService.getActions({ + connectionId, + organizationId, + }); + return { data: actions, count: actions.length }; + } catch (error) { + const message = + error instanceof Error ? error.message : 'Failed to get actions'; + throw new HttpException(message, HttpStatus.BAD_REQUEST); + } + } + + // ─── Batch endpoints ────────────────────────────────────────────── + + /** Get active batch for a connection (if any). */ + @Get('batch/active') + @RequirePermission('integration', 'read') + async getActiveBatch( + @Query('connectionId') connectionId: string, + @OrganizationId() organizationId: string, + ) { + const batch = await db.remediationBatch.findFirst({ + where: { + connectionId, + organizationId, + status: { in: ['pending', 'running'] }, + }, + orderBy: { createdAt: 'desc' }, + }); + return { data: batch }; + } + + /** Create a new batch record (called before triggering the task). */ + @Post('batch') + @RequirePermission('integration', 'update') + async createBatch( + @Body() + body: { + connectionId: string; + findings: Array<{ id: string; key: string; title: string }>; + }, + @OrganizationId() organizationId: string, + @UserId() userId: string, + ) { + const findings = body.findings.map((f) => ({ + id: f.id, + key: f.key, + title: f.title, + status: 'pending', + })); + + const batch = await db.remediationBatch.create({ + data: { + connectionId: body.connectionId, + organizationId, + initiatedById: userId, + status: 'pending', + findings, + }, + }); + + await logCloudSecurityActivity({ + organizationId, + userId, + connectionId: body.connectionId, + action: 'remediation_executed', + description: `Started batch fix: ${body.findings.length} findings`, + metadata: { batchId: batch.id, findingCount: body.findings.length }, + }); + + return { data: batch }; + } + + /** Update a batch (set triggerRunId after task starts). */ + @Patch('batch/:batchId') + @RequirePermission('integration', 'update') + async updateBatch( + @Param('batchId') batchId: string, + @Body() body: { triggerRunId?: string; status?: string }, + @OrganizationId() organizationId: string, + ) { + const batch = await db.remediationBatch.update({ + where: { id: batchId, organizationId }, + data: { + ...(body.triggerRunId && { triggerRunId: body.triggerRunId }), + ...(body.status && { status: body.status }), + }, + }); + return { data: batch }; + } + + /** Skip a specific finding in an active batch. */ + @Post('batch/:batchId/skip/:findingId') + @RequirePermission('integration', 'update') + async skipFinding( + @Param('batchId') batchId: string, + @Param('findingId') findingId: string, + @OrganizationId() organizationId: string, + ) { + const batch = await db.remediationBatch.findFirst({ + where: { id: batchId, organizationId }, + }); + if (!batch) { + throw new HttpException('Batch not found', HttpStatus.NOT_FOUND); + } + + const findings = batch.findings as Array<{ id: string; status: string }>; + const updated = findings.map((f) => + f.id === findingId && f.status === 'pending' + ? { ...f, status: 'cancelled' } + : f, + ); + + await db.remediationBatch.update({ + where: { id: batchId }, + data: { findings: updated }, + }); + + return { success: true }; + } +} diff --git a/apps/api/src/cloud-security/remediation.service.ts b/apps/api/src/cloud-security/remediation.service.ts new file mode 100644 index 0000000000..65d257bbad --- /dev/null +++ b/apps/api/src/cloud-security/remediation.service.ts @@ -0,0 +1,1009 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { db, Prisma } from '@db'; +import { CredentialVaultService } from '../integration-platform/services/credential-vault.service'; +import { parseAwsPermissionError } from './remediation-error.utils'; +import { AWSSecurityService } from './providers/aws-security.service'; +import { AiRemediationService } from './ai-remediation.service'; +import { GcpRemediationService } from './gcp-remediation.service'; +import { AzureRemediationService } from './azure-remediation.service'; +import { + executeAwsCommand, + executePlanSteps, + validatePlanSteps, +} from './aws-command-executor'; +import type { FixPlan, AwsCommandStep } from './ai-remediation.prompt'; + +@Injectable() +export class RemediationService { + private readonly logger = new Logger(RemediationService.name); + /** Cache fix plans between preview and execute to avoid double AI calls. */ + private readonly planCache = new Map< + string, + { plan: FixPlan; timestamp: number; permissionsList?: string[] } + >(); + private readonly PLAN_CACHE_MAX = 100; + private readonly PLAN_CACHE_TTL = 5 * 60 * 1000; + + private evictStalePlans() { + if (this.planCache.size <= this.PLAN_CACHE_MAX) return; + const now = Date.now(); + for (const [key, entry] of this.planCache) { + if (now - entry.timestamp > this.PLAN_CACHE_TTL) + this.planCache.delete(key); + } + // If still over limit, delete oldest + while (this.planCache.size > this.PLAN_CACHE_MAX) { + const firstKey = this.planCache.keys().next().value; + if (firstKey) this.planCache.delete(firstKey); + else break; + } + } + + constructor( + private readonly credentialVaultService: CredentialVaultService, + private readonly awsSecurityService: AWSSecurityService, + private readonly aiRemediationService: AiRemediationService, + private readonly gcpRemediationService: GcpRemediationService, + private readonly azureRemediationService: AzureRemediationService, + ) {} + + async getCapabilities(params: { + connectionId: string; + organizationId: string; + }) { + const connection = await this.getConnection(params); + + if (connection.provider.slug === 'gcp') { + return this.gcpRemediationService.getCapabilities(params); + } + + if (connection.provider.slug === 'azure') { + return this.azureRemediationService.getCapabilities(params); + } + + if (connection.provider.slug !== 'aws') { + return { enabled: false, remediations: [] }; + } + + const credentials = + await this.credentialVaultService.getDecryptedCredentials( + params.connectionId, + ); + + return { + enabled: Boolean(credentials?.remediationRoleArn), + aiPowered: true, + remediations: [], + }; + } + + async previewRemediation(params: { + connectionId: string; + organizationId: string; + checkResultId: string; + remediationKey: string; + cachedPermissions?: string[]; + }) { + // Delegate GCP/Azure to dedicated services + const connection = await this.getConnection(params); + if (connection.provider.slug === 'gcp') { + return this.gcpRemediationService.previewRemediation(params); + } + if (connection.provider.slug === 'azure') { + return this.azureRemediationService.previewRemediation(params); + } + + const { finding, credentials, region } = await this.resolveContext(params); + + const evidence = (finding.evidence ?? {}) as Record; + const findingKey = evidence.findingKey as string; + + // RECHECK MODE: if frontend sends cachedPermissions, skip AI entirely + // Just re-read the role and compare against the SAME list + if (params.cachedPermissions && params.cachedPermissions.length > 0) { + this.logger.log( + `Recheck mode: checking ${params.cachedPermissions.length} cached permissions: ${params.cachedPermissions.slice(0, 5).join(', ')}...`, + ); + const remediationCreds = + await this.awsSecurityService.assumeRemediationRole( + credentials, + region, + ); + let missingPermissions: string[] | undefined; + let permissionFixScript: string | undefined; + try { + const existingActions = await this.getExistingRolePermissions( + remediationCreds, + region, + ); + this.logger.log(`Role has ${existingActions.size} actions`); + const missing = params.cachedPermissions.filter( + (p) => !this.isPermissionCovered(p, existingActions), + ); + if (missing.length > 0) { + missingPermissions = missing; + // Always include ALL cached permissions in script — not just missing ones + // This prevents overwrite issues with IAM eventual consistency + permissionFixScript = this.buildStaticPermissionScript( + params.cachedPermissions, + ); + } + } catch (err) { + this.logger.warn( + `Cannot read role policies on recheck: ${err instanceof Error ? err.message : String(err)}`, + ); + missingPermissions = params.cachedPermissions; + } + + // Return cached plan data with updated permission status + const cached = this.planCache.get( + `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`, + ); + const cachedPlan = cached?.plan; + + return { + currentState: cachedPlan?.currentState ?? {}, + proposedState: cachedPlan?.proposedState ?? {}, + description: cachedPlan?.description ?? 'Recheck permissions', + risk: cachedPlan?.risk ?? 'medium', + apiCalls: cachedPlan?.requiredPermissions ?? params.cachedPermissions, + guidedOnly: false, + rollbackSupported: cachedPlan?.rollbackSupported ?? true, + requiresAcknowledgment: 'checkbox' as const, + acknowledgmentMessage: + 'This fix will modify your AWS infrastructure. Please review the changes above before proceeding.', + allRequiredPermissions: params.cachedPermissions, + ...(missingPermissions && + missingPermissions.length > 0 && { + missingPermissions, + permissionFixScript, + }), + }; + } + + const plan = await this.aiRemediationService.generateFixPlan({ + title: finding.title ?? 'Unknown', + description: finding.description, + severity: finding.severity, + resourceType: finding.resourceType, + resourceId: finding.resourceId, + remediation: finding.remediation, + findingKey, + evidence, + }); + + if (!plan.canAutoFix) { + return { + currentState: plan.currentState, + proposedState: {}, + description: plan.description, + risk: plan.risk, + apiCalls: [], + guidedOnly: true, + guidedSteps: plan.guidedSteps ?? [plan.reason ?? plan.description], + rollbackSupported: false, + requiresAcknowledgment: undefined, + }; + } + + // If plan has read steps, execute them now to get REAL state and refine the plan + if (plan.readSteps.length > 0) { + const readErrors = validatePlanSteps(plan.readSteps); + if (readErrors.length === 0) { + try { + const remediationCreds = + await this.awsSecurityService.assumeRemediationRole( + credentials, + region, + ); + const readResult = await executePlanSteps({ + steps: plan.readSteps, + credentials: remediationCreds, + region, + }); + const realState = readResult.results.reduce( + (acc, r) => ({ ...acc, [r.step.purpose]: r.output }), + {} as Record, + ); + + // Refine plan with real data + const refined = await this.aiRemediationService.refineFixPlan({ + finding: { + title: finding.title ?? 'Unknown', + description: finding.description, + severity: finding.severity, + resourceType: finding.resourceType, + resourceId: finding.resourceId, + remediation: finding.remediation, + findingKey, + evidence, + }, + originalPlan: plan, + realAwsState: realState, + }); + + // If AI now says it can't auto-fix, show guided steps + if (!refined.canAutoFix) { + return { + currentState: refined.currentState, + proposedState: {}, + description: refined.description, + risk: refined.risk, + apiCalls: [], + guidedOnly: true, + guidedSteps: refined.guidedSteps ?? [ + refined.reason ?? refined.description, + ], + rollbackSupported: false, + requiresAcknowledgment: undefined, + }; + } + + // Build the COMPLETE permission list from ALL sources + const aiPermissions = + await this.aiRemediationService.analyzeRequiredPermissions(refined); + + // Merge: AI analysis + refined plan's requiredPermissions + derived from commands + const allPerms = new Set([ + ...aiPermissions, + ...refined.requiredPermissions, + ]); + + // Also derive from actual step commands + const svcMap: Record = { + s3: 's3', + logs: 'logs', + 'cloudwatch-logs': 'logs', + cloudtrail: 'cloudtrail', + cloudwatch: 'cloudwatch', + iam: 'iam', + sns: 'sns', + ec2: 'ec2', + rds: 'rds', + kms: 'kms', + 'config-service': 'config', + guardduty: 'guardduty', + lambda: 'lambda', + dynamodb: 'dynamodb', + cloudfront: 'cloudfront', + }; + for (const step of [...refined.readSteps, ...refined.fixSteps]) { + const iamSvc = svcMap[step.service] ?? step.service; + // Resolve the REAL command name from the SDK (handles AI fuzzy names) + const realAction = this.resolveRealActionName( + step.service, + step.command, + ); + allPerms.add(`${iamSvc}:${realAction}`); + } + // Always add iam:PassRole if any role is being used + const allStepStr = JSON.stringify([...refined.fixSteps]); + if (allStepStr.includes('Role') || allStepStr.includes('role')) { + allPerms.add('iam:PassRole'); + } + + // Filter out dangerous + unnecessary + const dangerousActions = /Delete|Remove|Terminate|Deregister/i; + const permissionsList = [...allPerms] + .filter((p) => !dangerousActions.test(p.split(':')[1] ?? '')) + .filter( + (p) => p !== 'sts:GetCallerIdentity' && p !== 'sts:AssumeRole', + ) + .sort(); + // Check permissions by reading the ACTUAL policies on CompAI-Remediator + let missingPermissions: string[] | undefined; + let permissionFixScript: string | undefined; + try { + const existingActions = await this.getExistingRolePermissions( + remediationCreds, + region, + ); + this.logger.log( + `CompAI-Remediator has ${existingActions.size} actions. Needed: ${permissionsList.length}`, + ); + const missing = permissionsList.filter( + (p) => !this.isPermissionCovered(p, existingActions), + ); + if (missing.length > 0) { + this.logger.log( + `Missing ${missing.length} permissions: ${missing.join(', ')}`, + ); + missingPermissions = missing; + permissionFixScript = + this.buildStaticPermissionScript(permissionsList); + } + } catch (err) { + this.logger.warn( + `Cannot read role policies: ${err instanceof Error ? err.message : String(err)}`, + ); + missingPermissions = permissionsList; + permissionFixScript = + this.buildStaticPermissionScript(permissionsList); + } + + // Cache the refined plan + permissions for execute and Recheck + this.evictStalePlans(); + this.planCache.set( + `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`, + { plan: refined, timestamp: Date.now(), permissionsList }, + ); + + return { + currentState: refined.currentState, + proposedState: refined.proposedState, + description: refined.description, + risk: refined.risk, + apiCalls: refined.requiredPermissions, + guidedOnly: false, + rollbackSupported: refined.rollbackSupported, + requiresAcknowledgment: 'checkbox' as const, + acknowledgmentMessage: + 'This fix will modify your AWS infrastructure. Please review the changes above before proceeding.', + allRequiredPermissions: permissionsList, + ...(missingPermissions && + missingPermissions.length > 0 && { + missingPermissions, + permissionFixScript, + }), + }; + } catch { + // If read fails, fall through to show the AI's initial plan + } + } + } + + // Fallback: show initial AI plan without real data + this.evictStalePlans(); + this.planCache.set( + `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`, + { + plan, + timestamp: Date.now(), + permissionsList: plan.requiredPermissions, + }, + ); + + return { + currentState: plan.currentState, + proposedState: plan.proposedState, + description: plan.description, + risk: plan.risk, + apiCalls: plan.requiredPermissions, + guidedOnly: false, + rollbackSupported: plan.rollbackSupported, + requiresAcknowledgment: 'checkbox' as const, + acknowledgmentMessage: + 'This fix will modify your AWS infrastructure. Please review the changes above before proceeding.', + }; + } + + async executeRemediation(params: { + connectionId: string; + organizationId: string; + checkResultId: string; + remediationKey: string; + userId: string; + acknowledgment?: string; + }) { + // Delegate GCP/Azure to dedicated services + const connection = await this.getConnection(params); + if (connection.provider.slug === 'gcp') { + return this.gcpRemediationService.executeRemediation(params); + } + if (connection.provider.slug === 'azure') { + return this.azureRemediationService.executeRemediation(params); + } + + const { finding, credentials, region } = await this.resolveContext(params); + + // Get plan from cache or regenerate + let plan: FixPlan; + const cached = this.planCache.get( + `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`, + ); + if (cached && Date.now() - cached.timestamp < 5 * 60 * 1000) { + plan = cached.plan; + } else { + const evidence = (finding.evidence ?? {}) as Record; + plan = await this.aiRemediationService.generateFixPlan({ + title: finding.title ?? 'Unknown', + description: finding.description, + severity: finding.severity, + resourceType: finding.resourceType, + resourceId: finding.resourceId, + remediation: finding.remediation, + findingKey: evidence.findingKey as string, + evidence, + }); + } + + if (!plan.canAutoFix) { + throw new Error( + 'This finding requires manual remediation and cannot be auto-fixed.', + ); + } + + // Universal plan validation — reject plans that would leave infra in a bad state + if (!plan.fixSteps || plan.fixSteps.length === 0) { + throw new Error('AI generated an empty fix plan. Cannot proceed.'); + } + if (!plan.rollbackSteps || plan.rollbackSteps.length === 0) { + this.logger.warn( + `No rollback steps for ${params.remediationKey} — fix is irreversible`, + ); + } + + // Always require acknowledgment — we're modifying cloud infrastructure + if (!params.acknowledgment || params.acknowledgment !== 'acknowledged') { + throw new Error( + 'Acknowledgment is required before executing any remediation.', + ); + } + + // Create the action record + const action = await db.remediationAction.create({ + data: { + checkResultId: params.checkResultId, + connectionId: params.connectionId, + organizationId: params.organizationId, + initiatedById: params.userId, + remediationKey: params.remediationKey, + resourceId: finding.resourceId, + resourceType: finding.resourceType, + previousState: {}, + appliedState: {}, + status: 'executing', + riskLevel: plan.risk, + acknowledgmentText: params.acknowledgment ?? null, + acknowledgedAt: params.acknowledgment ? new Date() : null, + }, + }); + + try { + // Validate read steps first + const readErrors = validatePlanSteps(plan.readSteps); + if (readErrors.length > 0) { + throw new Error(`Invalid read steps: ${readErrors.join('; ')}`); + } + + const remediationCreds = + await this.awsSecurityService.assumeRemediationRole( + credentials, + region, + ); + + // Phase 1: Execute read steps to get REAL AWS state + const readResult = await executePlanSteps({ + steps: plan.readSteps, + credentials: remediationCreds, + region, + }); + const previousState = readResult.results.reduce( + (acc, r) => ({ ...acc, [r.step.purpose]: r.output }), + {} as Record, + ); + + // Phase 2: Send real AWS state back to AI to generate EXACT fix steps + const evidence = (finding.evidence ?? {}) as Record; + const refinedPlan = await this.aiRemediationService.refineFixPlan({ + finding: { + title: finding.title ?? 'Unknown', + description: finding.description, + severity: finding.severity, + resourceType: finding.resourceType, + resourceId: finding.resourceId, + remediation: finding.remediation, + findingKey: evidence.findingKey as string, + evidence, + }, + originalPlan: plan, + realAwsState: previousState, + }); + + if (!refinedPlan.canAutoFix) { + // AI found the fix can't be automated after seeing real state — return as failed with guidance + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'failed', + errorMessage: refinedPlan.reason ?? 'Cannot be auto-fixed.', + }, + }); + return { + actionId: action.id, + status: 'failed' as const, + resourceId: finding.resourceId, + error: + refinedPlan.reason ?? + 'This finding requires manual setup before auto-fix is possible.', + guidedSteps: refinedPlan.guidedSteps, + }; + } + + // Validate refined fix steps + if (!refinedPlan.fixSteps || refinedPlan.fixSteps.length === 0) { + throw new Error('AI refined plan has no fix steps. Cannot proceed.'); + } + const fixErrors = validatePlanSteps(refinedPlan.fixSteps); + if (fixErrors.length > 0) { + throw new Error(`Invalid fix steps: ${fixErrors.join('; ')}`); + } + + // Phase 3: Execute the refined fix steps (now with REAL values) + // Pass rollback steps for automatic undo on partial failure + const fixResult = await executePlanSteps({ + steps: refinedPlan.fixSteps, + credentials: remediationCreds, + region, + autoRollbackSteps: refinedPlan.rollbackSteps, + }); + + if (fixResult.error) { + this.logger.error( + `Fix step ${fixResult.error.stepIndex + 1} failed: ${fixResult.error.step.service}:${fixResult.error.step.command} — ${fixResult.error.message}`, + ); + this.logger.error( + `Step params: ${JSON.stringify(fixResult.error.step.params).slice(0, 500)}`, + ); + throw new Error(fixResult.error.message); + } + + const appliedState = { + steps: fixResult.results.map((r) => ({ + command: `${r.step.service}:${r.step.command}`, + output: r.output, + })), + rollbackSteps: refinedPlan.rollbackSteps, + }; + + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'success', + previousState: previousState as Prisma.InputJsonValue, + appliedState: appliedState as Prisma.InputJsonValue, + executedAt: new Date(), + }, + }); + + this.logger.log(`Remediation executed on ${finding.resourceId}`); + this.planCache.delete( + `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`, + ); + + return { + actionId: action.id, + status: 'success' as const, + resourceId: finding.resourceId, + previousState, + appliedState, + }; + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : String(error); + const permissionInfo = parseAwsPermissionError(errorMessage); + + // If permission error, build script with ALL needed permissions (not just the one that failed) + // This prevents overwriting CompAI-AutoFix with a partial list + let permissionError: + | { missingActions: string[]; fixScript?: string } + | undefined; + if (permissionInfo.isPermissionError && plan.fixSteps.length > 0) { + try { + const suggestion = + await this.aiRemediationService.suggestPermissionFix({ + errorMessage, + failedStep: plan.fixSteps[0], + }); + // Merge: cached permissions from preview + newly discovered missing ones + const cached = this.planCache.get( + `${params.connectionId}:${params.checkResultId}:${params.remediationKey}`, + ); + const allPerms = new Set([ + ...(cached?.permissionsList ?? plan.requiredPermissions), + ...suggestion.missingActions, + ]); + const mergedScript = this.buildStaticPermissionScript([...allPerms]); + permissionError = { + missingActions: suggestion.missingActions, + fixScript: mergedScript, + }; + } catch { + permissionError = { missingActions: permissionInfo.missingActions }; + } + } + + await db.remediationAction.update({ + where: { id: action.id }, + data: { status: 'failed', errorMessage }, + }); + + this.logger.error(`Remediation failed: ${errorMessage}`); + + return { + actionId: action.id, + status: 'failed' as const, + resourceId: finding.resourceId, + error: errorMessage, + ...(permissionError && { permissionError }), + }; + } + } + + async rollbackRemediation(params: { + actionId: string; + organizationId: string; + }) { + // Check provider to delegate GCP rollback + const actionWithProvider = await db.remediationAction.findFirst({ + where: { id: params.actionId, organizationId: params.organizationId }, + include: { connection: { include: { provider: true } } }, + }); + + if (!actionWithProvider) throw new Error('Remediation action not found'); + + if (actionWithProvider.connection?.provider?.slug === 'gcp') { + return this.gcpRemediationService.rollbackRemediation(params); + } + if (actionWithProvider.connection?.provider?.slug === 'azure') { + return this.azureRemediationService.rollbackRemediation(params); + } + + const action = actionWithProvider; + if (action.status !== 'success' && action.status !== 'unverified') { + throw new Error(`Cannot rollback action with status "${action.status}"`); + } + + const appliedState = action.appliedState as Record; + const rollbackSteps = (appliedState.rollbackSteps ?? + []) as AwsCommandStep[]; + + if (rollbackSteps.length === 0) { + throw new Error('No rollback steps available for this action'); + } + + const credentials = + await this.credentialVaultService.getDecryptedCredentials( + action.connectionId, + ); + if (!credentials) throw new Error('No credentials found'); + + const region = this.getRegion(credentials); + const remediationCreds = + await this.awsSecurityService.assumeRemediationRole(credentials, region); + + try { + const result = await executePlanSteps({ + steps: rollbackSteps, + credentials: remediationCreds, + region, + isRollback: true, + }); + + if (result.error) throw new Error(result.error.message); + + await db.remediationAction.update({ + where: { id: action.id }, + data: { status: 'rolled_back', rolledBackAt: new Date() }, + }); + + this.logger.log( + `Rolled back ${action.remediationKey} on ${action.resourceId}`, + ); + + return { + status: 'rolled_back' as const, + connectionId: action.connectionId, + remediationKey: action.remediationKey, + resourceId: action.resourceId, + }; + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : String(error); + const permissionInfo = parseAwsPermissionError(errorMessage); + + await db.remediationAction.update({ + where: { id: action.id }, + data: { + status: 'rollback_failed', + errorMessage: `Rollback failed: ${errorMessage}`, + }, + }); + + // If permission error, include actionable info + if (permissionInfo.isPermissionError) { + const missingActions = + permissionInfo.missingActions.length > 0 + ? permissionInfo.missingActions + : ['(could not determine specific action)']; + const script = this.buildStaticPermissionScript(missingActions); + throw new Error( + JSON.stringify({ + message: `Rollback failed: missing permissions`, + missingActions, + script, + }), + ); + } + + throw new Error(`Rollback failed: ${errorMessage}`); + } + } + + async getActions(params: { connectionId: string; organizationId: string }) { + const actions = await db.remediationAction.findMany({ + where: { + connectionId: params.connectionId, + organizationId: params.organizationId, + }, + orderBy: { createdAt: 'desc' }, + take: 50, + }); + + const userIds = [...new Set(actions.map((a) => a.initiatedById))].filter( + (id) => id !== 'system', + ); + const users = userIds.length + ? await db.user.findMany({ + where: { id: { in: userIds } }, + select: { id: true, name: true }, + }) + : []; + const userMap = new Map(users.map((u) => [u.id, u.name])); + userMap.set('system', 'System'); + + return actions.map((a) => ({ + ...a, + initiatedByName: userMap.get(a.initiatedById) ?? null, + })); + } + + // ─── Private helpers ────────────────────────────────────────────────── + + private async getConnection(params: { + connectionId: string; + organizationId: string; + }) { + const connection = await db.integrationConnection.findFirst({ + where: { + id: params.connectionId, + organizationId: params.organizationId, + status: 'active', + }, + include: { provider: true }, + }); + if (!connection) throw new Error('Connection not found or inactive'); + return connection; + } + + private async resolveContext(params: { + connectionId: string; + organizationId: string; + checkResultId: string; + remediationKey: string; + }) { + const connection = await this.getConnection(params); + if (connection.provider.slug !== 'aws') { + throw new Error('Remediation is only supported for AWS'); + } + + const finding = await db.integrationCheckResult.findFirst({ + where: { + id: params.checkResultId, + checkRun: { connectionId: params.connectionId }, + }, + }); + if (!finding) throw new Error('Finding not found'); + + const credentials = + await this.credentialVaultService.getDecryptedCredentials( + params.connectionId, + ); + if (!credentials) throw new Error('No credentials found'); + + // Extract region from finding evidence or resourceId (not just first configured region) + const region = this.getRegionForFinding(finding, credentials); + return { finding, credentials, region }; + } + + /** + * Determine the correct AWS region for a finding. + * Priority: evidence.region > ARN region > first configured region > us-east-1 + */ + private getRegionForFinding( + finding: { resourceId: string | null; evidence: unknown }, + credentials: Record, + ): string { + // 1. Check evidence for explicit region + const evidence = (finding.evidence ?? {}) as Record; + if (typeof evidence.region === 'string' && evidence.region) { + return evidence.region; + } + + // 2. Extract region from ARN (arn:aws:service:REGION:account:resource) + const resourceId = finding.resourceId ?? ''; + const arnMatch = resourceId.match(/^arn:aws[^:]*:[^:]+:([a-z0-9-]+):/); + if (arnMatch?.[1] && arnMatch[1] !== '*') { + return arnMatch[1]; + } + + // 3. Fall back to first configured region + return this.getRegion(credentials); + } + + /** + * Resolve AI-generated command name to the REAL SDK command name, + * then derive the correct IAM action. + * e.g., "PutBucketPublicAccessBlockCommand" → finds "PutPublicAccessBlockCommand" → returns "PutPublicAccessBlock" + */ + private resolveRealActionName(service: string, command: string): string { + // Import the SDK module statically (same as executor) + try { + // eslint-disable-next-line @typescript-eslint/no-require-imports + const mod = require(`@aws-sdk/client-${service}`) as Record< + string, + unknown + >; + + // Exact match + if (mod[command] && typeof mod[command] === 'function') { + return command.replace('Command', ''); + } + + // Fuzzy match — find closest command in the module + const cmdBase = command.replace('Command', ''); + const match = Object.keys(mod).find((k) => { + if (!k.endsWith('Command') || typeof mod[k] !== 'function') + return false; + const kBase = k.replace('Command', ''); + return ( + kBase.includes(cmdBase) || + cmdBase.includes(kBase) || + kBase.replace('Bucket', '') === cmdBase.replace('Bucket', '') + ); + }); + + if (match) { + return match.replace('Command', ''); + } + } catch { + // Module not found — fall back to raw name + } + + return command.replace('Command', ''); + } + + /** + * Build a permission script. Always includes ALL provided permissions. + * No IAM reads, no merging — avoids eventual consistency issues. + */ + private buildStaticPermissionScript(permissions: string[]): string { + const sorted = [...new Set(permissions)].sort(); + return [ + 'aws iam put-role-policy', + ' --role-name CompAI-Remediator', + ' --policy-name CompAI-AutoFix', + ` --policy-document '${JSON.stringify({ Version: '2012-10-17', Statement: [{ Effect: 'Allow', Action: sorted, Resource: '*' }] })}'`, + ].join(' \\\n'); + } + + private getRegion(credentials: Record): string { + if (Array.isArray(credentials.regions) && credentials.regions.length > 0) { + return credentials.regions[0] as string; + } + return 'us-east-1'; + } + + /** + * Read the ACTUAL IAM policies attached to CompAI-Remediator and return + * a Set of all allowed actions. This is deterministic — no simulation. + */ + private async getExistingRolePermissions( + credentials: { + accessKeyId: string; + secretAccessKey: string; + sessionToken?: string; + }, + region: string, + ): Promise> { + const { IAMClient, ListRolePoliciesCommand, GetRolePolicyCommand } = + await import('@aws-sdk/client-iam'); + const iam = new IAMClient({ + region, + credentials: { + accessKeyId: credentials.accessKeyId, + secretAccessKey: credentials.secretAccessKey, + sessionToken: credentials.sessionToken, + }, + }); + + const actions = new Set(); + const roleName = 'CompAI-Remediator'; + + try { + // List all inline policies + const listResp = await iam.send( + new ListRolePoliciesCommand({ RoleName: roleName }), + ); + const policyNames = listResp.PolicyNames ?? []; + this.logger.log( + `Role ${roleName} has ${policyNames.length} inline policies: ${policyNames.join(', ')}`, + ); + + // Read each policy and extract actions + for (const policyName of policyNames) { + try { + const policyResp = await iam.send( + new GetRolePolicyCommand({ + RoleName: roleName, + PolicyName: policyName, + }), + ); + const doc = JSON.parse( + decodeURIComponent(policyResp.PolicyDocument ?? '{}'), + ); + this.logger.log( + `Policy ${policyName}: ${JSON.stringify(doc).slice(0, 200)}`, + ); + const statements = Array.isArray(doc.Statement) ? doc.Statement : []; + for (const stmt of statements) { + if (stmt.Effect !== 'Allow') continue; + const stmtActions = Array.isArray(stmt.Action) + ? stmt.Action + : [stmt.Action]; + for (const action of stmtActions) { + if (typeof action === 'string') { + if (action === '*') { + actions.add('*'); + } else if (action.includes('*')) { + // Wildcard like "s3:*" or "cloudtrail:*" + actions.add(action); + } else { + actions.add(action); + } + } + } + } + } catch (policyErr) { + this.logger.warn( + `Failed to read policy ${policyName}: ${policyErr instanceof Error ? policyErr.message : String(policyErr)}`, + ); + } + } + this.logger.log( + `Total actions found on role: ${actions.size}. Sample: ${[...actions].slice(0, 10).join(', ')}`, + ); + } finally { + iam.destroy?.(); + } + + return actions; + } + + /** + * Check if a required permission is covered by the existing policy. + * Handles wildcards and common AI naming mistakes. + */ + private isPermissionCovered( + required: string, + existing: Set, + ): boolean { + if (existing.has('*')) return true; + if (existing.has(required)) return true; + // Check service wildcards: "s3:*" covers "s3:CreateBucket" + const [svc] = required.split(':'); + if (svc && existing.has(`${svc}:*`)) return true; + // AI sometimes adds "Bucket" in action names: s3:PutBucketPublicAccessBlock vs s3:PutPublicAccessBlock + const withoutBucket = required + .replace(':PutBucket', ':Put') + .replace(':GetBucket', ':Get') + .replace(':DeleteBucket', ':Delete'); + if (withoutBucket !== required && existing.has(withoutBucket)) return true; + const withBucket = required + .replace(':Put', ':PutBucket') + .replace(':Get', ':GetBucket') + .replace(':Delete', ':DeleteBucket'); + if (withBucket !== required && existing.has(withBucket)) return true; + return false; + } +} diff --git a/apps/api/src/comments/comments.controller.spec.ts b/apps/api/src/comments/comments.controller.spec.ts index 1d53e22fd5..fc3897b23d 100644 --- a/apps/api/src/comments/comments.controller.spec.ts +++ b/apps/api/src/comments/comments.controller.spec.ts @@ -10,7 +10,15 @@ import { CommentsService } from './comments.service'; jest.mock('@db', () => ({ ...jest.requireActual('@prisma/client'), db: {}, - Prisma: { PrismaClientKnownRequestError: class PrismaClientKnownRequestError extends Error { code: string; constructor(message: string, { code }: { code: string }) { super(message); this.code = code; } } }, + Prisma: { + PrismaClientKnownRequestError: class PrismaClientKnownRequestError extends Error { + code: string; + constructor(message: string, { code }: { code: string }) { + super(message); + this.code = code; + } + }, + }, })); jest.mock('../auth/auth.server', () => ({ @@ -77,7 +85,11 @@ describe('CommentsController', () => { const comments = [{ id: 'cmt_1', content: 'Hello' }]; mockCommentsService.getComments.mockResolvedValue(comments); - const result = await controller.getComments('org_123', 'tsk_1', 'task' as never); + const result = await controller.getComments( + 'org_123', + 'tsk_1', + 'task' as never, + ); expect(commentsService.getComments).toHaveBeenCalledWith( 'org_123', @@ -163,7 +175,10 @@ describe('CommentsController', () => { describe('updateComment', () => { it('should call commentsService.updateComment with correct parameters for session auth', async () => { - const dto = { content: 'Updated content', contextUrl: 'https://example.com' }; + const dto = { + content: 'Updated content', + contextUrl: 'https://example.com', + }; const updated = { id: 'cmt_1', content: 'Updated content' }; mockCommentsService.updateComment.mockResolvedValue(updated); diff --git a/apps/api/src/common/filters/cors-exception.filter.ts b/apps/api/src/common/filters/cors-exception.filter.ts index 009d267d4c..f8fe0a8aa5 100644 --- a/apps/api/src/common/filters/cors-exception.filter.ts +++ b/apps/api/src/common/filters/cors-exception.filter.ts @@ -16,7 +16,7 @@ export class CorsExceptionFilter implements ExceptionFilter { const status = exception.getStatus(); // Get the request origin - const origin = request.headers.origin as string | undefined; + const origin = request.headers.origin; // Set CORS headers on error responses for trusted origins. // Uses the sync check only — the main CORS middleware already validated diff --git a/apps/api/src/context/context.controller.spec.ts b/apps/api/src/context/context.controller.spec.ts index 8ecc5ddd40..25a6104834 100644 --- a/apps/api/src/context/context.controller.spec.ts +++ b/apps/api/src/context/context.controller.spec.ts @@ -178,7 +178,11 @@ describe('ContextController', () => { describe('updateContext', () => { it('should call contextService.updateById with id, organizationId, and dto', async () => { const dto = { answer: 'Updated answer' }; - const updated = { id: 'ctx_1', question: 'What is SOC2?', answer: 'Updated answer' }; + const updated = { + id: 'ctx_1', + question: 'What is SOC2?', + answer: 'Updated answer', + }; mockContextService.updateById.mockResolvedValue(updated); const result = await controller.updateContext( diff --git a/apps/api/src/context/context.controller.ts b/apps/api/src/context/context.controller.ts index 13f1e2108a..d8afa28aa5 100644 --- a/apps/api/src/context/context.controller.ts +++ b/apps/api/src/context/context.controller.ts @@ -45,8 +45,16 @@ export class ContextController { @Get() @RequirePermission('evidence', 'read') @ApiOperation(CONTEXT_OPERATIONS.getAllContext) - @ApiQuery({ name: 'search', required: false, description: 'Search by question text' }) - @ApiQuery({ name: 'page', required: false, description: 'Page number (1-based)' }) + @ApiQuery({ + name: 'search', + required: false, + description: 'Search by question text', + }) + @ApiQuery({ + name: 'page', + required: false, + description: 'Page number (1-based)', + }) @ApiQuery({ name: 'perPage', required: false, description: 'Items per page' }) @ApiResponse(GET_ALL_CONTEXT_RESPONSES[200]) @ApiResponse(GET_ALL_CONTEXT_RESPONSES[401]) @@ -61,7 +69,11 @@ export class ContextController { ) { const result = await this.contextService.findAllByOrganization( organizationId, - { search, page: page ? parseInt(page, 10) : undefined, perPage: perPage ? parseInt(perPage, 10) : undefined }, + { + search, + page: page ? parseInt(page, 10) : undefined, + perPage: perPage ? parseInt(perPage, 10) : undefined, + }, ); return { diff --git a/apps/api/src/controls/controls.controller.spec.ts b/apps/api/src/controls/controls.controller.spec.ts index 25b61809da..ccbeb405af 100644 --- a/apps/api/src/controls/controls.controller.spec.ts +++ b/apps/api/src/controls/controls.controller.spec.ts @@ -99,7 +99,14 @@ describe('ControlsController', () => { it('should parse sortDesc as false when not "true"', async () => { mockService.findAll.mockResolvedValue({ data: [], count: 0 }); - await controller.findAll('org_1', undefined, undefined, undefined, undefined, 'false'); + await controller.findAll( + 'org_1', + undefined, + undefined, + undefined, + undefined, + 'false', + ); expect(service.findAll).toHaveBeenCalledWith('org_1', { page: 1, @@ -147,8 +154,15 @@ describe('ControlsController', () => { describe('create', () => { it('should call service.create with organizationId and dto', async () => { - const dto: CreateControlDto = { name: 'New Control', description: 'A test control' }; - const mockCreated = { id: 'ctrl_new', name: 'New Control', description: 'A test control' }; + const dto: CreateControlDto = { + name: 'New Control', + description: 'A test control', + }; + const mockCreated = { + id: 'ctrl_new', + name: 'New Control', + description: 'A test control', + }; mockService.create.mockResolvedValue(mockCreated); const result = await controller.create('org_1', dto); diff --git a/apps/api/src/controls/controls.controller.ts b/apps/api/src/controls/controls.controller.ts index 605e814d4f..f19bb4766b 100644 --- a/apps/api/src/controls/controls.controller.ts +++ b/apps/api/src/controls/controls.controller.ts @@ -8,7 +8,12 @@ import { Query, UseGuards, } from '@nestjs/common'; -import { ApiTags, ApiBearerAuth, ApiOperation, ApiQuery } from '@nestjs/swagger'; +import { + ApiTags, + ApiBearerAuth, + ApiOperation, + ApiQuery, +} from '@nestjs/swagger'; import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; import { PermissionGuard } from '../auth/permission.guard'; import { RequirePermission } from '../auth/require-permission.decorator'; @@ -28,9 +33,21 @@ export class ControlsController { @ApiOperation({ summary: 'List controls with relations' }) @ApiQuery({ name: 'page', required: false }) @ApiQuery({ name: 'perPage', required: false }) - @ApiQuery({ name: 'name', required: false, description: 'Filter by name (case-insensitive contains)' }) - @ApiQuery({ name: 'sortBy', required: false, description: 'Field to sort by (default: name)' }) - @ApiQuery({ name: 'sortDesc', required: false, description: 'Sort descending (true/false)' }) + @ApiQuery({ + name: 'name', + required: false, + description: 'Filter by name (case-insensitive contains)', + }) + @ApiQuery({ + name: 'sortBy', + required: false, + description: 'Field to sort by (default: name)', + }) + @ApiQuery({ + name: 'sortDesc', + required: false, + description: 'Sort descending (true/false)', + }) async findAll( @OrganizationId() organizationId: string, @Query('page') page?: string, diff --git a/apps/api/src/controls/controls.service.ts b/apps/api/src/controls/controls.service.ts index c349ebb705..03698193d2 100644 --- a/apps/api/src/controls/controls.service.ts +++ b/apps/api/src/controls/controls.service.ts @@ -1,7 +1,4 @@ -import { - Injectable, - NotFoundException, -} from '@nestjs/common'; +import { Injectable, NotFoundException } from '@nestjs/common'; import { db, Prisma } from '@db'; import { CreateControlDto } from './dto/create-control.dto'; @@ -107,7 +104,8 @@ export class ControlsService { progress: { total: totalItems, completed, - progress: totalItems > 0 ? Math.round((completed / totalItems) * 100) : 0, + progress: + totalItems > 0 ? Math.round((completed / totalItems) * 100) : 0, byType: { policy: { total: policies.length, completed: policyCompleted }, task: { total: tasks.length, completed: taskCompleted }, diff --git a/apps/api/src/device-agent/device-agent-auth.service.spec.ts b/apps/api/src/device-agent/device-agent-auth.service.spec.ts index 7025e899a5..e8c3c5f199 100644 --- a/apps/api/src/device-agent/device-agent-auth.service.spec.ts +++ b/apps/api/src/device-agent/device-agent-auth.service.spec.ts @@ -1,4 +1,8 @@ -import { ForbiddenException, NotFoundException, UnauthorizedException } from '@nestjs/common'; +import { + ForbiddenException, + NotFoundException, + UnauthorizedException, +} from '@nestjs/common'; import { DeviceAgentAuthService } from './device-agent-auth.service'; // Mock dependencies @@ -42,7 +46,9 @@ import { deviceAgentRedisClient } from './device-agent-kv'; const mockDb = db as jest.Mocked; const mockAuth = auth as jest.Mocked; -const mockKv = deviceAgentRedisClient as jest.Mocked; +const mockKv = deviceAgentRedisClient as jest.Mocked< + typeof deviceAgentRedisClient +>; describe('DeviceAgentAuthService', () => { let service: DeviceAgentAuthService; @@ -62,7 +68,10 @@ describe('DeviceAgentAuthService', () => { const headers = new Headers(); headers.set('cookie', 'session=abc'); - const result = await service.generateAuthCode({ headers, state: 'test-state' }); + const result = await service.generateAuthCode({ + headers, + state: 'test-state', + }); expect(result.code).toHaveLength(64); // 32 bytes hex expect(mockKv.set).toHaveBeenCalledWith( @@ -137,7 +146,9 @@ describe('DeviceAgentAuthService', () => { }); expect(mockDb.member.findMany).toHaveBeenCalledWith({ where: { userId: 'user-1', deactivated: false }, - include: { organization: { select: { id: true, name: true, slug: true } } }, + include: { + organization: { select: { id: true, name: true, slug: true } }, + }, }); }); }); @@ -160,11 +171,16 @@ describe('DeviceAgentAuthService', () => { }); it('should create a new device without serial number', async () => { - (mockDb.member.findFirst as jest.Mock).mockResolvedValue({ id: 'member-1' }); + (mockDb.member.findFirst as jest.Mock).mockResolvedValue({ + id: 'member-1', + }); (mockDb.device.findFirst as jest.Mock).mockResolvedValue(null); (mockDb.device.create as jest.Mock).mockResolvedValue({ id: 'dev-1' }); - const result = await service.registerDevice({ userId: 'user-1', dto: baseDto }); + const result = await service.registerDevice({ + userId: 'user-1', + dto: baseDto, + }); expect(result).toEqual({ deviceId: 'dev-1' }); expect(mockDb.device.create).toHaveBeenCalledWith({ @@ -178,7 +194,9 @@ describe('DeviceAgentAuthService', () => { }); it('should create a new device with serial number', async () => { - (mockDb.member.findFirst as jest.Mock).mockResolvedValue({ id: 'member-1' }); + (mockDb.member.findFirst as jest.Mock).mockResolvedValue({ + id: 'member-1', + }); (mockDb.device.findUnique as jest.Mock).mockResolvedValue(null); (mockDb.device.create as jest.Mock).mockResolvedValue({ id: 'dev-2' }); @@ -189,12 +207,16 @@ describe('DeviceAgentAuthService', () => { }); it('should update existing device when same member re-registers', async () => { - (mockDb.member.findFirst as jest.Mock).mockResolvedValue({ id: 'member-1' }); + (mockDb.member.findFirst as jest.Mock).mockResolvedValue({ + id: 'member-1', + }); (mockDb.device.findUnique as jest.Mock).mockResolvedValue({ id: 'dev-existing', memberId: 'member-1', }); - (mockDb.device.update as jest.Mock).mockResolvedValue({ id: 'dev-existing' }); + (mockDb.device.update as jest.Mock).mockResolvedValue({ + id: 'dev-existing', + }); const dto = { ...baseDto, serialNumber: 'ABC123' }; const result = await service.registerDevice({ userId: 'user-1', dto }); @@ -204,14 +226,18 @@ describe('DeviceAgentAuthService', () => { }); it('should use fallback serial when serial belongs to different member', async () => { - (mockDb.member.findFirst as jest.Mock).mockResolvedValue({ id: 'member-2' }); + (mockDb.member.findFirst as jest.Mock).mockResolvedValue({ + id: 'member-2', + }); (mockDb.device.findUnique as jest.Mock).mockResolvedValue({ id: 'dev-other', memberId: 'member-1', }); // No existing fallback (mockDb.device.findFirst as jest.Mock).mockResolvedValue(null); - (mockDb.device.create as jest.Mock).mockResolvedValue({ id: 'dev-fallback' }); + (mockDb.device.create as jest.Mock).mockResolvedValue({ + id: 'dev-fallback', + }); const dto = { ...baseDto, serialNumber: 'GENERIC-SERIAL' }; const result = await service.registerDevice({ userId: 'user-1', dto }); @@ -235,17 +261,35 @@ describe('DeviceAgentAuthService', () => { screenLockEnabled: false, checkDetails: {}, }); - (mockDb.device.update as jest.Mock).mockResolvedValue({ isCompliant: true }); + (mockDb.device.update as jest.Mock).mockResolvedValue({ + isCompliant: true, + }); const result = await service.checkIn({ userId: 'user-1', dto: { deviceId: 'dev-1', checks: [ - { checkType: 'disk_encryption', passed: true, checkedAt: new Date().toISOString() }, - { checkType: 'antivirus', passed: true, checkedAt: new Date().toISOString() }, - { checkType: 'password_policy', passed: true, checkedAt: new Date().toISOString() }, - { checkType: 'screen_lock', passed: true, checkedAt: new Date().toISOString() }, + { + checkType: 'disk_encryption', + passed: true, + checkedAt: new Date().toISOString(), + }, + { + checkType: 'antivirus', + passed: true, + checkedAt: new Date().toISOString(), + }, + { + checkType: 'password_policy', + passed: true, + checkedAt: new Date().toISOString(), + }, + { + checkType: 'screen_lock', + passed: true, + checkedAt: new Date().toISOString(), + }, ], }, }); @@ -273,7 +317,11 @@ describe('DeviceAgentAuthService', () => { dto: { deviceId: 'dev-missing', checks: [ - { checkType: 'disk_encryption', passed: true, checkedAt: new Date().toISOString() }, + { + checkType: 'disk_encryption', + passed: true, + checkedAt: new Date().toISOString(), + }, ], }, }), @@ -289,15 +337,25 @@ describe('DeviceAgentAuthService', () => { screenLockEnabled: false, checkDetails: {}, }); - (mockDb.device.update as jest.Mock).mockResolvedValue({ isCompliant: false }); + (mockDb.device.update as jest.Mock).mockResolvedValue({ + isCompliant: false, + }); const result = await service.checkIn({ userId: 'user-1', dto: { deviceId: 'dev-1', checks: [ - { checkType: 'disk_encryption', passed: true, checkedAt: new Date().toISOString() }, - { checkType: 'antivirus', passed: false, checkedAt: new Date().toISOString() }, + { + checkType: 'disk_encryption', + passed: true, + checkedAt: new Date().toISOString(), + }, + { + checkType: 'antivirus', + passed: false, + checkedAt: new Date().toISOString(), + }, ], }, }); diff --git a/apps/api/src/device-agent/device-agent-auth.service.ts b/apps/api/src/device-agent/device-agent-auth.service.ts index 65bf2955bc..c81682e026 100644 --- a/apps/api/src/device-agent/device-agent-auth.service.ts +++ b/apps/api/src/device-agent/device-agent-auth.service.ts @@ -69,9 +69,7 @@ export class DeviceAgentAuthService { ); if (!stored) { - throw new UnauthorizedException( - 'Invalid or expired authorization code', - ); + throw new UnauthorizedException('Invalid or expired authorization code'); } return { diff --git a/apps/api/src/device-agent/device-agent.controller.ts b/apps/api/src/device-agent/device-agent.controller.ts index 339f7cc5b5..b765b8c73e 100644 --- a/apps/api/src/device-agent/device-agent.controller.ts +++ b/apps/api/src/device-agent/device-agent.controller.ts @@ -11,8 +11,17 @@ import { StreamableFile, UseGuards, } from '@nestjs/common'; -import { ApiOperation, ApiResponse, ApiSecurity, ApiTags } from '@nestjs/swagger'; -import { AuthContext, OrganizationId, UserId } from '../auth/auth-context.decorator'; +import { + ApiOperation, + ApiResponse, + ApiSecurity, + ApiTags, +} from '@nestjs/swagger'; +import { + AuthContext, + OrganizationId, + UserId, +} from '../auth/auth-context.decorator'; import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; import { PermissionGuard } from '../auth/permission.guard'; import { Public } from '../auth/public.decorator'; @@ -94,7 +103,7 @@ export class DeviceAgentController { // Construct Web API Headers from Express IncomingHttpHeaders const headers = new Headers(); const authHeader = req.headers['authorization']; - if (authHeader) headers.set('authorization', authHeader as string); + if (authHeader) headers.set('authorization', authHeader); const cookieHeader = req.headers['cookie']; if (cookieHeader) headers.set('cookie', cookieHeader); @@ -114,7 +123,10 @@ export class DeviceAgentController { @Post('register') @UseGuards(HybridAuthGuard) @SkipOrgCheck() - async registerDevice(@UserId() userId: string, @Body() dto: RegisterDeviceDto) { + async registerDevice( + @UserId() userId: string, + @Body() dto: RegisterDeviceDto, + ) { return this.deviceAgentAuthService.registerDevice({ userId, dto }); } diff --git a/apps/api/src/device-agent/device-agent.service.ts b/apps/api/src/device-agent/device-agent.service.ts index d763d8c1b9..60cf93c963 100644 --- a/apps/api/src/device-agent/device-agent.service.ts +++ b/apps/api/src/device-agent/device-agent.service.ts @@ -167,11 +167,11 @@ export class DeviceAgentService { } } - async getUpdateFile({ - filename, - }: { - filename: string; - }): Promise<{ stream: Readable; contentType: string; contentLength?: number }> { + async getUpdateFile({ filename }: { filename: string }): Promise<{ + stream: Readable; + contentType: string; + contentLength?: number; + }> { if (!isValidFilename(filename)) { throw new NotFoundException('Not found'); } diff --git a/apps/api/src/device-agent/dto/check-in.dto.ts b/apps/api/src/device-agent/dto/check-in.dto.ts index 06bafa1d4e..45cb75eb3b 100644 --- a/apps/api/src/device-agent/dto/check-in.dto.ts +++ b/apps/api/src/device-agent/dto/check-in.dto.ts @@ -33,7 +33,11 @@ class CheckDetailsDto { export class CheckResultDto { @IsEnum(['disk_encryption', 'antivirus', 'password_policy', 'screen_lock']) - checkType: 'disk_encryption' | 'antivirus' | 'password_policy' | 'screen_lock'; + checkType: + | 'disk_encryption' + | 'antivirus' + | 'password_policy' + | 'screen_lock'; @IsBoolean() passed: boolean; diff --git a/apps/api/src/device-agent/dto/register-device.dto.ts b/apps/api/src/device-agent/dto/register-device.dto.ts index eb28500738..eef13da5e3 100644 --- a/apps/api/src/device-agent/dto/register-device.dto.ts +++ b/apps/api/src/device-agent/dto/register-device.dto.ts @@ -1,9 +1,4 @@ -import { - IsEnum, - IsOptional, - IsString, - MinLength, -} from 'class-validator'; +import { IsEnum, IsOptional, IsString, MinLength } from 'class-validator'; export class RegisterDeviceDto { @IsString() diff --git a/apps/api/src/devices/devices.controller.spec.ts b/apps/api/src/devices/devices.controller.spec.ts index 7d03773005..048605c62c 100644 --- a/apps/api/src/devices/devices.controller.spec.ts +++ b/apps/api/src/devices/devices.controller.spec.ts @@ -8,7 +8,15 @@ import type { AuthContext as AuthContextType } from '../auth/types'; jest.mock('@db', () => ({ ...jest.requireActual('@prisma/client'), db: {}, - Prisma: { PrismaClientKnownRequestError: class PrismaClientKnownRequestError extends Error { code: string; constructor(message: string, { code }: { code: string }) { super(message); this.code = code; } } }, + Prisma: { + PrismaClientKnownRequestError: class PrismaClientKnownRequestError extends Error { + code: string; + constructor(message: string, { code }: { code: string }) { + super(message); + this.code = code; + } + }, + }, })); jest.mock('../auth/auth.server', () => ({ diff --git a/apps/api/src/email/unsubscribe.controller.ts b/apps/api/src/email/unsubscribe.controller.ts index 079525fe02..8799e0868d 100644 --- a/apps/api/src/email/unsubscribe.controller.ts +++ b/apps/api/src/email/unsubscribe.controller.ts @@ -1,4 +1,11 @@ -import { Controller, Post, Body, Query, HttpCode, BadRequestException } from '@nestjs/common'; +import { + Controller, + Post, + Body, + Query, + HttpCode, + BadRequestException, +} from '@nestjs/common'; import { ApiOperation, ApiTags } from '@nestjs/swagger'; import { db } from '@db'; import { generateUnsubscribeToken } from '@trycompai/email'; diff --git a/apps/api/src/evidence-forms/evidence-forms.service.spec.ts b/apps/api/src/evidence-forms/evidence-forms.service.spec.ts index 9a02a59303..670ee8d6ad 100644 --- a/apps/api/src/evidence-forms/evidence-forms.service.spec.ts +++ b/apps/api/src/evidence-forms/evidence-forms.service.spec.ts @@ -47,7 +47,8 @@ describe('EvidenceFormsService', () => { const authContext: AuthContext = { organizationId: 'org_123', authType: 'session', - isApiKey: false, isPlatformAdmin: false, + isApiKey: false, + isPlatformAdmin: false, userRoles: ['admin'], userId: 'usr_reviewer', userEmail: 'reviewer@example.com', diff --git a/apps/api/src/evidence-forms/evidence-forms.service.ts b/apps/api/src/evidence-forms/evidence-forms.service.ts index 01408529de..53e6c6edae 100644 --- a/apps/api/src/evidence-forms/evidence-forms.service.ts +++ b/apps/api/src/evidence-forms/evidence-forms.service.ts @@ -163,7 +163,9 @@ export class EvidenceFormsService { private requireEvidenceDeleteAccess(authContext: AuthContext): string { const userId = this.requireJwtUser(authContext); const roles = authContext.userRoles ?? []; - const canDelete = EVIDENCE_FORM_DELETE_ROLES.some((role) => roles.includes(role)); + const canDelete = EVIDENCE_FORM_DELETE_ROLES.some((role) => + roles.includes(role), + ); if (!canDelete) { throw new UnauthorizedException( @@ -214,10 +216,9 @@ export class EvidenceFormsService { typeof (value as Record).fileKey === 'string' ) { const fileObj = value as Record; - const freshUrl = - await this.attachmentsService.getPresignedDownloadUrl( - fileObj.fileKey as string, - ); + const freshUrl = await this.attachmentsService.getPresignedDownloadUrl( + fileObj.fileKey as string, + ); refreshed[key] = { ...fileObj, downloadUrl: freshUrl }; } } diff --git a/apps/api/src/findings/finding-notifier.service.spec.ts b/apps/api/src/findings/finding-notifier.service.spec.ts index 9ba8a4e873..a65f76300d 100644 --- a/apps/api/src/findings/finding-notifier.service.spec.ts +++ b/apps/api/src/findings/finding-notifier.service.spec.ts @@ -80,7 +80,9 @@ const { db, FindingType } = mockDbModule; describe('FindingNotifierService', () => { const mockedDb = db; - const mockedTriggerEmail = triggerEmail as jest.MockedFunction; + const mockedTriggerEmail = triggerEmail as jest.MockedFunction< + typeof triggerEmail + >; const mockedIsUserUnsubscribed = isUserUnsubscribed as jest.MockedFunction< typeof isUserUnsubscribed >; diff --git a/apps/api/src/findings/finding-notifier.service.ts b/apps/api/src/findings/finding-notifier.service.ts index a7fff542bc..e233cf0adf 100644 --- a/apps/api/src/findings/finding-notifier.service.ts +++ b/apps/api/src/findings/finding-notifier.service.ts @@ -410,7 +410,12 @@ export class FindingNotifierService { try { // Check unsubscribe preferences - const isUnsubscribed = await isUserUnsubscribed(db, recipient.email, 'findingNotifications', organizationId); + const isUnsubscribed = await isUserUnsubscribed( + db, + recipient.email, + 'findingNotifications', + organizationId, + ); if (isUnsubscribed) { this.logger.log( diff --git a/apps/api/src/findings/findings.controller.spec.ts b/apps/api/src/findings/findings.controller.spec.ts index 2010d9b328..b911cf0fe4 100644 --- a/apps/api/src/findings/findings.controller.spec.ts +++ b/apps/api/src/findings/findings.controller.spec.ts @@ -54,7 +54,8 @@ describe('FindingsController', () => { const authContext: AuthContext = { organizationId: 'org_123', authType: 'session', - isApiKey: false, isPlatformAdmin: false, + isApiKey: false, + isPlatformAdmin: false, userRoles: ['admin'], userId: 'usr_123', userEmail: 'admin@example.com', diff --git a/apps/api/src/framework-editor/control-template/control-template.controller.ts b/apps/api/src/framework-editor/control-template/control-template.controller.ts index 08f5506618..b108f5a17f 100644 --- a/apps/api/src/framework-editor/control-template/control-template.controller.ts +++ b/apps/api/src/framework-editor/control-template/control-template.controller.ts @@ -50,10 +50,7 @@ export class ControlTemplateController { @Patch(':id') @UsePipes(new ValidationPipe({ whitelist: true, transform: true })) - async update( - @Param('id') id: string, - @Body() dto: UpdateControlTemplateDto, - ) { + async update(@Param('id') id: string, @Body() dto: UpdateControlTemplateDto) { return this.service.update(id, dto); } @@ -95,10 +92,7 @@ export class ControlTemplateController { } @Post(':id/task-templates/:ttId') - async linkTaskTemplate( - @Param('id') id: string, - @Param('ttId') ttId: string, - ) { + async linkTaskTemplate(@Param('id') id: string, @Param('ttId') ttId: string) { return this.service.linkTaskTemplate(id, ttId); } diff --git a/apps/api/src/framework-editor/control-template/control-template.service.ts b/apps/api/src/framework-editor/control-template/control-template.service.ts index 16f346a141..1e91e4073f 100644 --- a/apps/api/src/framework-editor/control-template/control-template.service.ts +++ b/apps/api/src/framework-editor/control-template/control-template.service.ts @@ -1,4 +1,9 @@ -import { Injectable, NotFoundException, ConflictException, Logger } from '@nestjs/common'; +import { + Injectable, + NotFoundException, + ConflictException, + Logger, +} from '@nestjs/common'; import { db, Prisma } from '@db'; import type { EvidenceFormType } from '@db'; import { CreateControlTemplateDto } from './dto/create-control-template.dto'; diff --git a/apps/api/src/framework-editor/framework/framework-export.service.ts b/apps/api/src/framework-editor/framework/framework-export.service.ts index ba2455bd49..6637eb4c74 100644 --- a/apps/api/src/framework-editor/framework/framework-export.service.ts +++ b/apps/api/src/framework-editor/framework/framework-export.service.ts @@ -211,7 +211,7 @@ export class FrameworkExportService { data: { name: ct.name, description: ct.description, - documentTypes: (ct.documentTypes ?? []) as EvidenceFormType[], + documentTypes: ct.documentTypes ?? [], requirements: { connect: (ct.requirementIndices ?? []).map((i) => ({ id: createdRequirements[i].id, diff --git a/apps/api/src/framework-editor/framework/framework.controller.ts b/apps/api/src/framework-editor/framework/framework.controller.ts index 4564b32df0..8d05c3a9ef 100644 --- a/apps/api/src/framework-editor/framework/framework.controller.ts +++ b/apps/api/src/framework-editor/framework/framework.controller.ts @@ -29,10 +29,7 @@ export class FrameworkEditorFrameworkController { ) {} @Get() - async findAll( - @Query('take') take?: string, - @Query('skip') skip?: string, - ) { + async findAll(@Query('take') take?: string, @Query('skip') skip?: string) { const limit = Math.min(Number(take) || 500, 500); const offset = Number(skip) || 0; return this.frameworkService.findAll(limit, offset); @@ -100,10 +97,7 @@ export class FrameworkEditorFrameworkController { } @Post(':id/link-task/:taskId') - async linkTask( - @Param('id') id: string, - @Param('taskId') taskId: string, - ) { + async linkTask(@Param('id') id: string, @Param('taskId') taskId: string) { return this.frameworkService.linkTask(id, taskId); } diff --git a/apps/api/src/framework-editor/framework/framework.service.ts b/apps/api/src/framework-editor/framework/framework.service.ts index 4e9e3736d4..ecab18013f 100644 --- a/apps/api/src/framework-editor/framework/framework.service.ts +++ b/apps/api/src/framework-editor/framework/framework.service.ts @@ -1,4 +1,9 @@ -import { Injectable, NotFoundException, ConflictException, Logger } from '@nestjs/common'; +import { + Injectable, + NotFoundException, + ConflictException, + Logger, +} from '@nestjs/common'; import { db, Prisma } from '@db'; import { CreateFrameworkDto } from './dto/create-framework.dto'; import { UpdateFrameworkDto } from './dto/update-framework.dto'; @@ -186,9 +191,7 @@ export class FrameworkEditorFrameworkService { data: { requirements: { connect: requirementIds } }, }); - this.logger.log( - `Linked control ${controlId} to framework ${frameworkId}`, - ); + this.logger.log(`Linked control ${controlId} to framework ${frameworkId}`); return { message: 'Control linked to framework' }; } @@ -238,9 +241,7 @@ export class FrameworkEditorFrameworkService { data: { controlTemplates: { connect: controlIds } }, }); - this.logger.log( - `Linked policy ${policyId} to framework ${frameworkId}`, - ); + this.logger.log(`Linked policy ${policyId} to framework ${frameworkId}`); return { message: 'Policy linked to framework' }; } } diff --git a/apps/api/src/framework-editor/policy-template/dto/create-policy-template.dto.ts b/apps/api/src/framework-editor/policy-template/dto/create-policy-template.dto.ts index 9f744ec8d7..fb1638a9ae 100644 --- a/apps/api/src/framework-editor/policy-template/dto/create-policy-template.dto.ts +++ b/apps/api/src/framework-editor/policy-template/dto/create-policy-template.dto.ts @@ -1,10 +1,5 @@ import { ApiProperty } from '@nestjs/swagger'; -import { - IsString, - IsNotEmpty, - IsEnum, - MaxLength, -} from 'class-validator'; +import { IsString, IsNotEmpty, IsEnum, MaxLength } from 'class-validator'; import { Frequency, Departments } from '@db'; export class CreatePolicyTemplateDto { diff --git a/apps/api/src/framework-editor/policy-template/policy-template.controller.ts b/apps/api/src/framework-editor/policy-template/policy-template.controller.ts index 6898361f76..939f6b9ddb 100644 --- a/apps/api/src/framework-editor/policy-template/policy-template.controller.ts +++ b/apps/api/src/framework-editor/policy-template/policy-template.controller.ts @@ -51,10 +51,7 @@ export class PolicyTemplateController { @Patch(':id') @UsePipes(new ValidationPipe({ whitelist: true, transform: true })) - async update( - @Param('id') id: string, - @Body() dto: UpdatePolicyTemplateDto, - ) { + async update(@Param('id') id: string, @Body() dto: UpdatePolicyTemplateDto) { return this.service.update(id, dto); } diff --git a/apps/api/src/framework-editor/policy-template/policy-template.service.ts b/apps/api/src/framework-editor/policy-template/policy-template.service.ts index 4b60ba77d6..082ea290c5 100644 --- a/apps/api/src/framework-editor/policy-template/policy-template.service.ts +++ b/apps/api/src/framework-editor/policy-template/policy-template.service.ts @@ -1,4 +1,9 @@ -import { Injectable, NotFoundException, ConflictException, Logger } from '@nestjs/common'; +import { + Injectable, + NotFoundException, + ConflictException, + Logger, +} from '@nestjs/common'; import { db, Prisma } from '@db'; import { CreatePolicyTemplateDto } from './dto/create-policy-template.dto'; import { UpdatePolicyTemplateDto } from './dto/update-policy-template.dto'; diff --git a/apps/api/src/framework-editor/requirement/dto/create-requirement.dto.ts b/apps/api/src/framework-editor/requirement/dto/create-requirement.dto.ts index 4b5d627ca1..1971a5c4ee 100644 --- a/apps/api/src/framework-editor/requirement/dto/create-requirement.dto.ts +++ b/apps/api/src/framework-editor/requirement/dto/create-requirement.dto.ts @@ -1,10 +1,5 @@ import { ApiProperty, ApiPropertyOptional } from '@nestjs/swagger'; -import { - IsString, - IsNotEmpty, - IsOptional, - MaxLength, -} from 'class-validator'; +import { IsString, IsNotEmpty, IsOptional, MaxLength } from 'class-validator'; export class CreateRequirementDto { @ApiProperty({ example: 'frk_abc123' }) diff --git a/apps/api/src/framework-editor/requirement/requirement.controller.ts b/apps/api/src/framework-editor/requirement/requirement.controller.ts index 8ed9159b47..cf8a56980c 100644 --- a/apps/api/src/framework-editor/requirement/requirement.controller.ts +++ b/apps/api/src/framework-editor/requirement/requirement.controller.ts @@ -24,10 +24,7 @@ export class RequirementController { constructor(private readonly service: RequirementService) {} @Get() - async findAll( - @Query('take') take?: string, - @Query('skip') skip?: string, - ) { + async findAll(@Query('take') take?: string, @Query('skip') skip?: string) { const limit = Math.min(Number(take) || 500, 500); const offset = Number(skip) || 0; return this.service.findAll(limit, offset); @@ -41,10 +38,7 @@ export class RequirementController { @Patch(':id') @UsePipes(new ValidationPipe({ whitelist: true, transform: true })) - async update( - @Param('id') id: string, - @Body() dto: UpdateRequirementDto, - ) { + async update(@Param('id') id: string, @Body() dto: UpdateRequirementDto) { return this.service.update(id, dto); } diff --git a/apps/api/src/framework-editor/requirement/requirement.service.ts b/apps/api/src/framework-editor/requirement/requirement.service.ts index 0f0067e513..b38305feb8 100644 --- a/apps/api/src/framework-editor/requirement/requirement.service.ts +++ b/apps/api/src/framework-editor/requirement/requirement.service.ts @@ -33,9 +33,7 @@ export class RequirementService { where: { id: dto.frameworkId }, }); if (!framework) { - throw new NotFoundException( - `Framework ${dto.frameworkId} not found`, - ); + throw new NotFoundException(`Framework ${dto.frameworkId} not found`); } const req = await db.frameworkEditorRequirement.create({ diff --git a/apps/api/src/framework-editor/task-template/task-template.controller.ts b/apps/api/src/framework-editor/task-template/task-template.controller.ts index fae00cc706..b261546dfd 100644 --- a/apps/api/src/framework-editor/task-template/task-template.controller.ts +++ b/apps/api/src/framework-editor/task-template/task-template.controller.ts @@ -59,9 +59,7 @@ export class TaskTemplateController { @ApiResponse(GET_ALL_TASK_TEMPLATES_RESPONSES[200]) @ApiResponse(GET_ALL_TASK_TEMPLATES_RESPONSES[401]) @ApiResponse(GET_ALL_TASK_TEMPLATES_RESPONSES[500]) - async getAllTaskTemplates( - @Query('frameworkId') frameworkId?: string, - ) { + async getAllTaskTemplates(@Query('frameworkId') frameworkId?: string) { return await this.taskTemplateService.findAll(frameworkId); } diff --git a/apps/api/src/framework-editor/validators/max-json-size.validator.ts b/apps/api/src/framework-editor/validators/max-json-size.validator.ts index 9a2255fce7..2d6bc81aaf 100644 --- a/apps/api/src/framework-editor/validators/max-json-size.validator.ts +++ b/apps/api/src/framework-editor/validators/max-json-size.validator.ts @@ -28,7 +28,10 @@ export class MaxJsonSizeConstraint implements ValidatorConstraintInterface { } } -export function MaxJsonSize(maxBytes?: number, validationOptions?: ValidationOptions) { +export function MaxJsonSize( + maxBytes?: number, + validationOptions?: ValidationOptions, +) { return function (object: object, propertyName: string) { registerDecorator({ target: object.constructor, diff --git a/apps/api/src/frameworks/frameworks-scores.helper.ts b/apps/api/src/frameworks/frameworks-scores.helper.ts index 72a213437c..04041b2cba 100644 --- a/apps/api/src/frameworks/frameworks-scores.helper.ts +++ b/apps/api/src/frameworks/frameworks-scores.helper.ts @@ -13,26 +13,27 @@ const GENERAL_TRAINING_IDS = ['sat-1', 'sat-2', 'sat-3', 'sat-4', 'sat-5']; const HIPAA_TRAINING_ID = 'hipaa-sat-1'; export async function getOverviewScores(organizationId: string) { - const [allPolicies, allTasks, employees, onboarding, org, hipaaInstance] = await Promise.all([ - db.policy.findMany({ where: { organizationId } }), - db.task.findMany({ where: { organizationId } }), - db.member.findMany({ - where: { organizationId, deactivated: false }, - include: { user: true }, - }), - db.onboarding.findUnique({ - where: { organizationId }, - select: { triggerJobId: true }, - }), - db.organization.findUnique({ - where: { id: organizationId }, - select: { securityTrainingStepEnabled: true }, - }), - db.frameworkInstance.findFirst({ - where: { organizationId, framework: { name: 'HIPAA' } }, - select: { id: true }, - }), - ]); + const [allPolicies, allTasks, employees, onboarding, org, hipaaInstance] = + await Promise.all([ + db.policy.findMany({ where: { organizationId } }), + db.task.findMany({ where: { organizationId } }), + db.member.findMany({ + where: { organizationId, deactivated: false }, + include: { user: true }, + }), + db.onboarding.findUnique({ + where: { organizationId }, + select: { triggerJobId: true }, + }), + db.organization.findUnique({ + where: { id: organizationId }, + select: { securityTrainingStepEnabled: true }, + }), + db.frameworkInstance.findFirst({ + where: { organizationId, framework: { name: 'HIPAA' } }, + select: { id: true }, + }), + ]); const securityTrainingStepEnabled = org?.securityTrainingStepEnabled === true; const hasHipaaFramework = !!hipaaInstance; @@ -56,14 +57,16 @@ export async function getOverviewScores(organizationId: string) { ); // People score — filter to members with compliance:required permission - const activeEmployees = await filterComplianceMembers(employees, organizationId); + const activeEmployees = await filterComplianceMembers( + employees, + organizationId, + ); let completedMembers = 0; if (activeEmployees.length > 0) { const requiredPolicies = allPolicies.filter( - (p) => - p.isRequiredToSign && p.status === 'published' && !p.isArchived, + (p) => p.isRequiredToSign && p.status === 'published' && !p.isArchived, ); const memberIds = activeEmployees.map((e) => e.id); @@ -92,7 +95,11 @@ export async function getOverviewScores(organizationId: string) { ? completedVideoIds.includes(HIPAA_TRAINING_ID) : true; - if (hasAcceptedAllPolicies && hasCompletedAllTraining && hasCompletedHipaa) { + if ( + hasAcceptedAllPolicies && + hasCompletedAllTraining && + hasCompletedHipaa + ) { completedMembers++; } } @@ -138,18 +145,25 @@ async function computeDocumentsScore(organizationId: string) { }; } - const includedForms = evidenceFormDefinitionList.filter((f) => !f.hidden && !f.optional); + const includedForms = evidenceFormDefinitionList.filter( + (f) => !f.hidden && !f.optional, + ); const totalDocuments = includedForms.length; const outstandingDocuments = includedForms.reduce((count, form) => { if (form.type === 'meeting') { const allMeetingsOutstanding = meetingSubTypeValues.every((subType) => { const lastSubmitted = statuses[subType]?.lastSubmittedAt; - return !lastSubmitted || Date.now() - new Date(lastSubmitted).getTime() > SIX_MONTHS_MS; + return ( + !lastSubmitted || + Date.now() - new Date(lastSubmitted).getTime() > SIX_MONTHS_MS + ); }); return allMeetingsOutstanding ? count + 1 : count; } const lastSubmitted = statuses[form.type]?.lastSubmittedAt; - const isOutstanding = !lastSubmitted || Date.now() - new Date(lastSubmitted).getTime() > SIX_MONTHS_MS; + const isOutstanding = + !lastSubmitted || + Date.now() - new Date(lastSubmitted).getTime() > SIX_MONTHS_MS; return isOutstanding ? count + 1 : count; }, 0); @@ -184,10 +198,7 @@ async function getOrganizationFindings(organizationId: string) { })); } -export async function getCurrentMember( - organizationId: string, - userId: string, -) { +export async function getCurrentMember(organizationId: string, userId: string) { const member = await db.member.findFirst({ where: { userId, organizationId, deactivated: false }, select: { id: true, role: true }, diff --git a/apps/api/src/frameworks/frameworks.controller.spec.ts b/apps/api/src/frameworks/frameworks.controller.spec.ts index 01554ddd53..a8250426f1 100644 --- a/apps/api/src/frameworks/frameworks.controller.spec.ts +++ b/apps/api/src/frameworks/frameworks.controller.spec.ts @@ -40,8 +40,16 @@ describe('FrameworksController', () => { describe('findAll', () => { it('should return framework instances with count', async () => { const mockData = [ - { id: 'fi1', frameworkId: 'f1', framework: { id: 'f1', name: 'ISO 27001' } }, - { id: 'fi2', frameworkId: 'f2', framework: { id: 'f2', name: 'SOC 2' } }, + { + id: 'fi1', + frameworkId: 'f1', + framework: { id: 'f1', name: 'ISO 27001' }, + }, + { + id: 'fi2', + frameworkId: 'f2', + framework: { id: 'f2', name: 'SOC 2' }, + }, ]; mockService.findAll.mockResolvedValue(mockData); diff --git a/apps/api/src/frameworks/frameworks.controller.ts b/apps/api/src/frameworks/frameworks.controller.ts index e3e03cb9ea..b8e5c3002d 100644 --- a/apps/api/src/frameworks/frameworks.controller.ts +++ b/apps/api/src/frameworks/frameworks.controller.ts @@ -8,7 +8,12 @@ import { Query, UseGuards, } from '@nestjs/common'; -import { ApiTags, ApiBearerAuth, ApiOperation, ApiQuery } from '@nestjs/swagger'; +import { + ApiTags, + ApiBearerAuth, + ApiOperation, + ApiQuery, +} from '@nestjs/swagger'; import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; import { PermissionGuard } from '../auth/permission.guard'; import { RequirePermission } from '../auth/require-permission.decorator'; @@ -44,7 +49,10 @@ export class FrameworksController { @Get('available') @SkipOrgCheck() - @ApiOperation({ summary: 'List available frameworks (requires session, no active org needed — used during onboarding)' }) + @ApiOperation({ + summary: + 'List available frameworks (requires session, no active org needed — used during onboarding)', + }) async findAvailable() { const data = await this.frameworksService.findAvailable(); return { data, count: data.length }; diff --git a/apps/api/src/frameworks/frameworks.service.spec.ts b/apps/api/src/frameworks/frameworks.service.spec.ts index b3dba39007..f714be36bb 100644 --- a/apps/api/src/frameworks/frameworks.service.spec.ts +++ b/apps/api/src/frameworks/frameworks.service.spec.ts @@ -91,7 +91,9 @@ describe('FrameworksService', () => { }); it('should throw NotFoundException when instance not found', async () => { - (mockDb.frameworkInstance.findUnique as jest.Mock).mockResolvedValue(null); + (mockDb.frameworkInstance.findUnique as jest.Mock).mockResolvedValue( + null, + ); await expect(service.delete('missing', 'org_1')).rejects.toThrow( NotFoundException, diff --git a/apps/api/src/frameworks/frameworks.service.ts b/apps/api/src/frameworks/frameworks.service.ts index 149a5f8b04..311783ec2c 100644 --- a/apps/api/src/frameworks/frameworks.service.ts +++ b/apps/api/src/frameworks/frameworks.service.ts @@ -129,31 +129,35 @@ export class FrameworksService { } } - const [requirementDefinitions, tasks, requirementMaps, evidenceSubmissions] = - await Promise.all([ - db.frameworkEditorRequirement.findMany({ - where: { frameworkId: fi.frameworkId }, - orderBy: { name: 'asc' }, - }), - db.task.findMany({ - where: { organizationId, controls: { some: { organizationId } } }, - include: { controls: true }, - }), - db.requirementMap.findMany({ - where: { frameworkInstanceId }, - include: { control: true }, - }), - allFormTypes.size > 0 - ? db.evidenceSubmission.findMany({ - where: { - organizationId, - formType: { in: Array.from(allFormTypes) }, - }, - select: { id: true, formType: true, createdAt: true }, - orderBy: { createdAt: 'desc' }, - }) - : Promise.resolve([]), - ]); + const [ + requirementDefinitions, + tasks, + requirementMaps, + evidenceSubmissions, + ] = await Promise.all([ + db.frameworkEditorRequirement.findMany({ + where: { frameworkId: fi.frameworkId }, + orderBy: { name: 'asc' }, + }), + db.task.findMany({ + where: { organizationId, controls: { some: { organizationId } } }, + include: { controls: true }, + }), + db.requirementMap.findMany({ + where: { frameworkInstanceId }, + include: { control: true }, + }), + allFormTypes.size > 0 + ? db.evidenceSubmission.findMany({ + where: { + organizationId, + formType: { in: Array.from(allFormTypes) }, + }, + select: { id: true, formType: true, createdAt: true }, + orderBy: { createdAt: 'desc' }, + }) + : Promise.resolve([]), + ]); return { ...rest, @@ -181,10 +185,7 @@ export class FrameworksService { return { ...scores, currentMember }; } - async addFrameworks( - organizationId: string, - frameworkIds: string[], - ) { + async addFrameworks(organizationId: string, frameworkIds: string[]) { const result = await db.$transaction(async (tx) => { const frameworks = await tx.frameworkEditorFramework.findMany({ where: { id: { in: frameworkIds }, visible: true }, @@ -262,16 +263,17 @@ export class FrameworksService { } } - const evidenceSubmissions = formTypes.size > 0 - ? await db.evidenceSubmission.findMany({ - where: { - organizationId, - formType: { in: Array.from(formTypes) }, - }, - select: { id: true, formType: true, createdAt: true }, - orderBy: { createdAt: 'desc' }, - }) - : []; + const evidenceSubmissions = + formTypes.size > 0 + ? await db.evidenceSubmission.findMany({ + where: { + organizationId, + formType: { in: Array.from(formTypes) }, + }, + select: { id: true, formType: true, createdAt: true }, + orderBy: { createdAt: 'desc' }, + }) + : []; const siblingRequirements = allReqDefs .filter((r) => r.id !== requirementKey) diff --git a/apps/api/src/integration-platform/controllers/admin-integrations.controller.ts b/apps/api/src/integration-platform/controllers/admin-integrations.controller.ts index 3f525b57e8..bf3526a08e 100644 --- a/apps/api/src/integration-platform/controllers/admin-integrations.controller.ts +++ b/apps/api/src/integration-platform/controllers/admin-integrations.controller.ts @@ -72,8 +72,11 @@ export class AdminIntegrationsController { encryptedClientId: credential?.encryptedClientId, encryptedClientSecret: credential?.encryptedClientSecret, existingCustomSettings: - (credential as { customSettings?: Record } | undefined) - ?.customSettings || undefined, + ( + credential as + | { customSettings?: Record } + | undefined + )?.customSettings || undefined, ...(manifest.auth.type === 'oauth2' && { setupInstructions: manifest.auth.config.setupInstructions, createAppUrl: manifest.auth.config.createAppUrl, diff --git a/apps/api/src/integration-platform/controllers/checks.controller.ts b/apps/api/src/integration-platform/controllers/checks.controller.ts index e2e7881e2f..ccd5134a67 100644 --- a/apps/api/src/integration-platform/controllers/checks.controller.ts +++ b/apps/api/src/integration-platform/controllers/checks.controller.ts @@ -76,7 +76,10 @@ export class ChecksController { @Param('connectionId') connectionId: string, @OrganizationId() organizationId: string, ) { - await this.connectionService.getConnectionForOrg(connectionId, organizationId); + await this.connectionService.getConnectionForOrg( + connectionId, + organizationId, + ); const connection = await this.connectionRepository.findById(connectionId); if (!connection) { throw new HttpException('Connection not found', HttpStatus.NOT_FOUND); @@ -116,7 +119,10 @@ export class ChecksController { @Body() body: RunChecksDto, @OrganizationId() organizationId: string, ) { - await this.connectionService.getConnectionForOrg(connectionId, organizationId); + await this.connectionService.getConnectionForOrg( + connectionId, + organizationId, + ); const connection = await this.connectionRepository.findById(connectionId); if (!connection) { throw new HttpException('Connection not found', HttpStatus.NOT_FOUND); @@ -291,9 +297,10 @@ export class ChecksController { totalChecked: result.results.length, passedCount: result.totalPassing, failedCount: result.totalFindings, - logs: allLogs.length > 0 - ? (allLogs as unknown as Prisma.InputJsonValue) - : undefined, + logs: + allLogs.length > 0 + ? (allLogs as unknown as Prisma.InputJsonValue) + : undefined, }); return { @@ -305,7 +312,8 @@ export class ChecksController { } catch (error) { // Mark the check run as failed with error details const startTime = checkRun.startedAt?.getTime() || Date.now(); - const errorMessage = error instanceof Error ? error.message : String(error); + const errorMessage = + error instanceof Error ? error.message : String(error); const errorStack = error instanceof Error ? error.stack : undefined; await this.checkRunRepository.complete(checkRun.id, { status: 'failed', @@ -314,13 +322,15 @@ export class ChecksController { passedCount: 0, failedCount: 0, errorMessage, - logs: [{ - check: body.checkId || 'all', - level: 'error', - message: errorMessage, - ...(errorStack ? { data: { stack: errorStack } } : {}), - timestamp: new Date().toISOString(), - }] as unknown as Prisma.InputJsonValue, + logs: [ + { + check: body.checkId || 'all', + level: 'error', + message: errorMessage, + ...(errorStack ? { data: { stack: errorStack } } : {}), + timestamp: new Date().toISOString(), + }, + ] as unknown as Prisma.InputJsonValue, }); this.logger.error(`Check execution failed: ${error}`); diff --git a/apps/api/src/integration-platform/controllers/connections.controller.spec.ts b/apps/api/src/integration-platform/controllers/connections.controller.spec.ts index a53dad789b..31da8f9477 100644 --- a/apps/api/src/integration-platform/controllers/connections.controller.spec.ts +++ b/apps/api/src/integration-platform/controllers/connections.controller.spec.ts @@ -504,10 +504,7 @@ describe('ConnectionsController', () => { api_key: 'test-key', }); - const result = await controller.ensureValidCredentials( - 'conn_1', - 'org_1', - ); + const result = await controller.ensureValidCredentials('conn_1', 'org_1'); expect(result.success).toBe(true); expect(result.credentials).toEqual({ api_key: 'test-key' }); diff --git a/apps/api/src/integration-platform/controllers/connections.controller.ts b/apps/api/src/integration-platform/controllers/connections.controller.ts index 30df781e74..e318dce508 100644 --- a/apps/api/src/integration-platform/controllers/connections.controller.ts +++ b/apps/api/src/integration-platform/controllers/connections.controller.ts @@ -14,11 +14,12 @@ import { UseGuards, } from '@nestjs/common'; import { ApiTags, ApiSecurity } from '@nestjs/swagger'; -import { AssumeRoleCommand, STSClient } from '@aws-sdk/client-sts'; +import { db } from '@db'; import { - DescribeHubCommand, - SecurityHubClient, -} from '@aws-sdk/client-securityhub'; + AssumeRoleCommand, + GetCallerIdentityCommand, + STSClient, +} from '@aws-sdk/client-sts'; import { HybridAuthGuard } from '../../auth/hybrid-auth.guard'; import { PermissionGuard } from '../../auth/permission.guard'; import { RequirePermission } from '../../auth/require-permission.decorator'; @@ -28,6 +29,7 @@ import { CredentialVaultService } from '../services/credential-vault.service'; import { OAuthCredentialsService } from '../services/oauth-credentials.service'; import { AutoCheckRunnerService } from '../services/auto-check-runner.service'; import { ProviderRepository } from '../repositories/provider.repository'; +import { ConnectionRepository } from '../repositories/connection.repository'; import { getManifest, getAllManifests, @@ -63,6 +65,7 @@ export class ConnectionsController { private readonly oauthCredentialsService: OAuthCredentialsService, private readonly autoCheckRunnerService: AutoCheckRunnerService, private readonly providerRepository: ProviderRepository, + private readonly connectionRepository: ConnectionRepository, ) {} /** @@ -99,6 +102,9 @@ export class ConnectionsController { const setupInstructions = m.auth.type === 'custom' ? m.auth.config.setupInstructions : undefined; + const setupScript = + m.auth.type === 'custom' ? m.auth.config.setupScript : undefined; + // For OAuth providers, check if platform credentials are configured const oauthConfigured = m.auth.type === 'oauth2' @@ -147,10 +153,19 @@ export class ConnectionsController { docsUrl: m.docsUrl, credentialFields, setupInstructions, + setupScript, oauthConfigured, mappedTasks, requiredVariables: Array.from(requiredVariables), supportsMultipleConnections: m.supportsMultipleConnections ?? false, + services: + m.services?.map((s) => ({ + id: s.id, + name: s.name, + description: s.description, + enabledByDefault: s.enabledByDefault ?? true, + implemented: s.implemented ?? true, + })) ?? [], }; }); } @@ -181,6 +196,11 @@ export class ConnectionsController { ? manifest.auth.config.setupInstructions : undefined; + const setupScript = + manifest.auth.type === 'custom' + ? manifest.auth.config.setupScript + : undefined; + // Get mapped tasks from checks const mappedTasks: Array<{ id: string; name: string }> = []; const seenTaskIds = new Set(); @@ -225,8 +245,19 @@ export class ConnectionsController { docsUrl: manifest.docsUrl, credentialFields, setupInstructions, + setupScript, mappedTasks, requiredVariables: Array.from(requiredVariables), + supportsMultipleConnections: + manifest.supportsMultipleConnections ?? false, + services: + manifest.services?.map((s) => ({ + id: s.id, + name: s.name, + description: s.description, + enabledByDefault: s.enabledByDefault ?? true, + implemented: s.implemented ?? true, + })) ?? [], }; } @@ -264,7 +295,10 @@ export class ConnectionsController { @Param('id') id: string, @OrganizationId() organizationId: string, ) { - const connection = await this.connectionService.getConnectionForOrg(id, organizationId); + const connection = await this.connectionService.getConnectionForOrg( + id, + organizationId, + ); const providerSlug = (connection as { provider?: { slug: string } }) .provider?.slug; @@ -289,6 +323,39 @@ export class ConnectionsController { } } + // Backfill metadata from credentials if missing (for connections created before metadata sync) + let metadata = (connection.metadata ?? {}) as Record; + if (providerSlug === 'aws' && !metadata.accountId) { + try { + const creds = + await this.credentialVaultService.getDecryptedCredentials(id); + if (creds) { + const updates: Record = {}; + if (typeof creds.roleArn === 'string') { + updates.roleArn = creds.roleArn; + const m = creds.roleArn.match(/^arn:aws:iam::(\d{12}):role\/.+$/); + if (m) updates.accountId = m[1]; + } + if (typeof creds.remediationRoleArn === 'string') { + updates.remediationRoleArn = creds.remediationRoleArn; + } + if (Array.isArray(creds.regions)) { + updates.regions = creds.regions; + } + if (typeof creds.externalId === 'string') { + updates.externalId = creds.externalId; + } + if (Object.keys(updates).length > 0) { + metadata = { ...metadata, ...updates }; + // Persist so this only runs once + await this.connectionRepository.update(id, { metadata }); + } + } + } catch { + // Non-critical — just use whatever metadata we have + } + } + return { id: connection.id, providerId: connection.providerId, @@ -300,7 +367,7 @@ export class ConnectionsController { lastSyncAt: connection.lastSyncAt, nextSyncAt: connection.nextSyncAt, syncCadence: connection.syncCadence, - metadata: connection.metadata, + metadata, variables: connection.variables, errorMessage: connection.errorMessage, createdAt: connection.createdAt, @@ -390,6 +457,12 @@ export class ConnectionsController { if (typeof credentials.externalId === 'string') { metadata.externalId = credentials.externalId; } + if ( + typeof credentials.remediationRoleArn === 'string' && + credentials.remediationRoleArn + ) { + metadata.remediationRoleArn = credentials.remediationRoleArn; + } // Store Azure tenant/subscription IDs in metadata for display and pre-filling if (typeof credentials.tenantId === 'string') { metadata.tenantId = credentials.tenantId; @@ -549,84 +622,35 @@ export class ConnectionsController { } this.logger.log( - 'Validating AWS: Role assumption successful, checking Security Hub...', + 'Validating AWS: Role assumption successful, verifying identity...', ); - // Step 3: Check Security Hub in each region + // Step 3: Verify assumed identity works const awsCredentials = { accessKeyId: customerCreds.AccessKeyId, secretAccessKey: customerCreds.SecretAccessKey, sessionToken: customerCreds.SessionToken, }; - const regionResults: { - region: string; - enabled: boolean; - error?: string; - }[] = []; - - for (const region of regions) { - try { - const securityHub = new SecurityHubClient({ - region, - credentials: awsCredentials, - }); - - await securityHub.send(new DescribeHubCommand({})); - regionResults.push({ region, enabled: true }); - this.logger.log(`Security Hub is enabled in ${region}`); - } catch (error) { - const errorMessage = - error instanceof Error ? error.message : String(error); - if ( - errorMessage.includes('not subscribed') || - errorMessage.includes('InvalidAccessException') - ) { - regionResults.push({ - region, - enabled: false, - error: 'Security Hub not enabled', - }); - this.logger.warn(`Security Hub not enabled in ${region}`); - } else if (errorMessage.includes('AccessDenied')) { - regionResults.push({ - region, - enabled: false, - error: 'Access denied - check IAM permissions', - }); - } else { - regionResults.push({ region, enabled: false, error: errorMessage }); - } - } - } - - // Check if ALL regions have Security Hub enabled - const failedRegions = regionResults.filter((r) => !r.enabled); - - if (failedRegions.length > 0) { - const failedRegionNames = failedRegions.map((r) => r.region).join(', '); - const errorMsg = - failedRegions.length === 1 - ? `Security Hub is not enabled in region: ${failedRegionNames}. Please enable Security Hub in this region or remove it from your selection.` - : `Security Hub is not enabled in ${failedRegions.length} regions: ${failedRegionNames}. Please enable Security Hub in these regions or remove them from your selection.`; - - return { - success: false, - message: errorMsg, - details: { regions: regionResults }, - }; - } + const customerSts = new STSClient({ + region: primaryRegion, + credentials: awsCredentials, + }); + const identity = await customerSts.send(new GetCallerIdentityCommand({})); + this.logger.log( + `Validated AWS identity: ${identity.Arn} (Account: ${identity.Account})`, + ); // All validations passed! const message = regions.length === 1 - ? `Validated! Security Hub is enabled in ${regions[0]}.` - : `Validated! Security Hub is enabled in all ${regions.length} regions.`; + ? `Validated! Connected to AWS account ${identity.Account} in ${regions[0]}.` + : `Validated! Connected to AWS account ${identity.Account} in ${regions.length} regions.`; return { success: true, message, - details: { regions: regionResults }, + details: { account: identity.Account, regions }, }; } catch (err) { const errorMessage = @@ -661,7 +685,10 @@ export class ConnectionsController { @Param('id') id: string, @OrganizationId() organizationId: string, ) { - const connection = await this.connectionService.getConnectionForOrg(id, organizationId); + const connection = await this.connectionService.getConnectionForOrg( + id, + organizationId, + ); const providerSlug = (connection as any).provider?.slug; if (!providerSlug) { @@ -811,7 +838,10 @@ export class ConnectionsController { @OrganizationId() organizationId: string, @Body() body: { metadata?: Record }, ) { - const connection = await this.connectionService.getConnectionForOrg(id, organizationId); + const connection = await this.connectionService.getConnectionForOrg( + id, + organizationId, + ); if (body.metadata && Object.keys(body.metadata).length > 0) { // Merge with existing metadata @@ -840,7 +870,10 @@ export class ConnectionsController { @Param('id') id: string, @OrganizationId() organizationId: string, ) { - const connection = await this.connectionService.getConnectionForOrg(id, organizationId); + const connection = await this.connectionService.getConnectionForOrg( + id, + organizationId, + ); if (connection.status !== 'active') { throw new HttpException( @@ -990,6 +1023,74 @@ export class ConnectionsController { }; } + /** + * Update enabled services for a connection + */ + @Put(':id/services') + @RequirePermission('integration', 'update') + async updateConnectionServices( + @Param('id') id: string, + @Body() body: { services: string[] }, + @OrganizationId() organizationId: string, + ) { + if (!Array.isArray(body.services)) { + throw new HttpException( + 'services must be an array of service IDs', + HttpStatus.BAD_REQUEST, + ); + } + + const connection = await this.connectionService.getConnectionForOrg( + id, + organizationId, + ); + + const raw = connection.variables; + const existingVariables: Record = + raw && typeof raw === 'object' && !Array.isArray(raw) + ? (raw as Record) + : {}; + + // Get ALL possible services from the manifest + const provider = await db.integrationProvider.findUnique({ + where: { id: connection.providerId }, + select: { slug: true }, + }); + const manifest = provider ? getManifest(provider.slug) : null; + const allManifestServices = new Set( + manifest?.services?.map((s: { id: string }) => s.id) ?? [], + ); + + // disabledServices = all manifest services MINUS what user sent as enabled + const enabledSet = new Set(body.services); + const disabledServices = [...allManifestServices].filter( + (s) => !enabledSet.has(s), + ); + + // Merge user-enabled services into detectedServices so the GET + // logic treats them as "known" services (user intent > auto-detection) + const currentDetected = new Set( + Array.isArray(existingVariables.detectedServices) + ? (existingVariables.detectedServices as string[]) + : [], + ); + for (const id of body.services) { + currentDetected.add(id); + } + + await this.connectionRepository.update(id, { + variables: { + ...existingVariables, + disabledServices, + detectedServices: [...currentDetected], + // Clear legacy enabledServices to use new smart logic + enabledServices: undefined, + }, + }); + + return { success: true, disabledServices }; + } + /** * Update credentials for a custom auth connection */ @@ -1000,7 +1101,10 @@ export class ConnectionsController { @OrganizationId() organizationId: string, @Body() body: { credentials: Record }, ) { - const connection = await this.connectionService.getConnectionForOrg(id, organizationId); + const connection = await this.connectionService.getConnectionForOrg( + id, + organizationId, + ); const providerSlug = (connection as { provider?: { slug: string } }) .provider?.slug; @@ -1059,6 +1163,29 @@ export class ConnectionsController { mergedCredentials, ); + // Sync non-secret fields to metadata for display (pre-fill settings forms) + const metaUpdates: Record = {}; + if (typeof mergedCredentials.roleArn === 'string') { + metaUpdates.roleArn = mergedCredentials.roleArn; + const arnMatch = mergedCredentials.roleArn.match( + /^arn:aws:iam::(\d{12}):role\/.+$/, + ); + if (arnMatch) metaUpdates.accountId = arnMatch[1]; + } + if (typeof mergedCredentials.remediationRoleArn === 'string') { + metaUpdates.remediationRoleArn = mergedCredentials.remediationRoleArn; + } + if (Array.isArray(mergedCredentials.regions)) { + metaUpdates.regions = mergedCredentials.regions; + } + if (Object.keys(metaUpdates).length > 0) { + const existingMeta = + (connection.metadata as Record) ?? {}; + await this.connectionRepository.update(id, { + metadata: { ...existingMeta, ...metaUpdates }, + }); + } + // Only activate the connection if it was in error state (don't resume paused connections) if (connection.status === 'error') { await this.connectionService.activateConnection(id); diff --git a/apps/api/src/integration-platform/controllers/dynamic-integrations.controller.ts b/apps/api/src/integration-platform/controllers/dynamic-integrations.controller.ts index f906bbad65..5a7b61546c 100644 --- a/apps/api/src/integration-platform/controllers/dynamic-integrations.controller.ts +++ b/apps/api/src/integration-platform/controllers/dynamic-integrations.controller.ts @@ -47,7 +47,10 @@ export class DynamicIntegrationsController { const validation = validateIntegrationDefinition(body); if (!validation.success) { throw new HttpException( - { message: 'Invalid integration definition', errors: validation.errors }, + { + message: 'Invalid integration definition', + errors: validation.errors, + }, HttpStatus.BAD_REQUEST, ); } @@ -55,7 +58,7 @@ export class DynamicIntegrationsController { const def = validation.data!; // Validate and store syncDefinition through Zod to apply defaults (e.g., employeesPath) - const rawSyncDef = (body as Record).syncDefinition; + const rawSyncDef = body.syncDefinition; const validatedSyncDef = rawSyncDef ? SyncDefinitionSchema.parse(rawSyncDef) : undefined; @@ -74,12 +77,17 @@ export class DynamicIntegrationsController { capabilities: def.capabilities as unknown as Prisma.InputJsonValue, supportsMultipleConnections: def.supportsMultipleConnections, syncDefinition: validatedSyncDef - ? (JSON.parse(JSON.stringify(validatedSyncDef)) as Prisma.InputJsonValue) + ? (JSON.parse( + JSON.stringify(validatedSyncDef), + ) as Prisma.InputJsonValue) : null, + services: (def.services as unknown as Prisma.InputJsonValue) ?? undefined, }); // Delete checks not in the new definition, then upsert the rest - const existingChecks = await this.dynamicCheckRepo.findByIntegrationId(integration.id); + const existingChecks = await this.dynamicCheckRepo.findByIntegrationId( + integration.id, + ); const newCheckSlugs = new Set(def.checks.map((c) => c.checkSlug)); for (const existing of existingChecks) { if (!newCheckSlugs.has(existing.checkSlug)) { @@ -99,6 +107,7 @@ export class DynamicIntegrationsController { variables: (check.variables ?? []) as unknown as Prisma.InputJsonValue, isEnabled: check.isEnabled ?? true, sortOrder: check.sortOrder ?? index, + service: check.service ?? undefined, }); } @@ -114,7 +123,9 @@ export class DynamicIntegrationsController { // Refresh registry await this.loaderService.invalidateCache(); - this.logger.log(`Upserted dynamic integration: ${def.slug} with ${def.checks.length} checks`); + this.logger.log( + `Upserted dynamic integration: ${def.slug} with ${def.checks.length} checks`, + ); return { success: true, @@ -132,7 +143,10 @@ export class DynamicIntegrationsController { const validation = validateIntegrationDefinition(body); if (!validation.success) { throw new HttpException( - { message: 'Invalid integration definition', errors: validation.errors }, + { + message: 'Invalid integration definition', + errors: validation.errors, + }, HttpStatus.BAD_REQUEST, ); } @@ -147,7 +161,7 @@ export class DynamicIntegrationsController { ); } - const rawSyncDefCreate = (body as Record).syncDefinition; + const rawSyncDefCreate = body.syncDefinition; const validatedSyncDefCreate = rawSyncDefCreate ? SyncDefinitionSchema.parse(rawSyncDefCreate) : undefined; @@ -164,7 +178,9 @@ export class DynamicIntegrationsController { capabilities: def.capabilities as unknown as Prisma.InputJsonValue, supportsMultipleConnections: def.supportsMultipleConnections, syncDefinition: validatedSyncDefCreate - ? (JSON.parse(JSON.stringify(validatedSyncDefCreate)) as Prisma.InputJsonValue) + ? (JSON.parse( + JSON.stringify(validatedSyncDefCreate), + ) as Prisma.InputJsonValue) : undefined, }); @@ -194,7 +210,9 @@ export class DynamicIntegrationsController { await this.loaderService.invalidateCache(); - this.logger.log(`Created dynamic integration: ${def.slug} with ${def.checks.length} checks`); + this.logger.log( + `Created dynamic integration: ${def.slug} with ${def.checks.length} checks`, + ); return { success: true, id: integration.id, slug: integration.slug }; } @@ -225,7 +243,10 @@ export class DynamicIntegrationsController { async getById(@Param('id') id: string) { const integration = await this.dynamicIntegrationRepo.findById(id); if (!integration) { - throw new HttpException('Dynamic integration not found', HttpStatus.NOT_FOUND); + throw new HttpException( + 'Dynamic integration not found', + HttpStatus.NOT_FOUND, + ); } return integration; } @@ -237,7 +258,10 @@ export class DynamicIntegrationsController { async update(@Param('id') id: string, @Body() body: Record) { const existing = await this.dynamicIntegrationRepo.findById(id); if (!existing) { - throw new HttpException('Dynamic integration not found', HttpStatus.NOT_FOUND); + throw new HttpException( + 'Dynamic integration not found', + HttpStatus.NOT_FOUND, + ); } await this.dynamicIntegrationRepo.update(id, body); @@ -253,7 +277,10 @@ export class DynamicIntegrationsController { async remove(@Param('id') id: string) { const existing = await this.dynamicIntegrationRepo.findById(id); if (!existing) { - throw new HttpException('Dynamic integration not found', HttpStatus.NOT_FOUND); + throw new HttpException( + 'Dynamic integration not found', + HttpStatus.NOT_FOUND, + ); } await this.dynamicIntegrationRepo.delete(id); @@ -275,7 +302,10 @@ export class DynamicIntegrationsController { ) { const integration = await this.dynamicIntegrationRepo.findById(id); if (!integration) { - throw new HttpException('Dynamic integration not found', HttpStatus.NOT_FOUND); + throw new HttpException( + 'Dynamic integration not found', + HttpStatus.NOT_FOUND, + ); } const check = await this.dynamicCheckRepo.create({ @@ -344,7 +374,10 @@ export class DynamicIntegrationsController { async activate(@Param('id') id: string) { const integration = await this.dynamicIntegrationRepo.findById(id); if (!integration) { - throw new HttpException('Dynamic integration not found', HttpStatus.NOT_FOUND); + throw new HttpException( + 'Dynamic integration not found', + HttpStatus.NOT_FOUND, + ); } for (const check of integration.checks) { @@ -360,7 +393,9 @@ export class DynamicIntegrationsController { slug: integration.slug, name: integration.name, category: integration.category, - capabilities: (integration.capabilities as unknown as string[]) ?? ['checks'], + capabilities: (integration.capabilities as unknown as string[]) ?? [ + 'checks', + ], isActive: true, }); @@ -378,7 +413,10 @@ export class DynamicIntegrationsController { async deactivate(@Param('id') id: string) { const integration = await this.dynamicIntegrationRepo.findById(id); if (!integration) { - throw new HttpException('Dynamic integration not found', HttpStatus.NOT_FOUND); + throw new HttpException( + 'Dynamic integration not found', + HttpStatus.NOT_FOUND, + ); } await this.dynamicIntegrationRepo.update(id, { isActive: false }); @@ -419,7 +457,7 @@ export class DynamicIntegrationsController { capabilities: definition.capabilities, checksCount: definition.checks.length, checkSlugs: definition.checks.map((c) => c.checkSlug), - hasSyncDefinition: !!(body as Record).syncDefinition, + hasSyncDefinition: !!body.syncDefinition, }, }; } @@ -432,7 +470,10 @@ export class DynamicIntegrationsController { async getCheckRuns(@Param('id') id: string) { const integration = await this.dynamicIntegrationRepo.findById(id); if (!integration) { - throw new HttpException('Dynamic integration not found', HttpStatus.NOT_FOUND); + throw new HttpException( + 'Dynamic integration not found', + HttpStatus.NOT_FOUND, + ); } // Find all connections for this provider @@ -518,7 +559,10 @@ export class DynamicIntegrationsController { logs: run.logs, results: run.results, provider: run.connection?.provider - ? { slug: run.connection.provider.slug, name: run.connection.provider.name } + ? { + slug: run.connection.provider.slug, + name: run.connection.provider.name, + } : null, }; } diff --git a/apps/api/src/integration-platform/controllers/oauth-apps.controller.ts b/apps/api/src/integration-platform/controllers/oauth-apps.controller.ts index 0fc392f700..4b71d37054 100644 --- a/apps/api/src/integration-platform/controllers/oauth-apps.controller.ts +++ b/apps/api/src/integration-platform/controllers/oauth-apps.controller.ts @@ -111,12 +111,7 @@ export class OAuthAppsController { @OrganizationId() organizationId: string, @Body() body: SaveOAuthAppDto, ) { - const { - providerSlug, - clientId, - clientSecret, - customScopes, - } = body; + const { providerSlug, clientId, clientSecret, customScopes } = body; // Validate provider const manifest = getManifest(providerSlug); diff --git a/apps/api/src/integration-platform/controllers/oauth.controller.spec.ts b/apps/api/src/integration-platform/controllers/oauth.controller.spec.ts index 9577a09a91..a48d75e5b2 100644 --- a/apps/api/src/integration-platform/controllers/oauth.controller.spec.ts +++ b/apps/api/src/integration-platform/controllers/oauth.controller.spec.ts @@ -10,11 +10,20 @@ import { CredentialVaultService } from '../services/credential-vault.service'; import { ConnectionService } from '../services/connection.service'; import { OAuthCredentialsService } from '../services/oauth-credentials.service'; import { AutoCheckRunnerService } from '../services/auto-check-runner.service'; +import { CloudSecurityService } from '../../cloud-security/cloud-security.service'; jest.mock('../../auth/auth.server', () => ({ auth: { api: { getSession: jest.fn() } }, })); +jest.mock('../../auth/hybrid-auth.guard', () => ({ + HybridAuthGuard: class HybridAuthGuard {}, +})); + +jest.mock('../../auth/permission.guard', () => ({ + PermissionGuard: class PermissionGuard {}, +})); + jest.mock('@trycompai/auth', () => ({ statement: { integration: ['create', 'read', 'update', 'delete'], @@ -26,11 +35,21 @@ jest.mock('@trycompai/integration-platform', () => ({ getManifest: jest.fn(), })); +jest.mock('@trigger.dev/sdk', () => ({ + tasks: { + trigger: jest.fn(), + }, +})); + import { getManifest } from '@trycompai/integration-platform'; +import { tasks } from '@trigger.dev/sdk'; const mockedGetManifest = getManifest as jest.MockedFunction< typeof getManifest >; +const mockedTriggerTask = tasks.trigger as jest.MockedFunction< + typeof tasks.trigger +>; describe('OAuthController', () => { let controller: OAuthController; @@ -48,6 +67,7 @@ describe('OAuthController', () => { const mockConnectionRepository = { findByProviderAndOrg: jest.fn(), + update: jest.fn(), }; const mockCredentialVaultService = { @@ -56,6 +76,7 @@ describe('OAuthController', () => { const mockConnectionService = { createConnection: jest.fn(), + activateConnection: jest.fn(), }; const mockOAuthCredentialsService = { @@ -67,6 +88,10 @@ describe('OAuthController', () => { tryAutoRunChecks: jest.fn().mockResolvedValue(false), }; + const mockCloudSecurityService = { + detectServices: jest.fn().mockResolvedValue([]), + }; + const mockGuard = { canActivate: jest.fn().mockReturnValue(true) }; beforeEach(async () => { @@ -89,6 +114,10 @@ describe('OAuthController', () => { provide: AutoCheckRunnerService, useValue: mockAutoCheckRunnerService, }, + { + provide: CloudSecurityService, + useValue: mockCloudSecurityService, + }, ], }) .overrideGuard(HybridAuthGuard) @@ -101,6 +130,8 @@ describe('OAuthController', () => { jest.clearAllMocks(); mockAutoCheckRunnerService.tryAutoRunChecks.mockResolvedValue(false); + mockCloudSecurityService.detectServices.mockResolvedValue([]); + mockedTriggerTask.mockResolvedValue({ id: 'run_1' } as never); }); describe('checkAvailability', () => { @@ -122,9 +153,9 @@ describe('OAuthController', () => { }); it('should throw BAD_REQUEST when providerSlug is empty', async () => { - await expect( - controller.checkAvailability('', 'org_1'), - ).rejects.toThrow(HttpException); + await expect(controller.checkAvailability('', 'org_1')).rejects.toThrow( + HttpException, + ); }); }); @@ -296,10 +327,7 @@ describe('OAuthController', () => { }); it('should redirect with error when code or state is missing', async () => { - await controller.oauthCallback( - { code: '', state: '' }, - mockResponse, - ); + await controller.oauthCallback({ code: '', state: '' }, mockResponse); expect(mockResponse.redirect).toHaveBeenCalled(); const redirectUrl = (mockResponse.redirect as jest.Mock).mock.calls[0][0]; @@ -367,5 +395,168 @@ describe('OAuthController', () => { const redirectUrl = (mockResponse.redirect as jest.Mock).mock.calls[0][0]; expect(redirectUrl).toContain('error=token_exchange_failed'); }); + + it('should trigger initial GCP service discovery scan on successful first connect', async () => { + const futureDate = new Date(Date.now() + 600000); + mockOAuthStateRepository.findByState.mockResolvedValue({ + state: 'valid_gcp_state', + providerSlug: 'gcp', + organizationId: 'org_1', + userId: 'user_1', + codeVerifier: null, + redirectUrl: null, + expiresAt: futureDate, + }); + mockedGetManifest.mockReturnValue({ + id: 'gcp', + name: 'Google Cloud Platform', + category: 'Cloud', + auth: { + type: 'oauth2', + config: { + authorizeUrl: 'https://accounts.google.com/o/oauth2/v2/auth', + tokenUrl: 'https://oauth2.googleapis.com/token', + }, + }, + capabilities: [], + isActive: true, + } as never); + mockOAuthCredentialsService.getCredentials.mockResolvedValue({ + clientId: 'client_123', + clientSecret: 'secret_456', + scopes: ['openid'], + source: 'platform', + }); + mockProviderRepository.findBySlug.mockResolvedValue({ + id: 'provider_1', + }); + mockConnectionRepository.findByProviderAndOrg.mockResolvedValue({ + id: 'conn_1', + metadata: {}, + variables: {}, + lastSyncAt: null, + }); + mockConnectionRepository.update.mockResolvedValue({ + id: 'conn_1', + metadata: {}, + variables: {}, + lastSyncAt: null, + }); + mockConnectionService.activateConnection.mockResolvedValue({ + id: 'conn_1', + }); + + const fetchSpy = jest.spyOn(global, 'fetch').mockResolvedValue({ + ok: true, + status: 200, + json: () => Promise.resolve({ access_token: 'access_123' }), + text: () => Promise.resolve(''), + } as unknown as Response); + + await controller.oauthCallback( + { code: 'auth_code', state: 'valid_gcp_state' }, + mockResponse, + ); + + await new Promise((resolve) => setImmediate(resolve)); + + expect(mockCloudSecurityService.detectServices).toHaveBeenCalledWith( + 'conn_1', + 'org_1', + ); + expect(mockedTriggerTask).toHaveBeenCalledWith( + 'run-cloud-security-scan', + { + connectionId: 'conn_1', + organizationId: 'org_1', + providerSlug: 'gcp', + connectionName: 'conn_1', + }, + ); + expect(mockResponse.redirect).toHaveBeenCalled(); + const redirectUrl = (mockResponse.redirect as jest.Mock).mock.calls[0][0]; + expect(redirectUrl).toContain('success=true'); + expect(redirectUrl).toContain('provider=gcp'); + + fetchSpy.mockRestore(); + }); + + it('should skip initial GCP service discovery scan when detection already completed', async () => { + const futureDate = new Date(Date.now() + 600000); + mockOAuthStateRepository.findByState.mockResolvedValue({ + state: 'valid_gcp_state', + providerSlug: 'gcp', + organizationId: 'org_1', + userId: 'user_1', + codeVerifier: null, + redirectUrl: null, + expiresAt: futureDate, + }); + mockedGetManifest.mockReturnValue({ + id: 'gcp', + name: 'Google Cloud Platform', + category: 'Cloud', + auth: { + type: 'oauth2', + config: { + authorizeUrl: 'https://accounts.google.com/o/oauth2/v2/auth', + tokenUrl: 'https://oauth2.googleapis.com/token', + }, + }, + capabilities: [], + isActive: true, + } as never); + mockOAuthCredentialsService.getCredentials.mockResolvedValue({ + clientId: 'client_123', + clientSecret: 'secret_456', + scopes: ['openid'], + source: 'platform', + }); + mockProviderRepository.findBySlug.mockResolvedValue({ + id: 'provider_1', + }); + mockConnectionRepository.findByProviderAndOrg.mockResolvedValue({ + id: 'conn_1', + metadata: {}, + variables: { + serviceDetectionCompletedAt: new Date().toISOString(), + detectedServices: ['compute-engine'], + }, + lastSyncAt: null, + }); + mockConnectionRepository.update.mockResolvedValue({ + id: 'conn_1', + metadata: {}, + variables: { + serviceDetectionCompletedAt: new Date().toISOString(), + detectedServices: ['compute-engine'], + }, + lastSyncAt: null, + }); + mockConnectionService.activateConnection.mockResolvedValue({ + id: 'conn_1', + }); + + const fetchSpy = jest.spyOn(global, 'fetch').mockResolvedValue({ + ok: true, + status: 200, + json: () => Promise.resolve({ access_token: 'access_123' }), + text: () => Promise.resolve(''), + } as unknown as Response); + + await controller.oauthCallback( + { code: 'auth_code', state: 'valid_gcp_state' }, + mockResponse, + ); + + expect(mockedTriggerTask).not.toHaveBeenCalledWith( + 'run-cloud-security-scan', + expect.anything(), + ); + expect(mockCloudSecurityService.detectServices).not.toHaveBeenCalled(); + expect(mockResponse.redirect).toHaveBeenCalled(); + + fetchSpy.mockRestore(); + }); }); }); diff --git a/apps/api/src/integration-platform/controllers/oauth.controller.ts b/apps/api/src/integration-platform/controllers/oauth.controller.ts index 6c58a74466..8c820e9d50 100644 --- a/apps/api/src/integration-platform/controllers/oauth.controller.ts +++ b/apps/api/src/integration-platform/controllers/oauth.controller.ts @@ -315,6 +315,23 @@ export class OAuthController { // Store tokens and mark connection as active await this.credentialVaultService.storeOAuthTokens(connection.id, tokens); + + // Mark cloud OAuth reconnect completion so reconnect banners clear after successful OAuth. + if (manifest.category === 'Cloud') { + const metadata = + connection.metadata && + typeof connection.metadata === 'object' && + !Array.isArray(connection.metadata) + ? (connection.metadata as Record) + : {}; + connection = await this.connectionRepository.update(connection.id, { + metadata: { + ...metadata, + reconnectedAt: new Date().toISOString(), + }, + }); + } + await this.connectionService.activateConnection(connection.id); // Provider-specific post-OAuth actions @@ -347,6 +364,10 @@ export class OAuthController { ); }); + // GCP: skip automatic service detection and scan after OAuth. + // The user must first select projects on the integrations page. + // Service detection and scanning run after project selection. + // Redirect to success URL const successUrl = this.buildRedirectUrl( oauthState.redirectUrl, @@ -510,4 +531,5 @@ export class OAuthController { } return url.toString(); } + } diff --git a/apps/api/src/integration-platform/controllers/services.controller.ts b/apps/api/src/integration-platform/controllers/services.controller.ts new file mode 100644 index 0000000000..c00061ecc9 --- /dev/null +++ b/apps/api/src/integration-platform/controllers/services.controller.ts @@ -0,0 +1,161 @@ +import { + Controller, + Get, + Param, + HttpException, + HttpStatus, + UseGuards, +} from '@nestjs/common'; +import { ApiTags, ApiSecurity } from '@nestjs/swagger'; +import { HybridAuthGuard } from '../../auth/hybrid-auth.guard'; +import { PermissionGuard } from '../../auth/permission.guard'; +import { RequirePermission } from '../../auth/require-permission.decorator'; +import { OrganizationId } from '../../auth/auth-context.decorator'; +import { ConnectionService } from '../services/connection.service'; +import { getManifest } from '@trycompai/integration-platform'; + +@Controller({ path: 'integrations/connections', version: '1' }) +@ApiTags('Integrations') +@UseGuards(HybridAuthGuard, PermissionGuard) +@ApiSecurity('apikey') +export class ServicesController { + constructor(private readonly connectionService: ConnectionService) {} + + /** + * Get services for a connection with their enabled state + */ + @Get(':id/services') + @RequirePermission('integration', 'read') + async getConnectionServices( + @Param('id') id: string, + @OrganizationId() organizationId: string, + ) { + const connection = await this.connectionService.getConnectionForOrg( + id, + organizationId, + ); + + const providerSlug = (connection as { provider?: { slug: string } }) + .provider?.slug; + if (!providerSlug) { + throw new HttpException( + 'Connection has no provider', + HttpStatus.BAD_REQUEST, + ); + } + + const manifest = getManifest(providerSlug); + if (!manifest?.services?.length) { + return { services: [] }; + } + + const raw = connection.variables; + const variables: Record = + raw && typeof raw === 'object' && !Array.isArray(raw) + ? (raw as Record) + : {}; + const disabledServices = new Set( + Array.isArray(variables.disabledServices) + ? (variables.disabledServices as string[]) + : [], + ); + const rawDetected = Array.isArray(variables.detectedServices) + ? (variables.detectedServices as string[]) + : []; + const detectedServices = rawDetected.length > 0 ? rawDetected : null; + // Legacy format support + const legacyEnabledServices = Array.isArray(variables.enabledServices) + ? (variables.enabledServices as string[]) + : null; + + const source = legacyEnabledServices + ? 'legacy-enabled' + : detectedServices + ? 'detected' + : 'manifest-default'; + const detectionCompletedAt = + typeof variables.serviceDetectionCompletedAt === 'string' + ? variables.serviceDetectionCompletedAt + : null; + const detectionReady = + providerSlug === 'gcp' + ? source !== 'manifest-default' || Boolean(detectionCompletedAt) + : true; + + // AWS security baseline: always scanned, hidden from Services tab + const BASELINE_SERVICES = + providerSlug === 'aws' + ? new Set([ + 'cloudtrail', + 'config', + 'guardduty', + 'iam', + 'cloudwatch', + 'kms', + ]) + : new Set(); + + // Per-project service mapping (GCP only) + const servicesByProject = + variables.servicesByProject && + typeof variables.servicesByProject === 'object' && + !Array.isArray(variables.servicesByProject) + ? (variables.servicesByProject as Record) + : {}; + // Invert: service → project IDs + const projectsByService: Record = {}; + for (const [projectId, serviceIds] of Object.entries(servicesByProject)) { + if (!Array.isArray(serviceIds)) continue; + for (const svcId of serviceIds) { + (projectsByService[svcId] ??= []).push(projectId); + } + } + + return { + services: manifest.services + .filter((s) => !BASELINE_SERVICES.has(s.id)) + .map((s) => { + // Unimplemented services are never enabled + if (s.implemented === false) { + return { + id: s.id, + name: s.name, + description: s.description, + implemented: false, + enabled: false, + projects: [] as string[], + }; + } + + let enabled: boolean; + if (legacyEnabledServices) { + enabled = + legacyEnabledServices.includes(s.id) && + !disabledServices.has(s.id); + } else if (detectedServices) { + enabled = + detectedServices.includes(s.id) && !disabledServices.has(s.id); + } else { + // Default: use enabledByDefault from manifest, otherwise enabled + enabled = + (s.enabledByDefault ?? true) && !disabledServices.has(s.id); + } + + return { + id: s.id, + name: s.name, + description: s.description, + implemented: true, + enabled, + projects: projectsByService[s.id] ?? [], + }; + }), + meta: { + providerSlug, + source, + detectionReady, + detectionCompletedAt, + }, + }; + } +} diff --git a/apps/api/src/integration-platform/controllers/sync-gws.controller.spec.ts b/apps/api/src/integration-platform/controllers/sync-gws.controller.spec.ts index 4320387f89..f1b4ce11a3 100644 --- a/apps/api/src/integration-platform/controllers/sync-gws.controller.spec.ts +++ b/apps/api/src/integration-platform/controllers/sync-gws.controller.spec.ts @@ -31,9 +31,9 @@ jest.mock('@trycompai/auth', () => ({ })); jest.mock('@trycompai/integration-platform', () => { - const actual = jest.requireActual( - '@trycompai/integration-platform', - ); + const actual = jest.requireActual< + typeof import('@trycompai/integration-platform') + >('@trycompai/integration-platform'); return { ...actual, getManifest: jest.fn().mockReturnValue({ @@ -193,7 +193,9 @@ describe('SyncController - Google Workspace employees', () => { email: 'new@example.com', }); (mockedDb.member.findFirst as jest.Mock).mockResolvedValue(null); - (mockedDb.member.create as jest.Mock).mockResolvedValue({ id: 'mem_new' }); + (mockedDb.member.create as jest.Mock).mockResolvedValue({ + id: 'mem_new', + }); (mockedDb.member.findMany as jest.Mock).mockResolvedValue([]); const result = await controller.syncGoogleWorkspaceEmployees( @@ -677,7 +679,7 @@ describe('SyncController - Google Workspace employees', () => { ], }); - let callCount = 0; + const callCount = 0; (mockedDb.user.findUnique as jest.Mock).mockImplementation( ({ where }: { where: { email: string } }) => { const map: Record = { @@ -714,7 +716,9 @@ describe('SyncController - Google Workspace employees', () => { return Promise.resolve(map[where.userId] ?? null); }, ); - (mockedDb.member.create as jest.Mock).mockResolvedValue({ id: 'mem_new' }); + (mockedDb.member.create as jest.Mock).mockResolvedValue({ + id: 'mem_new', + }); // Deactivation pass: suspended@example.com member is active in org (mockedDb.member.findMany as jest.Mock).mockResolvedValue([ @@ -782,9 +786,7 @@ describe('SyncController - Google Workspace employees', () => { expect(result.errors).toBe(1); expect(result.imported).toBe(0); - const detail = result.details.find( - (d) => d.email === 'fail@example.com', - ); + const detail = result.details.find((d) => d.email === 'fail@example.com'); expect(detail?.status).toBe('error'); expect(detail?.reason).toBe('DB write failed'); }); diff --git a/apps/api/src/integration-platform/controllers/sync-ou-filter.ts b/apps/api/src/integration-platform/controllers/sync-ou-filter.ts index c29070c956..e798f21d46 100644 --- a/apps/api/src/integration-platform/controllers/sync-ou-filter.ts +++ b/apps/api/src/integration-platform/controllers/sync-ou-filter.ts @@ -6,9 +6,10 @@ * @param targetOrgUnits - Array of OU paths to include (undefined/empty = all users) * @returns Filtered array of users */ -export function filterUsersByOrgUnits< - T extends { orgUnitPath?: string }, ->(users: T[], targetOrgUnits: string[] | undefined): T[] { +export function filterUsersByOrgUnits( + users: T[], + targetOrgUnits: string[] | undefined, +): T[] { if (!targetOrgUnits || targetOrgUnits.length === 0) { return users; } @@ -16,8 +17,7 @@ export function filterUsersByOrgUnits< return users.filter((user) => { const userOu = user.orgUnitPath ?? '/'; return targetOrgUnits.some( - (ou) => - ou === '/' || userOu === ou || userOu.startsWith(`${ou}/`), + (ou) => ou === '/' || userOu === ou || userOu.startsWith(`${ou}/`), ); }); } diff --git a/apps/api/src/integration-platform/controllers/sync.controller.ts b/apps/api/src/integration-platform/controllers/sync.controller.ts index ae43e1afb4..852599c604 100644 --- a/apps/api/src/integration-platform/controllers/sync.controller.ts +++ b/apps/api/src/integration-platform/controllers/sync.controller.ts @@ -14,10 +14,7 @@ import { ApiTags, ApiSecurity } from '@nestjs/swagger'; import { HybridAuthGuard } from '../../auth/hybrid-auth.guard'; import { PermissionGuard } from '../../auth/permission.guard'; import { RequirePermission } from '../../auth/require-permission.decorator'; -import { - OrganizationId, - AuthContext, -} from '../../auth/auth-context.decorator'; +import { OrganizationId, AuthContext } from '../../auth/auth-context.decorator'; import type { AuthContext as AuthContextType } from '../../auth/types'; import { db } from '@db'; import type { Prisma } from '@db'; @@ -313,7 +310,9 @@ export class SyncController { activeUsers.map((u) => u.primaryEmail.toLowerCase()), ); const allSuspendedEmails = new Set( - ouFilteredUsers.filter((u) => u.suspended).map((u) => u.primaryEmail.toLowerCase()), + ouFilteredUsers + .filter((u) => u.suspended) + .map((u) => u.primaryEmail.toLowerCase()), ); const allActiveEmails = new Set( ouFilteredUsers @@ -435,7 +434,11 @@ export class SyncController { const deactivationGwDomains = effectiveSyncFilterMode === 'include' - ? new Set(ouFilteredUsers.map((u) => u.primaryEmail.split('@')[1]?.toLowerCase())) + ? new Set( + ouFilteredUsers.map((u) => + u.primaryEmail.split('@')[1]?.toLowerCase(), + ), + ) : new Set( filteredUsers.map((u) => u.primaryEmail.split('@')[1]?.toLowerCase(), @@ -523,10 +526,7 @@ export class SyncController { */ @Post('google-workspace/status') @RequirePermission('integration', 'read') - async getGoogleWorkspaceStatus( - @OrganizationId() organizationId: string, - ) { - + async getGoogleWorkspaceStatus(@OrganizationId() organizationId: string) { const connection = await this.connectionRepository.findBySlugAndOrg( 'google-workspace', organizationId, @@ -925,7 +925,6 @@ export class SyncController { @Post('rippling/status') @RequirePermission('integration', 'read') async getRipplingStatus(@OrganizationId() organizationId: string) { - const connection = await this.connectionRepository.findBySlugAndOrg( 'rippling', organizationId, @@ -1464,7 +1463,6 @@ export class SyncController { @Post('jumpcloud/status') @RequirePermission('integration', 'read') async getJumpCloudStatus(@OrganizationId() organizationId: string) { - const connection = await this.connectionRepository.findBySlugAndOrg( 'jumpcloud', organizationId, @@ -1492,10 +1490,7 @@ export class SyncController { */ @Get('employee-sync-provider') @RequirePermission('integration', 'read') - async getEmployeeSyncProvider( - @OrganizationId() organizationId: string, - ) { - + async getEmployeeSyncProvider(@OrganizationId() organizationId: string) { const org = await db.organization.findUnique({ where: { id: organizationId }, select: { employeeSyncProvider: true }, @@ -1528,7 +1523,6 @@ export class SyncController { @OrganizationId() organizationId: string, @Body() body: { provider: string | null }, ) { - const { provider } = body; // Validate provider if set @@ -1583,9 +1577,7 @@ export class SyncController { */ @Get('available-providers') @RequirePermission('integration', 'read') - async getAvailableSyncProviders( - @OrganizationId() organizationId: string, - ) { + async getAvailableSyncProviders(@OrganizationId() organizationId: string) { const allManifests = registry.getActiveManifests(); const syncProviders = allManifests.filter((m) => m.capabilities?.includes('sync'), @@ -1690,7 +1682,7 @@ export class SyncController { // Try to refresh OAuth token if applicable if (manifest.auth.type === 'oauth2' && credentials.refresh_token) { - const oauthConfig = manifest.auth.config as OAuthConfig; + const oauthConfig = manifest.auth.config; try { const oauthCredentials = await this.oauthCredentialsService.getCredentials( @@ -1698,17 +1690,16 @@ export class SyncController { organizationId, ); if (oauthCredentials) { - const newToken = - await this.credentialVaultService.refreshOAuthTokens( - connectionId, - { - tokenUrl: oauthConfig.tokenUrl, - refreshUrl: oauthConfig.refreshUrl, - clientId: oauthCredentials.clientId, - clientSecret: oauthCredentials.clientSecret, - clientAuthMethod: oauthConfig.clientAuthMethod, - }, - ); + const newToken = await this.credentialVaultService.refreshOAuthTokens( + connectionId, + { + tokenUrl: oauthConfig.tokenUrl, + refreshUrl: oauthConfig.refreshUrl, + clientId: oauthCredentials.clientId, + clientSecret: oauthCredentials.clientSecret, + clientAuthMethod: oauthConfig.clientAuthMethod, + }, + ); if (newToken) { credentials = await this.credentialVaultService.getDecryptedCredentials( @@ -1740,12 +1731,11 @@ export class SyncController { manifest, accessToken: typeof accessToken === 'string' ? accessToken : undefined, credentials: (credentials ?? {}) as Record, - variables: - ((connection.variables as Record) ?? {}) as Record, + variables: ((connection.variables as Record) ?? + {}) as Record, connectionId, organizationId, - metadata: - (connection.metadata as Record) ?? {}, + metadata: (connection.metadata as Record) ?? {}, logger: { info: (msg, data) => this.logger.log(msg, data), warn: (msg, data) => this.logger.warn(msg, data), @@ -1791,9 +1781,10 @@ export class SyncController { totalChecked: result.totalFound, passedCount: result.imported + result.reactivated, failedCount: result.errors, - logs: executionLogs.length > 0 - ? (executionLogs as unknown as Prisma.InputJsonValue) - : undefined, + logs: + executionLogs.length > 0 + ? (executionLogs as unknown as Prisma.InputJsonValue) + : undefined, }); this.logger.log( @@ -1813,7 +1804,8 @@ export class SyncController { timestamp: log.timestamp.toISOString(), })); - const errorMessage = error instanceof Error ? error.message : String(error); + const errorMessage = + error instanceof Error ? error.message : String(error); const errorStack = error instanceof Error ? error.stack : undefined; const startTime = syncRun.startedAt?.getTime() || Date.now(); diff --git a/apps/api/src/integration-platform/controllers/variables.controller.spec.ts b/apps/api/src/integration-platform/controllers/variables.controller.spec.ts index 5444b79b1e..a069385603 100644 --- a/apps/api/src/integration-platform/controllers/variables.controller.spec.ts +++ b/apps/api/src/integration-platform/controllers/variables.controller.spec.ts @@ -228,9 +228,9 @@ describe('VariablesController', () => { }); mockProviderRepository.findById.mockResolvedValue(null); - await expect( - controller.getConnectionVariables('conn_1'), - ).rejects.toThrow(HttpException); + await expect(controller.getConnectionVariables('conn_1')).rejects.toThrow( + HttpException, + ); }); it('should throw NOT_FOUND when manifest does not exist', async () => { @@ -245,9 +245,9 @@ describe('VariablesController', () => { }); mockedGetManifest.mockReturnValue(undefined as never); - await expect( - controller.getConnectionVariables('conn_1'), - ).rejects.toThrow(HttpException); + await expect(controller.getConnectionVariables('conn_1')).rejects.toThrow( + HttpException, + ); }); }); @@ -380,9 +380,9 @@ describe('VariablesController', () => { variables: { key: 'val' }, }); - expect( - mockAutoCheckRunnerService.tryAutoRunChecks, - ).toHaveBeenCalledWith('conn_1'); + expect(mockAutoCheckRunnerService.tryAutoRunChecks).toHaveBeenCalledWith( + 'conn_1', + ); }); }); }); diff --git a/apps/api/src/integration-platform/controllers/variables.controller.ts b/apps/api/src/integration-platform/controllers/variables.controller.ts index bc0ea0d0ed..b218606c15 100644 --- a/apps/api/src/integration-platform/controllers/variables.controller.ts +++ b/apps/api/src/integration-platform/controllers/variables.controller.ts @@ -14,7 +14,10 @@ import { HybridAuthGuard } from '../../auth/hybrid-auth.guard'; import { PermissionGuard } from '../../auth/permission.guard'; import { RequirePermission } from '../../auth/require-permission.decorator'; import { OrganizationId } from '../../auth/auth-context.decorator'; -import { getManifest, type CheckVariable } from '@trycompai/integration-platform'; +import { + getManifest, + type CheckVariable, +} from '@trycompai/integration-platform'; import { ConnectionRepository } from '../repositories/connection.repository'; import { ConnectionService } from '../services/connection.service'; import { ProviderRepository } from '../repositories/provider.repository'; @@ -116,7 +119,10 @@ export class VariablesController { @Param('connectionId') connectionId: string, @OrganizationId() organizationId: string, ) { - await this.connectionService.getConnectionForOrg(connectionId, organizationId); + await this.connectionService.getConnectionForOrg( + connectionId, + organizationId, + ); const connection = await this.connectionRepository.findById(connectionId); if (!connection) { @@ -189,7 +195,10 @@ export class VariablesController { @Param('variableId') variableId: string, @OrganizationId() organizationId: string, ): Promise<{ options: VariableOption[] }> { - await this.connectionService.getConnectionForOrg(connectionId, organizationId); + await this.connectionService.getConnectionForOrg( + connectionId, + organizationId, + ); const connection = await this.connectionRepository.findById(connectionId); if (!connection) { @@ -399,7 +408,10 @@ export class VariablesController { @Body() body: SaveVariablesDto, @OrganizationId() organizationId: string, ) { - await this.connectionService.getConnectionForOrg(connectionId, organizationId); + await this.connectionService.getConnectionForOrg( + connectionId, + organizationId, + ); const connection = await this.connectionRepository.findById(connectionId); if (!connection) { diff --git a/apps/api/src/integration-platform/integration-platform.module.ts b/apps/api/src/integration-platform/integration-platform.module.ts index 05fb6d43ee..36e52e415d 100644 --- a/apps/api/src/integration-platform/integration-platform.module.ts +++ b/apps/api/src/integration-platform/integration-platform.module.ts @@ -1,5 +1,6 @@ -import { Module } from '@nestjs/common'; +import { Module, forwardRef } from '@nestjs/common'; import { AuthModule } from '../auth/auth.module'; +import { CloudSecurityModule } from '../cloud-security/cloud-security.module'; import { OAuthController } from './controllers/oauth.controller'; import { OAuthAppsController } from './controllers/oauth-apps.controller'; import { ConnectionsController } from './controllers/connections.controller'; @@ -10,6 +11,7 @@ import { VariablesController } from './controllers/variables.controller'; import { TaskIntegrationsController } from './controllers/task-integrations.controller'; import { WebhookController } from './controllers/webhook.controller'; import { SyncController } from './controllers/sync.controller'; +import { ServicesController } from './controllers/services.controller'; import { CredentialVaultService } from './services/credential-vault.service'; import { ConnectionService } from './services/connection.service'; import { OAuthCredentialsService } from './services/oauth-credentials.service'; @@ -31,7 +33,7 @@ import { IntegrationSyncLoggerService } from './services/integration-sync-logger import { GenericEmployeeSyncService } from './services/generic-employee-sync.service'; @Module({ - imports: [AuthModule], + imports: [AuthModule, forwardRef(() => CloudSecurityModule)], controllers: [ OAuthController, OAuthAppsController, @@ -43,6 +45,7 @@ import { GenericEmployeeSyncService } from './services/generic-employee-sync.ser TaskIntegrationsController, WebhookController, SyncController, + ServicesController, ], providers: [ // Services diff --git a/apps/api/src/integration-platform/interceptors/platform-audit-log.interceptor.ts b/apps/api/src/integration-platform/interceptors/platform-audit-log.interceptor.ts index 7aaf00653f..cc0570033f 100644 --- a/apps/api/src/integration-platform/interceptors/platform-audit-log.interceptor.ts +++ b/apps/api/src/integration-platform/interceptors/platform-audit-log.interceptor.ts @@ -35,10 +35,25 @@ export class PlatformAuditLogInterceptor implements NestInterceptor { return next.handle().pipe( tap({ next: () => { - void this.persistAuditEntry(userId, action, method, request.url, providerSlug, false); + void this.persistAuditEntry( + userId, + action, + method, + request.url, + providerSlug, + false, + ); }, error: (err: Error) => { - void this.persistAuditEntry(userId, action, method, request.url, providerSlug, true, err.message); + void this.persistAuditEntry( + userId, + action, + method, + request.url, + providerSlug, + true, + err.message, + ); }, }), ); diff --git a/apps/api/src/integration-platform/repositories/connection.repository.ts b/apps/api/src/integration-platform/repositories/connection.repository.ts index ff27f64670..260d202b43 100644 --- a/apps/api/src/integration-platform/repositories/connection.repository.ts +++ b/apps/api/src/integration-platform/repositories/connection.repository.ts @@ -1,9 +1,6 @@ import { Injectable } from '@nestjs/common'; import { db } from '@db'; -import type { - IntegrationConnection, - IntegrationConnectionStatus, -} from '@db'; +import type { IntegrationConnection, IntegrationConnectionStatus } from '@db'; export interface CreateConnectionDto { providerId: string; @@ -39,7 +36,11 @@ export class ConnectionRepository { organizationId: string, ): Promise { return db.integrationConnection.findFirst({ - where: { providerId, organizationId }, + where: { + providerId, + organizationId, + status: { not: 'disconnected' }, + }, orderBy: { createdAt: 'desc' }, include: { provider: true, diff --git a/apps/api/src/integration-platform/repositories/dynamic-check.repository.ts b/apps/api/src/integration-platform/repositories/dynamic-check.repository.ts index d15dcb32d8..a62e0df4d2 100644 --- a/apps/api/src/integration-platform/repositories/dynamic-check.repository.ts +++ b/apps/api/src/integration-platform/repositories/dynamic-check.repository.ts @@ -72,6 +72,7 @@ export class DynamicCheckRepository { variables?: Prisma.InputJsonValue; isEnabled?: boolean; sortOrder?: number; + service?: string; }): Promise { return db.dynamicCheck.upsert({ where: { @@ -91,6 +92,7 @@ export class DynamicCheckRepository { variables: data.variables ?? [], isEnabled: data.isEnabled ?? true, sortOrder: data.sortOrder ?? 0, + service: data.service, }, update: { name: data.name, @@ -101,6 +103,7 @@ export class DynamicCheckRepository { variables: data.variables ?? [], isEnabled: data.isEnabled ?? true, sortOrder: data.sortOrder ?? 0, + service: data.service, }, }); } diff --git a/apps/api/src/integration-platform/repositories/dynamic-integration.repository.ts b/apps/api/src/integration-platform/repositories/dynamic-integration.repository.ts index f3d6e753ed..3f28cc5125 100644 --- a/apps/api/src/integration-platform/repositories/dynamic-integration.repository.ts +++ b/apps/api/src/integration-platform/repositories/dynamic-integration.repository.ts @@ -56,6 +56,7 @@ export class DynamicIntegrationRepository { capabilities?: Prisma.InputJsonValue; supportsMultipleConnections?: boolean; syncDefinition?: Prisma.InputJsonValue; + services?: Prisma.InputJsonValue; }): Promise { return db.dynamicIntegration.create({ data: { @@ -71,6 +72,7 @@ export class DynamicIntegrationRepository { capabilities: data.capabilities ?? ['checks'], supportsMultipleConnections: data.supportsMultipleConnections ?? false, syncDefinition: data.syncDefinition ?? undefined, + services: data.services ?? undefined, }, }); } @@ -102,6 +104,7 @@ export class DynamicIntegrationRepository { capabilities?: Prisma.InputJsonValue; supportsMultipleConnections?: boolean; syncDefinition?: Prisma.InputJsonValue | null; + services?: Prisma.InputJsonValue; }): Promise { return db.dynamicIntegration.upsert({ where: { slug: data.slug }, @@ -118,6 +121,7 @@ export class DynamicIntegrationRepository { capabilities: data.capabilities ?? ['checks'], supportsMultipleConnections: data.supportsMultipleConnections ?? false, syncDefinition: data.syncDefinition ?? undefined, + services: data.services ?? undefined, }, update: { name: data.name, @@ -130,9 +134,11 @@ export class DynamicIntegrationRepository { authConfig: data.authConfig, capabilities: data.capabilities ?? ['checks'], supportsMultipleConnections: data.supportsMultipleConnections ?? false, - syncDefinition: data.syncDefinition === null - ? Prisma.DbNull - : (data.syncDefinition ?? undefined), + syncDefinition: + data.syncDefinition === null + ? Prisma.DbNull + : (data.syncDefinition ?? undefined), + services: data.services ?? undefined, }, }); } diff --git a/apps/api/src/integration-platform/services/connection.service.ts b/apps/api/src/integration-platform/services/connection.service.ts index 259bd89e3b..b3a8d976f2 100644 --- a/apps/api/src/integration-platform/services/connection.service.ts +++ b/apps/api/src/integration-platform/services/connection.service.ts @@ -7,10 +7,7 @@ import { getManifest } from '@trycompai/integration-platform'; import { ConnectionRepository } from '../repositories/connection.repository'; import { ProviderRepository } from '../repositories/provider.repository'; import { ConnectionAuthTeardownService } from './connection-auth-teardown.service'; -import type { - IntegrationConnection, - IntegrationConnectionStatus, -} from '@db'; +import type { IntegrationConnection, IntegrationConnectionStatus } from '@db'; export interface CreateConnectionInput { providerSlug: string; @@ -149,7 +146,13 @@ export class ConnectionService { await this.getConnection(connectionId); // Verify exists await this.connectionAuthTeardownService.teardown({ connectionId }); - await this.connectionRepository.delete(connectionId); + // Soft-delete: preserve findings, remediation history, and activity logs + // for audit trail and compliance. Only clear credentials and mark as disconnected. + await this.connectionRepository.update(connectionId, { + status: 'disconnected', + activeCredentialVersionId: null, + errorMessage: null, + }); } async updateLastSync(connectionId: string): Promise { diff --git a/apps/api/src/integration-platform/services/dynamic-manifest-loader.service.ts b/apps/api/src/integration-platform/services/dynamic-manifest-loader.service.ts index ee84c02cfa..8a827f18cd 100644 --- a/apps/api/src/integration-platform/services/dynamic-manifest-loader.service.ts +++ b/apps/api/src/integration-platform/services/dynamic-manifest-loader.service.ts @@ -1,19 +1,31 @@ -import { Injectable, Logger, OnModuleInit } from '@nestjs/common'; +import { + Injectable, + Logger, + OnModuleDestroy, + OnModuleInit, +} from '@nestjs/common'; +import { Prisma } from '@db'; import { registry, interpretDeclarativeCheck, type IntegrationManifest, + type IntegrationService, type AuthStrategy, type IntegrationCategory, type IntegrationCapability, type FindingSeverity, type CheckVariable, } from '@trycompai/integration-platform'; -import { DynamicIntegrationRepository, type DynamicIntegrationWithChecks } from '../repositories/dynamic-integration.repository'; +import { + DynamicIntegrationRepository, + type DynamicIntegrationWithChecks, +} from '../repositories/dynamic-integration.repository'; import type { DynamicCheck } from '@db'; @Injectable() -export class DynamicManifestLoaderService implements OnModuleInit { +export class DynamicManifestLoaderService + implements OnModuleInit, OnModuleDestroy +{ private readonly logger = new Logger(DynamicManifestLoaderService.name); private refreshTimer: ReturnType | null = null; @@ -24,17 +36,54 @@ export class DynamicManifestLoaderService implements OnModuleInit { async onModuleInit() { try { await this.loadDynamicManifests(); - // Background refresh every 60 seconds as safety net - this.refreshTimer = setInterval(() => { - this.loadDynamicManifests().catch((err) => { - this.logger.error('Background refresh failed', err); - }); - }, 60_000); } catch (error) { - this.logger.error('Failed to load dynamic manifests on boot', error); + this.logManifestLoadFailure(error, 'boot'); + } + + // Always schedule refresh so manifests load after Postgres comes online (common in local dev). + this.refreshTimer = setInterval(() => { + this.loadDynamicManifests().catch((err) => { + if (this.isDatabaseUnavailable(err)) { + this.logger.debug( + 'Dynamic manifests skipped: database still unreachable', + ); + return; + } + this.logger.error( + 'Background refresh of dynamic manifests failed', + err, + ); + }); + }, 60_000); + } + + onModuleDestroy() { + if (this.refreshTimer) { + clearInterval(this.refreshTimer); + this.refreshTimer = null; } } + private isDatabaseUnavailable(error: unknown): boolean { + if (error instanceof Prisma.PrismaClientInitializationError) { + return true; + } + if (error instanceof Error) { + return error.message.includes("Can't reach database server"); + } + return false; + } + + private logManifestLoadFailure(error: unknown, phase: 'boot') { + if (this.isDatabaseUnavailable(error)) { + this.logger.warn( + 'Dynamic integration manifests not loaded: database unreachable. Start Postgres (e.g. packages/db docker) or set DATABASE_URL. Manifests will load when the DB is reachable.', + ); + return; + } + this.logger.error(`Failed to load dynamic manifests on ${phase}`, error); + } + /** * Load all active dynamic integrations from DB and merge into the registry. */ @@ -55,7 +104,9 @@ export class DynamicManifestLoaderService implements OnModuleInit { } registry.refreshDynamic(manifests); - this.logger.log(`Loaded ${manifests.length} dynamic integrations into registry`); + this.logger.log( + `Loaded ${manifests.length} dynamic integrations into registry`, + ); } /** @@ -84,7 +135,10 @@ export class DynamicManifestLoaderService implements OnModuleInit { // Collect manifest-level variables from syncDefinition (if present) // These appear in the customer configuration UI (ManageIntegrationDialog) - const syncDef = integration.syncDefinition as Record | null; + const syncDef = integration.syncDefinition as Record< + string, + unknown + > | null; const syncVariables = syncDef?.variables as CheckVariable[] | undefined; return { @@ -96,10 +150,17 @@ export class DynamicManifestLoaderService implements OnModuleInit { docsUrl: integration.docsUrl ?? undefined, auth, baseUrl: integration.baseUrl ?? undefined, - defaultHeaders: (integration.defaultHeaders as Record) ?? undefined, - capabilities: (integration.capabilities as unknown as IntegrationCapability[]) ?? ['checks'], + defaultHeaders: + (integration.defaultHeaders as Record) ?? undefined, + capabilities: + (integration.capabilities as unknown as IntegrationCapability[]) ?? [ + 'checks', + ], supportsMultipleConnections: integration.supportsMultipleConnections, - variables: syncVariables && syncVariables.length > 0 ? syncVariables : undefined, + variables: + syncVariables && syncVariables.length > 0 ? syncVariables : undefined, + services: + (integration.services as IntegrationService[] | null) ?? undefined, checks, isActive: integration.isActive, }; @@ -116,10 +177,13 @@ export class DynamicManifestLoaderService implements OnModuleInit { id: check.checkSlug, name: check.name, description: check.description, - definition: definition as Parameters[0]['definition'], + definition: definition as Parameters< + typeof interpretDeclarativeCheck + >[0]['definition'], taskMapping: check.taskMapping ?? undefined, defaultSeverity: (check.defaultSeverity as FindingSeverity) ?? 'medium', variables: variables && variables.length > 0 ? variables : undefined, + service: check.service ?? undefined, }); } } diff --git a/apps/api/src/integration-platform/services/generic-employee-sync.service.ts b/apps/api/src/integration-platform/services/generic-employee-sync.service.ts index c42959d891..8b0e01b452 100644 --- a/apps/api/src/integration-platform/services/generic-employee-sync.service.ts +++ b/apps/api/src/integration-platform/services/generic-employee-sync.service.ts @@ -8,12 +8,7 @@ import type { SyncEmployee } from '@trycompai/integration-platform'; export interface SyncResultDetail { email: string; - status: - | 'imported' - | 'skipped' - | 'deactivated' - | 'reactivated' - | 'error'; + status: 'imported' | 'skipped' | 'deactivated' | 'reactivated' | 'error'; reason?: string; } @@ -259,7 +254,9 @@ export class GenericEmployeeSyncService { : `User was removed from ${providerName}`, }); } catch (error) { - this.logger.error(`Error deactivating member ${memberEmail}: ${error}`); + this.logger.error( + `Error deactivating member ${memberEmail}: ${error}`, + ); results.errors++; } } diff --git a/apps/api/src/integration-platform/services/integration-sync-logger.service.ts b/apps/api/src/integration-platform/services/integration-sync-logger.service.ts index 989a3d60e9..70eb31aa84 100644 --- a/apps/api/src/integration-platform/services/integration-sync-logger.service.ts +++ b/apps/api/src/integration-platform/services/integration-sync-logger.service.ts @@ -36,14 +36,15 @@ export class IntegrationSyncLoggerService { }, }); - this.logger.log( - `Sync log started: ${log.id} (${provider}/${eventType})`, - ); + this.logger.log(`Sync log started: ${log.id} (${provider}/${eventType})`); return log.id; } - async completeLog(id: string, result: Record): Promise { + async completeLog( + id: string, + result: Record, + ): Promise { const log = await db.integrationSyncLog.findUnique({ where: { id } }); if (!log) { this.logger.warn(`Sync log not found: ${id}`); diff --git a/apps/api/src/knowledge-base/knowledge-base.controller.spec.ts b/apps/api/src/knowledge-base/knowledge-base.controller.spec.ts index 1a44dc6700..7e0b499b1f 100644 --- a/apps/api/src/knowledge-base/knowledge-base.controller.spec.ts +++ b/apps/api/src/knowledge-base/knowledge-base.controller.spec.ts @@ -61,9 +61,7 @@ describe('KnowledgeBaseController', () => { describe('listManualAnswers', () => { it('should return manual answers from service', async () => { - const mockAnswers = [ - { id: 'ma1', question: 'Q1?', answer: 'A1' }, - ]; + const mockAnswers = [{ id: 'ma1', question: 'Q1?', answer: 'A1' }]; mockService.listManualAnswers.mockResolvedValue(mockAnswers); const result = await controller.listManualAnswers('org_1'); diff --git a/apps/api/src/knowledge-base/knowledge-base.service.spec.ts b/apps/api/src/knowledge-base/knowledge-base.service.spec.ts index 52a4c6817d..3b829abce4 100644 --- a/apps/api/src/knowledge-base/knowledge-base.service.spec.ts +++ b/apps/api/src/knowledge-base/knowledge-base.service.spec.ts @@ -39,18 +39,12 @@ jest.mock('./utils/constants', () => ({ isViewableInBrowser: jest.fn(), })); -jest.mock( - '@/trigger/vector-store/process-knowledge-base-document', - () => ({}), -); +jest.mock('@/trigger/vector-store/process-knowledge-base-document', () => ({})); jest.mock( '@/trigger/vector-store/process-knowledge-base-documents-orchestrator', () => ({}), ); -jest.mock( - '@/trigger/vector-store/delete-knowledge-base-document', - () => ({}), -); +jest.mock('@/trigger/vector-store/delete-knowledge-base-document', () => ({})); jest.mock('@/trigger/vector-store/delete-manual-answer', () => ({})); jest.mock( '@/trigger/vector-store/delete-all-manual-answers-orchestrator', diff --git a/apps/api/src/knowledge-base/utils/s3-operations.ts b/apps/api/src/knowledge-base/utils/s3-operations.ts index 8acfe99c9e..cee6a7af3b 100644 --- a/apps/api/src/knowledge-base/utils/s3-operations.ts +++ b/apps/api/src/knowledge-base/utils/s3-operations.ts @@ -3,9 +3,12 @@ import { DeleteObjectCommand, GetObjectCommand, } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; import { randomBytes } from 'crypto'; -import { s3Client, APP_AWS_KNOWLEDGE_BASE_BUCKET } from '@/app/s3'; +import { + s3Client, + APP_AWS_KNOWLEDGE_BASE_BUCKET, + getSignedUrl, +} from '@/app/s3'; import { MAX_FILE_SIZE_BYTES, SIGNED_URL_EXPIRATION_SECONDS, diff --git a/apps/api/src/main.ts b/apps/api/src/main.ts index 6736642596..0aebc23a2c 100644 --- a/apps/api/src/main.ts +++ b/apps/api/src/main.ts @@ -72,16 +72,25 @@ async function bootstrap(): Promise { // Express-level middleware runs BEFORE NestJS module middleware, so without this // skip, express.json() would consume the stream before better-auth's handler. const jsonParser = express.json({ limit: '150mb' }); - const urlencodedParser = express.urlencoded({ limit: '150mb', extended: true }); - app.use((req: express.Request, res: express.Response, next: express.NextFunction) => { - if (req.path.startsWith('/api/auth')) { - return next(); - } - jsonParser(req, res, (err?: unknown) => { - if (err) return next(err); - urlencodedParser(req, res, next); - }); + const urlencodedParser = express.urlencoded({ + limit: '150mb', + extended: true, }); + app.use( + ( + req: express.Request, + res: express.Response, + next: express.NextFunction, + ) => { + if (req.path.startsWith('/api/auth')) { + return next(); + } + jsonParser(req, res, (err?: unknown) => { + if (err) return next(err); + urlencodedParser(req, res, next); + }); + }, + ); // STEP 4b: Enable global pipes and filters app.useGlobalPipes( diff --git a/apps/api/src/notifications/check-unsubscribe.spec.ts b/apps/api/src/notifications/check-unsubscribe.spec.ts index fd9c05482b..f3238fe5a0 100644 --- a/apps/api/src/notifications/check-unsubscribe.spec.ts +++ b/apps/api/src/notifications/check-unsubscribe.spec.ts @@ -166,7 +166,9 @@ describe('isUserUnsubscribed', () => { }); expect(await isUserUnsubscribed(db, email, 'taskReminders')).toBe(true); - expect(await isUserUnsubscribed(db, email, 'policyNotifications')).toBe(false); + expect(await isUserUnsubscribed(db, email, 'policyNotifications')).toBe( + false, + ); }); }); @@ -180,7 +182,12 @@ describe('isUserUnsubscribed', () => { roleSettings: [{ ...ALL_OFF }], }); - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); expect(result).toBe(true); }); @@ -194,7 +201,12 @@ describe('isUserUnsubscribed', () => { ], }); - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); expect(result).toBe(false); }); @@ -209,7 +221,12 @@ describe('isUserUnsubscribed', () => { roleSettings: [{ ...ALL_ON }], // role says ON }); - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); expect(result).toBe(true); // existing opt-out is preserved }); @@ -224,7 +241,12 @@ describe('isUserUnsubscribed', () => { roleSettings: [{ ...ALL_ON }], // role says ON }); - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); expect(result).toBe(false); // defaults to enabled }); @@ -239,7 +261,12 @@ describe('isUserUnsubscribed', () => { roleSettings: [{ ...ALL_ON }], // role says ON }); - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); expect(result).toBe(true); // admin can opt out }); @@ -254,7 +281,12 @@ describe('isUserUnsubscribed', () => { roleSettings: [{ ...ALL_ON }], }); - const result = await isUserUnsubscribed(db, email, 'weeklyTaskDigest', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'weeklyTaskDigest', + orgId, + ); expect(result).toBe(true); }); @@ -269,7 +301,12 @@ describe('isUserUnsubscribed', () => { roleSettings: [], // no role settings configured }); - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); expect(result).toBe(true); // falls through to personal pref }); @@ -284,7 +321,12 @@ describe('isUserUnsubscribed', () => { roleSettings: [], }); - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); expect(result).toBe(true); // falls through to personal pref }); @@ -302,7 +344,12 @@ describe('isUserUnsubscribed', () => { ], }); - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); expect(result).toBe(false); // employee role enables it, no personal opt-out }); @@ -317,7 +364,12 @@ describe('isUserUnsubscribed', () => { roleSettings: [{ ...ALL_ON }, { ...ALL_ON }], }); - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); expect(result).toBe(true); // admin portion allows opt-out }); @@ -360,14 +412,26 @@ describe('isUserUnsubscribed', () => { }); // ON by role - expect(await isUserUnsubscribed(db, email, 'policyNotifications', orgId)).toBe(false); - expect(await isUserUnsubscribed(db, email, 'taskAssignments', orgId)).toBe(false); - expect(await isUserUnsubscribed(db, email, 'weeklyTaskDigest', orgId)).toBe(false); + expect( + await isUserUnsubscribed(db, email, 'policyNotifications', orgId), + ).toBe(false); + expect( + await isUserUnsubscribed(db, email, 'taskAssignments', orgId), + ).toBe(false); + expect( + await isUserUnsubscribed(db, email, 'weeklyTaskDigest', orgId), + ).toBe(false); // OFF by role - expect(await isUserUnsubscribed(db, email, 'taskReminders', orgId)).toBe(true); - expect(await isUserUnsubscribed(db, email, 'taskMentions', orgId)).toBe(true); - expect(await isUserUnsubscribed(db, email, 'findingNotifications', orgId)).toBe(true); + expect(await isUserUnsubscribed(db, email, 'taskReminders', orgId)).toBe( + true, + ); + expect(await isUserUnsubscribed(db, email, 'taskMentions', orgId)).toBe( + true, + ); + expect( + await isUserUnsubscribed(db, email, 'findingNotifications', orgId), + ).toBe(true); }); }); @@ -395,7 +459,12 @@ describe('isUserUnsubscribed', () => { }, }; - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); // Should fall through to personal preferences since db.member is missing expect(result).toBe(true); @@ -405,7 +474,12 @@ describe('isUserUnsubscribed', () => { const db = createMockDb(); db.member.findMany.mockRejectedValue(new Error('Query failed')); - const result = await isUserUnsubscribed(db, email, 'taskReminders', orgId); + const result = await isUserUnsubscribed( + db, + email, + 'taskReminders', + orgId, + ); expect(result).toBe(false); }); diff --git a/apps/api/src/org-chart/org-chart.service.ts b/apps/api/src/org-chart/org-chart.service.ts index 8e36d5d299..a68acb3f63 100644 --- a/apps/api/src/org-chart/org-chart.service.ts +++ b/apps/api/src/org-chart/org-chart.service.ts @@ -10,9 +10,8 @@ import { PutObjectCommand, S3Client, } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; import { db } from '@db'; -import { s3Client, BUCKET_NAME } from '@/app/s3'; +import { s3Client, BUCKET_NAME, getSignedUrl } from '@/app/s3'; import type { UpsertOrgChartDto } from './dto/upsert-org-chart.dto'; import type { UploadOrgChartDto } from './dto/upload-org-chart.dto'; diff --git a/apps/api/src/organization/organization.controller.ts b/apps/api/src/organization/organization.controller.ts index 3024ddea42..9c9730bd65 100644 --- a/apps/api/src/organization/organization.controller.ts +++ b/apps/api/src/organization/organization.controller.ts @@ -18,10 +18,7 @@ import { ApiSecurity, ApiTags, } from '@nestjs/swagger'; -import { - AuthContext, - OrganizationId, -} from '../auth/auth-context.decorator'; +import { AuthContext, OrganizationId } from '../auth/auth-context.decorator'; import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; import { PermissionGuard } from '../auth/permission.guard'; import { RequirePermission } from '../auth/require-permission.decorator'; @@ -54,7 +51,11 @@ export class OrganizationController { @Get() @RequirePermission('organization', 'read') @ApiOperation(ORGANIZATION_OPERATIONS.getOrganization) - @ApiQuery({ name: 'includeOwnership', required: false, description: 'Include ownership data for transfer UI' }) + @ApiQuery({ + name: 'includeOwnership', + required: false, + description: 'Include ownership data for transfer UI', + }) @ApiResponse(GET_ORGANIZATION_RESPONSES[200]) @ApiResponse(GET_ORGANIZATION_RESPONSES[401]) async getOrganization( diff --git a/apps/api/src/organization/organization.service.ts b/apps/api/src/organization/organization.service.ts index f053e1ef75..e54546c06e 100644 --- a/apps/api/src/organization/organization.service.ts +++ b/apps/api/src/organization/organization.service.ts @@ -8,9 +8,8 @@ import { } from '@nestjs/common'; import { allRoles } from '@trycompai/auth'; import { GetObjectCommand, PutObjectCommand } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; import { db, Role } from '@db'; -import { APP_AWS_ORG_ASSETS_BUCKET, s3Client } from '../app/s3'; +import { APP_AWS_ORG_ASSETS_BUCKET, s3Client, getSignedUrl } from '../app/s3'; import type { UpdateOrganizationDto } from './dto/update-organization.dto'; import type { TransferOwnershipResponseDto } from './dto/transfer-ownership.dto'; @@ -319,10 +318,7 @@ export class OrganizationService { async getRoleNotificationSettings(organizationId: string) { const BUILT_IN_ROLES = Object.keys(allRoles); - const BUILT_IN_DEFAULTS: Record< - string, - Record - > = { + const BUILT_IN_DEFAULTS: Record> = { owner: { policyNotifications: true, taskReminders: true, @@ -432,7 +428,9 @@ export class OrganizationService { return { data: configs }; } - async getLogoSignedUrl(logoKey: string | null | undefined): Promise { + async getLogoSignedUrl( + logoKey: string | null | undefined, + ): Promise { if (!logoKey || !s3Client || !APP_AWS_ORG_ASSETS_BUCKET) { return null; } diff --git a/apps/api/src/people/dto/people-responses.dto.ts b/apps/api/src/people/dto/people-responses.dto.ts index 1b3d206935..65668b52a4 100644 --- a/apps/api/src/people/dto/people-responses.dto.ts +++ b/apps/api/src/people/dto/people-responses.dto.ts @@ -53,7 +53,8 @@ export class UserResponseDto { lastLogin: Date | null; @ApiProperty({ - description: 'Platform role of the user (managed by Better Auth admin plugin)', + description: + 'Platform role of the user (managed by Better Auth admin plugin)', example: 'user', nullable: true, }) diff --git a/apps/api/src/people/people-fleet.helper.ts b/apps/api/src/people/people-fleet.helper.ts index 9e3e8d1bbd..784d26cc89 100644 --- a/apps/api/src/people/people-fleet.helper.ts +++ b/apps/api/src/people/people-fleet.helper.ts @@ -17,7 +17,11 @@ export interface FleetPolicyResult { function buildPoliciesWithResults( host: Record, - results: { fleetPolicyId: number; fleetPolicyResponse: string | null; attachments: unknown }[], + results: { + fleetPolicyId: number; + fleetPolicyResponse: string | null; + attachments: unknown; + }[], ) { const platform = (host.platform as string)?.toLowerCase(); const osVersion = (host.os_version as string)?.toLowerCase(); @@ -27,13 +31,23 @@ function buildPoliciesWithResults( platform === 'osx' || osVersion?.includes('mac'); - const hostPolicies = (host.policies || []) as { id: number; name: string; response: string }[]; + const hostPolicies = (host.policies || []) as { + id: number; + name: string; + response: string; + }[]; const mdm = host.mdm as { connected_to_fleet?: boolean } | undefined; const allPolicies = [ ...hostPolicies, ...(isMacOS && mdm - ? [{ id: MDM_POLICY_ID, name: 'MDM Enabled', response: mdm.connected_to_fleet ? 'pass' : 'fail' }] + ? [ + { + id: MDM_POLICY_ID, + name: 'MDM Enabled', + response: mdm.connected_to_fleet ? 'pass' : 'fail', + }, + ] : []), ]; @@ -42,7 +56,8 @@ function buildPoliciesWithResults( return { ...policy, response: - policy.response === 'pass' || policyResult?.fleetPolicyResponse === 'pass' + policy.response === 'pass' || + policyResult?.fleetPolicyResponse === 'pass' ? 'pass' : 'fail', attachments: policyResult?.attachments || [], @@ -62,7 +77,8 @@ export async function getFleetComplianceForMember( } try { - const labelHostsData = await fleetService.getHostsByLabel(memberFleetLabelId); + const labelHostsData = + await fleetService.getHostsByLabel(memberFleetLabelId); const firstHost = labelHostsData?.hosts?.[0]; if (!firstHost) { @@ -117,7 +133,9 @@ export async function getAllEmployeeDevices( const labelResponses = await Promise.all( membersWithLabels.map(async (employee) => { try { - const data = await fleetService.getHostsByLabel(employee.fleetDmLabelId!); + const data = await fleetService.getHostsByLabel( + employee.fleetDmLabelId!, + ); return { userId: employee.userId, userName: employee.user?.name, @@ -125,7 +143,12 @@ export async function getAllEmployeeDevices( hosts: data?.hosts || [], }; } catch { - return { userId: employee.userId, userName: employee.user?.name, memberId: employee.id, hosts: [] }; + return { + userId: employee.userId, + userName: employee.user?.name, + memberId: employee.id, + hosts: [], + }; } }), ); diff --git a/apps/api/src/people/people-invite.service.spec.ts b/apps/api/src/people/people-invite.service.spec.ts index 8e9c5b3232..f3c9e31fd3 100644 --- a/apps/api/src/people/people-invite.service.spec.ts +++ b/apps/api/src/people/people-invite.service.spec.ts @@ -123,7 +123,9 @@ describe('PeopleInviteService', () => { (mockDb.member.create as jest.Mock).mockResolvedValue({ id: 'member_new', }); - (mockDb.employeeTrainingVideoCompletion.createMany as jest.Mock).mockResolvedValue({ + ( + mockDb.employeeTrainingVideoCompletion.createMany as jest.Mock + ).mockResolvedValue({ count: 5, }); @@ -189,7 +191,9 @@ describe('PeopleInviteService', () => { (mockDb.member.create as jest.Mock).mockResolvedValue({ id: 'member_new', }); - (mockDb.employeeTrainingVideoCompletion.createMany as jest.Mock).mockResolvedValue({ + ( + mockDb.employeeTrainingVideoCompletion.createMany as jest.Mock + ).mockResolvedValue({ count: 5, }); @@ -218,7 +222,9 @@ describe('PeopleInviteService', () => { (mockDb.member.create as jest.Mock).mockResolvedValue({ id: 'member_new', }); - (mockDb.employeeTrainingVideoCompletion.createMany as jest.Mock).mockResolvedValue({ + ( + mockDb.employeeTrainingVideoCompletion.createMany as jest.Mock + ).mockResolvedValue({ count: 5, }); @@ -248,7 +254,9 @@ describe('PeopleInviteService', () => { (mockDb.member.create as jest.Mock).mockResolvedValue({ id: 'member_new', }); - (mockDb.employeeTrainingVideoCompletion.createMany as jest.Mock).mockResolvedValue({ + ( + mockDb.employeeTrainingVideoCompletion.createMany as jest.Mock + ).mockResolvedValue({ count: 5, }); mockTriggerEmail.mockRejectedValueOnce(new Error('Email service down')); diff --git a/apps/api/src/people/people.controller.ts b/apps/api/src/people/people.controller.ts index 53c4394f13..899ada7376 100644 --- a/apps/api/src/people/people.controller.ts +++ b/apps/api/src/people/people.controller.ts @@ -121,7 +121,9 @@ export class PeopleController { @Get('devices') @RequirePermission('member', 'read') - @ApiOperation({ summary: 'Get all employee devices with fleet compliance data' }) + @ApiOperation({ + summary: 'Get all employee devices with fleet compliance data', + }) async getDevices( @OrganizationId() organizationId: string, @AuthContext() authContext: AuthContextType, @@ -143,12 +145,15 @@ export class PeopleController { @Get('test-stats/by-assignee') @RequirePermission('member', 'read') - @ApiOperation({ summary: 'Get integration test statistics grouped by assignee' }) + @ApiOperation({ + summary: 'Get integration test statistics grouped by assignee', + }) async getTestStatsByAssignee( @OrganizationId() organizationId: string, @AuthContext() authContext: AuthContextType, ) { - const data = await this.peopleService.getTestStatsByAssignee(organizationId); + const data = + await this.peopleService.getTestStatsByAssignee(organizationId); return { data, @@ -226,7 +231,9 @@ export class PeopleController { @Get('mentionable') @RequirePermission('member', 'read') - @ApiOperation({ summary: 'Get members who can read a specific resource type' }) + @ApiOperation({ + summary: 'Get members who can read a specific resource type', + }) async getMentionableMembers( @OrganizationId() organizationId: string, @AuthContext() authContext: AuthContextType, @@ -527,7 +534,9 @@ export class PeopleController { } @Put('me/email-preferences') - @ApiOperation({ summary: 'Update current user email notification preferences' }) + @ApiOperation({ + summary: 'Update current user email notification preferences', + }) async updateEmailPreferences( @AuthContext() authContext: AuthContextType, @Body() body: UpdateEmailPreferencesDto, @@ -538,9 +547,8 @@ export class PeopleController { ); } - return this.peopleService.updateEmailPreferences( - authContext.userId, - { ...body.preferences }, - ); + return this.peopleService.updateEmailPreferences(authContext.userId, { + ...body.preferences, + }); } } diff --git a/apps/api/src/people/people.service.spec.ts b/apps/api/src/people/people.service.spec.ts index a2f5c4e810..b6ce94971f 100644 --- a/apps/api/src/people/people.service.spec.ts +++ b/apps/api/src/people/people.service.spec.ts @@ -51,8 +51,14 @@ jest.mock('@db', () => ({ jest.mock('@trycompai/auth', () => ({ BUILT_IN_ROLE_PERMISSIONS: { - owner: { organization: ['read', 'update', 'delete'], member: ['create', 'read', 'update', 'delete'] }, - admin: { organization: ['read', 'update'], member: ['create', 'read', 'update', 'delete'] }, + owner: { + organization: ['read', 'update', 'delete'], + member: ['create', 'read', 'update', 'delete'], + }, + admin: { + organization: ['read', 'update'], + member: ['create', 'read', 'update', 'delete'], + }, auditor: { organization: ['read'], member: ['read'] }, employee: { compliance: ['required'] }, contractor: { compliance: ['required'] }, @@ -220,7 +226,11 @@ describe('PeopleService', () => { userId: 'usr_1', role: 'employee', }; - const updatedMember = { id: 'mem_1', user: { name: 'Alice' }, role: 'admin' }; + const updatedMember = { + id: 'mem_1', + user: { name: 'Alice' }, + role: 'admin', + }; (MemberValidator.validateOrganization as jest.Mock).mockResolvedValue( undefined, @@ -253,7 +263,11 @@ describe('PeopleService', () => { userId: 'usr_old', role: 'employee', }; - const updatedMember = { id: 'mem_1', user: { name: 'New' }, role: 'employee' }; + const updatedMember = { + id: 'mem_1', + user: { name: 'New' }, + role: 'employee', + }; (MemberValidator.validateOrganization as jest.Mock).mockResolvedValue( undefined, @@ -476,9 +490,9 @@ describe('PeopleService', () => { null, ); - await expect( - service.unlinkDevice('mem_none', 'org_123'), - ).rejects.toThrow(NotFoundException); + await expect(service.unlinkDevice('mem_none', 'org_123')).rejects.toThrow( + NotFoundException, + ); }); }); diff --git a/apps/api/src/people/people.service.ts b/apps/api/src/people/people.service.ts index b6692373a4..9f36b2cba9 100644 --- a/apps/api/src/people/people.service.ts +++ b/apps/api/src/people/people.service.ts @@ -66,7 +66,10 @@ export class PeopleService { // Collect all unique role names across members const allRoleNames = new Set(); for (const member of members) { - const roles = member.role.split(',').map((r) => r.trim()).filter(Boolean); + const roles = member.role + .split(',') + .map((r) => r.trim()) + .filter(Boolean); for (const role of roles) { allRoleNames.add(role); } @@ -92,16 +95,20 @@ export class PeopleService { where: { organizationId, name: { in: customRoleNames } }, }); for (const role of customRoles) { - const perms = typeof role.permissions === 'string' - ? JSON.parse(role.permissions) as Record - : role.permissions as Record; + const perms = + typeof role.permissions === 'string' + ? (JSON.parse(role.permissions) as Record) + : (role.permissions as Record); permissionMap.set(role.name, perms); } } // Filter members whose combined permissions include the required permission return members.filter((member) => { - const roles = member.role.split(',').map((r) => r.trim()).filter(Boolean); + const roles = member.role + .split(',') + .map((r) => r.trim()) + .filter(Boolean); for (const role of roles) { const perms = permissionMap.get(role); if (perms && perms[resource]?.includes('read')) { @@ -476,7 +483,10 @@ export class PeopleService { ); } - const updatedMember = await MemberQueries.unlinkDevice(memberId, organizationId); + const updatedMember = await MemberQueries.unlinkDevice( + memberId, + organizationId, + ); this.logger.log( `Unlinked device for member: ${updatedMember.user.name} (${memberId})`, diff --git a/apps/api/src/people/utils/member-deactivation.ts b/apps/api/src/people/utils/member-deactivation.ts index cb0c6cf862..c373a2baac 100644 --- a/apps/api/src/people/utils/member-deactivation.ts +++ b/apps/api/src/people/utils/member-deactivation.ts @@ -12,7 +12,6 @@ export interface UnassignedItem { const logger = new Logger('MemberDeactivation'); - /** * Collect all items assigned to a member (tasks, policies, risks, vendors). */ @@ -44,9 +43,11 @@ export async function collectAssignedItems({ const items: UnassignedItem[] = []; for (const t of tasks) items.push({ type: 'task', id: t.id, name: t.title }); - for (const p of policies) items.push({ type: 'policy', id: p.id, name: p.name }); + for (const p of policies) + items.push({ type: 'policy', id: p.id, name: p.name }); for (const r of risks) items.push({ type: 'risk', id: r.id, name: r.title }); - for (const v of vendors) items.push({ type: 'vendor', id: v.id, name: v.name }); + for (const v of vendors) + items.push({ type: 'vendor', id: v.id, name: v.name }); return items; } diff --git a/apps/api/src/people/utils/member-queries.ts b/apps/api/src/people/utils/member-queries.ts index a6683a3974..3e789f0b2b 100644 --- a/apps/api/src/people/utils/member-queries.ts +++ b/apps/api/src/people/utils/member-queries.ts @@ -118,8 +118,7 @@ export class MemberQueries { updatePayload.fleetDmLabelId = null; } - const hasUserUpdates = - name !== undefined || email !== undefined; + const hasUserUpdates = name !== undefined || email !== undefined; const hasMemberUpdates = Object.keys(updatePayload).length > 0; // If we need to update both user and member, use a transaction diff --git a/apps/api/src/policies/policies.controller.spec.ts b/apps/api/src/policies/policies.controller.spec.ts index b58c0f97ff..767fbcba3d 100644 --- a/apps/api/src/policies/policies.controller.spec.ts +++ b/apps/api/src/policies/policies.controller.spec.ts @@ -110,9 +110,7 @@ describe('PoliciesController', () => { beforeEach(async () => { const module: TestingModule = await Test.createTestingModule({ controllers: [PoliciesController], - providers: [ - { provide: PoliciesService, useValue: mockPoliciesService }, - ], + providers: [{ provide: PoliciesService, useValue: mockPoliciesService }], }) .overrideGuard(HybridAuthGuard) .useValue(mockGuard) @@ -208,7 +206,11 @@ describe('PoliciesController', () => { const mockPolicy = { id: 'pol_1', name: 'Test Policy' }; mockPoliciesService.findById.mockResolvedValue(mockPolicy); - const result = await controller.getPolicy('pol_1', orgId, mockAuthContext); + const result = await controller.getPolicy( + 'pol_1', + orgId, + mockAuthContext, + ); expect(policiesService.findById).toHaveBeenCalledWith('pol_1', orgId); expect(result).toEqual({ diff --git a/apps/api/src/policies/policies.controller.ts b/apps/api/src/policies/policies.controller.ts index 6e8fbf66fc..7cde6fa40c 100644 --- a/apps/api/src/policies/policies.controller.ts +++ b/apps/api/src/policies/policies.controller.ts @@ -244,7 +244,9 @@ export class PoliciesController { where: { organizationId }, orderBy: { createdAt: 'asc' }, }); - const contextHub = contextEntries.map((c) => `${c.question}\n${c.answer}`).join('\n'); + const contextHub = contextEntries + .map((c) => `${c.question}\n${c.answer}`) + .join('\n'); const handle = await tasks.trigger('update-policy', { organizationId, @@ -316,7 +318,7 @@ export class PoliciesController { // Generate signed URL const { S3Client, GetObjectCommand } = await import('@aws-sdk/client-s3'); - const { getSignedUrl } = await import('@aws-sdk/s3-request-presigner'); + const { getSignedUrl } = await import('../app/s3.js'); const bucketName = process.env.APP_AWS_BUCKET_NAME; if (!bucketName) { @@ -345,19 +347,33 @@ export class PoliciesController { @ApiParam(POLICY_PARAMS.policyId) async uploadPolicyPdf( @Param('id') id: string, - @Body() body: { versionId?: string; fileName: string; fileType: string; fileData: string }, + @Body() + body: { + versionId?: string; + fileName: string; + fileType: string; + fileData: string; + }, @OrganizationId() organizationId: string, @AuthContext() authContext: AuthContextType, ) { - const { S3Client, PutObjectCommand, DeleteObjectCommand } = await import('@aws-sdk/client-s3'); + const { S3Client, PutObjectCommand, DeleteObjectCommand } = + await import('@aws-sdk/client-s3'); const bucketName = process.env.APP_AWS_BUCKET_NAME; - if (!bucketName) throw new BadRequestException('File storage is not configured'); + if (!bucketName) + throw new BadRequestException('File storage is not configured'); const s3 = new S3Client({ region: process.env.AWS_REGION || 'us-east-1' }); const policy = await db.policy.findFirst({ where: { id, organizationId }, - select: { id: true, status: true, pdfUrl: true, currentVersionId: true, pendingVersionId: true }, + select: { + id: true, + status: true, + pdfUrl: true, + currentVersionId: true, + pendingVersionId: true, + }, }); if (!policy) throw new NotFoundException('Policy not found'); @@ -371,19 +387,39 @@ export class PoliciesController { }); if (!version) throw new NotFoundException('Version not found'); if (version.id === policy.currentVersionId && policy.status !== 'draft') { - throw new BadRequestException('Cannot upload PDF to the published version'); + throw new BadRequestException( + 'Cannot upload PDF to the published version', + ); } if (version.id === policy.pendingVersionId) { - throw new BadRequestException('Cannot upload PDF to a version pending approval'); + throw new BadRequestException( + 'Cannot upload PDF to a version pending approval', + ); } const s3Key = `${organizationId}/policies/${id}/v${version.version}-${Date.now()}-${sanitizedFileName}`; - await s3.send(new PutObjectCommand({ Bucket: bucketName, Key: s3Key, Body: fileBuffer, ContentType: body.fileType })); + await s3.send( + new PutObjectCommand({ + Bucket: bucketName, + Key: s3Key, + Body: fileBuffer, + ContentType: body.fileType, + }), + ); const oldPdfUrl = version.pdfUrl; - await db.policyVersion.update({ where: { id: body.versionId }, data: { pdfUrl: s3Key } }); + await db.policyVersion.update({ + where: { id: body.versionId }, + data: { pdfUrl: s3Key }, + }); if (oldPdfUrl && oldPdfUrl !== s3Key) { - try { await s3.send(new DeleteObjectCommand({ Bucket: bucketName, Key: oldPdfUrl })); } catch { /* ignore */ } + try { + await s3.send( + new DeleteObjectCommand({ Bucket: bucketName, Key: oldPdfUrl }), + ); + } catch { + /* ignore */ + } } return { data: { s3Key }, authType: authContext.authType }; @@ -391,12 +427,28 @@ export class PoliciesController { // Legacy: upload to policy level const s3Key = `${organizationId}/policies/${id}/${Date.now()}-${sanitizedFileName}`; - await s3.send(new PutObjectCommand({ Bucket: bucketName, Key: s3Key, Body: fileBuffer, ContentType: body.fileType })); + await s3.send( + new PutObjectCommand({ + Bucket: bucketName, + Key: s3Key, + Body: fileBuffer, + ContentType: body.fileType, + }), + ); const oldPdfUrl = policy.pdfUrl; - await db.policy.update({ where: { id }, data: { pdfUrl: s3Key, displayFormat: 'PDF' } }); + await db.policy.update({ + where: { id }, + data: { pdfUrl: s3Key, displayFormat: 'PDF' }, + }); if (oldPdfUrl && oldPdfUrl !== s3Key) { - try { await s3.send(new DeleteObjectCommand({ Bucket: bucketName, Key: oldPdfUrl })); } catch { /* ignore */ } + try { + await s3.send( + new DeleteObjectCommand({ Bucket: bucketName, Key: oldPdfUrl }), + ); + } catch { + /* ignore */ + } } return { data: { s3Key }, authType: authContext.authType }; @@ -413,9 +465,11 @@ export class PoliciesController { @AuthContext() authContext: AuthContextType, @Query('versionId') versionId?: string, ) { - const { S3Client, DeleteObjectCommand } = await import('@aws-sdk/client-s3'); + const { S3Client, DeleteObjectCommand } = + await import('@aws-sdk/client-s3'); const bucketName = process.env.APP_AWS_BUCKET_NAME; - if (!bucketName) throw new BadRequestException('File storage is not configured'); + if (!bucketName) + throw new BadRequestException('File storage is not configured'); const s3 = new S3Client({ region: process.env.AWS_REGION || 'us-east-1' }); @@ -426,8 +480,20 @@ export class PoliciesController { }); if (!version) throw new NotFoundException('Version not found'); if (version.pdfUrl) { - try { await s3.send(new DeleteObjectCommand({ Bucket: bucketName, Key: version.pdfUrl })); } catch { /* ignore */ } - await db.policyVersion.update({ where: { id: versionId }, data: { pdfUrl: null } }); + try { + await s3.send( + new DeleteObjectCommand({ + Bucket: bucketName, + Key: version.pdfUrl, + }), + ); + } catch { + /* ignore */ + } + await db.policyVersion.update({ + where: { id: versionId }, + data: { pdfUrl: null }, + }); } } else { const policy = await db.policy.findFirst({ @@ -436,7 +502,13 @@ export class PoliciesController { }); if (!policy) throw new NotFoundException('Policy not found'); if (policy.pdfUrl) { - try { await s3.send(new DeleteObjectCommand({ Bucket: bucketName, Key: policy.pdfUrl })); } catch { /* ignore */ } + try { + await s3.send( + new DeleteObjectCommand({ Bucket: bucketName, Key: policy.pdfUrl }), + ); + } catch { + /* ignore */ + } await db.policy.update({ where: { id }, data: { pdfUrl: null } }); } } @@ -445,7 +517,10 @@ export class PoliciesController { success: true, authType: authContext.authType, ...(authContext.userId && { - authenticatedUser: { id: authContext.userId, email: authContext.userEmail }, + authenticatedUser: { + id: authContext.userId, + email: authContext.userEmail, + }, }), }; } @@ -480,12 +555,16 @@ export class PoliciesController { if (!pdfUrl) return { url: null }; const { S3Client, GetObjectCommand } = await import('@aws-sdk/client-s3'); - const { getSignedUrl } = await import('@aws-sdk/s3-request-presigner'); + const { getSignedUrl } = await import('../app/s3.js'); const bucketName = process.env.APP_AWS_BUCKET_NAME; if (!bucketName) return { url: null }; const s3 = new S3Client({ region: process.env.AWS_REGION || 'us-east-1' }); - const url = await getSignedUrl(s3, new GetObjectCommand({ Bucket: bucketName, Key: pdfUrl }), { expiresIn: 900 }); + const url = await getSignedUrl( + s3, + new GetObjectCommand({ Bucket: bucketName, Key: pdfUrl }), + { expiresIn: 900 }, + ); return { url }; } @@ -933,7 +1012,9 @@ export class PoliciesController { @Post(':id/accept-changes') @RequirePermission('policy', 'update') - @ApiOperation({ summary: 'Accept pending policy changes and publish the version' }) + @ApiOperation({ + summary: 'Accept pending policy changes and publish the version', + }) @ApiParam(POLICY_PARAMS.policyId) async acceptPolicyChanges( @Param('id') id: string, diff --git a/apps/api/src/policies/policies.service.ts b/apps/api/src/policies/policies.service.ts index 634b8a24b8..7209143967 100644 --- a/apps/api/src/policies/policies.service.ts +++ b/apps/api/src/policies/policies.service.ts @@ -97,11 +97,7 @@ export class PoliciesService { } } - async publishAll( - organizationId: string, - userId?: string, - memberId?: string, - ) { + async publishAll(organizationId: string, userId?: string, memberId?: string) { const draftPolicies = await db.policy.findMany({ where: { organizationId, status: 'draft', isArchived: false }, select: { id: true, name: true, frequency: true }, @@ -243,7 +239,9 @@ export class PoliciesService { include: { user: { select: { role: true } } }, }); if (assignee?.user.role === 'admin') { - throw new BadRequestException('Cannot assign a platform admin as assignee'); + throw new BadRequestException( + 'Cannot assign a platform admin as assignee', + ); } } const contentValue = createData.content as Prisma.InputJsonValue[]; @@ -1032,7 +1030,9 @@ export class PoliciesService { } if (policy.approverId !== dto.approverId) { - throw new BadRequestException('Only the assigned approver can accept changes'); + throw new BadRequestException( + 'Only the assigned approver can accept changes', + ); } const version = await db.policyVersion.findUnique({ @@ -1091,7 +1091,9 @@ export class PoliciesService { } if (policy.approverId !== dto.approverId) { - throw new BadRequestException('Only the assigned approver can deny changes'); + throw new BadRequestException( + 'Only the assigned approver can deny changes', + ); } // Revert policy to previous state (draft if never published, published if it was) @@ -1220,7 +1222,8 @@ export class PoliciesService { draft: 2, }; policies.sort( - (a, b) => (statusPriority[a.status] ?? 3) - (statusPriority[b.status] ?? 3), + (a, b) => + (statusPriority[a.status] ?? 3) - (statusPriority[b.status] ?? 3), ); const mergedPdf = await PDFDocument.create(); diff --git a/apps/api/src/policies/schemas/get-policy-by-id.responses.ts b/apps/api/src/policies/schemas/get-policy-by-id.responses.ts index 151c7b972f..8d01937fcd 100644 --- a/apps/api/src/policies/schemas/get-policy-by-id.responses.ts +++ b/apps/api/src/policies/schemas/get-policy-by-id.responses.ts @@ -40,7 +40,8 @@ export const GET_POLICY_BY_ID_RESPONSES: Record = { }, 403: { status: 403, - description: 'Forbidden - User does not have permission to access this policy', + description: + 'Forbidden - User does not have permission to access this policy', content: { 'application/json': { schema: { diff --git a/apps/api/src/questionnaire/questionnaire.controller.spec.ts b/apps/api/src/questionnaire/questionnaire.controller.spec.ts index 0422468af6..4e6ea5a3e1 100644 --- a/apps/api/src/questionnaire/questionnaire.controller.spec.ts +++ b/apps/api/src/questionnaire/questionnaire.controller.spec.ts @@ -218,10 +218,7 @@ describe('QuestionnaireController', () => { error: undefined, }); - const result = await controller.answerSingleQuestion( - dto as any, - 'org_1', - ); + const result = await controller.answerSingleQuestion(dto as any, 'org_1'); expect(result.success).toBe(true); expect(result.data.answer).toBe('Our policy covers...'); diff --git a/apps/api/src/questionnaire/questionnaire.controller.ts b/apps/api/src/questionnaire/questionnaire.controller.ts index 24a20d693b..4f7f5146a4 100644 --- a/apps/api/src/questionnaire/questionnaire.controller.ts +++ b/apps/api/src/questionnaire/questionnaire.controller.ts @@ -29,10 +29,7 @@ import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; import { PermissionGuard } from '../auth/permission.guard'; import { Public } from '../auth/public.decorator'; import { RequirePermission } from '../auth/require-permission.decorator'; -import { - OrganizationId, - AuthContext, -} from '../auth/auth-context.decorator'; +import { OrganizationId, AuthContext } from '../auth/auth-context.decorator'; import { AuditRead } from '../audit/skip-audit-log.decorator'; import type { AuthContext as AuthContextType } from '../auth/types'; import { ParseQuestionnaireDto } from './dto/parse-questionnaire.dto'; diff --git a/apps/api/src/questionnaire/questionnaire.service.spec.ts b/apps/api/src/questionnaire/questionnaire.service.spec.ts index 150d6370e3..21d009943c 100644 --- a/apps/api/src/questionnaire/questionnaire.service.spec.ts +++ b/apps/api/src/questionnaire/questionnaire.service.spec.ts @@ -297,9 +297,7 @@ describe('QuestionnaireService', () => { const result = await service.saveAnswer(baseSaveDto as any); expect(result).toEqual({ success: true }); - expect( - mockDb.questionnaireQuestionAnswer.update, - ).toHaveBeenCalledWith( + expect(mockDb.questionnaireQuestionAnswer.update).toHaveBeenCalledWith( expect.objectContaining({ where: { id: 'qa1' }, data: expect.objectContaining({ @@ -397,9 +395,7 @@ describe('QuestionnaireService', () => { } as any); expect(result).toEqual({ success: true }); - expect( - mockDb.questionnaireQuestionAnswer.update, - ).toHaveBeenCalledWith({ + expect(mockDb.questionnaireQuestionAnswer.update).toHaveBeenCalledWith({ where: { id: 'qa1' }, data: expect.objectContaining({ answer: null, diff --git a/apps/api/src/questionnaire/utils/content-extractor.spec.ts b/apps/api/src/questionnaire/utils/content-extractor.spec.ts index 5326df8437..9dbdf62064 100644 --- a/apps/api/src/questionnaire/utils/content-extractor.spec.ts +++ b/apps/api/src/questionnaire/utils/content-extractor.spec.ts @@ -52,8 +52,20 @@ describe('content-extractor: extractContentFromFile', () => { it('should extract content from multiple sheets', async () => { const buffer = await createTestExcelBuffer([ - { name: 'General', rows: [['Info', 'Details'], ['Name', 'Acme Corp']] }, - { name: 'Security', rows: [['Control', 'Status'], ['MFA', 'Enabled']] }, + { + name: 'General', + rows: [ + ['Info', 'Details'], + ['Name', 'Acme Corp'], + ], + }, + { + name: 'Security', + rows: [ + ['Control', 'Status'], + ['MFA', 'Enabled'], + ], + }, ]); const base64 = buffer.toString('base64'); diff --git a/apps/api/src/questionnaire/utils/content-extractor.ts b/apps/api/src/questionnaire/utils/content-extractor.ts index 77c5348fca..635e0740a1 100644 --- a/apps/api/src/questionnaire/utils/content-extractor.ts +++ b/apps/api/src/questionnaire/utils/content-extractor.ts @@ -722,9 +722,11 @@ async function extractFromExcelStandard( worksheet.eachRow((row) => { const cells = row.values as unknown[]; jsonData.push( - cells.slice(1).map((cell) => - cell !== null && cell !== undefined ? String(cell).trim() : '', - ), + cells + .slice(1) + .map((cell) => + cell !== null && cell !== undefined ? String(cell).trim() : '', + ), ); }); diff --git a/apps/api/src/questionnaire/utils/export-generator.spec.ts b/apps/api/src/questionnaire/utils/export-generator.spec.ts index 224d27b51f..09316bc90b 100644 --- a/apps/api/src/questionnaire/utils/export-generator.spec.ts +++ b/apps/api/src/questionnaire/utils/export-generator.spec.ts @@ -104,7 +104,11 @@ describe('generatePDF', () => { describe('generateExportFile', () => { it('should generate XLSX export with correct metadata', async () => { - const result = await generateExportFile(sampleQAs, 'xlsx', 'vendor-test.pdf'); + const result = await generateExportFile( + sampleQAs, + 'xlsx', + 'vendor-test.pdf', + ); expect(result.mimeType).toBe( 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', @@ -114,7 +118,11 @@ describe('generateExportFile', () => { }); it('should generate CSV export with correct metadata', async () => { - const result = await generateExportFile(sampleQAs, 'csv', 'vendor-test.xlsx'); + const result = await generateExportFile( + sampleQAs, + 'csv', + 'vendor-test.xlsx', + ); expect(result.mimeType).toBe('text/csv'); expect(result.filename).toBe('vendor-test.csv'); diff --git a/apps/api/src/questionnaire/utils/export-generator.ts b/apps/api/src/questionnaire/utils/export-generator.ts index 25d9fcd607..3d4dda6493 100644 --- a/apps/api/src/questionnaire/utils/export-generator.ts +++ b/apps/api/src/questionnaire/utils/export-generator.ts @@ -68,7 +68,11 @@ export async function generateXLSX( // Add data rows for (let i = 0; i < questionsAndAnswers.length; i++) { const qa = questionsAndAnswers[i]; - worksheet.addRow({ num: i + 1, question: qa.question, answer: qa.answer || '' }); + worksheet.addRow({ + num: i + 1, + question: qa.question, + answer: qa.answer || '', + }); } const xlsxBuffer = await workbook.xlsx.writeBuffer(); diff --git a/apps/api/src/risks/dto/get-risks-query.dto.ts b/apps/api/src/risks/dto/get-risks-query.dto.ts index 33c61b41ad..0ed01747df 100644 --- a/apps/api/src/risks/dto/get-risks-query.dto.ts +++ b/apps/api/src/risks/dto/get-risks-query.dto.ts @@ -1,18 +1,7 @@ import { ApiPropertyOptional } from '@nestjs/swagger'; -import { - IsEnum, - IsInt, - IsOptional, - IsString, - Max, - Min, -} from 'class-validator'; +import { IsEnum, IsInt, IsOptional, IsString, Max, Min } from 'class-validator'; import { Type } from 'class-transformer'; -import { - RiskCategory, - Departments, - RiskStatus, -} from '@db'; +import { RiskCategory, Departments, RiskStatus } from '@db'; export enum RiskSortBy { CREATED_AT = 'createdAt', diff --git a/apps/api/src/risks/risks.controller.spec.ts b/apps/api/src/risks/risks.controller.spec.ts index 06e494133f..a0d043371f 100644 --- a/apps/api/src/risks/risks.controller.spec.ts +++ b/apps/api/src/risks/risks.controller.spec.ts @@ -10,7 +10,15 @@ import { RisksService } from './risks.service'; jest.mock('@db', () => ({ ...jest.requireActual('@prisma/client'), db: {}, - Prisma: { PrismaClientKnownRequestError: class PrismaClientKnownRequestError extends Error { code: string; constructor(message: string, { code }: { code: string }) { super(message); this.code = code; } } }, + Prisma: { + PrismaClientKnownRequestError: class PrismaClientKnownRequestError extends Error { + code: string; + constructor(message: string, { code }: { code: string }) { + super(message); + this.code = code; + } + }, + }, })); jest.mock('../auth/auth.server', () => ({ @@ -110,7 +118,7 @@ describe('RisksController', () => { .compile(); controller = module.get(RisksController); - risksService = module.get(RisksService) as jest.Mocked; + risksService = module.get(RisksService); jest.clearAllMocks(); mockBuildRiskAssignmentFilter.mockReturnValue({}); @@ -126,7 +134,11 @@ describe('RisksController', () => { }; it('should call findAllByOrganization with correct parameters', async () => { - risksService.findAllByOrganization.mockResolvedValue(paginatedResult as unknown as Awaited>); + risksService.findAllByOrganization.mockResolvedValue( + paginatedResult as unknown as Awaited< + ReturnType + >, + ); const query = { page: 1, perPage: 10 }; await controller.getAllRisks(query, orgId, authContext); @@ -144,7 +156,11 @@ describe('RisksController', () => { }); it('should return paginated data with auth info', async () => { - risksService.findAllByOrganization.mockResolvedValue(paginatedResult as unknown as Awaited>); + risksService.findAllByOrganization.mockResolvedValue( + paginatedResult as unknown as Awaited< + ReturnType + >, + ); const result = await controller.getAllRisks({}, orgId, authContext); @@ -162,7 +178,11 @@ describe('RisksController', () => { }); it('should omit authenticatedUser when userId is not present', async () => { - risksService.findAllByOrganization.mockResolvedValue(paginatedResult as unknown as Awaited>); + risksService.findAllByOrganization.mockResolvedValue( + paginatedResult as unknown as Awaited< + ReturnType + >, + ); const result = await controller.getAllRisks({}, orgId, authContextNoUser); @@ -173,7 +193,11 @@ describe('RisksController', () => { it('should pass assignment filter from buildRiskAssignmentFilter', async () => { const assignmentFilter = { assigneeId: 'mem_123' }; mockBuildRiskAssignmentFilter.mockReturnValue(assignmentFilter); - risksService.findAllByOrganization.mockResolvedValue(paginatedResult as unknown as Awaited>); + risksService.findAllByOrganization.mockResolvedValue( + paginatedResult as unknown as Awaited< + ReturnType + >, + ); await controller.getAllRisks({}, orgId, authContext); @@ -240,7 +264,11 @@ describe('RisksController', () => { ]; it('should call getStatsByDepartment with organizationId', async () => { - risksService.getStatsByDepartment.mockResolvedValue(deptStats as unknown as Awaited>); + risksService.getStatsByDepartment.mockResolvedValue( + deptStats as unknown as Awaited< + ReturnType + >, + ); await controller.getStatsByDepartment(orgId, authContext); @@ -248,7 +276,11 @@ describe('RisksController', () => { }); it('should return data with auth info', async () => { - risksService.getStatsByDepartment.mockResolvedValue(deptStats as unknown as Awaited>); + risksService.getStatsByDepartment.mockResolvedValue( + deptStats as unknown as Awaited< + ReturnType + >, + ); const result = await controller.getStatsByDepartment(orgId, authContext); @@ -265,7 +297,11 @@ describe('RisksController', () => { describe('getRiskById', () => { it('should call findById with correct parameters', async () => { - risksService.findById.mockResolvedValue(mockRisk as unknown as Awaited>); + risksService.findById.mockResolvedValue( + mockRisk as unknown as Awaited< + ReturnType + >, + ); await controller.getRiskById('risk_1', orgId, authContext); @@ -273,7 +309,11 @@ describe('RisksController', () => { }); it('should return risk with auth info', async () => { - risksService.findById.mockResolvedValue(mockRisk as unknown as Awaited>); + risksService.findById.mockResolvedValue( + mockRisk as unknown as Awaited< + ReturnType + >, + ); const result = await controller.getRiskById('risk_1', orgId, authContext); @@ -288,7 +328,11 @@ describe('RisksController', () => { }); it('should check hasRiskAccess and throw ForbiddenException if denied', async () => { - risksService.findById.mockResolvedValue(mockRisk as unknown as Awaited>); + risksService.findById.mockResolvedValue( + mockRisk as unknown as Awaited< + ReturnType + >, + ); mockHasRiskAccess.mockReturnValue(false); await expect( @@ -304,7 +348,11 @@ describe('RisksController', () => { }); it('should pass isApiKey option to hasRiskAccess', async () => { - risksService.findById.mockResolvedValue(mockRisk as unknown as Awaited>); + risksService.findById.mockResolvedValue( + mockRisk as unknown as Awaited< + ReturnType + >, + ); await controller.getRiskById('risk_1', orgId, authContextNoUser); @@ -325,7 +373,9 @@ describe('RisksController', () => { }; it('should call create with organizationId and dto', async () => { - risksService.create.mockResolvedValue(mockRisk as unknown as Awaited>); + risksService.create.mockResolvedValue( + mockRisk as unknown as Awaited>, + ); await controller.createRisk(createDto, orgId, authContext); @@ -333,7 +383,9 @@ describe('RisksController', () => { }); it('should return created risk with auth info', async () => { - risksService.create.mockResolvedValue(mockRisk as unknown as Awaited>); + risksService.create.mockResolvedValue( + mockRisk as unknown as Awaited>, + ); const result = await controller.createRisk(createDto, orgId, authContext); @@ -348,7 +400,9 @@ describe('RisksController', () => { }); it('should omit authenticatedUser for API key auth', async () => { - risksService.create.mockResolvedValue(mockRisk as unknown as Awaited>); + risksService.create.mockResolvedValue( + mockRisk as unknown as Awaited>, + ); const result = await controller.createRisk( createDto, @@ -366,7 +420,11 @@ describe('RisksController', () => { const updatedRisk = { ...mockRisk, title: 'Updated Risk' }; it('should call updateById with correct parameters', async () => { - risksService.updateById.mockResolvedValue(updatedRisk as unknown as Awaited>); + risksService.updateById.mockResolvedValue( + updatedRisk as unknown as Awaited< + ReturnType + >, + ); await controller.updateRisk('risk_1', updateDto, orgId, authContext); @@ -378,7 +436,11 @@ describe('RisksController', () => { }); it('should return updated risk with auth info', async () => { - risksService.updateById.mockResolvedValue(updatedRisk as unknown as Awaited>); + risksService.updateById.mockResolvedValue( + updatedRisk as unknown as Awaited< + ReturnType + >, + ); const result = await controller.updateRisk( 'risk_1', diff --git a/apps/api/src/risks/risks.controller.ts b/apps/api/src/risks/risks.controller.ts index d0f7696d9f..44802f2b12 100644 --- a/apps/api/src/risks/risks.controller.ts +++ b/apps/api/src/risks/risks.controller.ts @@ -153,7 +153,11 @@ export class RisksController { const risk = await this.risksService.findById(riskId, organizationId); // Check assignment access for restricted roles - if (!hasRiskAccess(risk, authContext.memberId, authContext.userRoles, { isApiKey: authContext.isApiKey })) { + if ( + !hasRiskAccess(risk, authContext.memberId, authContext.userRoles, { + isApiKey: authContext.isApiKey, + }) + ) { throw new ForbiddenException('You do not have access to this risk'); } diff --git a/apps/api/src/risks/risks.service.ts b/apps/api/src/risks/risks.service.ts index 440c44d709..37c920d4ec 100644 --- a/apps/api/src/risks/risks.service.ts +++ b/apps/api/src/risks/risks.service.ts @@ -1,4 +1,9 @@ -import { BadRequestException, Injectable, NotFoundException, Logger } from '@nestjs/common'; +import { + BadRequestException, + Injectable, + NotFoundException, + Logger, +} from '@nestjs/common'; import { db, Prisma } from '@db'; import { CreateRiskDto } from './dto/create-risk.dto'; import { GetRisksQueryDto } from './dto/get-risks-query.dto'; @@ -25,13 +30,18 @@ export interface PaginatedRisksResult { export class RisksService { private readonly logger = new Logger(RisksService.name); - private async validateAssigneeNotPlatformAdmin(assigneeId: string, organizationId: string) { + private async validateAssigneeNotPlatformAdmin( + assigneeId: string, + organizationId: string, + ) { const member = await db.member.findFirst({ where: { id: assigneeId, organizationId }, include: { user: { select: { role: true } } }, }); if (member?.user.role === 'admin') { - throw new BadRequestException('Cannot assign a platform admin as assignee'); + throw new BadRequestException( + 'Cannot assign a platform admin as assignee', + ); } } @@ -141,7 +151,10 @@ export class RisksService { async create(organizationId: string, createRiskDto: CreateRiskDto) { try { if (createRiskDto.assigneeId) { - await this.validateAssigneeNotPlatformAdmin(createRiskDto.assigneeId, organizationId); + await this.validateAssigneeNotPlatformAdmin( + createRiskDto.assigneeId, + organizationId, + ); } const risk = await db.risk.create({ data: { @@ -173,7 +186,10 @@ export class RisksService { await this.findById(id, organizationId); if (updateRiskDto.assigneeId) { - await this.validateAssigneeNotPlatformAdmin(updateRiskDto.assigneeId, organizationId); + await this.validateAssigneeNotPlatformAdmin( + updateRiskDto.assigneeId, + organizationId, + ); } const updatedRisk = await db.risk.update({ diff --git a/apps/api/src/risks/schemas/get-risk-by-id.responses.ts b/apps/api/src/risks/schemas/get-risk-by-id.responses.ts index b4b8ae5cf7..34a6d94e77 100644 --- a/apps/api/src/risks/schemas/get-risk-by-id.responses.ts +++ b/apps/api/src/risks/schemas/get-risk-by-id.responses.ts @@ -154,7 +154,8 @@ export const GET_RISK_BY_ID_RESPONSES: Record = { }, 403: { status: 403, - description: 'Forbidden - User does not have permission to access this risk', + description: + 'Forbidden - User does not have permission to access this risk', content: { 'application/json': { schema: { diff --git a/apps/api/src/roles/dto/create-role.dto.ts b/apps/api/src/roles/dto/create-role.dto.ts index dc017851c1..3b549f3f86 100644 --- a/apps/api/src/roles/dto/create-role.dto.ts +++ b/apps/api/src/roles/dto/create-role.dto.ts @@ -1,5 +1,13 @@ import { ApiProperty } from '@nestjs/swagger'; -import { IsNotEmpty, IsObject, IsOptional, IsString, MaxLength, MinLength, Matches } from 'class-validator'; +import { + IsNotEmpty, + IsObject, + IsOptional, + IsString, + MaxLength, + MinLength, + Matches, +} from 'class-validator'; export class CreateRoleDto { @ApiProperty({ @@ -13,12 +21,14 @@ export class CreateRoleDto { @MinLength(2) @MaxLength(50) @Matches(/^[a-zA-Z][a-zA-Z0-9\s-]*$/, { - message: 'Role name must start with a letter and contain only letters, numbers, spaces, and hyphens', + message: + 'Role name must start with a letter and contain only letters, numbers, spaces, and hyphens', }) name: string; @ApiProperty({ - description: 'Permissions for the role. Keys are resource names, values are arrays of allowed actions.', + description: + 'Permissions for the role. Keys are resource names, values are arrays of allowed actions.', example: { control: ['read', 'update'], policy: ['read', 'update'], @@ -30,7 +40,8 @@ export class CreateRoleDto { permissions: Record; @ApiProperty({ - description: 'Obligations for the role. Boolean flags for requirements like compliance.', + description: + 'Obligations for the role. Boolean flags for requirements like compliance.', example: { compliance: true }, required: false, }) diff --git a/apps/api/src/roles/dto/update-role.dto.ts b/apps/api/src/roles/dto/update-role.dto.ts index 268cb8fbc6..f24b61e251 100644 --- a/apps/api/src/roles/dto/update-role.dto.ts +++ b/apps/api/src/roles/dto/update-role.dto.ts @@ -1,5 +1,12 @@ import { ApiProperty } from '@nestjs/swagger'; -import { IsObject, IsOptional, IsString, MaxLength, MinLength, Matches } from 'class-validator'; +import { + IsObject, + IsOptional, + IsString, + MaxLength, + MinLength, + Matches, +} from 'class-validator'; export class UpdateRoleDto { @ApiProperty({ @@ -14,12 +21,14 @@ export class UpdateRoleDto { @MinLength(2) @MaxLength(50) @Matches(/^[a-zA-Z][a-zA-Z0-9\s-]*$/, { - message: 'Role name must start with a letter and contain only letters, numbers, spaces, and hyphens', + message: + 'Role name must start with a letter and contain only letters, numbers, spaces, and hyphens', }) name?: string; @ApiProperty({ - description: 'Updated permissions for the role. Keys are resource names, values are arrays of allowed actions.', + description: + 'Updated permissions for the role. Keys are resource names, values are arrays of allowed actions.', example: { control: ['read', 'update', 'delete'], policy: ['read', 'update', 'delete'], diff --git a/apps/api/src/roles/roles.controller.spec.ts b/apps/api/src/roles/roles.controller.spec.ts index b1a8cb02a2..47aab46b6b 100644 --- a/apps/api/src/roles/roles.controller.spec.ts +++ b/apps/api/src/roles/roles.controller.spec.ts @@ -81,10 +81,16 @@ describe('RolesController', () => { mockRolesService.createRole.mockResolvedValue(expectedRole); - const result = await controller.createRole('org_123', mockAuthContext, dto); + const result = await controller.createRole( + 'org_123', + mockAuthContext, + dto, + ); expect(result).toEqual(expectedRole); - expect(rolesService.createRole).toHaveBeenCalledWith('org_123', dto, ['owner']); + expect(rolesService.createRole).toHaveBeenCalledWith('org_123', dto, [ + 'owner', + ]); }); it('should pass multiple roles to service', async () => { @@ -121,7 +127,9 @@ describe('RolesController', () => { await controller.createRole('org_123', noRoleContext, dto); - expect(rolesService.createRole).toHaveBeenCalledWith('org_123', dto, ['employee']); + expect(rolesService.createRole).toHaveBeenCalledWith('org_123', dto, [ + 'employee', + ]); }); }); @@ -200,23 +208,31 @@ describe('RolesController', () => { await controller.updateRole('org_123', multiRoleContext, 'rol_123', dto); - expect(rolesService.updateRole).toHaveBeenCalledWith('org_123', 'rol_123', dto, [ - 'owner', - 'admin', - ]); + expect(rolesService.updateRole).toHaveBeenCalledWith( + 'org_123', + 'rol_123', + dto, + ['owner', 'admin'], + ); }); }); describe('deleteRole', () => { it('should delete a role', async () => { - const expectedResult = { success: true, message: "Role 'custom-role' deleted" }; + const expectedResult = { + success: true, + message: "Role 'custom-role' deleted", + }; mockRolesService.deleteRole.mockResolvedValue(expectedResult); const result = await controller.deleteRole('org_123', 'rol_123'); expect(result).toEqual(expectedResult); - expect(rolesService.deleteRole).toHaveBeenCalledWith('org_123', 'rol_123'); + expect(rolesService.deleteRole).toHaveBeenCalledWith( + 'org_123', + 'rol_123', + ); }); }); }); diff --git a/apps/api/src/roles/roles.controller.ts b/apps/api/src/roles/roles.controller.ts index c9e7b14372..6515a9e0ca 100644 --- a/apps/api/src/roles/roles.controller.ts +++ b/apps/api/src/roles/roles.controller.ts @@ -37,7 +37,8 @@ export class RolesController { @RequirePermission('ac', 'create') @ApiOperation({ summary: 'Create a custom role', - description: 'Create a new custom role with specified permissions. Only admins and owners can create roles.', + description: + 'Create a new custom role with specified permissions. Only admins and owners can create roles.', }) @ApiBody({ type: CreateRoleDto }) @ApiResponse({ @@ -61,9 +62,15 @@ export class RolesController { }, }, }) - @ApiResponse({ status: 400, description: 'Invalid role data or role already exists' }) + @ApiResponse({ + status: 400, + description: 'Invalid role data or role already exists', + }) @ApiResponse({ status: 401, description: 'Unauthorized' }) - @ApiResponse({ status: 403, description: 'Forbidden - cannot grant permissions you do not have' }) + @ApiResponse({ + status: 403, + description: 'Forbidden - cannot grant permissions you do not have', + }) async createRole( @OrganizationId() organizationId: string, @AuthContext() authContext: AuthContextType, @@ -80,7 +87,8 @@ export class RolesController { @RequirePermission('ac', 'read') @ApiOperation({ summary: 'List all roles', - description: 'List all roles for the organization, including built-in and custom roles.', + description: + 'List all roles for the organization, including built-in and custom roles.', }) @ApiResponse({ status: 200, @@ -153,16 +161,14 @@ export class RolesController { .split(',') .map((r) => r.trim()) .filter(Boolean); - const permissions = - await this.rolesService.getPermissionsForRoles( - organizationId, - roleNames, - ); - const obligations = - await this.rolesService.getObligationsForRoles( - organizationId, - roleNames, - ); + const permissions = await this.rolesService.getPermissionsForRoles( + organizationId, + roleNames, + ); + const obligations = await this.rolesService.getObligationsForRoles( + organizationId, + roleNames, + ); return { permissions, obligations }; } @@ -201,7 +207,8 @@ export class RolesController { @RequirePermission('ac', 'update') @ApiOperation({ summary: 'Update a custom role', - description: 'Update the name or permissions of a custom role. Cannot modify built-in roles.', + description: + 'Update the name or permissions of a custom role. Cannot modify built-in roles.', }) @ApiParam({ name: 'roleId', description: 'Role ID', example: 'rol_abc123' }) @ApiBody({ type: UpdateRoleDto }) @@ -222,7 +229,10 @@ export class RolesController { }) @ApiResponse({ status: 400, description: 'Invalid role data' }) @ApiResponse({ status: 401, description: 'Unauthorized' }) - @ApiResponse({ status: 403, description: 'Forbidden - cannot grant permissions you do not have' }) + @ApiResponse({ + status: 403, + description: 'Forbidden - cannot grant permissions you do not have', + }) @ApiResponse({ status: 404, description: 'Role not found' }) async updateRole( @OrganizationId() organizationId: string, @@ -242,7 +252,8 @@ export class RolesController { @RequirePermission('ac', 'delete') @ApiOperation({ summary: 'Delete a custom role', - description: 'Delete a custom role. Cannot delete if members are still assigned to it.', + description: + 'Delete a custom role. Cannot delete if members are still assigned to it.', }) @ApiParam({ name: 'roleId', description: 'Role ID', example: 'rol_abc123' }) @ApiResponse({ @@ -256,7 +267,10 @@ export class RolesController { }, }, }) - @ApiResponse({ status: 400, description: 'Cannot delete - members assigned to role' }) + @ApiResponse({ + status: 400, + description: 'Cannot delete - members assigned to role', + }) @ApiResponse({ status: 401, description: 'Unauthorized' }) @ApiResponse({ status: 404, description: 'Role not found' }) async deleteRole( diff --git a/apps/api/src/roles/roles.service.spec.ts b/apps/api/src/roles/roles.service.spec.ts index 7b065efa0b..3aeba0fd09 100644 --- a/apps/api/src/roles/roles.service.spec.ts +++ b/apps/api/src/roles/roles.service.spec.ts @@ -1,5 +1,9 @@ import { Test, TestingModule } from '@nestjs/testing'; -import { BadRequestException, ForbiddenException, NotFoundException } from '@nestjs/common'; +import { + BadRequestException, + ForbiddenException, + NotFoundException, +} from '@nestjs/common'; import { RolesService } from './roles.service'; // Mock @trycompai/auth to avoid ESM import issues with better-auth in Jest @@ -27,9 +31,15 @@ jest.mock('@trycompai/auth', () => { }; const BUILT_IN_ROLE_PERMISSIONS: Record> = { - owner: { ...Object.fromEntries(Object.entries(statement).map(([k, v]) => [k, [...v]])) }, + owner: { + ...Object.fromEntries( + Object.entries(statement).map(([k, v]) => [k, [...v]]), + ), + }, admin: { - ...Object.fromEntries(Object.entries(statement).map(([k, v]) => [k, [...v]])), + ...Object.fromEntries( + Object.entries(statement).map(([k, v]) => [k, [...v]]), + ), organization: ['read', 'update'], }, auditor: { @@ -70,7 +80,12 @@ jest.mock('@trycompai/auth', () => { contractor: { compliance: true }, }; - return { statement, allRoles, BUILT_IN_ROLE_PERMISSIONS, BUILT_IN_ROLE_OBLIGATIONS }; + return { + statement, + allRoles, + BUILT_IN_ROLE_PERMISSIONS, + BUILT_IN_ROLE_OBLIGATIONS, + }; }); // Mock the database @@ -132,7 +147,9 @@ describe('RolesService', () => { (mockDb.organizationRole.count as jest.Mock).mockResolvedValue(0); (mockDb.organizationRole.create as jest.Mock).mockResolvedValue(mockRole); - const result = await service.createRole(organizationId, validDto, ['owner']); + const result = await service.createRole(organizationId, validDto, [ + 'owner', + ]); expect(result.permissions).toEqual(validDto.permissions); expect(mockDb.organizationRole.create).toHaveBeenCalledWith({ @@ -148,12 +165,12 @@ describe('RolesService', () => { it('should reject built-in role names', async () => { const dto = { name: 'owner', permissions: { control: ['read'] } }; - await expect(service.createRole(organizationId, dto, ['owner'])).rejects.toThrow( - BadRequestException, - ); - await expect(service.createRole(organizationId, dto, ['owner'])).rejects.toThrow( - 'Cannot create role with reserved name: owner', - ); + await expect( + service.createRole(organizationId, dto, ['owner']), + ).rejects.toThrow(BadRequestException); + await expect( + service.createRole(organizationId, dto, ['owner']), + ).rejects.toThrow('Cannot create role with reserved name: owner'); }); it('should reject invalid resource names', async () => { @@ -162,12 +179,12 @@ describe('RolesService', () => { permissions: { invalidResource: ['read'] }, }; - await expect(service.createRole(organizationId, dto, ['owner'])).rejects.toThrow( - BadRequestException, - ); - await expect(service.createRole(organizationId, dto, ['owner'])).rejects.toThrow( - 'Invalid resource: invalidResource', - ); + await expect( + service.createRole(organizationId, dto, ['owner']), + ).rejects.toThrow(BadRequestException); + await expect( + service.createRole(organizationId, dto, ['owner']), + ).rejects.toThrow('Invalid resource: invalidResource'); }); it('should reject invalid actions for valid resources', async () => { @@ -176,10 +193,12 @@ describe('RolesService', () => { permissions: { control: ['read', 'invalidAction'] }, }; - await expect(service.createRole(organizationId, dto, ['owner'])).rejects.toThrow( - BadRequestException, - ); - await expect(service.createRole(organizationId, dto, ['owner'])).rejects.toThrow( + await expect( + service.createRole(organizationId, dto, ['owner']), + ).rejects.toThrow(BadRequestException); + await expect( + service.createRole(organizationId, dto, ['owner']), + ).rejects.toThrow( "Invalid action 'invalidAction' for resource 'control'", ); }); @@ -190,24 +209,24 @@ describe('RolesService', () => { name: validDto.name, }); - await expect(service.createRole(organizationId, validDto, ['owner'])).rejects.toThrow( - BadRequestException, - ); - await expect(service.createRole(organizationId, validDto, ['owner'])).rejects.toThrow( - `Role '${validDto.name}' already exists`, - ); + await expect( + service.createRole(organizationId, validDto, ['owner']), + ).rejects.toThrow(BadRequestException); + await expect( + service.createRole(organizationId, validDto, ['owner']), + ).rejects.toThrow(`Role '${validDto.name}' already exists`); }); it('should enforce maximum 20 roles per organization', async () => { (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue(null); (mockDb.organizationRole.count as jest.Mock).mockResolvedValue(20); - await expect(service.createRole(organizationId, validDto, ['owner'])).rejects.toThrow( - BadRequestException, - ); - await expect(service.createRole(organizationId, validDto, ['owner'])).rejects.toThrow( - 'Maximum of 20 custom roles per organization', - ); + await expect( + service.createRole(organizationId, validDto, ['owner']), + ).rejects.toThrow(BadRequestException); + await expect( + service.createRole(organizationId, validDto, ['owner']), + ).rejects.toThrow('Maximum of 20 custom roles per organization'); }); it('should prevent privilege escalation - cannot grant permissions you do not have', async () => { @@ -219,9 +238,9 @@ describe('RolesService', () => { }, }; - await expect(service.createRole(organizationId, dto, ['employee'])).rejects.toThrow( - ForbiddenException, - ); + await expect( + service.createRole(organizationId, dto, ['employee']), + ).rejects.toThrow(ForbiddenException); }); it('should allow owners to grant organization:delete', async () => { @@ -243,7 +262,9 @@ describe('RolesService', () => { }); // Should not throw for owner - await expect(service.createRole(organizationId, dto, ['owner'])).resolves.toBeDefined(); + await expect( + service.createRole(organizationId, dto, ['owner']), + ).resolves.toBeDefined(); }); it('should prevent non-owners from granting organization:delete', async () => { @@ -255,10 +276,12 @@ describe('RolesService', () => { }; // Admin doesn't have organization:delete permission, so privilege escalation check fails first - await expect(service.createRole(organizationId, dto, ['admin'])).rejects.toThrow( - ForbiddenException, - ); - await expect(service.createRole(organizationId, dto, ['admin'])).rejects.toThrow( + await expect( + service.createRole(organizationId, dto, ['admin']), + ).rejects.toThrow(ForbiddenException); + await expect( + service.createRole(organizationId, dto, ['admin']), + ).rejects.toThrow( "Cannot grant 'organization:delete' permission - you don't have this permission", ); }); @@ -305,7 +328,9 @@ describe('RolesService', () => { }, ]; - (mockDb.organizationRole.findMany as jest.Mock).mockResolvedValue(customRoles); + (mockDb.organizationRole.findMany as jest.Mock).mockResolvedValue( + customRoles, + ); const result = await service.listRoles('org_123'); @@ -333,7 +358,9 @@ describe('RolesService', () => { updatedAt: new Date(), }; - (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue(mockRole); + (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue( + mockRole, + ); const result = await service.getRole('org_123', 'rol_123'); @@ -344,9 +371,9 @@ describe('RolesService', () => { it('should throw NotFoundException for non-existent role', async () => { (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue(null); - await expect(service.getRole('org_123', 'rol_nonexistent')).rejects.toThrow( - NotFoundException, - ); + await expect( + service.getRole('org_123', 'rol_nonexistent'), + ).rejects.toThrow(NotFoundException); }); }); @@ -391,10 +418,14 @@ describe('RolesService', () => { obligations: '{}', }; - (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue(existingRole); + (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue( + existingRole, + ); await expect( - service.updateRole(organizationId, roleId, { name: 'admin' }, ['owner']), + service.updateRole(organizationId, roleId, { name: 'admin' }, [ + 'owner', + ]), ).rejects.toThrow(BadRequestException); }); @@ -402,7 +433,9 @@ describe('RolesService', () => { (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue(null); await expect( - service.updateRole(organizationId, roleId, { name: 'new-name' }, ['owner']), + service.updateRole(organizationId, roleId, { name: 'new-name' }, [ + 'owner', + ]), ).rejects.toThrow(NotFoundException); }); @@ -414,7 +447,9 @@ describe('RolesService', () => { obligations: '{}', }; - (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue(existingRole); + (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue( + existingRole, + ); // Employee trying to add organization:delete to a role await expect( @@ -440,9 +475,13 @@ describe('RolesService', () => { obligations: '{}', }; - (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue(existingRole); + (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue( + existingRole, + ); (mockDb.member.count as jest.Mock).mockResolvedValue(0); - (mockDb.organizationRole.delete as jest.Mock).mockResolvedValue(existingRole); + (mockDb.organizationRole.delete as jest.Mock).mockResolvedValue( + existingRole, + ); const result = await service.deleteRole(organizationId, roleId); @@ -460,7 +499,9 @@ describe('RolesService', () => { obligations: '{}', }; - (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue(existingRole); + (mockDb.organizationRole.findFirst as jest.Mock).mockResolvedValue( + existingRole, + ); (mockDb.member.count as jest.Mock).mockResolvedValue(3); await expect(service.deleteRole(organizationId, roleId)).rejects.toThrow( diff --git a/apps/api/src/roles/roles.service.ts b/apps/api/src/roles/roles.service.ts index a00b5bb0c7..3d77ce81b1 100644 --- a/apps/api/src/roles/roles.service.ts +++ b/apps/api/src/roles/roles.service.ts @@ -1,6 +1,17 @@ -import { Injectable, BadRequestException, NotFoundException, ForbiddenException } from '@nestjs/common'; +import { + Injectable, + BadRequestException, + NotFoundException, + ForbiddenException, +} from '@nestjs/common'; import { db } from '@db'; -import { statement, allRoles, BUILT_IN_ROLE_PERMISSIONS, BUILT_IN_ROLE_OBLIGATIONS, type RoleObligations } from '@trycompai/auth'; +import { + statement, + allRoles, + BUILT_IN_ROLE_PERMISSIONS, + BUILT_IN_ROLE_OBLIGATIONS, + type RoleObligations, +} from '@trycompai/auth'; import type { CreateRoleDto } from './dto/create-role.dto'; import type { UpdateRoleDto } from './dto/update-role.dto'; @@ -27,7 +38,7 @@ export class RolesService { for (const action of actions) { if (!validActions.includes(action)) { throw new BadRequestException( - `Invalid action '${action}' for resource '${resource}'. Valid actions: ${validActions.join(', ')}` + `Invalid action '${action}' for resource '${resource}'. Valid actions: ${validActions.join(', ')}`, ); } } @@ -44,7 +55,10 @@ export class RolesService { organizationId: string, ): Promise { // Get the caller's combined effective permissions from all their roles - const callerPermissions = await this.getCombinedPermissions(callerRoles, organizationId); + const callerPermissions = await this.getCombinedPermissions( + callerRoles, + organizationId, + ); for (const [resource, actions] of Object.entries(permissions)) { const callerActions = callerPermissions[resource] || []; @@ -52,16 +66,19 @@ export class RolesService { for (const action of actions) { if (!callerActions.includes(action)) { throw new ForbiddenException( - `Cannot grant '${resource}:${action}' permission - you don't have this permission` + `Cannot grant '${resource}:${action}' permission - you don't have this permission`, ); } } } // Special check: only owners can grant organization:delete - if (permissions.organization?.includes('delete') && !callerRoles.includes('owner')) { + if ( + permissions.organization?.includes('delete') && + !callerRoles.includes('owner') + ) { throw new ForbiddenException( - 'Only organization owners can grant organization:delete permission' + 'Only organization owners can grant organization:delete permission', ); } } @@ -87,7 +104,10 @@ export class RolesService { const combined: Record = {}; for (const roleName of roleNames) { - const rolePermissions = await this.getEffectivePermissions(roleName, organizationId); + const rolePermissions = await this.getEffectivePermissions( + roleName, + organizationId, + ); for (const [resource, actions] of Object.entries(rolePermissions)) { if (!combined[resource]) { @@ -126,9 +146,10 @@ export class RolesService { }); if (customRole) { - const perms = typeof customRole.permissions === 'string' - ? JSON.parse(customRole.permissions) - : customRole.permissions; + const perms = + typeof customRole.permissions === 'string' + ? JSON.parse(customRole.permissions) + : customRole.permissions; return perms as Record; } @@ -146,14 +167,20 @@ export class RolesService { ) { // Validate role name isn't a built-in role if (BUILT_IN_ROLES.includes(dto.name)) { - throw new BadRequestException(`Cannot create role with reserved name: ${dto.name}`); + throw new BadRequestException( + `Cannot create role with reserved name: ${dto.name}`, + ); } // Validate permissions this.validatePermissions(dto.permissions); // Check for privilege escalation - await this.validateNoPrivilegeEscalation(callerRoles, dto.permissions, organizationId); + await this.validateNoPrivilegeEscalation( + callerRoles, + dto.permissions, + organizationId, + ); // Check if role already exists const existing = await db.organizationRole.findFirst({ @@ -173,7 +200,9 @@ export class RolesService { }); if (roleCount >= 20) { - throw new BadRequestException('Maximum of 20 custom roles per organization'); + throw new BadRequestException( + 'Maximum of 20 custom roles per organization', + ); } // Create the role @@ -215,7 +244,7 @@ export class RolesService { const countMap = new Map(memberCounts.map((mc) => [mc.roleId, mc.count])); // Include built-in roles info - const builtInRoles = BUILT_IN_ROLES.map(name => ({ + const builtInRoles = BUILT_IN_ROLES.map((name) => ({ name, isBuiltIn: true, description: this.getBuiltInRoleDescription(name), @@ -223,11 +252,17 @@ export class RolesService { return { builtInRoles, - customRoles: customRoles.map(r => ({ + customRoles: customRoles.map((r) => ({ id: r.id, name: r.name, - permissions: typeof r.permissions === 'string' ? JSON.parse(r.permissions) : r.permissions, - obligations: typeof r.obligations === 'string' ? JSON.parse(r.obligations) : (r.obligations || {}), + permissions: + typeof r.permissions === 'string' + ? JSON.parse(r.permissions) + : r.permissions, + obligations: + typeof r.obligations === 'string' + ? JSON.parse(r.obligations) + : r.obligations || {}, isBuiltIn: false, createdAt: r.createdAt.toISOString(), updatedAt: r.updatedAt.toISOString(), @@ -258,8 +293,14 @@ export class RolesService { return { id: role.id, name: role.name, - permissions: typeof role.permissions === 'string' ? JSON.parse(role.permissions) : role.permissions, - obligations: typeof role.obligations === 'string' ? JSON.parse(role.obligations) : (role.obligations || {}), + permissions: + typeof role.permissions === 'string' + ? JSON.parse(role.permissions) + : role.permissions, + obligations: + typeof role.obligations === 'string' + ? JSON.parse(role.obligations) + : role.obligations || {}, isBuiltIn: false, createdAt: role.createdAt.toISOString(), updatedAt: role.updatedAt.toISOString(), @@ -310,7 +351,11 @@ export class RolesService { // Validate and check permissions if provided if (dto.permissions) { this.validatePermissions(dto.permissions); - await this.validateNoPrivilegeEscalation(callerRoles, dto.permissions, organizationId); + await this.validateNoPrivilegeEscalation( + callerRoles, + dto.permissions, + organizationId, + ); } // Update the role @@ -318,16 +363,26 @@ export class RolesService { where: { id: roleId }, data: { ...(dto.name && { name: dto.name }), - ...(dto.permissions && { permissions: JSON.stringify(dto.permissions) }), - ...(dto.obligations !== undefined && { obligations: JSON.stringify(dto.obligations) }), + ...(dto.permissions && { + permissions: JSON.stringify(dto.permissions), + }), + ...(dto.obligations !== undefined && { + obligations: JSON.stringify(dto.obligations), + }), }, }); return { id: updated.id, name: updated.name, - permissions: typeof updated.permissions === 'string' ? JSON.parse(updated.permissions) : updated.permissions, - obligations: typeof updated.obligations === 'string' ? JSON.parse(updated.obligations) : (updated.obligations || {}), + permissions: + typeof updated.permissions === 'string' + ? JSON.parse(updated.permissions) + : updated.permissions, + obligations: + typeof updated.obligations === 'string' + ? JSON.parse(updated.obligations) + : updated.obligations || {}, isBuiltIn: false, createdAt: updated.createdAt, updatedAt: updated.updatedAt, @@ -360,7 +415,7 @@ export class RolesService { if (membersWithRole > 0) { throw new BadRequestException( `Cannot delete role '${role.name}' - ${membersWithRole} member(s) are assigned to it. ` + - `Reassign them to a different role first.` + `Reassign them to a different role first.`, ); } @@ -429,9 +484,10 @@ export class RolesService { const combined: RoleObligations = {}; for (const role of customRoles) { - const obligations = typeof role.obligations === 'string' - ? JSON.parse(role.obligations) - : (role.obligations || {}); + const obligations = + typeof role.obligations === 'string' + ? JSON.parse(role.obligations) + : role.obligations || {}; if (obligations.compliance) combined.compliance = true; } @@ -445,8 +501,10 @@ export class RolesService { const descriptions: Record = { owner: 'Full access to everything including organization deletion', admin: 'Full access except organization deletion', - auditor: 'Read-only access with export capabilities for compliance audits', - employee: 'Limited access to assigned tasks and basic compliance activities', + auditor: + 'Read-only access with export capabilities for compliance audits', + employee: + 'Limited access to assigned tasks and basic compliance activities', contractor: 'Limited access similar to employee for external contractors', }; return descriptions[name] || ''; diff --git a/apps/api/src/scripts/seed-dynamic-integration.ts b/apps/api/src/scripts/seed-dynamic-integration.ts index ec14a6168e..88c3ab8b46 100644 --- a/apps/api/src/scripts/seed-dynamic-integration.ts +++ b/apps/api/src/scripts/seed-dynamic-integration.ts @@ -34,7 +34,9 @@ async function main() { const content = readFileSync(absolutePath, 'utf-8'); rawJson = JSON.parse(content); } catch (error) { - console.error(`Failed to read/parse JSON file: ${error instanceof Error ? error.message : String(error)}`); + console.error( + `Failed to read/parse JSON file: ${error instanceof Error ? error.message : String(error)}`, + ); process.exit(1); } @@ -49,7 +51,9 @@ async function main() { } const def = validation.data!; - console.log(`Validated: ${def.name} (${def.slug}) with ${def.checks.length} checks`); + console.log( + `Validated: ${def.name} (${def.slug}) with ${def.checks.length} checks`, + ); // Helper to convert to Prisma-compatible JSON const toJson = (val: unknown) => JSON.parse(JSON.stringify(val)); @@ -65,7 +69,9 @@ async function main() { logoUrl: def.logoUrl, docsUrl: def.docsUrl, baseUrl: def.baseUrl, - defaultHeaders: def.defaultHeaders ? toJson(def.defaultHeaders) : undefined, + defaultHeaders: def.defaultHeaders + ? toJson(def.defaultHeaders) + : undefined, authConfig: toJson(def.authConfig), capabilities: toJson(def.capabilities), supportsMultipleConnections: def.supportsMultipleConnections ?? false, @@ -81,7 +87,9 @@ async function main() { logoUrl: def.logoUrl, docsUrl: def.docsUrl, baseUrl: def.baseUrl, - defaultHeaders: def.defaultHeaders ? toJson(def.defaultHeaders) : undefined, + defaultHeaders: def.defaultHeaders + ? toJson(def.defaultHeaders) + : undefined, authConfig: toJson(def.authConfig), capabilities: toJson(def.capabilities), supportsMultipleConnections: def.supportsMultipleConnections ?? false, @@ -149,7 +157,9 @@ async function main() { console.log(`Upserted IntegrationProvider for ${def.slug}`); console.log(`\nDone! Integration "${def.name}" is now live.`); - console.log('The registry will pick it up on the next refresh (within 60 seconds) or on API restart.'); + console.log( + 'The registry will pick it up on the next refresh (within 60 seconds) or on API restart.', + ); process.exit(0); } diff --git a/apps/api/src/secrets/encryption.util.ts b/apps/api/src/secrets/encryption.util.ts index 7a28488dd5..1a6c39c7d1 100644 --- a/apps/api/src/secrets/encryption.util.ts +++ b/apps/api/src/secrets/encryption.util.ts @@ -1,4 +1,9 @@ -import { createCipheriv, createDecipheriv, randomBytes, scryptSync } from 'node:crypto'; +import { + createCipheriv, + createDecipheriv, + randomBytes, + scryptSync, +} from 'node:crypto'; const ALGORITHM = 'aes-256-gcm'; const IV_LENGTH = 12; @@ -27,7 +32,10 @@ export function encrypt(text: string): EncryptedData { const key = deriveKey(secretKey, salt); const cipher = createCipheriv(ALGORITHM, key, iv); - const encrypted = Buffer.concat([cipher.update(text, 'utf8'), cipher.final()]); + const encrypted = Buffer.concat([ + cipher.update(text, 'utf8'), + cipher.final(), + ]); const tag = cipher.getAuthTag(); return { @@ -53,6 +61,9 @@ export function decrypt(encryptedData: EncryptedData): string { const decipher = createDecipheriv(ALGORITHM, key, iv); decipher.setAuthTag(tag); - const decrypted = Buffer.concat([decipher.update(encrypted), decipher.final()]); + const decrypted = Buffer.concat([ + decipher.update(encrypted), + decipher.final(), + ]); return decrypted.toString('utf8'); } diff --git a/apps/api/src/secrets/secrets.controller.spec.ts b/apps/api/src/secrets/secrets.controller.spec.ts index 58c0088b7d..776d5478f4 100644 --- a/apps/api/src/secrets/secrets.controller.spec.ts +++ b/apps/api/src/secrets/secrets.controller.spec.ts @@ -129,10 +129,7 @@ describe('SecretsController', () => { mockAuthContext, ); - expect(secretsService.createSecret).toHaveBeenCalledWith( - 'org_123', - body, - ); + expect(secretsService.createSecret).toHaveBeenCalledWith('org_123', body); expect(result).toEqual({ secret: created, authType: 'session', diff --git a/apps/api/src/secrets/secrets.controller.ts b/apps/api/src/secrets/secrets.controller.ts index 427ff8c15b..8eb09428e4 100644 --- a/apps/api/src/secrets/secrets.controller.ts +++ b/apps/api/src/secrets/secrets.controller.ts @@ -8,7 +8,13 @@ import { Put, UseGuards, } from '@nestjs/common'; -import { ApiBody, ApiOperation, ApiParam, ApiSecurity, ApiTags } from '@nestjs/swagger'; +import { + ApiBody, + ApiOperation, + ApiParam, + ApiSecurity, + ApiTags, +} from '@nestjs/swagger'; import { AuthContext, OrganizationId } from '../auth/auth-context.decorator'; import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; import { PermissionGuard } from '../auth/permission.guard'; @@ -86,7 +92,13 @@ export class SecretsController { }, }) async createSecret( - @Body() body: { name: string; value: string; description?: string; category?: string }, + @Body() + body: { + name: string; + value: string; + description?: string; + category?: string; + }, @OrganizationId() organizationId: string, @AuthContext() authContext: AuthContextType, ) { diff --git a/apps/api/src/secrets/secrets.service.ts b/apps/api/src/secrets/secrets.service.ts index 1ffb5330e5..ea6010ea7b 100644 --- a/apps/api/src/secrets/secrets.service.ts +++ b/apps/api/src/secrets/secrets.service.ts @@ -38,9 +38,7 @@ export class SecretsService { throw new NotFoundException('Secret not found'); } - const decryptedValue = decrypt( - JSON.parse(secret.value) as EncryptedData, - ); + const decryptedValue = decrypt(JSON.parse(secret.value) as EncryptedData); return { ...secret, value: decryptedValue }; } diff --git a/apps/api/src/security-penetration-tests/dto/create-penetration-test.dto.ts b/apps/api/src/security-penetration-tests/dto/create-penetration-test.dto.ts index 37c91f8938..12f09a8a80 100644 --- a/apps/api/src/security-penetration-tests/dto/create-penetration-test.dto.ts +++ b/apps/api/src/security-penetration-tests/dto/create-penetration-test.dto.ts @@ -52,7 +52,8 @@ export class CreatePenetrationTestDto { workspace?: string; @ApiPropertyOptional({ - description: 'Optional webhook URL to notify when report generation completes', + description: + 'Optional webhook URL to notify when report generation completes', required: false, }) @IsOptional() diff --git a/apps/api/src/security-penetration-tests/maced-client.ts b/apps/api/src/security-penetration-tests/maced-client.ts index 2876a7f798..2529e8d631 100644 --- a/apps/api/src/security-penetration-tests/maced-client.ts +++ b/apps/api/src/security-penetration-tests/maced-client.ts @@ -228,7 +228,9 @@ export class MacedClient { ); } - async createPentest(payload: MacedCreatePentestPayload): Promise { + async createPentest( + payload: MacedCreatePentestPayload, + ): Promise { const validatedPayload = macedCreatePentestPayloadSchema.safeParse(payload); if (!validatedPayload.success) { this.logger.error( diff --git a/apps/api/src/security-penetration-tests/pentest-billing.controller.ts b/apps/api/src/security-penetration-tests/pentest-billing.controller.ts index 8f7500df33..ff8868f4c9 100644 --- a/apps/api/src/security-penetration-tests/pentest-billing.controller.ts +++ b/apps/api/src/security-penetration-tests/pentest-billing.controller.ts @@ -54,7 +54,9 @@ export class PentestBillingController { @Post('subscribe') @RequirePermission('pentest', 'create') @HttpCode(200) - @ApiOperation({ summary: 'Create a Stripe checkout session for pentest subscription' }) + @ApiOperation({ + summary: 'Create a Stripe checkout session for pentest subscription', + }) @ApiResponse({ status: 200, description: 'Checkout URL returned' }) async subscribe( @OrganizationId() organizationId: string, @@ -76,7 +78,10 @@ export class PentestBillingController { @OrganizationId() organizationId: string, @Body() body: HandleSuccessDto, ) { - await this.billingService.handleSubscriptionSuccess(organizationId, body.sessionId); + await this.billingService.handleSubscriptionSuccess( + organizationId, + body.sessionId, + ); return { success: true }; } @@ -89,7 +94,10 @@ export class PentestBillingController { @OrganizationId() organizationId: string, @Body() body: PortalDto, ) { - return this.billingService.createBillingPortalSession(organizationId, body.returnUrl); + return this.billingService.createBillingPortalSession( + organizationId, + body.returnUrl, + ); } @Post('charge') @@ -101,6 +109,9 @@ export class PentestBillingController { @OrganizationId() organizationId: string, @Body() body: ChargeDto, ) { - return this.billingService.checkAndChargeBilling(organizationId, body.runId); + return this.billingService.checkAndChargeBilling( + organizationId, + body.runId, + ); } } diff --git a/apps/api/src/security-penetration-tests/pentest-billing.service.ts b/apps/api/src/security-penetration-tests/pentest-billing.service.ts index b10073b209..54b9810f00 100644 --- a/apps/api/src/security-penetration-tests/pentest-billing.service.ts +++ b/apps/api/src/security-penetration-tests/pentest-billing.service.ts @@ -35,7 +35,9 @@ export class PentestBillingService { const allowedOrigin = new URL(appUrl).origin; if (parsed.origin !== allowedOrigin) { - throw new BadRequestException('Redirect URL must belong to the application origin.'); + throw new BadRequestException( + 'Redirect URL must belong to the application origin.', + ); } } @@ -51,7 +53,9 @@ export class PentestBillingService { const priceId = process.env.STRIPE_PENTEST_SUBSCRIPTION_PRICE_ID; if (!priceId) { - throw new BadRequestException('STRIPE_PENTEST_SUBSCRIPTION_PRICE_ID is not configured.'); + throw new BadRequestException( + 'STRIPE_PENTEST_SUBSCRIPTION_PRICE_ID is not configured.', + ); } const org = await db.organization.findUnique({ @@ -89,7 +93,9 @@ export class PentestBillingService { }); if (!session.url) { - throw new BadRequestException('Failed to create Stripe Checkout session URL.'); + throw new BadRequestException( + 'Failed to create Stripe Checkout session URL.', + ); } return { url: session.url }; @@ -113,7 +119,7 @@ export class PentestBillingService { const stripeCustomerId = typeof session.customer === 'string' ? session.customer - : session.customer?.id ?? ''; + : (session.customer?.id ?? ''); const existingBilling = await db.organizationBilling.findUnique({ where: { organizationId }, @@ -121,12 +127,19 @@ export class PentestBillingService { if (existingBilling) { if (existingBilling.stripeCustomerId !== stripeCustomerId) { - throw new ForbiddenException('Checkout session does not belong to this organization.'); + throw new ForbiddenException( + 'Checkout session does not belong to this organization.', + ); } } else { const customer = await stripe.customers.retrieve(stripeCustomerId); - if (customer.deleted || customer.metadata?.organizationId !== organizationId) { - throw new ForbiddenException('Checkout session does not belong to this organization.'); + if ( + customer.deleted || + customer.metadata?.organizationId !== organizationId + ) { + throw new ForbiddenException( + 'Checkout session does not belong to this organization.', + ); } } @@ -144,7 +157,8 @@ export class PentestBillingService { organizationBillingId: billing.id, stripeSubscriptionId: subscription.id, stripePriceId: item?.price.id ?? '', - status: subscription.status === 'active' ? 'active' : subscription.status, + status: + subscription.status === 'active' ? 'active' : subscription.status, currentPeriodStart: new Date((item?.current_period_start ?? 0) * 1000), currentPeriodEnd: new Date((item?.current_period_end ?? 0) * 1000), }, @@ -152,7 +166,8 @@ export class PentestBillingService { organizationBillingId: billing.id, stripeSubscriptionId: subscription.id, stripePriceId: item?.price.id ?? '', - status: subscription.status === 'active' ? 'active' : subscription.status, + status: + subscription.status === 'active' ? 'active' : subscription.status, currentPeriodStart: new Date((item?.current_period_start ?? 0) * 1000), currentPeriodEnd: new Date((item?.current_period_end ?? 0) * 1000), }, @@ -172,7 +187,9 @@ export class PentestBillingService { }); if (!billing) { - throw new NotFoundException('No billing record found for this organization.'); + throw new NotFoundException( + 'No billing record found for this organization.', + ); } const portalSession = await stripe.billingPortal.sessions.create({ @@ -209,7 +226,10 @@ export class PentestBillingService { } if (subscription.status !== 'active') { - throw new HttpException('Pentest subscription is not active.', HttpStatus.PAYMENT_REQUIRED); + throw new HttpException( + 'Pentest subscription is not active.', + HttpStatus.PAYMENT_REQUIRED, + ); } const runsThisPeriod = await db.securityPenetrationTestRun.count({ @@ -230,7 +250,9 @@ export class PentestBillingService { const overagePriceId = process.env.STRIPE_PENTEST_OVERAGE_PRICE_ID; if (!overagePriceId) { - throw new BadRequestException('STRIPE_PENTEST_OVERAGE_PRICE_ID is not configured.'); + throw new BadRequestException( + 'STRIPE_PENTEST_OVERAGE_PRICE_ID is not configured.', + ); } const price = await stripe.prices.retrieve(overagePriceId); @@ -247,7 +269,8 @@ export class PentestBillingService { throw new BadRequestException('Stripe customer not found.'); } - const defaultPaymentMethod = customer.invoice_settings?.default_payment_method; + const defaultPaymentMethod = + customer.invoice_settings?.default_payment_method; if (!defaultPaymentMethod) { throw new HttpException( 'No payment method on file. Update billing at /settings/billing.', @@ -278,7 +301,10 @@ export class PentestBillingService { ); if (paymentIntent.status !== 'succeeded') { - throw new HttpException('Overage payment failed. Check billing.', HttpStatus.PAYMENT_REQUIRED); + throw new HttpException( + 'Overage payment failed. Check billing.', + HttpStatus.PAYMENT_REQUIRED, + ); } return { charged: true }; diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.controller.spec.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.controller.spec.ts index c1a7d90870..4560107fe4 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.controller.spec.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.controller.spec.ts @@ -11,7 +11,8 @@ import type { SecurityPenetrationTestsService } from './security-penetration-tes import type { Request as ExpressRequest } from 'express'; describe('SecurityPenetrationTestsController', () => { - const originalWebhookBase = process.env.SECURITY_PENETRATION_TESTS_WEBHOOK_URL; + const originalWebhookBase = + process.env.SECURITY_PENETRATION_TESTS_WEBHOOK_URL; const createReportMock = jest.fn(); const listReportsMock = jest.fn(); const getReportMock = jest.fn(); @@ -34,7 +35,8 @@ describe('SecurityPenetrationTestsController', () => { beforeEach(() => { jest.clearAllMocks(); - process.env.SECURITY_PENETRATION_TESTS_WEBHOOK_URL = 'https://callback.example.com/webhook'; + process.env.SECURITY_PENETRATION_TESTS_WEBHOOK_URL = + 'https://callback.example.com/webhook'; }); afterAll(() => { @@ -110,7 +112,11 @@ describe('SecurityPenetrationTestsController', () => { }); const responseMock = { set: jest.fn() }; - const output = await controller.getReport('org_123', 'run_1', responseMock as never); + const output = await controller.getReport( + 'org_123', + 'run_1', + responseMock as never, + ); expect(getReportOutputMock).toHaveBeenCalledWith('org_123', 'run_1'); expect(responseMock.set).toHaveBeenCalledWith({ @@ -127,12 +133,17 @@ describe('SecurityPenetrationTestsController', () => { }); const responseMock = { set: jest.fn() }; - const output = await controller.getPdf('org_123', 'run_1', responseMock as never); + const output = await controller.getPdf( + 'org_123', + 'run_1', + responseMock as never, + ); expect(getReportPdfMock).toHaveBeenCalledWith('org_123', 'run_1'); expect(responseMock.set).toHaveBeenCalledWith({ 'Content-Type': 'application/pdf', - 'Content-Disposition': 'attachment; filename="penetration-test-run_1.pdf"', + 'Content-Disposition': + 'attachment; filename="penetration-test-run_1.pdf"', 'Cache-Control': 'no-store', }); expect(output).toBeDefined(); @@ -153,13 +164,10 @@ describe('SecurityPenetrationTestsController', () => { await controller.handleWebhook(requestMock, webhookPayload); - expect(handleWebhookMock).toHaveBeenCalledWith( - webhookPayload, - { - webhookToken: undefined, - eventId: undefined, - }, - ); + expect(handleWebhookMock).toHaveBeenCalledWith(webhookPayload, { + webhookToken: undefined, + eventId: undefined, + }); }); it('accepts first query value from array form in webhook extraction', async () => { @@ -177,13 +185,10 @@ describe('SecurityPenetrationTestsController', () => { await controller.handleWebhook(requestMock, webhookPayload); - expect(handleWebhookMock).toHaveBeenCalledWith( - webhookPayload, - { - webhookToken: 'query-token', - eventId: undefined, - }, - ); + expect(handleWebhookMock).toHaveBeenCalledWith(webhookPayload, { + webhookToken: 'query-token', + eventId: undefined, + }); }); it('passes webhook token and event id metadata into service handler', async () => { @@ -201,12 +206,9 @@ describe('SecurityPenetrationTestsController', () => { await controller.handleWebhook(requestMock, webhookPayload); - expect(handleWebhookMock).toHaveBeenCalledWith( - webhookPayload, - { - webhookToken: 'query-token', - eventId: 'evt_123', - }, - ); + expect(handleWebhookMock).toHaveBeenCalledWith(webhookPayload, { + webhookToken: 'query-token', + eventId: 'evt_123', + }); }); }); diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.controller.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.controller.ts index d65251240b..714709ee4a 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.controller.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.controller.ts @@ -33,21 +33,19 @@ import { SecurityPenetrationTestsService } from './security-penetration-tests.se @ApiSecurity('apikey') @ApiHeader({ name: 'X-Organization-Id', - description: 'Organization ID (required for session auth, optional for API key auth)', + description: + 'Organization ID (required for session auth, optional for API key auth)', required: false, }) @UseGuards(HybridAuthGuard, PermissionGuard) export class SecurityPenetrationTestsController { - constructor( - private readonly service: SecurityPenetrationTestsService, - ) {} + constructor(private readonly service: SecurityPenetrationTestsService) {} @Get() @RequirePermission('pentest', 'read') @ApiOperation({ summary: 'List penetration test runs', - description: - 'Returns all penetration tests created for the organization.', + description: 'Returns all penetration tests created for the organization.', }) @ApiResponse({ status: 200, @@ -126,7 +124,10 @@ export class SecurityPenetrationTestsController { status: 200, description: 'Progress returned', }) - async getProgress(@OrganizationId() organizationId: string, @Param('id') id: string) { + async getProgress( + @OrganizationId() organizationId: string, + @Param('id') id: string, + ) { return this.service.getReportProgress(organizationId, id); } @@ -234,7 +235,10 @@ export class SecurityPenetrationTestsController { }); } - private extractStringFromQuery(request: Request, key: string): string | undefined { + private extractStringFromQuery( + request: Request, + key: string, + ): string | undefined { const queryValue = request.query[key]; if (Array.isArray(queryValue)) { return this.extractStringFromQueryValue(queryValue[0]); diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts index 623f8ff0a5..d9bf2aba90 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.service.spec.ts @@ -5,7 +5,9 @@ import type { CredentialVaultService } from '../integration-platform/services/cr import type { CreatePenetrationTestDto } from './dto/create-penetration-test.dto'; import { SecurityPenetrationTestsService } from './security-penetration-tests.service'; -const mockCredentialVaultService: jest.Mocked> = { +const mockCredentialVaultService: jest.Mocked< + Pick +> = { getDecryptedCredentials: jest.fn(), }; @@ -52,12 +54,13 @@ type MockDb = { describe('SecurityPenetrationTestsService', () => { const originalFetch = global.fetch; const originalMacedApiKey = process.env.MACED_API_KEY; - const originalWebhookBase = process.env.SECURITY_PENETRATION_TESTS_WEBHOOK_URL; + const originalWebhookBase = + process.env.SECURITY_PENETRATION_TESTS_WEBHOOK_URL; const defaultWebhookToken = 'test-webhook-token'; const defaultWebhookTokenHash = createHash('sha256') .update(defaultWebhookToken) .digest('hex'); - const fetchMock = jest.fn() as jest.Mock; + const fetchMock = jest.fn(); const mockedDb = db as unknown as MockDb; let service: SecurityPenetrationTestsService; @@ -130,7 +133,8 @@ describe('SecurityPenetrationTestsService', () => { }); it('creates report payload with resolved webhook URL', async () => { - process.env.SECURITY_PENETRATION_TESTS_WEBHOOK_URL = 'https://api.trycomp.ai/webhook'; + process.env.SECURITY_PENETRATION_TESTS_WEBHOOK_URL = + 'https://api.trycomp.ai/webhook'; const expectedPayload = { id: 'run_456', status: 'provisioning', @@ -150,7 +154,10 @@ describe('SecurityPenetrationTestsService', () => { await service.createReport('org_123', payload); const [, options] = fetchMock.mock.calls[0]; - const requestBody = JSON.parse(options.body as string) as Record; + const requestBody = JSON.parse(options.body as string) as Record< + string, + unknown + >; expect(requestBody.webhookUrl).toBe( 'https://api.trycomp.ai/webhook/v1/security-penetration-tests/webhook', @@ -158,7 +165,10 @@ describe('SecurityPenetrationTestsService', () => { expect(requestBody.targetUrl).toBe(payload.targetUrl); expect(requestBody.repoUrl).toBe(payload.repoUrl); expect(requestBody.testMode).toBe(true); - expect(requestBody).not.toHaveProperty('webhookUrl', 'https://report-callback.example.com/webhook'); + expect(requestBody).not.toHaveProperty( + 'webhookUrl', + 'https://report-callback.example.com/webhook', + ); expect(mockedDb.secret.upsert).toHaveBeenCalledTimes(1); expect(mockedDb.securityPenetrationTestRun.upsert).toHaveBeenCalledTimes(1); }); @@ -183,7 +193,10 @@ describe('SecurityPenetrationTestsService', () => { }); const [, options] = fetchMock.mock.calls[0]; - const requestBody = JSON.parse(options.body as string) as Record; + const requestBody = JSON.parse(options.body as string) as Record< + string, + unknown + >; expect(requestBody.webhookUrl).toBe( 'https://api.trycomp.ai/v1/security-penetration-tests/webhook', @@ -354,17 +367,13 @@ describe('SecurityPenetrationTestsService', () => { }); it('returns empty list for empty payload', async () => { - fetchMock.mockResolvedValueOnce( - new Response('', { status: 200 }), - ); + fetchMock.mockResolvedValueOnce(new Response('', { status: 200 })); await expect(service.listReports('org_123')).resolves.toEqual([]); }); it('maps invalid list payload to bad gateway', async () => { - fetchMock.mockResolvedValueOnce( - new Response('not-json', { status: 200 }), - ); + fetchMock.mockResolvedValueOnce(new Response('not-json', { status: 200 })); await expect(service.listReports('org_123')).rejects.toEqual( expect.objectContaining({ @@ -382,7 +391,8 @@ describe('SecurityPenetrationTestsService', () => { status: 'provisioning', webhookToken: 'provider-token', }; - const webhookUrl = 'https://app.company.test/security-penetration-tests/webhook'; + const webhookUrl = + 'https://app.company.test/security-penetration-tests/webhook'; process.env.SECURITY_PENETRATION_TESTS_WEBHOOK_URL = webhookUrl; @@ -399,9 +409,14 @@ describe('SecurityPenetrationTestsService', () => { await service.createReport('org_123', payload); const [, options] = fetchMock.mock.calls[0]; - const requestBody = JSON.parse(options.body as string) as Record; + const requestBody = JSON.parse(options.body as string) as Record< + string, + unknown + >; - expect(requestBody.webhookUrl).toBe('https://app.company.test/v1/security-penetration-tests/webhook'); + expect(requestBody.webhookUrl).toBe( + 'https://app.company.test/v1/security-penetration-tests/webhook', + ); }); it('allows third-party webhook URLs without requiring provider webhook token', async () => { @@ -428,7 +443,10 @@ describe('SecurityPenetrationTestsService', () => { ); const [, options] = fetchMock.mock.calls[0]; - const requestBody = JSON.parse(options.body as string) as Record; + const requestBody = JSON.parse(options.body as string) as Record< + string, + unknown + >; expect(requestBody.webhookUrl).toBe( 'https://external-webhook.example.com/callback/v1/security-penetration-tests/webhook', @@ -457,7 +475,10 @@ describe('SecurityPenetrationTestsService', () => { }); const [, options] = fetchMock.mock.calls[0]; - const requestBody = JSON.parse(options.body as string) as Record; + const requestBody = JSON.parse(options.body as string) as Record< + string, + unknown + >; expect(requestBody.webhookUrl).toBe( 'https://app.company.test/v1/security-penetration-tests/webhook?foo=bar', @@ -470,7 +491,8 @@ describe('SecurityPenetrationTestsService', () => { status: 'provisioning', }; - const webhookUrl = 'https://app.company.test/v1/security-penetration-tests/webhook?foo=bar'; + const webhookUrl = + 'https://app.company.test/v1/security-penetration-tests/webhook?foo=bar'; fetchMock.mockResolvedValueOnce( new Response(JSON.stringify(expectedPayload), { status: 200 }), @@ -483,9 +505,14 @@ describe('SecurityPenetrationTestsService', () => { }); const [, options] = fetchMock.mock.calls[0]; - const requestBody = JSON.parse(options.body as string) as Record; + const requestBody = JSON.parse(options.body as string) as Record< + string, + unknown + >; - expect(requestBody.webhookUrl).toBe('https://app.company.test/v1/security-penetration-tests/webhook?foo=bar'); + expect(requestBody.webhookUrl).toBe( + 'https://app.company.test/v1/security-penetration-tests/webhook?foo=bar', + ); }); it('supports absolute webhook URLs that require appending the expected endpoint', async () => { @@ -505,9 +532,14 @@ describe('SecurityPenetrationTestsService', () => { }); const [, options] = fetchMock.mock.calls[0]; - const requestBody = JSON.parse(options.body as string) as Record; + const requestBody = JSON.parse(options.body as string) as Record< + string, + unknown + >; - expect(requestBody.webhookUrl).toBe('https://callback.example.com/hook/v1/security-penetration-tests/webhook'); + expect(requestBody.webhookUrl).toBe( + 'https://callback.example.com/hook/v1/security-penetration-tests/webhook', + ); }); it('strips webhookToken query parameter before forwarding webhook URL to provider', async () => { @@ -530,7 +562,10 @@ describe('SecurityPenetrationTestsService', () => { }); const [, options] = fetchMock.mock.calls[0]; - const requestBody = JSON.parse(options.body as string) as Record; + const requestBody = JSON.parse(options.body as string) as Record< + string, + unknown + >; expect(requestBody.webhookUrl).toBe( 'https://callback.example.com/hook/v1/security-penetration-tests/webhook?foo=bar', @@ -805,7 +840,10 @@ describe('SecurityPenetrationTestsService', () => { }), ); - const output = await service.getReportOutput('org_123', 'run_output_no_type'); + const output = await service.getReportOutput( + 'org_123', + 'run_output_no_type', + ); expect(output.contentType).toBe('text/markdown; charset=utf-8'); expect(output.contentDisposition).toBeNull(); @@ -849,9 +887,7 @@ describe('SecurityPenetrationTestsService', () => { }); it('maps empty get report response to bad gateway', async () => { - fetchMock.mockResolvedValueOnce( - new Response('', { status: 200 }), - ); + fetchMock.mockResolvedValueOnce(new Response('', { status: 200 })); await expect(service.getReport('org_123', 'run_123')).rejects.toEqual( expect.objectContaining({ @@ -864,11 +900,11 @@ describe('SecurityPenetrationTestsService', () => { }); it('maps empty get progress response to bad gateway', async () => { - fetchMock.mockResolvedValueOnce( - new Response('', { status: 200 }), - ); + fetchMock.mockResolvedValueOnce(new Response('', { status: 200 })); - await expect(service.getReportProgress('org_123', 'run_123')).rejects.toEqual( + await expect( + service.getReportProgress('org_123', 'run_123'), + ).rejects.toEqual( expect.objectContaining({ status: HttpStatus.BAD_GATEWAY, response: { @@ -879,11 +915,11 @@ describe('SecurityPenetrationTestsService', () => { }); it('maps invalid report progress payload to bad gateway', async () => { - fetchMock.mockResolvedValueOnce( - new Response('nope', { status: 200 }), - ); + fetchMock.mockResolvedValueOnce(new Response('nope', { status: 200 })); - await expect(service.getReportProgress('org_123', 'run_123')).rejects.toEqual( + await expect( + service.getReportProgress('org_123', 'run_123'), + ).rejects.toEqual( expect.objectContaining({ status: HttpStatus.BAD_GATEWAY, response: { @@ -930,7 +966,9 @@ describe('SecurityPenetrationTestsService', () => { ); expect(output.buffer).toEqual(Buffer.from(fixtureBuffer)); expect(output.contentType).toBe('application/pdf'); - expect(output.contentDisposition).toBe('attachment; filename="penetration-test-run_pdf.pdf"'); + expect(output.contentDisposition).toBe( + 'attachment; filename="penetration-test-run_pdf.pdf"', + ); }); it('throws a mapped HttpException for failed provider calls', async () => { diff --git a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts index 1cae0b2230..8e93b05e5e 100644 --- a/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts +++ b/apps/api/src/security-penetration-tests/security-penetration-tests.service.ts @@ -94,9 +94,12 @@ export class SecurityPenetrationTestsService { private readonly logger = new Logger(SecurityPenetrationTestsService.name); private readonly macedClient = new MacedClient(); - constructor(private readonly credentialVaultService: CredentialVaultService) {} + constructor( + private readonly credentialVaultService: CredentialVaultService, + ) {} - private readonly canonicalWebhookPath = '/v1/security-penetration-tests/webhook'; + private readonly canonicalWebhookPath = + '/v1/security-penetration-tests/webhook'; private readonly defaultWebhookBaseUrl = 'https://api.trycomp.ai'; private readonly defaultCompWebhookHosts = new Set([ 'api.trycomp.ai', @@ -111,7 +114,9 @@ export class SecurityPenetrationTestsService { ); } - async listReports(organizationId: string): Promise { + async listReports( + organizationId: string, + ): Promise { const ownedRunIds = await this.listOwnedRunIds(organizationId); if (ownedRunIds.size === 0) { return []; @@ -119,9 +124,11 @@ export class SecurityPenetrationTestsService { const reports = await this.macedClient.listPentests(); - return reports.filter((report) => { - return ownedRunIds.has(report.id); - }).map((report) => this.mapMacedRunToSecurityPenetrationTest(report)); + return reports + .filter((report) => { + return ownedRunIds.has(report.id); + }) + .map((report) => this.mapMacedRunToSecurityPenetrationTest(report)); } async createReport( @@ -158,7 +165,8 @@ export class SecurityPenetrationTestsService { (await this.getGithubTokenForOrg(organizationId)) ?? undefined; } - const createdReport = await this.macedClient.createPentest(sanitizedPayload); + const createdReport = + await this.macedClient.createPentest(sanitizedPayload); const providerRunId = createdReport.id; @@ -267,7 +275,10 @@ export class SecurityPenetrationTestsService { } } - async getReport(organizationId: string, id: string): Promise { + async getReport( + organizationId: string, + id: string, + ): Promise { await this.assertRunOwnership(organizationId, id); const report = await this.macedClient.getPentest(id); return this.mapMacedRunToSecurityPenetrationTest(report); @@ -281,19 +292,26 @@ export class SecurityPenetrationTestsService { return this.macedClient.getPentestProgress(id); } - async getReportOutput(organizationId: string, id: string): Promise { + async getReportOutput( + organizationId: string, + id: string, + ): Promise { await this.getReport(organizationId, id); const response = await this.macedClient.getPentestReportRaw(id); return { buffer: Buffer.from(await response.arrayBuffer()), - contentType: response.headers.get('Content-Type') || 'text/markdown; charset=utf-8', + contentType: + response.headers.get('Content-Type') || 'text/markdown; charset=utf-8', contentDisposition: response.headers.get('Content-Disposition'), }; } - async getReportPdf(organizationId: string, id: string): Promise { + async getReportPdf( + organizationId: string, + id: string, + ): Promise { await this.getReport(organizationId, id); const response = await this.macedClient.getPentestReportPdf(id); @@ -344,7 +362,8 @@ export class SecurityPenetrationTestsService { throw new BadRequestException('Webhook payload must include a report id'); } - const organizationId = await this.resolveOrganizationForRun(payloadReportId); + const organizationId = + await this.resolveOrganizationForRun(payloadReportId); const duplicate = await this.verifyAndRecordWebhookHandshake({ organizationId, @@ -360,7 +379,11 @@ export class SecurityPenetrationTestsService { (completedEvent ? 'completed' : undefined) || (failedEvent ? 'failed' : undefined); - const eventType: WebhookEventType = completedEvent ? 'completed' : failedEvent ? 'failed' : 'status'; + const eventType: WebhookEventType = completedEvent + ? 'completed' + : failedEvent + ? 'failed' + : 'status'; this.logger.log( `[Webhook] Received penetration test ${eventType} event for org=${organizationId}${payloadReportId ? ` run=${payloadReportId}` : ''} status=${payloadStatus ?? 'unknown'}`, @@ -394,7 +417,9 @@ export class SecurityPenetrationTestsService { }; } - private async getGithubTokenForOrg(organizationId: string): Promise { + private async getGithubTokenForOrg( + organizationId: string, + ): Promise { try { const provider = await db.integrationProvider.findUnique({ where: { slug: 'github' }, @@ -418,9 +443,10 @@ export class SecurityPenetrationTestsService { return null; } - const credentials = await this.credentialVaultService.getDecryptedCredentials( - connection.id, - ); + const credentials = + await this.credentialVaultService.getDecryptedCredentials( + connection.id, + ); const token = credentials?.access_token; return typeof token === 'string' && token.length > 0 ? token : null; @@ -470,7 +496,10 @@ export class SecurityPenetrationTestsService { for (const suffix of legacySuffixes) { if (normalizedPath.endsWith(suffix)) { - const basePath = normalizedPath.slice(0, normalizedPath.length - suffix.length); + const basePath = normalizedPath.slice( + 0, + normalizedPath.length - suffix.length, + ); return basePath ? `${basePath}${this.canonicalWebhookPath}` : this.canonicalWebhookPath; @@ -488,9 +517,7 @@ export class SecurityPenetrationTestsService { return path.endsWith(this.canonicalWebhookPath); } - private resolveWebhookUrl( - providedUrl?: string, - ): string | undefined { + private resolveWebhookUrl(providedUrl?: string): string | undefined { const baseUrl = providedUrl?.trim() || this.defaultWebhookBase; if (!baseUrl) { return undefined; @@ -517,7 +544,9 @@ export class SecurityPenetrationTestsService { } const value = payload[key]; - return typeof value === 'string' && value.trim().length > 0 ? value.trim() : undefined; + return typeof value === 'string' && value.trim().length > 0 + ? value.trim() + : undefined; } private extractNumberField( @@ -529,7 +558,9 @@ export class SecurityPenetrationTestsService { } const value = payload[key]; - return typeof value === 'number' && Number.isFinite(value) ? value : undefined; + return typeof value === 'number' && Number.isFinite(value) + ? value + : undefined; } private extractCompletedWebhookPayload( @@ -547,7 +578,7 @@ export class SecurityPenetrationTestsService { return null; } - const reportRecord = reportValue as Record; + const reportRecord = reportValue; const markdown = this.extractStringField(reportRecord, 'markdown'); const costUsd = this.extractNumberField(reportRecord, 'costUsd'); const durationMs = this.extractNumberField(reportRecord, 'durationMs'); @@ -658,10 +689,7 @@ export class SecurityPenetrationTestsService { ): Promise { const ownerOrganizationId = await this.resolveOrganizationForRun( reportId, - new HttpException( - { error: 'Report not found' }, - HttpStatus.NOT_FOUND, - ), + new HttpException({ error: 'Report not found' }, HttpStatus.NOT_FOUND), ); if (ownerOrganizationId !== organizationId) { @@ -674,7 +702,9 @@ export class SecurityPenetrationTestsService { private async resolveOrganizationForRun( reportId: string, - notFoundError: Error = new ForbiddenException('Run ownership mapping not found'), + notFoundError: Error = new ForbiddenException( + 'Run ownership mapping not found', + ), ): Promise { const marker = await db.securityPenetrationTestRun.findUnique({ where: { @@ -693,14 +723,15 @@ export class SecurityPenetrationTestsService { } private async listOwnedRunIds(organizationId: string): Promise> { - const markers = (await db.securityPenetrationTestRun.findMany({ - where: { - organizationId, - }, - select: { - providerRunId: true, - }, - })) ?? []; + const markers = + (await db.securityPenetrationTestRun.findMany({ + where: { + organizationId, + }, + select: { + providerRunId: true, + }, + })) ?? []; return new Set(markers.map(({ providerRunId }) => providerRunId)); } @@ -737,7 +768,9 @@ export class SecurityPenetrationTestsService { try { hosts.add(new URL(candidate).host.toLowerCase()); } catch { - this.logger.warn(`Ignoring invalid trusted webhook host URL: ${candidate}`); + this.logger.warn( + `Ignoring invalid trusted webhook host URL: ${candidate}`, + ); } } @@ -822,7 +855,11 @@ export class SecurityPenetrationTestsService { ): Promise { for (let attempt = 1; attempt <= 3; attempt += 1) { try { - await this.persistWebhookHandshake(organizationId, reportId, webhookToken); + await this.persistWebhookHandshake( + organizationId, + reportId, + webhookToken, + ); return true; } catch (error) { this.logger.error( diff --git a/apps/api/src/soa/soa.service.spec.ts b/apps/api/src/soa/soa.service.spec.ts index f3f3220fc8..8c87ad045f 100644 --- a/apps/api/src/soa/soa.service.spec.ts +++ b/apps/api/src/soa/soa.service.spec.ts @@ -67,13 +67,13 @@ describe('SOAService', () => { it('throws NotFoundException when framework not found', async () => { mockDb.frameworkEditorFramework.findUnique.mockResolvedValue(null); - await expect(service.ensureSetup(dto)).rejects.toThrow( - NotFoundException, - ); + await expect(service.ensureSetup(dto)).rejects.toThrow(NotFoundException); }); it('returns success:false for non-ISO 27001 framework', async () => { - (mockDb.frameworkEditorFramework.findUnique as jest.Mock).mockResolvedValue({ + ( + mockDb.frameworkEditorFramework.findUnique as jest.Mock + ).mockResolvedValue({ id: 'fw-1', name: 'SOC 2', }); @@ -83,14 +83,18 @@ describe('SOAService', () => { }); it('throws InternalServerErrorException when config creation fails', async () => { - (mockDb.frameworkEditorFramework.findUnique as jest.Mock).mockResolvedValue({ + ( + mockDb.frameworkEditorFramework.findUnique as jest.Mock + ).mockResolvedValue({ id: 'fw-1', name: 'ISO 27001', }); - (mockDb.sOAFrameworkConfiguration.findFirst as jest.Mock).mockResolvedValue(null); - (mockDb.frameworkEditorFramework.findFirst as jest.Mock).mockRejectedValue( - new Error('DB error'), - ); + ( + mockDb.sOAFrameworkConfiguration.findFirst as jest.Mock + ).mockResolvedValue(null); + ( + mockDb.frameworkEditorFramework.findFirst as jest.Mock + ).mockRejectedValue(new Error('DB error')); await expect(service.ensureSetup(dto)).rejects.toThrow( InternalServerErrorException, ); @@ -98,15 +102,19 @@ describe('SOAService', () => { it('throws InternalServerErrorException when document creation fails', async () => { const config = { id: 'cfg-1', questions: [] }; - (mockDb.frameworkEditorFramework.findUnique as jest.Mock).mockResolvedValue({ + ( + mockDb.frameworkEditorFramework.findUnique as jest.Mock + ).mockResolvedValue({ id: 'fw-1', name: 'ISO 27001', }); - (mockDb.sOAFrameworkConfiguration.findFirst as jest.Mock).mockResolvedValue(config); + ( + mockDb.sOAFrameworkConfiguration.findFirst as jest.Mock + ).mockResolvedValue(config); (mockDb.sOADocument.findFirst as jest.Mock).mockResolvedValue(null); - (mockDb.sOAFrameworkConfiguration.findUnique as jest.Mock).mockRejectedValue( - new Error('DB error'), - ); + ( + mockDb.sOAFrameworkConfiguration.findUnique as jest.Mock + ).mockRejectedValue(new Error('DB error')); await expect(service.ensureSetup(dto)).rejects.toThrow( InternalServerErrorException, ); @@ -115,11 +123,15 @@ describe('SOAService', () => { it('returns existing document when setup already complete', async () => { const config = { id: 'cfg-1', questions: [{ id: 'q1' }] }; const doc = { id: 'doc-1', answers: [] }; - (mockDb.frameworkEditorFramework.findUnique as jest.Mock).mockResolvedValue({ + ( + mockDb.frameworkEditorFramework.findUnique as jest.Mock + ).mockResolvedValue({ id: 'fw-1', name: 'ISO 27001', }); - (mockDb.sOAFrameworkConfiguration.findFirst as jest.Mock).mockResolvedValue(config); + ( + mockDb.sOAFrameworkConfiguration.findFirst as jest.Mock + ).mockResolvedValue(config); (mockDb.sOADocument.findFirst as jest.Mock).mockResolvedValue(doc); const result = await service.ensureSetup(dto); expect(result.success).toBe(true); @@ -230,7 +242,9 @@ describe('SOAService', () => { userId: 'user-1', role: 'admin', }); - (mockDb.user.findUnique as jest.Mock).mockResolvedValue({ role: 'admin' }); + (mockDb.user.findUnique as jest.Mock).mockResolvedValue({ + role: 'admin', + }); await expect(service.submitForApproval(dto)).rejects.toThrow( BadRequestException, ); diff --git a/apps/api/src/soa/soa.service.ts b/apps/api/src/soa/soa.service.ts index 8f4cef7c5e..aa085444ac 100644 --- a/apps/api/src/soa/soa.service.ts +++ b/apps/api/src/soa/soa.service.ts @@ -187,7 +187,9 @@ export class SOAService { }); if (!configuration) { - throw new NotFoundException('No SOA configuration found for this framework'); + throw new NotFoundException( + 'No SOA configuration found for this framework', + ); } const existingLatestDocument = await db.sOADocument.findFirst({ @@ -397,7 +399,9 @@ export class SOAService { select: { role: true }, }); if (approverUser?.role === 'admin') { - throw new BadRequestException('Cannot assign a platform admin as approver'); + throw new BadRequestException( + 'Cannot assign a platform admin as approver', + ); } const isOwnerOrAdmin = @@ -525,7 +529,9 @@ export class SOAService { member.role.includes('owner') || member.role.includes('admin'); if (!isOwnerOrAdmin) { - throw new ForbiddenException('Only owners and admins can perform this action'); + throw new ForbiddenException( + 'Only owners and admins can perform this action', + ); } return member; diff --git a/apps/api/src/task-management/task-item-assignment-notifier.service.ts b/apps/api/src/task-management/task-item-assignment-notifier.service.ts index 4ee32165ad..8abc54e0b4 100644 --- a/apps/api/src/task-management/task-item-assignment-notifier.service.ts +++ b/apps/api/src/task-management/task-item-assignment-notifier.service.ts @@ -93,7 +93,10 @@ export class TaskItemAssignmentNotifierService { } // Skip notifications for platform admin members unless they are an owner - const isOwner = assigneeMember.role?.split(',').map((r: string) => r.trim()).includes('owner'); + const isOwner = assigneeMember.role + ?.split(',') + .map((r: string) => r.trim()) + .includes('owner'); if (assigneeUser.role === 'admin' && !isOwner) { this.logger.log( `Skipping assignment notification: assignee ${assigneeUser.email} is a platform admin (non-owner)`, diff --git a/apps/api/src/task-management/task-management.service.ts b/apps/api/src/task-management/task-management.service.ts index 1c992ee315..6e67a899c0 100644 --- a/apps/api/src/task-management/task-management.service.ts +++ b/apps/api/src/task-management/task-management.service.ts @@ -271,7 +271,9 @@ export class TaskManagementService { include: { user: { select: { role: true } } }, }); if (assigneeMember?.user.role === 'admin') { - throw new BadRequestException('Cannot assign a platform admin as assignee'); + throw new BadRequestException( + 'Cannot assign a platform admin as assignee', + ); } } @@ -322,7 +324,6 @@ export class TaskManagementService { assigneeMemberId: createTaskItemDto.assigneeId, assignedByUserId: authContext.userId, }); - } // Notify mentioned users @@ -486,7 +487,9 @@ export class TaskManagementService { include: { user: { select: { role: true } } }, }); if (assigneeMember?.user.role === 'admin') { - throw new BadRequestException('Cannot assign a platform admin as assignee'); + throw new BadRequestException( + 'Cannot assign a platform admin as assignee', + ); } } updateData.assigneeId = updateTaskItemDto.assigneeId; @@ -537,13 +540,11 @@ export class TaskManagementService { assigneeMemberId: taskItem.assigneeId, assignedByUserId: authContext.userId, }); - } else { // Assignee removed this.logger.log( `[ASSIGNEE DEBUG] Assignee removed from task ${taskItem.id}`, ); - } } else if (updateTaskItemDto.assigneeId !== undefined) { this.logger.log( diff --git a/apps/api/src/tasks/attachments.service.ts b/apps/api/src/tasks/attachments.service.ts index 2f636e46b0..6934162522 100644 --- a/apps/api/src/tasks/attachments.service.ts +++ b/apps/api/src/tasks/attachments.service.ts @@ -4,7 +4,7 @@ import { PutObjectCommand, S3Client, } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; +import { getSignedUrl, s3Client } from '@/app/s3'; import { AttachmentEntityType, AttachmentType } from '@db'; import { BadRequestException, @@ -15,7 +15,6 @@ import { db } from '@db'; import { randomBytes } from 'crypto'; import { AttachmentResponseDto } from './dto/task-responses.dto'; import { UploadAttachmentDto } from './dto/upload-attachment.dto'; -import { s3Client } from '@/app/s3'; import { validateFileContent } from '../utils/file-type-validation'; @Injectable() diff --git a/apps/api/src/tasks/automations/automations.controller.ts b/apps/api/src/tasks/automations/automations.controller.ts index 8dc1407549..05ce95b6cc 100644 --- a/apps/api/src/tasks/automations/automations.controller.ts +++ b/apps/api/src/tasks/automations/automations.controller.ts @@ -250,7 +250,9 @@ export class AutomationsController { @Post(':automationId/versions') @RequirePermission('task', 'update') - @ApiOperation({ summary: 'Create a published version record for an automation' }) + @ApiOperation({ + summary: 'Create a published version record for an automation', + }) @ApiParam({ name: 'taskId', description: 'Task ID' }) @ApiParam({ name: 'automationId', description: 'Automation ID' }) async createVersion( diff --git a/apps/api/src/tasks/description-framework-filter.spec.ts b/apps/api/src/tasks/description-framework-filter.spec.ts index 2890fde9f2..6133c1baa3 100644 --- a/apps/api/src/tasks/description-framework-filter.spec.ts +++ b/apps/api/src/tasks/description-framework-filter.spec.ts @@ -30,10 +30,7 @@ describe('filterDescriptionByFrameworks', () => { it('keeps all framework paragraphs when all are active', () => { const desc = 'Base task.\n\nFor ISO 27001: ISO requirement.\n\nFor HIPAA: HIPAA requirement.'; - const result = filterDescriptionByFrameworks(desc, [ - 'ISO 27001', - 'HIPAA', - ]); + const result = filterDescriptionByFrameworks(desc, ['ISO 27001', 'HIPAA']); expect(result).toContain('For ISO 27001'); expect(result).toContain('For HIPAA'); expect(result).toContain('Base task.'); diff --git a/apps/api/src/tasks/task-notifier.service.ts b/apps/api/src/tasks/task-notifier.service.ts index f6bb705f5a..91dbf22d11 100644 --- a/apps/api/src/tasks/task-notifier.service.ts +++ b/apps/api/src/tasks/task-notifier.service.ts @@ -98,7 +98,10 @@ export class TaskNotifierService { // Build recipient list: all members excluding actor. // The isUserUnsubscribed check handles role-based filtering via the notification matrix. - const recipientMap = new Map(); + const recipientMap = new Map< + string, + { id: string; name: string; email: string } + >(); for (const member of allMembers) { if (member.user?.id && member.user.email) { @@ -452,7 +455,10 @@ export class TaskNotifierService { // Build recipient list: all members excluding actor. // The isUserUnsubscribed check handles role-based filtering via the notification matrix. - const recipientMap = new Map(); + const recipientMap = new Map< + string, + { id: string; name: string; email: string } + >(); for (const member of allMembers) { if (member.user?.id && member.user.email) { @@ -1266,7 +1272,9 @@ export class TaskNotifierService { const { organizationId, tasks: failedTasks } = params; if (failedTasks.length === 0) { - this.logger.log('[notifyBulkAutomationFailures] No failed tasks, skipping'); + this.logger.log( + '[notifyBulkAutomationFailures] No failed tasks, skipping', + ); return; } diff --git a/apps/api/src/tasks/tasks.controller.spec.ts b/apps/api/src/tasks/tasks.controller.spec.ts index 7e3ecc305d..2474d5fdd5 100644 --- a/apps/api/src/tasks/tasks.controller.spec.ts +++ b/apps/api/src/tasks/tasks.controller.spec.ts @@ -11,7 +11,15 @@ import { AttachmentsService } from '../attachments/attachments.service'; jest.mock('@db', () => ({ ...jest.requireActual('@prisma/client'), db: {}, - Prisma: { PrismaClientKnownRequestError: class PrismaClientKnownRequestError extends Error { code: string; constructor(message: string, { code }: { code: string }) { super(message); this.code = code; } } }, + Prisma: { + PrismaClientKnownRequestError: class PrismaClientKnownRequestError extends Error { + code: string; + constructor(message: string, { code }: { code: string }) { + super(message); + this.code = code; + } + }, + }, })); jest.mock('../auth/auth.server', () => ({ @@ -265,9 +273,7 @@ describe('TasksController', () => { status: TaskStatus.done, }); - expect(mockTasksService.getApiKeyActorUserId).toHaveBeenCalledWith( - orgId, - ); + expect(mockTasksService.getApiKeyActorUserId).toHaveBeenCalledWith(orgId); expect(mockTasksService.updateTasksStatus).toHaveBeenCalledWith( orgId, ['tsk_1'], @@ -623,9 +629,7 @@ describe('TasksController', () => { title: 'Updated', }); - expect(mockTasksService.getApiKeyActorUserId).toHaveBeenCalledWith( - orgId, - ); + expect(mockTasksService.getApiKeyActorUserId).toHaveBeenCalledWith(orgId); }); }); @@ -845,7 +849,9 @@ describe('TasksController', () => { userId: 'usr_dto', }; mockTasksService.verifyTaskAccess.mockResolvedValue(undefined); - mockAttachmentsService.uploadAttachment.mockResolvedValue({ id: 'att_1' }); + mockAttachmentsService.uploadAttachment.mockResolvedValue({ + id: 'att_1', + }); await controller.uploadTaskAttachment( apiKeyAuth, diff --git a/apps/api/src/tasks/tasks.controller.ts b/apps/api/src/tasks/tasks.controller.ts index 1f33cb65c9..47984b0a99 100644 --- a/apps/api/src/tasks/tasks.controller.ts +++ b/apps/api/src/tasks/tasks.controller.ts @@ -111,7 +111,11 @@ export class TasksController { }, }, }) - @ApiQuery({ name: 'includeRelations', required: false, description: 'Include controls and automations with runs' }) + @ApiQuery({ + name: 'includeRelations', + required: false, + description: 'Include controls and automations with runs', + }) async getTasks( @OrganizationId() organizationId: string, @AuthContext() authContext: AuthContextType, @@ -134,12 +138,15 @@ export class TasksController { @RequirePermission('task', 'read') @ApiOperation({ summary: 'Get task templates', - description: 'Retrieve all available task templates, optionally filtered by framework.', + description: + 'Retrieve all available task templates, optionally filtered by framework.', }) - @ApiQuery({ name: 'frameworkId', required: false, description: 'Filter templates by framework ID' }) - async getTaskTemplates( - @Query('frameworkId') frameworkId?: string, - ) { + @ApiQuery({ + name: 'frameworkId', + required: false, + description: 'Filter templates by framework ID', + }) + async getTaskTemplates(@Query('frameworkId') frameworkId?: string) { return await this.tasksService.getTaskTemplates(frameworkId); } @@ -421,7 +428,8 @@ export class TasksController { @ApiResponse({ status: 400, description: 'Invalid request body' }) async reorderTasks( @OrganizationId() organizationId: string, - @Body() body: { updates: { id: string; order: number; status: TaskStatus }[] }, + @Body() + body: { updates: { id: string; order: number; status: TaskStatus }[] }, ): Promise<{ success: boolean }> { if (!Array.isArray(body.updates) || body.updates.length === 0) { throw new BadRequestException('updates must be a non-empty array'); @@ -605,9 +613,16 @@ export class TasksController { // Check assignment access for restricted roles // The task object from service includes assigneeId even though DTO doesn't declare it - const taskWithAssignee = task as TaskResponseDto & { assigneeId: string | null }; + const taskWithAssignee = task as TaskResponseDto & { + assigneeId: string | null; + }; if ( - !hasTaskAccess(taskWithAssignee, authContext.memberId, authContext.userRoles, { isApiKey: authContext.isApiKey }) + !hasTaskAccess( + taskWithAssignee, + authContext.memberId, + authContext.userRoles, + { isApiKey: authContext.isApiKey }, + ) ) { throw new ForbiddenException('You do not have access to this task'); } diff --git a/apps/api/src/tasks/tasks.service.ts b/apps/api/src/tasks/tasks.service.ts index deb99dcd7c..f9a89caf3d 100644 --- a/apps/api/src/tasks/tasks.service.ts +++ b/apps/api/src/tasks/tasks.service.ts @@ -10,7 +10,9 @@ import { db, TaskStatus, Prisma, TaskFrequency, Departments } from '@db'; import { TaskResponseDto } from './dto/task-responses.dto'; import { TaskNotifierService } from './task-notifier.service'; -function computeNextTaskReviewDate(frequency: TaskFrequency | null | undefined): Date { +function computeNextTaskReviewDate( + frequency: TaskFrequency | null | undefined, +): Date { const now = new Date(); switch (frequency) { case TaskFrequency.daily: @@ -175,10 +177,7 @@ export class TasksService { /** * Get a single task by ID */ - async getTask( - organizationId: string, - taskId: string, - ) { + async getTask(organizationId: string, taskId: string) { try { const [task, activeFrameworkNames] = await Promise.all([ db.task.findFirst({ @@ -256,7 +255,13 @@ export class TasksService { where, include: { user: { - select: { id: true, name: true, email: true, image: true, role: true }, + select: { + id: true, + name: true, + email: true, + image: true, + role: true, + }, }, }, orderBy: { timestamp: 'desc' }, @@ -423,7 +428,9 @@ export class TasksService { include: { user: { select: { role: true } } }, }); if (assigneeMember?.user.role === 'admin') { - throw new BadRequestException('Cannot assign a platform admin as assignee'); + throw new BadRequestException( + 'Cannot assign a platform admin as assignee', + ); } } @@ -562,7 +569,10 @@ export class TasksService { } if (updateData.status !== undefined) { // Prevent bypassing the approval workflow via direct status change - if (existingTask.status === 'in_review' && updateData.status !== 'in_review') { + if ( + existingTask.status === 'in_review' && + updateData.status !== 'in_review' + ) { throw new BadRequestException( 'Cannot change status directly while task is in review. Use the approve or reject actions instead.', ); @@ -587,7 +597,9 @@ export class TasksService { include: { user: { select: { role: true } } }, }); if (assigneeMember?.user.role === 'admin') { - throw new BadRequestException('Cannot assign a platform admin as assignee'); + throw new BadRequestException( + 'Cannot assign a platform admin as assignee', + ); } } dataToUpdate.assigneeId = @@ -600,7 +612,9 @@ export class TasksService { if (updateData.frequency !== undefined) { dataToUpdate.frequency = updateData.frequency; // When frequency changes, recalculate the review date - dataToUpdate.reviewDate = computeNextTaskReviewDate(updateData.frequency); + dataToUpdate.reviewDate = computeNextTaskReviewDate( + updateData.frequency, + ); } if (updateData.department !== undefined) { dataToUpdate.department = updateData.department; @@ -821,10 +835,7 @@ export class TasksService { /** * Regenerate task from its associated template */ - async regenerateFromTemplate( - organizationId: string, - taskId: string, - ) { + async regenerateFromTemplate(organizationId: string, taskId: string) { const task = await db.task.findFirst({ where: { id: taskId, organizationId }, include: { taskTemplate: true }, @@ -835,7 +846,9 @@ export class TasksService { } if (!task.taskTemplate) { - throw new BadRequestException('Task has no associated template to regenerate from'); + throw new BadRequestException( + 'Task has no associated template to regenerate from', + ); } const updated = await db.task.update({ @@ -868,10 +881,7 @@ export class TasksService { /** * Delete a single task by ID */ - async deleteTask( - organizationId: string, - taskId: string, - ): Promise { + async deleteTask(organizationId: string, taskId: string): Promise { const task = await db.task.findFirst({ where: { id: taskId, @@ -924,7 +934,9 @@ export class TasksService { } if (approver.user.role === 'admin') { - throw new BadRequestException('Cannot assign a platform admin as approver'); + throw new BadRequestException( + 'Cannot assign a platform admin as approver', + ); } const currentMember = await db.member.findFirst({ @@ -1001,7 +1013,9 @@ export class TasksService { } if (approver.user.role === 'admin') { - throw new BadRequestException('Cannot assign a platform admin as approver'); + throw new BadRequestException( + 'Cannot assign a platform admin as approver', + ); } const tasks = await db.task.findMany({ diff --git a/apps/api/src/training/permissions-regression.spec.ts b/apps/api/src/training/permissions-regression.spec.ts index d8a0c77688..1ec83db5f2 100644 --- a/apps/api/src/training/permissions-regression.spec.ts +++ b/apps/api/src/training/permissions-regression.spec.ts @@ -92,9 +92,7 @@ describe('Built-in role permissions — regression', () => { }); it('should have trust read/update', () => { - expect(perms.trust).toEqual( - expect.arrayContaining(['read', 'update']), - ); + expect(perms.trust).toEqual(expect.arrayContaining(['read', 'update'])); }); it('should have pentest create/read/delete', () => { @@ -110,9 +108,7 @@ describe('Built-in role permissions — regression', () => { }); it('should have portal read/update', () => { - expect(perms.portal).toEqual( - expect.arrayContaining(['read', 'update']), - ); + expect(perms.portal).toEqual(expect.arrayContaining(['read', 'update'])); }); it('should have organization read/update/delete', () => { @@ -172,9 +168,7 @@ describe('Built-in role permissions — regression', () => { }); it('should have portal read/update', () => { - expect(perms.portal).toEqual( - expect.arrayContaining(['read', 'update']), - ); + expect(perms.portal).toEqual(expect.arrayContaining(['read', 'update'])); }); it('should have pentest create/read/delete', () => { @@ -253,9 +247,7 @@ describe('Built-in role permissions — regression', () => { }); it('should have portal read/update', () => { - expect(perms.portal).toEqual( - expect.arrayContaining(['read', 'update']), - ); + expect(perms.portal).toEqual(expect.arrayContaining(['read', 'update'])); }); it('should NOT have app access', () => { diff --git a/apps/api/src/training/training-hipaa.spec.ts b/apps/api/src/training/training-hipaa.spec.ts index 7c3a3d6f23..6550067a95 100644 --- a/apps/api/src/training/training-hipaa.spec.ts +++ b/apps/api/src/training/training-hipaa.spec.ts @@ -85,11 +85,7 @@ describe('TrainingService — HIPAA training', () => { }); db.employeeTrainingVideoCompletion.findMany.mockResolvedValue([]); - const result = await service.markVideoComplete( - 'mem_1', - 'org_1', - 'sat-3', - ); + const result = await service.markVideoComplete('mem_1', 'org_1', 'sat-3'); expect(result.videoId).toBe('sat-3'); }); diff --git a/apps/api/src/training/training.controller.spec.ts b/apps/api/src/training/training.controller.spec.ts index a18b608865..dbfdcd95ef 100644 --- a/apps/api/src/training/training.controller.spec.ts +++ b/apps/api/src/training/training.controller.spec.ts @@ -30,9 +30,7 @@ describe('TrainingController', () => { beforeEach(async () => { const module: TestingModule = await Test.createTestingModule({ controllers: [TrainingController], - providers: [ - { provide: TrainingService, useValue: mockTrainingService }, - ], + providers: [{ provide: TrainingService, useValue: mockTrainingService }], }) .overrideGuard(HybridAuthGuard) .useValue(mockGuard) @@ -181,9 +179,7 @@ describe('TrainingController', () => { PERMISSIONS_KEY, controller.getCompletions, ); - expect(permissions).toEqual([ - { resource: 'portal', actions: ['read'] }, - ]); + expect(permissions).toEqual([{ resource: 'portal', actions: ['read'] }]); }); it('markVideoComplete should require portal:update', () => { diff --git a/apps/api/src/training/training.controller.ts b/apps/api/src/training/training.controller.ts index 1bc4b22939..bb75f593ae 100644 --- a/apps/api/src/training/training.controller.ts +++ b/apps/api/src/training/training.controller.ts @@ -158,7 +158,10 @@ export class TrainingController { }) @ApiProduces('application/pdf') @ApiResponse({ status: 200, description: 'PDF certificate file' }) - @ApiResponse({ status: 400, description: 'HIPAA training not complete or member not found' }) + @ApiResponse({ + status: 400, + description: 'HIPAA training not complete or member not found', + }) async generateHipaaCertificate( @OrganizationId() organizationId: string, @Body() dto: SendTrainingCompletionDto, diff --git a/apps/api/src/training/training.service.ts b/apps/api/src/training/training.service.ts index 2f4d9d6524..816f944e93 100644 --- a/apps/api/src/training/training.service.ts +++ b/apps/api/src/training/training.service.ts @@ -163,7 +163,10 @@ export class TrainingService { return { sent: false, reason: 'training_not_complete' }; } - const resolved = await this.resolveMemberForCertificate(memberId, organizationId); + const resolved = await this.resolveMemberForCertificate( + memberId, + organizationId, + ); if ('reason' in resolved) return { sent: false, reason: resolved.reason }; const completedAt = await this.getTrainingCompletionDate(memberId); @@ -186,24 +189,35 @@ export class TrainingService { const isComplete = await this.hasCompletedAllTraining(memberId); if (!isComplete) return { error: 'training_not_complete' }; - const resolved = await this.resolveMemberForCertificate(memberId, organizationId); + const resolved = await this.resolveMemberForCertificate( + memberId, + organizationId, + ); if ('reason' in resolved) return { error: resolved.reason }; const completedAt = await this.getTrainingCompletionDate(memberId); if (!completedAt) return { error: 'no_completion_date' }; - const pdf = await this.trainingCertificatePdfService.generateTrainingCertificatePdf({ - userName: resolved.userName, - organizationName: resolved.organizationName, - completedAt, - }); + const pdf = + await this.trainingCertificatePdfService.generateTrainingCertificatePdf({ + userName: resolved.userName, + organizationName: resolved.organizationName, + completedAt, + }); - return { pdf, fileName: `training-certificate-${resolved.userName.replace(/\s+/g, '-').toLowerCase()}.pdf` }; + return { + pdf, + fileName: `training-certificate-${resolved.userName.replace(/\s+/g, '-').toLowerCase()}.pdf`, + }; } async getHipaaCompletionDate(memberId: string): Promise { const completion = await db.employeeTrainingVideoCompletion.findFirst({ - where: { memberId, videoId: HIPAA_TRAINING_ID, completedAt: { not: null } }, + where: { + memberId, + videoId: HIPAA_TRAINING_ID, + completedAt: { not: null }, + }, }); return completion?.completedAt ?? null; } @@ -213,9 +227,13 @@ export class TrainingService { organizationId: string, ): Promise<{ sent: boolean; reason?: string }> { const isComplete = await this.hasCompletedHipaaTraining(memberId); - if (!isComplete) return { sent: false, reason: 'hipaa_training_not_complete' }; + if (!isComplete) + return { sent: false, reason: 'hipaa_training_not_complete' }; - const resolved = await this.resolveMemberForCertificate(memberId, organizationId); + const resolved = await this.resolveMemberForCertificate( + memberId, + organizationId, + ); if ('reason' in resolved) return { sent: false, reason: resolved.reason }; const completedAt = await this.getHipaaCompletionDate(memberId); @@ -227,7 +245,9 @@ export class TrainingService { organizationName: resolved.organizationName, completedAt, }); - this.logger.log(`HIPAA training completion email sent to ${resolved.email}`); + this.logger.log( + `HIPAA training completion email sent to ${resolved.email}`, + ); return { sent: true }; } @@ -238,19 +258,26 @@ export class TrainingService { const isComplete = await this.hasCompletedHipaaTraining(memberId); if (!isComplete) return { error: 'hipaa_training_not_complete' }; - const resolved = await this.resolveMemberForCertificate(memberId, organizationId); + const resolved = await this.resolveMemberForCertificate( + memberId, + organizationId, + ); if ('reason' in resolved) return { error: resolved.reason }; const completedAt = await this.getHipaaCompletionDate(memberId); if (!completedAt) return { error: 'no_completion_date' }; - const pdf = await this.trainingCertificatePdfService.generateHipaaCertificatePdf({ - userName: resolved.userName, - organizationName: resolved.organizationName, - completedAt, - }); + const pdf = + await this.trainingCertificatePdfService.generateHipaaCertificatePdf({ + userName: resolved.userName, + organizationName: resolved.organizationName, + completedAt, + }); - return { pdf, fileName: `hipaa-training-certificate-${resolved.userName.replace(/\s+/g, '-').toLowerCase()}.pdf` }; + return { + pdf, + fileName: `hipaa-training-certificate-${resolved.userName.replace(/\s+/g, '-').toLowerCase()}.pdf`, + }; } private async resolveMemberForCertificate( @@ -270,7 +297,8 @@ export class TrainingService { if (!member?.user) return { reason: 'member_not_found' }; if (!member.user.email) return { reason: 'no_email' }; - if (member.organizationId !== organizationId) return { reason: 'organization_mismatch' }; + if (member.organizationId !== organizationId) + return { reason: 'organization_mismatch' }; return { userName: member.user.name || 'Team Member', diff --git a/apps/api/src/trigger/cloud-security/run-cloud-security-scan.ts b/apps/api/src/trigger/cloud-security/run-cloud-security-scan.ts index 31786c6e9d..06da14c6cb 100644 --- a/apps/api/src/trigger/cloud-security/run-cloud-security-scan.ts +++ b/apps/api/src/trigger/cloud-security/run-cloud-security-scan.ts @@ -22,20 +22,15 @@ export const runCloudSecurityScan = task({ await tags.add([`org:${organizationId}`]); - logger.info( - `Starting cloud security scan for connection: ${connectionName}`, - { - connectionId, - provider: providerSlug, - organizationId, - }, - ); - try { - // Verify connection is still active + // Verify connection is still active and resolve provider (payload may use legacy "platform") const connection = await db.integrationConnection.findUnique({ where: { id: connectionId }, - select: { id: true, status: true }, + select: { + id: true, + status: true, + provider: { select: { slug: true } }, + }, }); if (!connection) { @@ -52,19 +47,43 @@ export const runCloudSecurityScan = task({ }; } - // Call the cloud security scan API endpoint + const resolvedProviderSlug = + connection.provider?.slug ?? providerSlug; + + logger.info( + `Starting cloud security scan for connection: ${connectionName}`, + { + connectionId, + payloadProviderSlug: providerSlug, + resolvedProviderSlug, + organizationId, + }, + ); + const apiUrl = process.env.BASE_URL || 'http://localhost:3333'; + const headers = { + 'Content-Type': 'application/json', + 'x-service-token': process.env.SERVICE_TOKEN_TRIGGER!, + 'x-organization-id': organizationId, + }; + // Auto-detect services before scanning (AWS via Cost Explorer, GCP via Service Usage API) + // Azure uses scan-based detection instead, so skip the pre-scan detect call + if (resolvedProviderSlug === 'aws' || resolvedProviderSlug === 'gcp') { + try { + await fetch( + `${apiUrl}/v1/cloud-security/detect-services/${connectionId}`, + { method: 'POST', headers }, + ); + } catch { + // Non-critical — scan proceeds even if detect fails + } + } + + // Run the scan const response = await fetch( `${apiUrl}/v1/cloud-security/scan/${connectionId}`, - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - 'x-service-token': process.env.SERVICE_TOKEN_TRIGGER!, - 'x-organization-id': organizationId, - }, - }, + { method: 'POST', headers }, ); if (!response.ok) { diff --git a/apps/api/src/trigger/email/send-email.ts b/apps/api/src/trigger/email/send-email.ts index ad14d0f2fe..212e87f1d4 100644 --- a/apps/api/src/trigger/email/send-email.ts +++ b/apps/api/src/trigger/email/send-email.ts @@ -53,7 +53,8 @@ export const sendEmailTask = schemaTask({ try { // Build List-Unsubscribe headers for Gmail/RFC 8058 one-click compliance - const apiBaseUrl = process.env.NEXT_PUBLIC_API_URL || 'https://api.trycomp.ai'; + const apiBaseUrl = + process.env.NEXT_PUBLIC_API_URL || 'https://api.trycomp.ai'; const token = generateUnsubscribeToken(params.to); const oneClickUrl = `${apiBaseUrl}/v1/email/unsubscribe?email=${encodeURIComponent(params.to)}&token=${encodeURIComponent(token)}`; const headers: Record = { diff --git a/apps/api/src/trigger/integration-platform/sync-employees-schedule.ts b/apps/api/src/trigger/integration-platform/sync-employees-schedule.ts index 4bfe108f62..429e024f21 100644 --- a/apps/api/src/trigger/integration-platform/sync-employees-schedule.ts +++ b/apps/api/src/trigger/integration-platform/sync-employees-schedule.ts @@ -199,7 +199,11 @@ async function syncProvider(params: SyncProviderParams): Promise { default: // Try generic dynamic sync endpoint for non-built-in providers - return syncDynamicProvider({ providerSlug, connectionId, organizationId }); + return syncDynamicProvider({ + providerSlug, + connectionId, + organizationId, + }); } } diff --git a/apps/api/src/trigger/policies/update-policy-helpers.ts b/apps/api/src/trigger/policies/update-policy-helpers.ts index 0739ad3e3b..c4fd5fec66 100644 --- a/apps/api/src/trigger/policies/update-policy-helpers.ts +++ b/apps/api/src/trigger/policies/update-policy-helpers.ts @@ -1,6 +1,11 @@ import { openai } from '@ai-sdk/openai'; import { db } from '@db'; -import type { FrameworkEditorFramework, FrameworkEditorPolicyTemplate, Policy, Prisma } from '@db'; +import type { + FrameworkEditorFramework, + FrameworkEditorPolicyTemplate, + Policy, + Prisma, +} from '@db'; import { logger } from '@trigger.dev/sdk'; import { generateObject, NoObjectGeneratedError } from 'ai'; import { z } from 'zod'; @@ -10,7 +15,7 @@ import { generatePrompt } from './update-policy-prompts'; const PLACEHOLDER_REGEX = /<<\s*TO\s*REVIEW\s*>>/gi; function extractText(node: Record): string { - const text = node && typeof node['text'] === 'string' ? (node['text'] as string) : ''; + const text = node && typeof node['text'] === 'string' ? node['text'] : ''; const content = Array.isArray((node as any)?.content) ? ((node as any).content as Record[]) : null; @@ -20,10 +25,12 @@ function extractText(node: Record): string { return text || ''; } -function sanitizeNodePlaceholders(node: Record): Record { +function sanitizeNodePlaceholders( + node: Record, +): Record { const cloned: Record = { ...node }; if (typeof cloned['text'] === 'string') { - const replaced = (cloned['text'] as string) + const replaced = cloned['text'] .replace(PLACEHOLDER_REGEX, '') .replace(/\s{2,}/g, ' ') .trim(); @@ -40,7 +47,10 @@ function sanitizeNodePlaceholders(node: Record): Record[] = []; let i = 0; while (i < content.length) { - const node = content[i] as Record; - const nodeType = typeof node['type'] === 'string' ? (node['type'] as string) : ''; + const node = content[i]; + const nodeType = typeof node['type'] === 'string' ? node['type'] : ''; if (nodeType === 'heading') { const headingText = extractText(node); if (shouldRemoveAuditorArtifactsHeading(headingText)) { i += 1; while (i < content.length) { - const nextNode = content[i] as Record; - const nextType = typeof nextNode['type'] === 'string' ? (nextNode['type'] as string) : ''; + const nextNode = content[i]; + const nextType = + typeof nextNode['type'] === 'string' ? nextNode['type'] : ''; if (nextType === 'heading') break; i += 1; } @@ -71,15 +82,17 @@ function removeAuditorArtifactsSection( } function extractHeadingText(node: Record): string { - const type = typeof node['type'] === 'string' ? (node['type'] as string) : ''; + const type = typeof node['type'] === 'string' ? node['type'] : ''; if (type !== 'heading') return ''; return extractText(node).trim(); } -function getAllowedTopLevelHeadings(originalContent: Record[]): string[] { +function getAllowedTopLevelHeadings( + originalContent: Record[], +): string[] { const allowed: string[] = []; for (const node of originalContent) { - const type = typeof node['type'] === 'string' ? (node['type'] as string) : ''; + const type = typeof node['type'] === 'string' ? node['type'] : ''; if (type === 'heading') { const level = (node as any)?.attrs?.level; if (typeof level === 'number' && level >= 1 && level <= 2) { @@ -133,15 +146,18 @@ export async function fetchOrganizationAndPolicy( }), ]); - if (!organization) throw new Error(`Organization not found for ${organizationId}`); + if (!organization) + throw new Error(`Organization not found for ${organizationId}`); if (!policy) throw new Error(`Policy not found for ${policyId}`); - if (!policy.policyTemplateId) throw new Error(`Policy template not found for ${policyId}`); + if (!policy.policyTemplateId) + throw new Error(`Policy template not found for ${policyId}`); const policyTemplate = await db.frameworkEditorPolicyTemplate.findUnique({ where: { id: policy.policyTemplateId }, }); - if (!policyTemplate) throw new Error(`Policy template not found for ${policy.policyTemplateId}`); + if (!policyTemplate) + throw new Error(`Policy template not found for ${policy.policyTemplateId}`); return { organization, policy, policyTemplate }; } @@ -193,10 +209,15 @@ Return the complete TipTap document following ALL the above requirements using p const parsed = object as { type?: string; content?: unknown }; if (parsed?.type !== 'document' || !Array.isArray(parsed?.content)) { - throw new Error('AI response did not match expected TipTap document structure'); + throw new Error( + 'AI response did not match expected TipTap document structure', + ); } - return { type: 'document' as const, content: parsed.content as Record[] }; + return { + type: 'document' as const, + content: parsed.content as Record[], + }; } catch (aiError) { logger.error(`Error generating AI content: ${aiError}`); if (NoObjectGeneratedError.isInstance(aiError)) { @@ -233,7 +254,8 @@ export async function updatePolicyInDatabase( if (pdfUrlsToDelete.length > 0) { try { - const { S3Client, DeleteObjectCommand } = await import('@aws-sdk/client-s3'); + const { S3Client, DeleteObjectCommand } = + await import('@aws-sdk/client-s3'); const bucketName = process.env.APP_AWS_BUCKET_NAME; if (bucketName) { const s3 = new S3Client({ @@ -241,7 +263,9 @@ export async function updatePolicyInDatabase( }); await Promise.allSettled( pdfUrlsToDelete.map((pdfUrl) => - s3.send(new DeleteObjectCommand({ Bucket: bucketName, Key: pdfUrl })), + s3.send( + new DeleteObjectCommand({ Bucket: bucketName, Key: pdfUrl }), + ), ), ); } @@ -290,11 +314,21 @@ export async function updatePolicyInDatabase( } } -export async function processPolicyUpdate(params: UpdatePolicyParams): Promise { +export async function processPolicyUpdate( + params: UpdatePolicyParams, +): Promise { const { organizationId, policyId, contextHub, frameworks, memberId } = params; - const { organization, policyTemplate } = await fetchOrganizationAndPolicy(organizationId, policyId); - const prompt = await generatePolicyPrompt(policyTemplate, contextHub, organization, frameworks); + const { organization, policyTemplate } = await fetchOrganizationAndPolicy( + organizationId, + policyId, + ); + const prompt = await generatePolicyPrompt( + policyTemplate, + contextHub, + organization, + frameworks, + ); const updatedContent = await generatePolicyContent(prompt); await updatePolicyInDatabase(policyId, updatedContent.content, memberId); diff --git a/apps/api/src/trigger/policies/update-policy-prompts.ts b/apps/api/src/trigger/policies/update-policy-prompts.ts index be848cdb30..99a8c23b99 100644 --- a/apps/api/src/trigger/policies/update-policy-prompts.ts +++ b/apps/api/src/trigger/policies/update-policy-prompts.ts @@ -1,4 +1,7 @@ -import type { FrameworkEditorFramework, FrameworkEditorPolicyTemplate } from '@db'; +import type { + FrameworkEditorFramework, + FrameworkEditorPolicyTemplate, +} from '@db'; import { logger } from '@trigger.dev/sdk'; export const generatePrompt = ({ @@ -18,7 +21,9 @@ export const generatePrompt = ({ logger.info(`Company Name: ${companyName}`); logger.info(`Company Website: ${companyWebsite}`); logger.info(`Context: ${contextHub}`); - logger.info(`Existing Policy Content: ${JSON.stringify(policyTemplate.content)}`); + logger.info( + `Existing Policy Content: ${JSON.stringify(policyTemplate.content)}`, + ); logger.info( `Frameworks: ${JSON.stringify( frameworks.map((f) => ({ id: f.id, name: f.name, version: f.version })), @@ -29,7 +34,9 @@ export const generatePrompt = ({ frameworks.length > 0 ? frameworks.map((f) => `${f.name} v${f.version}`).join(', ') : 'None explicitly selected'; - const hasHIPAA = frameworks.some((f) => f.name.toLowerCase().includes('hipaa')); + const hasHIPAA = frameworks.some((f) => + f.name.toLowerCase().includes('hipaa'), + ); const hasSOC2 = frameworks.some( (f) => /soc\s*2/i.test(f.name) || f.name.toLowerCase().includes('soc'), ); diff --git a/apps/api/src/trigger/policies/update-policy.ts b/apps/api/src/trigger/policies/update-policy.ts index 2e177b6f84..ae10c57284 100644 --- a/apps/api/src/trigger/policies/update-policy.ts +++ b/apps/api/src/trigger/policies/update-policy.ts @@ -2,7 +2,10 @@ import { logger, metadata, queue, schemaTask, tags } from '@trigger.dev/sdk'; import { z } from 'zod'; import { processPolicyUpdate } from './update-policy-helpers'; -const updatePolicyQueue = queue({ name: 'update-policy', concurrencyLimit: 50 }); +const updatePolicyQueue = queue({ + name: 'update-policy', + concurrencyLimit: 50, +}); export const updatePolicy = schemaTask({ id: 'update-policy', diff --git a/apps/api/src/trigger/questionnaire/parse-questionnaire.ts b/apps/api/src/trigger/questionnaire/parse-questionnaire.ts index 9a04a1a916..399730255a 100644 --- a/apps/api/src/trigger/questionnaire/parse-questionnaire.ts +++ b/apps/api/src/trigger/questionnaire/parse-questionnaire.ts @@ -346,8 +346,9 @@ export const parseQuestionnaireTask = task({ 'questionnaire'; const s3Key = payload.s3Key || ''; const fileType = payload.fileType || 'application/octet-stream'; - const fileSize = payload.fileSize - ?? (payload.fileData + const fileSize = + payload.fileSize ?? + (payload.fileData ? Buffer.from(payload.fileData, 'base64').length : 0); diff --git a/apps/api/src/trigger/vector-store/helpers/extract-content-from-file.spec.ts b/apps/api/src/trigger/vector-store/helpers/extract-content-from-file.spec.ts index c7455d0315..5988d2dbd8 100644 --- a/apps/api/src/trigger/vector-store/helpers/extract-content-from-file.spec.ts +++ b/apps/api/src/trigger/vector-store/helpers/extract-content-from-file.spec.ts @@ -127,9 +127,9 @@ describe('extractContentFromFile - Excel handling', () => { it('should throw on corrupt Excel data', async () => { const badData = Buffer.from('not an excel file').toString('base64'); - await expect( - extractContentFromFile(badData, XLSX_MIME), - ).rejects.toThrow('Failed to parse Excel file'); + await expect(extractContentFromFile(badData, XLSX_MIME)).rejects.toThrow( + 'Failed to parse Excel file', + ); }); }); @@ -251,8 +251,8 @@ describe('extractContentFromFile - image extraction', () => { const base64 = Buffer.from('fake-image-data').toString('base64'); - await expect( - extractContentFromFile(base64, 'image/png'), - ).rejects.toThrow('Failed to extract image content: Vision API error'); + await expect(extractContentFromFile(base64, 'image/png')).rejects.toThrow( + 'Failed to extract image content: Vision API error', + ); }); }); diff --git a/apps/api/src/trigger/vendor/vendor-risk-assessment-task.ts b/apps/api/src/trigger/vendor/vendor-risk-assessment-task.ts index 986c7877fa..78210a5138 100644 --- a/apps/api/src/trigger/vendor/vendor-risk-assessment-task.ts +++ b/apps/api/src/trigger/vendor/vendor-risk-assessment-task.ts @@ -791,397 +791,428 @@ export const vendorRiskAssessmentTask: Task< const frameworkChecklist = buildFrameworkChecklist(organizationFrameworks); try { - // Helper to append a progress message to run metadata - const messages: ResearchMessage[] = []; - const pushMessage = (text: string, type: ResearchMessage['type'], url?: string) => { - const msg: ResearchMessage = { text, type, timestamp: Date.now(), ...url ? { url } : {} }; - messages.push(msg); - metadata.set('messages', messages); - }; + // Helper to append a progress message to run metadata + const messages: ResearchMessage[] = []; + const pushMessage = ( + text: string, + type: ResearchMessage['type'], + url?: string, + ) => { + const msg: ResearchMessage = { + text, + type, + timestamp: Date.now(), + ...(url ? { url } : {}), + }; + messages.push(msg); + metadata.set('messages', messages); + }; - // Initialize metadata - metadata.set('phase', 'starting'); - metadata.set('messages', []); - metadata.set('coreReady', false); - metadata.set('newsReady', false); + // Initialize metadata + metadata.set('phase', 'starting'); + metadata.set('messages', []); + metadata.set('coreReady', false); + metadata.set('newsReady', false); - metadata.set('phase', 'researching'); - pushMessage(`Analyzing ${payload.vendorWebsite}...`, 'searching'); + metadata.set('phase', 'researching'); + pushMessage(`Analyzing ${payload.vendorWebsite}...`, 'searching'); - logger.info('🚀 Starting parallel research', { - vendor: payload.vendorName, - website: payload.vendorWebsite, - organizationId: payload.organizationId, - }); + logger.info('🚀 Starting parallel research', { + vendor: payload.vendorName, + website: payload.vendorWebsite, + organizationId: payload.organizationId, + }); - const coreStartedAt = Date.now(); - const newsStartedAt = Date.now(); + const coreStartedAt = Date.now(); + const newsStartedAt = Date.now(); - const sleep = (ms: number) => - new Promise((resolve) => setTimeout(resolve, ms)); + const sleep = (ms: number) => + new Promise((resolve) => setTimeout(resolve, ms)); - // Run core research and news research in parallel - const [coreResult, newsResult] = await Promise.allSettled([ - (async () => { - pushMessage('Crawling vendor website...', 'searching'); - logger.info('🔍 Core research started', { - vendor: payload.vendorName, - website: payload.vendorWebsite, - }); - const result = await firecrawlResearchCore({ - vendorName: payload.vendorName, - vendorWebsite: payload.vendorWebsite!, - }); - const durationMs = Date.now() - coreStartedAt; - if (result) { - const certCount = result.certifications?.length ?? 0; - const verifiedCount = - result.certifications?.filter((c) => c.status === 'verified') - .length ?? 0; - const linkCount = result.links?.length ?? 0; - logger.info('✅ Core research completed', { + // Run core research and news research in parallel + const [coreResult, newsResult] = await Promise.allSettled([ + (async () => { + pushMessage('Crawling vendor website...', 'searching'); + logger.info('🔍 Core research started', { vendor: payload.vendorName, - durationMs, - certifications: certCount, - verifiedCertifications: verifiedCount, - links: linkCount, - hasAssessment: Boolean(result.securityAssessment), - riskLevel: result.riskLevel ?? 'none', + website: payload.vendorWebsite, }); - - // Report each finding individually with delays so the UI - // shows them appearing one by one in real time - if (result.certifications?.length) { - pushMessage('Extracting certifications...', 'analyzing'); - await sleep(300); - for (const cert of result.certifications) { - if (cert.status === 'verified') { - pushMessage(`Found ${cert.type}`, 'found', cert.url ?? undefined); - await sleep(250); + const result = await firecrawlResearchCore({ + vendorName: payload.vendorName, + vendorWebsite: payload.vendorWebsite!, + }); + const durationMs = Date.now() - coreStartedAt; + if (result) { + const certCount = result.certifications?.length ?? 0; + const verifiedCount = + result.certifications?.filter((c) => c.status === 'verified') + .length ?? 0; + const linkCount = result.links?.length ?? 0; + logger.info('✅ Core research completed', { + vendor: payload.vendorName, + durationMs, + certifications: certCount, + verifiedCertifications: verifiedCount, + links: linkCount, + hasAssessment: Boolean(result.securityAssessment), + riskLevel: result.riskLevel ?? 'none', + }); + + // Report each finding individually with delays so the UI + // shows them appearing one by one in real time + if (result.certifications?.length) { + pushMessage('Extracting certifications...', 'analyzing'); + await sleep(300); + for (const cert of result.certifications) { + if (cert.status === 'verified') { + pushMessage( + `Found ${cert.type}`, + 'found', + cert.url ?? undefined, + ); + await sleep(250); + } } } - } - if (result.links?.length) { - pushMessage('Extracting security and legal links...', 'analyzing'); - await sleep(300); - for (const link of result.links) { - pushMessage(`Found ${link.label}`, 'found', link.url); - await sleep(200); + if (result.links?.length) { + pushMessage( + 'Extracting security and legal links...', + 'analyzing', + ); + await sleep(300); + for (const link of result.links) { + pushMessage(`Found ${link.label}`, 'found', link.url); + await sleep(200); + } } - } - if (result.securityAssessment) { - pushMessage('Generating security assessment...', 'analyzing'); - await sleep(400); - pushMessage('Security assessment complete', 'found'); + if (result.securityAssessment) { + pushMessage('Generating security assessment...', 'analyzing'); + await sleep(400); + pushMessage('Security assessment complete', 'found'); + } + } else { + logger.warn('⚠️ Core research returned null', { + vendor: payload.vendorName, + durationMs, + }); } - } else { - logger.warn('⚠️ Core research returned null', { + return result; + })(), + (async () => { + logger.info('📰 News research started', { vendor: payload.vendorName, - durationMs, + website: payload.vendorWebsite, }); - } - return result; - })(), - (async () => { - logger.info('📰 News research started', { - vendor: payload.vendorName, - website: payload.vendorWebsite, - }); - const result = await firecrawlResearchNews({ - vendorName: payload.vendorName, - vendorWebsite: payload.vendorWebsite!, - }); - const durationMs = Date.now() - newsStartedAt; - if (result?.length) { - logger.info('✅ News research completed', { - vendor: payload.vendorName, - durationMs, - newsItems: result.length, + const result = await firecrawlResearchNews({ + vendorName: payload.vendorName, + vendorWebsite: payload.vendorWebsite!, }); - // Stagger news reporting - pushMessage('Processing recent news...', 'analyzing'); - await sleep(200); - for (const item of result) { - pushMessage(`Found: ${item.title}`, 'found', item.url ?? undefined); - await sleep(150); + const durationMs = Date.now() - newsStartedAt; + if (result?.length) { + logger.info('✅ News research completed', { + vendor: payload.vendorName, + durationMs, + newsItems: result.length, + }); + // Stagger news reporting + pushMessage('Processing recent news...', 'analyzing'); + await sleep(200); + for (const item of result) { + pushMessage( + `Found: ${item.title}`, + 'found', + item.url ?? undefined, + ); + await sleep(150); + } + } else { + logger.info('📰 News research returned no items', { + vendor: payload.vendorName, + durationMs, + }); } - } else { - logger.info('📰 News research returned no items', { - vendor: payload.vendorName, - durationMs, - }); - } - return result; - })(), - ]); - - logger.info('🏁 Both research calls settled', { - vendor: payload.vendorName, - coreStatus: coreResult.status, - newsStatus: newsResult.status, - coreError: - coreResult.status === 'rejected' ? String(coreResult.reason) : null, - newsError: - newsResult.status === 'rejected' ? String(newsResult.reason) : null, - }); - - // --- Process core results --- - const coreData = - coreResult.status === 'fulfilled' ? coreResult.value : null; + return result; + })(), + ]); - if (coreData) { - pushMessage('Writing core research to database...', 'analyzing'); - logger.info('💾 Writing core data to GlobalVendors', { + logger.info('🏁 Both research calls settled', { vendor: payload.vendorName, - domain, - normalizedWebsite, + coreStatus: coreResult.status, + newsStatus: newsResult.status, + coreError: + coreResult.status === 'rejected' ? String(coreResult.reason) : null, + newsError: + newsResult.status === 'rejected' ? String(newsResult.reason) : null, }); - const description = buildRiskAssessmentDescription({ - vendorName: payload.vendorName, - vendorWebsite: payload.vendorWebsite ?? null, - research: { ...coreData, news: null }, - frameworkChecklist, - organizationFrameworks, - }); - const data = parseRiskAssessmentJson(description); - - // Upsert GlobalVendors (same advisory lock pattern as before) - const lockKey = domain ?? normalizedWebsite; - const { nextVersion, updatedWebsites } = await withAdvisoryLock({ - lockKey, - run: async () => { - const latestGlobalVendors = domain - ? await db.globalVendors.findMany({ - where: { website: { contains: domain } }, - select: { - website: true, - riskAssessmentVersion: true, - riskAssessmentUpdatedAt: true, - }, - orderBy: [ - { riskAssessmentUpdatedAt: 'desc' }, - { createdAt: 'desc' }, - ], - }) - : []; - - const currentMax = maxVersion(latestGlobalVendors); - const computedNext = incrementVersion(currentMax); - const now = new Date(); - - if (latestGlobalVendors.length > 0) { - for (const gv of latestGlobalVendors) { - await db.globalVendors.update({ - where: { website: gv.website }, - data: { - company_name: payload.vendorName, - riskAssessmentData: data, - riskAssessmentVersion: computedNext, - riskAssessmentUpdatedAt: now, - }, - }); + // --- Process core results --- + const coreData = + coreResult.status === 'fulfilled' ? coreResult.value : null; + + if (coreData) { + pushMessage('Writing core research to database...', 'analyzing'); + logger.info('💾 Writing core data to GlobalVendors', { + vendor: payload.vendorName, + domain, + normalizedWebsite, + }); + + const description = buildRiskAssessmentDescription({ + vendorName: payload.vendorName, + vendorWebsite: payload.vendorWebsite ?? null, + research: { ...coreData, news: null }, + frameworkChecklist, + organizationFrameworks, + }); + const data = parseRiskAssessmentJson(description); + + // Upsert GlobalVendors (same advisory lock pattern as before) + const lockKey = domain ?? normalizedWebsite; + const { nextVersion, updatedWebsites } = await withAdvisoryLock({ + lockKey, + run: async () => { + const latestGlobalVendors = domain + ? await db.globalVendors.findMany({ + where: { website: { contains: domain } }, + select: { + website: true, + riskAssessmentVersion: true, + riskAssessmentUpdatedAt: true, + }, + orderBy: [ + { riskAssessmentUpdatedAt: 'desc' }, + { createdAt: 'desc' }, + ], + }) + : []; + + const currentMax = maxVersion(latestGlobalVendors); + const computedNext = incrementVersion(currentMax); + const now = new Date(); + + if (latestGlobalVendors.length > 0) { + for (const gv of latestGlobalVendors) { + await db.globalVendors.update({ + where: { website: gv.website }, + data: { + company_name: payload.vendorName, + riskAssessmentData: data, + riskAssessmentVersion: computedNext, + riskAssessmentUpdatedAt: now, + }, + }); + } + return { + nextVersion: computedNext, + updatedWebsites: latestGlobalVendors.map((gv) => gv.website), + }; } + + await db.globalVendors.upsert({ + where: { website: normalizedWebsite }, + create: { + website: normalizedWebsite, + company_name: payload.vendorName, + riskAssessmentData: data, + riskAssessmentVersion: computedNext, + riskAssessmentUpdatedAt: now, + }, + update: { + company_name: payload.vendorName, + riskAssessmentData: data, + riskAssessmentVersion: computedNext, + riskAssessmentUpdatedAt: now, + }, + }); + return { nextVersion: computedNext, - updatedWebsites: latestGlobalVendors.map((gv) => gv.website), + updatedWebsites: [normalizedWebsite], }; - } + }, + }); - await db.globalVendors.upsert({ - where: { website: normalizedWebsite }, - create: { - website: normalizedWebsite, - company_name: payload.vendorName, - riskAssessmentData: data, - riskAssessmentVersion: computedNext, - riskAssessmentUpdatedAt: now, - }, - update: { - company_name: payload.vendorName, - riskAssessmentData: data, - riskAssessmentVersion: computedNext, - riskAssessmentUpdatedAt: now, + logger.info('💾 GlobalVendors upsert complete', { + vendor: payload.vendorName, + version: nextVersion, + updatedWebsites, + }); + + // Extract risk level and badges + logger.info('🎯 Normalizing risk level', { + vendor: payload.vendorName, + }); + const rawRiskLevel = extractRiskLevel(data); + const normalizedRiskLvl = await normalizeRiskLevel(rawRiskLevel); + const inherentProbability = mapRiskLevelToLikelihood(normalizedRiskLvl); + const inherentImpact = mapRiskLevelToImpact(normalizedRiskLvl); + const residualProbability = mapRiskLevelToLikelihood(normalizedRiskLvl); + const residualImpact = mapRiskLevelToImpact(normalizedRiskLvl); + const complianceBadges = extractComplianceBadges(data); + const logoUrl = generateLogoUrl(vendor.website); + + logger.info('📊 Risk level and badges extracted', { + vendor: payload.vendorName, + rawRiskLevel, + normalizedRiskLevel: normalizedRiskLvl, + hasBadges: Boolean(complianceBadges), + badgeCount: Array.isArray(complianceBadges) + ? complianceBadges.length + : 0, + hasLogo: Boolean(logoUrl), + }); + + // Update vendor with core data (keep status in_progress — news may still be loading) + await db.vendor.update({ + where: { id: vendor.id }, + data: { + inherentProbability, + inherentImpact, + residualProbability, + residualImpact, + ...(complianceBadges ? { complianceBadges } : {}), + ...(logoUrl ? { logoUrl } : {}), + }, + }); + + metadata.set('phase', 'core_complete'); + metadata.set('coreReady', true); + + logger.info( + '🎉 Core phase complete — vendor updated, metadata.coreReady=true', + { + vendor: payload.vendorName, + vendorId: vendor.id, + version: nextVersion, + }, + ); + + // --- Process news results (merge into existing data) --- + const newsData = + newsResult.status === 'fulfilled' ? newsResult.value : null; + + if (newsData && newsData.length > 0) { + pushMessage('Adding news to research data...', 'analyzing'); + + await withAdvisoryLock({ + lockKey, + run: async () => { + // Read current data, merge news, write back + const websites = + updatedWebsites.length > 0 + ? updatedWebsites + : [normalizedWebsite]; + for (const website of websites) { + const gv = await db.globalVendors.findUnique({ + where: { website }, + select: { riskAssessmentData: true }, + }); + if (!gv?.riskAssessmentData) continue; + + const existingParsed = gv.riskAssessmentData as Record< + string, + unknown + >; + const existingTyped = + existingParsed as unknown as import('./vendor-risk-assessment/agent-types').VendorRiskAssessmentDataV1; + const merged = mergeNewsIntoRiskAssessment( + existingTyped, + newsData, + ); + + await db.globalVendors.update({ + where: { website }, + data: { + riskAssessmentData: JSON.parse(JSON.stringify(merged)), + }, + }); + } }, }); - return { - nextVersion: computedNext, - updatedWebsites: [normalizedWebsite], - }; - }, - }); - - logger.info('💾 GlobalVendors upsert complete', { - vendor: payload.vendorName, - version: nextVersion, - updatedWebsites, - }); + metadata.set('newsReady', true); + logger.info( + '📰 News merged into GlobalVendors — metadata.newsReady=true', + { + vendor: payload.vendorName, + vendorId: vendor.id, + newsCount: newsData.length, + websites: + updatedWebsites.length > 0 + ? updatedWebsites + : [normalizedWebsite], + }, + ); + } else if (newsResult.status === 'rejected') { + pushMessage('News research could not be completed', 'error'); + logger.warn('News research failed, continuing with core data only', { + vendor: payload.vendorName, + error: + newsResult.reason instanceof Error + ? newsResult.reason.message + : String(newsResult.reason), + }); + } + } else { + // Core research failed + if (coreResult.status === 'rejected') { + pushMessage('Research encountered an issue', 'error'); + metadata.set('phase', 'failed'); + throw coreResult.reason; + } + // Core returned null (API key missing, invalid URL, etc.) + pushMessage('Could not complete research for this vendor', 'error'); + metadata.set('phase', 'failed'); + throw new Error( + `Core research returned null for ${payload.vendorName} — vendor will not be marked as assessed`, + ); + } - // Extract risk level and badges - logger.info('🎯 Normalizing risk level', { - vendor: payload.vendorName, - }); - const rawRiskLevel = extractRiskLevel(data); - const normalizedRiskLvl = await normalizeRiskLevel(rawRiskLevel); - const inherentProbability = mapRiskLevelToLikelihood(normalizedRiskLvl); - const inherentImpact = mapRiskLevelToImpact(normalizedRiskLvl); - const residualProbability = mapRiskLevelToLikelihood(normalizedRiskLvl); - const residualImpact = mapRiskLevelToImpact(normalizedRiskLvl); - const complianceBadges = extractComplianceBadges(data); - const logoUrl = generateLogoUrl(vendor.website); - - logger.info('📊 Risk level and badges extracted', { + // Mark vendor as assessed and flip verify task + logger.info('🏷️ Setting vendor status to assessed', { vendor: payload.vendorName, - rawRiskLevel, - normalizedRiskLevel: normalizedRiskLvl, - hasBadges: Boolean(complianceBadges), - badgeCount: Array.isArray(complianceBadges) ? complianceBadges.length : 0, - hasLogo: Boolean(logoUrl), + vendorId: vendor.id, }); - - // Update vendor with core data (keep status in_progress — news may still be loading) await db.vendor.update({ where: { id: vendor.id }, + data: { status: VendorStatus.assessed }, + }); + + await db.taskItem.updateMany({ + where: { + id: verifyTaskItemId, + status: { notIn: [TaskItemStatus.done, TaskItemStatus.canceled] }, + }, data: { - inherentProbability, - inherentImpact, - residualProbability, - residualImpact, - ...(complianceBadges ? { complianceBadges } : {}), - ...(logoUrl ? { logoUrl } : {}), + status: TaskItemStatus.todo, + description: + 'Review the latest Risk Assessment and confirm it is accurate.', + assigneeId: assigneeMemberId, + updatedById: creatorMemberId, }, }); - metadata.set('phase', 'core_complete'); - metadata.set('coreReady', true); + metadata.set('phase', 'complete'); - logger.info('🎉 Core phase complete — vendor updated, metadata.coreReady=true', { + logger.info('✅ COMPLETED — all phases done', { vendor: payload.vendorName, vendorId: vendor.id, - version: nextVersion, + researched: Boolean(coreData), + hasNews: newsResult.status === 'fulfilled' && Boolean(newsResult.value), + coreStatus: coreResult.status, + newsStatus: newsResult.status, }); - // --- Process news results (merge into existing data) --- - const newsData = - newsResult.status === 'fulfilled' ? newsResult.value : null; - - if (newsData && newsData.length > 0) { - pushMessage('Adding news to research data...', 'analyzing'); - - await withAdvisoryLock({ - lockKey, - run: async () => { - // Read current data, merge news, write back - const websites = - updatedWebsites.length > 0 - ? updatedWebsites - : [normalizedWebsite]; - for (const website of websites) { - const gv = await db.globalVendors.findUnique({ - where: { website }, - select: { riskAssessmentData: true }, - }); - if (!gv?.riskAssessmentData) continue; - - const existingParsed = gv.riskAssessmentData as Record< - string, - unknown - >; - const existingTyped = - existingParsed as unknown as import('./vendor-risk-assessment/agent-types').VendorRiskAssessmentDataV1; - const merged = mergeNewsIntoRiskAssessment( - existingTyped, - newsData, - ); - - await db.globalVendors.update({ - where: { website }, - data: { - riskAssessmentData: JSON.parse(JSON.stringify(merged)), - }, - }); - } - }, - }); - - metadata.set('newsReady', true); - logger.info('📰 News merged into GlobalVendors — metadata.newsReady=true', { - vendor: payload.vendorName, - vendorId: vendor.id, - newsCount: newsData.length, - websites: updatedWebsites.length > 0 ? updatedWebsites : [normalizedWebsite], - }); - } else if (newsResult.status === 'rejected') { - pushMessage('News research could not be completed', 'error'); - logger.warn('News research failed, continuing with core data only', { - vendor: payload.vendorName, - error: - newsResult.reason instanceof Error - ? newsResult.reason.message - : String(newsResult.reason), - }); - } - } else { - // Core research failed - if (coreResult.status === 'rejected') { - pushMessage('Research encountered an issue', 'error'); - metadata.set('phase', 'failed'); - throw coreResult.reason; - } - // Core returned null (API key missing, invalid URL, etc.) - pushMessage('Could not complete research for this vendor', 'error'); - metadata.set('phase', 'failed'); - throw new Error( - `Core research returned null for ${payload.vendorName} — vendor will not be marked as assessed`, - ); - } - - // Mark vendor as assessed and flip verify task - logger.info('🏷️ Setting vendor status to assessed', { - vendor: payload.vendorName, - vendorId: vendor.id, - }); - await db.vendor.update({ - where: { id: vendor.id }, - data: { status: VendorStatus.assessed }, - }); - - await db.taskItem.updateMany({ - where: { - id: verifyTaskItemId, - status: { notIn: [TaskItemStatus.done, TaskItemStatus.canceled] }, - }, - data: { - status: TaskItemStatus.todo, - description: - 'Review the latest Risk Assessment and confirm it is accurate.', - assigneeId: assigneeMemberId, - updatedById: creatorMemberId, - }, - }); - - metadata.set('phase', 'complete'); - - logger.info('✅ COMPLETED — all phases done', { - vendor: payload.vendorName, - vendorId: vendor.id, - researched: Boolean(coreData), - hasNews: newsResult.status === 'fulfilled' && Boolean(newsResult.value), - coreStatus: coreResult.status, - newsStatus: newsResult.status, - }); - - return { - success: true, - vendorId: vendor.id, - deduped: false, - researched: Boolean(coreData), - riskAssessmentVersion: coreData ? 'latest' : null, - verifyTaskItemId, - }; + return { + success: true, + vendorId: vendor.id, + deduped: false, + researched: Boolean(coreData), + riskAssessmentVersion: coreData ? 'latest' : null, + verifyTaskItemId, + }; } catch (error) { // Reset vendor status so the UI no longer shows an infinite loading state. // The user can retry later once the underlying issue is resolved. diff --git a/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-core.ts b/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-core.ts index a2e366fdfc..cb92637fe3 100644 --- a/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-core.ts +++ b/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-core.ts @@ -56,11 +56,13 @@ Focus on the official website ${vendorWebsite} and its trust/security/compliance properties: { risk_level: { type: 'string', - description: 'Overall vendor risk level: critical, high, medium, low, or very_low', + description: + 'Overall vendor risk level: critical, high, medium, low, or very_low', }, security_assessment: { type: 'string', - description: 'A detailed paragraph summarizing the vendor security posture, including strengths, weaknesses, and key findings', + description: + 'A detailed paragraph summarizing the vendor security posture, including strengths, weaknesses, and key findings', }, last_researched_at: { type: 'string', @@ -68,30 +70,36 @@ Focus on the official website ${vendorWebsite} and its trust/security/compliance }, certifications: { type: 'array', - description: 'All security and compliance certifications found on the vendor website', + description: + 'All security and compliance certifications found on the vendor website', items: { type: 'object', properties: { type: { type: 'string', - description: 'Certification name, e.g. SOC 2 Type II, ISO 27001, FedRAMP, HIPAA, PCI DSS, GDPR, ISO 42001, ISO 27017, ISO 27018, TISAX, CSA STAR, C5, etc.', + description: + 'Certification name, e.g. SOC 2 Type II, ISO 27001, FedRAMP, HIPAA, PCI DSS, GDPR, ISO 42001, ISO 27017, ISO 27018, TISAX, CSA STAR, C5, etc.', }, status: { type: 'string', enum: ['verified', 'expired', 'not_certified', 'unknown'], - description: 'Whether the certification is currently active/verified, expired, not certified, or unknown', + description: + 'Whether the certification is currently active/verified, expired, not certified, or unknown', }, issued_at: { type: 'string', - description: 'ISO 8601 date when the certification was issued, if mentioned', + description: + 'ISO 8601 date when the certification was issued, if mentioned', }, expires_at: { type: 'string', - description: 'ISO 8601 date when the certification expires, if mentioned', + description: + 'ISO 8601 date when the certification expires, if mentioned', }, url: { type: 'string', - description: 'Direct URL to the certification report or trust page on the vendor domain', + description: + 'Direct URL to the certification report or trust page on the vendor domain', }, }, required: ['type'], @@ -99,7 +107,8 @@ Focus on the official website ${vendorWebsite} and its trust/security/compliance }, links: { type: 'object', - description: 'Direct URLs to key legal and security pages on the vendor domain', + description: + 'Direct URLs to key legal and security pages on the vendor domain', properties: { privacy_policy_url: { type: 'string', @@ -111,15 +120,18 @@ Focus on the official website ${vendorWebsite} and its trust/security/compliance }, trust_center_url: { type: 'string', - description: 'Direct URL to the trust portal where customers can review security posture and request reports. Prefer the dedicated trust portal (often on trust.page, safebase.io, vanta.com, or a trust. subdomain) over documentation pages.', + description: + 'Direct URL to the trust portal where customers can review security posture and request reports. Prefer the dedicated trust portal (often on trust.page, safebase.io, vanta.com, or a trust. subdomain) over documentation pages.', }, security_page_url: { type: 'string', - description: 'Direct URL to the security overview or security practices page', + description: + 'Direct URL to the security overview or security practices page', }, soc2_report_url: { type: 'string', - description: 'Direct URL to request or download the SOC 2 report', + description: + 'Direct URL to request or download the SOC 2 report', }, }, }, @@ -128,7 +140,11 @@ Focus on the official website ${vendorWebsite} and its trust/security/compliance }, }); } catch (error) { - return handleFirecrawlError(error, { vendorName, vendorWebsite, callType: 'core' }); + return handleFirecrawlError(error, { + vendorName, + vendorWebsite, + callType: 'core', + }); } if (!agentResponse.success || agentResponse.status === 'failed') { @@ -154,16 +170,25 @@ Focus on the official website ${vendorWebsite} and its trust/security/compliance if (links?.trust_center_url) linkPairs.push({ label: 'Trust & Security', url: links.trust_center_url }); if (links?.security_page_url) - linkPairs.push({ label: 'Security Overview', url: links.security_page_url }); + linkPairs.push({ + label: 'Security Overview', + url: links.security_page_url, + }); if (links?.soc2_report_url) linkPairs.push({ label: 'SOC 2 Report', url: links.soc2_report_url }); if (links?.privacy_policy_url) linkPairs.push({ label: 'Privacy Policy', url: links.privacy_policy_url }); if (links?.terms_of_service_url) - linkPairs.push({ label: 'Terms of Service', url: links.terms_of_service_url }); + linkPairs.push({ + label: 'Terms of Service', + url: links.terms_of_service_url, + }); const normalizedLinks = linkPairs - .map((l) => ({ ...l, url: validateVendorUrl(l.url, vendorDomain, l.label) })) + .map((l) => ({ + ...l, + url: validateVendorUrl(l.url, vendorDomain, l.label), + })) .filter((l): l is { label: string; url: string } => Boolean(l.url)); const certifications = @@ -177,7 +202,10 @@ Focus on the official website ${vendorWebsite} and its trust/security/compliance logger.info('Firecrawl core research completed', { vendorWebsite, - found: { links: normalizedLinks.length, certifications: certifications.length }, + found: { + links: normalizedLinks.length, + certifications: certifications.length, + }, }); return { @@ -185,7 +213,8 @@ Focus on the official website ${vendorWebsite} and its trust/security/compliance vendorName, vendorWebsite, lastResearchedAt: - normalizeIso(parsed.data.last_researched_at ?? null) ?? new Date().toISOString(), + normalizeIso(parsed.data.last_researched_at ?? null) ?? + new Date().toISOString(), riskLevel: parsed.data.risk_level ?? null, securityAssessment: parsed.data.security_assessment ?? null, certifications: certifications.length > 0 ? certifications : null, diff --git a/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-news.ts b/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-news.ts index 11ca1bff65..56b7154dcc 100644 --- a/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-news.ts +++ b/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-news.ts @@ -13,7 +13,8 @@ const newsResponseSchema = { properties: { news: { type: 'array' as const, - description: 'Recent news articles about the company from the last 12 months, ordered by date descending', + description: + 'Recent news articles about the company from the last 12 months, ordered by date descending', items: { type: 'object' as const, properties: { @@ -31,7 +32,8 @@ const newsResponseSchema = { }, source: { type: 'string' as const, - description: 'Publication name, e.g. TechCrunch, Reuters, company blog', + description: + 'Publication name, e.g. TechCrunch, Reuters, company blog', }, url: { type: 'string' as const, @@ -40,7 +42,8 @@ const newsResponseSchema = { sentiment: { type: 'string' as const, enum: ['positive', 'negative', 'neutral'], - description: 'Whether the news is positive (funding, partnerships), negative (breaches, lawsuits), or neutral', + description: + 'Whether the news is positive (funding, partnerships), negative (breaches, lawsuits), or neutral', }, }, required: ['date', 'title'], @@ -84,7 +87,11 @@ Search the company's blog, newsroom, press releases, and reputable tech news sou schema: newsResponseSchema, }); } catch (error) { - return handleFirecrawlError(error, { vendorName, vendorWebsite, callType: 'news' }); + return handleFirecrawlError(error, { + vendorName, + vendorWebsite, + callType: 'news', + }); } if (!agentResponse.success || agentResponse.status === 'failed') { @@ -96,10 +103,14 @@ Search the company's blog, newsroom, press releases, and reputable tech news sou return null; } - const data = agentResponse.data as { news?: Array> } | undefined; + const data = agentResponse.data as + | { news?: Array> } + | undefined; const rawNews = data?.news; if (!Array.isArray(rawNews) || rawNews.length === 0) { - logger.info('Firecrawl news research returned no news items', { vendorWebsite }); + logger.info('Firecrawl news research returned no news items', { + vendorWebsite, + }); return null; } @@ -107,14 +118,17 @@ Search the company's blog, newsroom, press releases, and reputable tech news sou .flatMap((n) => { const isoDate = normalizeIso(n.date as string | undefined); if (!isoDate) return []; - return [{ - date: isoDate, - title: (n.title as string) ?? '', - summary: (n.summary as string) ?? null, - source: (n.source as string) ?? null, - url: normalizeUrl(n.url as string | undefined), - sentiment: (n.sentiment as 'positive' | 'negative' | 'neutral') ?? null, - }]; + return [ + { + date: isoDate, + title: (n.title as string) ?? '', + summary: (n.summary as string) ?? null, + source: (n.source as string) ?? null, + url: normalizeUrl(n.url as string | undefined), + sentiment: + (n.sentiment as 'positive' | 'negative' | 'neutral') ?? null, + }, + ]; }) .filter(Boolean); diff --git a/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-shared.ts b/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-shared.ts index 13e53b41ee..347df01eaa 100644 --- a/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-shared.ts +++ b/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent-shared.ts @@ -43,7 +43,9 @@ export function setupFirecrawlClient(params: { }): FirecrawlSetup | null { const apiKey = process.env.FIRECRAWL_API_KEY; if (!apiKey) { - logger.warn('FIRECRAWL_API_KEY is not configured; skipping vendor research'); + logger.warn( + 'FIRECRAWL_API_KEY is not configured; skipping vendor research', + ); return null; } @@ -95,11 +97,14 @@ export function handleFirecrawlError( message.includes('Too Many Requests'); if (isBillingOrRateLimit) { - logger.error(`Firecrawl API billing or rate limit error (${context.callType})`, { - vendorName: context.vendorName, - vendorWebsite: context.vendorWebsite, - error: message, - }); + logger.error( + `Firecrawl API billing or rate limit error (${context.callType})`, + { + vendorName: context.vendorName, + vendorWebsite: context.vendorWebsite, + error: message, + }, + ); throw error; } diff --git a/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent.ts b/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent.ts index 1ab6303a1b..349f3e9930 100644 --- a/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent.ts +++ b/apps/api/src/trigger/vendor/vendor-risk-assessment/firecrawl-agent.ts @@ -179,8 +179,7 @@ Focus on their official website ${vendorWebsite} (especially trust/security/comp }, }); } catch (error) { - const message = - error instanceof Error ? error.message : String(error); + const message = error instanceof Error ? error.message : String(error); const isBillingOrRateLimit = message.includes('402') || message.includes('429') || diff --git a/apps/api/src/trigger/vendor/vendor-risk-assessment/url-validation.spec.ts b/apps/api/src/trigger/vendor/vendor-risk-assessment/url-validation.spec.ts index a026c95cd6..281a40ec72 100644 --- a/apps/api/src/trigger/vendor/vendor-risk-assessment/url-validation.spec.ts +++ b/apps/api/src/trigger/vendor/vendor-risk-assessment/url-validation.spec.ts @@ -17,15 +17,15 @@ describe('isUrlFromVendorDomain', () => { }); it('accepts www subdomain', () => { - expect( - isUrlFromVendorDomain('https://www.wix.com/terms', 'wix.com'), - ).toBe(true); + expect(isUrlFromVendorDomain('https://www.wix.com/terms', 'wix.com')).toBe( + true, + ); }); it('accepts other subdomains', () => { - expect( - isUrlFromVendorDomain('https://trust.wix.com', 'wix.com'), - ).toBe(true); + expect(isUrlFromVendorDomain('https://trust.wix.com', 'wix.com')).toBe( + true, + ); expect( isUrlFromVendorDomain('https://security.wix.com/page', 'wix.com'), ).toBe(true); @@ -35,25 +35,25 @@ describe('isUrlFromVendorDomain', () => { expect(isUrlFromVendorDomain('https://x.com/privacy', 'wix.com')).toBe( false, ); - expect( - isUrlFromVendorDomain('https://twitter.com/wix', 'wix.com'), - ).toBe(false); + expect(isUrlFromVendorDomain('https://twitter.com/wix', 'wix.com')).toBe( + false, + ); }); it('rejects domains that end with vendor domain but are different', () => { // "notwix.com" ends with "wix.com" as a string, but is a different domain - expect( - isUrlFromVendorDomain('https://notwix.com/privacy', 'wix.com'), - ).toBe(false); + expect(isUrlFromVendorDomain('https://notwix.com/privacy', 'wix.com')).toBe( + false, + ); }); it('is case-insensitive', () => { expect( isUrlFromVendorDomain('https://WWW.WIX.COM/privacy', 'wix.com'), ).toBe(true); - expect( - isUrlFromVendorDomain('https://wix.com/privacy', 'WIX.COM'), - ).toBe(true); + expect(isUrlFromVendorDomain('https://wix.com/privacy', 'WIX.COM')).toBe( + true, + ); }); it('returns false for invalid URLs', () => { @@ -84,7 +84,9 @@ describe('extractVendorDomain', () => { it('extracts root domain from subdomain websites', () => { expect(extractVendorDomain('https://app.slack.com')).toBe('slack.com'); expect(extractVendorDomain('https://trust.wix.com')).toBe('wix.com'); - expect(extractVendorDomain('https://dashboard.stripe.com')).toBe('stripe.com'); + expect(extractVendorDomain('https://dashboard.stripe.com')).toBe( + 'stripe.com', + ); }); it('extracts root domain from multi-level subdomains', () => { @@ -92,16 +94,20 @@ describe('extractVendorDomain', () => { }); it('handles two-part TLDs correctly', () => { - expect(extractVendorDomain('https://app.example.co.uk')).toBe('example.co.uk'); - expect(extractVendorDomain('https://www.example.com.au')).toBe('example.com.au'); + expect(extractVendorDomain('https://app.example.co.uk')).toBe( + 'example.co.uk', + ); + expect(extractVendorDomain('https://www.example.com.au')).toBe( + 'example.com.au', + ); }); }); describe('validateVendorUrl', () => { it('returns normalized URL for valid vendor URLs', () => { - expect(validateVendorUrl('https://wix.com/privacy', 'wix.com', 'privacy')).toBe( - 'https://wix.com/privacy', - ); + expect( + validateVendorUrl('https://wix.com/privacy', 'wix.com', 'privacy'), + ).toBe('https://wix.com/privacy'); }); it('accepts URLs from any domain (domain filtering removed — trusts AI agent)', () => { @@ -123,9 +129,9 @@ describe('validateVendorUrl', () => { }); it('accepts subdomain URLs', () => { - expect( - validateVendorUrl('https://trust.wix.com', 'wix.com', 'trust'), - ).toBe('https://trust.wix.com/'); + expect(validateVendorUrl('https://trust.wix.com', 'wix.com', 'trust')).toBe( + 'https://trust.wix.com/', + ); }); it('accepts parent domain URLs when vendor website is a subdomain', () => { diff --git a/apps/api/src/trigger/vendor/vendor-risk-assessment/url-validation.ts b/apps/api/src/trigger/vendor/vendor-risk-assessment/url-validation.ts index ee7467936d..454daf6bf4 100644 --- a/apps/api/src/trigger/vendor/vendor-risk-assessment/url-validation.ts +++ b/apps/api/src/trigger/vendor/vendor-risk-assessment/url-validation.ts @@ -3,10 +3,10 @@ import { getDomain } from 'tldts'; // Well-known trust portal domains that vendors use to host their security pages const TRUSTED_PORTAL_DOMAINS = [ - 'trust.page', // SafeBase - 'vanta.com', // Vanta trust centers - 'drata.com', // Drata trust centers - 'safebase.io', // SafeBase + 'trust.page', // SafeBase + 'vanta.com', // Vanta trust centers + 'drata.com', // Drata trust centers + 'safebase.io', // SafeBase 'securityscorecard.com', 'whistic.com', 'conveyor.com', @@ -32,7 +32,7 @@ export function isUrlFromVendorDomain( const parsed = new URL(url); const hostname = parsed.hostname.toLowerCase(); const domain = vendorDomain.toLowerCase(); - const vendorName = domain.split('.')[0]!; // "github" from "github.com" + const vendorName = domain.split('.')[0]; // "github" from "github.com" // Direct match: github.com or *.github.com if (hostname === domain || hostname.endsWith(`.${domain}`)) { @@ -60,9 +60,7 @@ export function isUrlFromVendorDomain( * For example, "https://app.slack.com" → "slack.com". * Returns null if the URL is invalid. */ -export function extractVendorDomain( - website: string, -): string | null { +export function extractVendorDomain(website: string): string | null { try { const urlObj = new URL( /^https?:\/\//i.test(website) ? website : `https://${website}`, diff --git a/apps/api/src/trust-portal/policy-pdf-renderer.service.spec.ts b/apps/api/src/trust-portal/policy-pdf-renderer.service.spec.ts index f2508b48ab..4dcc2e8186 100644 --- a/apps/api/src/trust-portal/policy-pdf-renderer.service.spec.ts +++ b/apps/api/src/trust-portal/policy-pdf-renderer.service.spec.ts @@ -200,7 +200,7 @@ describe('PolicyPdfRendererService', () => { const result = service.renderPoliciesPdfBuffer( [ { - name: 'Politique d\'Authentification', + name: "Politique d'Authentification", content: { type: 'doc', content: [ @@ -209,7 +209,7 @@ describe('PolicyPdfRendererService', () => { content: [ { type: 'text', - text: '🇫🇷 Résumé des règles d\'authentification café', + text: "🇫🇷 Résumé des règles d'authentification café", }, ], }, diff --git a/apps/api/src/trust-portal/policy-pdf-renderer.service.ts b/apps/api/src/trust-portal/policy-pdf-renderer.service.ts index 7cf0ee28e0..d2bf5e43e2 100644 --- a/apps/api/src/trust-portal/policy-pdf-renderer.service.ts +++ b/apps/api/src/trust-portal/policy-pdf-renderer.service.ts @@ -92,13 +92,27 @@ export class PolicyPdfRendererService { ); const replacements: { [key: string]: string } = { - '\u2018': "'", '\u2019': "'", '\u201C': '"', '\u201D': '"', - '\u2013': '-', '\u2014': '-', '\u2026': '...', - '\u2265': '>=', '\u2264': '<=', '\u00B0': 'deg', - '\u00A9': '(c)', '\u00AE': '(R)', '\u2122': 'TM', - '\u00A0': ' ', '\u2022': '•', '\u00B1': '+/-', - '\u00D7': 'x', '\u00F7': '/', '\u2192': '->', - '\u2190': '<-', '\u2194': '<->', + '\u2018': "'", + '\u2019': "'", + '\u201C': '"', + '\u201D': '"', + '\u2013': '-', + '\u2014': '-', + '\u2026': '...', + '\u2265': '>=', + '\u2264': '<=', + '\u00B0': 'deg', + '\u00A9': '(c)', + '\u00AE': '(R)', + '\u2122': 'TM', + '\u00A0': ' ', + '\u2022': '•', + '\u00B1': '+/-', + '\u00D7': 'x', + '\u00F7': '/', + '\u2192': '->', + '\u2190': '<-', + '\u2194': '<->', }; let cleanedText = strippedText; @@ -113,14 +127,63 @@ export class PolicyPdfRendererService { return char; } const fallbacks: { [key: string]: string } = { - à: 'a', á: 'a', â: 'a', ã: 'a', ä: 'a', å: 'a', æ: 'ae', - è: 'e', é: 'e', ê: 'e', ë: 'e', ì: 'i', í: 'i', î: 'i', ï: 'i', - ò: 'o', ó: 'o', ô: 'o', õ: 'o', ö: 'o', ø: 'o', - ù: 'u', ú: 'u', û: 'u', ü: 'u', ñ: 'n', ç: 'c', ß: 'ss', ÿ: 'y', - À: 'A', Á: 'A', Â: 'A', Ã: 'A', Ä: 'A', Å: 'A', Æ: 'AE', - È: 'E', É: 'E', Ê: 'E', Ë: 'E', Ì: 'I', Í: 'I', Î: 'I', Ï: 'I', - Ò: 'O', Ó: 'O', Ô: 'O', Õ: 'O', Ö: 'O', Ø: 'O', - Ù: 'U', Ú: 'U', Û: 'U', Ü: 'U', Ñ: 'N', Ç: 'C', Ý: 'Y', + à: 'a', + á: 'a', + â: 'a', + ã: 'a', + ä: 'a', + å: 'a', + æ: 'ae', + è: 'e', + é: 'e', + ê: 'e', + ë: 'e', + ì: 'i', + í: 'i', + î: 'i', + ï: 'i', + ò: 'o', + ó: 'o', + ô: 'o', + õ: 'o', + ö: 'o', + ø: 'o', + ù: 'u', + ú: 'u', + û: 'u', + ü: 'u', + ñ: 'n', + ç: 'c', + ß: 'ss', + ÿ: 'y', + À: 'A', + Á: 'A', + Â: 'A', + Ã: 'A', + Ä: 'A', + Å: 'A', + Æ: 'AE', + È: 'E', + É: 'E', + Ê: 'E', + Ë: 'E', + Ì: 'I', + Í: 'I', + Î: 'I', + Ï: 'I', + Ò: 'O', + Ó: 'O', + Ô: 'O', + Õ: 'O', + Ö: 'O', + Ø: 'O', + Ù: 'U', + Ú: 'U', + Û: 'U', + Ü: 'U', + Ñ: 'N', + Ç: 'C', + Ý: 'Y', }; return fallbacks[char] ?? char; }); diff --git a/apps/api/src/trust-portal/trust-access.controller.spec.ts b/apps/api/src/trust-portal/trust-access.controller.spec.ts index 4e165fc0bc..cb83ddaa82 100644 --- a/apps/api/src/trust-portal/trust-access.controller.spec.ts +++ b/apps/api/src/trust-portal/trust-access.controller.spec.ts @@ -87,7 +87,11 @@ describe('TrustAccessController', () => { const req = mockRequest(); mockService.createAccessRequest.mockResolvedValue({ id: 'req_1' }); - const result = await controller.createAccessRequest('my-portal', dto, req); + const result = await controller.createAccessRequest( + 'my-portal', + dto, + req, + ); expect(result).toEqual({ id: 'req_1' }); expect(service.createAccessRequest).toHaveBeenCalledWith( @@ -134,7 +138,10 @@ describe('TrustAccessController', () => { const result = await controller.approveRequest(orgId, 'req_1', dto, req); expect(result).toEqual({ success: true }); - expect(service.getMemberIdFromUserId).toHaveBeenCalledWith('user_1', orgId); + expect(service.getMemberIdFromUserId).toHaveBeenCalledWith( + 'user_1', + orgId, + ); expect(service.approveRequest).toHaveBeenCalledWith( orgId, 'req_1', @@ -163,7 +170,10 @@ describe('TrustAccessController', () => { const result = await controller.denyRequest(orgId, 'req_1', dto, req); expect(result).toEqual({ success: true }); - expect(service.getMemberIdFromUserId).toHaveBeenCalledWith('user_1', orgId); + expect(service.getMemberIdFromUserId).toHaveBeenCalledWith( + 'user_1', + orgId, + ); expect(service.denyRequest).toHaveBeenCalledWith( orgId, 'req_1', @@ -204,7 +214,10 @@ describe('TrustAccessController', () => { const result = await controller.revokeGrant(orgId, 'grant_1', dto, req); expect(result).toEqual({ success: true }); - expect(service.getMemberIdFromUserId).toHaveBeenCalledWith('user_1', orgId); + expect(service.getMemberIdFromUserId).toHaveBeenCalledWith( + 'user_1', + orgId, + ); expect(service.revokeGrant).toHaveBeenCalledWith( orgId, 'grant_1', @@ -230,7 +243,10 @@ describe('TrustAccessController', () => { const result = await controller.resendAccessEmail(orgId, 'grant_1'); expect(result).toEqual({ success: true }); - expect(service.resendAccessGrantEmail).toHaveBeenCalledWith(orgId, 'grant_1'); + expect(service.resendAccessGrantEmail).toHaveBeenCalledWith( + orgId, + 'grant_1', + ); }); }); @@ -314,7 +330,11 @@ describe('TrustAccessController', () => { const dto = { email: 'user@example.com' }; mockService.reclaimAccess.mockResolvedValue({ success: true }); - const result = await controller.reclaimAccess('my-portal', dto as any, 'security-questionnaire'); + const result = await controller.reclaimAccess( + 'my-portal', + dto as any, + 'security-questionnaire', + ); expect(result).toEqual({ success: true }); expect(service.reclaimAccess).toHaveBeenCalledWith( @@ -358,43 +378,58 @@ describe('TrustAccessController', () => { const result = await controller.getPoliciesByAccessToken('token_abc'); expect(result).toEqual(mockResult); - expect(service.getPoliciesByAccessToken).toHaveBeenCalledWith('token_abc'); + expect(service.getPoliciesByAccessToken).toHaveBeenCalledWith( + 'token_abc', + ); }); }); describe('downloadAllPolicies', () => { it('should call service.downloadAllPoliciesByAccessToken with token', async () => { const mockResult = { url: 'https://download-url' }; - mockService.downloadAllPoliciesByAccessToken.mockResolvedValue(mockResult); + mockService.downloadAllPoliciesByAccessToken.mockResolvedValue( + mockResult, + ); const result = await controller.downloadAllPolicies('token_abc'); expect(result).toEqual(mockResult); - expect(service.downloadAllPoliciesByAccessToken).toHaveBeenCalledWith('token_abc'); + expect(service.downloadAllPoliciesByAccessToken).toHaveBeenCalledWith( + 'token_abc', + ); }); }); describe('downloadAllPoliciesAsZip', () => { it('should call service.downloadAllPoliciesAsZipByAccessToken with token', async () => { const mockResult = { url: 'https://zip-url' }; - mockService.downloadAllPoliciesAsZipByAccessToken.mockResolvedValue(mockResult); + mockService.downloadAllPoliciesAsZipByAccessToken.mockResolvedValue( + mockResult, + ); const result = await controller.downloadAllPoliciesAsZip('token_abc'); expect(result).toEqual(mockResult); - expect(service.downloadAllPoliciesAsZipByAccessToken).toHaveBeenCalledWith('token_abc'); + expect( + service.downloadAllPoliciesAsZipByAccessToken, + ).toHaveBeenCalledWith('token_abc'); }); }); describe('getComplianceResourcesByAccessToken', () => { it('should call service.getComplianceResourcesByAccessToken with token', async () => { const mockResult = [{ id: 'cr_1' }]; - mockService.getComplianceResourcesByAccessToken.mockResolvedValue(mockResult); + mockService.getComplianceResourcesByAccessToken.mockResolvedValue( + mockResult, + ); - const result = await controller.getComplianceResourcesByAccessToken('token_abc'); + const result = + await controller.getComplianceResourcesByAccessToken('token_abc'); expect(result).toEqual(mockResult); - expect(service.getComplianceResourcesByAccessToken).toHaveBeenCalledWith('token_abc'); + expect(service.getComplianceResourcesByAccessToken).toHaveBeenCalledWith( + 'token_abc', + ); }); }); @@ -403,46 +438,68 @@ describe('TrustAccessController', () => { const mockResult = [{ id: 'td_1' }]; mockService.getTrustDocumentsByAccessToken.mockResolvedValue(mockResult); - const result = await controller.getTrustDocumentsByAccessToken('token_abc'); + const result = + await controller.getTrustDocumentsByAccessToken('token_abc'); expect(result).toEqual(mockResult); - expect(service.getTrustDocumentsByAccessToken).toHaveBeenCalledWith('token_abc'); + expect(service.getTrustDocumentsByAccessToken).toHaveBeenCalledWith( + 'token_abc', + ); }); }); describe('downloadAllTrustDocuments', () => { it('should call service.downloadAllTrustDocumentsByAccessToken with token', async () => { const mockResult = { url: 'https://zip-url' }; - mockService.downloadAllTrustDocumentsByAccessToken.mockResolvedValue(mockResult); + mockService.downloadAllTrustDocumentsByAccessToken.mockResolvedValue( + mockResult, + ); const result = await controller.downloadAllTrustDocuments('token_abc'); expect(result).toEqual(mockResult); - expect(service.downloadAllTrustDocumentsByAccessToken).toHaveBeenCalledWith('token_abc'); + expect( + service.downloadAllTrustDocumentsByAccessToken, + ).toHaveBeenCalledWith('token_abc'); }); }); describe('getTrustDocumentUrlByAccessToken', () => { it('should call service with token and documentId', async () => { const mockResult = { url: 'https://signed-url' }; - mockService.getTrustDocumentUrlByAccessToken.mockResolvedValue(mockResult); + mockService.getTrustDocumentUrlByAccessToken.mockResolvedValue( + mockResult, + ); - const result = await controller.getTrustDocumentUrlByAccessToken('token_abc', 'tdoc_1'); + const result = await controller.getTrustDocumentUrlByAccessToken( + 'token_abc', + 'tdoc_1', + ); expect(result).toEqual(mockResult); - expect(service.getTrustDocumentUrlByAccessToken).toHaveBeenCalledWith('token_abc', 'tdoc_1'); + expect(service.getTrustDocumentUrlByAccessToken).toHaveBeenCalledWith( + 'token_abc', + 'tdoc_1', + ); }); }); describe('getComplianceResourceUrlByAccessToken', () => { it('should call service with token and framework', async () => { const mockResult = { url: 'https://signed-url' }; - mockService.getComplianceResourceUrlByAccessToken.mockResolvedValue(mockResult); + mockService.getComplianceResourceUrlByAccessToken.mockResolvedValue( + mockResult, + ); - const result = await controller.getComplianceResourceUrlByAccessToken('token_abc', 'SOC2'); + const result = await controller.getComplianceResourceUrlByAccessToken( + 'token_abc', + 'SOC2', + ); expect(result).toEqual(mockResult); - expect(service.getComplianceResourceUrlByAccessToken).toHaveBeenCalledWith('token_abc', 'SOC2'); + expect( + service.getComplianceResourceUrlByAccessToken, + ).toHaveBeenCalledWith('token_abc', 'SOC2'); }); }); diff --git a/apps/api/src/trust-portal/trust-access.service.spec.ts b/apps/api/src/trust-portal/trust-access.service.spec.ts new file mode 100644 index 0000000000..19e80f3c43 --- /dev/null +++ b/apps/api/src/trust-portal/trust-access.service.spec.ts @@ -0,0 +1,174 @@ +import { GetObjectCommand } from '@aws-sdk/client-s3'; +import { db } from '@db'; +import { getSignedUrl } from '../app/s3'; +import { TrustAccessService } from './trust-access.service'; + +jest.mock('@db', () => ({ + db: { + trust: { + findUnique: jest.fn(), + upsert: jest.fn(), + }, + trustNDAAgreement: { + findUnique: jest.fn(), + }, + trustAccessGrant: { + findUnique: jest.fn(), + }, + }, + Prisma: { + PrismaClientKnownRequestError: class PrismaClientKnownRequestError extends Error { + code: string; + + constructor(code: string) { + super(); + this.code = code; + } + }, + }, + TrustFramework: { + iso_27001: 'iso_27001', + iso_42001: 'iso_42001', + gdpr: 'gdpr', + hipaa: 'hipaa', + soc2_type1: 'soc2_type1', + soc2_type2: 'soc2_type2', + pci_dss: 'pci_dss', + nen_7510: 'nen_7510', + iso_9001: 'iso_9001', + }, +})); + +jest.mock('../app/s3', () => ({ + APP_AWS_ORG_ASSETS_BUCKET: 'org-assets', + s3Client: { send: jest.fn() }, + getSignedUrl: jest.fn(), +})); + +const mockDb = db as unknown as { + trust: { + findUnique: jest.Mock; + upsert: jest.Mock; + }; + trustNDAAgreement: { + findUnique: jest.Mock; + }; + trustAccessGrant: { + findUnique: jest.Mock; + }; +}; + +const mockGetSignedUrl = getSignedUrl as jest.MockedFunction< + typeof getSignedUrl +>; + +describe('TrustAccessService favicon branding', () => { + const service = new TrustAccessService( + { + getSignedUrl: jest.fn(), + } as any, + {} as any, + {} as any, + {} as any, + ); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('falls back to organizationId lookup when getPublicFavicon route id is not a friendlyUrl', async () => { + mockDb.trust.findUnique.mockResolvedValueOnce(null).mockResolvedValueOnce({ + favicon: 'org_123/trust/favicon/icon.png', + }); + mockGetSignedUrl.mockResolvedValue('https://cdn.example.com/favicon.png'); + + const result = await service.getPublicFavicon('org_123'); + + expect(mockDb.trust.findUnique).toHaveBeenNthCalledWith(1, { + where: { friendlyUrl: 'org_123' }, + select: { favicon: true }, + }); + expect(mockDb.trust.findUnique).toHaveBeenNthCalledWith(2, { + where: { organizationId: 'org_123' }, + select: { favicon: true }, + }); + expect(result).toBe('https://cdn.example.com/favicon.png'); + expect(mockGetSignedUrl).toHaveBeenCalledTimes(1); + expect(mockGetSignedUrl.mock.calls[0][1]).toBeInstanceOf(GetObjectCommand); + }); + + it('includes friendlyUrl and faviconUrl in getGrantByAccessToken response', async () => { + const futureDate = new Date(Date.now() + 60 * 60 * 1000); + + mockDb.trustAccessGrant.findUnique.mockResolvedValue({ + id: 'grant_1', + status: 'active', + expiresAt: futureDate, + accessTokenExpiresAt: futureDate, + subjectEmail: 'alice@example.com', + accessRequest: { + organizationId: 'org_123', + name: 'Alice', + organization: { + name: 'Acme Security', + }, + }, + ndaAgreement: null, + }); + mockDb.trust.findUnique.mockResolvedValue({ + friendlyUrl: 'acme-security', + favicon: 'org_123/trust/favicon/icon.png', + }); + mockGetSignedUrl.mockResolvedValue('https://cdn.example.com/favicon.png'); + + const result = await service.getGrantByAccessToken('grant-token'); + + expect(result).toMatchObject({ + organizationName: 'Acme Security', + friendlyUrl: 'acme-security', + faviconUrl: 'https://cdn.example.com/favicon.png', + subjectEmail: 'alice@example.com', + }); + }); + + it('includes friendlyUrl and faviconUrl in getNdaByToken response', async () => { + const futureDate = new Date(Date.now() + 60 * 60 * 1000); + + mockDb.trustNDAAgreement.findUnique.mockResolvedValue({ + id: 'nda_1', + organizationId: 'org_123', + signTokenExpiresAt: futureDate, + status: 'pending', + accessRequest: { + name: 'Alice', + email: 'alice@example.com', + organization: { + name: 'Acme Security', + }, + }, + grant: null, + }); + mockDb.trust.findUnique + .mockResolvedValueOnce({ + domain: null, + domainVerified: false, + friendlyUrl: 'acme-security', + }) + .mockResolvedValueOnce({ + friendlyUrl: 'acme-security', + favicon: 'org_123/trust/favicon/icon.png', + }); + mockGetSignedUrl.mockResolvedValue('https://cdn.example.com/favicon.png'); + + const result = await service.getNdaByToken('nda-token'); + + expect(result).toMatchObject({ + id: 'nda_1', + status: 'pending', + organizationName: 'Acme Security', + friendlyUrl: 'acme-security', + faviconUrl: 'https://cdn.example.com/favicon.png', + }); + expect(result.portalUrl).toContain('/acme-security'); + }); +}); diff --git a/apps/api/src/trust-portal/trust-access.service.ts b/apps/api/src/trust-portal/trust-access.service.ts index b29de78d15..237aba4c0f 100644 --- a/apps/api/src/trust-portal/trust-access.service.ts +++ b/apps/api/src/trust-portal/trust-access.service.ts @@ -18,8 +18,7 @@ import { NdaPdfService } from './nda-pdf.service'; import { AttachmentsService } from '../attachments/attachments.service'; import { PolicyPdfRendererService } from './policy-pdf-renderer.service'; import { GetObjectCommand, PutObjectCommand } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; -import { APP_AWS_ORG_ASSETS_BUCKET, s3Client } from '../app/s3'; +import { APP_AWS_ORG_ASSETS_BUCKET, s3Client, getSignedUrl } from '../app/s3'; import { Prisma, TrustFramework } from '@db'; import archiver from 'archiver'; import { PassThrough, Readable } from 'stream'; @@ -987,10 +986,15 @@ export class TrustAccessService { organizationId: nda.organizationId, organizationName: nda.accessRequest.organization.name, }); + const branding = await this.getTrustBrandingByOrganizationId( + nda.organizationId, + ); const baseResponse = { id: nda.id, organizationName: nda.accessRequest.organization.name, + friendlyUrl: branding.friendlyUrl, + faviconUrl: branding.faviconUrl, requesterName: nda.accessRequest.name, requesterEmail: nda.accessRequest.email, expiresAt: nda.signTokenExpiresAt, @@ -1435,15 +1439,54 @@ export class TrustAccessService { const ndaPdfUrl = grant.ndaAgreement?.pdfSignedKey ? await this.ndaPdfService.getSignedUrl(grant.ndaAgreement.pdfSignedKey) : null; + const branding = await this.getTrustBrandingByOrganizationId( + grant.accessRequest.organizationId, + ); return { organizationName: grant.accessRequest.organization.name, + friendlyUrl: branding.friendlyUrl, + faviconUrl: branding.faviconUrl, expiresAt: grant.expiresAt, subjectEmail: grant.subjectEmail, ndaPdfUrl, }; } + private async getTrustBrandingByOrganizationId( + organizationId: string, + ): Promise<{ friendlyUrl: string; faviconUrl: string | null }> { + const trust = await db.trust.findUnique({ + where: { organizationId }, + select: { friendlyUrl: true, favicon: true }, + }); + + const friendlyUrl = trust?.friendlyUrl ?? organizationId; + const faviconUrl = trust?.favicon + ? await this.getFaviconSignedUrl(trust.favicon) + : null; + + return { friendlyUrl, faviconUrl }; + } + + private async getFaviconSignedUrl( + faviconKey: string, + ): Promise { + if (!s3Client || !APP_AWS_ORG_ASSETS_BUCKET) { + return null; + } + + try { + const command = new GetObjectCommand({ + Bucket: APP_AWS_ORG_ASSETS_BUCKET, + Key: faviconKey, + }); + return await getSignedUrl(s3Client, command, { expiresIn: 86400 }); // 24 hours + } catch { + return null; + } + } + async validateAccessTokenAndGetOrganizationId( token: string, ): Promise { @@ -2429,24 +2472,23 @@ export class TrustAccessService { } async getPublicFavicon(friendlyUrl: string): Promise { - const trust = await db.trust.findUnique({ + let trust = await db.trust.findUnique({ where: { friendlyUrl }, select: { favicon: true }, }); - if (!trust?.favicon || !s3Client || !APP_AWS_ORG_ASSETS_BUCKET) { - return null; + if (!trust) { + trust = await db.trust.findUnique({ + where: { organizationId: friendlyUrl }, + select: { favicon: true }, + }); } - try { - const command = new GetObjectCommand({ - Bucket: APP_AWS_ORG_ASSETS_BUCKET, - Key: trust.favicon, - }); - return await getSignedUrl(s3Client, command, { expiresIn: 86400 }); // 24 hours - } catch { + if (!trust?.favicon) { return null; } + + return this.getFaviconSignedUrl(trust.favicon); } async getPublicVendors(friendlyUrl: string) { diff --git a/apps/api/src/trust-portal/trust-portal.controller.spec.ts b/apps/api/src/trust-portal/trust-portal.controller.spec.ts index 368c64edef..142fc1bc61 100644 --- a/apps/api/src/trust-portal/trust-portal.controller.spec.ts +++ b/apps/api/src/trust-portal/trust-portal.controller.spec.ts @@ -99,7 +99,11 @@ describe('TrustPortalController', () => { describe('uploadFavicon', () => { it('should call service.uploadFavicon with organizationId and body', async () => { - const body = { fileName: 'fav.png', fileType: 'image/png', fileData: 'base64data' }; + const body = { + fileName: 'fav.png', + fileType: 'image/png', + fileData: 'base64data', + }; const mockResult = { url: 'https://example.com/fav.png' }; mockService.uploadFavicon.mockResolvedValue(mockResult); @@ -140,7 +144,10 @@ describe('TrustPortalController', () => { const mockResult = { id: 'cr_1' }; mockService.uploadComplianceResource.mockResolvedValue(mockResult); - const result = await controller.uploadComplianceResource(dto, authContext); + const result = await controller.uploadComplianceResource( + dto, + authContext, + ); expect(result).toEqual(mockResult); expect(service.uploadComplianceResource).toHaveBeenCalledWith(dto); @@ -161,7 +168,10 @@ describe('TrustPortalController', () => { const mockResult = { url: 'https://signed-url' }; mockService.getComplianceResourceUrl.mockResolvedValue(mockResult); - const result = await controller.getComplianceResourceUrl(dto, authContext); + const result = await controller.getComplianceResourceUrl( + dto, + authContext, + ); expect(result).toEqual(mockResult); expect(service.getComplianceResourceUrl).toHaveBeenCalledWith(dto); @@ -182,7 +192,10 @@ describe('TrustPortalController', () => { const mockResult = [{ id: 'cr_1' }]; mockService.listComplianceResources.mockResolvedValue(mockResult); - const result = await controller.listComplianceResources(dto as any, authContext); + const result = await controller.listComplianceResources( + dto as any, + authContext, + ); expect(result).toEqual(mockResult); expect(service.listComplianceResources).toHaveBeenCalledWith(orgId); @@ -224,7 +237,10 @@ describe('TrustPortalController', () => { const mockResult = [{ id: 'td_1' }]; mockService.listTrustDocuments.mockResolvedValue(mockResult); - const result = await controller.listTrustDocuments(dto as any, authContext); + const result = await controller.listTrustDocuments( + dto as any, + authContext, + ); expect(result).toEqual(mockResult); expect(service.listTrustDocuments).toHaveBeenCalledWith(orgId); @@ -238,7 +254,11 @@ describe('TrustPortalController', () => { const mockResult = { url: 'https://signed-url' }; mockService.getTrustDocumentUrl.mockResolvedValue(mockResult); - const result = await controller.getTrustDocumentUrl(dto, documentId, authContext); + const result = await controller.getTrustDocumentUrl( + dto, + documentId, + authContext, + ); expect(result).toEqual(mockResult); expect(service.getTrustDocumentUrl).toHaveBeenCalledWith(documentId, dto); @@ -259,7 +279,11 @@ describe('TrustPortalController', () => { const documentId = 'td_1'; mockService.deleteTrustDocument.mockResolvedValue({ success: true }); - const result = await controller.deleteTrustDocument(dto, documentId, authContext); + const result = await controller.deleteTrustDocument( + dto, + documentId, + authContext, + ); expect(result).toEqual({ success: true }); expect(service.deleteTrustDocument).toHaveBeenCalledWith(documentId, dto); @@ -276,7 +300,11 @@ describe('TrustPortalController', () => { describe('togglePortal', () => { it('should call service.togglePortal with correct params', async () => { - const body = { enabled: true, contactEmail: 'test@example.com', primaryColor: '#000' }; + const body = { + enabled: true, + contactEmail: 'test@example.com', + primaryColor: '#000', + }; mockService.togglePortal.mockResolvedValue({ enabled: true }); const result = await controller.togglePortal(orgId, body); @@ -308,12 +336,17 @@ describe('TrustPortalController', () => { describe('addCustomDomain', () => { it('should call service.addCustomDomain with organizationId and domain', async () => { const body = { domain: 'trust.example.com' }; - mockService.addCustomDomain.mockResolvedValue({ domain: 'trust.example.com' }); + mockService.addCustomDomain.mockResolvedValue({ + domain: 'trust.example.com', + }); const result = await controller.addCustomDomain(orgId, body); expect(result).toEqual({ domain: 'trust.example.com' }); - expect(service.addCustomDomain).toHaveBeenCalledWith(orgId, 'trust.example.com'); + expect(service.addCustomDomain).toHaveBeenCalledWith( + orgId, + 'trust.example.com', + ); }); it('should throw BadRequestException when domain is empty', async () => { @@ -332,7 +365,10 @@ describe('TrustPortalController', () => { const result = await controller.checkDnsRecords(orgId, body); expect(result).toEqual(mockResult); - expect(service.checkDnsRecords).toHaveBeenCalledWith(orgId, 'trust.example.com'); + expect(service.checkDnsRecords).toHaveBeenCalledWith( + orgId, + 'trust.example.com', + ); }); it('should throw BadRequestException when domain is empty', async () => { @@ -396,7 +432,11 @@ describe('TrustPortalController', () => { describe('updateOverview', () => { it('should call service.updateOverview with organizationId and parsed dto', async () => { - const body = { organizationId: orgId, overviewTitle: 'Our Trust', overviewContent: 'Desc' }; + const body = { + organizationId: orgId, + overviewTitle: 'Our Trust', + overviewContent: 'Desc', + }; mockService.updateOverview.mockResolvedValue({ success: true }); const result = await controller.updateOverview(body as any, authContext); @@ -404,7 +444,10 @@ describe('TrustPortalController', () => { expect(result).toEqual({ success: true }); expect(service.updateOverview).toHaveBeenCalledWith( orgId, - expect.objectContaining({ overviewTitle: 'Our Trust', overviewContent: 'Desc' }), + expect.objectContaining({ + overviewTitle: 'Our Trust', + overviewContent: 'Desc', + }), ); }); @@ -437,10 +480,17 @@ describe('TrustPortalController', () => { describe('createCustomLink', () => { it('should call service.createCustomLink with organizationId and parsed dto', async () => { - const body = { organizationId: orgId, title: 'Link', url: 'https://example.com' }; + const body = { + organizationId: orgId, + title: 'Link', + url: 'https://example.com', + }; mockService.createCustomLink.mockResolvedValue({ id: 'cl_1' }); - const result = await controller.createCustomLink(body as any, authContext); + const result = await controller.createCustomLink( + body as any, + authContext, + ); expect(result).toEqual({ id: 'cl_1' }); expect(service.createCustomLink).toHaveBeenCalledWith( @@ -450,7 +500,11 @@ describe('TrustPortalController', () => { }); it('should throw BadRequestException for organization mismatch', async () => { - const body = { organizationId: orgId, title: 'Link', url: 'https://example.com' }; + const body = { + organizationId: orgId, + title: 'Link', + url: 'https://example.com', + }; await expect( controller.createCustomLink(body as any, otherOrgAuthContext), @@ -463,7 +517,11 @@ describe('TrustPortalController', () => { const body = { title: 'Updated', url: 'https://new.com' }; mockService.updateCustomLink.mockResolvedValue({ id: 'cl_1' }); - const result = await controller.updateCustomLink('cl_1', body as any, authContext); + const result = await controller.updateCustomLink( + 'cl_1', + body as any, + authContext, + ); expect(result).toEqual({ id: 'cl_1' }); expect(service.updateCustomLink).toHaveBeenCalledWith( @@ -490,10 +548,16 @@ describe('TrustPortalController', () => { const body = { organizationId: orgId, linkIds: ['cl_1', 'cl_2'] }; mockService.reorderCustomLinks.mockResolvedValue({ success: true }); - const result = await controller.reorderCustomLinks(body as any, authContext); + const result = await controller.reorderCustomLinks( + body as any, + authContext, + ); expect(result).toEqual({ success: true }); - expect(service.reorderCustomLinks).toHaveBeenCalledWith(orgId, ['cl_1', 'cl_2']); + expect(service.reorderCustomLinks).toHaveBeenCalledWith(orgId, [ + 'cl_1', + 'cl_2', + ]); }); it('should throw BadRequestException for organization mismatch', async () => { @@ -526,9 +590,15 @@ describe('TrustPortalController', () => { describe('updateVendorTrustSettings', () => { it('should call service.updateVendorTrustSettings with vendorId, dto, and organizationId', async () => { const body = { showOnTrustPortal: true }; - mockService.updateVendorTrustSettings.mockResolvedValue({ success: true }); + mockService.updateVendorTrustSettings.mockResolvedValue({ + success: true, + }); - const result = await controller.updateVendorTrustSettings('v_1', body as any, authContext); + const result = await controller.updateVendorTrustSettings( + 'v_1', + body as any, + authContext, + ); expect(result).toEqual({ success: true }); expect(service.updateVendorTrustSettings).toHaveBeenCalledWith( diff --git a/apps/api/src/trust-portal/trust-portal.controller.ts b/apps/api/src/trust-portal/trust-portal.controller.ts index bbbf2c0ecf..30f67b916d 100644 --- a/apps/api/src/trust-portal/trust-portal.controller.ts +++ b/apps/api/src/trust-portal/trust-portal.controller.ts @@ -332,7 +332,9 @@ export class TrustPortalController { @Post('settings/custom-domain') @RequirePermission('trust', 'update') - @ApiOperation({ summary: 'Add or update a custom domain for the trust portal' }) + @ApiOperation({ + summary: 'Add or update a custom domain for the trust portal', + }) async addCustomDomain( @OrganizationId() organizationId: string, @Body() body: { domain: string }, @@ -353,10 +355,7 @@ export class TrustPortalController { if (!body.domain) { throw new BadRequestException('Domain is required'); } - return this.trustPortalService.checkDnsRecords( - organizationId, - body.domain, - ); + return this.trustPortalService.checkDnsRecords(organizationId, body.domain); } @Put('settings/faqs') @@ -366,10 +365,7 @@ export class TrustPortalController { @OrganizationId() organizationId: string, @Body() body: { faqs: Array<{ question: string; answer: string }> }, ) { - return this.trustPortalService.updateFaqs( - organizationId, - body.faqs ?? [], - ); + return this.trustPortalService.updateFaqs(organizationId, body.faqs ?? []); } @Put('settings/allowed-domains') @@ -557,7 +553,11 @@ export class TrustPortalController { @ApiOperation({ summary: 'List vendors configured for trust portal', }) - @ApiQuery({ name: 'all', required: false, description: 'When true, returns all org vendors with sync' }) + @ApiQuery({ + name: 'all', + required: false, + description: 'When true, returns all org vendors with sync', + }) async listVendors( @OrganizationId() organizationId: string, @Query('all') all?: string, diff --git a/apps/api/src/trust-portal/trust-portal.service.ts b/apps/api/src/trust-portal/trust-portal.service.ts index 0fb95662d6..fc63fc07d8 100644 --- a/apps/api/src/trust-portal/trust-portal.service.ts +++ b/apps/api/src/trust-portal/trust-portal.service.ts @@ -11,7 +11,6 @@ import { GetObjectCommand, PutObjectCommand, } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; import { db } from '@db'; import { DomainStatusResponseDto, @@ -25,7 +24,7 @@ import { UploadComplianceResourceDto, } from './dto/compliance-resource.dto'; import * as dns from 'node:dns'; -import { APP_AWS_ORG_ASSETS_BUCKET, s3Client } from '../app/s3'; +import { APP_AWS_ORG_ASSETS_BUCKET, s3Client, getSignedUrl } from '../app/s3'; import { DeleteTrustDocumentDto, TrustDocumentResponseDto, @@ -101,7 +100,8 @@ export class TrustPortalService { if (!resp.ok) { const errorBody = await resp.json().catch(() => ({})); const err = new Error( - errorBody?.error?.message || `Vercel API ${method} ${path} failed (${resp.status})`, + errorBody?.error?.message || + `Vercel API ${method} ${path} failed (${resp.status})`, ) as Error & { status: number; responseData: unknown }; err.status = resp.status; err.responseData = errorBody; @@ -210,7 +210,7 @@ export class TrustPortalService { // Get domain information including verification status // Vercel API endpoint: GET /v9/projects/{projectId}/domains/{domain} - const teamId = process.env.VERCEL_TEAM_ID!; + const teamId = process.env.VERCEL_TEAM_ID; const [domainResponse, configResponse] = await Promise.all([ this.vercelFetch({ method: 'GET', @@ -771,9 +771,7 @@ export class TrustPortalService { }); const domainVerified = - currentTrust?.domain === domain - ? currentTrust.domainVerified - : false; + currentTrust?.domain === domain ? currentTrust.domainVerified : false; // Remove old domain from Vercel if switching to a different one if (currentTrust?.domain && currentTrust.domain !== domain) { @@ -826,8 +824,7 @@ export class TrustPortalService { const statusData = statusResp.data; const isVercelDomain = statusData.verified === false; - const vercelVerification = - statusData.verification?.[0]?.value || null; + const vercelVerification = statusData.verification?.[0]?.value || null; await db.trust.upsert({ where: { organizationId }, @@ -863,8 +860,7 @@ export class TrustPortalService { const addData = addResp.data; const isVercelDomain = addData.verified === false; - const vercelVerification = - addData.verification?.[0]?.value || null; + const vercelVerification = addData.verification?.[0]?.value || null; await db.trust.upsert({ where: { organizationId }, @@ -889,7 +885,17 @@ export class TrustPortalService { }; } catch (error) { // Handle Vercel 409 conflict — domain already exists on the project - const vercelError = error as Error & { status?: number; responseData?: { error?: { code?: string; projectId?: string; message?: string; domain?: VercelDomainResponse } } }; + const vercelError = error as Error & { + status?: number; + responseData?: { + error?: { + code?: string; + projectId?: string; + message?: string; + domain?: VercelDomainResponse; + }; + }; + }; if (vercelError.status === 409) { const errorData = vercelError.responseData?.error; @@ -943,7 +949,9 @@ export class TrustPortalService { // Extract meaningful error message const errorMessage = - error instanceof Error ? error.message : 'Failed to update custom domain'; + error instanceof Error + ? error.message + : 'Failed to update custom domain'; this.logger.error(`Custom domain error for ${domain}:`, error); throw new BadRequestException(errorMessage); @@ -951,7 +959,8 @@ export class TrustPortalService { } /** Validate domain to prevent path injection in API URLs */ - private static readonly VALID_DOMAIN_PATTERN = /^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$/; + private static readonly VALID_DOMAIN_PATTERN = + /^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$/; private validateDomain(domain: string): void { if (!TrustPortalService.VALID_DOMAIN_PATTERN.test(domain)) { @@ -1013,8 +1022,7 @@ export class TrustPortalService { }); const vercelData = vercelStatusResp.data; liveIsVercelDomain = vercelData.verified === false; - liveVercelVerification = - vercelData.verification?.[0]?.value || null; + liveVercelVerification = vercelData.verification?.[0]?.value || null; // Sync DB with live Vercel state await db.trust.update({ @@ -1107,9 +1115,7 @@ export class TrustPortalService { // the _vercel TXT record before the domain will serve traffic. // For same-account domains, DNS verification is sufficient — Vercel will // pick up the CNAME on its own, so don't block on the verify response. - const domainFullyVerified = requiresVercelTxt - ? vercelVerified - : true; + const domainFullyVerified = requiresVercelTxt ? vercelVerified : true; await db.trust.update({ where: { organizationId }, @@ -1485,8 +1491,8 @@ export class TrustPortalService { soc2type1Status: trust.soc2type1_status ?? 'started', soc2type2Status: !trust.soc2type2 && trust.soc2 - ? trust.soc2_status ?? 'started' - : trust.soc2type2_status ?? 'started', + ? (trust.soc2_status ?? 'started') + : (trust.soc2type2_status ?? 'started'), iso27001Status: trust.iso27001_status ?? 'started', iso42001Status: trust.iso42001_status ?? 'started', gdprStatus: trust.gdpr_status ?? 'started', @@ -1627,9 +1633,9 @@ export class TrustPortalService { globalVendor.riskAssessmentData, ); if (extractedBadges && extractedBadges.length > 0) { - const currentBadges = vendor.complianceBadges as - | Array<{ type: string }> - | null; + const currentBadges = vendor.complianceBadges as Array<{ + type: string; + }> | null; const currentTypes = new Set( currentBadges?.map((b) => b.type) ?? [], ); diff --git a/apps/api/src/utils/assignment-filter.spec.ts b/apps/api/src/utils/assignment-filter.spec.ts index b260862833..c884b287f1 100644 --- a/apps/api/src/utils/assignment-filter.spec.ts +++ b/apps/api/src/utils/assignment-filter.spec.ts @@ -104,8 +104,12 @@ describe('Assignment Filter Utilities', () => { }); it('should return empty filter for API key auth regardless of roles/memberId', () => { - expect(buildTaskAssignmentFilter(null, null, { isApiKey: true })).toEqual({}); - expect(buildTaskAssignmentFilter(undefined, [], { isApiKey: true })).toEqual({}); + expect(buildTaskAssignmentFilter(null, null, { isApiKey: true })).toEqual( + {}, + ); + expect( + buildTaskAssignmentFilter(undefined, [], { isApiKey: true }), + ).toEqual({}); }); }); @@ -200,8 +204,12 @@ describe('Assignment Filter Utilities', () => { }); it('should return true for API key auth regardless of assignment or memberId', () => { - expect(hasTaskAccess(unassignedTask, null, null, { isApiKey: true })).toBe(true); - expect(hasTaskAccess(noAssigneeTask, undefined, [], { isApiKey: true })).toBe(true); + expect( + hasTaskAccess(unassignedTask, null, null, { isApiKey: true }), + ).toBe(true); + expect( + hasTaskAccess(noAssigneeTask, undefined, [], { isApiKey: true }), + ).toBe(true); }); }); @@ -225,7 +233,9 @@ describe('Assignment Filter Utilities', () => { }); it('should return true for API key auth regardless of assignment or memberId', () => { - expect(hasRiskAccess(unassignedRisk, null, null, { isApiKey: true })).toBe(true); + expect( + hasRiskAccess(unassignedRisk, null, null, { isApiKey: true }), + ).toBe(true); }); }); @@ -261,9 +271,9 @@ describe('Assignment Filter Utilities', () => { }); it('should return false for restricted role when control has no tasks', () => { - expect( - hasControlAccess(controlWithNoTasks, memberId, ['employee']), - ).toBe(false); + expect(hasControlAccess(controlWithNoTasks, memberId, ['employee'])).toBe( + false, + ); }); it('should return false for restricted role with no memberId', () => { diff --git a/apps/api/src/utils/compliance-filters.ts b/apps/api/src/utils/compliance-filters.ts index 8411ca621e..dd64b3c948 100644 --- a/apps/api/src/utils/compliance-filters.ts +++ b/apps/api/src/utils/compliance-filters.ts @@ -46,7 +46,10 @@ export async function filterComplianceMembers( const allCustomRoleNames = new Set(); const builtInRoleNames = new Set(Object.keys(allRoles)); const memberRoles = members.map((member) => { - const roleNames = member.role.split(',').map((r) => r.trim()).filter(Boolean); + const roleNames = member.role + .split(',') + .map((r) => r.trim()) + .filter(Boolean); const customNames = roleNames.filter((n) => !builtInRoleNames.has(n)); for (const name of customNames) allCustomRoleNames.add(name); return { member, roleNames }; @@ -61,9 +64,10 @@ export async function filterComplianceMembers( }); customObligationMap = Object.fromEntries( customRoles.map((r) => { - const obligations = typeof r.obligations === 'string' - ? JSON.parse(r.obligations) - : (r.obligations || {}); + const obligations = + typeof r.obligations === 'string' + ? JSON.parse(r.obligations) + : r.obligations || {}; return [r.name, obligations as RoleObligations]; }), ); diff --git a/apps/api/src/utils/department-visibility.spec.ts b/apps/api/src/utils/department-visibility.spec.ts index 9f6fe23631..8dba3b7f8b 100644 --- a/apps/api/src/utils/department-visibility.spec.ts +++ b/apps/api/src/utils/department-visibility.spec.ts @@ -196,9 +196,9 @@ describe('Department Visibility Utilities', () => { }; it('should return true when member department is in visible list', () => { - expect( - canViewPolicy(itAndHrPolicy, Departments.it, ['employee']), - ).toBe(true); + expect(canViewPolicy(itAndHrPolicy, Departments.it, ['employee'])).toBe( + true, + ); expect( canViewPolicy(itAndHrPolicy, Departments.hr, ['contractor']), ).toBe(true); @@ -233,9 +233,9 @@ describe('Department Visibility Utilities', () => { visibility: 'UNKNOWN' as PolicyVisibility, visibleToDepartments: [], }; - expect( - canViewPolicy(unknownPolicy, Departments.it, ['employee']), - ).toBe(false); + expect(canViewPolicy(unknownPolicy, Departments.it, ['employee'])).toBe( + false, + ); }); it('should handle empty visibleToDepartments array', () => { diff --git a/apps/api/src/utils/file-type-validation.spec.ts b/apps/api/src/utils/file-type-validation.spec.ts index 2552d2682c..7297e30b6a 100644 --- a/apps/api/src/utils/file-type-validation.spec.ts +++ b/apps/api/src/utils/file-type-validation.spec.ts @@ -3,48 +3,70 @@ import { validateFileContent } from './file-type-validation'; describe('validateFileContent', () => { it('should accept a valid PNG file', () => { - const pngBuffer = Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a]); - expect(() => validateFileContent(pngBuffer, 'image/png', 'test.png')).not.toThrow(); + const pngBuffer = Buffer.from([ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + ]); + expect(() => + validateFileContent(pngBuffer, 'image/png', 'test.png'), + ).not.toThrow(); }); it('should accept a valid PDF file', () => { const pdfBuffer = Buffer.from('%PDF-1.4 some content'); - expect(() => validateFileContent(pdfBuffer, 'application/pdf', 'test.pdf')).not.toThrow(); + expect(() => + validateFileContent(pdfBuffer, 'application/pdf', 'test.pdf'), + ).not.toThrow(); }); it('should accept a valid JPEG file', () => { const jpegBuffer = Buffer.from([0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10]); - expect(() => validateFileContent(jpegBuffer, 'image/jpeg', 'test.jpg')).not.toThrow(); + expect(() => + validateFileContent(jpegBuffer, 'image/jpeg', 'test.jpg'), + ).not.toThrow(); }); it('should reject HTML content disguised as PNG', () => { const htmlBuffer = Buffer.from(''); - expect(() => validateFileContent(htmlBuffer, 'image/png', 'test.png')).toThrow(BadRequestException); + expect(() => + validateFileContent(htmlBuffer, 'image/png', 'test.png'), + ).toThrow(BadRequestException); }); it('should reject PNG with wrong magic bytes', () => { const fakeBuffer = Buffer.from([0x00, 0x00, 0x00, 0x00]); - expect(() => validateFileContent(fakeBuffer, 'image/png', 'test.png')).toThrow(BadRequestException); + expect(() => + validateFileContent(fakeBuffer, 'image/png', 'test.png'), + ).toThrow(BadRequestException); }); it('should reject files containing script tags regardless of type', () => { - const malicious = Buffer.from(''); - expect(() => validateFileContent(malicious, 'text/plain', 'readme.txt')).toThrow(BadRequestException); + const malicious = Buffer.from( + '', + ); + expect(() => + validateFileContent(malicious, 'text/plain', 'readme.txt'), + ).toThrow(BadRequestException); }); it('should reject files with event handlers', () => { const malicious = Buffer.from(''); - expect(() => validateFileContent(malicious, 'text/plain', 'readme.txt')).toThrow(BadRequestException); + expect(() => + validateFileContent(malicious, 'text/plain', 'readme.txt'), + ).toThrow(BadRequestException); }); it('should allow text files that are actually text', () => { const textBuffer = Buffer.from('Hello, this is a normal text file.'); - expect(() => validateFileContent(textBuffer, 'text/plain', 'readme.txt')).not.toThrow(); + expect(() => + validateFileContent(textBuffer, 'text/plain', 'readme.txt'), + ).not.toThrow(); }); it('should allow unknown MIME types without magic byte check', () => { const csvBuffer = Buffer.from('name,email\njohn,john@example.com'); - expect(() => validateFileContent(csvBuffer, 'text/csv', 'data.csv')).not.toThrow(); + expect(() => + validateFileContent(csvBuffer, 'text/csv', 'data.csv'), + ).not.toThrow(); }); it('should accept a valid WebP file', () => { @@ -53,7 +75,9 @@ describe('validateFileContent', () => { webpBuffer.write('RIFF', 0); webpBuffer.writeUInt32LE(8, 4); webpBuffer.write('WEBP', 8); - expect(() => validateFileContent(webpBuffer, 'image/webp', 'photo.webp')).not.toThrow(); + expect(() => + validateFileContent(webpBuffer, 'image/webp', 'photo.webp'), + ).not.toThrow(); }); it('should reject a WAV file disguised as WebP', () => { @@ -62,7 +86,9 @@ describe('validateFileContent', () => { wavBuffer.write('RIFF', 0); wavBuffer.writeUInt32LE(8, 4); wavBuffer.write('WAVE', 8); - expect(() => validateFileContent(wavBuffer, 'image/webp', 'fake.webp')).toThrow(BadRequestException); + expect(() => + validateFileContent(wavBuffer, 'image/webp', 'fake.webp'), + ).toThrow(BadRequestException); }); it('should reject a RIFF file with script content disguised as WebP', () => { @@ -70,6 +96,8 @@ describe('validateFileContent', () => { malicious.write('RIFF', 0); malicious.writeUInt32LE(56, 4); malicious.write('AVI ', 8); // Not WEBP - expect(() => validateFileContent(malicious, 'image/webp', 'evil.webp')).toThrow(BadRequestException); + expect(() => + validateFileContent(malicious, 'image/webp', 'evil.webp'), + ).toThrow(BadRequestException); }); }); diff --git a/apps/api/src/vendors/dto/create-vendor.dto.ts b/apps/api/src/vendors/dto/create-vendor.dto.ts index 4e5813f5fe..1542c93de9 100644 --- a/apps/api/src/vendors/dto/create-vendor.dto.ts +++ b/apps/api/src/vendors/dto/create-vendor.dto.ts @@ -8,12 +8,7 @@ import { IsBoolean, } from 'class-validator'; import { Transform } from 'class-transformer'; -import { - VendorCategory, - VendorStatus, - Likelihood, - Impact, -} from '@db'; +import { VendorCategory, VendorStatus, Likelihood, Impact } from '@db'; export class CreateVendorDto { @ApiProperty({ diff --git a/apps/api/src/vendors/dto/trigger-vendor-risk-assessment.dto.ts b/apps/api/src/vendors/dto/trigger-vendor-risk-assessment.dto.ts index 54c722db61..98e06ab1cc 100644 --- a/apps/api/src/vendors/dto/trigger-vendor-risk-assessment.dto.ts +++ b/apps/api/src/vendors/dto/trigger-vendor-risk-assessment.dto.ts @@ -29,7 +29,11 @@ export class TriggerVendorRiskAssessmentVendorDto { } export class TriggerSingleVendorRiskAssessmentDto { - @ApiProperty({ description: 'Organization ID (deprecated — use auth context)', example: 'org_abc123', required: false }) + @ApiProperty({ + description: 'Organization ID (deprecated — use auth context)', + example: 'org_abc123', + required: false, + }) @IsOptional() @IsString() organizationId?: string; @@ -59,7 +63,11 @@ export class TriggerSingleVendorRiskAssessmentDto { } export class TriggerVendorRiskAssessmentBatchDto { - @ApiProperty({ description: 'Organization ID (deprecated — use auth context)', example: 'org_abc123', required: false }) + @ApiProperty({ + description: 'Organization ID (deprecated — use auth context)', + example: 'org_abc123', + required: false, + }) @IsOptional() @IsString() organizationId?: string; diff --git a/apps/api/src/vendors/dto/update-vendor.dto.spec.ts b/apps/api/src/vendors/dto/update-vendor.dto.spec.ts index cc25082e2d..53641b820f 100644 --- a/apps/api/src/vendors/dto/update-vendor.dto.spec.ts +++ b/apps/api/src/vendors/dto/update-vendor.dto.spec.ts @@ -23,19 +23,28 @@ describe('UpdateVendorDto', () => { isSubProcessor: false, assigneeId: 'mem_abc123', }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors).toHaveLength(0); }); it('should accept a minimal update (single field)', async () => { const dto = toDto({ website: 'https://www.acronis.com' }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors).toHaveLength(0); }); it('should accept an empty body (no fields to update)', async () => { const dto = toDto({}); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors).toHaveLength(0); }); @@ -49,13 +58,19 @@ describe('UpdateVendorDto', () => { website: 'https://www.acronis.com', isSubProcessor: false, }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors).toHaveLength(0); }); it('should still reject empty name', async () => { const dto = toDto({ name: '' }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors.length).toBeGreaterThan(0); expect(errors[0].property).toBe('name'); }); @@ -63,27 +78,39 @@ describe('UpdateVendorDto', () => { // ── assigneeId: null (unassigned vendor) ────────────────────────── it('should accept assigneeId: null', async () => { const dto = toDto({ assigneeId: null }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors).toHaveLength(0); }); // ── website handling ────────────────────────────────────────────── it('should transform empty website to undefined (skip validation)', async () => { const dto = toDto({ website: '' }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors).toHaveLength(0); expect(dto.website).toBeUndefined(); }); it('should accept a valid website URL', async () => { const dto = toDto({ website: 'https://www.cloudflare.com' }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors).toHaveLength(0); }); it('should reject an invalid website URL', async () => { const dto = toDto({ website: 'not-a-url' }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors.length).toBeGreaterThan(0); expect(errors[0].property).toBe('website'); }); @@ -91,14 +118,20 @@ describe('UpdateVendorDto', () => { // ── enum validation ─────────────────────────────────────────────── it('should reject invalid category enum', async () => { const dto = toDto({ category: 'invalid_category' }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors.length).toBeGreaterThan(0); expect(errors[0].property).toBe('category'); }); it('should reject invalid status enum', async () => { const dto = toDto({ status: 'invalid_status' }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors.length).toBeGreaterThan(0); expect(errors[0].property).toBe('status'); }); @@ -106,7 +139,10 @@ describe('UpdateVendorDto', () => { // ── forbidNonWhitelisted ────────────────────────────────────────── it('should reject unknown properties', async () => { const dto = toDto({ name: 'Acronis', unknownField: 'value' }); - const errors = await validate(dto, { whitelist: true, forbidNonWhitelisted: true }); + const errors = await validate(dto, { + whitelist: true, + forbidNonWhitelisted: true, + }); expect(errors.length).toBeGreaterThan(0); expect(errors.some((e) => e.property === 'unknownField')).toBe(true); }); diff --git a/apps/api/src/vendors/dto/update-vendor.dto.ts b/apps/api/src/vendors/dto/update-vendor.dto.ts index 1d2c87deb4..c7186b3f70 100644 --- a/apps/api/src/vendors/dto/update-vendor.dto.ts +++ b/apps/api/src/vendors/dto/update-vendor.dto.ts @@ -8,12 +8,7 @@ import { IsBoolean, } from 'class-validator'; import { Transform } from 'class-transformer'; -import { - VendorCategory, - VendorStatus, - Likelihood, - Impact, -} from '@db'; +import { VendorCategory, VendorStatus, Likelihood, Impact } from '@db'; /** * DTO for PATCH /vendors/:id @@ -46,7 +41,10 @@ export class UpdateVendorDto { @IsEnum(VendorStatus) status?: VendorStatus; - @ApiPropertyOptional({ description: 'Inherent probability', enum: Likelihood }) + @ApiPropertyOptional({ + description: 'Inherent probability', + enum: Likelihood, + }) @IsOptional() @IsEnum(Likelihood) inherentProbability?: Likelihood; @@ -56,7 +54,10 @@ export class UpdateVendorDto { @IsEnum(Impact) inherentImpact?: Impact; - @ApiPropertyOptional({ description: 'Residual probability', enum: Likelihood }) + @ApiPropertyOptional({ + description: 'Residual probability', + enum: Likelihood, + }) @IsOptional() @IsEnum(Likelihood) residualProbability?: Likelihood; diff --git a/apps/api/src/vendors/dto/vendor-response.dto.ts b/apps/api/src/vendors/dto/vendor-response.dto.ts index 5e5d8aa472..7c295d24b8 100644 --- a/apps/api/src/vendors/dto/vendor-response.dto.ts +++ b/apps/api/src/vendors/dto/vendor-response.dto.ts @@ -1,10 +1,5 @@ import { ApiProperty } from '@nestjs/swagger'; -import { - VendorCategory, - VendorStatus, - Likelihood, - Impact, -} from '@db'; +import { VendorCategory, VendorStatus, Likelihood, Impact } from '@db'; export class VendorResponseDto { @ApiProperty({ diff --git a/apps/api/src/vendors/internal-vendor-automation.controller.ts b/apps/api/src/vendors/internal-vendor-automation.controller.ts index aafaad292c..9526d7fb0b 100644 --- a/apps/api/src/vendors/internal-vendor-automation.controller.ts +++ b/apps/api/src/vendors/internal-vendor-automation.controller.ts @@ -1,5 +1,10 @@ import { Body, Controller, HttpCode, Post, UseGuards } from '@nestjs/common'; -import { ApiOperation, ApiResponse, ApiSecurity, ApiTags } from '@nestjs/swagger'; +import { + ApiOperation, + ApiResponse, + ApiSecurity, + ApiTags, +} from '@nestjs/swagger'; import { HybridAuthGuard } from '../auth/hybrid-auth.guard'; import { PermissionGuard } from '../auth/permission.guard'; import { RequirePermission } from '../auth/require-permission.decorator'; diff --git a/apps/api/src/vendors/vendors.controller.spec.ts b/apps/api/src/vendors/vendors.controller.spec.ts index 163fc92bac..76391db6e6 100644 --- a/apps/api/src/vendors/vendors.controller.spec.ts +++ b/apps/api/src/vendors/vendors.controller.spec.ts @@ -59,9 +59,7 @@ describe('VendorsController', () => { beforeEach(async () => { const module: TestingModule = await Test.createTestingModule({ controllers: [VendorsController], - providers: [ - { provide: VendorsService, useValue: mockVendorsService }, - ], + providers: [{ provide: VendorsService, useValue: mockVendorsService }], }) .overrideGuard(HybridAuthGuard) .useValue(mockGuard) @@ -124,7 +122,10 @@ describe('VendorsController', () => { it('should not include authenticatedUser when userId is missing', async () => { mockVendorsService.findAllByOrganization.mockResolvedValue([]); - const result = await controller.getAllVendors('org_123', apiKeyAuthContext); + const result = await controller.getAllVendors( + 'org_123', + apiKeyAuthContext, + ); expect(result.authenticatedUser).toBeUndefined(); expect(result.authType).toBe('api-key'); @@ -171,7 +172,11 @@ describe('VendorsController', () => { describe('createVendor', () => { it('should create a vendor and return with auth context', async () => { const dto = { name: 'New Vendor', category: 'SaaS' }; - const createdVendor = { id: 'vnd_new', name: 'New Vendor', category: 'SaaS' }; + const createdVendor = { + id: 'vnd_new', + name: 'New Vendor', + category: 'SaaS', + }; mockVendorsService.create.mockResolvedValue(createdVendor); const result = await controller.createVendor( @@ -280,7 +285,9 @@ describe('VendorsController', () => { }); it('should pass undefined userId when auth context has no userId', async () => { - mockVendorsService.triggerAssessment.mockResolvedValue({ status: 'pending' }); + mockVendorsService.triggerAssessment.mockResolvedValue({ + status: 'pending', + }); await controller.triggerAssessment('vnd_1', 'org_123', apiKeyAuthContext); diff --git a/apps/api/src/vendors/vendors.controller.ts b/apps/api/src/vendors/vendors.controller.ts index cfb1ef6d2e..41d47bb101 100644 --- a/apps/api/src/vendors/vendors.controller.ts +++ b/apps/api/src/vendors/vendors.controller.ts @@ -45,10 +45,12 @@ export class VendorsController { @Get('global/search') @RequirePermission('vendor', 'read') @ApiOperation({ summary: 'Search global vendors database' }) - @ApiQuery({ name: 'name', required: false, description: 'Vendor name to search for' }) - async searchGlobalVendors( - @Query('name') name?: string, - ) { + @ApiQuery({ + name: 'name', + required: false, + description: 'Vendor name to search for', + }) + async searchGlobalVendors(@Query('name') name?: string) { return this.vendorsService.searchGlobal(name ?? ''); } diff --git a/apps/api/src/vendors/vendors.service.ts b/apps/api/src/vendors/vendors.service.ts index cb94bd8404..47c49b0d17 100644 --- a/apps/api/src/vendors/vendors.service.ts +++ b/apps/api/src/vendors/vendors.service.ts @@ -1,4 +1,9 @@ -import { BadRequestException, Injectable, NotFoundException, Logger } from '@nestjs/common'; +import { + BadRequestException, + Injectable, + NotFoundException, + Logger, +} from '@nestjs/common'; import { db, TaskItemPriority, TaskItemStatus } from '@db'; import { CreateVendorDto } from './dto/create-vendor.dto'; import { UpdateVendorDto } from './dto/update-vendor.dto'; @@ -198,13 +203,18 @@ export class VendorsService { } } - private async validateAssigneeNotPlatformAdmin(assigneeId: string, organizationId: string) { + private async validateAssigneeNotPlatformAdmin( + assigneeId: string, + organizationId: string, + ) { const member = await db.member.findFirst({ where: { id: assigneeId, organizationId }, include: { user: { select: { role: true } } }, }); if (member?.user.role === 'admin') { - throw new BadRequestException('Cannot assign a platform admin as assignee'); + throw new BadRequestException( + 'Cannot assign a platform admin as assignee', + ); } } @@ -215,7 +225,10 @@ export class VendorsService { ) { try { if (createVendorDto.assigneeId) { - await this.validateAssigneeNotPlatformAdmin(createVendorDto.assigneeId, organizationId); + await this.validateAssigneeNotPlatformAdmin( + createVendorDto.assigneeId, + organizationId, + ); } const vendor = await db.vendor.create({ data: { @@ -607,7 +620,10 @@ export class VendorsService { updateVendorDto.assigneeId && updateVendorDto.assigneeId !== existing.assigneeId ) { - await this.validateAssigneeNotPlatformAdmin(updateVendorDto.assigneeId, organizationId); + await this.validateAssigneeNotPlatformAdmin( + updateVendorDto.assigneeId, + organizationId, + ); } const updatedVendor = await db.vendor.update({ diff --git a/apps/api/test/maced-contract.e2e-spec.ts b/apps/api/test/maced-contract.e2e-spec.ts index 6cbb299582..28a9f1b25d 100644 --- a/apps/api/test/maced-contract.e2e-spec.ts +++ b/apps/api/test/maced-contract.e2e-spec.ts @@ -1,4 +1,7 @@ -import { MacedClient, type MacedPentestRun } from '../src/security-penetration-tests/maced-client'; +import { + MacedClient, + type MacedPentestRun, +} from '../src/security-penetration-tests/maced-client'; const enabledValues = new Set(['1', 'true', 'yes']); const isContractCanaryEnabled = enabledValues.has( diff --git a/apps/app/package.json b/apps/app/package.json index 9c091828a9..508d84aa6e 100644 --- a/apps/app/package.json +++ b/apps/app/package.json @@ -9,11 +9,11 @@ "@ai-sdk/provider": "^3.0.0", "@ai-sdk/react": "^3.0.0", "@ai-sdk/rsc": "^2.0.0", - "@aws-sdk/client-ec2": "^3.911.0", - "@aws-sdk/client-lambda": "^3.891.0", - "@aws-sdk/client-s3": "^3.859.0", - "@aws-sdk/client-sts": "^3.808.0", - "@aws-sdk/s3-request-presigner": "^3.859.0", + "@aws-sdk/client-ec2": "^3.948.0", + "@aws-sdk/client-lambda": "^3.948.0", + "@aws-sdk/client-s3": "3.1013.0", + "@aws-sdk/client-sts": "^3.948.0", + "@aws-sdk/s3-request-presigner": "3.1013.0", "@azure/core-rest-pipeline": "^1.21.0", "@browserbasehq/sdk": "2.6.0", "@browserbasehq/stagehand": "^3.0.5", diff --git a/apps/app/src/actions/files/upload-file.ts b/apps/app/src/actions/files/upload-file.ts index d45f570b17..57d02d909e 100644 --- a/apps/app/src/actions/files/upload-file.ts +++ b/apps/app/src/actions/files/upload-file.ts @@ -4,7 +4,7 @@ import { BUCKET_NAME, s3Client } from '@/app/s3'; import { auth } from '@/utils/auth'; import { logger } from '@/utils/logger'; import { GetObjectCommand, PutObjectCommand } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; +import { getSignedUrl } from '@/lib/s3-presigner'; import { AttachmentEntityType, AttachmentType, db } from '@db/server'; import { revalidatePath } from 'next/cache'; import { headers } from 'next/headers'; diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/actions/batch-fix.ts b/apps/app/src/app/(app)/[orgId]/cloud-tests/actions/batch-fix.ts new file mode 100644 index 0000000000..01a93c90fb --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/actions/batch-fix.ts @@ -0,0 +1,164 @@ +'use server'; + +import { auth, runs, tasks } from '@trigger.dev/sdk'; +import { serverApi } from '@/lib/api-server'; + +interface BatchFixInput { + organizationId: string; + connectionId: string; + findings: Array<{ id: string; key: string; title: string }>; +} + +export async function startBatchFix( + input: BatchFixInput, +): Promise<{ data?: { batchId: string; runId: string; accessToken: string }; error?: string }> { + try { + // Step 1: Create batch record in DB via API + const api = serverApi; + const batchResp = await api.post<{ data: { id: string } }>('/v1/cloud-security/remediation/batch', { + connectionId: input.connectionId, + findings: input.findings, + }); + + if (batchResp.error || !batchResp.data?.data?.id) { + return { error: 'Failed to create batch record' }; + } + + const batchId = batchResp.data.data.id; + + // Step 2: Trigger the API-layer task + const handle = await tasks.trigger('remediate-batch', { + batchId, + organizationId: input.organizationId, + connectionId: input.connectionId, + }); + + // Step 3: Store triggerRunId on the batch + await api.patch(`/v1/cloud-security/remediation/batch/${batchId}`, { + triggerRunId: handle.id, + status: 'running', + }); + + // Step 4: Create public access token for real-time progress + const accessToken = await auth.createPublicToken({ + scopes: { read: { runs: [handle.id] } }, + }); + + return { data: { batchId, runId: handle.id, accessToken } }; + } catch (err) { + console.error('Failed to start batch fix:', err); + return { error: err instanceof Error ? err.message : 'Failed to start batch fix' }; + } +} + +export async function cancelBatchFix(runId: string, batchId: string): Promise { + try { + // Mark batch as cancelled in DB — task will check this before next finding + const api = serverApi; + await api.patch(`/v1/cloud-security/remediation/batch/${batchId}`, { + status: 'cancelled', + }); + // Also cancel the trigger run + await runs.cancel(runId); + } catch { + // Run may have already completed + } +} + +/** Check for an active batch on page load — returns batch + access token if found. */ +export async function getActiveBatch( + connectionId: string, +): Promise<{ + batchId: string; + triggerRunId: string; + accessToken: string; + findings: Array<{ id: string; title: string; status: string; error?: string }>; +} | null> { + try { + const resp = await serverApi.get( + `/v1/cloud-security/remediation/batch/active?connectionId=${connectionId}`, + ); + const batch = (resp.data as { data?: { id: string; triggerRunId?: string; findings: unknown[] } })?.data; + if (!batch?.triggerRunId) return null; + + // Verify the trigger run is actually still active + try { + const run = await runs.retrieve(batch.triggerRunId); + if (run.status === 'COMPLETED' || run.status === 'FAILED' || run.status === 'CANCELED' || run.status === 'SYSTEM_FAILURE') { + // Run is done — mark batch as done in DB so it doesn't show up again + await serverApi.patch(`/v1/cloud-security/remediation/batch/${batch.id}`, { + status: 'done', + }); + return null; + } + } catch { + // Can't verify run — mark batch as done to be safe + await serverApi.patch(`/v1/cloud-security/remediation/batch/${batch.id}`, { + status: 'done', + }); + return null; + } + + const accessToken = await auth.createPublicToken({ + scopes: { read: { runs: [batch.triggerRunId] } }, + }); + + return { + batchId: batch.id, + triggerRunId: batch.triggerRunId, + accessToken, + findings: batch.findings as Array<{ id: string; title: string; status: string; error?: string }>, + }; + } catch { + return null; + } +} + +export async function skipBatchFinding(batchId: string, findingId: string): Promise { + try { + await serverApi.post(`/v1/cloud-security/remediation/batch/${batchId}/skip/${findingId}`, {}); + } catch { + // Best effort + } +} + +/** Retry a single finding immediately (user added permissions and wants instant retry). */ +export async function retryFinding( + connectionId: string, + checkResultId: string, + remediationKey: string, +): Promise<{ status: 'fixed' | 'failed' | 'needs_permissions'; error?: string; missingPermissions?: string[] }> { + try { + // Preview first + const preview = await serverApi.post<{ + guidedOnly?: boolean; + missingPermissions?: string[]; + }>('/v1/cloud-security/remediation/preview', { + connectionId, + checkResultId, + remediationKey, + }); + + if (preview.error) return { status: 'failed', error: String(preview.error) }; + + const data = preview.data as { guidedOnly?: boolean; missingPermissions?: string[] } | undefined; + if (data?.missingPermissions && data.missingPermissions.length > 0) { + return { status: 'needs_permissions', missingPermissions: data.missingPermissions }; + } + + // Execute + const execute = await serverApi.post<{ status: string; error?: string }>( + '/v1/cloud-security/remediation/execute', + { connectionId, checkResultId, remediationKey, acknowledgment: 'acknowledged' }, + ); + + const execData = execute.data as { status?: string; error?: string } | undefined; + if (execute.error || execData?.status === 'failed') { + return { status: 'failed', error: String(execute.error ?? execData?.error ?? 'Failed') }; + } + + return { status: 'fixed' }; + } catch (err) { + return { status: 'failed', error: err instanceof Error ? err.message : 'Failed' }; + } +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/AcknowledgmentPanel.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/AcknowledgmentPanel.tsx new file mode 100644 index 0000000000..cbc483f0a9 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/AcknowledgmentPanel.tsx @@ -0,0 +1,115 @@ +'use client'; + +import { Input } from '@trycompai/ui/input'; +import { AlertTriangle, ListOrdered } from 'lucide-react'; + +interface AcknowledgmentPanelProps { + requiresAcknowledgment?: 'type-to-confirm' | 'checkbox'; + acknowledgmentMessage?: string; + confirmationPhrase?: string; + guidedOnly?: boolean; + guidedSteps?: string[]; + onAcknowledgmentChange: (value: string | null) => void; + acknowledged: boolean; +} + +export function AcknowledgmentPanel({ + requiresAcknowledgment, + acknowledgmentMessage, + confirmationPhrase, + guidedOnly, + guidedSteps, + onAcknowledgmentChange, + acknowledged, +}: AcknowledgmentPanelProps) { + if (guidedOnly) { + return ( +
+
+ + + Manual Steps Required + +
+

+ This remediation must be performed manually. Follow these steps: +

+ {guidedSteps && guidedSteps.length > 0 && ( +
    + {guidedSteps.map((step, index) => ( +
  1. + {step} +
  2. + ))} +
+ )} +
+ ); + } + + if (requiresAcknowledgment === 'type-to-confirm') { + return ( +
+
+ +

+ {acknowledgmentMessage} +

+
+
+ + + onAcknowledgmentChange(e.target.value || null) + } + className={ + acknowledged + ? 'border-emerald-300 focus-visible:ring-emerald-500' + : '' + } + /> +
+
+ ); + } + + if (requiresAcknowledgment === 'checkbox') { + return ( +
+
+ +

+ {acknowledgmentMessage} +

+
+ +
+ ); + } + + return null; +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/AzureSetupGuide.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/AzureSetupGuide.tsx new file mode 100644 index 0000000000..8b27cf80d1 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/AzureSetupGuide.tsx @@ -0,0 +1,226 @@ +'use client'; + +import { useApi } from '@/hooks/use-api'; +import { Check, ExternalLink, Loader2, X } from 'lucide-react'; +import { useEffect, useRef, useState } from 'react'; +import { toast } from 'sonner'; + +interface AzureSetupGuideProps { + connectionId: string; + hasSubscriptionId: boolean; + onRunScan: () => void; + isScanning: boolean; +} + +interface SetupStep { + name: string; + success: boolean; + error?: string; +} + +const MANUAL_STEPS = [ + { + title: 'Ensure you have a Subscription', + description: 'An active Azure subscription is required.', + link: 'https://portal.azure.com/#blade/Microsoft_Azure_Billing/SubscriptionsBlade', + linkText: 'Subscriptions', + }, + { + title: 'Assign Security Reader role', + description: 'Your account needs Security Reader on the subscription.', + link: 'https://portal.azure.com/#blade/Microsoft_Azure_Billing/SubscriptionsBlade', + linkText: 'Subscription → IAM', + }, + { + title: 'Enable Microsoft Defender for Cloud', + description: 'Enable at least the free tier of Defender for Cloud.', + link: 'https://portal.azure.com/#blade/Microsoft_Azure_Security/SecurityMenuBlade/EnvironmentSettings', + linkText: 'Defender Settings', + }, +]; + +export function AzureSetupGuide({ + connectionId, + hasSubscriptionId, + onRunScan, + isScanning, +}: AzureSetupGuideProps) { + const api = useApi(); + const [isSettingUp, setIsSettingUp] = useState(false); + const [setupResult, setSetupResult] = useState<{ + steps: SetupStep[]; + subscriptionId?: string; + subscriptionName?: string; + } | null>(null); + + const ranRef = useRef(false); + + // Auto-run setup on first mount — no user action needed + useEffect(() => { + if (ranRef.current) return; + ranRef.current = true; + handleAutoSetup(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + const handleAutoSetup = async () => { + setIsSettingUp(true); + try { + const resp = await api.post<{ + steps: SetupStep[]; + subscriptionId?: string; + subscriptionName?: string; + }>(`/v1/cloud-security/setup-azure/${connectionId}`, {}); + + if (resp.error) { + toast.error(typeof resp.error === 'string' ? resp.error : 'Setup failed'); + return; + } + + if (resp.data) { + setSetupResult(resp.data); + const succeeded = resp.data.steps.filter((s) => s.success).length; + const total = resp.data.steps.length; + if (succeeded === total) { + toast.success('Azure setup complete — running first scan...'); + // Auto-run scan when everything passes + onRunScan(); + } else { + toast.message(`${succeeded}/${total} steps completed. See details below.`); + } + } + } catch { + toast.error('Setup failed'); + } finally { + setIsSettingUp(false); + } + }; + + const allStepsSucceeded = setupResult?.steps.every((s) => s.success); + + return ( +
+
+
+

Get started with Azure scanning

+

+ We'll detect your subscription and verify access. You can do it automatically or follow the manual steps. +

+
+ + {/* Auto-setup in progress */} + {!setupResult && ( +
+ + {hasSubscriptionId && } + +
+ +

Verifying access and configuring...

+
+
+ )} + + {/* Setup results */} + {setupResult && ( +
+ + {setupResult.steps.map((step, i) => ( + + ))} +
+ )} + + {/* Manual fallback */} + {setupResult && !allStepsSucceeded && ( +
+

+ Some checks need attention: +

+
+ {MANUAL_STEPS.map((step, i) => ( +
+ {step.title} + + {step.linkText} + +
+ ))} +
+
+ )} + + {/* Run scan button — only shown if setup partially failed */} + {setupResult && !allStepsSucceeded && ( + + )} +
+ + {/* Auto-fix info — shown only before setup runs */} + {!setupResult && ( +
+

+ Scanning works with Reader + Security Reader roles.{' '} + Auto-fix requires Contributor-level access. + We'll detect your permissions automatically during setup. +

+
+ )} +
+ ); +} + +function StepRow({ + done, + failed, + label, + error, +}: { + done?: boolean; + failed?: boolean; + label: string; + error?: string; +}) { + return ( +
+
+ {done && } + {failed && } +
+
+

+ {label} +

+ {error && ( +

{error}

+ )} +
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/BatchRemediationDialog.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/BatchRemediationDialog.tsx new file mode 100644 index 0000000000..10c107be43 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/BatchRemediationDialog.tsx @@ -0,0 +1,672 @@ +'use client'; + +import { Badge } from '@trycompai/ui/badge'; +import { Button } from '@trycompai/ui/button'; +import { Checkbox } from '@trycompai/ui/checkbox'; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from '@trycompai/ui/dialog'; +import { + Check, + Copy, + ExternalLink, + Loader2, + Play, + RefreshCw, + ShieldAlert, + SkipForward, + X, + Zap, +} from 'lucide-react'; +import { toast } from 'sonner'; +import { useCallback, useEffect, useMemo, useState } from 'react'; +import { useRealtimeRun } from '@trigger.dev/react-hooks'; +import { + startBatchFix, + cancelBatchFix, + skipBatchFinding, + retryFinding, +} from '../actions/batch-fix'; + +interface Finding { + id: string; + title: string | null; + key: string; + severity: string; +} + +interface BatchRemediationDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + serviceName: string; + findings: Finding[]; + connectionId: string; + organizationId: string; + onComplete?: () => void; + /** Called when the trigger run starts — parent uses this to enable the floating pill. */ + onRunStarted?: (info: { batchId: string; triggerRunId: string; accessToken: string }) => void; + /** Resume an active batch (loaded on page mount). */ + activeBatch?: { + batchId: string; + triggerRunId: string; + accessToken: string; + findings: Array<{ id: string; title: string; status: string; error?: string }>; + } | null; +} + +type FindingStatus = 'pending' | 'fixing' | 'fixed' | 'needs_permissions' | 'skipped' | 'failed' | 'cancelled'; + +interface FindingProgress { + id: string; + key?: string; + title: string; + severity?: string; + status: FindingStatus; + error?: string; + missingPermissions?: string[]; +} + +interface BatchProgress { + current: number; + total: number; + fixed: number; + skipped: number; + failed: number; + findings: FindingProgress[]; + phase: 'running' | 'retrying' | 'scanning' | 'waiting_for_permissions' | 'done' | 'cancelled'; + permChecksLeft?: number; +} + +const STATUS_CONFIG: Record = { + pending: { icon: Loader2, color: 'text-muted-foreground/40', bg: '' }, + fixing: { icon: Loader2, color: 'text-primary', bg: 'bg-primary/[0.04]' }, + fixed: { icon: Check, color: 'text-emerald-500', bg: '' }, + needs_permissions: { icon: ShieldAlert, color: 'text-muted-foreground', bg: '' }, + skipped: { icon: SkipForward, color: 'text-muted-foreground', bg: '' }, + failed: { icon: X, color: 'text-red-500', bg: '' }, + cancelled: { icon: X, color: 'text-muted-foreground', bg: '' }, +}; + +const SEVERITY_DOT: Record = { + critical: 'bg-red-500', + high: 'bg-red-400', + medium: 'bg-amber-400', + low: 'bg-blue-400', + info: 'bg-gray-300', +}; + +/** Per-finding inline permissions with copy/cloudshell/retry. */ +function FindingPermissions({ + permissions, + onRetry, +}: { + permissions: string[]; + onRetry: () => void; +}) { + const [copied, setCopied] = useState(false); + const [retrying, setRetrying] = useState(false); + + // Group by service + const grouped = useMemo(() => { + const groups: Record = {}; + for (const p of permissions) { + const [svc, action] = p.split(':'); + if (svc && action) (groups[svc] ??= []).push(action); + } + return groups; + }, [permissions]); + + const script = [ + 'ROLE="CompAI-Remediator" POLICY="CompAI-BatchPermissions"', + `NEW='${JSON.stringify(permissions)}'`, + 'CUR=$(aws iam get-role-policy --role-name "$ROLE" --policy-name "$POLICY" --query \'PolicyDocument.Statement[0].Action\' --output json 2>/dev/null || echo \'[]\')', + 'MERGED=$(echo "$CUR $NEW" | jq -s \'add | unique\')', + 'aws iam put-role-policy --role-name "$ROLE" --policy-name "$POLICY" --policy-document "{\\"Version\\":\\"2012-10-17\\",\\"Statement\\":[{\\"Effect\\":\\"Allow\\",\\"Action\\":$MERGED,\\"Resource\\":\\"*\\"}]}"', + ].join('\n'); + + return ( +
+
+ {Object.entries(grouped).map(([svc, actions]) => ( +
+ {svc}: + {actions.map((a) => ( + {a} + ))} +
+ ))} +
+
+ + + + CloudShell + + +
+
+ ); +} + +/** Consolidated banner — shows permissions grouped by AWS service, merge-safe script. */ +function MissingPermsBanner({ + findings, + confirmedPermissions, +}: { + findings: FindingProgress[]; + confirmedPermissions: string[]; +}) { + const [copied, setCopied] = useState(false); + const confirmed = useMemo(() => new Set(confirmedPermissions), [confirmedPermissions]); + + // Group by AWS service + const grouped = useMemo(() => { + const perms = new Set(); + for (const f of findings) { + if (f.missingPermissions) { + for (const p of f.missingPermissions) { + if (!confirmed.has(p)) perms.add(p); + } + } + } + const groups: Record = {}; + for (const p of [...perms].sort()) { + const [svc, action] = p.split(':'); + if (!svc || !action) continue; + (groups[svc] ??= []).push(action); + } + return groups; + }, [findings, confirmed]); + + const allMissing = Object.entries(grouped).flatMap(([svc, actions]) => + actions.map((a) => `${svc}:${a}`), + ); + + if (allMissing.length === 0) return null; + + // Merge-safe script: reads existing policy, merges new permissions, writes combined + // Uses jq (available in AWS CloudShell) to avoid overwriting existing perms + const newPermsJson = JSON.stringify(allMissing); + const script = [ + '# Merge new permissions with existing (won\'t overwrite)', + 'ROLE="CompAI-Remediator"', + 'POLICY="CompAI-BatchPermissions"', + `NEW_PERMS='${newPermsJson}'`, + '', + '# Get existing permissions (empty array if policy doesn\'t exist yet)', + 'EXISTING=$(aws iam get-role-policy --role-name "$ROLE" --policy-name "$POLICY" \\', + ' --query \'PolicyDocument.Statement[0].Action\' --output json 2>/dev/null || echo \'[]\')', + '', + '# Merge and deduplicate', + 'MERGED=$(echo "$EXISTING $NEW_PERMS" | jq -s \'add | unique\')', + '', + '# Apply combined policy', + 'aws iam put-role-policy --role-name "$ROLE" --policy-name "$POLICY" \\', + ' --policy-document "{\\"Version\\":\\"2012-10-17\\",\\"Statement\\":[{\\"Effect\\":\\"Allow\\",\\"Action\\":$MERGED,\\"Resource\\":\\"*\\"}]}"', + '', + 'echo "Added $(echo $NEW_PERMS | jq length) permissions ($(echo $MERGED | jq length) total)"', + ].join('\n'); + + const handleCopy = () => { + navigator.clipboard.writeText(script); + setCopied(true); + toast.success('Permission script copied'); + setTimeout(() => setCopied(false), 2000); + }; + + const serviceCount = Object.keys(grouped).length; + + return ( +
+
+
+ +
+
+

+ {allMissing.length} permission{allMissing.length !== 1 ? 's' : ''} needed across {serviceCount} service{serviceCount !== 1 ? 's' : ''} +

+

+ Run the script below — it merges with existing permissions, nothing gets overwritten. +

+
+
+ + {/* Grouped by service */} +
+ {Object.entries(grouped).map(([svc, actions]) => ( +
+ {svc} +
+ {actions.map((a) => ( + + {a} + + ))} +
+
+ ))} +
+ +
+ + + + CloudShell + +
+
+ ); +} + +export function BatchRemediationDialog({ + open, + onOpenChange, + serviceName, + findings, + connectionId, + organizationId, + onComplete, + onRunStarted, + activeBatch, +}: BatchRemediationDialogProps) { + const [selected, setSelected] = useState>(new Set()); + const [acknowledged, setAcknowledged] = useState(false); + const [batchId, setBatchId] = useState(null); + const [runId, setRunId] = useState(null); + const [accessToken, setAccessToken] = useState(null); + const [starting, setStarting] = useState(false); + const [cancelling, setCancelling] = useState(false); + + // Resume active batch if provided + useEffect(() => { + if (activeBatch && open) { + setBatchId(activeBatch.batchId); + setRunId(activeBatch.triggerRunId); + setAccessToken(activeBatch.accessToken); + } + }, [activeBatch, open]); + + // Real-time task progress + const { run } = useRealtimeRun(runId ?? '', { + accessToken: accessToken ?? undefined, + enabled: Boolean(runId && accessToken), + }); + + const progress = (run?.metadata as { progress?: BatchProgress } | undefined) + ?.progress ?? null; + + // Detect if the trigger run itself is finished (cancelled, failed, completed) + const runStatus = run?.status; + const runFinished = runStatus === 'COMPLETED' || runStatus === 'FAILED' || runStatus === 'CANCELED' || runStatus === 'SYSTEM_FAILURE'; + + const isRunning = Boolean(runId) && !runFinished && (!progress || progress.phase === 'running' || progress.phase === 'retrying'); + const isWaitingPerms = progress?.phase === 'waiting_for_permissions'; + const isScanning = progress?.phase === 'scanning'; + const isDone = progress?.phase === 'done' || progress?.phase === 'cancelled' || runFinished; + + // Reset on open + useEffect(() => { + if (open && !activeBatch) { + setSelected(new Set(findings.map((f) => f.id))); + setAcknowledged(false); + setBatchId(null); + setRunId(null); + setAccessToken(null); + setCancelling(false); + } + }, [open, findings, activeBatch]); + + // Auto-complete + auto-close when all findings are fixed + useEffect(() => { + if (isDone && progress && progress.fixed > 0) { + onComplete?.(); + // Auto-close if everything succeeded (no failures or skips) + const allFixed = progress.failed === 0 && progress.skipped === 0; + if (allFixed) { + const timer = setTimeout(() => onOpenChange(false), 3000); + return () => clearTimeout(timer); + } + } + }, [isDone, progress, onComplete, onOpenChange]); + + // Findings with progress (from task metadata or initial list) + const findingsWithProgress = useMemo((): FindingProgress[] => { + if (progress?.findings) return progress.findings; + if (activeBatch?.findings) { + return activeBatch.findings.map((f) => ({ + id: f.id, + title: f.title, + status: (f.status as FindingStatus) || 'pending', + error: f.error, + })); + } + if (runId) { + return findings + .filter((f) => selected.has(f.id)) + .map((f) => ({ id: f.id, title: f.title ?? 'Untitled', status: 'pending' as FindingStatus })); + } + return []; + }, [progress, runId, findings, selected, activeBatch]); + + const handleToggle = useCallback((id: string) => { + setSelected((prev) => { + const next = new Set(prev); + if (next.has(id)) next.delete(id); + else next.add(id); + return next; + }); + }, []); + + const handleToggleAll = useCallback(() => { + if (selected.size === findings.length) setSelected(new Set()); + else setSelected(new Set(findings.map((f) => f.id))); + }, [selected.size, findings]); + + const handleStart = async () => { + const selectedFindings = findings + .filter((f) => selected.has(f.id)) + .map((f) => ({ id: f.id, key: f.key, title: f.title ?? 'Untitled' })); + if (selectedFindings.length === 0) return; + + setStarting(true); + const result = await startBatchFix({ + organizationId, + connectionId, + findings: selectedFindings, + }); + setStarting(false); + + if (result.error || !result.data) return; + + setBatchId(result.data.batchId); + setRunId(result.data.runId); + setAccessToken(result.data.accessToken); + onRunStarted?.({ + batchId: result.data.batchId, + triggerRunId: result.data.runId, + accessToken: result.data.accessToken, + }); + }; + + const handleCancel = async () => { + if (!runId || !batchId) return; + setCancelling(true); + await cancelBatchFix(runId, batchId); + }; + + const handleSkipFinding = async (findingId: string) => { + if (!batchId) return; + await skipBatchFinding(batchId, findingId); + }; + + // Retry: create a new batch with only the skipped/failed findings + const handleRetrySkipped = async () => { + const retryFindings = findingsWithProgress + .filter((f) => f.status === 'skipped' || f.status === 'failed') + .map((f) => { + const orig = findings.find((o) => o.id === f.id); + return orig ? { id: orig.id, key: orig.key, title: orig.title ?? 'Untitled' } : null; + }) + .filter((f): f is { id: string; key: string; title: string } => f !== null); + + if (retryFindings.length === 0) return; + + setStarting(true); + const result = await startBatchFix({ organizationId, connectionId, findings: retryFindings }); + setStarting(false); + + if (result.error || !result.data) return; + + setBatchId(result.data.batchId); + setRunId(result.data.runId); + setAccessToken(result.data.accessToken); + onRunStarted?.({ + batchId: result.data.batchId, + triggerRunId: result.data.runId, + accessToken: result.data.accessToken, + }); + }; + + const handleClose = () => { + // Allow close even while running — task continues in background + onOpenChange(false); + }; + + const selectedCount = selected.size; + const allSelected = selectedCount === findings.length; + const pct = progress ? Math.round((progress.current / progress.total) * 100) : 0; + const hasSkippedOrFailed = findingsWithProgress.some( + (f) => f.status === 'skipped' || f.status === 'failed' || f.status === 'needs_permissions', + ); + + return ( + + + + + + Fix All — {serviceName} + + + {runId + ? `Processing ${progress?.total ?? selectedCount} findings` + : `${selectedCount} finding${selectedCount !== 1 ? 's' : ''} selected for auto-fix`} + + + + {/* ─── Pre-start: Selection ─── */} + {!runId && ( + <> +
+ + + {selectedCount} selected +
+ +
+ {findings.map((f) => ( + + ))} +
+ +
+ + +
+ + )} + + {/* ─── In-progress / Done ─── */} + {runId && ( + <> + {/* Progress bar */} +
+
+
+
+
+ + {isScanning ? 'Re-scanning to verify...' + : isDone ? (progress?.phase === 'cancelled' ? 'Cancelled' : 'Complete') + : isWaitingPerms ? `Waiting for permissions... (${progress?.permChecksLeft ?? 0} checks left)` + : progress?.phase === 'retrying' ? 'Retrying with new permissions...' + : `Fixing ${progress?.current ?? 0} of ${progress?.total ?? selectedCount}...`} + +
+ {(progress?.fixed ?? 0) > 0 && {progress!.fixed} fixed} + {(progress?.skipped ?? 0) > 0 && {progress!.skipped} skipped} + {(progress?.failed ?? 0) > 0 && {progress!.failed} failed} +
+
+
+ + {/* Finding progress list */} +
+ {findingsWithProgress.map((f) => { + const config = STATUS_CONFIG[f.status] ?? STATUS_CONFIG.pending; + const Icon = config.icon; + const canSkip = f.status === 'pending' && !isDone; + const isMissingPerms = f.status === 'needs_permissions' && f.missingPermissions && f.missingPermissions.length > 0; + + return ( +
+
+
+ +
+
+

+ {f.title} +

+ {f.error && !isMissingPerms && ( +

{f.error}

+ )} +
+ {canSkip && ( + + )} + {f.status === 'fixed' && Done} + {f.status === 'cancelled' && Removed} +
+ {/* Per-finding permissions — only shows for THIS finding */} + {isMissingPerms && ( + { + // Find the original finding data for key + const orig = findings.find((o) => o.id === f.id); + if (!orig) return; + const result = await retryFinding(connectionId, f.id, orig.key); + if (result.status === 'fixed') { + toast.success(`Fixed: ${f.title}`); + onComplete?.(); + } else if (result.status === 'needs_permissions') { + toast.error('Still missing permissions'); + } else { + toast.error(result.error ?? 'Retry failed'); + } + }} + /> + )} +
+ ); + })} +
+ + {/* Actions */} +
+ {!isDone && !isScanning && ( + + )} + {isScanning && ( + + )} + {isDone && hasSkippedOrFailed && ( + + )} + {isDone && ( + + )} + {!isDone && ( + + )} +
+ + )} + +
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudSettingsModal.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudSettingsModal.tsx index 394535f80c..be17668aa9 100644 --- a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudSettingsModal.tsx +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudSettingsModal.tsx @@ -7,14 +7,12 @@ import { Dialog, DialogContent, DialogDescription, - DialogFooter, DialogHeader, DialogTitle, } from '@trycompai/ui/dialog'; import { Button, Tabs, - TabsContent, TabsList, TabsTrigger, cn, @@ -24,8 +22,8 @@ import { useState } from 'react'; import { toast } from 'sonner'; interface CloudProvider { - id: string; // Provider slug (aws, gcp, azure) - connectionId: string; // The actual connection ID + id: string; + connectionId: string; name: string; status: string; accountId?: string; @@ -40,20 +38,12 @@ interface CloudSettingsModalProps { onUpdate: () => void; } -/** - * Get the appropriate text color class based on connection status - */ const getStatusColorClass = (status: string): string => { switch (status.toLowerCase()) { case 'active': return 'text-green-600 dark:text-green-400'; case 'error': return 'text-red-600 dark:text-red-400'; - case 'pending': - return 'text-amber-600 dark:text-amber-400'; - case 'paused': - case 'disconnected': - return 'text-muted-foreground'; default: return 'text-muted-foreground'; } @@ -68,24 +58,18 @@ export function CloudSettingsModal({ const api = useApi(); const { hasPermission } = usePermissions(); const canDelete = hasPermission('integration', 'delete'); - const [activeTab, setActiveTab] = useState(connectedProviders[0]?.connectionId || 'aws'); + const [activeProvider, setActiveProvider] = useState(connectedProviders[0]?.connectionId || ''); const [isDeleting, setIsDeleting] = useState(false); const { deleteConnection } = useIntegrationMutations(); + const currentProvider = connectedProviders.find((p) => p.connectionId === activeProvider) ?? connectedProviders[0]; + const handleDisconnect = async (provider: CloudProvider) => { - if ( - !confirm( - 'Are you sure you want to disconnect this cloud provider? All scan results will be deleted.', - ) - ) { - return; - } + if (!confirm('Are you sure? All scan results will be deleted.')) return; try { setIsDeleting(true); - if (provider.isLegacy) { - // Legacy providers use the old Integration table const response = await api.delete(`/v1/cloud-security/legacy/${provider.connectionId}`); if (!response.error) { toast.success('Cloud provider disconnected'); @@ -96,8 +80,6 @@ export function CloudSettingsModal({ } return; } - - // New platform providers use the IntegrationConnection table const result = await deleteConnection(provider.connectionId); if (result.success) { toast.success('Cloud provider disconnected'); @@ -106,87 +88,106 @@ export function CloudSettingsModal({ } else { toast.error(result.error || 'Failed to disconnect'); } - } catch (error) { - console.error('Disconnect error:', error); + } catch { toast.error('An unexpected error occurred'); } finally { setIsDeleting(false); } }; - if (connectedProviders.length === 0) { - return null; - } + if (connectedProviders.length === 0) return null; return ( - + - Manage Cloud Connections + Connection Settings - Manage your cloud provider connections. To update credentials, disconnect and reconnect. + Manage your cloud provider connections. - - - {connectedProviders.map((provider) => ( - - {provider.name} - - ))} - - - {connectedProviders.map((provider) => ( - -
-
-

- {provider.name} is connected. Credentials are securely stored and encrypted at - rest. -

- {(provider.accountId || provider.regions?.length) && ( -
- {provider.accountId &&

Account: {provider.accountId}

} - {provider.regions?.length &&

Regions: {provider.regions.join(', ')}

} -
- )} -
- -
-
- Connection Status - - {provider.status} - -
-

- To update credentials, disconnect this provider and reconnect with new IAM role - settings. -

-
- - - {canDelete && ( - - )} - -
-
- ))} -
+ {/* Provider selector (if multiple) */} + {connectedProviders.length > 1 && ( + + + {connectedProviders.map((p) => ( + + {p.name} + + ))} + + + )} + + {currentProvider && ( + + )}
); } + +// ─── Connection Tab ───────────────────────────────────────────────────── + +function ConnectionTab({ + provider, + canDelete, + isDeleting, + onDisconnect, +}: { + provider: CloudProvider; + canDelete: boolean; + isDeleting: boolean; + onDisconnect: (p: CloudProvider) => void; +}) { + return ( +
+
+
+ Status + + {provider.status} + +
+ {provider.accountId && ( +
+ Account + {provider.accountId} +
+ )} + {provider.regions && provider.regions.length > 0 && ( +
+ Regions + {provider.regions.length} region{provider.regions.length !== 1 ? 's' : ''} +
+ )} +
+ +

+ {provider.id === 'aws' + ? 'To update credentials, disconnect and reconnect with new IAM role settings.' + : provider.id === 'gcp' + ? 'To update credentials, disconnect and reconnect with your Google account.' + : 'To update credentials, disconnect and reconnect with your Microsoft account.'} +

+ + {canDelete && ( + + )} +
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudTestsSection.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudTestsSection.tsx new file mode 100644 index 0000000000..636ec4063c --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/CloudTestsSection.tsx @@ -0,0 +1,1343 @@ +'use client'; + +import { useApi } from '@/hooks/use-api'; +import { Badge } from '@trycompai/ui/badge'; +import { Button } from '@trycompai/ui/button'; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from '@trycompai/ui/dialog'; +import { + AlertTriangle, + Check, + ChevronDown, + ChevronRight, + Copy, + ExternalLink, + ListOrdered, + Loader2, + RefreshCw, + Search, + ShieldAlert, + ShieldCheck, + ShieldX, + Terminal, + Wrench, + X, + Zap, +} from 'lucide-react'; +import { awsRemediationScript } from '@trycompai/integration-platform'; +import { useCallback, useEffect, useMemo, useState } from 'react'; +import { toast } from 'sonner'; +import { getActiveBatch } from '../actions/batch-fix'; +import { BatchRemediationDialog } from './BatchRemediationDialog'; +import { AzureSetupGuide } from './AzureSetupGuide'; +import { GcpSetupGuide } from './GcpSetupGuide'; +import { RemediationDialog } from './RemediationDialog'; +import { ScheduledScanPopover } from './ScheduledScanPopover'; + +import type { Finding } from '../types'; + +interface RemediationCapabilities { + enabled: boolean; + remediations: Array<{ + remediationKey: string; + findingPattern: string; + description: string; + risk: string; + guidedOnly?: boolean; + guidedSteps?: string[]; + rollbackSupported?: boolean; + }>; +} + +interface CloudTestsSectionProps { + providerSlug: string; + connectionId: string; + onScanComplete?: () => void; + orgId: string; + /** When the last scan completed — null means never scanned */ + lastRunAt?: Date | null; + /** Connection variables (e.g., GCP org ID) */ + variables?: Record; +} + +const SEVERITY_ORDER: Record = { + critical: 0, high: 1, medium: 2, low: 3, info: 4, +}; + +const SEVERITY_STYLES: Record = { + critical: { + dot: 'bg-red-500', + badge: 'border-red-200 bg-red-50 text-red-700 dark:border-red-800 dark:bg-red-950 dark:text-red-400', + }, + high: { + dot: 'bg-orange-500', + badge: 'border-orange-200 bg-orange-50 text-orange-700 dark:border-orange-800 dark:bg-orange-950 dark:text-orange-400', + }, + medium: { + dot: 'bg-yellow-500', + badge: 'border-yellow-200 bg-yellow-50 text-yellow-700 dark:border-yellow-800 dark:bg-yellow-950 dark:text-yellow-400', + }, + low: { + dot: 'bg-blue-500', + badge: 'border-blue-200 bg-blue-50 text-blue-700 dark:border-blue-800 dark:bg-blue-950 dark:text-blue-400', + }, + info: { + dot: 'bg-gray-400', + badge: 'border-gray-200 bg-gray-50 text-gray-600 dark:border-gray-700 dark:bg-gray-900 dark:text-gray-400', + }, +}; + +const SERVICE_NAMES: Record = { + 'security-hub': 'Security Hub', + 'iam-analyzer': 'IAM Access Analyzer', + 'cloudtrail': 'CloudTrail', + 's3': 'S3 Bucket Security', + 'ec2-vpc': 'EC2 & VPC Security', + 'rds': 'RDS Security', + 'kms': 'KMS', + 'cloudwatch': 'CloudWatch', + 'config': 'AWS Config', + 'guardduty': 'GuardDuty', + 'secrets-manager': 'Secrets Manager', + 'waf': 'WAF', + 'elb': 'ELB / ALB', + 'acm': 'ACM', + 'backup': 'AWS Backup', + 'inspector': 'Inspector', + 'ecs-eks': 'ECS & EKS', + 'lambda': 'Lambda', + 'dynamodb': 'DynamoDB', + 'sns-sqs': 'SNS & SQS', + 'ecr': 'ECR', + 'opensearch': 'OpenSearch', + 'redshift': 'Redshift', + 'macie': 'Macie', + 'route53': 'Route 53', + 'api-gateway': 'API Gateway', + 'cloudfront': 'CloudFront', + 'cognito': 'Cognito', + 'elasticache': 'ElastiCache', + 'efs': 'EFS', + 'msk': 'MSK', + 'sagemaker': 'SageMaker', + 'systems-manager': 'Systems Manager', + 'codebuild': 'CodeBuild', + 'network-firewall': 'Network Firewall', + 'shield': 'Shield', + 'kinesis': 'Kinesis', + 'glue': 'Glue', + 'athena': 'Athena', + 'emr': 'EMR', + 'step-functions': 'Step Functions', + 'eventbridge': 'EventBridge', + 'transfer-family': 'Transfer Family', + 'elastic-beanstalk': 'Elastic Beanstalk', + 'appflow': 'AppFlow', +}; + +interface ServiceGroup { + serviceId: string; + name: string; + findings: Finding[]; + passed: number; + failed: number; +} + +export function CloudTestsSection({ + providerSlug, + connectionId, + onScanComplete, + orgId, + lastRunAt, + variables, +}: CloudTestsSectionProps) { + const api = useApi(); + const [scanCompleted, setScanCompleted] = useState(false); + const [scanError, setScanError] = useState<{ message: string; errorCode?: string } | null>(null); + const [isScanning, setIsScanning] = useState(false); + const [batchServiceId, setBatchServiceId] = useState(null); + const [activeBatch, setActiveBatch] = useState<{ + batchId: string; + triggerRunId: string; + accessToken: string; + findings: Array<{ id: string; title: string; status: string; error?: string }>; + } | null>(null); + const [expandedIds, setExpandedIds] = useState>(new Set()); + const [expandedGroups, setExpandedGroups] = useState>(new Set()); + const [severityFilter, setSeverityFilter] = useState(null); + const [projectFilter, setProjectFilter] = useState(null); + const [searchQuery, setSearchQuery] = useState(''); + const [capabilities, setCapabilities] = + useState(null); + const [capabilitiesLoaded, setCapabilitiesLoaded] = useState(false); + const [remediationTarget, setRemediationTarget] = useState<{ + connectionId: string; + checkResultId: string; + remediationKey: string; + findingTitle: string; + guidedOnly?: boolean; + guidedSteps?: string[]; + risk?: string; + description?: string; + } | null>(null); + const [showSetupDialog, setShowSetupDialog] = useState(false); + + const findingsResponse = api.useSWR<{ data: Finding[]; count: number }>( + '/v1/cloud-security/findings', + { revalidateOnFocus: true }, + ); + + const allFindings = Array.isArray(findingsResponse.data?.data?.data) + ? findingsResponse.data.data.data + : []; + const hasLoadedFindings = + findingsResponse.data !== undefined || findingsResponse.error !== undefined; + + // Load remediation capabilities for the selected connection + useEffect(() => { + if (!connectionId || (providerSlug !== 'aws' && providerSlug !== 'gcp' && providerSlug !== 'azure')) return; + + const loadCapabilities = async () => { + const resp = await api.get( + `/v1/cloud-security/remediation/capabilities?connectionId=${connectionId}`, + ); + if (!resp.error && resp.data) { + setCapabilities(resp.data as RemediationCapabilities); + } + setCapabilitiesLoaded(true); + }; + loadCapabilities(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [connectionId, providerSlug]); + + // Check for active batch once on mount (separate from capabilities to avoid re-runs) + useEffect(() => { + if (!connectionId || (providerSlug !== 'aws' && providerSlug !== 'gcp' && providerSlug !== 'azure')) return; + let cancelled = false; + + getActiveBatch(connectionId).then((batch) => { + if (cancelled) return; + if (batch) { + setActiveBatch(batch); + setBatchServiceId('_active'); + } + }); + + return () => { cancelled = true; }; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [connectionId]); + + const canFixFinding = useCallback( + (finding: Finding): { key: string; enabled: boolean } | null => { + if (!capabilities?.enabled || !finding.findingKey) return null; + // AI-powered: every finding with a findingKey can be analyzed + return { key: finding.findingKey, enabled: true }; + }, + [capabilities], + ); + + const findings = useMemo(() => { + return allFindings + .filter( + (f) => + f.providerSlug === providerSlug || f.connectionId === connectionId, + ) + .filter( + (f) => !projectFilter || f.projectDisplayName === projectFilter, + ) + .sort( + (a, b) => + (SEVERITY_ORDER[a.severity ?? 'info'] ?? 5) - + (SEVERITY_ORDER[b.severity ?? 'info'] ?? 5), + ); + }, [allFindings, providerSlug, connectionId, projectFilter]); + + // Unique project names across all findings (for filter pills) + const projectNames = useMemo(() => { + const names = new Set(); + for (const f of allFindings) { + if ( + (f.providerSlug === providerSlug || f.connectionId === connectionId) && + f.projectDisplayName + ) { + names.add(f.projectDisplayName); + } + } + return [...names].sort((a, b) => a.localeCompare(b)); + }, [allFindings, providerSlug, connectionId]); + + const failedFindings = findings.filter( + (f) => f.status === 'failed' || f.status === 'FAILED', + ); + const passedFindings = findings.filter( + (f) => f.status === 'passed' || f.status === 'success', + ); + + // Group findings by serviceId + const serviceGroups = useMemo(() => { + const q = searchQuery.toLowerCase().trim(); + const groupMap = new Map(); + for (const f of findings) { + const key = f.serviceId ?? 'other'; + const group = groupMap.get(key) ?? []; + group.push(f); + groupMap.set(key, group); + } + + const groups: ServiceGroup[] = []; + for (const [serviceId, groupFindings] of groupMap) { + const serviceName = SERVICE_NAMES[serviceId] ?? serviceId; + const serviceMatches = q ? serviceName.toLowerCase().includes(q) : true; + + const failed = groupFindings.filter( + (f) => f.status === 'failed' || f.status === 'FAILED', + ); + const passed = groupFindings.filter( + (f) => f.status === 'passed' || f.status === 'success', + ); + + let filteredFailed = severityFilter + ? failed.filter((f) => f.severity?.toLowerCase() === severityFilter) + : failed; + + // If search query exists and service name doesn't match, filter findings by title + if (q && !serviceMatches) { + filteredFailed = filteredFailed.filter( + (f) => + f.title?.toLowerCase().includes(q) || + f.description?.toLowerCase().includes(q) || + f.findingKey?.toLowerCase().includes(q), + ); + } + + groups.push({ + serviceId, + name: serviceName, + findings: filteredFailed, + passed: passed.length, + failed: failed.length, + }); + } + + return groups + .filter((g) => g.findings.length > 0 || (!severityFilter && !q && g.passed > 0)) + .sort((a, b) => b.failed - a.failed || a.name.localeCompare(b.name)); + }, [findings, severityFilter, searchQuery]); + + // Split into baseline (security fundamentals) vs service-specific + const BASELINE_SERVICE_IDS = new Set(['cloudtrail', 'config', 'guardduty', 'iam', 'cloudwatch', 'kms']); + const baselineGroups = providerSlug === 'aws' + ? serviceGroups.filter((g) => BASELINE_SERVICE_IDS.has(g.serviceId)) + : []; + const regularGroups = providerSlug === 'aws' + ? serviceGroups.filter((g) => !BASELINE_SERVICE_IDS.has(g.serviceId)) + : serviceGroups; + + const severityCounts = useMemo(() => { + const counts: Record = {}; + for (const f of failedFindings) { + const sev = f.severity?.toLowerCase() ?? 'info'; + counts[sev] = (counts[sev] ?? 0) + 1; + } + return counts; + }, [failedFindings]); + + const handleRunScan = useCallback(async () => { + if (!connectionId) return; + setIsScanning(true); + const startTime = Date.now(); + toast.message('Starting security scan...'); + setScanError(null); + try { + const response = await api.post<{ + success?: boolean; + message?: string; + errorCode?: string; + }>( + `/v1/cloud-security/scan/${connectionId}`, + {}, + ); + if (response.error) { + const data = response.data as { message?: string; errorCode?: string } | undefined; + const errorCode = data?.errorCode; + const message = data?.message ?? (typeof response.error === 'string' ? response.error : 'Scan failed'); + // GCP setup errors get persistent inline message + if (errorCode === 'SCC_NOT_ACTIVATED' || errorCode === 'GCP_ORG_MISSING') { + setScanError({ message, errorCode }); + } else { + toast.error(message); + } + return; + } + await findingsResponse.mutate(); + onScanComplete?.(); + setScanCompleted(true); + const elapsed = Math.round((Date.now() - startTime) / 1000); + toast.success(`Scan completed in ${elapsed}s!`); + } catch (err) { + toast.error( + `Scan failed: ${err instanceof Error ? err.message : 'Unknown error'}`, + ); + } finally { + setIsScanning(false); + } + }, [connectionId, api, findingsResponse, onScanComplete]); + + const toggleExpanded = (id: string) => { + setExpandedIds((prev) => { + const next = new Set(prev); + if (next.has(id)) next.delete(id); + else next.add(id); + return next; + }); + }; + + const toggleGroup = (serviceId: string) => { + setExpandedGroups((prev) => { + const next = new Set(prev); + if (next.has(serviceId)) next.delete(serviceId); + else next.add(serviceId); + return next; + }); + }; + + // Track if batch dialog is open so we can show the floating pill when minimized + // Compute batch dialog data LIVE from current findings (never stale) + const batchTarget = useMemo(() => { + if (!batchServiceId) return null; + const group = serviceGroups.find((g) => g.serviceId === batchServiceId); + if (!group) return null; + const fixable = group.findings.filter((f) => { + const match = canFixFinding(f); + return match?.key && match.enabled; + }); + if (fixable.length === 0) return null; + return { + serviceName: group.name, + findings: fixable.map((f) => ({ + id: f.id, + title: f.title, + key: f.findingKey!, + severity: f.severity ?? 'medium', + })), + }; + }, [batchServiceId, serviceGroups, canFixFinding]); + + const batchDialogOpen = Boolean(batchServiceId); + + if (!connectionId) return null; + + return ( +
+ {/* Active batch pill — shows when batch is running but dialog is minimized */} + {activeBatch && !batchDialogOpen && ( + + )} + + {/* Scanning banner */} + {isScanning && ( +
+ +
+

Scanning...

+

Verifying your cloud security posture. This may take a moment.

+
+
+ )} + + {/* Header with scan button */} +
+
+

Security Findings

+

+ {findings.length} total findings for this account +

+
+
+ + +
+
+ + {/* Selected projects indicator (GCP) */} + {providerSlug === 'gcp' && (() => { + const ids = Array.isArray(variables?.project_ids) + ? (variables.project_ids as string[]) + : []; + const savedNames = (variables?.project_names ?? {}) as Record; + return ids.length > 0 ? ( +
+ + + + {ids.length} project{ids.length > 1 ? 's' : ''}: +
+ {ids.map((id: string) => { + const name = savedNames[id]; + return ( + + {name ?? id} + {name && {id}} + + ); + })} +
+ + Change + +
+ ) : ( +
+
+ + + + No projects selected — select projects to scope your scan. +
+ + Select projects + +
+ ); + })()} + + {/* Stats row */} +
+ } + value={passedFindings.length} + label="Passed" + accent="emerald" + /> + } + value={failedFindings.length} + label="Failed" + accent="red" + /> + } + value={findings.length} + label="Total" + accent="gray" + /> +
+ + {/* Severity filter pills */} + {failedFindings.length > 0 && ( +
+ + {Object.entries(SEVERITY_ORDER) + .sort(([, a], [, b]) => a - b) + .map(([sev]) => + severityCounts[sev] ? ( + + ) : null, + )} +
+ )} + + {/* Project filter (GCP multi-project) */} + {projectNames.length > 1 && ( +
+ + Project + + + {projectNames.map((name) => ( + + ))} +
+ )} + + {/* Search */} + {findings.length > 0 && ( +
+ + setSearchQuery(e.target.value)} + className="min-w-0 flex-1 bg-transparent text-xs outline-none placeholder:text-muted-foreground/40" + /> + {searchQuery && ( + + )} +
+ )} + + {/* Service findings */} + {regularGroups.length > 0 && ( +
+ {regularGroups.map((group) => { + const isGroupExpanded = expandedGroups.has(group.serviceId); + const hasFailures = group.findings.length > 0; + + return ( +
+ + {isGroupExpanded && ( +
+ {group.findings.length > 0 ? ( + group.findings.map((finding) => { + const match = canFixFinding(finding); + return ( + toggleExpanded(finding.id)} + remediationKey={match?.key ?? null} + remediationEnabled={match?.enabled ?? false} + capabilitiesLoaded={capabilitiesLoaded} + onFix={(key) => + setRemediationTarget({ + connectionId: finding.connectionId, + checkResultId: finding.id, + remediationKey: key, + findingTitle: finding.title ?? 'Finding', + }) + } + onSetup={() => setShowSetupDialog(true)} + /> + ); + }) + ) : ( +
+ + All {group.passed} checks passed +
+ )} +
+ )} +
+ ); + })} +
+ )} + + {/* Security baseline findings */} + {baselineGroups.length > 0 && ( +
+
+

Security Baseline

+
+
+

+ Core security checks that apply to every cloud account, regardless of which services you use. +

+ {baselineGroups.map((group) => { + const isGroupExpanded = expandedGroups.has(group.serviceId); + const hasFailures = group.findings.length > 0; + + return ( +
+ + {isGroupExpanded && ( +
+ {group.findings.length > 0 ? ( + group.findings.map((finding) => { + const match = canFixFinding(finding); + return ( + toggleExpanded(finding.id)} + remediationKey={match?.key ?? null} + remediationEnabled={match?.enabled ?? false} + capabilitiesLoaded={capabilitiesLoaded} + onFix={(key) => + setRemediationTarget({ + connectionId: finding.connectionId, + checkResultId: finding.id, + remediationKey: key, + findingTitle: finding.title ?? 'Finding', + }) + } + onSetup={() => setShowSetupDialog(true)} + /> + ); + }) + ) : ( +
+ + All {group.passed} checks passed +
+ )} +
+ )} +
+ ); + })} +
+ )} + + {/* No search results */} + {searchQuery && serviceGroups.length === 0 && findings.length > 0 && ( +
+ +

+ No findings matching "{searchQuery}" +

+ +
+ )} + + {/* GCP setup error — SCC not activated or org missing */} + {scanError && ( +
+
+
+ +
+
+

+ {scanError.errorCode === 'SCC_NOT_ACTIVATED' + ? 'Security Command Center is not activated' + : scanError.errorCode === 'GCP_ORG_MISSING' + ? 'GCP Organization not detected' + : 'Setup required'} +

+

+ {scanError.message} +

+ {scanError.errorCode === 'SCC_NOT_ACTIVATED' && ( + + Open GCP Console + + + )} +
+
+
+ )} + + {/* Empty state — never scanned */} + {hasLoadedFindings && findings.length === 0 && !lastRunAt && !scanCompleted && !scanError && ( + providerSlug === 'gcp' ? ( + 0 + : Boolean(variables?.project_id) + } + onRunScan={handleRunScan} + isScanning={isScanning} + orgId={orgId} + /> + ) : providerSlug === 'azure' ? ( + + ) : ( +
+
+ +
+

No scan results yet

+

+ Run a security scan to check your cloud posture. You can configure which services to scan in the Services tab. +

+
+ ) + )} + + {/* All checks passed — clean posture (AWS: has passed findings; GCP: scan ran but 0 findings) */} + {failedFindings.length === 0 && !findingsResponse.isValidating && (passedFindings.length > 0 || ((lastRunAt || scanCompleted) && findings.length === 0)) && serviceGroups.length === 0 && ( +
+
+ +
+

+ Looking good! +

+

+ {passedFindings.length > 0 + ? `All ${passedFindings.length} security checks passed — no issues found` + : 'Security scan completed — no issues found'} +

+
+ )} + + {/* Remediation dialog */} + {remediationTarget && ( + { + if (!open) setRemediationTarget(null); + }} + connectionId={remediationTarget.connectionId} + checkResultId={remediationTarget.checkResultId} + remediationKey={remediationTarget.remediationKey} + findingTitle={remediationTarget.findingTitle} + providerSlug={providerSlug} + guidedOnly={remediationTarget.guidedOnly} + guidedSteps={remediationTarget.guidedSteps} + risk={remediationTarget.risk} + description={remediationTarget.description} + onComplete={() => { + toast.message('Re-scanning to verify fix...'); + handleRunScan(); + }} + /> + )} + + {/* Batch remediation dialog */} + {batchServiceId && ( + { + if (!open) setBatchServiceId(null); + }} + serviceName={batchTarget?.serviceName ?? 'Resuming'} + findings={batchTarget?.findings ?? []} + connectionId={connectionId} + organizationId={orgId} + activeBatch={activeBatch} + onRunStarted={(info) => { + setActiveBatch({ ...info, findings: [] }); + }} + onComplete={() => { + setActiveBatch(null); + // Task already triggers a re-scan — just refresh the findings list + findingsResponse.mutate(); + }} + /> + )} + + {/* Remediation setup dialog */} + { + setShowSetupDialog(false); + // Reload capabilities after role ARN is saved + const loadCapabilities = async () => { + const resp = await api.get( + `/v1/cloud-security/remediation/capabilities?connectionId=${connectionId}`, + ); + if (!resp.error && resp.data) { + setCapabilities(resp.data as RemediationCapabilities); + } + }; + loadCapabilities(); + }} + /> + +
+ ); +} + +function StatCard({ + icon, + value, + label, + accent, +}: { + icon: React.ReactNode; + value: number; + label: string; + accent: string; +}) { + return ( +
+
+ {icon} +
+
+

{value}

+

{label}

+
+
+ ); +} + +function RemediationSetupDialog({ + open, + onOpenChange, + orgId, + connectionId, + onSaved, +}: { + open: boolean; + onOpenChange: (open: boolean) => void; + orgId: string; + connectionId: string; + onSaved?: () => void; +}) { + const api = useApi(); + const [copied, setCopied] = useState(false); + const [roleArn, setRoleArn] = useState(''); + const [saving, setSaving] = useState(false); + const [saveError, setSaveError] = useState(null); + + const finalScript = awsRemediationScript.replace( + /YOUR_EXTERNAL_ID/g, + orgId, + ); + + const handleCopy = useCallback(() => { + navigator.clipboard.writeText(finalScript); + setCopied(true); + setTimeout(() => setCopied(false), 2000); + }, [finalScript]); + + const handleSaveRoleArn = useCallback(async () => { + if (!roleArn.trim() || !connectionId) return; + + const arnPattern = /^arn:aws:iam::\d{12}:role\/.+$/; + if (!arnPattern.test(roleArn.trim())) { + setSaveError('Invalid ARN format. Expected: arn:aws:iam:::role/'); + return; + } + + setSaving(true); + setSaveError(null); + try { + const resp = await api.put(`/v1/connections/${connectionId}/credentials`, { + credentials: { remediationRoleArn: roleArn.trim() }, + }); + if (resp.error) { + setSaveError(typeof resp.error === 'string' ? resp.error : 'Failed to save Role ARN'); + return; + } + toast.success('Remediation Role ARN saved'); + setRoleArn(''); + onSaved?.(); + } catch { + setSaveError('Failed to save Role ARN'); + } finally { + setSaving(false); + } + }, [api, connectionId, roleArn, onSaved]); + + return ( + + + + Enable Auto-Remediation + + Set up a remediation IAM role to enable auto-fix capabilities for + your AWS security findings. + + + +
+
+
+
+ +
+
+

Remediation Role Setup

+

+ Create a write-access IAM role for auto-fix +

+
+
+ +
+
+ + 1 + +

+ Copy the setup script and run it in AWS CloudShell +

+
+
+ + 2 + +

+ Paste the{' '} + Role ARN{' '} + from the output below +

+
+
+ +
+ + + + Open CloudShell + +
+
+ + {/* Role ARN input */} +
+ +
+ { + setRoleArn(e.target.value); + setSaveError(null); + }} + className="flex-1 rounded-md border bg-background px-3 py-2 text-xs placeholder:text-muted-foreground/50 focus:outline-none focus:ring-2 focus:ring-primary/30" + /> + +
+ {saveError && ( +

{saveError}

+ )} +
+ +

+ The remediation role is separate from your audit role — your audit + role stays read-only. +

+
+
+
+ ); +} + +function FindingRow({ + finding, + expanded, + onToggle, + remediationKey, + remediationEnabled, + capabilitiesLoaded, + onFix, + onSetup, +}: { + finding: Finding; + expanded: boolean; + onToggle: () => void; + remediationKey: string | null; + remediationEnabled: boolean; + capabilitiesLoaded: boolean; + onFix: (key: string) => void; + onSetup: () => void; +}) { + const severity = finding.severity?.toLowerCase() ?? 'info'; + const styles = SEVERITY_STYLES[severity] ?? SEVERITY_STYLES.info; + + const handleFixClick = (e: React.MouseEvent) => { + e.stopPropagation(); + if (!remediationKey) return; + if (remediationEnabled) { + onFix(remediationKey); + } else { + onSetup(); + } + }; + + const renderFixButton = () => { + if (!capabilitiesLoaded) { + return ( + e.stopPropagation()} className="shrink-0"> + + + ); + } + + if (!remediationKey) { + return null; + } + + // AI-powered: every finding with a key gets Fix + return ( + + ); + }; + + return ( +
+
{ + // Don't toggle if user clicked a button or interactive element + const target = e.target as HTMLElement; + if (target.closest('button') || target.closest('a') || target.tagName === 'BUTTON') return; + onToggle(); + }} + onKeyDown={(e) => { + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault(); + onToggle(); + } + }} + > + + {expanded ? ( + + ) : ( + + )} + + + + {finding.title ?? 'Untitled finding'} + + {finding.projectDisplayName && ( + + {finding.projectDisplayName} + + )} + e.stopPropagation()} onKeyDown={(e) => e.stopPropagation()}> + {renderFixButton()} + + + {severity} + +
+ {expanded && ( +
+ {finding.description && ( +

+ {finding.description} +

+ )} + {finding.remediation && ( +
+

Remediation

+

+ {finding.remediation} +

+
+ )} +
+ )} +
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/EmptyState.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/EmptyState.tsx index 3a2d1d00d3..e159de2dec 100644 --- a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/EmptyState.tsx +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/EmptyState.tsx @@ -67,7 +67,7 @@ interface ProviderFieldWithOptions extends ProviderFieldBase { options?: { value: string; label: string }[]; } -const PROVIDER_FIELDS: Record<'aws' | 'gcp' | 'azure', ProviderFieldWithOptions[]> = { +const PROVIDER_FIELDS: Partial> = { aws: [ { id: 'connectionName', @@ -89,48 +89,6 @@ const PROVIDER_FIELDS: Record<'aws' | 'gcp' | 'azure', ProviderFieldWithOptions[ type: 'password', }, ], - gcp: [ - { - id: 'organization_id', - label: 'Organization ID', - placeholder: '123456789012', - helpText: 'Console → IAM & Admin → Settings', - }, - { - id: 'service_account_key', - label: 'Service Account Key', - placeholder: 'Paste your JSON key here', - helpText: 'IAM & Admin → Service Accounts → Keys → Add Key', - type: 'textarea', - }, - ], - azure: [ - { - id: 'AZURE_SUBSCRIPTION_ID', - label: 'Subscription ID', - placeholder: '00000000-0000-0000-0000-000000000000', - helpText: 'Azure Portal → Subscriptions', - }, - { - id: 'AZURE_TENANT_ID', - label: 'Tenant ID', - placeholder: '00000000-0000-0000-0000-000000000000', - helpText: 'Azure Active Directory → Overview', - }, - { - id: 'AZURE_CLIENT_ID', - label: 'Client ID', - placeholder: '00000000-0000-0000-0000-000000000000', - helpText: 'App registrations → Overview', - }, - { - id: 'AZURE_CLIENT_SECRET', - label: 'Client Secret', - placeholder: 'Enter your client secret', - helpText: 'App registrations → Certificates & secrets', - type: 'password', - }, - ], }; type TriggerInfo = { @@ -154,7 +112,7 @@ export function EmptyState({ const api = useApi(); const { hasPermission } = usePermissions(); const canCreate = hasPermission('integration', 'create'); - const initialUsesDialog = initialProvider === 'aws' || initialProvider === 'azure'; + const initialUsesDialog = initialProvider === 'aws' || initialProvider === 'gcp' || initialProvider === 'azure'; const [step, setStep] = useState( initialProvider && !initialUsesDialog ? 'connect' : 'choose', ); @@ -162,8 +120,8 @@ export function EmptyState({ initialProvider && !initialUsesDialog ? initialProvider : null, ); const [showConnectDialog, setShowConnectDialog] = useState(initialUsesDialog); - const [connectDialogProvider, setConnectDialogProvider] = useState<'aws' | 'azure'>( - initialProvider === 'azure' ? 'azure' : 'aws', + const [connectDialogProvider, setConnectDialogProvider] = useState<'aws' | 'gcp' | 'azure'>( + initialProvider === 'azure' ? 'azure' : initialProvider === 'gcp' ? 'gcp' : 'aws', ); const [credentials, setCredentials] = useState>({}); const [errors, setErrors] = useState>({}); @@ -172,14 +130,14 @@ export function EmptyState({ const [awsAccountId, setAwsAccountId] = useState(''); useEffect(() => { - if (initialProvider === 'aws' || initialProvider === 'azure') { + if (initialProvider === 'aws' || initialProvider === 'gcp' || initialProvider === 'azure') { setConnectDialogProvider(initialProvider); setShowConnectDialog(true); } }, [initialProvider]); const handleProviderSelect = (providerId: CloudProvider) => { - if (providerId === 'aws' || providerId === 'azure') { + if (providerId === 'aws' || providerId === 'gcp' || providerId === 'azure') { setConnectDialogProvider(providerId); setShowConnectDialog(true); return; @@ -215,6 +173,7 @@ export function EmptyState({ const validateFields = (): boolean => { if (!selectedProvider) return false; const fields = PROVIDER_FIELDS[selectedProvider]; + if (!fields) return true; // OAuth providers (GCP/Azure) don't have credential fields const newErrors: Record = {}; fields.forEach((field) => { @@ -450,12 +409,18 @@ export function EmptyState({ onOpenChange={(open) => setShowConnectDialog(open)} integrationId={connectDialogProvider} integrationName={ - connectDialogProvider === 'azure' ? 'Microsoft Azure' : 'Amazon Web Services' + connectDialogProvider === 'gcp' + ? 'Google Cloud Platform' + : connectDialogProvider === 'azure' + ? 'Microsoft Azure' + : 'Amazon Web Services' } integrationLogoUrl={ - connectDialogProvider === 'azure' - ? 'https://img.logo.dev/azure.microsoft.com?token=pk_AZatYxV5QDSfWpRDaBxzRQ' - : 'https://img.logo.dev/aws.amazon.com?token=pk_AZatYxV5QDSfWpRDaBxzRQ' + connectDialogProvider === 'gcp' + ? 'https://img.logo.dev/cloud.google.com?token=pk_AZatYxV5QDSfWpRDaBxzRQ' + : connectDialogProvider === 'azure' + ? 'https://img.logo.dev/azure.microsoft.com?token=pk_AZatYxV5QDSfWpRDaBxzRQ' + : 'https://img.logo.dev/aws.amazon.com?token=pk_AZatYxV5QDSfWpRDaBxzRQ' } onConnected={() => { setShowConnectDialog(false); @@ -503,7 +468,7 @@ export function EmptyState({ // Step 2: Connect (Form) if (step === 'connect' && provider) { - const fields = PROVIDER_FIELDS[provider.id]; + const fields = PROVIDER_FIELDS[provider.id as keyof typeof PROVIDER_FIELDS]; return ( @@ -547,7 +512,7 @@ export function EmptyState({ - {fields.map((field) => { + {fields?.map((field) => { const stringValue: string = typeof credentials[field.id] === 'string' ? (credentials[field.id] as string) diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/GcpSetupGuide.test.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/GcpSetupGuide.test.tsx new file mode 100644 index 0000000000..a2f63857ce --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/GcpSetupGuide.test.tsx @@ -0,0 +1,109 @@ +import { render, screen, waitFor } from '@testing-library/react'; +import { describe, expect, it, vi, beforeEach } from 'vitest'; + +const mockPost = vi.fn(); + +vi.mock('@/hooks/use-api', () => ({ + useApi: () => ({ + post: mockPost, + }), +})); + +vi.mock('sonner', () => ({ + toast: { + error: vi.fn(), + success: vi.fn(), + message: vi.fn(), + }, +})); + +import { GcpSetupGuide } from './GcpSetupGuide'; + +describe('GcpSetupGuide', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('renders actionable failed setup steps from API response', async () => { + mockPost.mockResolvedValue({ + data: { + email: 'user@example.com', + organizationId: '123456789', + steps: [ + { + id: 'enable_security_command_center_api', + name: 'Enable Security Command Center API', + success: false, + error: 'Permission denied', + requiredForScan: true, + resolveAction: { + label: 'Resolve this', + method: 'POST', + endpoint: '/v1/cloud-security/setup-gcp/conn_1/resolve-step', + body: { stepId: 'enable_security_command_center_api' }, + }, + adminActions: [ + { + kind: 'link', + label: 'Open API', + url: 'https://console.cloud.google.com/apis/library/securitycenter.googleapis.com', + }, + ], + }, + { + id: 'grant_findings_viewer_role', + name: 'Grant Findings Viewer role', + success: false, + error: 'Need org admin role', + requiredForScan: true, + resolveAction: { + label: 'Resolve this', + method: 'POST', + endpoint: '/v1/cloud-security/setup-gcp/conn_1/resolve-step', + body: { stepId: 'grant_findings_viewer_role' }, + }, + adminActions: [ + { + kind: 'link', + label: 'Open IAM', + url: 'https://console.cloud.google.com/iam-admin/iam', + }, + ], + }, + ], + }, + }); + + render( + , + ); + + await waitFor(() => + expect( + screen.getByText('Some required setup steps need manual action:'), + ).toBeInTheDocument(), + ); + + expect( + screen.getAllByText('Enable Security Command Center API').length, + ).toBeGreaterThan(0); + expect(screen.getAllByText('Grant Findings Viewer role').length).toBeGreaterThan(0); + expect(screen.getByRole('link', { name: /Open API/i })).toHaveAttribute( + 'href', + 'https://console.cloud.google.com/apis/library/securitycenter.googleapis.com', + ); + expect(screen.getByRole('link', { name: /Open IAM/i })).toHaveAttribute( + 'href', + 'https://console.cloud.google.com/iam-admin/iam', + ); + expect(screen.getByRole('button', { name: 'Resolve all' })).toBeInTheDocument(); + expect(screen.getAllByRole('button', { name: 'Resolve this' })).toHaveLength(2); + }); +}); diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/GcpSetupGuide.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/GcpSetupGuide.tsx new file mode 100644 index 0000000000..53cb567793 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/GcpSetupGuide.tsx @@ -0,0 +1,450 @@ +'use client'; + +import { useApi } from '@/hooks/use-api'; +import { Check, Copy, ExternalLink, Loader2, X } from 'lucide-react'; +import { useEffect, useRef, useState } from 'react'; +import { toast } from 'sonner'; + +interface GcpSetupGuideProps { + connectionId: string; + hasOrgId: boolean; + hasSelectedProjects: boolean; + onRunScan: () => void; + isScanning: boolean; + orgId: string; +} + +interface GcpProject { + id: string; + name: string; +} + +interface SetupStep { + id: string; + name: string; + success: boolean; + error?: string; + actionUrl?: string; + actionText?: string; + requiredForScan?: boolean; + resolveAction?: { + label: string; + method: 'POST'; + endpoint: string; + body: { stepId: string }; + }; + adminActions?: Array< + | { kind: 'link'; label: string; url: string } + | { kind: 'command'; label: string; command: string } + >; +} + + +export function GcpSetupGuide({ + connectionId, + hasOrgId, + hasSelectedProjects, + onRunScan, + isScanning, + orgId, +}: GcpSetupGuideProps) { + const api = useApi(); + const [isSettingUp, setIsSettingUp] = useState(false); + const [resolvingStepId, setResolvingStepId] = useState(null); + const [copiedCommandKey, setCopiedCommandKey] = useState(null); + const [projects, setProjects] = useState([]); + const [selectedProjectId, setSelectedProjectId] = useState(null); + const [setupResult, setSetupResult] = useState<{ + email: string | null; + steps: SetupStep[]; + organizationId?: string; + } | null>(null); + + const ranRef = useRef(false); + + // Auto-run setup on first mount (only if projects are selected) + useEffect(() => { + if (ranRef.current || !hasSelectedProjects) return; + ranRef.current = true; + handleAutoSetup(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [hasSelectedProjects]); + + const hasBlockingFailuresForSteps = (steps: SetupStep[]) => + steps.some((step) => !step.success && step.requiredForScan !== false); + + const handleAutoSetup = async (overrideProjectId?: string) => { + setIsSettingUp(true); + try { + const body: { projectId?: string } = {}; + const projectToUse = overrideProjectId ?? selectedProjectId; + if (projectToUse) body.projectId = projectToUse; + + const resp = await api.post<{ + email: string | null; + steps: SetupStep[]; + organizationId?: string; + projectId?: string; + projects?: GcpProject[]; + }>(`/v1/cloud-security/setup-gcp/${connectionId}`, body); + + if (resp.error) { + toast.error(typeof resp.error === 'string' ? resp.error : 'Setup failed'); + return; + } + + if (resp.data) { + setSetupResult(resp.data); + if (resp.data.projects?.length) setProjects(resp.data.projects); + if (resp.data.projectId) setSelectedProjectId(resp.data.projectId); + + const succeeded = resp.data.steps.filter((s) => s.success).length; + const total = resp.data.steps.length; + const hasBlockingFailures = hasBlockingFailuresForSteps(resp.data.steps); + if (!hasBlockingFailures) { + toast.success('Required setup complete — running first scan...'); + onRunScan(); + } else { + toast.message(`${succeeded}/${total} steps completed. See details below.`); + } + } + } catch { + toast.error('Setup failed'); + } finally { + setIsSettingUp(false); + } + }; + + const handleResolveStep = async (step: SetupStep) => { + if (!step.resolveAction) return; + setResolvingStepId(step.id); + try { + const resp = await api.post<{ + email: string | null; + step: SetupStep; + organizationId?: string; + projects?: GcpProject[]; + }>(step.resolveAction.endpoint, step.resolveAction.body); + + if (resp.data?.projects?.length) setProjects(resp.data.projects); + + if (resp.error || !resp.data?.step) { + toast.error(typeof resp.error === 'string' ? resp.error : 'Could not resolve this step'); + return; + } + + const wasBlocking = setupResult ? hasBlockingFailuresForSteps(setupResult.steps) : true; + let nextSteps: SetupStep[] = []; + + setSetupResult((prev) => { + const previous = prev ?? { + email: resp.data?.email ?? null, + organizationId: resp.data?.organizationId, + steps: [], + }; + const existing = previous.steps.find((s) => s.id === resp.data!.step.id); + nextSteps = existing + ? previous.steps.map((s) => (s.id === resp.data!.step.id ? resp.data!.step : s)) + : [...previous.steps, resp.data!.step]; + + return { + ...previous, + email: resp.data?.email ?? previous.email, + organizationId: resp.data?.organizationId ?? previous.organizationId, + steps: nextSteps, + }; + }); + + if (resp.data.step.success) { + toast.success(`${resp.data.step.name} resolved`); + } else { + toast.message(`Still blocked: ${resp.data.step.name}`); + } + + const isBlockingNow = hasBlockingFailuresForSteps( + nextSteps.length > 0 ? nextSteps : setupResult?.steps ?? [], + ); + if (wasBlocking && !isBlockingNow) { + toast.success('Required setup complete — running first scan...'); + onRunScan(); + } + } catch { + toast.error('Could not resolve this step'); + } finally { + setResolvingStepId(null); + } + }; + + const handleCopyCommand = async (copyKey: string, command: string) => { + try { + await navigator.clipboard.writeText(command); + setCopiedCommandKey(copyKey); + setTimeout(() => setCopiedCommandKey(null), 1600); + toast.success('Command copied'); + } catch { + toast.error('Failed to copy command'); + } + }; + + const allStepsSucceeded = setupResult?.steps.every((s) => s.success); + const failedSteps = setupResult?.steps.filter((s) => !s.success) ?? []; + const failedRequiredSteps = failedSteps.filter((step) => step.requiredForScan !== false); + const failedOptionalSteps = failedSteps.filter((step) => step.requiredForScan === false); + const hasBlockingFailures = failedRequiredSteps.length > 0; + const getAdminActions = (step: SetupStep) => + step.adminActions ?? + (step.actionUrl && step.actionText + ? [{ kind: 'link' as const, label: step.actionText, url: step.actionUrl }] + : []); + + return ( +
+
+
+

Get started with GCP scanning

+

+ OAuth signs in your account, but GCP still requires org-level IAM/API access for Security Command Center. We'll try to set it up automatically first. +

+

+ For full auto-fix and rollback capabilities, connect with a GCP account that has Owner or Editor role on the selected projects. +

+
+ + {/* No projects selected — direct user to integrations page */} + {!hasSelectedProjects && ( + + )} + + {/* Auto-setup in progress */} + {hasSelectedProjects && !setupResult && ( +
+ + {hasOrgId && } + +
+ +

Setting up GCP scanning...

+
+
+ )} + + {/* Setup results */} + {setupResult && ( +
+ + {setupResult.organizationId && ( + + )} + {setupResult.email && ( + + )} + + {/* Project info */} + {selectedProjectId && ( + p.id === selectedProjectId)?.name ?? selectedProjectId}`} + /> + )} + + {setupResult.steps.map((step) => ( + + ))} +
+ )} + + {/* Manual fallback for failed steps */} + {setupResult && !allStepsSucceeded && ( +
+

+ {hasBlockingFailures + ? 'Some required setup steps need manual action:' + : 'Scan can still work. The remaining steps are optional for auto-setup:'} +

+
+ {failedSteps.map((step) => ( +
+
+
+

{step.name}

+ {step.error && ( +

{step.error}

+ )} +
+ + {step.requiredForScan === false ? 'Optional' : 'Required'} + +
+
+ {step.resolveAction && ( + + )} + {getAdminActions(step).map((action, index) => + action.kind === 'link' ? ( + + {action.label} + + + ) : ( + + ), + )} +
+
+ ))} +
+ {!hasBlockingFailures && failedOptionalSteps.length > 0 && ( +

+ Optional steps improve automatic setup and future onboarding, but they are not required for reading findings. +

+ )} +
+ )} + + {/* Run scan button — only shown if setup partially failed */} + {setupResult && !allStepsSucceeded && ( +
+ + +
+ )} +
+ +
+ ); +} + +function StepRow({ + done, + failed, + optional, + label, + error, +}: { + done?: boolean; + failed?: boolean; + optional?: boolean; + label: string; + error?: string; +}) { + return ( +
+
+ {done && } + {failed && } +
+
+

+ {label} +

+ {error && ( +

{error}

+ )} +
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/PermissionErrorPanel.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/PermissionErrorPanel.tsx new file mode 100644 index 0000000000..4211316a79 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/PermissionErrorPanel.tsx @@ -0,0 +1,290 @@ +'use client'; + +import { Button } from '@trycompai/ui/button'; +import { Check, Copy, ExternalLink, RefreshCw, ShieldAlert } from 'lucide-react'; +import { useState } from 'react'; +import { toast } from 'sonner'; + +interface PermissionErrorPanelProps { + error: string; + /** Missing IAM actions extracted from the error (preferred). */ + missingActions?: string[]; + /** Planned API calls from the preview (fallback for AWS). */ + apiCalls?: string[]; + /** Ready-to-paste fix script from backend (preferred over client-side). */ + fixScript?: string; + /** Cloud provider — affects script format and links. */ + provider?: 'aws' | 'gcp' | 'azure'; + /** Retry the remediation after the user fixes permissions. */ + onRetry?: () => void; + isRetrying?: boolean; + isWaiting?: boolean; +} + +/** Extract IAM actions from the error message itself (client-side parsing). */ +function extractActionsFromError(error: string): string[] { + const patterns = [ + // AWS patterns + /not authorized to perform:\s*([\w:*]+)/i, + /required\s+([\w:*]+)\s+permission/i, + /denied.*?(?:action|for):\s*([\w:*]+)/i, + // GCP patterns + /permission\s+'([\w.]+)'/i, + /does not have\s+([\w.]+)\s+access/i, + /'([\w.]+)'\s*denied/i, + ]; + const actions = new Set(); + for (const pattern of patterns) { + const match = error.match(pattern); + if (match?.[1]) actions.add(match[1]); + } + return [...actions]; +} + +/** Known AWS service-linked role patterns. */ +const SERVICE_LINKED_ROLE_PATTERNS: { pattern: RegExp; service: string; command: string }[] = [ + { pattern: /config.*service-linked role/i, service: 'AWS Config', command: 'aws iam create-service-linked-role --aws-service-name config.amazonaws.com' }, + { pattern: /guardduty.*service-linked role|service-linked role.*guardduty/i, service: 'GuardDuty', command: 'aws iam create-service-linked-role --aws-service-name guardduty.amazonaws.com' }, + { pattern: /inspector.*service-linked role/i, service: 'Inspector', command: 'aws iam create-service-linked-role --aws-service-name inspector2.amazonaws.com' }, + { pattern: /macie.*service-linked role/i, service: 'Macie', command: 'aws iam create-service-linked-role --aws-service-name macie.amazonaws.com' }, +]; + +function detectServiceLinkedRole(error: string): { service: string; command: string } | null { + if (!error.toLowerCase().includes('service-linked role')) return null; + for (const entry of SERVICE_LINKED_ROLE_PATTERNS) { + if (entry.pattern.test(error)) return entry; + } + return null; +} + +function buildAwsFixScript(actions: string[]): string | null { + if (actions.length === 0) return null; + const policy = JSON.stringify({ + Version: '2012-10-17', + Statement: [{ Effect: 'Allow', Action: actions, Resource: '*' }], + }); + return `aws iam put-role-policy --role-name CompAI-Remediator --policy-name CompAI-AutoFix --policy-document '${policy}'`; +} + +function isAzureError(error: string): boolean { + return ( + error.includes('AuthorizationFailed') || + error.includes('management.azure.com') || + error.includes('does not have authorization') + ); +} + +function isGcpError(error: string): boolean { + return ( + error.includes('PERMISSION_DENIED') || + error.includes('googleapis.com') || + /does not have\s+[\w.]+\s+access/i.test(error) || + /permission\s+'[\w.]+'/i.test(error) + ); +} + +export function PermissionErrorPanel({ + error, + missingActions, + apiCalls, + fixScript: backendScript, + provider, + onRetry, + isRetrying, + isWaiting, +}: PermissionErrorPanelProps) { + const [copied, setCopied] = useState(false); + + // Auto-detect provider if not specified + const detectedProvider = provider ?? (isAzureError(error) ? 'azure' : isGcpError(error) ? 'gcp' : 'aws'); + const isGcp = detectedProvider === 'gcp'; + const isAzure = detectedProvider === 'azure'; + + const serviceLinkedRole = isGcp ? null : detectServiceLinkedRole(error); + const isPermissionError = + serviceLinkedRole !== null || + error.includes('not authorized') || + error.includes('AccessDenied') || + error.includes('access denied') || + error.includes('PERMISSION_DENIED') || + error.includes('Permission denied') || + (error.includes('required') && error.includes('permission')); + + if (!isPermissionError) { + // Truncate long AI-generated messages for clean UX + const shortError = error.length > 150 + ? error.slice(0, 150).replace(/\s+\S*$/, '') + '…' + : error; + const hasDetails = error.length > 150; + + return ( +
+

Fix could not be applied

+

{shortError}

+ {hasDetails && ( +
+ + Show full details + +

{error}

+
+ )} + {onRetry && ( + + )} +
+ ); + } + + // Priority: service-linked role > backend script > client-parsed + const parsedFromError = extractActionsFromError(error); + const actions = missingActions?.length + ? missingActions + : parsedFromError.length + ? parsedFromError + : (apiCalls ?? []); + + const script = serviceLinkedRole + ? serviceLinkedRole.command + : backendScript ?? (isGcp || isAzure ? null : buildAwsFixScript(actions)); + + const shellName = isAzure ? 'Cloud Shell' : isGcp ? 'Cloud Shell' : 'CloudShell'; + const shellUrl = isAzure + ? 'https://portal.azure.com/#cloudshell/' + : isGcp + ? 'https://console.cloud.google.com/cloudshell' + : 'https://console.aws.amazon.com/cloudshell'; + const propagationText = isAzure + ? 'Role assignment changes in Azure may take a few minutes to propagate.' + : isGcp + ? 'IAM changes in GCP may take a few minutes to propagate.' + : 'IAM permission changes can take up to 10 seconds to propagate in AWS.'; + + const handleCopy = () => { + if (!script) return; + navigator.clipboard.writeText(script); + setCopied(true); + toast.success('Script copied to clipboard'); + setTimeout(() => setCopied(false), 2000); + }; + + return ( +
+
+
+ +
+

+ {serviceLinkedRole + ? 'Missing Service-Linked Role' + : isGcp + ? 'Missing GCP IAM Permission' + : 'Missing IAM Permission'} +

+

+ {serviceLinkedRole + ? `${serviceLinkedRole.service} requires a service-linked role. Create it with the command below, then retry.` + : isGcp + ? ( + <> + Your GCP account is missing permissions needed for this fix. + {actions.length > 0 && ( + <> + {' '}Missing:{' '} + {actions.map((a, i) => ( + + {i > 0 && ', '} + {a} + + ))} + + )} + + ) + : ( + <> + The remediation role is missing permissions needed for this fix. + {actions.length > 0 && ( + <> + {' '}Required:{' '} + {actions.map((a, i) => ( + + {i > 0 && ', '} + {a} + + ))} + + )} + + )} +

+
+
+
+ + {script && ( +
+

+ Run this in {isAzure ? 'Azure' : isGcp ? 'Google' : 'AWS'} {shellName} to add the permission: +

+
+            {script}
+          
+ {isGcp && ( +

+ Replace YOUR_EMAIL with your Google account email and YOUR_PROJECT_ID with your GCP project ID. +

+ )} +
+ + + + Open {shellName} + + {onRetry && ( + + )} +
+

+ {propagationText} +

+
+ )} +
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ProviderTabs.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ProviderTabs.tsx index 6d3268160d..b414cbd5ba 100644 --- a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ProviderTabs.tsx +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ProviderTabs.tsx @@ -1,8 +1,16 @@ +import { useApi } from '@/hooks/use-api'; +import { useConnectionServices } from '@/hooks/use-integration-platform'; +import { CLOUD_RECONNECT_CUTOFF_LABEL } from '@/lib/cloud-reconnect-policy'; import { Button, Select, SelectContent, SelectItem, SelectTrigger, SelectValue, Tabs, TabsContent, TabsList, TabsTrigger } from '@trycompai/design-system'; import { Add } from '@trycompai/design-system/icons'; -import { useState } from 'react'; +import { useCallback, useEffect, useRef, useState } from 'react'; +import { toast } from 'sonner'; import type { Finding, Provider } from '../types'; +import { ActivitySection } from '@/app/(app)/[orgId]/integrations/[slug]/components/ActivitySection'; +import { RemediationHistorySection } from '@/app/(app)/[orgId]/integrations/[slug]/components/RemediationHistorySection'; +import { CloudTestsSection } from './CloudTestsSection'; import { ResultsView } from './ResultsView'; +import { ServicesGrid } from './ServicesGrid'; interface ProviderTabsProps { providerGroups: Record; @@ -17,8 +25,10 @@ interface ProviderTabsProps { onAddConnection: (providerType: string) => void; onConfigure: (provider: Provider) => void; needsConfiguration: (provider: Provider) => boolean; + requiresReconnect: (provider: Provider) => boolean; canRunScan?: boolean; canAddConnection?: boolean; + orgId: string; } const formatProviderLabel = (providerType: string): string => { @@ -35,58 +45,6 @@ const formatProviderLabel = (providerType: string): string => { .join(' '); }; -/** - * AWS region pattern: matches formats like us-east-1, eu-west-2, ap-southeast-1, etc. - */ -const AWS_REGION_PATTERN = - /^(us|eu|ap|sa|ca|me|af|il)-(north|south|east|west|central|northeast|southeast|northwest|southwest)-\d$/; - -const isValidAwsRegion = (value: string): boolean => { - return AWS_REGION_PATTERN.test(value); -}; - -const extractRegionFromTitle = (title: string | null | undefined): string | null => { - if (!title) return null; - const match = title.match(/\s\(([-a-z0-9]+)\)\s*$/i); - if (!match) return null; - const candidate = match[1].toLowerCase(); - // Only return if it's actually an AWS region, not other suffixes like "nacl", "public", etc. - return isValidAwsRegion(candidate) ? candidate : null; -}; - -const stripRegionSuffix = (title: string | null | undefined): string | null => { - if (!title) return null; - // Only strip the suffix if it's a valid AWS region - const match = title.match(/\s\(([-a-z0-9]+)\)\s*$/i); - if (!match) return title; - const candidate = match[1].toLowerCase(); - // Keep non-region suffixes like "(nacl)", "(public)" as part of the title - return isValidAwsRegion(candidate) ? title.replace(/\s\(([-a-z0-9]+)\)\s*$/i, '').trim() : title; -}; - -const buildRegionOptions = ( - connection: Provider, - findings: Finding[], -): Array<{ id: string; label: string }> => { - const regionMap = new Map(); - - if (connection.regions?.length) { - for (const region of connection.regions) { - regionMap.set(region.toLowerCase(), region); - } - } else { - for (const finding of findings) { - const region = extractRegionFromTitle(finding.title); - if (region && !regionMap.has(region)) { - regionMap.set(region, region); - } - } - } - - return Array.from(regionMap.entries()) - .sort(([a], [b]) => a.localeCompare(b)) - .map(([id, label]) => ({ id, label })); -}; function ConnectionDetails({ connection }: { connection: Provider }) { const details: string[] = []; @@ -119,6 +77,153 @@ function ConnectionDetails({ connection }: { connection: Provider }) { ); } +/** Cloud provider connection with full tabbed UI (AWS + GCP) */ +function CloudConnectionContent({ + connection, + orgId, + onScanComplete, +}: { + connection: Provider; + orgId: string; + onScanComplete: () => void; +}) { + const api = useApi(); + const { + services, + meta: servicesMeta, + refresh: refreshServices, + updateServices, + } = useConnectionServices(connection.id); + const [togglingService, setTogglingService] = useState(null); + const detectedRef = useRef(false); + + // Auto-detect services on first load (AWS via Cost Explorer, GCP via Service Usage API) + useEffect(() => { + if (detectedRef.current || !connection.id) return; + if (connection.integrationId !== 'aws' && connection.integrationId !== 'gcp') return; + detectedRef.current = true; + + api.post(`/v1/cloud-security/detect-services/${connection.id}`, {}).then((resp) => { + if (!resp.error) { + const data = resp.data as { services?: string[] }; + if (data?.services?.length) { + toast.success(`${data.services.length} services detected`); + refreshServices(); + } + } + }); + }, [connection.id, connection.integrationId, api, refreshServices]); + + const handleToggleService = useCallback( + async (serviceId: string, enabled: boolean): Promise => { + setTogglingService(serviceId); + try { + await updateServices(serviceId, enabled); + return true; + } catch { + return false; + } finally { + setTogglingService(null); + } + }, + [updateServices], + ); + + // Derive manifest-like services from connection services + const manifestServices = services.map((s) => ({ + id: s.id, + name: s.name ?? s.id, + description: s.description ?? '', + implemented: s.implemented ?? true, + })); + + const enabledCount = services.filter((s) => s.enabled).length; + const waitingForDetection = connection.integrationId === 'gcp' && servicesMeta.detectionReady === false; + const showEnabledCount = !waitingForDetection; + + return ( + + + Findings + Activity + Remediations + + Services{showEnabledCount && enabledCount > 0 ? ` (${enabledCount})` : ''} + + + + +
+ +
+
+ + +
+ +
+
+ + +
+ +
+
+ + +
+
+
+

Scan Configuration

+

+ Toggle which services to include in scans.{connection.integrationId === 'aws' ? ' New services are auto-detected from your AWS usage.' : ''} +

+
+
+
+

Daily automated scan

+

+ Runs every day at 5:00 AM UTC{enabledCount > 0 ? ` across ${enabledCount} service${enabledCount !== 1 ? 's' : ''} + security baseline` : ' on security baseline checks'} +

+
+ + Active + +
+
+ {waitingForDetection ? ( +
+

Detecting active GCP services...

+

+ We'll show real service toggles as soon as detection completes. +

+
+ ) : manifestServices.length > 0 ? ( + + ) : ( +

+ No services detected yet.{connection.integrationId === 'aws' ? ' Services are auto-detected from your AWS billing data.' : ' Run a scan to detect services.'} +

+ )} +
+
+
+ ); +} + export function ProviderTabs({ providerGroups, providerTypes, @@ -132,11 +237,11 @@ export function ProviderTabs({ onAddConnection, onConfigure, needsConfiguration, + requiresReconnect, canRunScan, canAddConnection, + orgId, }: ProviderTabsProps) { - const [activeRegionTabs, setActiveRegionTabs] = useState>({}); - return (
@@ -208,58 +313,67 @@ export function ProviderTabs({
{connections.map((connection) => { - const connFindings = findingsByProvider[connection.id] ?? []; - const regionOptions = buildRegionOptions(connection, connFindings); - const showRegionTabs = - connection.integrationId.toLowerCase() === 'aws' && regionOptions.length >= 1; - const activeRegion = activeRegionTabs[connection.id] || 'all'; - const filteredFindings = - showRegionTabs && activeRegion !== 'all' - ? connFindings.filter( - (finding) => extractRegionFromTitle(finding.title) === activeRegion, - ) - : connFindings; - const displayFindings = filteredFindings.map((finding) => ({ - ...finding, - title: stripRegionSuffix(finding.title), - })); + const reconnectRequired = requiresReconnect(connection); return (
- - - {showRegionTabs && ( -
- - setActiveRegionTabs((prev) => ({ - ...prev, - [connection.id]: value, - })) - } - > - - All regions - {regionOptions.map((region) => ( - - {region.label} - - ))} - - + {reconnectRequired && ( +
+
+
+

Reconnect this account

+

+ This connection was created before {CLOUD_RECONNECT_CUTOFF_LABEL}. Reconnect it to keep scans and remediation fully reliable. +

+
+ {canAddConnection !== false && ( + + )} +
)} - onRunScan(connection.id)} - isScanning={isScanning} - needsConfiguration={needsConfiguration(connection)} - onConfigure={() => onConfigure(connection)} - canRunScan={canRunScan} - /> + + + {/* New platform connections get full tabbed UI */} + {!connection.isLegacy ? ( + onRunScan(connection.id)} + /> + ) : ( + <> + {/* Upgrade banner for legacy AWS connections */} + {connection.isLegacy && connection.integrationId === 'aws' && ( + +
+

Auto-fix is available

+

+ Upgrade to the new connection to enable one-click fixes, batch remediation, and rollback for all findings. +

+
+ + Upgrade → + +
+ )} + onRunScan(connection.id)} + isScanning={isScanning} + needsConfiguration={needsConfiguration(connection)} + onConfigure={() => onConfigure(connection)} + canRunScan={canRunScan} + /> + + )}
); diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/RemediationDialog.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/RemediationDialog.tsx new file mode 100644 index 0000000000..53f2da103b --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/RemediationDialog.tsx @@ -0,0 +1,676 @@ +'use client'; + +import { useApi } from '@/hooks/use-api'; +import { Badge } from '@trycompai/ui/badge'; +import { Button } from '@trycompai/ui/button'; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from '@trycompai/ui/dialog'; +import { AlertTriangle, ListOrdered, Loader2, RotateCcw } from 'lucide-react'; +import { useCallback, useEffect, useRef, useState } from 'react'; +import { toast } from 'sonner'; +import { AcknowledgmentPanel } from './AcknowledgmentPanel'; +import { PermissionErrorPanel } from './PermissionErrorPanel'; + +interface RemediationDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + connectionId: string; + checkResultId: string; + remediationKey: string; + findingTitle: string; + providerSlug?: string; + guidedOnly?: boolean; + guidedSteps?: string[]; + risk?: string; + description?: string; + onComplete?: () => void; +} + +interface PreviewData { + currentState: Record; + proposedState: Record; + description: string; + risk: string; + apiCalls: string[]; + requiresAcknowledgment?: 'type-to-confirm' | 'checkbox'; + acknowledgmentMessage?: string; + confirmationPhrase?: string; + guidedOnly?: boolean; + guidedSteps?: string[]; + rollbackSupported?: boolean; + missingPermissions?: string[]; + permissionFixScript?: string; + allRequiredPermissions?: string[]; +} + +const RISK_STYLES: Record = { + low: 'border-emerald-200 bg-emerald-50 text-emerald-700', + medium: 'border-yellow-200 bg-yellow-50 text-yellow-700', + high: 'border-red-200 bg-red-50 text-red-700', + critical: 'border-purple-200 bg-purple-50 text-purple-700', +}; + +// ─── Helper components (must be declared before RemediationDialog) ────── + +function RichText({ text }: { text: string }) { + const urlRegex = /(https?:\/\/[^\s)]+)/g; + const parts = text.split(urlRegex); + if (parts.length === 1) { + return

{text}

; + } + return ( +

+ {parts.map((part, i) => + part.match(urlRegex) ? ( + {part} + ) : ({part}), + )} +

+ ); +} + +function CodeBlock({ code }: { code: string }) { + const [copied, setCopied] = useState(false); + const handleCopy = () => { + navigator.clipboard.writeText(code); + setCopied(true); + setTimeout(() => setCopied(false), 2000); + }; + return ( +
+ +
+        {code}
+      
+
+ ); +} + +function TextSegment({ text }: { text: string }) { + const parts = text.split(/(`[^`]+`)/g); + const rendered: React.ReactNode[] = []; + for (let i = 0; i < parts.length; i++) { + const part = parts[i] ?? ''; + if (part.startsWith('`') && part.endsWith('`')) { + const code = part.slice(1, -1); + if (code.length > 60 || code.startsWith('aws ') || code.includes(' --')) { + rendered.push(); + } else { + rendered.push({code}); + } + } else if (part.trim()) { + rendered.push(); + } + } + return <>{rendered}; +} + +function TextWithInlineCode({ text }: { text: string }) { + const jsonSplit = text.split(/(\{[^{}]*"(?:Version|Effect|Statement)"[^{}]*(?:\{[^{}]*\}[^{}]*)*\})/g); + const elements: React.ReactNode[] = []; + for (let i = 0; i < jsonSplit.length; i++) { + const segment = jsonSplit[i] ?? ''; + if (segment.startsWith('{') && (segment.includes('"Version"') || segment.includes('"Effect"'))) { + try { + elements.push(); + } catch { elements.push(); } + } else if (segment.trim()) { + elements.push(); + } + } + return <>{elements}; +} + +function StepContent({ text }: { text: string }) { + const tripleBacktickParts = text.split(/(```[\s\S]*?```)/g); + if (tripleBacktickParts.length > 1) { + return ( + <> + {tripleBacktickParts.map((part, i) => { + if (part.startsWith('```') && part.endsWith('```')) { + return ; + } + const trimmed = part.trim(); + if (!trimmed) return null; + return ; + })} + + ); + } + return ; +} + +function StateBlock({ label, state }: { label: string; state: Record }) { + return ( +
+

{label}

+
{JSON.stringify(state, null, 2)}
+
+ ); +} + +/** Animated loading steps that show progress during analysis. */ +function LoadingSteps({ providerSlug }: { providerSlug?: string }) { + const [step, setStep] = useState(0); + useEffect(() => { + const timers = [ + setTimeout(() => setStep(1), 1500), + setTimeout(() => setStep(2), 4000), + setTimeout(() => setStep(3), 7000), + ]; + return () => timers.forEach(clearTimeout); + }, []); + + const providerName = providerSlug === 'gcp' ? 'GCP' : providerSlug === 'azure' ? 'Azure' : 'AWS'; + const steps = [ + { label: 'Analyzing finding', sub: 'Reviewing security configuration' }, + { label: `Reading ${providerName} configuration`, sub: 'Fetching current resource state' }, + { label: 'Checking required permissions', sub: 'Verifying access' }, + { label: 'Preparing fix plan', sub: 'Generating remediation steps' }, + ]; + + const progress = ((step + 1) / steps.length) * 100; + + return ( +
+ {/* Progress bar */} +
+
+
+ +
+ {steps.map(({ label, sub }, i) => { + const done = i < step; + const active = i === step; + const pending = i > step; + + return ( +
+
+ {done ? ( + + + + ) : active ? ( + + ) : ( +
+ )} +
+
+

+ {label} +

+ {active && ( +

{sub}

+ )} +
+
+ ); + })} +
+
+ ); +} + +// ─── Main component ───────────────────────────────────────────────────── + +export function RemediationDialog({ + open, + onOpenChange, + connectionId, + checkResultId, + remediationKey, + findingTitle, + providerSlug, + guidedOnly, + guidedSteps, + risk, + description, + onComplete, +}: RemediationDialogProps) { + const api = useApi(); + const [preview, setPreview] = useState(null); + const [isLoadingPreview, setIsLoadingPreview] = useState(false); + const [isExecuting, setIsExecuting] = useState(false); + const [isWaitingPropagation, setIsWaitingPropagation] = useState(false); + const [succeeded, setSucceeded] = useState(false); + const [error, setError] = useState(null); + const [permissionError, setPermissionError] = useState<{ missingActions: string[]; fixScript?: string } | null>(null); + const [acknowledgment, setAcknowledgment] = useState(null); + + // Ref to store permissions across rechecks (avoids stale closure in useCallback) + const permissionsRef = useRef(undefined); + + const loadPreview = useCallback(async (recheck = false) => { + setIsLoadingPreview(true); + setError(null); + try { + const response = await api.post( + '/v1/cloud-security/remediation/preview', + { + connectionId, + checkResultId, + remediationKey, + // On recheck, send the cached permissions so backend doesn't re-run AI + ...(recheck && permissionsRef.current && { + cachedPermissions: permissionsRef.current, + }), + }, + ); + if (response.error) { + setError( + typeof response.error === 'string' + ? response.error + : 'Failed to load preview', + ); + return; + } + const previewData = response.data as PreviewData; + setPreview(previewData); + // Store permissions in ref so Recheck can access them without stale closure + if (previewData.allRequiredPermissions) { + permissionsRef.current = previewData.allRequiredPermissions; + } + } catch { + setError('Failed to load preview'); + } finally { + setIsLoadingPreview(false); + } + }, [api, connectionId, checkResultId, remediationKey]); + + useEffect(() => { + if (!open) return; + setError(null); + setPermissionError(null); + setAcknowledgment(null); + + // Guided-only: skip API call, use local data + if (guidedOnly && guidedSteps) { + setPreview({ + currentState: {}, + proposedState: {}, + description: description ?? '', + risk: risk ?? 'medium', + apiCalls: [], + guidedOnly: true, + guidedSteps, + rollbackSupported: false, + }); + return; + } + + setPreview(null); + loadPreview(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [open, remediationKey]); + + const handleExecute = async () => { + setIsExecuting(true); + setError(null); + setPermissionError(null); + try { + const response = await api.post<{ + status: string; + error?: string; + permissionError?: { missingActions: string[]; fixScript?: string }; + }>( + '/v1/cloud-security/remediation/execute', + { connectionId, checkResultId, remediationKey, acknowledgment }, + ); + if (response.error) { + const msg = + typeof response.error === 'string' + ? response.error + : 'Remediation failed'; + setError(msg); + return; + } + + const data = response.data; + if (data?.status === 'success') { + setPreview(null); + setError(null); + setSucceeded(true); + toast.success('Fix applied successfully'); + // Trigger re-scan, then close dialog after user sees confirmation + onComplete?.(); + setTimeout(() => { + onOpenChange(false); + setSucceeded(false); + }, 4000); + } else { + const msg = data?.error || 'Remediation failed'; + setError(msg); + if (data?.permissionError) { + setPermissionError(data.permissionError); + } + } + } catch { + setError('Remediation failed. Please try again.'); + } finally { + setIsExecuting(false); + } + }; + + const handleRetry = async () => { + setIsWaitingPropagation(true); + // IAM permission changes take up to 10s to propagate in AWS + await new Promise((r) => setTimeout(r, 10_000)); + setIsWaitingPropagation(false); + await handleExecute(); + }; + + const isGuided = preview?.guidedOnly; + + return ( + + +
+ + + {isGuided ? 'Remediation Steps' : 'Auto-Remediate Finding'} + + + {findingTitle} + + + +
+ {/* Applying state — shown while executing */} + {isExecuting && !succeeded && !error && ( +
+
+ +
+
+

Applying fix...

+

+ Executing changes to your cloud infrastructure. This may take a moment. +

+
+
+ )} + + {/* Success state */} + {succeeded && ( +
+
+ + + +
+
+

Fix applied successfully

+

+ Re-scanning to verify the changes... +

+
+
+ )} + + {isLoadingPreview && !succeeded && ( + preview ? ( + /* Recheck — just verifying permissions */ +
+ +

Verifying permissions

+
+ ) : ( + /* First load — full analysis */ + + ) + )} + + {error && !succeeded && ( + + )} + + {preview && !isLoadingPreview && ( + <> + {/* Guided-only: show steps directly */} + {isGuided && preview.guidedSteps && ( +
+ {/* Description + risk row */} +
+ {preview.description && ( +

+ {preview.description} +

+ )} + + {preview.risk} + +
+ + {/* Steps card */} +
+
+
+ +
+ + Follow these steps in the {providerSlug === 'azure' ? 'Azure Portal' : providerSlug === 'gcp' ? 'GCP Console' : 'AWS Console'} + +
+
+
    + {preview.guidedSteps.map((step, i) => ( +
  1. + + {i + 1} + +
    + +
    +
  2. + ))} +
+
+
+ + {/* Footer */} +
+

+ {preview.guidedSteps.length} steps to complete +

+ +
+
+ )} + + {/* Auto-fix: show preview + execute */} + {!isGuided && ( + <> +

+ {preview.description} +

+ +
+
+ Risk: + + {preview.risk} + +
+ {preview.rollbackSupported !== false && ( +
+ + Rollback available +
+ )} + {preview.rollbackSupported === false && ( +
+ + Irreversible +
+ )} +
+ + {/* Current vs Proposed */} +
+ + +
+ + {/* API calls — collapsible if many */} + {preview.apiCalls.length > 0 && ( +
+ + {preview.apiCalls.length} API calls + +
+ {preview.apiCalls.map((call, i) => { + const label = typeof call === 'string' + ? call + : `${(call as { method?: string }).method} ${(call as { endpoint?: string }).endpoint}`; + return ( + {label} + ); + })} +
+
+ )} + + {/* Missing permissions — show setup step BEFORE apply */} + {preview.missingPermissions && preview.missingPermissions.length > 0 ? ( +
+
+

+ {preview.missingPermissions.length} permissions needed +

+

Run in CloudShell, then Recheck

+
+
+ {preview.permissionFixScript && ( +
+                            {preview.permissionFixScript}
+                          
+ )} +
+ + + CloudShell + + +
+
+
+ ) : ( + /* Permissions OK — show acknowledgment */ + + )} + +
+ + +
+ + )} + + )} +
+
+
+
+ ); +} + diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/RemediationHistorySection.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/RemediationHistorySection.tsx new file mode 100644 index 0000000000..affc5809fd --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/RemediationHistorySection.tsx @@ -0,0 +1,311 @@ +'use client'; + +import { useApi } from '@/hooks/use-api'; +import { Button, Section, Stack, Text } from '@trycompai/design-system'; +import { RecentlyViewed, Undo } from '@trycompai/design-system/icons'; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, +} from '@trycompai/ui/dialog'; +import { formatDistanceToNow } from 'date-fns'; +import { Copy, ExternalLink, Loader2 } from 'lucide-react'; +import { useState } from 'react'; +import { toast } from 'sonner'; + +interface RemediationActionItem { + id: string; + remediationKey: string; + resourceId: string; + resourceType: string; + status: string; + riskLevel: string | null; + errorMessage: string | null; + initiatedById: string; + initiatedByName: string | null; + executedAt: string | null; + rolledBackAt: string | null; + createdAt: string; +} + +const STATUS_BADGE: Record = { + success: { variant: 'default', label: 'Success' }, + failed: { variant: 'destructive', label: 'Failed' }, + rolled_back: { variant: 'outline', label: 'Rolled Back' }, + rollback_failed: { variant: 'destructive', label: 'Rollback Failed' }, + executing: { variant: 'secondary', label: 'Executing' }, +}; + +function formatRemediationKey(key: string): string { + // Extract a clean, short description from the finding key + // e.g., "cloudwatch-cloudwatch-no-cloudtrail-integration-cloudtrail-not-integrated-with-cloudwatch-logs" + // → "CloudTrail not integrated with CloudWatch Logs" + const parts = key.split('-'); + // Skip the service prefix (first 1-2 parts that repeat) + const seen = new Set(); + const meaningful = parts.filter((p) => { + const lower = p.toLowerCase(); + if (seen.has(lower) || lower.length <= 2) return false; + seen.add(lower); + return true; + }); + if (meaningful.length === 0) return key; + return meaningful + .map((w) => w.charAt(0).toUpperCase() + w.slice(1)) + .join(' '); +} + +const getInitials = (name: string | null) => + name + ? name.split(' ').map((p) => p[0]).join('').toUpperCase().slice(0, 2) + : 'S'; + +function RemediationRow({ + action, + onRollback, +}: { + action: RemediationActionItem; + onRollback: (action: RemediationActionItem) => void; +}) { + const displayName = action.initiatedByName ?? 'System'; + const timeAgo = formatDistanceToNow(new Date(action.executedAt ?? action.createdAt), { + addSuffix: true, + }); + const badge = STATUS_BADGE[action.status] ?? STATUS_BADGE.executing; + const canRollback = action.status === 'success'; + const hasError = + action.errorMessage && + (action.status === 'failed' || action.status === 'rollback_failed'); + + return ( +
+ +
+

+ {displayName} + {' '} + applied {formatRemediationKey(action.remediationKey)} +

+

{timeAgo}

+
+ {canRollback && ( + + )} +
+ ); +} + +function RollbackConfirmDialog({ + action, + open, + onOpenChange, + onConfirm, + isLoading, + permError, + providerSlug, +}: { + action: RemediationActionItem | null; + open: boolean; + onOpenChange: (open: boolean) => void; + onConfirm: () => void; + isLoading: boolean; + permError?: { missingActions: string[]; script: string } | null; + providerSlug?: string; +}) { + if (!action) return null; + + const friendlyKey = formatRemediationKey(action.remediationKey); + const appliedAt = action.executedAt + ? formatDistanceToNow(new Date(action.executedAt), { addSuffix: true }) + : null; + + return ( + + + + Rollback + + This will undo the fix and revert your {providerSlug === 'azure' ? 'Azure' : providerSlug === 'gcp' ? 'GCP' : 'AWS'} infrastructure to its previous state. + + + +
+

{friendlyKey}

+ {appliedAt && ( +

Applied {appliedAt}

+ )} +
+ + {/* Permission error for rollback */} + {permError && ( +
+

Missing permissions for rollback

+
+ {permError.missingActions.map((a) => ( + {a} + ))} +
+
+ + + + {providerSlug === 'azure' ? 'Azure Shell' : providerSlug === 'gcp' ? 'Cloud Shell' : 'CloudShell'} + +
+
+ )} + +
+ + +
+
+
+ ); +} + +export function RemediationHistorySection({ connectionId, providerSlug }: { connectionId: string; providerSlug?: string }) { + const api = useApi(); + const [rollbackTarget, setRollbackTarget] = useState(null); + const [isRollingBack, setIsRollingBack] = useState(false); + const [rollbackPermError, setRollbackPermError] = useState<{ + missingActions: string[]; + script: string; + } | null>(null); + + const { data, isLoading, mutate } = api.useSWR<{ data: RemediationActionItem[]; count: number }>( + connectionId ? `/v1/cloud-security/remediation/actions?connectionId=${connectionId}` : null, + { revalidateOnFocus: false }, + ); + + const allActions = Array.isArray(data?.data?.data) ? data.data.data : (Array.isArray(data?.data) ? data.data : []); + const actions = allActions.filter((a) => a.status !== 'failed' && a.status !== 'executing'); + + const handleRollback = async () => { + if (!rollbackTarget) return; + setIsRollingBack(true); + setRollbackPermError(null); + try { + const response = await api.post<{ + status?: string; + message?: string; + missingActions?: string[]; + script?: string; + }>( + `/v1/cloud-security/remediation/${rollbackTarget.id}/rollback`, + {}, + ); + if (response.error) { + // Check for structured permission error + const errData = response.data; + if (errData?.missingActions && errData.script) { + setRollbackPermError({ + missingActions: errData.missingActions, + script: errData.script, + }); + return; + } + toast.error(typeof response.error === 'string' ? response.error : 'Rollback failed'); + return; + } + toast.success('Remediation rolled back successfully'); + setRollbackTarget(null); + await mutate(); + } catch { + toast.error('Rollback failed'); + } finally { + setIsRollingBack(false); + } + }; + + if (isLoading) { + return ( +
+ +
+ ); + } + + if (actions.length === 0) { + return ( +
+
+ + + + No remediations have been performed yet + + +
+
+ ); + } + + return ( + <> +
+
+ {actions.map((action) => ( + + ))} +
+
+ + { + if (!open) { + setRollbackTarget(null); + setRollbackPermError(null); + } + }} + onConfirm={handleRollback} + isLoading={isRollingBack} + permError={rollbackPermError} + /> + + ); +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ScheduledScanPopover.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ScheduledScanPopover.tsx new file mode 100644 index 0000000000..9ae0fe14c3 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ScheduledScanPopover.tsx @@ -0,0 +1,160 @@ +'use client'; + +import { useApi } from '@/hooks/use-api'; +import { useConnectionServices } from '@/hooks/use-integration-platform'; +import { Popover, PopoverContent, PopoverTrigger } from '@trycompai/ui/popover'; +import { Checkbox } from '@trycompai/ui/checkbox'; +import { Button, cn } from '@trycompai/design-system'; +import { EventSchedule } from '@trycompai/design-system/icons'; +import { Search } from 'lucide-react'; +import { useCallback, useMemo, useState } from 'react'; +import { toast } from 'sonner'; + +interface ScheduledScanPopoverProps { + connectionId: string; +} + +export function ScheduledScanPopover({ connectionId }: ScheduledScanPopoverProps) { + const apiClient = useApi(); + const { services, meta, refresh: refreshServices } = useConnectionServices(connectionId); + const [search, setSearch] = useState(''); + const [saving, setSaving] = useState(null); + const waitingForDetection = meta.providerSlug === 'gcp' && meta.detectionReady === false; + + const filteredServices = useMemo(() => { + if (!search) return services.filter((s) => s.implemented !== false); + const q = search.toLowerCase(); + return services + .filter((s) => s.implemented !== false) + .filter( + (s) => s.name.toLowerCase().includes(q) || s.id.toLowerCase().includes(q), + ); + }, [services, search]); + + const implementedServices = services.filter((s) => s.implemented !== false); + const enabledCount = implementedServices.filter((s) => s.enabled).length; + + const handleToggle = useCallback(async (serviceId: string, enabled: boolean) => { + setSaving(serviceId); + try { + const newEnabledIds = services + .filter((s) => (s.id === serviceId ? enabled : s.enabled)) + .map((s) => s.id); + await apiClient.put( + `/v1/integrations/connections/${connectionId}/services`, + { services: newEnabledIds }, + ); + await refreshServices(); + } finally { + setSaving(null); + } + }, [services, connectionId, apiClient, refreshServices]); + + const handleEnableAll = useCallback(async () => { + setSaving('all'); + try { + const allIds = implementedServices.map((s) => s.id); + await apiClient.put( + `/v1/integrations/connections/${connectionId}/services`, + { services: allIds }, + ); + await refreshServices(); + toast.success('All services enabled'); + } finally { + setSaving(null); + } + }, [implementedServices, connectionId, apiClient, refreshServices]); + + return ( + + + + + + {/* Schedule header */} +
+
+

Daily scan

+

+ {waitingForDetection + ? 'Detecting active GCP services...' + : 'Every day at 5:00 AM UTC'} +

+
+ + Active + +
+ + {/* Service toggles */} +
+
+

+ {waitingForDetection + ? 'Waiting for detection' + : `${enabledCount} of ${implementedServices.length} services`} +

+ +
+
+ + {/* Search (only if many services) */} + {implementedServices.length > 8 && !waitingForDetection && ( +
+
+ + setSearch(e.target.value)} + className="w-full rounded-md border bg-background pl-7 pr-3 py-1.5 text-xs placeholder:text-muted-foreground/50 focus:outline-none focus:ring-1 focus:ring-primary/30" + /> +
+
+ )} + + {/* Service list */} +
+ {waitingForDetection ? ( +

+ Service toggles will appear once detection completes. +

+ ) : ( + filteredServices.map((service) => ( + + )) + )} + {!waitingForDetection && filteredServices.length === 0 && search && ( +

+ No services match "{search}" +

+ )} +
+
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ServiceCard.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ServiceCard.tsx new file mode 100644 index 0000000000..b243a7f118 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ServiceCard.tsx @@ -0,0 +1,181 @@ +'use client'; + +import { useConnectionServices } from '@/hooks/use-integration-platform'; +import { Badge } from '@trycompai/ui/badge'; +import { + Cloud, + Database, + Globe, + HardDrive, + Key, + Lock, + MonitorCheck, + Network, + ScanSearch, + Server, + Shield, + Terminal, + Workflow, +} from 'lucide-react'; + +const SERVICE_ICONS: Record = { + 'security-hub': Shield, + 'iam-analyzer': Key, + 'cloudtrail': ScanSearch, + 's3': HardDrive, + 'ec2-vpc': Server, + 'rds': Database, + 'kms': Lock, + 'cloudwatch': MonitorCheck, + 'config': MonitorCheck, + 'guardduty': Shield, + 'secrets-manager': Key, + 'waf': Shield, + 'elb': Network, + 'acm': Lock, + 'backup': HardDrive, + 'inspector': ScanSearch, + 'ecs-eks': Server, + 'lambda': Terminal, + 'dynamodb': Database, + 'sns-sqs': Workflow, + 'ecr': Server, + 'opensearch': Database, + 'redshift': Database, + 'macie': ScanSearch, + 'route53': Globe, + 'api-gateway': Network, + 'cloudfront': Globe, + 'cognito': Key, + 'elasticache': Database, + 'efs': HardDrive, + 'msk': Workflow, + 'sagemaker': Cloud, + 'systems-manager': Terminal, + 'codebuild': Terminal, + 'network-firewall': Shield, + 'shield': Shield, + 'kinesis': Workflow, + 'glue': Workflow, + 'athena': Database, + 'emr': Cloud, + 'step-functions': Workflow, + 'eventbridge': Workflow, + 'transfer-family': Network, + 'elastic-beanstalk': Cloud, + 'appflow': Workflow, + // GCP services + 'cloud-storage': HardDrive, + 'compute-engine': Server, + 'vpc-network': Network, + 'iam': Key, + 'cloud-sql': Database, + 'gke': Server, + 'cloud-kms': Lock, + 'cloud-logging': ScanSearch, + 'cloud-monitoring': MonitorCheck, + 'cloud-dns': Globe, + 'bigquery': Database, + 'pubsub': Workflow, + 'cloud-armor': Shield, + 'security-command-center': Shield, +}; + +interface ServiceMeta { + id: string; + name: string; + description: string; + enabledByDefault?: boolean; + implemented?: boolean; +} + +function ServiceIcon({ serviceId }: { serviceId: string }) { + const Icon = SERVICE_ICONS[serviceId] as React.ComponentType<{ className?: string }> | undefined; + if (!Icon) return null; + return ( +
+ +
+ ); +} + +interface ServiceCardProps { + service: ServiceMeta; + connectionId: string | null; + isConnected: boolean; + onToggle?: (id: string, enabled: boolean) => void | Promise; + toggling?: boolean; +} + +export function ServiceCard({ + service, + connectionId, + isConnected, + onToggle, + toggling, +}: ServiceCardProps) { + const { services } = useConnectionServices(connectionId); + + const isImplemented = service.implemented !== false; + const liveService = services.find((s) => s.id === service.id); + const isEnabled = liveService?.enabled ?? false; + const showToggle = isImplemented && isConnected && onToggle; + + return ( +
+
+ +
+
+ {service.name} + {!isImplemented && ( + + Coming Soon + + )} +
+

+ {service.description} +

+ {liveService?.projects && liveService.projects.length > 0 && ( +
+ {liveService.projects.map((pid) => ( + + {pid} + + ))} +
+ )} +
+ {showToggle && ( + + )} +
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ServicesGrid.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ServicesGrid.tsx new file mode 100644 index 0000000000..db3e86f87e --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/ServicesGrid.tsx @@ -0,0 +1,91 @@ +'use client'; + +import { orderServicesForConnectionGrid } from '@/lib/connection-services-display-order'; +import { Search } from '@trycompai/design-system/icons'; +import { useCallback, useEffect, useMemo, useState } from 'react'; +import { ServiceCard } from './ServiceCard'; + +export function ServicesGrid({ + services, + connectionServices = [], + connectionId, + onToggle, + togglingService, +}: { + services: Array<{ id: string; name: string; description: string; implemented?: boolean }>; + connectionServices?: Array<{ id: string; enabled: boolean }>; + connectionId: string | null; + onToggle: (id: string, enabled: boolean) => boolean | void | Promise; + togglingService: string | null; +}) { + const [search, setSearch] = useState(''); + const [tailEnabledIds, setTailEnabledIds] = useState([]); + + useEffect(() => { + setTailEnabledIds([]); + }, [connectionId]); + + const handleToggle = useCallback( + async (id: string, enabled: boolean) => { + let rollback: string[] | null = null; + setTailEnabledIds((prev) => { + rollback = [...prev]; + if (enabled) return [...prev.filter((x) => x !== id), id]; + return prev.filter((x) => x !== id); + }); + const result = await Promise.resolve(onToggle(id, enabled)); + if (result === false && rollback) { + setTailEnabledIds(rollback); + } + }, + [onToggle], + ); + + const displayedServices = useMemo( + () => + orderServicesForConnectionGrid({ + manifestServices: services, + connectionServices, + search, + tailEnabledIds, + }), + [services, connectionServices, search, tailEnabledIds], + ); + + return ( +
+
+
+ + setSearch(e.target.value)} + className="w-44 rounded-md border bg-background py-1.5 pl-7 pr-3 text-xs placeholder:text-muted-foreground/50 focus:outline-none focus:ring-2 focus:ring-primary/30" + /> +
+
+
+ {displayedServices.map((service) => ( + + ))} + {displayedServices.length === 0 && search && ( +

+ No services matching "{search}" +

+ )} +
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/TestsLayout.tsx b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/TestsLayout.tsx index 1f5ee9460c..228e505441 100644 --- a/apps/app/src/app/(app)/[orgId]/cloud-tests/components/TestsLayout.tsx +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/components/TestsLayout.tsx @@ -4,9 +4,11 @@ import { ConnectIntegrationDialog } from '@/components/integrations/ConnectInteg import { useApi } from '@/hooks/use-api'; import { usePermissions } from '@/hooks/use-permissions'; import { ManageIntegrationDialog } from '@/components/integrations/ManageIntegrationDialog'; +import { CLOUD_RECONNECT_CUTOFF_LABEL, requiresCloudReconnect } from '@/lib/cloud-reconnect-policy'; import { Button, PageHeader, PageHeaderDescription, PageLayout } from '@trycompai/design-system'; import { Add, Settings } from '@trycompai/design-system/icons'; -import { useMemo, useState } from 'react'; +import { useSearchParams, useRouter } from 'next/navigation'; +import { useCallback, useMemo, useState } from 'react'; import { toast } from 'sonner'; import { isCloudProviderSlug } from '../constants'; import type { Finding, Provider } from '../types'; @@ -52,7 +54,21 @@ export function TestsLayout({ initialFindings, initialProviders, orgId }: TestsL const [showSettings, setShowSettings] = useState(false); const [viewingResults, setViewingResults] = useState(true); const [isScanning, setIsScanning] = useState(false); - const [activeProviderTab, setActiveProviderTab] = useState(null); + const searchParams = useSearchParams(); + const router = useRouter(); + const [activeProviderTab, setActiveProviderTabState] = useState( + searchParams.get('provider'), + ); + const setActiveProviderTab = useCallback((tab: string | null) => { + setActiveProviderTabState(tab); + const params = new URLSearchParams(searchParams.toString()); + if (tab) { + params.set('provider', tab); + } else { + params.delete('provider'); + } + router.replace(`?${params.toString()}`, { scroll: false }); + }, [searchParams, router]); const [activeConnectionTabs, setActiveConnectionTabs] = useState>({}); const [addConnectionProvider, setAddConnectionProvider] = useState(null); const [configureDialogOpen, setConfigureDialogOpen] = useState(false); @@ -86,6 +102,19 @@ export function TestsLayout({ initialFindings, initialProviders, orgId }: TestsL const isProvidersValidating = providersResponse.isValidating; const connectedProviders = providers; + const reconnectRequiredCount = useMemo( + () => + connectedProviders.filter((provider) => + requiresCloudReconnect({ + providerId: provider.integrationId, + createdAt: provider.createdAt, + reconnectedAt: provider.reconnectedAt, + isLegacy: provider.isLegacy, + status: provider.status, + }), + ).length, + [connectedProviders], + ); // Group connections by provider type (aws, gcp, azure) const providerGroups = useMemo(() => { @@ -254,6 +283,17 @@ export function TestsLayout({ initialFindings, initialProviders, orgId }: TestsL {multiProviderDescription} + {reconnectRequiredCount > 0 && ( +
+

+ Reconnect required for {reconnectRequiredCount} cloud connection{reconnectRequiredCount === 1 ? '' : 's'} +

+

+ Connections created before {CLOUD_RECONNECT_CUTOFF_LABEL} should be re-added to keep scans and remediation fully reliable. +

+
+ )} + + requiresCloudReconnect({ + providerId: provider.integrationId, + createdAt: provider.createdAt, + reconnectedAt: provider.reconnectedAt, + isLegacy: provider.isLegacy, + status: provider.status, + }) + } canRunScan={canRunScan} canAddConnection={canCreateIntegration} + orgId={orgId} /> {/* CloudSettingsModal for single-connection providers AND legacy connections */} @@ -298,7 +348,6 @@ export function TestsLayout({ initialFindings, initialProviders, orgId }: TestsL open={showSettings} onOpenChange={setShowSettings} connectedProviders={connectedProviders - .filter((p) => !p.supportsMultipleConnections || p.isLegacy) .map((p) => ({ id: p.integrationId, connectionId: p.id, diff --git a/apps/app/src/app/(app)/[orgId]/cloud-tests/types.ts b/apps/app/src/app/(app)/[orgId]/cloud-tests/types.ts index 5538a930fa..b78b72879b 100644 --- a/apps/app/src/app/(app)/[orgId]/cloud-tests/types.ts +++ b/apps/app/src/app/(app)/[orgId]/cloud-tests/types.ts @@ -5,6 +5,10 @@ export interface Finding { remediation: string | null; status: string | null; severity: string | null; + serviceId: string | null; + findingKey: string | null; + resourceId: string | null; + projectDisplayName: string | null; completedAt: Date | null; connectionId: string; providerSlug: string; @@ -23,6 +27,7 @@ export interface Provider { status: string; createdAt: Date; updatedAt: Date; + reconnectedAt?: Date | string | null; isLegacy?: boolean; variables?: Record | null; requiredVariables?: string[]; diff --git a/apps/app/src/app/(app)/[orgId]/documents/[formType]/loading.tsx b/apps/app/src/app/(app)/[orgId]/documents/[formType]/loading.tsx new file mode 100644 index 0000000000..834c94b27f --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/documents/[formType]/loading.tsx @@ -0,0 +1,5 @@ +import { PageLayout } from '@trycompai/design-system'; + +export default function Loading() { + return ; +} diff --git a/apps/app/src/app/(app)/[orgId]/documents/[formType]/page.tsx b/apps/app/src/app/(app)/[orgId]/documents/[formType]/page.tsx index 4fd1326fb9..43720357bb 100644 --- a/apps/app/src/app/(app)/[orgId]/documents/[formType]/page.tsx +++ b/apps/app/src/app/(app)/[orgId]/documents/[formType]/page.tsx @@ -2,7 +2,6 @@ import { CompanyFormPageClient } from '@/app/(app)/[orgId]/documents/components/ import { Breadcrumb, PageLayout } from '@trycompai/design-system'; import Link from 'next/link'; import { notFound } from 'next/navigation'; -import { Suspense } from 'react'; import { evidenceFormDefinitions, evidenceFormTypeSchema } from '../forms'; import { auth } from '@/utils/auth'; import { headers } from 'next/headers'; @@ -43,13 +42,11 @@ export default async function CompanyFormDetailPage({ { label: formDefinition.title, isCurrent: true }, ]} /> - - - + ); } diff --git a/apps/app/src/app/(app)/[orgId]/documents/components/CompanyFormPageClient.tsx b/apps/app/src/app/(app)/[orgId]/documents/components/CompanyFormPageClient.tsx index 4b577b9246..86b452af17 100644 --- a/apps/app/src/app/(app)/[orgId]/documents/components/CompanyFormPageClient.tsx +++ b/apps/app/src/app/(app)/[orgId]/documents/components/CompanyFormPageClient.tsx @@ -426,19 +426,7 @@ export function CompanyFormPageClient({
- {isLoading ? ( - - - - - - No submissions yet - - Start by creating a new submission, click the New Submission button above. - - - - ) : !data || data.submissions.length === 0 ? ( + {!data || data.submissions.length === 0 ? ( diff --git a/apps/app/src/app/(app)/[orgId]/documents/components/CompanySubmissionWizard.tsx b/apps/app/src/app/(app)/[orgId]/documents/components/CompanySubmissionWizard.tsx index 4cb560d043..3c6f32b224 100644 --- a/apps/app/src/app/(app)/[orgId]/documents/components/CompanySubmissionWizard.tsx +++ b/apps/app/src/app/(app)/[orgId]/documents/components/CompanySubmissionWizard.tsx @@ -16,6 +16,7 @@ import { api } from '@/lib/api-client'; import { meetingFields } from '@trycompai/company'; import { zodResolver } from '@hookform/resolvers/zod'; import { + Alert, Button, Field, FieldError, @@ -548,6 +549,13 @@ export function CompanySubmissionWizard({ {textareaFields.length === 0 && ( No additional fields required for this step. )} + {textareaFields.some((f) => f.placeholder) && ( + + )} {textareaFields.map((field) => ( ))} + {extendedFields.some((f) => f.type === 'textarea' && f.placeholder) && ( + + )} {extendedFields.map((field) => ( -
-
- ); - } - if (allError) { return ( -
- - Failed to load findings -
+ + + + + + Failed to load findings + + Something went wrong. Please try refreshing the page. + + + ); } - if (sortedFindings.length === 0) { + if (allIsLoading || sortedFindings.length === 0) { return ( -
-

No findings for this document

+ + + + + + No findings yet + + Findings will appear here when an auditor flags issues requiring attention. + + {canCreateFinding && ( -
- -
+ )} -
+
); } @@ -216,14 +233,7 @@ export function DocumentFindingsSection({
- {sortedFindings.length === 0 ? ( -
- -

No findings for this document

- {canCreateFinding &&

Create a finding to flag an issue

} -
- ) : ( -
+
{visibleFindings.map((finding: Finding) => ( INITIAL_DISPLAY_COUNT && (
-
)}
- )}
); diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/actions/batch-fix.ts b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/actions/batch-fix.ts new file mode 100644 index 0000000000..01a93c90fb --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/actions/batch-fix.ts @@ -0,0 +1,164 @@ +'use server'; + +import { auth, runs, tasks } from '@trigger.dev/sdk'; +import { serverApi } from '@/lib/api-server'; + +interface BatchFixInput { + organizationId: string; + connectionId: string; + findings: Array<{ id: string; key: string; title: string }>; +} + +export async function startBatchFix( + input: BatchFixInput, +): Promise<{ data?: { batchId: string; runId: string; accessToken: string }; error?: string }> { + try { + // Step 1: Create batch record in DB via API + const api = serverApi; + const batchResp = await api.post<{ data: { id: string } }>('/v1/cloud-security/remediation/batch', { + connectionId: input.connectionId, + findings: input.findings, + }); + + if (batchResp.error || !batchResp.data?.data?.id) { + return { error: 'Failed to create batch record' }; + } + + const batchId = batchResp.data.data.id; + + // Step 2: Trigger the API-layer task + const handle = await tasks.trigger('remediate-batch', { + batchId, + organizationId: input.organizationId, + connectionId: input.connectionId, + }); + + // Step 3: Store triggerRunId on the batch + await api.patch(`/v1/cloud-security/remediation/batch/${batchId}`, { + triggerRunId: handle.id, + status: 'running', + }); + + // Step 4: Create public access token for real-time progress + const accessToken = await auth.createPublicToken({ + scopes: { read: { runs: [handle.id] } }, + }); + + return { data: { batchId, runId: handle.id, accessToken } }; + } catch (err) { + console.error('Failed to start batch fix:', err); + return { error: err instanceof Error ? err.message : 'Failed to start batch fix' }; + } +} + +export async function cancelBatchFix(runId: string, batchId: string): Promise { + try { + // Mark batch as cancelled in DB — task will check this before next finding + const api = serverApi; + await api.patch(`/v1/cloud-security/remediation/batch/${batchId}`, { + status: 'cancelled', + }); + // Also cancel the trigger run + await runs.cancel(runId); + } catch { + // Run may have already completed + } +} + +/** Check for an active batch on page load — returns batch + access token if found. */ +export async function getActiveBatch( + connectionId: string, +): Promise<{ + batchId: string; + triggerRunId: string; + accessToken: string; + findings: Array<{ id: string; title: string; status: string; error?: string }>; +} | null> { + try { + const resp = await serverApi.get( + `/v1/cloud-security/remediation/batch/active?connectionId=${connectionId}`, + ); + const batch = (resp.data as { data?: { id: string; triggerRunId?: string; findings: unknown[] } })?.data; + if (!batch?.triggerRunId) return null; + + // Verify the trigger run is actually still active + try { + const run = await runs.retrieve(batch.triggerRunId); + if (run.status === 'COMPLETED' || run.status === 'FAILED' || run.status === 'CANCELED' || run.status === 'SYSTEM_FAILURE') { + // Run is done — mark batch as done in DB so it doesn't show up again + await serverApi.patch(`/v1/cloud-security/remediation/batch/${batch.id}`, { + status: 'done', + }); + return null; + } + } catch { + // Can't verify run — mark batch as done to be safe + await serverApi.patch(`/v1/cloud-security/remediation/batch/${batch.id}`, { + status: 'done', + }); + return null; + } + + const accessToken = await auth.createPublicToken({ + scopes: { read: { runs: [batch.triggerRunId] } }, + }); + + return { + batchId: batch.id, + triggerRunId: batch.triggerRunId, + accessToken, + findings: batch.findings as Array<{ id: string; title: string; status: string; error?: string }>, + }; + } catch { + return null; + } +} + +export async function skipBatchFinding(batchId: string, findingId: string): Promise { + try { + await serverApi.post(`/v1/cloud-security/remediation/batch/${batchId}/skip/${findingId}`, {}); + } catch { + // Best effort + } +} + +/** Retry a single finding immediately (user added permissions and wants instant retry). */ +export async function retryFinding( + connectionId: string, + checkResultId: string, + remediationKey: string, +): Promise<{ status: 'fixed' | 'failed' | 'needs_permissions'; error?: string; missingPermissions?: string[] }> { + try { + // Preview first + const preview = await serverApi.post<{ + guidedOnly?: boolean; + missingPermissions?: string[]; + }>('/v1/cloud-security/remediation/preview', { + connectionId, + checkResultId, + remediationKey, + }); + + if (preview.error) return { status: 'failed', error: String(preview.error) }; + + const data = preview.data as { guidedOnly?: boolean; missingPermissions?: string[] } | undefined; + if (data?.missingPermissions && data.missingPermissions.length > 0) { + return { status: 'needs_permissions', missingPermissions: data.missingPermissions }; + } + + // Execute + const execute = await serverApi.post<{ status: string; error?: string }>( + '/v1/cloud-security/remediation/execute', + { connectionId, checkResultId, remediationKey, acknowledgment: 'acknowledged' }, + ); + + const execData = execute.data as { status?: string; error?: string } | undefined; + if (execute.error || execData?.status === 'failed') { + return { status: 'failed', error: String(execute.error ?? execData?.error ?? 'Failed') }; + } + + return { status: 'fixed' }; + } catch (err) { + return { status: 'failed', error: err instanceof Error ? err.message : 'Failed' }; + } +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/AccountSelector.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/AccountSelector.tsx new file mode 100644 index 0000000000..fc6f5fd5bf --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/AccountSelector.tsx @@ -0,0 +1,83 @@ +'use client'; + +import type { ConnectionListItem } from '@/hooks/use-integration-platform'; +import { Select, SelectContent, SelectItem, SelectTrigger } from '@trycompai/design-system'; +import { getConnectionDisplayLabel } from './connection-display'; + +interface AccountSelectorProps { + connections: ConnectionListItem[]; + selectedId: string; + onSelect: (id: string) => void; + /** Sit inside a parent toolbar — no outer border (parent provides the frame). */ + embedded?: boolean; + /** Tighter label size for dense layouts. */ + compact?: boolean; +} + +const STATUS_DOT: Record = { + active: 'bg-emerald-500', + pending: 'bg-yellow-500', + error: 'bg-red-500', +}; + +export function AccountSelector({ + connections, + selectedId, + onSelect, + embedded = false, + compact = false, +}: AccountSelectorProps) { + const selected = connections.find((c) => c.id === selectedId); + const selectedName = selected ? getConnectionDisplayLabel(selected) : 'Select account'; + + const triggerStyle = embedded + ? { + width: '100%' as const, + minWidth: 0, + justifyContent: 'start' as const, + gap: 6, + border: 'none', + background: 'transparent', + boxShadow: 'none', + } + : { + width: '100%' as const, + justifyContent: 'start' as const, + gap: 6, + minWidth: 0, + }; + + return ( + + ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/AccountSettingsSheet.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/AccountSettingsSheet.tsx new file mode 100644 index 0000000000..abd2e3b8ca --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/AccountSettingsSheet.tsx @@ -0,0 +1,56 @@ +'use client'; + +import type { IntegrationProvider } from '@/hooks/use-integration-platform'; +import { Sheet, SheetBody, SheetContent, SheetHeader, SheetTitle } from '@trycompai/ui/sheet'; +import { AccountSettingsOAuthBody } from './account-settings-oauth'; +import { AwsAccountSettingsBody } from './aws-account-settings-body'; + +interface AccountSettingsSheetProps { + open: boolean; + onOpenChange: (open: boolean) => void; + connectionId: string; + provider: IntegrationProvider; + orgId: string; + onUpdated?: () => void; +} + +export function AccountSettingsSheet({ + open, + onOpenChange, + connectionId, + provider, + orgId, + onUpdated, +}: AccountSettingsSheetProps) { + const isAws = provider.id === 'aws'; + + return ( + + + + Account Settings +

{provider.name}

+
+ + {isAws ? ( + + ) : ( + + )} + +
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/AcknowledgmentPanel.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/AcknowledgmentPanel.tsx new file mode 100644 index 0000000000..b413cf4c4a --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/AcknowledgmentPanel.tsx @@ -0,0 +1,2 @@ +// Re-export from cloud-tests (canonical location) +export { AcknowledgmentPanel } from '@/app/(app)/[orgId]/cloud-tests/components/AcknowledgmentPanel'; diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ActivitySection.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ActivitySection.tsx new file mode 100644 index 0000000000..ff7a9ff5b3 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ActivitySection.tsx @@ -0,0 +1,103 @@ +'use client'; + +import { api } from '@/lib/api-client'; +import { Avatar, AvatarFallback } from '@trycompai/ui/avatar'; +import { Section, Text, HStack, Stack } from '@trycompai/design-system'; +import { Activity } from '@trycompai/design-system/icons'; +import { formatDistanceToNow } from 'date-fns'; +import { Loader2 } from 'lucide-react'; +import useSWR from 'swr'; + +interface ActivityEntry { + id: string; + type: 'scan' | 'remediation' | 'rollback' | 'service_change'; + description: string; + userId: string | null; + userName: string | null; + status: 'success' | 'failed' | 'info'; + timestamp: string; + metadata?: Record; +} + +interface ActivitySectionProps { + connectionId: string; +} + +const getInitials = (name: string | null) => + name + ? name.split(' ').map((p) => p[0]).join('').toUpperCase().slice(0, 2) + : 'S'; + +function ActivityRow({ entry }: { entry: ActivityEntry }) { + const displayName = entry.userName ?? 'System'; + const timeAgo = formatDistanceToNow(new Date(entry.timestamp), { addSuffix: true }); + + return ( + + + + {getInitials(entry.userName)} + + + + + {displayName} + {' '} + {entry.description} + + +
+ {timeAgo} +
+
+ ); +} + +export function ActivitySection({ connectionId }: ActivitySectionProps) { + const { data, isLoading } = useSWR( + connectionId ? ['cloud-activity', connectionId] : null, + async () => { + const response = await api.get<{ data: ActivityEntry[] }>( + `/v1/cloud-security/activity?connectionId=${connectionId}&take=50`, + ); + if (response.error) throw new Error(response.error); + return response.data?.data ?? []; + }, + { revalidateOnFocus: false, dedupingInterval: 5000 }, + ); + + const entries = data ?? []; + + if (isLoading) { + return ( +
+ +
+ ); + } + + if (entries.length === 0) { + return ( +
+
+ + + + Activity will appear here when scans or remediations are run + + +
+
+ ); + } + + return ( +
+
+ {entries.map((entry) => ( + + ))} +
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/BatchRemediationDialog.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/BatchRemediationDialog.tsx new file mode 100644 index 0000000000..acdbac77b3 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/BatchRemediationDialog.tsx @@ -0,0 +1,2 @@ +// Re-export from cloud-tests (canonical location) +export { BatchRemediationDialog } from '@/app/(app)/[orgId]/cloud-tests/components/BatchRemediationDialog'; diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/CloudTestsSection.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/CloudTestsSection.tsx new file mode 100644 index 0000000000..24d94d24a2 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/CloudTestsSection.tsx @@ -0,0 +1,2 @@ +// Re-export from cloud-tests (canonical location) +export { CloudTestsSection } from '@/app/(app)/[orgId]/cloud-tests/components/CloudTestsSection'; diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ConnectionSection.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ConnectionSection.tsx new file mode 100644 index 0000000000..6686f39ce4 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ConnectionSection.tsx @@ -0,0 +1,149 @@ +'use client'; + +import type { + ConnectionListItem, + IntegrationProvider, +} from '@/hooks/use-integration-platform'; +import { Badge } from '@trycompai/ui/badge'; +import { Button } from '@trycompai/ui/button'; +import { + AlertCircle, + CheckCircle2, + Clock, + Globe, + Plus, + Server, +} from 'lucide-react'; + +interface ConnectionSectionProps { + provider: IntegrationProvider; + connections: ConnectionListItem[]; + onConnect: () => void; +} + +export function ConnectionSection({ + provider, + connections, + onConnect, +}: ConnectionSectionProps) { + if (connections.length === 0) { + return ( +
+ +

+ No connections yet +

+

+ Connect your {provider.name} account to get started +

+ +
+ ); + } + + return ( +
+
+

+ Connections + + ({connections.length}) + +

+ {provider.supportsMultipleConnections && ( + + )} +
+ +
+ {connections.map((connection) => ( + + ))} +
+
+ ); +} + +function ConnectionRow({ connection }: { connection: ConnectionListItem }) { + const metadata = (connection.metadata ?? {}) as Record; + const displayName = + (metadata.connectionName as string) ?? + (metadata.accountId as string) ?? + connection.id; + const accountId = metadata.accountId as string | undefined; + const regions = metadata.regions as string[] | undefined; + + return ( +
+
+
+ +
+
+

{displayName}

+
+ {accountId && ( + {accountId} + )} + {regions && regions.length > 0 && ( + + + {regions.length} {regions.length === 1 ? 'region' : 'regions'} + + )} + {connection.lastSyncAt && ( + + + {new Date(connection.lastSyncAt).toLocaleDateString(undefined, { + month: 'short', + day: 'numeric', + hour: '2-digit', + minute: '2-digit', + })} + + )} +
+
+
+ +
+ ); +} + +function ConnectionStatusBadge({ status }: { status: string }) { + switch (status) { + case 'active': + return ( + + + Active + + ); + case 'error': + return ( + + + Error + + ); + case 'pending': + return ( + + + Pending + + ); + default: + return ( + + {status} + + ); + } +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/EmptyStateOnboarding.test.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/EmptyStateOnboarding.test.tsx new file mode 100644 index 0000000000..846029b15b --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/EmptyStateOnboarding.test.tsx @@ -0,0 +1,124 @@ +import { fireEvent, render, screen, waitFor } from '@testing-library/react'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { EmptyStateOnboarding } from './EmptyStateOnboarding'; + +const mockCreateConnection = vi.fn(); +const mockToastSuccess = vi.fn(); +const mockToastError = vi.fn(); + +vi.mock('@/hooks/use-integration-platform', () => ({ + useIntegrationMutations: () => ({ + createConnection: mockCreateConnection, + }), +})); + +vi.mock('@/components/integrations/CloudShellSetup', () => ({ + CloudShellSetup: () =>
, +})); + +vi.mock('@/components/integrations/CredentialInput', () => ({ + CredentialInput: ({ field, value, onChange }: any) => ( + onChange(event.target.value)} + /> + ), +})); + +vi.mock('@trycompai/design-system', () => ({ + Button: ({ children, disabled, loading, onClick }: any) => ( + + ), + Label: ({ children, htmlFor }: any) => , +})); + +vi.mock('lucide-react', () => ({ + ArrowRight: () => , + Shield: () => , +})); + +vi.mock('@trycompai/integration-platform', () => ({ + awsRemediationScript: '', +})); + +vi.mock('sonner', () => ({ + toast: { + success: (...args: unknown[]) => mockToastSuccess(...args), + error: (...args: unknown[]) => mockToastError(...args), + }, +})); + +describe('EmptyStateOnboarding', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('allows connecting dynamic custom integrations with no credential fields', async () => { + mockCreateConnection.mockResolvedValue({ success: true }); + const onConnected = vi.fn(); + + render( + , + ); + + fireEvent.click(screen.getByRole('button', { name: /connect account/i })); + + await waitFor(() => { + expect(mockCreateConnection).toHaveBeenCalledWith('dynamic-security', {}); + }); + expect(onConnected).toHaveBeenCalled(); + expect(mockToastSuccess).toHaveBeenCalledWith('Dynamic Security connected!'); + }); + + it('uses API key fallback field when credential fields are missing', async () => { + mockCreateConnection.mockResolvedValue({ success: true }); + + render( + , + ); + + fireEvent.click(screen.getByRole('button', { name: /connect account/i })); + expect(screen.getByText('API Key is required')).toBeInTheDocument(); + expect(mockCreateConnection).not.toHaveBeenCalled(); + + fireEvent.change(screen.getByLabelText('API Key'), { target: { value: 'secret' } }); + fireEvent.click(screen.getByRole('button', { name: /connect account/i })); + + await waitFor(() => { + expect(mockCreateConnection).toHaveBeenCalledWith('dynamic-api', { api_key: 'secret' }); + }); + }); +}); + diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/EmptyStateOnboarding.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/EmptyStateOnboarding.tsx new file mode 100644 index 0000000000..303d854cc2 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/EmptyStateOnboarding.tsx @@ -0,0 +1,650 @@ +'use client'; + +import { CloudShellSetup } from '@/components/integrations/CloudShellSetup'; +import { CredentialInput } from '@/components/integrations/CredentialInput'; +import type { IntegrationProvider } from '@/hooks/use-integration-platform'; +import { useIntegrationMutations } from '@/hooks/use-integration-platform'; +import { Button, Label } from '@trycompai/design-system'; +import { awsRemediationScript } from '@trycompai/integration-platform'; +import { ArrowRight, Shield } from 'lucide-react'; +import { useCallback, useMemo, useState } from 'react'; +import { toast } from 'sonner'; + +// ─── Primitives ───────────────────────────────────────────────────────── + +function StepHeader({ step, title }: { step: number; title: string }) { + return ( +
+ + {step} + +

{title}

+
+ ); +} + +function FieldRow({ + field, + value, + error, + onChange, +}: { + field: { id: string; label: string; required?: boolean; helpText?: string; type?: string }; + value: string | string[]; + error?: string; + onChange: (value: string | string[]) => void; +}) { + return ( +
+ + [0]['field']} + value={value} + onChange={onChange} + /> + {field.helpText && ( +

{field.helpText}

+ )} + {error &&

{error}

} +
+ ); +} + +/** Compact setup guide — shows only headings as collapsible sections, max 3-4 key steps each. */ +function SetupGuide({ text, fallback, docsUrl }: { text?: string | null; fallback: string; docsUrl?: string | null }) { + const raw = text || fallback; + const [expandedSection, setExpandedSection] = useState(null); + + // Parse into sections (split on ### headings) + const sections = useMemo(() => { + const lines = raw.split('\n'); + const result: Array<{ title: string; steps: string[] }> = []; + let current: { title: string; steps: string[] } | null = null; + + for (const line of lines) { + const trimmed = line.trim(); + if (!trimmed) continue; + + if (trimmed.startsWith('##')) { + if (current) result.push(current); + current = { title: trimmed.replace(/^#{1,4}\s*/, ''), steps: [] }; + } else if (current && (/^\d+[\.\)]\s/.test(trimmed) || trimmed.startsWith('- '))) { + current.steps.push(trimmed.replace(/^\d+[\.\)]\s*/, '').replace(/^-\s*/, '')); + } else if (current && trimmed.startsWith('>')) { + current.steps.push(trimmed.replace(/^>\s*/, '')); + } else if (current) { + current.steps.push(trimmed); + } + } + if (current) result.push(current); + return result; + }, [raw]); + + // No structured content — simple fallback + if (sections.length === 0) { + return ( +

{formatInline(raw)}

+ ); + } + + return ( +
+ {sections.map((section, i) => { + const isOpen = expandedSection === i; + const previewSteps = section.steps.slice(0, 3); + + return ( +
+ + {isOpen && ( +
+ {(previewSteps).map((step, j) => ( +
+ + {j + 1} + +

+ {formatInline(step)} +

+
+ ))} + {section.steps.length > 3 && ( +

+ +{section.steps.length - 3} more step{section.steps.length - 3 !== 1 ? 's' : ''} in docs +

+ )} +
+ )} +
+ ); + })} + {docsUrl && ( + + View full documentation + + + )} +
+ ); +} + +/** Format inline markdown: **bold**, `code`, [links](url) */ +function formatInline(text: string): React.ReactNode { + const parts = text.split(/(\*\*[^*]+\*\*|`[^`]+`|\[[^\]]+\]\([^)]+\))/g); + return parts.map((part, i) => { + if (part.startsWith('**') && part.endsWith('**')) { + return {part.slice(2, -2)}; + } + if (part.startsWith('`') && part.endsWith('`')) { + return {part.slice(1, -1)}; + } + const linkMatch = part.match(/^\[([^\]]+)\]\(([^)]+)\)$/); + if (linkMatch) { + return {linkMatch[1]}; + } + return {part}; + }); +} + +// ─── Main ─────────────────────────────────────────────────────────────── + +interface EmptyStateOnboardingProps { + provider: IntegrationProvider; + orgId: string; + onConnected: () => void; + /** For OAuth providers — opens the OAuth flow */ + onOAuthConnect?: () => void; +} + +export function EmptyStateOnboarding({ + provider, + orgId, + onConnected, + onOAuthConnect, +}: EmptyStateOnboardingProps) { + const isOAuth = provider.authType === 'oauth2'; + const isCloudProvider = provider.category === 'Cloud'; + const isComingSoon = isOAuth && provider.oauthConfigured === false; + + // Coming soon — show info + notify + if (isComingSoon) { + return ; + } + + // OAuth providers get a simple connect card + if (isOAuth) { + return ; + } + + // Cloud providers with setup scripts get the full guided flow + if (isCloudProvider && provider.setupScript) { + return ( + + ); + } + + // Everything else: API key / basic / custom credentials + return ( + + ); +} + +// ─── OAuth (GitHub, Google Workspace, etc.) ───────────────────────────── + +function ComingSoonState({ provider }: { provider: IntegrationProvider }) { + return ( +
+
+
+
+ {provider.logoUrl && ( + + )} +
+

{provider.name}

+

{provider.description}

+
+
+
+

Coming Soon

+

+ This integration is under development. We'll notify you when it's ready. +

+
+
+
+
+ ); +} + +function OAuthSetup({ + provider, + onConnect, +}: { + provider: IntegrationProvider; + onConnect?: () => void; +}) { + return ( +
+
+
+ {provider.logoUrl && ( + + )} +
+

Connect {provider.name}

+

+ You'll be redirected to authorize access. Takes about 30 seconds. +

+
+
+ +
+
+ ); +} + +// ─── API Key / Basic / Custom Credentials ────────────────────────────── + +function CredentialSetup({ + provider, + orgId, + onConnected, +}: { + provider: IntegrationProvider; + orgId: string; + onConnected: () => void; +}) { + const { createConnection } = useIntegrationMutations(); + const [connecting, setConnecting] = useState(false); + const [credentials, setCredentials] = useState>({}); + const [errors, setErrors] = useState>({}); + + const fields = useMemo(() => { + const configuredFields = provider.credentialFields ?? []; + + if (provider.authType === 'basic' && configuredFields.length === 0) { + return [ + { + id: 'username', + label: 'Username', + type: 'text' as const, + required: true, + placeholder: 'Enter username', + }, + { + id: 'password', + label: 'Password', + type: 'password' as const, + required: true, + placeholder: 'Enter password', + }, + ]; + } + + if (provider.authType === 'api_key' && configuredFields.length === 0) { + return [ + { + id: 'api_key', + label: 'API Key', + type: 'password' as const, + required: true, + placeholder: 'Enter your API key', + }, + ]; + } + + return configuredFields; + }, [provider.authType, provider.credentialFields]); + const hasConfigurableFields = fields.length > 0; + + const updateCredential = (fieldId: string, value: string | string[]) => { + setCredentials((prev) => ({ ...prev, [fieldId]: value })); + if (errors[fieldId]) { + setErrors((prev) => { + const next = { ...prev }; + delete next[fieldId]; + return next; + }); + } + }; + + const handleConnect = useCallback(async () => { + const newErrors: Record = {}; + for (const field of fields) { + const value = credentials[field.id]; + const isMissing = + field.type === 'multi-select' + ? !Array.isArray(value) || value.length === 0 + : !String(value ?? '').trim(); + if (field.required && isMissing) { + newErrors[field.id] = `${field.label} is required`; + } + } + if (Object.keys(newErrors).length > 0) { + setErrors(newErrors); + return; + } + + setConnecting(true); + try { + const result = await createConnection(provider.id, credentials); + if (!result.success) { + toast.error(result.error || 'Failed to connect'); + return; + } + toast.success(`${provider.name} connected!`); + onConnected(); + } catch { + toast.error('Failed to connect'); + } finally { + setConnecting(false); + } + }, [fields, credentials, createConnection, provider, onConnected]); + + return ( +
+
+

Connect {provider.name}

+

+ {provider.description || 'Enter your credentials to get started.'} +

+
+ +
+ {/* Main form */} +
+
+ {hasConfigurableFields ? ( + fields.map((field) => ( + updateCredential(field.id, v)} + /> + )) + ) : ( +
+ No additional setup fields are required for this integration. + {provider.docsUrl ? ( + <> + {' '} + + Open docs + + . + + ) : null} +
+ )} +
+
+ +
+
+ + {/* Sidebar — setup guide */} +
+
+ {provider.logoUrl && ( + + )} +

Setup guide

+
+
+ +
+
+
+
+ ); +} + +// ─── Cloud Providers (AWS, GCP, Azure) ───────────────────────────────── + +function CloudSetup({ + provider, + orgId, + onConnected, +}: { + provider: IntegrationProvider; + orgId: string; + onConnected: () => void; +}) { + const { createConnection } = useIntegrationMutations(); + const [connecting, setConnecting] = useState(false); + const [credentials, setCredentials] = useState>({}); + const [errors, setErrors] = useState>({}); + + const allFields = provider.credentialFields ?? []; + const visibleFields = allFields.filter( + (field) => field.id !== 'externalId' && field.id !== 'connectionName', + ); + + const updateCredential = (fieldId: string, value: string | string[]) => { + setCredentials((prev) => ({ ...prev, [fieldId]: value })); + if (errors[fieldId]) { + setErrors((prev) => { + const next = { ...prev }; + delete next[fieldId]; + return next; + }); + } + }; + + const handleConnect = useCallback(async () => { + const finalCredentials = { ...credentials }; + if (!finalCredentials.externalId) finalCredentials.externalId = orgId; + if (!finalCredentials.connectionName) { + const arnMatch = String(finalCredentials.roleArn ?? '').match(/:(\d{12}):/); + finalCredentials.connectionName = arnMatch ? `AWS ${arnMatch[1]}` : 'AWS Account'; + } + + const newErrors: Record = {}; + for (const field of allFields) { + if (field.id === 'externalId' || field.id === 'connectionName') continue; + const value = finalCredentials[field.id]; + const isMissing = + field.type === 'multi-select' + ? !Array.isArray(value) || value.length === 0 + : !String(value ?? '').trim(); + if (field.required && isMissing) { + newErrors[field.id] = `${field.label} is required`; + } + } + if (Object.keys(newErrors).length > 0) { + setErrors(newErrors); + return; + } + + setConnecting(true); + try { + const result = await createConnection(provider.id, finalCredentials); + if (!result.success) { + toast.error(result.error || 'Failed to connect'); + return; + } + toast.success(`${provider.name} connected and verified!`); + setCredentials({}); + onConnected(); + } catch { + toast.error('Failed to connect'); + } finally { + setConnecting(false); + } + }, [allFields, credentials, createConnection, provider, orgId, onConnected]); + + const connectionFields = visibleFields.filter((f) => f.id !== 'remediationRoleArn' && f.id !== 'regions'); + const regionFields = visibleFields.filter((f) => f.id === 'regions'); + const remediationFields = visibleFields.filter((f) => f.id === 'remediationRoleArn'); + const hasRemediation = provider.id === 'aws' && remediationFields.length > 0; + + return ( +
+
+

Get started

+

+ Connect a read-only IAM role to start scanning your cloud security posture. +

+
+ +
+ {/* ─── Left: Unified setup flow ─── */} +
+ {/* Step 1 */} + {provider.setupScript && ( +
+ + +

+ Connecting multiple accounts? Run the script in each account and add them one by one. +

+
+ )} + +
+ + {/* Step 2 */} +
+ + {connectionFields.map((field) => ( + updateCredential(field.id, v)} + /> + ))} +
+ + {/* Step 3 */} + {regionFields.length > 0 && ( + <> +
+
+ + {regionFields.map((field) => ( + updateCredential(field.id, v)} + /> + ))} +
+ + )} + + {/* CTA */} +
+ +
+
+ + {/* ─── Right: Optional sidebar ─── */} + {hasRemediation && ( +
+
+
+
+ +
+

Auto-Remediation

+ + Optional + +
+
+

+ Enable one-click fixes for security findings. This creates a separate write-access role — your audit role stays read-only. +

+ + {remediationFields.map((field) => ( + updateCredential(field.id, v)} + /> + ))} +
+
+

+ You can always enable this later from Settings. +

+
+ )} +
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/GcpProjectPicker.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/GcpProjectPicker.tsx new file mode 100644 index 0000000000..cbde2e1116 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/GcpProjectPicker.tsx @@ -0,0 +1,159 @@ +'use client'; + +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuTrigger, +} from '@trycompai/ui/dropdown-menu'; + +export interface GcpOrg { + id: string; + displayName: string; + projects: Array<{ id: string; name: string }>; +} + +interface GcpProjectPickerProps { + organizations: GcpOrg[]; + selectedProjectIds: string[]; + onToggleProject: (orgId: string, projectId: string) => void; +} + +export function GcpProjectPicker({ + organizations, + selectedProjectIds, + onToggleProject, +}: GcpProjectPickerProps) { + const selectedSet = new Set(selectedProjectIds); + const count = selectedProjectIds.length; + + const allProjects = organizations.flatMap((o) => + o.projects.map((p) => ({ ...p, orgId: o.id })), + ); + const selectedNames = allProjects + .filter((p) => selectedSet.has(p.id)) + .map((p) => p.name); + + const label = + count === 0 + ? 'Select projects' + : count <= 2 + ? selectedNames.join(', ') + : `${selectedNames[0]} +${count - 1} more`; + + return ( +
+
+

GCP Projects

+

+ Select which projects to scan and monitor. Findings and service + detection are scoped to these projects. +

+
+ + + + + + + e.preventDefault()} + > + {organizations.map((org) => { + const orgSelectedCount = org.projects.filter((p) => + selectedSet.has(p.id), + ).length; + return ( + + + {org.displayName} + {orgSelectedCount > 0 && ( + + {orgSelectedCount} + + )} + + + {org.projects.length === 0 ? ( +
+ No projects found +
+ ) : ( + org.projects.map((p) => { + const checked = selectedSet.has(p.id); + return ( + + ); + }) + )} +
+
+ ); + })} +
+
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/IntegrationProviderHero.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/IntegrationProviderHero.tsx new file mode 100644 index 0000000000..bab3fe8218 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/IntegrationProviderHero.tsx @@ -0,0 +1,164 @@ +'use client'; + +import type { ConnectionListItem, IntegrationProvider } from '@/hooks/use-integration-platform'; +import { Button } from '@trycompai/design-system'; +import { Add, Launch, Settings } from '@trycompai/design-system/icons'; +import Image from 'next/image'; +import { AccountSelector } from './AccountSelector'; +import { getConnectionDisplayLabel } from './connection-display'; + +type HeroProps = { + provider: IntegrationProvider; + isConnected: boolean; + activeConnections: ConnectionListItem[]; + selectedConnection: ConnectionListItem | null; + onSelectConnection: (id: string) => void; + onOpenSettings: () => void; + onAddAccount: () => void; +}; + +export function IntegrationProviderHero({ + provider, + isConnected, + activeConnections, + selectedConnection, + onSelectConnection, + onOpenSettings, + onAddAccount, +}: HeroProps) { + return ( +
+
+
+
+ {provider.logoUrl ? ( +
+ {provider.name} +
+ ) : null} +
+
+

+ {provider.name} +

+ {isConnected ? ( + + {activeConnections.length === 1 + ? 'Connected' + : `${activeConnections.length} accounts`} + + ) : ( + + Not connected + + )} +
+

+ {provider.description} +

+
+
+ +
+ {provider.docsUrl || isConnected ? ( +
+ {!isConnected && provider.docsUrl ? ( + + ) : null} + {isConnected ? ( +
+ {/* Row 1: Docs + Settings */} +
+
+ {provider.docsUrl ? ( +
+ +
+ ) : null} +
+ +
+
+
+ {/* Row 2: account + Add */} +
+
+
+ {activeConnections.length === 1 && selectedConnection ? ( +
+ + + {getConnectionDisplayLabel(selectedConnection)} + +
+ ) : ( +
+ +
+ )} +
+ +
+
+
+
+
+ ) : null} +
+ ) : null} +
+
+
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/PermissionErrorPanel.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/PermissionErrorPanel.tsx new file mode 100644 index 0000000000..b841f3ba60 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/PermissionErrorPanel.tsx @@ -0,0 +1,2 @@ +// Re-export from cloud-tests (canonical location) +export { PermissionErrorPanel } from '@/app/(app)/[orgId]/cloud-tests/components/PermissionErrorPanel'; diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ProviderDetailView.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ProviderDetailView.tsx new file mode 100644 index 0000000000..307c17fcbf --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ProviderDetailView.tsx @@ -0,0 +1,442 @@ +'use client'; + +import { ConnectIntegrationDialog } from '@/components/integrations/ConnectIntegrationDialog'; +import { + useConnectionServices, + useIntegrationConnections, + useIntegrationMutations, + type ConnectionListItem, + type IntegrationProvider, +} from '@/hooks/use-integration-platform'; +import { + CLOUD_RECONNECT_CUTOFF_LABEL, + requiresCloudReconnect, +} from '@/lib/cloud-reconnect-policy'; +import { api } from '@/lib/api-client'; +import { + Breadcrumb, + Button, + Stack, +} from '@trycompai/design-system'; +import { Add, Security } from '@trycompai/design-system/icons'; +import { GcpProjectPicker } from './GcpProjectPicker'; +import Link from 'next/link'; +import { useParams, useRouter } from 'next/navigation'; +import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { toast } from 'sonner'; +import { AccountSettingsSheet } from './AccountSettingsSheet'; +import { getConnectionDisplayLabel } from './connection-display'; +import { IntegrationProviderHero } from './IntegrationProviderHero'; +import { EmptyStateOnboarding } from './EmptyStateOnboarding'; +import { ServicesGrid } from './services-grid'; + +interface ProviderDetailViewProps { + provider: IntegrationProvider; + initialConnections: ConnectionListItem[]; + /** Server passes true when URL was ?success=true&provider=gcp (OAuth return) */ + gcpOAuthJustConnected?: boolean; +} + +export function ProviderDetailView({ + provider, + initialConnections, + gcpOAuthJustConnected = false, +}: ProviderDetailViewProps) { + const { orgId } = useParams<{ orgId: string }>(); + const router = useRouter(); + const { connections: allConnections, refresh: refreshConnections } = useIntegrationConnections(); + const { startOAuth } = useIntegrationMutations(); + const [showAddAccount, setShowAddAccount] = useState(false); + const [settingsOpen, setSettingsOpen] = useState(false); + const [reconnectDialogOpen, setReconnectDialogOpen] = useState(false); + + const connections = useMemo(() => { + const live = allConnections.filter((c) => c.providerSlug === provider.id); + return live.length > 0 ? live : initialConnections; + }, [allConnections, initialConnections, provider.id]); + + const activeConnections = connections.filter( + (c) => c.status === 'active' || c.status === 'pending', + ); + const isConnected = activeConnections.length > 0; + const [selectedConnectionId, setSelectedConnectionId] = useState(null); + + const selectedConnection = useMemo(() => { + if (selectedConnectionId) { + return ( + activeConnections.find((c) => c.id === selectedConnectionId) ?? activeConnections[0] ?? null + ); + } + return activeConnections[0] ?? null; + }, [selectedConnectionId, activeConnections]); + + const services = + ( + provider as IntegrationProvider & { + services?: Array<{ id: string; name: string; description: string; implemented?: boolean }>; + } + ).services ?? []; + const isCloudProvider = provider.category === 'Cloud'; + const selectedConnectionRequiresReconnect = useMemo(() => { + if (!isCloudProvider || !selectedConnection) return false; + const metadata = (selectedConnection.metadata || {}) as Record; + return requiresCloudReconnect({ + providerId: provider.id, + createdAt: selectedConnection.createdAt, + reconnectedAt: + typeof metadata.reconnectedAt === 'string' ? metadata.reconnectedAt : null, + status: selectedConnection.status, + }); + }, [isCloudProvider, provider.id, selectedConnection]); + + // Services hook for the selected connection + const { + services: connectionServices, + meta: servicesMeta, + refresh: refreshServices, + updateServices, + } = useConnectionServices(selectedConnection?.id ?? null); + const [togglingService, setTogglingService] = useState(null); + const [gcpOrgs, setGcpOrgs] = useState< + Array<{ + id: string; + displayName: string; + projects: Array<{ id: string; name: string }>; + }> + >([]); + const [gcpSelectedProjectIds, setGcpSelectedProjectIds] = useState([]); + const oauthBootstrapHandledRef = useRef(false); + + const handleToggleService = useCallback( + async (serviceId: string, enabled: boolean): Promise => { + setTogglingService(serviceId); + try { + await updateServices(serviceId, enabled); + toast.success( + `${services.find((s) => s.id === serviceId)?.name ?? serviceId} ${enabled ? 'enabled' : 'disabled'}`, + ); + return true; + } catch (err) { + toast.error(err instanceof Error ? err.message : 'Failed to update'); + return false; + } finally { + setTogglingService(null); + } + }, + [updateServices, services], + ); + + // OAuth return (?success=true): strip query, detect org/projects (NOT services yet — user must select projects first) + useEffect(() => { + if ( + !gcpOAuthJustConnected || + provider.id !== 'gcp' || + !selectedConnection?.id || + oauthBootstrapHandledRef.current + ) { + return; + } + oauthBootstrapHandledRef.current = true; + router.replace(`/${orgId}/integrations/gcp`, { scroll: false }); + + // For GCP: only detect orgs/projects — service detection waits for project selection + // (The detect-gcp-org effect already fires on connection) + }, [ + gcpOAuthJustConnected, + provider.id, + selectedConnection?.id, + orgId, + router, + ]); + + // Auto-detect GCP organization after OAuth connect + const gcpDetectedRef = useRef>(new Set()); + useEffect(() => { + if ( + provider.id !== 'gcp' || + !isConnected || + !selectedConnection?.id || + gcpDetectedRef.current.has(selectedConnection.id) + ) { + return; + } + gcpDetectedRef.current.add(selectedConnection.id); + api + .post<{ + organizations: Array<{ + id: string; + displayName: string; + projects: Array<{ id: string; name: string }>; + }>; + selectedProjectIds?: string[]; + }>( + `/v1/cloud-security/detect-gcp-org/${selectedConnection.id}`, + ) + .then((res) => { + const orgs = res.data?.organizations ?? []; + if (orgs.length > 0) setGcpOrgs(orgs); + if (res.data?.selectedProjectIds?.length) { + setGcpSelectedProjectIds(res.data.selectedProjectIds); + } + if (orgs.length === 1) { + toast.success(`Connected to GCP organization: ${orgs[0].displayName}`); + } else if (orgs.length > 1) { + toast.info(`${orgs.length} GCP organizations found — select projects below.`); + } + }) + .catch(() => {}); + }, [provider.id, isConnected, selectedConnection?.id]); + + // Auto-detect services when switching accounts (skip GCP — needs project selection first) + const detectedConnections = useRef>(new Set()); + useEffect(() => { + if ( + !isCloudProvider || + !isConnected || + !selectedConnection?.id || + detectedConnections.current.has(selectedConnection.id) || + provider.id === 'gcp' + ) { + return; + } + + const connectionForRun = selectedConnection; + api + .post<{ services: string[] }>( + `/v1/cloud-security/detect-services/${connectionForRun.id}`, + ) + .then((res) => { + if (res.error) return; + detectedConnections.current.add(connectionForRun.id); + const count = res.data?.services?.length; + if (count) { + const name = getConnectionDisplayLabel(connectionForRun); + toast.success(`${count} services detected${name ? ` in ${name}` : ''}`); + } + return refreshServices(); + }) + .catch(() => {}); + }, [isCloudProvider, isConnected, selectedConnection?.id, refreshServices, provider.id]); + + const handleConnect = useCallback(async () => { + if (provider.authType === 'oauth2') { + const redirectUrl = `${window.location.origin}/${orgId}/integrations/${provider.id}?success=true`; + const result = await startOAuth(provider.id, redirectUrl); + if (result?.authorizationUrl) { + window.location.href = result.authorizationUrl; + } else { + toast.error(result.error || 'Failed to start connection'); + } + return; + } else { + // For non-OAuth, show the inline add-account form + setShowAddAccount(true); + } + }, [provider, orgId, startOAuth]); + + return ( + <> + + }, + }, + { label: provider.name, isCurrent: true }, + ]} + /> + + setSelectedConnectionId(id)} + onOpenSettings={() => setSettingsOpen(true)} + onAddAccount={() => void handleConnect()} + /> + + {selectedConnectionRequiresReconnect && ( +
+
+

Reconnect this account

+

+ This connection was created before {CLOUD_RECONNECT_CUTOFF_LABEL}. Reconnect it to keep scans and remediation fully reliable. +

+
+ +
+ )} + + {/* Content: zero state OR findings */} + {!isConnected && ( + { + if (isCloudProvider) { + // Redirect to Cloud Tests after connecting a cloud provider + window.location.href = `/${orgId}/cloud-tests?provider=${provider.id}`; + } + }} + onOAuthConnect={handleConnect} + /> + )} + + {isConnected && isCloudProvider && ( +
+ {/* Link to Cloud Tests — findings + fix now live there */} + +
+

View Findings & Auto-Fix

+

+ Security findings, auto-remediation, and batch fixes are now in Cloud Tests. +

+
+ Open Cloud Tests → +
+ + {/* GCP org → project selector */} + {provider.id === 'gcp' && gcpOrgs.length > 0 && selectedConnection && ( + { + const prev = gcpSelectedProjectIds; + const next = prev.includes(projectId) + ? prev.filter((id) => id !== projectId) + : [...prev, projectId]; + setGcpSelectedProjectIds(next); + // Build name map keyed by both project ID and project number + const allProjects = gcpOrgs.flatMap((o) => o.projects); + const projectNames: Record = {}; + for (const pid of next) { + const p = allProjects.find((proj) => proj.id === pid) as + | { id: string; name: string; number?: string } + | undefined; + if (p) { + projectNames[pid] = p.name; + if (p.number) projectNames[p.number] = p.name; + } + } + void api + .post( + `/v1/cloud-security/select-gcp-projects/${selectedConnection.id}`, + { projectIds: next, projectNames, gcpOrganizationId: gcpOrgId }, + ) + .then(async (res) => { + if (res.error) { + setGcpSelectedProjectIds(prev); + toast.error('Failed to update projects'); + return; + } + toast.success( + next.length === 0 + ? 'All projects deselected' + : `${next.length} project(s) selected`, + ); + // Re-run service detection for the new project selection + if (next.length > 0) { + await api.post( + `/v1/cloud-security/detect-services/${selectedConnection.id}`, + ); + refreshServices(); + } + }); + }} + /> + )} + + {/* Services config */} + {services.length > 0 && ( +
+

Services

+ {provider.id === 'gcp' && gcpSelectedProjectIds.length === 0 ? ( +
+
+ + + +
+
+

Select projects first

+

+ Choose GCP projects above to detect which services are active. Service toggles will appear here once projects are selected. +

+
+
+ ) : provider.id === 'gcp' && + servicesMeta.detectionReady === false ? ( +
+

+ Detecting active GCP services… +

+

+ Checking which APIs are enabled in your selected projects. +

+
+
+
+
+ ) : ( + + )} +
+ )} +
+ )} + + + {/* Account settings sheet */} + {selectedConnection && ( + + )} + + {/* Inline add-account form (shown when clicking "+ Add" while already connected) */} + {showAddAccount && isConnected && ( + setShowAddAccount(false)} + onOAuthConnect={handleConnect} + /> + )} + + {selectedConnectionRequiresReconnect && ( + { + setReconnectDialogOpen(false); + setShowAddAccount(false); + refreshConnections(); + }} + /> + )} + + ); +} + diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/RemediationDialog.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/RemediationDialog.tsx new file mode 100644 index 0000000000..9727ada31c --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/RemediationDialog.tsx @@ -0,0 +1,2 @@ +// Re-export from cloud-tests (canonical location) +export { RemediationDialog } from '@/app/(app)/[orgId]/cloud-tests/components/RemediationDialog'; diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/RemediationHistorySection.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/RemediationHistorySection.tsx new file mode 100644 index 0000000000..58bb7703c9 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/RemediationHistorySection.tsx @@ -0,0 +1,2 @@ +// Re-export from cloud-tests (canonical location) +export { RemediationHistorySection } from '@/app/(app)/[orgId]/cloud-tests/components/RemediationHistorySection'; diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ServiceCard.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ServiceCard.tsx new file mode 100644 index 0000000000..e09ad4dc66 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/ServiceCard.tsx @@ -0,0 +1,166 @@ +'use client'; + +import { useConnectionServices } from '@/hooks/use-integration-platform'; +import { Badge } from '@trycompai/ui/badge'; +import { + Cloud, + Database, + Globe, + HardDrive, + Key, + Lock, + MonitorCheck, + Network, + ScanSearch, + Server, + Shield, + Terminal, + Workflow, +} from 'lucide-react'; + +const SERVICE_ICONS: Record = { + 'security-hub': Shield, + 'iam-analyzer': Key, + 'cloudtrail': ScanSearch, + 's3': HardDrive, + 'ec2-vpc': Server, + 'rds': Database, + 'kms': Lock, + 'cloudwatch': MonitorCheck, + 'config': MonitorCheck, + 'guardduty': Shield, + 'secrets-manager': Key, + 'waf': Shield, + 'elb': Network, + 'acm': Lock, + 'backup': HardDrive, + 'inspector': ScanSearch, + 'ecs-eks': Server, + 'lambda': Terminal, + 'dynamodb': Database, + 'sns-sqs': Workflow, + 'ecr': Server, + 'opensearch': Database, + 'redshift': Database, + 'macie': ScanSearch, + 'route53': Globe, + 'api-gateway': Network, + 'cloudfront': Globe, + 'cognito': Key, + 'elasticache': Database, + 'efs': HardDrive, + 'msk': Workflow, + 'sagemaker': Cloud, + 'systems-manager': Terminal, + 'codebuild': Terminal, + 'network-firewall': Shield, + 'shield': Shield, + 'kinesis': Workflow, + 'glue': Workflow, + 'athena': Database, + 'emr': Cloud, + 'step-functions': Workflow, + 'eventbridge': Workflow, + 'transfer-family': Network, + 'elastic-beanstalk': Cloud, + 'appflow': Workflow, +}; + +interface ServiceMeta { + id: string; + name: string; + description: string; + enabledByDefault?: boolean; + implemented?: boolean; +} + +function ServiceIcon({ serviceId }: { serviceId: string }) { + const Icon = SERVICE_ICONS[serviceId] as React.ComponentType<{ className?: string }> | undefined; + if (!Icon) return null; + return ( +
+ +
+ ); +} + +interface ServiceCardProps { + service: ServiceMeta; + connectionId: string | null; + isConnected: boolean; + onToggle?: (id: string, enabled: boolean) => void | Promise; + toggling?: boolean; +} + +export function ServiceCard({ + service, + connectionId, + isConnected, + onToggle, + toggling, +}: ServiceCardProps) { + const { services } = useConnectionServices(connectionId); + + const isImplemented = service.implemented !== false; + const liveService = services.find((s) => s.id === service.id); + const isEnabled = liveService?.enabled ?? false; + const showToggle = isImplemented && isConnected && onToggle; + + return ( +
+
+ +
+
+ {service.name} + {!isImplemented && ( + + Coming Soon + + )} +
+

+ {service.description} +

+ {liveService?.projects && liveService.projects.length > 0 && ( +
+ {liveService.projects.map((pid) => ( + + {pid} + + ))} +
+ )} +
+ {showToggle && ( + + )} +
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/account-settings-oauth.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/account-settings-oauth.tsx new file mode 100644 index 0000000000..0e90c65c43 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/account-settings-oauth.tsx @@ -0,0 +1,225 @@ +'use client'; + +import type { IntegrationProvider } from '@/hooks/use-integration-platform'; +import { + useIntegrationConnection, + useIntegrationMutations, +} from '@/hooks/use-integration-platform'; +import { Button } from '@trycompai/design-system'; +import { Badge } from '@trycompai/ui/badge'; +import { AlertTriangle, CheckCircle2, Loader2 } from 'lucide-react'; +import { useCallback, useEffect, useState } from 'react'; +import { toast } from 'sonner'; +import { AccountSettingsInfoRow } from './account-settings-shared-ui'; +import { + OAuthConnectionVariablesForm, + type OAuthVariableRow, +} from './oauth-connection-variables-form'; + +interface VariablesResponse { + connectionId: string; + providerSlug: string; + variables: OAuthVariableRow[]; +} + +type AccountSettingsOAuthProps = { + open: boolean; + connectionId: string; + provider: IntegrationProvider; + onUpdated?: () => void; + onOpenChange: (open: boolean) => void; +}; + +export function AccountSettingsOAuthBody({ + open, + connectionId, + provider, + onUpdated, + onOpenChange, +}: AccountSettingsOAuthProps) { + const { connection, isLoading } = useIntegrationConnection( + open ? connectionId : null, + ); + const { + getConnectionVariables, + saveConnectionVariables, + getVariableOptions, + deleteConnection, + } = useIntegrationMutations(); + + const [variables, setVariables] = useState([]); + const [variableValues, setVariableValues] = useState< + Record + >({}); + const [loadingVariables, setLoadingVariables] = useState(false); + const [savingVariables, setSavingVariables] = useState(false); + const [disconnecting, setDisconnecting] = useState(false); + const [dynamicOptions, setDynamicOptions] = useState< + Record + >({}); + const [loadingOptions, setLoadingOptions] = useState>({}); + + const loadVariables = useCallback(async () => { + setLoadingVariables(true); + try { + const result = await getConnectionVariables(connectionId); + if (result.data?.variables) { + setVariables(result.data.variables); + const next: Record = {}; + for (const v of result.data.variables) { + if (v.currentValue !== undefined) { + next[v.id] = v.currentValue as string | number | boolean | string[]; + } + } + setVariableValues(next); + } + if (result.error) toast.error('Failed to load settings'); + } catch { + toast.error('Failed to load settings'); + } finally { + setLoadingVariables(false); + } + }, [connectionId, getConnectionVariables]); + + useEffect(() => { + if (!open) return; + void loadVariables(); + }, [open, loadVariables]); + + const fetchOptions = useCallback( + async (variableId: string) => { + setLoadingOptions((p) => ({ ...p, [variableId]: true })); + try { + const result = await getVariableOptions(connectionId, variableId); + if (result.options) { + setDynamicOptions((p) => ({ ...p, [variableId]: result.options! })); + } + } finally { + setLoadingOptions((p) => ({ ...p, [variableId]: false })); + } + }, + [connectionId, getVariableOptions], + ); + + const handleSaveVariables = useCallback(async () => { + setSavingVariables(true); + try { + const result = await saveConnectionVariables(connectionId, variableValues); + if (!result.success) { + toast.error(result.error || 'Failed to save'); + return; + } + toast.success('Settings saved'); + onUpdated?.(); + await loadVariables(); + } catch { + toast.error('Failed to save'); + } finally { + setSavingVariables(false); + } + }, [connectionId, saveConnectionVariables, variableValues, onUpdated, loadVariables]); + + const handleDisconnect = useCallback(async () => { + if (!confirm('Are you sure? All associated data will be removed.')) return; + setDisconnecting(true); + try { + const result = await deleteConnection(connectionId); + if (result.success) { + toast.success('Disconnected'); + onOpenChange(false); + onUpdated?.(); + } else toast.error(result.error || 'Failed'); + } catch { + toast.error('Failed'); + } finally { + setDisconnecting(false); + } + }, [connectionId, deleteConnection, onOpenChange, onUpdated]); + + if (!open) { + return null; + } + + if (isLoading || loadingVariables) { + return ( +
+ +
+ ); + } + + return ( +
+
+ + + + Active + + ) : ( + + {connection?.status} + + ) + } + /> + {connection?.createdAt && ( + + )} +
+ + + +
+
+
+ +
+

Disconnect

+

+ Remove this account and all data +

+
+
+ +
+
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/account-settings-shared-ui.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/account-settings-shared-ui.tsx new file mode 100644 index 0000000000..f059de1d50 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/account-settings-shared-ui.tsx @@ -0,0 +1,75 @@ +'use client'; + +import { Label } from '@trycompai/design-system'; +import type { ReactNode } from 'react'; + +export function AccountSettingsSection({ + label, + children, +}: { + label: string; + children: ReactNode; +}) { + return ( +
+

+ {label} +

+ {children} +
+ ); +} + +export function AccountSettingsFieldGroup({ + label, + children, +}: { + label: string; + children: ReactNode; +}) { + return ( +
+
+ +
+ {children} +
+ ); +} + +export function AccountSettingsInfoRow({ + label, + value, + mono, + badge, + valueTruncate, +}: { + label: string; + value?: string; + mono?: boolean; + badge?: ReactNode; + valueTruncate?: boolean; +}) { + return ( +
+ + {label} + + {badge ?? ( + + {value} + + )} +
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/aws-account-settings-body.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/aws-account-settings-body.tsx new file mode 100644 index 0000000000..e81948d8c8 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/aws-account-settings-body.tsx @@ -0,0 +1,300 @@ +'use client'; + +import { CloudShellSetup } from '@/components/integrations/CloudShellSetup'; +import { CredentialInput } from '@/components/integrations/CredentialInput'; +import type { IntegrationProvider } from '@/hooks/use-integration-platform'; +import { + useIntegrationConnection, + useIntegrationMutations, +} from '@/hooks/use-integration-platform'; +import { Button } from '@trycompai/design-system'; +import { Badge } from '@trycompai/ui/badge'; +import { awsRemediationScript } from '@trycompai/integration-platform'; +import { AlertTriangle, CheckCircle2, Loader2 } from 'lucide-react'; +import { useCallback, useEffect, useState } from 'react'; +import { toast } from 'sonner'; +import { + AccountSettingsFieldGroup, + AccountSettingsInfoRow, + AccountSettingsSection, +} from './account-settings-shared-ui'; + +export function AwsAccountSettingsBody({ + open, + connectionId, + provider, + orgId, + onUpdated, +}: { + open: boolean; + connectionId: string; + provider: IntegrationProvider; + orgId: string; + onUpdated?: () => void; +}) { + const { connection, isLoading } = useIntegrationConnection(open ? connectionId : null); + const { updateConnectionCredentials, updateConnectionMetadata, deleteConnection } = + useIntegrationMutations(); + + const [roleArn, setRoleArn] = useState(''); + const [remediationRoleArn, setRemediationRoleArn] = useState(''); + const [regions, setRegions] = useState([]); + const [savingCredentials, setSavingCredentials] = useState(false); + const [savingRemediation, setSavingRemediation] = useState(false); + const [savingRegions, setSavingRegions] = useState(false); + const [disconnecting, setDisconnecting] = useState(false); + + const metadata = (connection?.metadata ?? {}) as Record; + const displayName = + (metadata.connectionName as string) ?? (metadata.accountId as string) ?? connectionId; + const accountId = metadata.accountId as string | undefined; + const externalId = (metadata.externalId as string) ?? orgId; + const hasRemediation = Boolean(metadata.remediationRoleArn); + const regionsField = provider.credentialFields?.find((f) => f.id === 'regions'); + + useEffect(() => { + if (!connection) return; + setRoleArn((metadata.roleArn as string) ?? ''); + setRemediationRoleArn((metadata.remediationRoleArn as string) ?? ''); + setRegions(Array.isArray(metadata.regions) ? (metadata.regions as string[]) : []); + }, [connection, metadata.roleArn, metadata.remediationRoleArn, metadata.regions]); + + const saveField = useCallback( + async ( + creds: Record, + metaUpdates: Record, + setLoading: (v: boolean) => void, + successMsg: string, + ) => { + setLoading(true); + try { + const result = await updateConnectionCredentials(connectionId, creds); + if (!result.success) { + toast.error(result.error || 'Failed to save'); + return; + } + if (Object.keys(metaUpdates).length > 0) { + await updateConnectionMetadata(connectionId, metaUpdates); + } + toast.success(successMsg); + onUpdated?.(); + } catch { + toast.error('Failed to save'); + } finally { + setLoading(false); + } + }, + [connectionId, updateConnectionCredentials, updateConnectionMetadata, onUpdated], + ); + + const handleSaveCredentials = useCallback(async () => { + if (!roleArn.trim()) { + toast.error('Role ARN is required'); + return; + } + const meta: Record = { roleArn }; + const arnMatch = roleArn.match(/^arn:aws:iam::(\d{12}):role\/.+$/); + if (arnMatch) meta.accountId = arnMatch[1]; + await saveField({ roleArn }, meta, setSavingCredentials, 'Credentials saved'); + }, [roleArn, saveField]); + + const handleSaveRemediation = useCallback(async () => { + await saveField( + { remediationRoleArn }, + { remediationRoleArn }, + setSavingRemediation, + 'Remediation role saved', + ); + }, [remediationRoleArn, saveField]); + + const handleSaveRegions = useCallback(async () => { + if (regions.length === 0) { + toast.error('Select at least one region'); + return; + } + await saveField({ regions }, { regions }, setSavingRegions, 'Regions saved'); + }, [regions, saveField]); + + const handleDisconnect = useCallback(async () => { + if (!confirm('Are you sure? All associated data will be removed.')) return; + setDisconnecting(true); + try { + const result = await deleteConnection(connectionId); + if (result.success) { + toast.success('Disconnected'); + onUpdated?.(); + } else toast.error(result.error || 'Failed'); + } catch { + toast.error('Failed'); + } finally { + setDisconnecting(false); + } + }, [connectionId, deleteConnection, onUpdated]); + + if (isLoading) { + return ( +
+ +
+ ); + } + + return ( +
+
+ + + Active + + ) : ( + + {connection?.status} + + ) + } + /> + {accountId && ( + + )} + {displayName && !accountId && ( + + )} + {regions.length > 0 && ( + + )} + {connection?.createdAt && ( + + )} +
+ + + + setRoleArn(v as string)} + /> + + +

+ {externalId} +

+
+ +
+ + +
+ Status + {hasRemediation ? ( + + + Configured + + ) : ( + + Not configured + + )} +
+ + + setRemediationRoleArn(v as string)} + /> + + +
+ + {regionsField && ( + + setRegions(v as string[])} + /> + + + )} + +
+
+
+ +
+

Disconnect

+

Remove this account and all data

+
+
+ +
+
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/connection-display.test.ts b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/connection-display.test.ts new file mode 100644 index 0000000000..b4c7541e74 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/connection-display.test.ts @@ -0,0 +1,71 @@ +import type { ConnectionListItem } from '@/hooks/use-integration-platform'; +import { describe, expect, it } from 'vitest'; +import { getConnectionDisplayLabel, getRegionCount } from './connection-display'; + +function conn(overrides: Partial & { id: string }): ConnectionListItem { + return { + providerId: 'prv_x', + providerSlug: 'aws', + providerName: 'AWS', + status: 'active', + authStrategy: 'custom', + lastSyncAt: null, + nextSyncAt: null, + errorMessage: null, + variables: null, + createdAt: '2026-01-01T00:00:00.000Z', + ...overrides, + }; +} + +describe('getConnectionDisplayLabel', () => { + it('prefers connectionName from metadata', () => { + expect( + getConnectionDisplayLabel( + conn({ + id: 'icn_abc', + metadata: { connectionName: 'Production', accountId: '123' }, + }), + ), + ).toBe('Production'); + }); + + it('uses AWS accountId when no connectionName', () => { + expect( + getConnectionDisplayLabel( + conn({ + id: 'icn_abc', + metadata: { accountId: '013388577167' }, + }), + ), + ).toBe('AWS 013388577167'); + }); + + it('parses account id from roleArn', () => { + expect( + getConnectionDisplayLabel( + conn({ + id: 'icn_abc', + metadata: { roleArn: 'arn:aws:iam::013388577167:role/x' }, + }), + ), + ).toBe('AWS 013388577167'); + }); +}); + +describe('getRegionCount', () => { + it('returns length of regions array in metadata', () => { + expect( + getRegionCount( + conn({ + id: 'icn_x', + metadata: { regions: ['us-east-1', 'eu-west-1'] }, + }), + ), + ).toBe(2); + }); + + it('returns 0 when missing', () => { + expect(getRegionCount(null)).toBe(0); + }); +}); diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/connection-display.ts b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/connection-display.ts new file mode 100644 index 0000000000..4248c24fae --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/connection-display.ts @@ -0,0 +1,23 @@ +import type { ConnectionListItem } from '@/hooks/use-integration-platform'; + +/** Human-readable label for a connection (matches AccountSelector). */ +export function getConnectionDisplayLabel(connection: ConnectionListItem): string { + const meta = (connection.metadata ?? {}) as Record; + if (typeof meta.connectionName === 'string' && meta.connectionName) { + return meta.connectionName; + } + if (typeof meta.accountId === 'string' && meta.accountId) { + return `AWS ${meta.accountId}`; + } + const roleArn = meta.roleArn as string | undefined; + const arnMatch = roleArn?.match(/arn:aws:iam::(\d{12})/); + if (arnMatch) return `AWS ${arnMatch[1]}`; + return `Account ${connection.id.slice(4, 12)}`; +} + +export function getRegionCount(connection: ConnectionListItem | null): number { + if (!connection) return 0; + const meta = (connection.metadata ?? {}) as Record; + if (Array.isArray(meta.regions)) return meta.regions.length; + return 0; +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/oauth-connection-variables-form.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/oauth-connection-variables-form.tsx new file mode 100644 index 0000000000..aab999ddc6 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/oauth-connection-variables-form.tsx @@ -0,0 +1,157 @@ +'use client'; + +import { Button, Label } from '@trycompai/design-system'; +import { Input } from '@trycompai/ui/input'; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@trycompai/ui/select'; +import { Loader2 } from 'lucide-react'; +import type { Dispatch, SetStateAction } from 'react'; + +export type OAuthVariableRow = { + id: string; + label: string; + type: string; + required: boolean; + helpText?: string; + placeholder?: string; + description?: string; + currentValue?: string | number | boolean | string[]; + hasDynamicOptions?: boolean; + options?: { value: string; label: string }[]; +}; + +type Props = { + variables: OAuthVariableRow[]; + variableValues: Record; + setVariableValues: Dispatch< + SetStateAction> + >; + dynamicOptions: Record; + loadingOptions: Record; + fetchOptions: (variableId: string) => void; + onSave: () => void; + savingVariables: boolean; +}; + +export function OAuthConnectionVariablesForm({ + variables, + variableValues, + setVariableValues, + dynamicOptions, + loadingOptions, + fetchOptions, + onSave, + savingVariables, +}: Props) { + if (variables.length === 0) { + return ( +

+ No extra settings for this connection. You can disconnect below if needed. +

+ ); + } + + return ( +
+

+ Configuration +

+ {variables.map((variable) => { + const options = dynamicOptions[variable.id] ?? variable.options ?? []; + return ( +
+ + {variable.description ? ( +

{variable.description}

+ ) : null} + {variable.helpText ? ( +

{variable.helpText}

+ ) : null} + + {variable.type === 'boolean' ? ( + + ) : variable.type === 'select' ? ( + + ) : ( + + setVariableValues((prev) => ({ + ...prev, + [variable.id]: + variable.type === 'number' ? Number(e.target.value) : e.target.value, + })) + } + placeholder={variable.placeholder} + /> + )} +
+ ); + })} + +
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/services-grid.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/services-grid.tsx new file mode 100644 index 0000000000..db3e86f87e --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/components/services-grid.tsx @@ -0,0 +1,91 @@ +'use client'; + +import { orderServicesForConnectionGrid } from '@/lib/connection-services-display-order'; +import { Search } from '@trycompai/design-system/icons'; +import { useCallback, useEffect, useMemo, useState } from 'react'; +import { ServiceCard } from './ServiceCard'; + +export function ServicesGrid({ + services, + connectionServices = [], + connectionId, + onToggle, + togglingService, +}: { + services: Array<{ id: string; name: string; description: string; implemented?: boolean }>; + connectionServices?: Array<{ id: string; enabled: boolean }>; + connectionId: string | null; + onToggle: (id: string, enabled: boolean) => boolean | void | Promise; + togglingService: string | null; +}) { + const [search, setSearch] = useState(''); + const [tailEnabledIds, setTailEnabledIds] = useState([]); + + useEffect(() => { + setTailEnabledIds([]); + }, [connectionId]); + + const handleToggle = useCallback( + async (id: string, enabled: boolean) => { + let rollback: string[] | null = null; + setTailEnabledIds((prev) => { + rollback = [...prev]; + if (enabled) return [...prev.filter((x) => x !== id), id]; + return prev.filter((x) => x !== id); + }); + const result = await Promise.resolve(onToggle(id, enabled)); + if (result === false && rollback) { + setTailEnabledIds(rollback); + } + }, + [onToggle], + ); + + const displayedServices = useMemo( + () => + orderServicesForConnectionGrid({ + manifestServices: services, + connectionServices, + search, + tailEnabledIds, + }), + [services, connectionServices, search, tailEnabledIds], + ); + + return ( +
+
+
+ + setSearch(e.target.value)} + className="w-44 rounded-md border bg-background py-1.5 pl-7 pr-3 text-xs placeholder:text-muted-foreground/50 focus:outline-none focus:ring-2 focus:ring-primary/30" + /> +
+
+
+ {displayedServices.map((service) => ( + + ))} + {displayedServices.length === 0 && search && ( +

+ No services matching "{search}" +

+ )} +
+
+ ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/[slug]/page.tsx b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/page.tsx new file mode 100644 index 0000000000..2059b80bb0 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/integrations/[slug]/page.tsx @@ -0,0 +1,48 @@ +import { serverApi } from '@/lib/api-server'; +import type { IntegrationProviderResponse } from '@trycompai/integration-platform'; +import type { ConnectionListItemResponse } from '@trycompai/integration-platform'; +import { PageLayout } from '@trycompai/design-system'; +import { redirect } from 'next/navigation'; +import { ProviderDetailView } from './components/ProviderDetailView'; + +interface PageProps { + params: Promise<{ orgId: string; slug: string }>; + searchParams: Promise>; +} + +export default async function ProviderDetailPage({ params, searchParams }: PageProps) { + const { orgId, slug } = await params; + const sp = await searchParams; + const success = typeof sp.success === 'string' ? sp.success : ''; + const providerParam = typeof sp.provider === 'string' ? sp.provider : ''; + const gcpOAuthJustConnected = + slug === 'gcp' && success === 'true' && providerParam === 'gcp'; + + const [providerResult, connectionsResult] = await Promise.all([ + serverApi.get( + `/v1/integrations/connections/providers/${slug}`, + ), + serverApi.get( + '/v1/integrations/connections', + ), + ]); + + if (!providerResult.data || providerResult.error) { + redirect(`/${orgId}/integrations`); + } + + const provider = providerResult.data; + const connections = (connectionsResult.data ?? []).filter( + (c) => c.providerSlug === slug, + ); + + return ( + + + + ); +} diff --git a/apps/app/src/app/(app)/[orgId]/integrations/components/PlatformIntegrations.test.tsx b/apps/app/src/app/(app)/[orgId]/integrations/components/PlatformIntegrations.test.tsx index b10e3ff0f9..c90a4b8326 100644 --- a/apps/app/src/app/(app)/[orgId]/integrations/components/PlatformIntegrations.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/integrations/components/PlatformIntegrations.test.tsx @@ -17,36 +17,35 @@ vi.mock('@/hooks/use-permissions', () => ({ })); // Mock integration platform hooks -const mockStartOAuth = vi.fn(); +const { + mockStartOAuth, + mockUseIntegrationProviders, + mockUseIntegrationConnections, + mockUseVendors, +} = vi.hoisted(() => ({ + mockStartOAuth: vi.fn(), + mockUseIntegrationProviders: vi.fn(), + mockUseIntegrationConnections: vi.fn(), + mockUseVendors: vi.fn(), +})); + +const { mockRouterPush, mockUseSearchParams } = vi.hoisted(() => ({ + mockRouterPush: vi.fn(), + mockUseSearchParams: vi.fn(() => new URLSearchParams()), +})); + vi.mock('@/hooks/use-integration-platform', () => ({ - useIntegrationProviders: () => ({ - providers: [ - { - id: 'github', - name: 'GitHub', - description: 'Code hosting', - category: 'Development', - logoUrl: '/github.png', - authType: 'oauth2', - oauthConfigured: true, - isActive: true, - requiredVariables: [], - mappedTasks: [], - supportsMultipleConnections: false, - }, - ], - isLoading: false, - }), - useIntegrationConnections: () => ({ - connections: [], - isLoading: false, - refresh: vi.fn(), - }), + useIntegrationProviders: mockUseIntegrationProviders, + useIntegrationConnections: mockUseIntegrationConnections, useIntegrationMutations: () => ({ startOAuth: mockStartOAuth, }), })); +vi.mock('@/hooks/use-vendors', () => ({ + useVendors: mockUseVendors, +})); + // Mock integrations data vi.mock('../data/integrations', () => ({ CATEGORIES: ['Development'], @@ -92,8 +91,8 @@ vi.mock('next/link', () => ({ // Mock next/navigation vi.mock('next/navigation', () => ({ useParams: () => ({ orgId: 'org-1' }), - useRouter: () => ({ push: vi.fn() }), - useSearchParams: () => new URLSearchParams(), + useRouter: () => ({ push: mockRouterPush }), + useSearchParams: mockUseSearchParams, })); // Mock @trycompai/ui components @@ -159,6 +158,39 @@ const defaultProps = { describe('PlatformIntegrations', () => { beforeEach(() => { vi.clearAllMocks(); + mockUseSearchParams.mockReturnValue(new URLSearchParams() as any); + mockUseIntegrationProviders.mockReturnValue({ + providers: [ + { + id: 'github', + name: 'GitHub', + description: 'Code hosting', + category: 'Development', + logoUrl: '/github.png', + authType: 'oauth2', + oauthConfigured: true, + isActive: true, + requiredVariables: [], + mappedTasks: [], + supportsMultipleConnections: false, + }, + ], + isLoading: false, + }); + mockUseIntegrationConnections.mockReturnValue({ + connections: [], + isLoading: false, + refresh: vi.fn(), + }); + mockUseVendors.mockReturnValue({ + data: { + data: { + data: [], + count: 0, + }, + status: 200, + }, + }); }); describe('Permission gating', () => { @@ -203,6 +235,53 @@ describe('PlatformIntegrations', () => { expect(screen.getByText('GitHub')).toBeInTheDocument(); expect(screen.getByTestId('search-input')).toBeInTheDocument(); }); + + it('treats provider as connected when an active connection exists alongside older disconnected rows', () => { + setMockPermissions(ADMIN_PERMISSIONS); + mockUseIntegrationProviders.mockReturnValue({ + providers: [ + { + id: 'gcp', + name: 'Google Cloud Platform', + description: 'Cloud security', + category: 'Cloud', + logoUrl: '/gcp.png', + authType: 'oauth2', + oauthConfigured: true, + isActive: true, + requiredVariables: [], + mappedTasks: [], + supportsMultipleConnections: true, + }, + ], + isLoading: false, + }); + mockUseIntegrationConnections.mockReturnValue({ + connections: [ + // Newest row returned first by API + { + id: 'conn-new-active', + providerSlug: 'gcp', + status: 'active', + variables: {}, + createdAt: '2026-04-14T00:00:00.000Z', + }, + { + id: 'conn-old-disconnected', + providerSlug: 'gcp', + status: 'disconnected', + variables: {}, + createdAt: '2026-04-01T00:00:00.000Z', + }, + ] as any, + isLoading: false, + refresh: vi.fn(), + }); + + render(); + + expect(screen.queryByText('Connect')).not.toBeInTheDocument(); + }); }); describe('Employee sync import prompt', () => { @@ -248,10 +327,7 @@ describe('PlatformIntegrations', () => { }); // Mock useSearchParams to simulate OAuth callback - const { useSearchParams: mockUseSearchParams } = vi.mocked( - await import('next/navigation'), - ); - vi.mocked(mockUseSearchParams).mockReturnValue( + mockUseSearchParams.mockReturnValue( new URLSearchParams('success=true&provider=google-workspace') as any, ); @@ -312,10 +388,7 @@ describe('PlatformIntegrations', () => { refresh: vi.fn(), }); - const { useSearchParams: mockUseSearchParams } = vi.mocked( - await import('next/navigation'), - ); - vi.mocked(mockUseSearchParams).mockReturnValue( + mockUseSearchParams.mockReturnValue( new URLSearchParams('success=true&provider=github') as any, ); @@ -329,4 +402,77 @@ describe('PlatformIntegrations', () => { expect(toast.info).not.toHaveBeenCalled(); }); }); + + describe('Vendor-prioritized ordering', () => { + it('shows integrations from vendor list before non-vendor integrations', () => { + mockUseIntegrationProviders.mockReturnValue({ + providers: [ + { + id: 'github', + name: 'GitHub', + description: 'Code hosting', + category: 'Development', + logoUrl: '/github.png', + authType: 'oauth2', + oauthConfigured: true, + isActive: true, + requiredVariables: [], + mappedTasks: [], + supportsMultipleConnections: false, + }, + { + id: 'slack', + name: 'Slack', + description: 'Team communication', + category: 'Communication', + logoUrl: '/slack.png', + authType: 'api_key', + isActive: true, + requiredVariables: [], + mappedTasks: [], + supportsMultipleConnections: false, + }, + ], + isLoading: false, + }); + mockUseIntegrationConnections.mockReturnValue({ + connections: [ + { + id: 'conn-1', + providerSlug: 'github', + status: 'active', + variables: {}, + }, + ], + isLoading: false, + refresh: vi.fn(), + }); + mockUseVendors.mockReturnValue({ + data: { + data: { + data: [ + { + id: 'vnd-1', + name: 'Slack', + }, + ], + count: 1, + }, + status: 200, + }, + }); + + setMockPermissions(ADMIN_PERMISSIONS); + + render(); + + const integrationTitles = screen + .getAllByRole('heading', { level: 3 }) + .map((heading) => heading.textContent?.trim()) + .filter(Boolean); + + expect(integrationTitles[0]).toBe('Slack'); + expect(integrationTitles[1]).toBe('GitHub'); + }); + }); }); diff --git a/apps/app/src/app/(app)/[orgId]/integrations/components/PlatformIntegrations.tsx b/apps/app/src/app/(app)/[orgId]/integrations/components/PlatformIntegrations.tsx index 752f6273d4..d099fc24af 100644 --- a/apps/app/src/app/(app)/[orgId]/integrations/components/PlatformIntegrations.tsx +++ b/apps/app/src/app/(app)/[orgId]/integrations/components/PlatformIntegrations.tsx @@ -3,6 +3,7 @@ import { ConnectIntegrationDialog } from '@/components/integrations/ConnectIntegrationDialog'; import { ManageIntegrationDialog } from '@/components/integrations/ManageIntegrationDialog'; import { usePermissions } from '@/hooks/use-permissions'; +import { useVendors } from '@/hooks/use-vendors'; import { ConnectionListItem, IntegrationProvider, @@ -21,6 +22,7 @@ import { DialogTitle, } from '@trycompai/ui/dialog'; import { Skeleton } from '@trycompai/ui/skeleton'; +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@trycompai/ui/tooltip'; import { AlertCircle, AlertTriangle, @@ -53,7 +55,6 @@ const EMPLOYEE_SYNC_PROVIDERS = new Set([ 'google-workspace', 'rippling', 'jumpcloud', - 'ramp', ]); // Check if a provider needs variable configuration based on manifest's required variables @@ -67,6 +68,16 @@ const providerNeedsConfiguration = ( return requiredVariables.some((varId) => !currentVars[varId]); }; +const normalizeIntegrationName = (value: string): string => { + return value + .toLowerCase() + .replace(/\s*\([^)]*\)\s*$/, '') + .replace(/[_-]+/g, ' ') + .replace(/[^a-z0-9 ]+/g, '') + .replace(/\s+/g, ' ') + .trim(); +}; + interface RelevantTask { taskId: string; // Actual task ID for navigation taskTemplateId: string; @@ -84,6 +95,38 @@ interface PlatformIntegrationsProps { taskTemplates: Array<{ id: string; taskId: string; name: string; description: string }>; } +const CONNECTION_STATUS_PRIORITY: Record = { + active: 5, + pending: 4, + error: 3, + paused: 2, + disconnected: 1, +}; + +const getConnectionPriority = (connection: ConnectionListItem): number => { + return CONNECTION_STATUS_PRIORITY[connection.status] ?? 0; +}; + +const getConnectionCreatedAtMs = (connection: ConnectionListItem): number => { + const date = new Date(connection.createdAt); + return Number.isNaN(date.getTime()) ? 0 : date.getTime(); +}; + +const shouldReplaceProviderConnection = ( + current: ConnectionListItem | undefined, + candidate: ConnectionListItem, +): boolean => { + if (!current) return true; + + const currentPriority = getConnectionPriority(current); + const candidatePriority = getConnectionPriority(candidate); + if (candidatePriority !== currentPriority) { + return candidatePriority > currentPriority; + } + + return getConnectionCreatedAtMs(candidate) > getConnectionCreatedAtMs(current); +}; + export function PlatformIntegrations({ className, taskTemplates }: PlatformIntegrationsProps) { const { orgId } = useParams<{ orgId: string }>(); const router = useRouter(); @@ -97,6 +140,7 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg const { hasPermission } = usePermissions(); const canCreate = hasPermission('integration', 'create'); const { startOAuth } = useIntegrationMutations(); + const { data: vendorsResponse } = useVendors(); const [searchQuery, setSearchQuery] = useState(''); const [selectedCategory, setSelectedCategory] = useState('All'); @@ -126,7 +170,7 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg if (provider.authType === 'oauth2') { setConnectingProvider(provider.id); try { - const redirectUrl = window.location.href; + const redirectUrl = `${window.location.origin}/${orgId}/integrations/${provider.id}?success=true`; const result = await startOAuth(provider.id, redirectUrl); if (result.authorizationUrl) { window.location.href = result.authorizationUrl; @@ -141,9 +185,8 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg return; } - // For non-OAuth (api_key, basic, custom), open the connect dialog - setConnectingProviderInfo(provider); - setConnectDialogOpen(true); + // For non-OAuth (api_key, basic, custom), navigate to detail page + router.push(`/${orgId}/integrations/${provider.id}`); }; const handleConnectDialogSuccess = () => { @@ -180,13 +223,49 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg }; // Map connections by provider slug - const connectionsByProvider = useMemo( - () => new Map(connections?.map((c) => [c.providerSlug, c]) || []), - [connections], - ); + const connectionsByProvider = useMemo(() => { + const map = new Map(); + for (const connection of connections ?? []) { + const current = map.get(connection.providerSlug); + if (shouldReplaceProviderConnection(current, connection)) { + map.set(connection.providerSlug, connection); + } + } + return map; + }, [connections]); + + const vendorNames = useMemo(() => { + const vendors = vendorsResponse?.data?.data; + if (!Array.isArray(vendors)) { + return new Set(); + } - // Merge and sort: platform first (warnings, then connected, then disconnected), then custom + return new Set( + vendors + .map((vendor) => normalizeIntegrationName(vendor.name)) + .filter((name) => name.length > 0), + ); + }, [vendorsResponse]); + + // Merge/sort integrations, then prioritize entries matching vendors in the org's vendor list. const unifiedIntegrations = useMemo(() => { + const platformSortTier = ( + item: UnifiedIntegration & { type: 'platform' }, + ): 0 | 1 | 2 => { + const { provider, connection } = item; + const isComingSoon = + provider.authType === 'oauth2' && provider.oauthConfigured === false; + if (isComingSoon) return 2; + + const hasEstablishedConnection = + connection && + connection.status !== 'disconnected' && + ['active', 'pending', 'error', 'paused'].includes(connection.status); + if (hasEstablishedConnection) return 0; + + return 1; + }; + const platformIntegrations: UnifiedIntegration[] = (providers?.filter((p) => p.isActive) || []) .map((provider) => ({ type: 'platform' as const, @@ -194,30 +273,58 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg connection: connectionsByProvider.get(provider.id), })) .sort((a, b) => { - const aConnected = a.connection?.status === 'active'; - const bConnected = b.connection?.status === 'active'; - const aNeedsConfig = aConnected && a.connection?.variables === null; - const bNeedsConfig = bConnected && b.connection?.variables === null; + const tierA = platformSortTier(a); + const tierB = platformSortTier(b); + if (tierA !== tierB) return tierA - tierB; + return a.provider.name.localeCompare(b.provider.name); + }); - // Warnings first - if (aNeedsConfig && !bNeedsConfig) return -1; - if (!aNeedsConfig && bNeedsConfig) return 1; + // AI Agent integrations are hidden — they are not real integrations + const allIntegrations = [...platformIntegrations]; + if (vendorNames.size === 0) { + return allIntegrations; + } - // Then connected - if (aConnected && !bConnected) return -1; - if (!aConnected && bConnected) return 1; + const vendorListedIntegrations: UnifiedIntegration[] = []; + const otherIntegrations: UnifiedIntegration[] = []; - // Then alphabetical - return a.provider.name.localeCompare(b.provider.name); - }); + allIntegrations.forEach((integration) => { + const candidateNames = + integration.type === 'platform' + ? [integration.provider.name, integration.provider.id] + : [integration.integration.name, integration.integration.id]; - const customIntegrations: UnifiedIntegration[] = INTEGRATIONS.map((integration) => ({ - type: 'custom' as const, - integration, - })); + const isVendorListed = candidateNames + .map((candidateName) => normalizeIntegrationName(candidateName)) + .some((normalizedCandidateName) => vendorNames.has(normalizedCandidateName)); - return [...platformIntegrations, ...customIntegrations]; - }, [providers, connectionsByProvider]); + if (isVendorListed) { + vendorListedIntegrations.push(integration); + } else { + otherIntegrations.push(integration); + } + }); + + // Connected integrations always appear first, then vendor-listed, then others + const ESTABLISHED_STATUSES = new Set(['active', 'pending', 'error', 'paused']); + const sortByConnection = (list: UnifiedIntegration[]) => + list.sort((a, b) => { + const aConnected = + a.type === 'platform' && + a.connection?.status && + ESTABLISHED_STATUSES.has(a.connection.status) + ? 0 + : 1; + const bConnected = + b.type === 'platform' && + b.connection?.status && + ESTABLISHED_STATUSES.has(b.connection.status) + ? 0 + : 1; + return aConnected - bConnected; + }); + return sortByConnection([...vendorListedIntegrations, ...otherIntegrations]); + }, [providers, connectionsByProvider, vendorNames]); // Get all unique categories const allCategories = useMemo(() => { @@ -457,16 +564,51 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg connection?.variables as Record | null, ); + const isComingSoon = provider.authType === 'oauth2' && provider.oauthConfigured === false; + + /** Primary CTA is Connect / Set up — card still opens details on click; hide redundant “View details” row */ + const showConnectOrSetup = + canCreate && + !needsConfiguration && + !isConnected && + !hasError && + !isComingSoon; + return ( - router.push(`/${orgId}/integrations/${provider.id}`)} + onKeyDown={ + isComingSoon + ? undefined + : (e) => { + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault(); + router.push(`/${orgId}/integrations/${provider.id}`); + } + } + } + > +
@@ -504,7 +646,7 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg variant="ghost" size="sm" className="h-8 w-8 p-0" - onClick={() => handleOpenManageDialog(connection, provider)} + onClick={(e) => { e.preventDefault(); e.stopPropagation(); handleOpenManageDialog(connection, provider); }} > @@ -552,24 +694,59 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg ); })} {provider.mappedTasks.length > 3 && ( - - +{provider.mappedTasks.length - 3} more - + + + + + +{provider.mappedTasks.length - 3} more + + + +
+ {provider.mappedTasks.slice(3).map((t) => { + const tid = templateToTaskMap.get(t.id); + return tid ? ( + e.stopPropagation()} + > + {t.name} + + ) : ( + {t.name} + ); + })} +
+
+
+
)}
)}
+ {!isComingSoon && !showConnectOrSetup && ( +
+ View details + +
+ )} +
{needsConfiguration ? ( @@ -583,7 +760,7 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg size="sm" variant="outline" className="w-full" - onClick={() => handleConnect(provider)} + onClick={(e) => { e.preventDefault(); e.stopPropagation(); handleConnect(provider); }} disabled={isConnecting} > {isConnecting ? ( @@ -605,7 +782,7 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg ) : null}
+ {!isComingSoon && ( +
+ )} +
); } @@ -791,22 +975,7 @@ export function PlatformIntegrations({ className, taskTemplates }: PlatformInteg /> ))} - {/* Connect Dialog (for non-OAuth integrations) */} - {connectingProviderInfo && ( - { - if (!open) { - setConnectDialogOpen(false); - setConnectingProviderInfo(null); - } - }} - integrationId={connectingProviderInfo.id} - integrationName={connectingProviderInfo.name} - integrationLogoUrl={connectingProviderInfo.logoUrl} - onConnected={handleConnectDialogSuccess} - /> - )} + {/* Connect dialog removed — non-OAuth providers navigate to detail page */} {/* Custom Integration Detail Modal */} {selectedCustomIntegration && ( diff --git a/apps/app/src/app/(app)/[orgId]/layout.tsx b/apps/app/src/app/(app)/[orgId]/layout.tsx index 1a3aa8daed..869c58e6e4 100644 --- a/apps/app/src/app/(app)/[orgId]/layout.tsx +++ b/apps/app/src/app/(app)/[orgId]/layout.tsx @@ -7,7 +7,7 @@ import { resolveUserPermissions } from '@/lib/permissions.server'; import type { OrganizationFromMe } from '@/types'; import { auth } from '@/utils/auth'; import { GetObjectCommand } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; +import { getSignedUrl } from '@/lib/s3-presigner'; import { db, Role } from '@db/server'; import dynamic from 'next/dynamic'; import { cookies, headers } from 'next/headers'; diff --git a/apps/app/src/app/(app)/[orgId]/people/page.tsx b/apps/app/src/app/(app)/[orgId]/people/page.tsx index 01a725cf1d..aced60ac1d 100644 --- a/apps/app/src/app/(app)/[orgId]/people/page.tsx +++ b/apps/app/src/app/(app)/[orgId]/people/page.tsx @@ -2,7 +2,7 @@ import { filterComplianceMembers } from '@/lib/compliance'; import { auth } from '@/utils/auth'; import { s3Client, BUCKET_NAME } from '@/app/s3'; import { GetObjectCommand } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; +import { getSignedUrl } from '@/lib/s3-presigner'; import { db } from '@db/server'; import type { Metadata } from 'next'; import { headers } from 'next/headers'; diff --git a/apps/app/src/app/(app)/[orgId]/policies/[policyId]/actions/get-policy-pdf-url.ts b/apps/app/src/app/(app)/[orgId]/policies/[policyId]/actions/get-policy-pdf-url.ts index 18842b8037..051d281b5f 100644 --- a/apps/app/src/app/(app)/[orgId]/policies/[policyId]/actions/get-policy-pdf-url.ts +++ b/apps/app/src/app/(app)/[orgId]/policies/[policyId]/actions/get-policy-pdf-url.ts @@ -3,7 +3,7 @@ import { authActionClient } from '@/actions/safe-action'; import { BUCKET_NAME, s3Client } from '@/app/s3'; import { GetObjectCommand } from '@aws-sdk/client-s3'; -import { getSignedUrl } from '@aws-sdk/s3-request-presigner'; +import { getSignedUrl } from '@/lib/s3-presigner'; import { db } from '@db/server'; import { z } from 'zod'; diff --git a/apps/app/src/app/(app)/[orgId]/policies/all/components/PolicyFilters.tsx b/apps/app/src/app/(app)/[orgId]/policies/all/components/PolicyFilters.tsx index bb8a43ee80..bd1a18ceb4 100644 --- a/apps/app/src/app/(app)/[orgId]/policies/all/components/PolicyFilters.tsx +++ b/apps/app/src/app/(app)/[orgId]/policies/all/components/PolicyFilters.tsx @@ -16,6 +16,7 @@ import { import { Search } from '@trycompai/design-system/icons'; import { useMemo, useState } from 'react'; import { PoliciesTableDS } from './PoliciesTableDS'; +import { comparePoliciesByName } from './policy-name-sort'; interface PolicyFiltersProps { policies: Policy[]; @@ -33,8 +34,8 @@ export function PolicyFilters({ policies }: PolicyFiltersProps) { const [searchQuery, setSearchQuery] = useState(''); const [statusFilter, setStatusFilter] = useState('all'); const [departmentFilter, setDepartmentFilter] = useState('all'); - const [sortColumn, setSortColumn] = useState<'name' | 'status' | 'updatedAt'>('updatedAt'); - const [sortDirection, setSortDirection] = useState<'asc' | 'desc'>('desc'); + const [sortColumn, setSortColumn] = useState<'name' | 'status' | 'updatedAt'>('name'); + const [sortDirection, setSortDirection] = useState<'asc' | 'desc'>('asc'); // Get unique departments from policies const departments = useMemo(() => { @@ -74,7 +75,7 @@ export function PolicyFilters({ policies }: PolicyFiltersProps) { result.sort((a, b) => { let comparison = 0; if (sortColumn === 'name') { - comparison = a.name.localeCompare(b.name); + comparison = comparePoliciesByName(a, b); } else if (sortColumn === 'status') { comparison = a.status.localeCompare(b.status); } else if (sortColumn === 'updatedAt') { diff --git a/apps/app/src/app/(app)/[orgId]/policies/all/components/policy-name-sort.test.ts b/apps/app/src/app/(app)/[orgId]/policies/all/components/policy-name-sort.test.ts new file mode 100644 index 0000000000..6546b3fa17 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/policies/all/components/policy-name-sort.test.ts @@ -0,0 +1,32 @@ +import { describe, expect, it } from 'vitest'; +import { comparePoliciesByName } from './policy-name-sort'; + +describe('comparePoliciesByName', () => { + it('sorts policy names alphabetically without case sensitivity', () => { + const policies = [ + { id: '2', name: 'zebra policy' }, + { id: '3', name: 'Alpha policy' }, + { id: '1', name: 'beta policy' }, + ]; + + const sorted = [...policies].sort(comparePoliciesByName); + + expect(sorted.map((policy) => policy.name)).toEqual([ + 'Alpha policy', + 'beta policy', + 'zebra policy', + ]); + }); + + it('falls back to deterministic ordering when names only differ by case', () => { + const policies = [ + { id: 'b', name: 'Policy' }, + { id: 'a', name: 'policy' }, + { id: 'c', name: 'policy' }, + ]; + + const sorted = [...policies].sort(comparePoliciesByName); + + expect(sorted.map((policy) => policy.id)).toEqual(['a', 'b', 'c']); + }); +}); diff --git a/apps/app/src/app/(app)/[orgId]/policies/all/components/policy-name-sort.ts b/apps/app/src/app/(app)/[orgId]/policies/all/components/policy-name-sort.ts new file mode 100644 index 0000000000..e999de8a8a --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/policies/all/components/policy-name-sort.ts @@ -0,0 +1,18 @@ +type NamedPolicy = { + id: string; + name: string; +}; + +const POLICY_NAME_COLLATOR = new Intl.Collator(undefined, { sensitivity: 'base' }); + +export function comparePoliciesByName( + a: NamedPolicy, + b: NamedPolicy, +): number { + const byName = POLICY_NAME_COLLATOR.compare(a.name, b.name); + if (byName !== 0) { + return byName; + } + + return a.id.localeCompare(b.id); +} diff --git a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/SingleTask.tsx b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/SingleTask.tsx index f44c5cf06a..b5b9b5893b 100644 --- a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/SingleTask.tsx +++ b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/SingleTask.tsx @@ -172,6 +172,7 @@ export function SingleTask({ await updateTask({ status: updates.status, assigneeId: updates.assigneeId, + approverId: updates.approverId, frequency: updates.frequency, department: updates.department, reviewDate: updates.reviewDate ? String(updates.reviewDate) : undefined, diff --git a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/TaskIntegrationChecks.tsx b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/TaskIntegrationChecks.tsx index 064795e994..9e66b1c6b3 100644 --- a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/TaskIntegrationChecks.tsx +++ b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/TaskIntegrationChecks.tsx @@ -162,12 +162,14 @@ export function TaskIntegrationChecks({ const handleConfirmDisconnect = useCallback(async () => { if (!disconnectTarget) return; - const { connectionId, checkId, checkName } = disconnectTarget; + const { connectionId, checkId, checkName, integrationName } = + disconnectTarget; + const monitorName = integrationName || checkName; setTogglingCheck(checkId); setDisconnectError(null); try { await disconnectCheckFromTask(connectionId, checkId); - toast.success(`Disconnected "${checkName}" from this task.`); + toast.success(`Disconnected "${monitorName}" from this task.`); setDisconnectTarget(null); } catch (err) { console.error('Failed to disconnect check:', err); @@ -198,6 +200,12 @@ export function TaskIntegrationChecks({ [reconnectCheckToTask], ); + const getMonitorDisplayName = useCallback( + (check: Pick) => + check.integrationName || check.checkName, + [], + ); + if (loading) { return (
@@ -380,6 +388,7 @@ export function TaskIntegrationChecks({ const isRunning = runningCheck === check.checkId; const isExpanded = expandedCheck === check.checkId; const needsConfig = check.needsConfiguration; + const monitorName = getMonitorDisplayName(check); // Determine status from latest run const hasFailed = latestRun @@ -459,8 +468,13 @@ export function TaskIntegrationChecks({

- {check.checkName} + {monitorName}

+ {check.checkName !== monitorName && ( + + {check.checkName} + + )}
{needsConfig ? (

@@ -640,6 +654,7 @@ export function TaskIntegrationChecks({

{disabledForTaskChecks.map((check) => { const isToggling = togglingCheck === check.checkId; + const monitorName = getMonitorDisplayName(check); return (

- {check.checkName} + {monitorName}

Will not run until reconnected @@ -671,7 +686,7 @@ export function TaskIntegrationChecks({ handleReconnect( check.connectionId!, check.checkId, - check.checkName, + monitorName, ) } > @@ -696,33 +711,36 @@ export function TaskIntegrationChecks({ More integrations available

- {disconnectedChecks.map((check) => ( - -
- {check.integrationName} -
-

- {check.checkName} -

+ {disconnectedChecks.map((check) => { + const monitorName = getMonitorDisplayName(check); + return ( + +
+ {check.integrationName} +
+

+ {monitorName} +

+
-
- - - ))} + + + ); + })}
)} @@ -748,11 +766,21 @@ export function TaskIntegrationChecks({ {disconnectTarget ? ( <> - {disconnectTarget.checkName} from{' '} - {disconnectTarget.integrationName} will no - longer run for this task. The integration itself stays - connected and will continue running for other tasks. You can - reconnect it to this task at any time. + + {disconnectTarget.integrationName || + disconnectTarget.checkName} + + {disconnectTarget.checkName !== + (disconnectTarget.integrationName || + disconnectTarget.checkName) && ( + <> + {' '} + ({disconnectTarget.checkName} check) + + )}{' '} + will no longer run for this task. The integration itself + stays connected and will continue running for other tasks. You + can reconnect it to this task at any time. ) : null} diff --git a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/findings/CreateFindingSheet.test.tsx b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/findings/CreateFindingSheet.test.tsx index b613d89381..85aaa4b60a 100644 --- a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/findings/CreateFindingSheet.test.tsx +++ b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/findings/CreateFindingSheet.test.tsx @@ -32,6 +32,15 @@ vi.mock('@/hooks/use-findings-api', () => ({ }, ], FINDING_CATEGORY_LABELS: { general: 'General' }, + FINDING_TYPE_FRAMEWORK_OPTIONS: [ + { value: 'soc2', label: 'SOC 2' }, + { value: 'iso27001', label: 'ISO 27001' }, + { value: 'pci_dss', label: 'PCI DSS' }, + { value: 'hipaa', label: 'HIPAA' }, + { value: 'gdpr', label: 'GDPR' }, + { value: 'iso9001', label: 'ISO 9001' }, + { value: 'iso42001', label: 'ISO 42001' }, + ], FINDING_TYPE_LABELS: { soc2: 'SOC 2', iso27001: 'ISO 27001' }, useFindingActions: () => ({ createFinding: vi.fn(), @@ -99,7 +108,14 @@ vi.mock('@trycompai/design-system', () => ({ Select: ({ children }: any) =>
{children}
, SelectContent: ({ children }: any) =>
{children}
, SelectGroup: ({ children }: any) =>
{children}
, - SelectItem: ({ children }: any) =>
{children}
, + SelectItem: ({ children, value, disabled }: any) => ( +
+ {children} +
+ ), SelectLabel: ({ children }: any) =>
{children}
, SelectTrigger: ({ children }: any) => , Sheet: ({ children, open }: any) => (open ?
{children}
: null), @@ -189,4 +205,32 @@ describe('CreateFindingSheet permission gating', () => { expect(screen.queryByTestId('sheet')).not.toBeInTheDocument(); }); + + it('shows all required frameworks in finding type dropdown options', () => { + setMockPermissions(ADMIN_PERMISSIONS); + + render(); + + expect(screen.getByText('SOC 2')).toBeInTheDocument(); + expect(screen.getByText('ISO 27001')).toBeInTheDocument(); + expect(screen.getByText('PCI DSS')).toBeInTheDocument(); + expect(screen.getByText('HIPAA')).toBeInTheDocument(); + expect(screen.getByText('GDPR')).toBeInTheDocument(); + expect(screen.getByText('ISO 9001')).toBeInTheDocument(); + expect(screen.getByText('ISO 42001')).toBeInTheDocument(); + }); + + it('only enables framework options currently supported by FindingType enum', () => { + setMockPermissions(ADMIN_PERMISSIONS); + + render(); + + expect(screen.getByTestId('select-item-soc2')).toHaveAttribute('data-disabled', 'false'); + expect(screen.getByTestId('select-item-iso27001')).toHaveAttribute('data-disabled', 'false'); + expect(screen.getByTestId('select-item-pci_dss')).toHaveAttribute('data-disabled', 'true'); + expect(screen.getByTestId('select-item-hipaa')).toHaveAttribute('data-disabled', 'true'); + expect(screen.getByTestId('select-item-gdpr')).toHaveAttribute('data-disabled', 'true'); + expect(screen.getByTestId('select-item-iso9001')).toHaveAttribute('data-disabled', 'true'); + expect(screen.getByTestId('select-item-iso42001')).toHaveAttribute('data-disabled', 'true'); + }); }); diff --git a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/findings/CreateFindingSheet.tsx b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/findings/CreateFindingSheet.tsx index ed582387d9..7034178e69 100644 --- a/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/findings/CreateFindingSheet.tsx +++ b/apps/app/src/app/(app)/[orgId]/tasks/[taskId]/components/findings/CreateFindingSheet.tsx @@ -3,6 +3,7 @@ import { DEFAULT_FINDING_TEMPLATES, FINDING_CATEGORY_LABELS, + FINDING_TYPE_FRAMEWORK_OPTIONS, FINDING_TYPE_LABELS, useFindingActions, useFindingTemplates, @@ -68,6 +69,7 @@ export function CreateFindingSheet({ const [isSubmitting, setIsSubmitting] = useState(false); const { hasPermission } = usePermissions(); const canCreateFinding = hasPermission('finding', 'create'); + const supportedFindingTypes = new Set(Object.values(FindingType)); const { data: templatesData } = useFindingTemplates(); const { createFinding } = useFindingActions(); @@ -165,8 +167,12 @@ export function CreateFindingSheet({ -
- -
- ); - } - - if (field.type === 'textarea') { - return ( -