diff --git a/apps/api/src/admin-organizations/admin-policies.controller.ts b/apps/api/src/admin-organizations/admin-policies.controller.ts index 81bb2389a3..98caca9679 100644 --- a/apps/api/src/admin-organizations/admin-policies.controller.ts +++ b/apps/api/src/admin-organizations/admin-policies.controller.ts @@ -130,20 +130,41 @@ export class AdminPoliciesController { ) { const instances = await db.frameworkInstance.findMany({ where: { organizationId: orgId }, - include: { framework: true }, + include: { framework: true, customFramework: true }, }); + const normalized = instances.map((fi) => { + if (fi.framework) { + return { + id: fi.framework.id, + name: fi.framework.name, + version: fi.framework.version, + description: fi.framework.description, + visible: fi.framework.visible, + createdAt: fi.framework.createdAt, + updatedAt: fi.framework.updatedAt, + }; + } + if (fi.customFramework) { + return { + id: fi.customFramework.id, + name: fi.customFramework.name, + version: fi.customFramework.version, + description: fi.customFramework.description, + visible: true, + createdAt: fi.customFramework.createdAt, + updatedAt: fi.customFramework.updatedAt, + }; + } + return null; + }); const uniqueFrameworks = Array.from( - new Map(instances.map((fi) => [fi.framework.id, fi.framework])).values(), - ).map((f) => ({ - id: f.id, - name: f.name, - version: f.version, - description: f.description, - visible: f.visible, - createdAt: f.createdAt, - updatedAt: f.updatedAt, - })); + new Map( + normalized + .filter((f): f is NonNullable => f !== null) + .map((f) => [f.id, f]), + ).values(), + ); const contextEntries = await db.context.findMany({ where: { organizationId: orgId }, diff --git a/apps/api/src/controls/controls.controller.ts b/apps/api/src/controls/controls.controller.ts index f19bb4766b..1b98543931 100644 --- a/apps/api/src/controls/controls.controller.ts +++ b/apps/api/src/controls/controls.controller.ts @@ -4,6 +4,7 @@ import { Delete, Get, Param, + ParseEnumPipe, Post, Query, UseGuards, @@ -20,6 +21,11 @@ import { RequirePermission } from '../auth/require-permission.decorator'; import { OrganizationId } from '../auth/auth-context.decorator'; import { ControlsService } from './controls.service'; import { CreateControlDto } from './dto/create-control.dto'; +import { LinkPoliciesDto } from './dto/link-policies.dto'; +import { LinkTasksDto } from './dto/link-tasks.dto'; +import { LinkRequirementsToControlDto } from './dto/link-requirements.dto'; +import { LinkDocumentTypesDto } from './dto/link-document-types.dto'; +import { EvidenceFormType } from '@db'; @ApiTags('Controls') @ApiBearerAuth() @@ -92,6 +98,78 @@ export class ControlsController { return this.controlsService.create(organizationId, dto); } + @Post(':id/policies/link') + @RequirePermission('control', 'update') + @ApiOperation({ summary: 'Link existing policies to a control' }) + async linkPolicies( + @OrganizationId() organizationId: string, + @Param('id') id: string, + @Body() dto: LinkPoliciesDto, + ) { + return this.controlsService.linkPolicies( + id, + organizationId, + dto.policyIds, + ); + } + + @Post(':id/tasks/link') + @RequirePermission('control', 'update') + @ApiOperation({ summary: 'Link existing tasks to a control' }) + async linkTasks( + @OrganizationId() organizationId: string, + @Param('id') id: string, + @Body() dto: LinkTasksDto, + ) { + return this.controlsService.linkTasks(id, organizationId, dto.taskIds); + } + + @Post(':id/requirements/link') + @RequirePermission('control', 'update') + @ApiOperation({ summary: 'Link existing requirements to a control' }) + async linkRequirements( + @OrganizationId() organizationId: string, + @Param('id') id: string, + @Body() dto: LinkRequirementsToControlDto, + ) { + return this.controlsService.linkRequirements( + id, + organizationId, + dto.requirements, + ); + } + + @Post(':id/document-types/link') + @RequirePermission('control', 'update') + @ApiOperation({ summary: 'Link required document types to a control' }) + async linkDocumentTypes( + @OrganizationId() organizationId: string, + @Param('id') id: string, + @Body() dto: LinkDocumentTypesDto, + ) { + return this.controlsService.linkDocumentTypes( + id, + organizationId, + dto.formTypes, + ); + } + + @Delete(':id/document-types/:formType') + @RequirePermission('control', 'update') + @ApiOperation({ summary: 'Remove a required document type from a control' }) + async unlinkDocumentType( + @OrganizationId() organizationId: string, + @Param('id') id: string, + @Param('formType', new ParseEnumPipe(EvidenceFormType)) + formType: EvidenceFormType, + ) { + return this.controlsService.unlinkDocumentType( + id, + organizationId, + formType, + ); + } + @Delete(':id') @RequirePermission('control', 'delete') @ApiOperation({ summary: 'Delete a control' }) diff --git a/apps/api/src/controls/controls.service.ts b/apps/api/src/controls/controls.service.ts index 03698193d2..6f436989d7 100644 --- a/apps/api/src/controls/controls.service.ts +++ b/apps/api/src/controls/controls.service.ts @@ -1,5 +1,9 @@ -import { Injectable, NotFoundException } from '@nestjs/common'; -import { db, Prisma } from '@db'; +import { + BadRequestException, + Injectable, + NotFoundException, +} from '@nestjs/common'; +import { db, EvidenceFormType, Prisma } from '@db'; import { CreateControlDto } from './dto/create-control.dto'; const controlInclude = { @@ -12,11 +16,14 @@ const controlInclude = { requirementsMapped: { include: { frameworkInstance: { - include: { framework: true }, + include: { framework: true, customFramework: true }, }, requirement: { select: { name: true, identifier: true }, }, + customRequirement: { + select: { name: true, identifier: true }, + }, }, }, } satisfies Prisma.ControlInclude; @@ -67,12 +74,14 @@ export class ControlsService { include: { policies: true, tasks: true, + controlDocumentTypes: true, requirementsMapped: { include: { frameworkInstance: { - include: { framework: true }, + include: { framework: true, customFramework: true }, }, requirement: true, + customRequirement: true, }, }, }, @@ -82,6 +91,24 @@ export class ControlsService { throw new NotFoundException('Control not found'); } + const formTypes = (control.controlDocumentTypes ?? []).map( + (d) => d.formType, + ); + const submissionCountsByFormType: Record = {}; + if (formTypes.length > 0) { + const grouped = await db.evidenceSubmission.groupBy({ + by: ['formType'], + where: { + organizationId, + formType: { in: formTypes }, + }, + _count: { _all: true }, + }); + for (const g of grouped) { + submissionCountsByFormType[g.formType] = g._count._all; + } + } + // Compute progress const policies = control.policies || []; const tasks = control.tasks || []; @@ -101,6 +128,7 @@ export class ControlsService { return { ...control, + submissionCountsByFormType, progress: { total: totalItems, completed, @@ -136,63 +164,420 @@ export class ControlsService { }, }, }, + customFramework: { + include: { + requirements: { + select: { id: true, name: true, identifier: true }, + }, + }, + }, }, }), ]); - const requirements = frameworkInstances.flatMap((fi) => - fi.framework.requirements.map((req) => ({ - id: req.id, - name: req.name, - identifier: req.identifier, - frameworkInstanceId: fi.id, - frameworkName: fi.framework.name, - })), - ); + type RequirementOption = { + id: string; + name: string; + identifier: string; + frameworkInstanceId: string; + frameworkName: string; + isCustom: boolean; + requirementId?: string; + customRequirementId?: string; + }; + const requirements: RequirementOption[] = []; + for (const fi of frameworkInstances) { + if (fi.customFramework) { + for (const req of fi.customFramework.requirements) { + requirements.push({ + id: req.id, + name: req.name, + identifier: req.identifier, + customRequirementId: req.id, + frameworkInstanceId: fi.id, + frameworkName: fi.customFramework.name, + isCustom: true, + }); + } + } else if (fi.framework) { + for (const req of fi.framework.requirements) { + requirements.push({ + id: req.id, + name: req.name, + identifier: req.identifier, + requirementId: req.id, + frameworkInstanceId: fi.id, + frameworkName: fi.framework.name, + isCustom: false, + }); + } + } + } return { policies, tasks, requirements }; } async create(organizationId: string, dto: CreateControlDto) { - const { name, description, policyIds, taskIds, requirementMappings } = dto; + const { + name, + description, + policyIds, + taskIds, + requirementMappings, + documentTypes, + } = dto; - const control = await db.control.create({ - data: { - name, - description, - organizationId, - ...(policyIds && - policyIds.length > 0 && { + for (const m of requirementMappings ?? []) { + const hasPlatform = Boolean(m.requirementId); + const hasCustom = Boolean(m.customRequirementId); + if (hasPlatform === hasCustom) { + throw new BadRequestException( + 'Each requirement mapping must set exactly one of requirementId or customRequirementId', + ); + } + } + + // Scope every FK supplied by the client to the caller's org before trusting + // it. Prisma FKs only check row existence, not tenancy. + const scopedPolicyIds = await this.validatePolicyIds( + policyIds, + organizationId, + ); + const scopedTaskIds = await this.validateTaskIds(taskIds, organizationId); + const scopedRequirementMappings = await this.validateRequirementMappings( + requirementMappings, + organizationId, + ); + + return db.$transaction(async (tx) => { + const control = await tx.control.create({ + data: { + name, + description, + organizationId, + ...(scopedPolicyIds.length > 0 && { policies: { - connect: policyIds.map((id) => ({ id })), + connect: scopedPolicyIds.map((id) => ({ id })), }, }), - ...(taskIds && - taskIds.length > 0 && { + ...(scopedTaskIds.length > 0 && { tasks: { - connect: taskIds.map((id) => ({ id })), + connect: scopedTaskIds.map((id) => ({ id })), }, }), - }, + }, + }); + + if (scopedRequirementMappings.length > 0) { + await tx.requirementMap.createMany({ + data: scopedRequirementMappings.map((mapping) => ({ + controlId: control.id, + frameworkInstanceId: mapping.frameworkInstanceId, + requirementId: mapping.requirementId ?? null, + customRequirementId: mapping.customRequirementId ?? null, + })), + skipDuplicates: true, + }); + } + + if (documentTypes && documentTypes.length > 0) { + await tx.controlDocumentType.createMany({ + data: documentTypes.map((formType) => ({ + controlId: control.id, + formType, + })), + skipDuplicates: true, + }); + } + + return control; }); + } - if (requirementMappings && requirementMappings.length > 0) { - await Promise.all( - requirementMappings.map((mapping) => - db.requirementMap.create({ - data: { - controlId: control.id, - requirementId: mapping.requirementId, - frameworkInstanceId: mapping.frameworkInstanceId, - }, - }), - ), + private async validatePolicyIds( + policyIds: string[] | undefined, + organizationId: string, + ): Promise { + if (!policyIds || policyIds.length === 0) return []; + const uniqueIds = Array.from(new Set(policyIds)); + const policies = await db.policy.findMany({ + where: { id: { in: uniqueIds }, organizationId }, + select: { id: true }, + }); + if (policies.length !== uniqueIds.length) { + throw new BadRequestException('One or more policies are invalid'); + } + return policies.map((p) => p.id); + } + + private async validateTaskIds( + taskIds: string[] | undefined, + organizationId: string, + ): Promise { + if (!taskIds || taskIds.length === 0) return []; + const uniqueIds = Array.from(new Set(taskIds)); + const tasks = await db.task.findMany({ + where: { id: { in: uniqueIds }, organizationId }, + select: { id: true }, + }); + if (tasks.length !== uniqueIds.length) { + throw new BadRequestException('One or more tasks are invalid'); + } + return tasks.map((t) => t.id); + } + + private async validateRequirementMappings( + mappings: + | { + requirementId?: string; + customRequirementId?: string; + frameworkInstanceId: string; + }[] + | undefined, + organizationId: string, + ) { + if (!mappings || mappings.length === 0) return []; + + const frameworkInstanceIds = Array.from( + new Set(mappings.map((m) => m.frameworkInstanceId)), + ); + const instances = await db.frameworkInstance.findMany({ + where: { id: { in: frameworkInstanceIds }, organizationId }, + select: { id: true, frameworkId: true, customFrameworkId: true }, + }); + const instanceById = new Map(instances.map((i) => [i.id, i])); + if (instances.length !== frameworkInstanceIds.length) { + throw new BadRequestException( + 'One or more framework instances are invalid', ); } + const platformReqIds = mappings + .map((m) => m.requirementId) + .filter((id): id is string => Boolean(id)); + const customReqIds = mappings + .map((m) => m.customRequirementId) + .filter((id): id is string => Boolean(id)); + + const [platformReqs, customReqs] = await Promise.all([ + platformReqIds.length > 0 + ? db.frameworkEditorRequirement.findMany({ + where: { id: { in: platformReqIds } }, + select: { id: true, frameworkId: true }, + }) + : Promise.resolve<{ id: string; frameworkId: string }[]>([]), + customReqIds.length > 0 + ? db.customRequirement.findMany({ + where: { id: { in: customReqIds }, organizationId }, + select: { id: true, customFrameworkId: true }, + }) + : Promise.resolve<{ id: string; customFrameworkId: string }[]>([]), + ]); + const platformReqFwById = new Map( + platformReqs.map((r) => [r.id, r.frameworkId]), + ); + const customReqFwById = new Map( + customReqs.map((r) => [r.id, r.customFrameworkId]), + ); + + for (const m of mappings) { + const instance = instanceById.get(m.frameworkInstanceId); + if (!instance) { + throw new BadRequestException( + 'One or more framework instances are invalid', + ); + } + if (m.requirementId) { + const reqFwId = platformReqFwById.get(m.requirementId); + if (!reqFwId || reqFwId !== instance.frameworkId) { + throw new BadRequestException( + 'One or more requirement mappings are invalid', + ); + } + } else if (m.customRequirementId) { + const reqFwId = customReqFwById.get(m.customRequirementId); + if (!reqFwId || reqFwId !== instance.customFrameworkId) { + throw new BadRequestException( + 'One or more requirement mappings are invalid', + ); + } + } + } + + return mappings; + } + + private async ensureControl(controlId: string, organizationId: string) { + const control = await db.control.findUnique({ + where: { id: controlId, organizationId }, + select: { id: true }, + }); + if (!control) { + throw new NotFoundException('Control not found'); + } return control; } + async linkPolicies( + controlId: string, + organizationId: string, + policyIds: string[], + ) { + await this.ensureControl(controlId, organizationId); + + const policies = await db.policy.findMany({ + where: { id: { in: policyIds }, organizationId }, + select: { id: true }, + }); + if (policies.length === 0) { + throw new BadRequestException('No valid policies to link'); + } + + await db.control.update({ + where: { id: controlId }, + data: { policies: { connect: policies.map((p) => ({ id: p.id })) } }, + }); + + return { count: policies.length }; + } + + async linkTasks( + controlId: string, + organizationId: string, + taskIds: string[], + ) { + await this.ensureControl(controlId, organizationId); + + const tasks = await db.task.findMany({ + where: { id: { in: taskIds }, organizationId }, + select: { id: true }, + }); + if (tasks.length === 0) { + throw new BadRequestException('No valid tasks to link'); + } + + await db.control.update({ + where: { id: controlId }, + data: { tasks: { connect: tasks.map((t) => ({ id: t.id })) } }, + }); + + return { count: tasks.length }; + } + + async linkRequirements( + controlId: string, + organizationId: string, + mappings: { + requirementId?: string; + customRequirementId?: string; + frameworkInstanceId: string; + }[], + ) { + await this.ensureControl(controlId, organizationId); + + for (const m of mappings) { + const hasPlatform = Boolean(m.requirementId); + const hasCustom = Boolean(m.customRequirementId); + if (hasPlatform === hasCustom) { + throw new BadRequestException( + 'Each mapping must set exactly one of requirementId or customRequirementId', + ); + } + } + + const frameworkInstanceIds = Array.from( + new Set(mappings.map((m) => m.frameworkInstanceId)), + ); + const instances = await db.frameworkInstance.findMany({ + where: { id: { in: frameworkInstanceIds }, organizationId }, + select: { id: true, frameworkId: true, customFrameworkId: true }, + }); + const instanceById = new Map(instances.map((i) => [i.id, i])); + + const platformReqIds = mappings + .map((m) => m.requirementId) + .filter((id): id is string => Boolean(id)); + const customReqIds = mappings + .map((m) => m.customRequirementId) + .filter((id): id is string => Boolean(id)); + + const [platformReqs, customReqs] = await Promise.all([ + platformReqIds.length > 0 + ? db.frameworkEditorRequirement.findMany({ + where: { id: { in: platformReqIds } }, + select: { id: true, frameworkId: true }, + }) + : Promise.resolve<{ id: string; frameworkId: string }[]>([]), + customReqIds.length > 0 + ? db.customRequirement.findMany({ + where: { id: { in: customReqIds }, organizationId }, + select: { id: true, customFrameworkId: true }, + }) + : Promise.resolve<{ id: string; customFrameworkId: string }[]>([]), + ]); + const platformReqFwById = new Map( + platformReqs.map((r) => [r.id, r.frameworkId]), + ); + const customReqFwById = new Map( + customReqs.map((r) => [r.id, r.customFrameworkId]), + ); + + const validMappings = mappings.filter((m) => { + const instance = instanceById.get(m.frameworkInstanceId); + if (!instance) return false; + if (m.requirementId) { + const reqFwId = platformReqFwById.get(m.requirementId); + return Boolean(reqFwId) && reqFwId === instance.frameworkId; + } + if (m.customRequirementId) { + const reqFwId = customReqFwById.get(m.customRequirementId); + return Boolean(reqFwId) && reqFwId === instance.customFrameworkId; + } + return false; + }); + + if (validMappings.length === 0) { + throw new BadRequestException('No valid requirements to link'); + } + + const result = await db.requirementMap.createMany({ + data: validMappings.map((m) => ({ + controlId, + frameworkInstanceId: m.frameworkInstanceId, + requirementId: m.requirementId ?? null, + customRequirementId: m.customRequirementId ?? null, + })), + skipDuplicates: true, + }); + + return { count: result.count }; + } + + async linkDocumentTypes( + controlId: string, + organizationId: string, + formTypes: EvidenceFormType[], + ) { + await this.ensureControl(controlId, organizationId); + const result = await db.controlDocumentType.createMany({ + data: formTypes.map((formType) => ({ controlId, formType })), + skipDuplicates: true, + }); + return { count: result.count }; + } + + async unlinkDocumentType( + controlId: string, + organizationId: string, + formType: EvidenceFormType, + ) { + await this.ensureControl(controlId, organizationId); + await db.controlDocumentType.deleteMany({ + where: { controlId, formType }, + }); + return { success: true }; + } + async delete(controlId: string, organizationId: string) { const control = await db.control.findUnique({ where: { diff --git a/apps/api/src/controls/dto/create-control.dto.ts b/apps/api/src/controls/dto/create-control.dto.ts index b899ce6180..ac0735837f 100644 --- a/apps/api/src/controls/dto/create-control.dto.ts +++ b/apps/api/src/controls/dto/create-control.dto.ts @@ -1,17 +1,33 @@ +import { EvidenceFormType } from '@db'; import { ApiProperty } from '@nestjs/swagger'; import { IsString, IsOptional, IsArray, IsNotEmpty, + IsEnum, ValidateNested, } from 'class-validator'; import { Type } from 'class-transformer'; class RequirementMappingDto { - @ApiProperty({ description: 'Requirement ID' }) + @ApiProperty({ + description: + 'Platform requirement ID (exactly one of requirementId / customRequirementId must be set)', + required: false, + }) + @IsOptional() + @IsString() + requirementId?: string; + + @ApiProperty({ + description: + 'Org-custom requirement ID (exactly one of requirementId / customRequirementId must be set)', + required: false, + }) + @IsOptional() @IsString() - requirementId: string; + customRequirementId?: string; @ApiProperty({ description: 'Framework instance ID' }) @IsString() @@ -60,4 +76,15 @@ export class CreateControlDto { @ValidateNested({ each: true }) @Type(() => RequirementMappingDto) requirementMappings?: RequirementMappingDto[]; + + @ApiProperty({ + description: 'Evidence form types to require on this control', + required: false, + enum: EvidenceFormType, + isArray: true, + }) + @IsOptional() + @IsArray() + @IsEnum(EvidenceFormType, { each: true }) + documentTypes?: EvidenceFormType[]; } diff --git a/apps/api/src/controls/dto/link-document-types.dto.ts b/apps/api/src/controls/dto/link-document-types.dto.ts new file mode 100644 index 0000000000..b3b32c859e --- /dev/null +++ b/apps/api/src/controls/dto/link-document-types.dto.ts @@ -0,0 +1,19 @@ +import { EvidenceFormType } from '@db'; +import { ApiProperty } from '@nestjs/swagger'; +import { + ArrayMinSize, + IsArray, + IsEnum, +} from 'class-validator'; + +export class LinkDocumentTypesDto { + @ApiProperty({ + description: 'Evidence form types to require for this control', + enum: EvidenceFormType, + isArray: true, + }) + @IsArray() + @ArrayMinSize(1) + @IsEnum(EvidenceFormType, { each: true }) + formTypes: EvidenceFormType[]; +} diff --git a/apps/api/src/controls/dto/link-policies.dto.ts b/apps/api/src/controls/dto/link-policies.dto.ts new file mode 100644 index 0000000000..5465f1cbe2 --- /dev/null +++ b/apps/api/src/controls/dto/link-policies.dto.ts @@ -0,0 +1,10 @@ +import { ArrayMinSize, IsArray, IsString } from 'class-validator'; +import { ApiProperty } from '@nestjs/swagger'; + +export class LinkPoliciesDto { + @ApiProperty({ description: 'Policy IDs to link to the control', type: [String] }) + @IsArray() + @ArrayMinSize(1) + @IsString({ each: true }) + policyIds: string[]; +} diff --git a/apps/api/src/controls/dto/link-requirements.dto.ts b/apps/api/src/controls/dto/link-requirements.dto.ts new file mode 100644 index 0000000000..8cff1a0c50 --- /dev/null +++ b/apps/api/src/controls/dto/link-requirements.dto.ts @@ -0,0 +1,43 @@ +import { + ArrayMinSize, + IsArray, + IsOptional, + IsString, + ValidateNested, +} from 'class-validator'; +import { Type } from 'class-transformer'; +import { ApiProperty } from '@nestjs/swagger'; + +export class LinkRequirementMappingDto { + @ApiProperty({ + description: 'Platform requirement ID', + required: false, + }) + @IsOptional() + @IsString() + requirementId?: string; + + @ApiProperty({ + description: 'Org-custom requirement ID', + required: false, + }) + @IsOptional() + @IsString() + customRequirementId?: string; + + @ApiProperty({ description: 'Framework instance ID' }) + @IsString() + frameworkInstanceId: string; +} + +export class LinkRequirementsToControlDto { + @ApiProperty({ + description: 'Requirement + framework instance pairs to link', + type: [LinkRequirementMappingDto], + }) + @IsArray() + @ArrayMinSize(1) + @ValidateNested({ each: true }) + @Type(() => LinkRequirementMappingDto) + requirements: LinkRequirementMappingDto[]; +} diff --git a/apps/api/src/controls/dto/link-tasks.dto.ts b/apps/api/src/controls/dto/link-tasks.dto.ts new file mode 100644 index 0000000000..c85833a8c5 --- /dev/null +++ b/apps/api/src/controls/dto/link-tasks.dto.ts @@ -0,0 +1,10 @@ +import { ArrayMinSize, IsArray, IsString } from 'class-validator'; +import { ApiProperty } from '@nestjs/swagger'; + +export class LinkTasksDto { + @ApiProperty({ description: 'Task IDs to link to the control', type: [String] }) + @IsArray() + @ArrayMinSize(1) + @IsString({ each: true }) + taskIds: string[]; +} diff --git a/apps/api/src/framework-editor/control-template/control-template.controller.ts b/apps/api/src/framework-editor/control-template/control-template.controller.ts index 6b8003ecf1..16968b192e 100644 --- a/apps/api/src/framework-editor/control-template/control-template.controller.ts +++ b/apps/api/src/framework-editor/control-template/control-template.controller.ts @@ -44,11 +44,8 @@ export class ControlTemplateController { @Post() @ApiOperation({ summary: 'Create a control template' }) @UsePipes(new ValidationPipe({ whitelist: true, transform: true })) - async create( - @Body() dto: CreateControlTemplateDto, - @Query('frameworkId') frameworkId?: string, - ) { - return this.service.create(dto, frameworkId); + async create(@Body() dto: CreateControlTemplateDto) { + return this.service.create(dto); } @Patch(':id') diff --git a/apps/api/src/framework-editor/control-template/control-template.service.spec.ts b/apps/api/src/framework-editor/control-template/control-template.service.spec.ts new file mode 100644 index 0000000000..7f55e94222 --- /dev/null +++ b/apps/api/src/framework-editor/control-template/control-template.service.spec.ts @@ -0,0 +1,85 @@ +jest.mock('@db', () => ({ + db: { + frameworkEditorControlTemplate: { + create: jest.fn(), + findUnique: jest.fn(), + findMany: jest.fn(), + update: jest.fn(), + delete: jest.fn(), + }, + frameworkEditorRequirement: { + findMany: jest.fn(), + }, + }, + Prisma: { PrismaClientKnownRequestError: class {} }, +})); + +import { db } from '@db'; +import { ControlTemplateService } from './control-template.service'; + +const mockDb = db as jest.Mocked; + +describe('ControlTemplateService', () => { + let service: ControlTemplateService; + + beforeEach(() => { + service = new ControlTemplateService(); + jest.clearAllMocks(); + (mockDb.frameworkEditorControlTemplate.create as jest.Mock).mockResolvedValue({ + id: 'frk_ct_new', + name: 'New Control', + }); + }); + + describe('create', () => { + const baseDto = { + name: 'New Control', + description: 'Some description', + }; + + // Regression test for CS-271: creating a control used to auto-link every + // requirement in the caller-supplied framework. The `frameworkId` parameter + // has been removed, and this test guards against the behavior coming back + // even if the requirements table is populated. + it('never queries or auto-links framework requirements on create (CS-271)', async () => { + (mockDb.frameworkEditorRequirement.findMany as jest.Mock).mockResolvedValue([ + { id: 'frk_req_1' }, + { id: 'frk_req_2' }, + ]); + + await service.create(baseDto); + + expect(mockDb.frameworkEditorRequirement.findMany).not.toHaveBeenCalled(); + const createArgs = (mockDb.frameworkEditorControlTemplate.create as jest.Mock).mock + .calls[0][0]; + expect(createArgs.data).not.toHaveProperty('requirements'); + }); + + it('persists name and description', async () => { + await service.create(baseDto); + + const createArgs = (mockDb.frameworkEditorControlTemplate.create as jest.Mock).mock + .calls[0][0]; + expect(createArgs.data).toMatchObject({ + name: 'New Control', + description: 'Some description', + }); + }); + + it('persists documentTypes when provided', async () => { + await service.create({ ...baseDto, documentTypes: ['penetration-test'] }); + + const createArgs = (mockDb.frameworkEditorControlTemplate.create as jest.Mock).mock + .calls[0][0]; + expect(createArgs.data.documentTypes).toEqual(['penetration-test']); + }); + + it('omits documentTypes when not provided', async () => { + await service.create(baseDto); + + const createArgs = (mockDb.frameworkEditorControlTemplate.create as jest.Mock).mock + .calls[0][0]; + expect(createArgs.data).not.toHaveProperty('documentTypes'); + }); + }); +}); diff --git a/apps/api/src/framework-editor/control-template/control-template.service.ts b/apps/api/src/framework-editor/control-template/control-template.service.ts index 1e91e4073f..565fc4a151 100644 --- a/apps/api/src/framework-editor/control-template/control-template.service.ts +++ b/apps/api/src/framework-editor/control-template/control-template.service.ts @@ -54,16 +54,7 @@ export class ControlTemplateService { return ct; } - async create(dto: CreateControlTemplateDto, frameworkId?: string) { - const requirementIds = frameworkId - ? await db.frameworkEditorRequirement - .findMany({ - where: { frameworkId }, - select: { id: true }, - }) - .then((reqs) => reqs.map((r) => ({ id: r.id }))) - : []; - + async create(dto: CreateControlTemplateDto) { const ct = await db.frameworkEditorControlTemplate.create({ data: { name: dto.name, @@ -71,9 +62,6 @@ export class ControlTemplateService { ...(dto.documentTypes && { documentTypes: dto.documentTypes as EvidenceFormType[], }), - ...(requirementIds.length > 0 && { - requirements: { connect: requirementIds }, - }), }, }); this.logger.log(`Created control template: ${ct.name} (${ct.id})`); diff --git a/apps/api/src/frameworks/dto/create-custom-framework.dto.ts b/apps/api/src/frameworks/dto/create-custom-framework.dto.ts new file mode 100644 index 0000000000..a703007927 --- /dev/null +++ b/apps/api/src/frameworks/dto/create-custom-framework.dto.ts @@ -0,0 +1,21 @@ +import { IsOptional, IsString, MaxLength, MinLength } from 'class-validator'; +import { ApiProperty } from '@nestjs/swagger'; + +export class CreateCustomFrameworkDto { + @ApiProperty({ description: 'Framework name', example: 'Internal Controls' }) + @IsString() + @MinLength(1) + @MaxLength(120) + name: string; + + @ApiProperty({ description: 'Framework description' }) + @IsString() + @MaxLength(2000) + description: string; + + @ApiProperty({ description: 'Version', required: false, example: '1.0' }) + @IsOptional() + @IsString() + @MaxLength(40) + version?: string; +} diff --git a/apps/api/src/frameworks/dto/create-custom-requirement.dto.ts b/apps/api/src/frameworks/dto/create-custom-requirement.dto.ts new file mode 100644 index 0000000000..27e5c9620e --- /dev/null +++ b/apps/api/src/frameworks/dto/create-custom-requirement.dto.ts @@ -0,0 +1,21 @@ +import { IsString, MaxLength, MinLength } from 'class-validator'; +import { ApiProperty } from '@nestjs/swagger'; + +export class CreateCustomRequirementDto { + @ApiProperty({ description: 'Requirement name', example: 'Access Review' }) + @IsString() + @MinLength(1) + @MaxLength(200) + name: string; + + @ApiProperty({ description: 'Identifier', example: '10.3' }) + @IsString() + @MinLength(1) + @MaxLength(80) + identifier: string; + + @ApiProperty({ description: 'Description' }) + @IsString() + @MaxLength(4000) + description: string; +} diff --git a/apps/api/src/frameworks/dto/link-controls.dto.ts b/apps/api/src/frameworks/dto/link-controls.dto.ts new file mode 100644 index 0000000000..0b6a2fa71a --- /dev/null +++ b/apps/api/src/frameworks/dto/link-controls.dto.ts @@ -0,0 +1,14 @@ +import { ArrayMinSize, IsArray, IsString } from 'class-validator'; +import { ApiProperty } from '@nestjs/swagger'; + +export class LinkControlsDto { + @ApiProperty({ + description: + 'Existing org Control IDs to map to the requirement on this framework instance', + type: [String], + }) + @IsArray() + @ArrayMinSize(1) + @IsString({ each: true }) + controlIds: string[]; +} diff --git a/apps/api/src/frameworks/dto/link-requirements.dto.ts b/apps/api/src/frameworks/dto/link-requirements.dto.ts new file mode 100644 index 0000000000..d735e9c93e --- /dev/null +++ b/apps/api/src/frameworks/dto/link-requirements.dto.ts @@ -0,0 +1,14 @@ +import { ArrayMinSize, IsArray, IsString } from 'class-validator'; +import { ApiProperty } from '@nestjs/swagger'; + +export class LinkRequirementsDto { + @ApiProperty({ + description: + 'IDs of existing FrameworkEditorRequirement rows to clone into this framework', + type: [String], + }) + @IsArray() + @ArrayMinSize(1) + @IsString({ each: true }) + requirementIds: string[]; +} diff --git a/apps/api/src/frameworks/frameworks-scores.helper.ts b/apps/api/src/frameworks/frameworks-scores.helper.ts index 04041b2cba..9c233b1640 100644 --- a/apps/api/src/frameworks/frameworks-scores.helper.ts +++ b/apps/api/src/frameworks/frameworks-scores.helper.ts @@ -206,11 +206,14 @@ export async function getCurrentMember(organizationId: string, userId: string) { return member; } +interface ControlForScoring { + id: string; + policies: { id: string; status: string }[]; + controlDocumentTypes?: { formType: string }[]; +} + interface FrameworkWithControlsForScoring { - controls: { - id: string; - policies: { id: string; status: string }[]; - }[]; + controls: ControlForScoring[]; } interface TaskWithControls { @@ -219,40 +222,79 @@ interface TaskWithControls { controls: { id: string }[]; } -export function computeFrameworkComplianceScore( - framework: FrameworkWithControlsForScoring, +interface EvidenceSubmissionForScoring { + formType: string; + submittedAt: Date | string; +} + +function hasAnyArtifact( + control: ControlForScoring, tasks: TaskWithControls[], -): number { - const controls = framework.controls ?? []; +): boolean { + const policies = control.policies ?? []; + const documentTypes = control.controlDocumentTypes ?? []; + const controlTasks = tasks.filter((t) => + t.controls.some((c) => c.id === control.id), + ); + return ( + policies.length > 0 || controlTasks.length > 0 || documentTypes.length > 0 + ); +} - // Deduplicate policies by id across all controls - const uniquePoliciesMap = new Map(); - for (const c of controls) { - for (const p of c.policies || []) { - uniquePoliciesMap.set(p.id, p); - } - } - const uniquePolicies = Array.from(uniquePoliciesMap.values()); +function isControlCompleted( + control: ControlForScoring, + tasks: TaskWithControls[], + evidenceSubmissions: EvidenceSubmissionForScoring[], +): boolean { + const policies = control.policies ?? []; + const documentTypes = control.controlDocumentTypes ?? []; + const controlTasks = tasks.filter((t) => + t.controls.some((c) => c.id === control.id), + ); - const totalPolicies = uniquePolicies.length; - const publishedPolicies = uniquePolicies.filter( - (p) => p.status === 'published', - ).length; - const policyRatio = totalPolicies > 0 ? publishedPolicies / totalPolicies : 0; + const policiesComplete = + policies.length === 0 || + policies.every((p) => p.status === 'published'); - const controlIds = controls.map((c) => c.id); - const uniqueTaskMap = new Map(); - for (const t of tasks) { - if (t.controls.some((c) => controlIds.includes(c.id))) { - uniqueTaskMap.set(t.id, t); + const tasksComplete = + controlTasks.length === 0 || + controlTasks.every( + (t) => t.status === 'done' || t.status === 'not_relevant', + ); + + let documentsComplete = true; + if (documentTypes.length > 0) { + const sorted = [...evidenceSubmissions].sort( + (a, b) => + new Date(b.submittedAt).getTime() - new Date(a.submittedAt).getTime(), + ); + const now = Date.now(); + for (const dt of documentTypes) { + const latest = sorted.find((es) => es.formType === dt.formType); + if ( + !latest || + now - new Date(latest.submittedAt).getTime() > SIX_MONTHS_MS + ) { + documentsComplete = false; + break; + } } } - const uniqueTasks = Array.from(uniqueTaskMap.values()); - const totalTasks = uniqueTasks.length; - const doneTasks = uniqueTasks.filter( - (t) => t.status === 'done' || t.status === 'not_relevant', - ).length; - const taskRatio = totalTasks > 0 ? doneTasks / totalTasks : 1; - return Math.round(((policyRatio + taskRatio) / 2) * 100); + return policiesComplete && tasksComplete && documentsComplete; +} + +export function computeFrameworkComplianceScore( + framework: FrameworkWithControlsForScoring, + tasks: TaskWithControls[], + evidenceSubmissions: EvidenceSubmissionForScoring[] = [], +): number { + const controls = (framework.controls ?? []).filter((c) => + hasAnyArtifact(c, tasks), + ); + if (controls.length === 0) return 0; + const completed = controls.filter((c) => + isControlCompleted(c, tasks, evidenceSubmissions), + ).length; + return Math.round((completed / controls.length) * 100); } diff --git a/apps/api/src/frameworks/frameworks.controller.ts b/apps/api/src/frameworks/frameworks.controller.ts index b8e5c3002d..cee958679f 100644 --- a/apps/api/src/frameworks/frameworks.controller.ts +++ b/apps/api/src/frameworks/frameworks.controller.ts @@ -22,6 +22,10 @@ import { AuthContext, OrganizationId } from '../auth/auth-context.decorator'; import type { AuthContext as AuthContextType } from '../auth/types'; import { FrameworksService } from './frameworks.service'; import { AddFrameworksDto } from './dto/add-frameworks.dto'; +import { CreateCustomFrameworkDto } from './dto/create-custom-framework.dto'; +import { CreateCustomRequirementDto } from './dto/create-custom-requirement.dto'; +import { LinkRequirementsDto } from './dto/link-requirements.dto'; +import { LinkControlsDto } from './dto/link-controls.dto'; @ApiTags('Frameworks') @ApiBearerAuth() @@ -53,8 +57,8 @@ export class FrameworksController { summary: 'List available frameworks (requires session, no active org needed — used during onboarding)', }) - async findAvailable() { - const data = await this.frameworksService.findAvailable(); + async findAvailable(@OrganizationId() organizationId?: string) { + const data = await this.frameworksService.findAvailable(organizationId); return { data, count: data.length }; } @@ -106,6 +110,64 @@ export class FrameworksController { ); } + @Post('custom') + @RequirePermission('framework', 'create') + @ApiOperation({ summary: 'Create a custom framework for this organization' }) + async createCustom( + @OrganizationId() organizationId: string, + @Body() dto: CreateCustomFrameworkDto, + ) { + return this.frameworksService.createCustom(organizationId, dto); + } + + @Post(':id/requirements') + @RequirePermission('framework', 'update') + @ApiOperation({ summary: 'Add a custom requirement to a framework instance' }) + async createRequirement( + @OrganizationId() organizationId: string, + @Param('id') id: string, + @Body() dto: CreateCustomRequirementDto, + ) { + return this.frameworksService.createRequirement(id, organizationId, dto); + } + + @Post(':id/requirements/link') + @RequirePermission('framework', 'update') + @ApiOperation({ + summary: + 'Link (clone) existing requirements from another framework into this one', + }) + async linkRequirements( + @OrganizationId() organizationId: string, + @Param('id') id: string, + @Body() dto: LinkRequirementsDto, + ) { + return this.frameworksService.linkRequirements( + id, + organizationId, + dto.requirementIds, + ); + } + + @Post(':id/requirements/:requirementKey/controls/link') + @RequirePermission('framework', 'update') + @ApiOperation({ + summary: 'Link existing org controls to a requirement', + }) + async linkControls( + @OrganizationId() organizationId: string, + @Param('id') id: string, + @Param('requirementKey') requirementKey: string, + @Body() dto: LinkControlsDto, + ) { + return this.frameworksService.linkControlsToRequirement( + id, + requirementKey, + organizationId, + dto.controlIds, + ); + } + @Delete(':id') @RequirePermission('framework', 'delete') @ApiOperation({ summary: 'Delete a framework instance' }) diff --git a/apps/api/src/frameworks/frameworks.service.spec.ts b/apps/api/src/frameworks/frameworks.service.spec.ts index f714be36bb..5d46c16012 100644 --- a/apps/api/src/frameworks/frameworks.service.spec.ts +++ b/apps/api/src/frameworks/frameworks.service.spec.ts @@ -9,6 +9,25 @@ jest.mock('@db', () => ({ findUnique: jest.fn(), delete: jest.fn(), }, + frameworkEditorRequirement: { + findMany: jest.fn(), + findFirst: jest.fn(), + }, + customRequirement: { + findMany: jest.fn(), + findFirst: jest.fn(), + create: jest.fn(), + }, + requirementMap: { + findMany: jest.fn(), + createMany: jest.fn(), + }, + task: { + findMany: jest.fn(), + }, + evidenceSubmission: { + findMany: jest.fn(), + }, }, })); @@ -58,7 +77,7 @@ describe('FrameworksService', () => { expect(result).toEqual(mockInstances); expect(mockDb.frameworkInstance.findMany).toHaveBeenCalledWith({ where: { organizationId: 'org_1' }, - include: { framework: true }, + include: { framework: true, customFramework: true }, }); }); @@ -135,4 +154,85 @@ describe('FrameworksService', () => { }); }); }); + + // Regression coverage for the cross-tenant leak that existed on this branch + // before the split: previously both findOne and findRequirement read from + // FrameworkEditorRequirement without filtering by organizationId, so an org's + // request could surface another org's custom requirements sharing a framework. + // With the split, a custom framework instance reads from `customRequirement` + // (which is always org-scoped) and a platform framework instance reads from + // the global `frameworkEditorRequirement`. There is no shared table to leak. + describe('custom-framework isolation', () => { + it('findOne on a custom FI reads only that org\'s custom requirements', async () => { + (mockDb.frameworkInstance.findUnique as jest.Mock).mockResolvedValue({ + id: 'fi_custom', + organizationId: 'org_A', + frameworkId: null, + customFrameworkId: 'cfrm_A', + customFramework: { id: 'cfrm_A', name: 'A Custom' }, + framework: null, + requirementsMapped: [], + }); + (mockDb.customRequirement.findMany as jest.Mock).mockResolvedValue([ + { id: 'creq_1', name: 'R1', identifier: 'R1', description: '' }, + ]); + (mockDb.task.findMany as jest.Mock).mockResolvedValue([]); + (mockDb.requirementMap.findMany as jest.Mock).mockResolvedValue([]); + (mockDb.evidenceSubmission.findMany as jest.Mock).mockResolvedValue([]); + + const result = await service.findOne('fi_custom', 'org_A'); + + expect(mockDb.customRequirement.findMany).toHaveBeenCalledWith({ + where: { customFrameworkId: 'cfrm_A' }, + orderBy: { name: 'asc' }, + }); + expect( + mockDb.frameworkEditorRequirement.findMany, + ).not.toHaveBeenCalled(); + expect(result.requirementDefinitions).toHaveLength(1); + }); + + it('findOne on a platform FI reads only FrameworkEditorRequirement', async () => { + (mockDb.frameworkInstance.findUnique as jest.Mock).mockResolvedValue({ + id: 'fi_platform', + organizationId: 'org_A', + frameworkId: 'frk_soc2', + customFrameworkId: null, + framework: { id: 'frk_soc2', name: 'SOC 2' }, + customFramework: null, + requirementsMapped: [], + }); + ( + mockDb.frameworkEditorRequirement.findMany as jest.Mock + ).mockResolvedValue([ + { id: 'frk_rq_1', name: 'CC1', identifier: 'cc1-1', description: '' }, + ]); + (mockDb.task.findMany as jest.Mock).mockResolvedValue([]); + (mockDb.requirementMap.findMany as jest.Mock).mockResolvedValue([]); + (mockDb.evidenceSubmission.findMany as jest.Mock).mockResolvedValue([]); + + await service.findOne('fi_platform', 'org_A'); + + expect(mockDb.frameworkEditorRequirement.findMany).toHaveBeenCalledWith({ + where: { frameworkId: 'frk_soc2' }, + orderBy: { name: 'asc' }, + }); + expect(mockDb.customRequirement.findMany).not.toHaveBeenCalled(); + }); + + it('createRequirement rejects a platform framework instance', async () => { + (mockDb.frameworkInstance.findUnique as jest.Mock).mockResolvedValue({ + customFrameworkId: null, + }); + + await expect( + service.createRequirement('fi_platform', 'org_A', { + name: 'x', + identifier: 'x', + description: 'x', + }), + ).rejects.toThrow(/Cannot add custom requirements/); + expect(mockDb.customRequirement.create).not.toHaveBeenCalled(); + }); + }); }); diff --git a/apps/api/src/frameworks/frameworks.service.ts b/apps/api/src/frameworks/frameworks.service.ts index 311783ec2c..6947aa8524 100644 --- a/apps/api/src/frameworks/frameworks.service.ts +++ b/apps/api/src/frameworks/frameworks.service.ts @@ -11,8 +11,52 @@ import { } from './frameworks-scores.helper'; import { upsertOrgFrameworkStructure } from './frameworks-upsert.helper'; +type RequirementDef = { + id: string; + name: string; + identifier: string; + description: string; + frameworkId: string | null; + customFrameworkId: string | null; +}; + @Injectable() export class FrameworksService { + private async loadRequirementDefinitions(fi: { + frameworkId: string | null; + customFrameworkId: string | null; + }): Promise { + if (fi.customFrameworkId) { + const rows = await db.customRequirement.findMany({ + where: { customFrameworkId: fi.customFrameworkId }, + orderBy: { name: 'asc' }, + }); + return rows.map((r) => ({ + id: r.id, + name: r.name, + identifier: r.identifier, + description: r.description, + frameworkId: null, + customFrameworkId: r.customFrameworkId, + })); + } + if (fi.frameworkId) { + const rows = await db.frameworkEditorRequirement.findMany({ + where: { frameworkId: fi.frameworkId }, + orderBy: { name: 'asc' }, + }); + return rows.map((r) => ({ + id: r.id, + name: r.name, + identifier: r.identifier, + description: r.description, + frameworkId: r.frameworkId, + customFrameworkId: null, + })); + } + return []; + } + async findAll( organizationId: string, options?: { includeControls?: boolean; includeScores?: boolean }, @@ -24,6 +68,7 @@ export class FrameworksService { where: { organizationId }, include: { framework: true, + customFramework: true, ...(includeControls && { requirementsMapped: { include: { @@ -32,6 +77,7 @@ export class FrameworksService { policies: { select: { id: true, name: true, status: true }, }, + controlDocumentTypes: true, requirementsMapped: true, }, }, @@ -45,7 +91,6 @@ export class FrameworksService { return frameworkInstances; } - // Deduplicate controls from requirementsMapped const frameworksWithControls = frameworkInstances.map((fi: any) => { const controlsMap = new Map(); for (const rm of fi.requirementsMapped || []) { @@ -66,18 +111,27 @@ export class FrameworksService { return frameworksWithControls; } - // Fetch tasks for scoring - const tasks = await db.task.findMany({ - where: { - organizationId, - controls: { some: { organizationId } }, - }, - include: { controls: true }, - }); + const [tasks, evidenceSubmissions] = await Promise.all([ + db.task.findMany({ + where: { + organizationId, + controls: { some: { organizationId } }, + }, + include: { controls: true }, + }), + db.evidenceSubmission.findMany({ + where: { organizationId }, + select: { formType: true, submittedAt: true }, + }), + ]); return frameworksWithControls.map((fw: any) => ({ ...fw, - complianceScore: computeFrameworkComplianceScore(fw, tasks), + complianceScore: computeFrameworkComplianceScore( + fw, + tasks, + evidenceSubmissions, + ), })); } @@ -86,6 +140,7 @@ export class FrameworksService { where: { id: frameworkInstanceId, organizationId }, include: { framework: true, + customFramework: true, requirementsMapped: { include: { control: { @@ -106,7 +161,6 @@ export class FrameworksService { throw new NotFoundException('Framework instance not found'); } - // Deduplicate controls const controlsMap = new Map(); for (const rm of fi.requirementsMapped) { if (rm.control && !controlsMap.has(rm.control.id)) { @@ -121,7 +175,6 @@ export class FrameworksService { } const { requirementsMapped: _, ...rest } = fi; - // Collect all required evidence form types across all controls const allFormTypes = new Set(); for (const control of controlsMap.values()) { for (const dt of control.controlDocumentTypes) { @@ -129,35 +182,28 @@ export class FrameworksService { } } - const [ - requirementDefinitions, - tasks, - requirementMaps, - evidenceSubmissions, - ] = await Promise.all([ - db.frameworkEditorRequirement.findMany({ - where: { frameworkId: fi.frameworkId }, - orderBy: { name: 'asc' }, - }), - db.task.findMany({ - where: { organizationId, controls: { some: { organizationId } } }, - include: { controls: true }, - }), - db.requirementMap.findMany({ - where: { frameworkInstanceId }, - include: { control: true }, - }), - allFormTypes.size > 0 - ? db.evidenceSubmission.findMany({ - where: { - organizationId, - formType: { in: Array.from(allFormTypes) }, - }, - select: { id: true, formType: true, createdAt: true }, - orderBy: { createdAt: 'desc' }, - }) - : Promise.resolve([]), - ]); + const [requirementDefinitions, tasks, requirementMaps, evidenceSubmissions] = + await Promise.all([ + this.loadRequirementDefinitions(fi), + db.task.findMany({ + where: { organizationId, controls: { some: { organizationId } } }, + include: { controls: true }, + }), + db.requirementMap.findMany({ + where: { frameworkInstanceId }, + include: { control: true }, + }), + allFormTypes.size > 0 + ? db.evidenceSubmission.findMany({ + where: { + organizationId, + formType: { in: Array.from(allFormTypes) }, + }, + select: { id: true, formType: true, submittedAt: true }, + orderBy: { submittedAt: 'desc' }, + }) + : Promise.resolve([]), + ]); return { ...rest, @@ -169,12 +215,198 @@ export class FrameworksService { }; } - async findAvailable() { - const frameworks = await db.frameworkEditorFramework.findMany({ - where: { visible: true }, - include: { requirements: true }, + async findAvailable(organizationId?: string) { + const [platform, custom] = await Promise.all([ + db.frameworkEditorFramework.findMany({ + where: { visible: true }, + include: { requirements: true }, + }), + organizationId + ? db.customFramework.findMany({ + where: { organizationId }, + include: { requirements: true }, + }) + : Promise.resolve([]), + ]); + + return [ + ...platform.map((f) => ({ ...f, isCustom: false as const })), + ...custom.map((f) => ({ ...f, visible: true, isCustom: true as const })), + ]; + } + + async createCustom( + organizationId: string, + input: { name: string; description: string; version?: string }, + ) { + return db.$transaction(async (tx) => { + const customFramework = await tx.customFramework.create({ + data: { + name: input.name, + description: input.description, + version: input.version ?? '1.0.0', + organizationId, + }, + }); + + const instance = await tx.frameworkInstance.create({ + data: { organizationId, customFrameworkId: customFramework.id }, + include: { framework: true, customFramework: true }, + }); + + return instance; + }); + } + + async createRequirement( + frameworkInstanceId: string, + organizationId: string, + input: { name: string; identifier: string; description: string }, + ) { + const fi = await db.frameworkInstance.findUnique({ + where: { id: frameworkInstanceId, organizationId }, + select: { customFrameworkId: true }, + }); + if (!fi) { + throw new NotFoundException('Framework instance not found'); + } + if (!fi.customFrameworkId) { + throw new BadRequestException( + 'Cannot add custom requirements to a platform framework', + ); + } + + return db.customRequirement.create({ + data: { + name: input.name, + identifier: input.identifier, + description: input.description, + customFrameworkId: fi.customFrameworkId, + organizationId, + }, + }); + } + + async linkRequirements( + frameworkInstanceId: string, + organizationId: string, + requirementIds: string[], + ) { + const fi = await db.frameworkInstance.findUnique({ + where: { id: frameworkInstanceId, organizationId }, + select: { customFrameworkId: true }, + }); + if (!fi) { + throw new NotFoundException('Framework instance not found'); + } + if (!fi.customFrameworkId) { + throw new BadRequestException( + 'Cannot link requirements into a platform framework', + ); + } + + // Sources may come from either the platform editor table or this org's + // custom requirements. + const [platformSources, customSources] = await Promise.all([ + db.frameworkEditorRequirement.findMany({ + where: { id: { in: requirementIds } }, + select: { name: true, identifier: true, description: true }, + }), + db.customRequirement.findMany({ + where: { id: { in: requirementIds }, organizationId }, + select: { name: true, identifier: true, description: true }, + }), + ]); + const sources = [...platformSources, ...customSources]; + if (sources.length === 0) { + throw new BadRequestException('No valid requirements to link'); + } + + const existing = await db.customRequirement.findMany({ + where: { + customFrameworkId: fi.customFrameworkId, + identifier: { in: sources.map((r) => r.identifier) }, + }, + select: { identifier: true }, + }); + const existingIdentifiers = new Set(existing.map((r) => r.identifier)); + const toCreate = sources.filter( + (r) => !existingIdentifiers.has(r.identifier), + ); + if (toCreate.length === 0) { + return { count: 0, requirements: [] }; + } + + const created = await db.customRequirement.createManyAndReturn({ + data: toCreate.map((r) => ({ + name: r.name, + identifier: r.identifier, + description: r.description, + customFrameworkId: fi.customFrameworkId!, + organizationId, + })), + }); + + return { count: created.length, requirements: created }; + } + + async linkControlsToRequirement( + frameworkInstanceId: string, + requirementKey: string, + organizationId: string, + controlIds: string[], + ) { + const fi = await db.frameworkInstance.findUnique({ + where: { id: frameworkInstanceId, organizationId }, + select: { id: true, frameworkId: true, customFrameworkId: true }, + }); + if (!fi) { + throw new NotFoundException('Framework instance not found'); + } + + let requirementKind: 'platform' | 'custom'; + if (fi.customFrameworkId) { + const req = await db.customRequirement.findFirst({ + where: { + id: requirementKey, + customFrameworkId: fi.customFrameworkId, + organizationId, + }, + select: { id: true }, + }); + if (!req) throw new NotFoundException('Requirement not found'); + requirementKind = 'custom'; + } else if (fi.frameworkId) { + const req = await db.frameworkEditorRequirement.findFirst({ + where: { id: requirementKey, frameworkId: fi.frameworkId }, + select: { id: true }, + }); + if (!req) throw new NotFoundException('Requirement not found'); + requirementKind = 'platform'; + } else { + throw new NotFoundException('Requirement not found'); + } + + const controls = await db.control.findMany({ + where: { id: { in: controlIds }, organizationId }, + select: { id: true }, }); - return frameworks; + if (controls.length === 0) { + throw new BadRequestException('No valid controls to link'); + } + + const result = await db.requirementMap.createMany({ + data: controls.map((c) => ({ + controlId: c.id, + frameworkInstanceId, + ...(requirementKind === 'custom' + ? { customRequirementId: requirementKey } + : { requirementId: requirementKey }), + })), + skipDuplicates: true, + }); + + return { count: result.count }; } async getScores(organizationId: string, userId?: string) { @@ -220,19 +452,26 @@ export class FrameworksService { ) { const fi = await db.frameworkInstance.findUnique({ where: { id: frameworkInstanceId, organizationId }, - select: { id: true, frameworkId: true }, + select: { id: true, frameworkId: true, customFrameworkId: true }, }); - if (!fi) { throw new NotFoundException('Framework instance not found'); } - const [allReqDefs, relatedControls, tasks] = await Promise.all([ - db.frameworkEditorRequirement.findMany({ - where: { frameworkId: fi.frameworkId }, - }), + const allReqDefs = await this.loadRequirementDefinitions(fi); + const requirement = allReqDefs.find((r) => r.id === requirementKey); + if (!requirement) { + throw new NotFoundException('Requirement not found'); + } + + const [relatedControls, tasks] = await Promise.all([ db.requirementMap.findMany({ - where: { frameworkInstanceId, requirementId: requirementKey }, + where: { + frameworkInstanceId, + ...(fi.customFrameworkId + ? { customRequirementId: requirementKey } + : { requirementId: requirementKey }), + }, include: { control: { include: { @@ -250,12 +489,6 @@ export class FrameworksService { }), ]); - const requirement = allReqDefs.find((r) => r.id === requirementKey); - if (!requirement) { - throw new NotFoundException('Requirement not found'); - } - - // Collect evidence form types for related controls const formTypes = new Set(); for (const rc of relatedControls) { for (const dt of rc.control.controlDocumentTypes || []) { @@ -270,8 +503,8 @@ export class FrameworksService { organizationId, formType: { in: Array.from(formTypes) }, }, - select: { id: true, formType: true, createdAt: true }, - orderBy: { createdAt: 'desc' }, + select: { id: true, formType: true, submittedAt: true }, + orderBy: { submittedAt: 'desc' }, }) : []; diff --git a/apps/api/src/policies/policies.controller.ts b/apps/api/src/policies/policies.controller.ts index 7cde6fa40c..0ee238b346 100644 --- a/apps/api/src/policies/policies.controller.ts +++ b/apps/api/src/policies/policies.controller.ts @@ -225,20 +225,43 @@ export class PoliciesController { const instances = await db.frameworkInstance.findMany({ where: { organizationId }, - include: { framework: true }, + include: { framework: true, customFramework: true }, }); + // Normalize platform + org-custom frameworks into a single shape so the AI + // context reflects every framework the org has enabled, not just platform. + const normalized = instances.map((fi) => { + if (fi.framework) { + return { + id: fi.framework.id, + name: fi.framework.name, + version: fi.framework.version, + description: fi.framework.description, + visible: fi.framework.visible, + createdAt: fi.framework.createdAt, + updatedAt: fi.framework.updatedAt, + }; + } + if (fi.customFramework) { + return { + id: fi.customFramework.id, + name: fi.customFramework.name, + version: fi.customFramework.version, + description: fi.customFramework.description, + visible: true, + createdAt: fi.customFramework.createdAt, + updatedAt: fi.customFramework.updatedAt, + }; + } + return null; + }); const uniqueFrameworks = Array.from( - new Map(instances.map((fi) => [fi.framework.id, fi.framework])).values(), - ).map((f) => ({ - id: f.id, - name: f.name, - version: f.version, - description: f.description, - visible: f.visible, - createdAt: f.createdAt, - updatedAt: f.updatedAt, - })); + new Map( + normalized + .filter((f): f is NonNullable => f !== null) + .map((f) => [f.id, f]), + ).values(), + ); const contextEntries = await db.context.findMany({ where: { organizationId }, diff --git a/apps/api/src/tasks/tasks.service.ts b/apps/api/src/tasks/tasks.service.ts index f9a89caf3d..dd19c9dda6 100644 --- a/apps/api/src/tasks/tasks.service.ts +++ b/apps/api/src/tasks/tasks.service.ts @@ -75,9 +75,12 @@ export class TasksService { where: { organizationId }, include: { framework: { select: { name: true } }, + customFramework: { select: { name: true } }, }, }); - return instances.map((fi) => fi.framework.name); + return instances + .map((fi) => fi.framework?.name ?? fi.customFramework?.name) + .filter((name): name is string => Boolean(name)); } /** @@ -315,6 +318,7 @@ export class TasksService { where: { organizationId }, include: { framework: { select: { id: true, name: true } }, + customFramework: { select: { id: true, name: true } }, requirementsMapped: { select: { controlId: true } }, }, }), diff --git a/apps/api/src/trigger/policies/update-policy.ts b/apps/api/src/trigger/policies/update-policy.ts index ae10c57284..93caa20eb7 100644 --- a/apps/api/src/trigger/policies/update-policy.ts +++ b/apps/api/src/trigger/policies/update-policy.ts @@ -25,6 +25,7 @@ export const updatePolicy = schemaTask({ version: z.string(), description: z.string(), visible: z.boolean(), + organizationId: z.string().nullable().default(null), createdAt: z.date(), updatedAt: z.date(), }), diff --git a/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/PoliciesTable.tsx b/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/PoliciesTable.tsx index 1a4a0212d7..0b79f44172 100644 --- a/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/PoliciesTable.tsx +++ b/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/PoliciesTable.tsx @@ -87,7 +87,14 @@ export function PoliciesTable({ policies, orgId }: PoliciesTableProps) { } }} > - {policy.name} + + + {policy.name} + + {new Date(policy.createdAt).toLocaleDateString()} diff --git a/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/RequirementsTable.tsx b/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/RequirementsTable.tsx index 10ca7f8d0c..e25b2d4ca8 100644 --- a/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/RequirementsTable.tsx +++ b/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/RequirementsTable.tsx @@ -74,6 +74,7 @@ export function RequirementsTable({ requirements, orgId }: RequirementsTableProp + Identifier Name Description @@ -81,39 +82,53 @@ export function RequirementsTable({ requirements, orgId }: RequirementsTableProp {filteredRequirements.length === 0 ? ( - + No requirements found. ) : ( - filteredRequirements.map((requirement) => ( - handleRowClick(requirement)} - onKeyDown={(event) => { - if (event.key === 'Enter' || event.key === ' ') { - event.preventDefault(); - handleRowClick(requirement); - } - }} - > - - - {requirement.requirement.name} - - - - - {requirement.requirement.description} - - - - )) + filteredRequirements.map((requirement) => { + const identifier = requirement.requirement.identifier?.trim(); + const name = requirement.requirement.name; + const description = requirement.requirement.description; + return ( + handleRowClick(requirement)} + onKeyDown={(event) => { + if (event.key === 'Enter' || event.key === ' ') { + event.preventDefault(); + handleRowClick(requirement); + } + }} + > + + {identifier || '—'} + + + + {name} + + + + + {description} + + + + ); + }) )}
diff --git a/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/TasksTable.tsx b/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/TasksTable.tsx index 99f1c92727..e39cd14bce 100644 --- a/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/TasksTable.tsx +++ b/apps/app/src/app/(app)/[orgId]/controls/[controlId]/components/TasksTable.tsx @@ -89,9 +89,21 @@ export function TasksTable({ tasks, orgId }: TasksTableProps) { } }} > - {task.title} - {task.description} + + {task.title} + + + + + {task.description} + diff --git a/apps/app/src/app/(app)/[orgId]/controls/hooks/useControls.ts b/apps/app/src/app/(app)/[orgId]/controls/hooks/useControls.ts index b46df5eedd..1ddce147bc 100644 --- a/apps/app/src/app/(app)/[orgId]/controls/hooks/useControls.ts +++ b/apps/app/src/app/(app)/[orgId]/controls/hooks/useControls.ts @@ -9,15 +9,17 @@ interface ControlsApiResponse { pageCount: number; } +type RequirementMappingPayload = + | { requirementId: string; customRequirementId?: never; frameworkInstanceId: string } + | { requirementId?: never; customRequirementId: string; frameworkInstanceId: string }; + interface CreateControlPayload { name: string; description: string; policyIds?: string[]; taskIds?: string[]; - requirementMappings?: { - requirementId: string; - frameworkInstanceId: string; - }[]; + requirementMappings?: RequirementMappingPayload[]; + documentTypes?: string[]; } export const controlsKey = () => ['/v1/controls'] as const; diff --git a/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/AddCustomRequirementSheet.tsx b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/AddCustomRequirementSheet.tsx new file mode 100644 index 0000000000..3ca26750d8 --- /dev/null +++ b/apps/app/src/app/(app)/[orgId]/frameworks/[frameworkInstanceId]/components/AddCustomRequirementSheet.tsx @@ -0,0 +1,160 @@ +'use client'; + +import { apiClient } from '@/lib/api-client'; +import { usePermissions } from '@/hooks/use-permissions'; +import { + Button, + Sheet, + SheetBody, + SheetContent, + SheetHeader, + SheetTitle, +} from '@trycompai/design-system'; +import { Add } from '@trycompai/design-system/icons'; +import { + Form, + FormControl, + FormField, + FormItem, + FormLabel, + FormMessage, +} from '@trycompai/ui/form'; +import { Input } from '@trycompai/ui/input'; +import { Textarea } from '@trycompai/ui/textarea'; +import { zodResolver } from '@hookform/resolvers/zod'; +import { useRouter } from 'next/navigation'; +import { useEffect, useState } from 'react'; +import { useForm } from 'react-hook-form'; +import { toast } from 'sonner'; +import { z } from 'zod'; + +const schema = z.object({ + identifier: z.string().min(1, 'Identifier is required').max(80), + name: z.string().min(1, 'Name is required').max(200), + description: z.string().max(4000), +}); + +type FormValues = z.infer; + +export function AddCustomRequirementSheet({ + frameworkInstanceId, +}: { + frameworkInstanceId: string; +}) { + const { hasPermission } = usePermissions(); + const router = useRouter(); + const [isOpen, setIsOpen] = useState(false); + const [isSubmitting, setIsSubmitting] = useState(false); + + const form = useForm({ + resolver: zodResolver(schema), + defaultValues: { identifier: '', name: '', description: '' }, + mode: 'onChange', + }); + + useEffect(() => { + if (!isOpen) { + form.reset({ identifier: '', name: '', description: '' }); + } + }, [isOpen, form]); + + if (!hasPermission('framework', 'update')) return null; + + const handleSubmit = async (values: FormValues) => { + if (isSubmitting) return; + setIsSubmitting(true); + try { + const response = await apiClient.post( + `/v1/frameworks/${frameworkInstanceId}/requirements`, + values, + ); + if (response.error) throw new Error(response.error); + toast.success('Requirement added'); + setIsOpen(false); + form.reset(); + router.refresh(); + } catch (error) { + toast.error( + error instanceof Error ? error.message : 'Failed to add requirement', + ); + } finally { + setIsSubmitting(false); + } + }; + + return ( + <> + + + + + Add Custom Requirement + + +
+ + ( + + Identifier + + + + + + )} + /> + ( + + Name + + + + + + )} + /> + ( + + Description + +