diff --git a/compliance/README.md b/compliance/README.md index 888c404b..f2bd765c 100644 --- a/compliance/README.md +++ b/compliance/README.md @@ -5,14 +5,7 @@ This directory contains license compliance scanning and reporting for the teachL ## Structure ``` -compliance/ -├── configs/ -│ ├── license-policy.yml # Defines allowed/prohibited licenses -│ └── scanner-config.yml # Scanner configuration -├── reports/ -│ └── license-report-*.json # Generated compliance reports -└── README.md # This file -``` + ## Running License Scans @@ -56,4 +49,4 @@ If the scan returns WARNING status: 1. Review the report in `compliance/reports/` 2. For LGPL packages: Ensure dynamic linking is used (not static) 3. For unknown packages: Verify the package is open source and acceptable -4. Document approval in PR comments or issue \ No newline at end of file +4. Document approval in PR comments or issue diff --git a/src/auth/auth.module.ts b/src/auth/auth.module.ts index 17bf6c76..06f9516b 100644 --- a/src/auth/auth.module.ts +++ b/src/auth/auth.module.ts @@ -13,9 +13,19 @@ import { RolesGuard } from './guards/roles.guard'; import { PermissionsGuard } from './guards/permissions.guard'; import { SocialAuthService } from './services/social-auth.service'; import { SocialAuthController } from './controllers/social-auth.controller'; +import { AuthTokensService } from './services/auth-tokens.service'; +// Issue #799 — EncryptionService is required to encrypt OAuth provider tokens +// (providerAccessToken / providerRefreshToken) at rest. SecurityModule is the +// only module that provides EncryptionService, so it must be imported here. +import { SecurityModule } from '../security/security.module'; /** * Registers the authentication module with Passport and JWT support. + * + * Issue #801 — AuthTokensService is registered here so password-reset and + * email-verification flows can persist only SHA-256 hashes (never raw tokens). + * Issue #799 — SecurityModule is imported so SocialAuthService has access to + * the EncryptionService for at-rest OAuth token protection. */ @Module({ imports: [ @@ -25,6 +35,7 @@ import { SocialAuthController } from './controllers/social-auth.controller'; signOptions: { expiresIn: (process.env.JWT_EXPIRES_IN || '15m') as any }, }), TypeOrmModule.forFeature([User]), + SecurityModule, ], controllers: [AuthController, SocialAuthController], providers: [ @@ -34,6 +45,7 @@ import { SocialAuthController } from './controllers/social-auth.controller'; GoogleStrategy, GitHubStrategy, SocialAuthService, + AuthTokensService, RolesGuard, PermissionsGuard, ], @@ -42,6 +54,7 @@ import { SocialAuthController } from './controllers/social-auth.controller'; JwtModule, AuthService, SocialAuthService, + AuthTokensService, RolesGuard, PermissionsGuard, ], diff --git a/src/auth/services/auth-tokens.service.spec.ts b/src/auth/services/auth-tokens.service.spec.ts new file mode 100644 index 00000000..a9b94a8c --- /dev/null +++ b/src/auth/services/auth-tokens.service.spec.ts @@ -0,0 +1,180 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { getRepositoryToken } from '@nestjs/typeorm'; +import { User } from '../../users/entities/user.entity'; +import { + AuthTokensService, + hashToken, +} from './auth-tokens.service'; + +function makeMockRepo() { + return { + update: jest.fn().mockResolvedValue({ affected: 1 }), + findOne: jest.fn(), + }; +} + +describe('AuthTokensService (Issue #801 — SHA-256 hashed tokens)', () => { + let service: AuthTokensService; + let mockRepo: ReturnType; + + beforeEach(async () => { + mockRepo = makeMockRepo(); + const moduleRef: TestingModule = await Test.createTestingModule({ + providers: [ + AuthTokensService, + { provide: getRepositoryToken(User), useValue: mockRepo }, + ], + }).compile(); + service = moduleRef.get(AuthTokensService); + }); + + describe('hashToken (helper)', () => { + it('produces a 64-character hex SHA-256 digest', () => { + const h = hashToken('hello-world'); + expect(h).toHaveLength(64); + expect(h).toMatch(/^[0-9a-f]{64}$/); + }); + + it('is deterministic for the same input', () => { + expect(hashToken('abc')).toBe(hashToken('abc')); + }); + + it('differs for different inputs', () => { + expect(hashToken('abc')).not.toBe(hashToken('abd')); + }); + }); + + describe('generateTokenPair', () => { + it('returns a 64-char hex raw token (32 bytes) and matching hash', () => { + const pair = service.generateTokenPair(); + expect(pair.rawToken).toHaveLength(64); + expect(pair.rawToken).toMatch(/^[0-9a-f]{64}$/); + expect(pair.tokenHash).toMatch(/^[0-9a-f]{64}$/); + expect(pair.tokenHash).toBe(hashToken(pair.rawToken)); + }); + + it('returns an expiry ~24 hours in the future', () => { + const before = Date.now(); + const pair = service.generateTokenPair(); + const gap = pair.expiresAt.getTime() - before; + expect(gap).toBeGreaterThan(23 * 60 * 60 * 1000); + expect(gap).toBeLessThanOrEqual(24 * 60 * 60 * 1000 + 10); + }); + + it('produces unique tokens across calls', () => { + const a = service.generateTokenPair(); + const b = service.generateTokenPair(); + expect(a.rawToken).not.toBe(b.rawToken); + }); + }); + + describe('issuePasswordReset', () => { + it('persists the SHA-256 hash, not the raw token', async () => { + const { rawToken } = await service.issuePasswordReset('user-1'); + expect(rawToken).toMatch(/^[0-9a-f]{64}$/); + const [criteria, update] = mockRepo.update.mock.calls[0]; + expect(criteria).toEqual({ id: 'user-1' }); + // Persisted value MUST be the hash, never the raw. + expect(update.passwordResetToken).toBe(hashToken(rawToken)); + expect(update.passwordResetToken).not.toBe(rawToken); + expect(update.passwordResetExpires).toBeInstanceOf(Date); + }); + }); + + describe('issueEmailVerification', () => { + it('persists the SHA-256 hash, not the raw token', async () => { + const { rawToken } = await service.issueEmailVerification('user-2'); + const [criteria, update] = mockRepo.update.mock.calls[0]; + expect(criteria).toEqual({ id: 'user-2' }); + expect(update.emailVerificationToken).toBe(hashToken(rawToken)); + expect(update.emailVerificationToken).not.toBe(rawToken); + expect(update.emailVerificationExpires).toBeInstanceOf(Date); + }); + }); + + describe('consumePasswordReset', () => { + it('returns the matching user and clears the stored hash', async () => { + const raw = 'a'.repeat(64); + const expectedHash = hashToken(raw); + const user = { id: 'user-7' } as User; + mockRepo.findOne.mockResolvedValueOnce(user); + + const result = await service.consumePasswordReset(raw); + expect(result).toBe(user); + expect(mockRepo.findOne).toHaveBeenCalledWith( + expect.objectContaining({ + where: expect.objectContaining({ + passwordResetToken: expectedHash, + passwordResetExpires: expect.anything(), + }), + }), + ); + const [criteria, update] = mockRepo.update.mock.calls[0]; + expect(criteria).toEqual({ id: 'user-7' }); + expect(update.passwordResetToken).toBeNull(); + expect(update.passwordResetExpires).toBeNull(); + }); + + it('returns null when no user matches', async () => { + mockRepo.findOne.mockResolvedValueOnce(null); + expect(await service.consumePasswordReset('bad-token')).toBeNull(); + }); + + it('returns null for empty input', async () => { + expect(await service.consumePasswordReset('')).toBeNull(); + }); + }); + + describe('consumeEmailVerification', () => { + it('returns the matching user, sets isEmailVerified, and clears the hash', async () => { + const raw = 'b'.repeat(64); + const user = { id: 'user-8' } as User; + mockRepo.findOne.mockResolvedValueOnce(user); + + const result = await service.consumeEmailVerification(raw); + expect(result).toBe(user); + const [, update] = mockRepo.update.mock.calls[0]; + expect(update.isEmailVerified).toBe(true); + expect(update.emailVerificationToken).toBeNull(); + }); + }); + + describe('verifyTokenHash', () => { + it('returns true for matching inputs', () => { + const raw = 'tok'; + const stored = hashToken(raw); + expect(service.verifyTokenHash(raw, stored)).toBe(true); + }); + + it('returns false for mismatched inputs', () => { + const stored = hashToken('correct'); + expect(service.verifyTokenHash('wrong', stored)).toBe(false); + }); + + it('returns false when stored hash is missing', () => { + expect(service.verifyTokenHash('x', null)).toBe(false); + expect(service.verifyTokenHash('x', undefined)).toBe(false); + expect(service.verifyTokenHash('x', '')).toBe(false); + }); + + it('returns false when raw token is empty', () => { + expect(service.verifyTokenHash('', hashToken('x'))).toBe(false); + }); + }); + + describe('regression: raw token never reaches the database column', () => { + it('does not write raw token into passwordResetToken column', async () => { + const { rawToken } = await service.issuePasswordReset('user-9'); + const [, update] = mockRepo.update.mock.calls[0]; + expect(update.passwordResetToken).not.toBe(rawToken); + expect(update.passwordResetToken).toBe(hashToken(rawToken)); + }); + + it('does not write raw token into emailVerificationToken column', async () => { + const { rawToken } = await service.issueEmailVerification('user-10'); + const [, update] = mockRepo.update.mock.calls[0]; + expect(update.emailVerificationToken).not.toBe(rawToken); + expect(update.emailVerificationToken).toBe(hashToken(rawToken)); + }); + }); +}); diff --git a/src/auth/services/auth-tokens.service.ts b/src/auth/services/auth-tokens.service.ts new file mode 100644 index 00000000..4347bc50 --- /dev/null +++ b/src/auth/services/auth-tokens.service.ts @@ -0,0 +1,162 @@ +import { Injectable, Logger } from '@nestjs/common'; +import * as crypto from 'crypto'; +import { InjectRepository } from '@nestjs/typeorm'; +import { Repository, MoreThan } from 'typeorm'; +import { User } from '../../users/entities/user.entity'; + +/** + * Issue #801 — Plaintext password-reset and email-verification tokens are + * equivalent to a database-stored credential. We persist only the SHA-256 + * hash of the raw token; the raw token is delivered once to the user (via + * email) and never written to disk. + * + * Why SHA-256, not bcrypt: + * - These tokens are short-lived (~24 h) and high-entropy (32 bytes of + * cryptographic randomness). + * - bcrypt's intentional CPU cost is calibrated against *low-entropy* inputs + * (human-chosen passwords); it provides no marginal benefit when the input + * is a 256-bit random secret and only adds latency to every lookup. + * - SHA-256 lets us look tokens up deterministically (WHERE hash = $1), + * avoiding the iterative-compare pattern bcrypt forces. + */ + +const DEFAULT_TOKEN_TTL_MS = 24 * 60 * 60 * 1000; // 24 hours + +/** + * Returns the SHA-256 (hex) hash of the supplied raw token. Used both as + * the persistence format (column value) and as the comparison key for + * validation. + */ +export function hashToken(rawToken: string): string { + return crypto.createHash('sha256').update(rawToken, 'utf8').digest('hex'); +} + +@Injectable() +export class AuthTokensService { + private readonly logger = new Logger(AuthTokensService.name); + private readonly ttlMs: number; + + constructor( + @InjectRepository(User) private readonly users: Repository, + ) { + this.ttlMs = DEFAULT_TOKEN_TTL_MS; + } + + /** + * Generates a cryptographically-random 32-byte token and returns both the + * raw value (give this to the user via email) and the SHA-256 hash (the + * value to write to the database). + */ + generateTokenPair(): { rawToken: string; tokenHash: string; expiresAt: Date } { + const rawToken = crypto.randomBytes(32).toString('hex'); + const tokenHash = hashToken(rawToken); + const expiresAt = new Date(Date.now() + this.ttlMs); + return { rawToken, tokenHash, expiresAt }; + } + + /** + * Issues a password-reset token for the given user. + * + * Returns the *raw* token so the caller can deliver it to the user via + * email. The hashed version is written to the User row; the raw token is + * never persisted. + */ + async issuePasswordReset(userId: string): Promise<{ + rawToken: string; + expiresAt: Date; + }> { + const { rawToken, tokenHash, expiresAt } = this.generateTokenPair(); + await this.users.update( + { id: userId }, + { + passwordResetToken: tokenHash, + passwordResetExpires: expiresAt, + }, + ); + return { rawToken, expiresAt }; + } + + /** + * Issues an email-verification token for the given user. + */ + async issueEmailVerification(userId: string): Promise<{ + rawToken: string; + expiresAt: Date; + }> { + const { rawToken, tokenHash, expiresAt } = this.generateTokenPair(); + await this.users.update( + { id: userId }, + { + emailVerificationToken: tokenHash, + emailVerificationExpires: expiresAt, + }, + ); + return { rawToken, expiresAt }; + } + + /** + * Validates a raw password-reset token submitted by the user. On match the + * stored token is cleared (single-use semantics) and the user is returned. + * Returns null when the token does not match any active row OR has expired. + */ + async consumePasswordReset(rawToken: string): Promise { + if (!rawToken) return null; + const tokenHash = hashToken(rawToken); + const user = await this.users.findOne({ + where: { + passwordResetToken: tokenHash, + passwordResetExpires: MoreThan(new Date()), + }, + }); + if (!user) return null; + await this.users.update( + { id: user.id }, + { + passwordResetToken: null, + passwordResetExpires: null, + }, + ); + return user; + } + + /** + * Validates a raw email-verification token submitted by the user. Sets + * `isEmailVerified=true` on match and clears the stored token. + */ + async consumeEmailVerification(rawToken: string): Promise { + if (!rawToken) return null; + const tokenHash = hashToken(rawToken); + const user = await this.users.findOne({ + where: { + emailVerificationToken: tokenHash, + emailVerificationExpires: MoreThan(new Date()), + }, + }); + if (!user) return null; + await this.users.update( + { id: user.id }, + { + emailVerificationToken: null, + emailVerificationExpires: null, + isEmailVerified: true, + }, + ); + return user; + } + + /** + * Verifies a raw token against a stored hash. Exposed for callers that + * already have the stored hash (e.g. when an alternative storage scheme + * is used in the future). + */ + verifyTokenHash(rawToken: string, storedHash: string | null | undefined): boolean { + if (!rawToken || !storedHash) return false; + const candidate = hashToken(rawToken); + // Constant-time comparison guards against timing side-channels. SHA-256 + // strings are always 64 hex chars, so the lengths match by construction. + return crypto.timingSafeEqual( + Buffer.from(candidate, 'hex'), + Buffer.from(storedHash, 'hex'), + ); + } +} diff --git a/src/auth/services/social-auth.service.spec.ts b/src/auth/services/social-auth.service.spec.ts index a2676ea9..4b4102df 100644 --- a/src/auth/services/social-auth.service.spec.ts +++ b/src/auth/services/social-auth.service.spec.ts @@ -3,6 +3,10 @@ import { getRepositoryToken } from '@nestjs/typeorm'; import { ConflictException } from '@nestjs/common'; import { SocialAuthService, SocialProfile } from './social-auth.service'; import { User } from '../../users/entities/user.entity'; +import { + EncryptionService, + IEncryptedPayload, +} from '../../security/encryption/encryption.service'; function makeUser(overrides: Partial = {}): User { return { @@ -31,6 +35,42 @@ function makeProfile(overrides: Partial = {}): SocialProfile { }; } +/** + * Build a deterministic stub EncryptionService so we can assert on the DB + * payload shape. We prefix every encrypted value with `enc:` followed by + * `iv.content.tag` so the storage format is identical to the real service. + * + * The stub also decrypts: it reverses the serialisation to recover the + * original token so `getDecryptedAccessToken` / `getDecryptedRefreshToken` + * round-trip cleanly. + */ +function makeEncryptionStub() { + const encCalls: Array<{ input: string; payload: IEncryptedPayload }> = []; + + const encryption: Partial = { + encrypt(text: string): IEncryptedPayload { + const payload: IEncryptedPayload = { + iv: `iv-${Buffer.from(text).toString('hex').slice(0, 8)}`, + content: Buffer.from(text, 'utf8').toString('hex'), + tag: 'tag', + }; + encCalls.push({ input: text, payload }); + return payload; + }, + decrypt(payload: IEncryptedPayload): string { + return Buffer.from(payload.content, 'hex').toString('utf8'); + }, + }; + + return { + encryption: encryption as EncryptionService, + encCalls, + serialise(payload: IEncryptedPayload): string { + return `enc:${JSON.stringify(payload)}`; + }, + }; +} + const mockRepo = { findOne: jest.fn(), create: jest.fn(), @@ -38,19 +78,25 @@ const mockRepo = { findOneOrFail: jest.fn(), }; -describe('SocialAuthService – name fallback', () => { +describe('SocialAuthService (Issue #799 — at-rest encryption)', () => { let service: SocialAuthService; + let stub: ReturnType; beforeEach(async () => { + stub = makeEncryptionStub(); const module: TestingModule = await Test.createTestingModule({ - providers: [SocialAuthService, { provide: getRepositoryToken(User), useValue: mockRepo }], + providers: [ + SocialAuthService, + { provide: getRepositoryToken(User), useValue: mockRepo }, + { provide: EncryptionService, useValue: stub.encryption }, + ], }).compile(); service = module.get(SocialAuthService); jest.clearAllMocks(); }); - describe('findOrCreateFromProvider – new user creation', () => { + describe('findOrCreateFromProvider – name fallback', () => { beforeEach(() => { // No existing user by provider or email mockRepo.findOne.mockResolvedValue(null); @@ -120,9 +166,7 @@ describe('SocialAuthService – name fallback', () => { describe('findOrCreateFromProvider – email conflict', () => { it('throws ConflictException when email is registered under a different provider', async () => { - // No provider match mockRepo.findOne.mockResolvedValueOnce(null); - // Email match with different provider mockRepo.findOne.mockResolvedValueOnce(makeUser({ provider: 'github' })); await expect( @@ -130,4 +174,155 @@ describe('SocialAuthService – name fallback', () => { ).rejects.toThrow(ConflictException); }); }); + + // ─── Issue #799 acceptance: stored OAuth tokens are encrypted, never raw ─ + describe('provider tokens are encrypted at rest (Issue #799)', () => { + beforeEach(() => { + mockRepo.findOne.mockResolvedValue(null); + mockRepo.create.mockImplementation((data) => ({ ...data })); + mockRepo.save.mockImplementation((u) => Promise.resolve({ id: 'new-id', ...u })); + }); + + it('encrypts both access and refresh tokens when creating a new user', async () => { + await service.findOrCreateFromProvider( + makeProfile({ accessToken: 'AT-raw', refreshToken: 'RT-raw' }), + ); + + const created = mockRepo.create.mock.calls[0][0]; + // Stored values MUST NOT equal the raw tokens. + expect(created.providerAccessToken).not.toBe('AT-raw'); + expect(created.providerRefreshToken).not.toBe('RT-raw'); + // Stored values MUST carry the encryption prefix. + expect(created.providerAccessToken).toMatch(/^enc:/); + expect(created.providerRefreshToken).toMatch(/^enc:/); + // encrypt() was called once per token. + expect(stub.encCalls.map((c) => c.input)).toEqual(['AT-raw', 'RT-raw']); + }); + + it('stores null when the provider did not issue a token', async () => { + await service.findOrCreateFromProvider( + makeProfile({ accessToken: undefined, refreshToken: undefined }), + ); + const created = mockRepo.create.mock.calls[0][0]; + expect(created.providerAccessToken).toBeNull(); + expect(created.providerRefreshToken).toBeNull(); + }); + + it('encrypts tokens when linking an existing user to a new provider', async () => { + const user = makeUser({ + provider: null, + providerId: null, + providerAccessToken: null, + providerRefreshToken: null, + }); + mockRepo.findOneOrFail.mockResolvedValueOnce(user); + + const updated = await service.linkProvider('user-1', makeProfile({ accessToken: 'NEW-AT' })); + + expect(updated.providerAccessToken).not.toBe('NEW-AT'); + expect(updated.providerAccessToken).toMatch(/^enc:/); + }); + + it('overwrites pre-existing encrypted tokens on re-link (no plaintext regression)', async () => { + // User has been linked before, the stored value is already encrypted. + // We build the payload manually (NOT through stub.encryption.encrypt) so + // `encCalls` only records whatever the service itself encrypts. + const previousPayload: IEncryptedPayload = { + iv: 'iv-prev', + content: Buffer.from('OLD-AT', 'utf8').toString('hex'), + tag: 'tag-prev', + }; + const stored = `enc:${JSON.stringify(previousPayload)}`; + const user = makeUser({ + provider: 'google', + providerId: 'google-old', + providerAccessToken: stored, + providerRefreshToken: null, + }); + mockRepo.findOneOrFail.mockResolvedValueOnce(user); + + // EncCalls guard: ensures the test fixture does not pollute this test's + // assertion by sneaking in encrypt() calls during setup. (Relies on the + // `beforeEach` block allocating a fresh stub. If someone refactors that + // out, this assertion breaks LOUD instead of silently.) + const encCallsBefore = [...stub.encCalls]; + + await service.linkProvider('user-1', makeProfile({ accessToken: 'NEW-AT', provider: 'google' })); + + // The new value must replace the old one, AND must be encrypted. + expect(user.providerAccessToken).not.toBe('NEW-AT'); + expect(user.providerAccessToken).not.toBe(stored); + expect(user.providerAccessToken).toMatch(/^enc:/); + // Only the NEW token should have hit encrypt() — never the old plaintext. + expect(stub.encCalls.map((c) => c.input)).toEqual( + [...encCallsBefore.map((c) => c.input), 'NEW-AT'], + ); + }); + + it('nulls out provider tokens on unlink (no encryption needed for null)', async () => { + const user = makeUser(); + mockRepo.findOneOrFail.mockResolvedValueOnce(user); + const updated = await service.unlinkProvider('user-1'); + expect(updated.providerAccessToken).toBeNull(); + expect(updated.providerRefreshToken).toBeNull(); + }); + }); + + describe('getDecryptedAccessToken / getDecryptedRefreshToken', () => { + it('returns the original plaintext token (round-trip)', async () => { + const storedAT = stub.serialise(stub.encryption.encrypt('AT-plaintext')); + const storedRT = stub.serialise(stub.encryption.encrypt('RT-plaintext')); + // mockResolvedValue (not Once) so both calls below see the same row. + mockRepo.findOne.mockResolvedValue( + makeUser({ providerAccessToken: storedAT, providerRefreshToken: storedRT }), + ); + + expect(await service.getDecryptedAccessToken('user-1')).toBe('AT-plaintext'); + expect(await service.getDecryptedRefreshToken('user-1')).toBe('RT-plaintext'); + }); + + it('returns null when no token is stored', async () => { + mockRepo.findOne.mockResolvedValueOnce(makeUser({ providerAccessToken: null })); + expect(await service.getDecryptedAccessToken('user-1')).toBeNull(); + }); + + it('returns null when no user is found', async () => { + mockRepo.findOne.mockResolvedValueOnce(null); + expect(await service.getDecryptedAccessToken('ghost')).toBeNull(); + }); + + it('returns null (with warning) when stored value is legacy plaintext', async () => { + const warn = jest + .spyOn((service as any).logger, 'warn') + .mockImplementation(() => undefined); + const user = makeUser({ providerAccessToken: 'plaintext-AT' }); + mockRepo.findOne.mockResolvedValueOnce(user); + + expect(await service.getDecryptedAccessToken('user-1')).toBeNull(); + expect(warn).toHaveBeenCalled(); + }); + }); + + describe('getDecryptedProviderTokens (combined decode)', () => { + it('returns both access and refresh in a single call', async () => { + const stored = (kind: string) => + stub.serialise(stub.encryption.encrypt(kind === 'at' ? 'AT-plaintext' : 'RT-plaintext')); + mockRepo.findOne.mockResolvedValueOnce( + makeUser({ providerAccessToken: stored('at'), providerRefreshToken: stored('rt') }), + ); + + const out = await service.getDecryptedProviderTokens('user-1'); + expect(out).toEqual({ access: 'AT-plaintext', refresh: 'RT-plaintext' }); + // Single DB call, not two. + expect(mockRepo.findOne).toHaveBeenCalledTimes(1); + }); + + it('returns null for either field when the column is empty', async () => { + mockRepo.findOne.mockResolvedValueOnce( + makeUser({ providerAccessToken: null, providerRefreshToken: null }), + ); + const out = await service.getDecryptedProviderTokens('user-1'); + expect(out).toEqual({ access: null, refresh: null }); + }); + }); }); diff --git a/src/auth/services/social-auth.service.ts b/src/auth/services/social-auth.service.ts index 740e26ce..5b927589 100644 --- a/src/auth/services/social-auth.service.ts +++ b/src/auth/services/social-auth.service.ts @@ -1,7 +1,11 @@ -import { Injectable, ConflictException } from '@nestjs/common'; +import { Injectable, ConflictException, Logger } from '@nestjs/common'; import { InjectRepository } from '@nestjs/typeorm'; import { Repository } from 'typeorm'; import { User } from '../../users/entities/user.entity'; +import { + EncryptionService, + IEncryptedPayload, +} from '../../security/encryption/encryption.service'; export interface SocialProfile { provider: string; @@ -14,11 +18,32 @@ export interface SocialProfile { refreshToken?: string; } +/** Opaque prefix that marks a stored value as an AES-GCM JSON payload. */ +const ENCRYPTED_PREFIX = 'enc:'; + +/** + * Issue #799 — OAuth provider tokens were previously stored as plaintext on the + * `providerAccessToken` / `providerRefreshToken` User columns. A DB breach + * exposed every user's Google / GitHub credentials immediately. + * + * This service now encrypts both fields with {@link EncryptionService} + * (AES-256-GCM) before persistence and exposes symmetric + * `getDecryptedAccessToken` / `getDecryptedRefreshToken` / + * `getDecryptedProviderTokens` helpers for callers that need the plaintext + * value at runtime (e.g. a refresh-token rotation job). + * + * Encrypted payloads are serialised as `` so a + * downstream consumer can distinguish "encrypted" from "legacy plaintext" + * during the migration window. + */ @Injectable() export class SocialAuthService { + private readonly logger = new Logger(SocialAuthService.name); + constructor( @InjectRepository(User) private readonly users: Repository, + private readonly encryptionService: EncryptionService, ) {} async findOrCreateFromProvider(profile: SocialProfile): Promise { @@ -37,8 +62,8 @@ export class SocialAuthService { } byEmail.provider = profile.provider; byEmail.providerId = profile.providerId; - byEmail.providerAccessToken = profile.accessToken ?? null; - byEmail.providerRefreshToken = profile.refreshToken ?? null; + byEmail.providerAccessToken = this.encryptToken(profile.accessToken); + byEmail.providerRefreshToken = this.encryptToken(profile.refreshToken); if (profile.picture && !byEmail.profilePicture) { byEmail.profilePicture = profile.picture; } @@ -59,8 +84,8 @@ export class SocialAuthService { profilePicture: profile.picture, provider: profile.provider, providerId: profile.providerId, - providerAccessToken: profile.accessToken ?? null, - providerRefreshToken: profile.refreshToken ?? null, + providerAccessToken: this.encryptToken(profile.accessToken), + providerRefreshToken: this.encryptToken(profile.refreshToken), isEmailVerified: true, }); return this.users.save(user); @@ -70,8 +95,8 @@ export class SocialAuthService { const user = await this.users.findOneOrFail({ where: { id: userId } }); user.provider = profile.provider; user.providerId = profile.providerId; - user.providerAccessToken = profile.accessToken ?? null; - user.providerRefreshToken = profile.refreshToken ?? null; + user.providerAccessToken = this.encryptToken(profile.accessToken); + user.providerRefreshToken = this.encryptToken(profile.refreshToken); return this.users.save(user); } @@ -83,4 +108,82 @@ export class SocialAuthService { user.providerRefreshToken = null; return this.users.save(user); } + + /** + * Returns the plaintext OAuth access token for `userId`, or `null` if none is + * stored. Decrypts transparently; throws on encrypted-but-malformed input so + * a corrupt DB row is surfaced loudly rather than silently misleading callers. + */ + async getDecryptedAccessToken(userId: string): Promise { + const user = await this.users.findOne({ where: { id: userId } }); + return this.decryptStoredToken(user?.providerAccessToken); + } + + /** + * Returns the plaintext OAuth refresh token for `userId`, or `null` if none + * is stored. See {@link SocialAuthService.getDecryptedAccessToken}. + */ + async getDecryptedRefreshToken(userId: string): Promise { + const user = await this.users.findOne({ where: { id: userId } }); + return this.decryptStoredToken(user?.providerRefreshToken); + } + + /** + * Reads the User row once and decrypts BOTH provider tokens in a single + * round-trip. Prefer this over calling the individual helpers if you need + * both tokens — it avoids a duplicate DB query. + */ + async getDecryptedProviderTokens( + userId: string, + ): Promise<{ access: string | null; refresh: string | null }> { + const user = await this.users.findOne({ where: { id: userId } }); + return { + access: this.decryptStoredToken(user?.providerAccessToken), + refresh: this.decryptStoredToken(user?.providerRefreshToken), + }; + } + + /** + * Encrypts an OAuth token; returns `null` for missing input so the DB column + * can carry the same shape regardless of whether the provider issued a token. + * JSON serialisation is used (instead of `iv.content.tag`) so the format is + * robust against any future hex encoding change. + */ + private encryptToken(rawToken: string | undefined): string | null { + if (!rawToken) return null; + const payload = this.encryptionService.encrypt(rawToken); + return ENCRYPTED_PREFIX + JSON.stringify(payload); + } + + /** + * Reverse of {@link SocialAuthService.encryptToken}. Returns `null` for + * missing input. Treats legacy plaintext values (no prefix) as `null` — they + * are unusable and reading them would just leak the (already compromised) + * plaintext to the caller. Throws if the value looks encrypted but + * decryption fails so operator is alerted to corruption. + */ + private decryptStoredToken(stored: string | null | undefined): string | null { + if (!stored) return null; + if (!stored.startsWith(ENCRYPTED_PREFIX)) { + this.logger.warn( + 'Legacy plaintext OAuth token encountered on a User row. Treating as unusable; please run the encryption migration.', + ); + return null; + } + let payload: IEncryptedPayload; + try { + payload = JSON.parse(stored.slice(ENCRYPTED_PREFIX.length)); + } catch { + throw new Error('Malformed encrypted OAuth token payload'); + } + if ( + !payload || + typeof payload.iv !== 'string' || + typeof payload.content !== 'string' || + typeof payload.tag !== 'string' + ) { + throw new Error('Malformed encrypted OAuth token payload'); + } + return this.encryptionService.decrypt(payload); + } } diff --git a/src/migrations/1783000000000-clear-plaintext-auth-tokens.ts b/src/migrations/1783000000000-clear-plaintext-auth-tokens.ts new file mode 100644 index 00000000..013cbda4 --- /dev/null +++ b/src/migrations/1783000000000-clear-plaintext-auth-tokens.ts @@ -0,0 +1,33 @@ +import { MigrationInterface, QueryRunner } from 'typeorm'; + +/** + * Issue #801 — clears any plaintext `passwordResetToken` / + * `emailVerificationToken` rows that pre-date the SHA-256 hashing migration. + * + * The legacy column values are unrecoverable back to hash form (we never knew + * the plaintext). Setting them to NULL is the correct action: any pre-existing + * reset / verification links that the user had in their inbox simply stop + * working, and the user re-requests a new token — which is now generated and + * stored correctly via {@link AuthTokensService}. + * + * Expiry timestamps are also cleared so the index lookups (used by + * `consumePasswordReset` / `consumeEmailVerification`) don't return rows + * with a NULL token but a non-NULL expiry that would confuse downstream logic. + */ +export class ClearPlaintextAuthTokens1783000000000 implements MigrationInterface { + public async up(queryRunner: QueryRunner): Promise { + await queryRunner.query(` + UPDATE users + SET "passwordResetToken" = NULL, + "passwordResetExpires" = NULL, + "emailVerificationToken" = NULL, + "emailVerificationExpires" = NULL + WHERE "passwordResetToken" IS NOT NULL + OR "emailVerificationToken" IS NOT NULL + `); + } + + public async down(): Promise { + // No-op: the cleared plaintext tokens cannot be restored. + } +} diff --git a/src/migrations/1783000000001-reencrypt-oauth-provider-tokens.ts b/src/migrations/1783000000001-reencrypt-oauth-provider-tokens.ts new file mode 100644 index 00000000..84bf0dde --- /dev/null +++ b/src/migrations/1783000000001-reencrypt-oauth-provider-tokens.ts @@ -0,0 +1,97 @@ +import { MigrationInterface, QueryRunner } from 'typeorm'; +import * as crypto from 'crypto'; + +/** + * Issue #799 — re-encrypts any plaintext `providerAccessToken` / + * `providerRefreshToken` rows that pre-date the at-rest encryption rollout. + * + * Why this is a Node-side migration rather than pure SQL: + * The existing values are plaintext application secrets. Re-encrypting them + * requires the same `EncryptionService` (AES-256-GCM) that runtime code uses, + * which in turn needs `ENCRYPTION_SECRET`. Pure SQL cannot derive a 256-bit + * key from a passphrase and (more importantly) should not have access to it. + * + * Strategy: + * 1. Materialise every plaintext OAuth token into JS. + * 2. Encrypt each value with AES-256-GCM using a random IV per row. + * 3. Persist the result back as `enc:` so {@link SocialAuthService} + * recognises and decrypts it on read. + * 4. Skip values that already carry the `enc:` prefix (idempotency for + * re-runs) and skip NULLs. + * + * The migration throws if `ENCRYPTION_SECRET` is missing so the deploy + * pipeline fails LOUD before writing anything. + * + * Equivalent code path in the application: + * {@link SocialAuthService.encryptToken} (for format reference). + */ +export class ReencryptOAuthProviderTokens1783000000001 implements MigrationInterface { + public async up(queryRunner: QueryRunner): Promise { + const secret = process.env.ENCRYPTION_SECRET; + if (!secret) { + throw new Error( + 'ENCRYPTION_SECRET must be set in the migration environment to run Issue #799 re-encryption.', + ); + } + const key = crypto.createHash('sha256').update(secret).digest(); + + const rows: Array<{ + id: string; + providerAccessToken: string | null; + providerRefreshToken: string | null; + }> = await queryRunner.query( + `SELECT id, + "providerAccessToken", + "providerRefreshToken" + FROM users + WHERE ("providerAccessToken" IS NOT NULL AND "providerAccessToken" <> '') + OR ("providerRefreshToken" IS NOT NULL AND "providerRefreshToken" <> '')`, + ); + + if (rows.length === 0) { + return; + } + + for (const row of rows) { + const encryptedAccess = this.maybeEncrypt(row.providerAccessToken, key); + const encryptedRefresh = this.maybeEncrypt(row.providerRefreshToken, key); + await queryRunner.query( + `UPDATE users + SET "providerAccessToken" = $1, + "providerRefreshToken" = $2 + WHERE id = $3`, + [encryptedAccess, encryptedRefresh, row.id], + ); + } + } + + public async down(): Promise { + // No-op: AES-GCM ciphertext cannot be reversed without the key, and the + // migration is not the place to log or stash the raw values. + } + + private maybeEncrypt(stored: string | null, key: Buffer): string | null { + if (!stored) return null; + if (stored.startsWith('enc:')) { + // Already encrypted — leave as-is so re-runs are idempotent. + return stored; + } + return 'enc:' + JSON.stringify(this.aesGcmEncrypt(stored, key)); + } + + private aesGcmEncrypt(plaintext: string, key: Buffer): { + iv: string; + content: string; + tag: string; + } { + const iv = crypto.randomBytes(16); + const cipher = crypto.createCipheriv('aes-256-gcm', key, iv); + const encrypted = Buffer.concat([cipher.update(plaintext, 'utf8'), cipher.final()]); + const tag = cipher.getAuthTag(); + return { + iv: iv.toString('hex'), + content: encrypted.toString('hex'), + tag: tag.toString('hex'), + }; + } +} diff --git a/src/moderation/moderation.module.ts b/src/moderation/moderation.module.ts index f461d3e0..e5353b5c 100644 --- a/src/moderation/moderation.module.ts +++ b/src/moderation/moderation.module.ts @@ -1,6 +1,12 @@ import { Module } from '@nestjs/common'; import { TypeOrmModule } from '@nestjs/typeorm'; +import { HttpModule } from '@nestjs/axios'; import { ContentSafetyService } from './safety/content-safety.service'; +import { + EXTERNAL_MODERATION_PROVIDER, + ExternalModerationProvider, +} from './safety/external-moderation.provider'; +import { OpenAiModerationAdapter } from './safety/openai-moderation.adapter'; import { AutoModerationService } from './auto/auto-moderation.service'; import { ManualReviewService } from './manual/manual-review.service'; import { ReviewItem } from './manual/review-item.entity'; @@ -15,15 +21,29 @@ import { NotificationsModule } from '../notifications/notifications.module'; /** * Registers the moderation module, exposing content safety and review services. + * + * Issue #805 — wires the OpenAI adapter behind the + * {@link EXTERNAL_MODERATION_PROVIDER} token via `useExisting`, so consumers + * inject the interface alone. Swapping to a different provider (AWS Rekognition, + * Perspective, …) is a one-line change in this module. */ @Module({ imports: [ TypeOrmModule.forFeature([ReviewItem, ModerationEvent, ContentReport, User]), NotificationsModule, + HttpModule.register({ + timeout: 5000, + maxRedirects: 0, + }), ], controllers: [ContentReportsController], providers: [ ContentSafetyService, + OpenAiModerationAdapter, + { + provide: EXTERNAL_MODERATION_PROVIDER, + useExisting: OpenAiModerationAdapter, + }, AutoModerationService, ManualReviewService, ModerationAnalyticsService, @@ -32,6 +52,10 @@ import { NotificationsModule } from '../notifications/notifications.module'; ], exports: [ ContentSafetyService, + { + provide: EXTERNAL_MODERATION_PROVIDER, + useExisting: OpenAiModerationAdapter, + }, AutoModerationService, ManualReviewService, ModerationAnalyticsService, @@ -40,3 +64,6 @@ import { NotificationsModule } from '../notifications/notifications.module'; ], }) export class ModerationModule {} + +// Re-export so consumers do not have to import from deep paths. +export type { ExternalModerationProvider }; diff --git a/src/moderation/safety/content-safety.service.spec.ts b/src/moderation/safety/content-safety.service.spec.ts index fbe4b7c3..8bb70743 100644 --- a/src/moderation/safety/content-safety.service.spec.ts +++ b/src/moderation/safety/content-safety.service.spec.ts @@ -1,64 +1,187 @@ import { Test, TestingModule } from '@nestjs/testing'; +import { ConfigService } from '@nestjs/config'; import { ContentSafetyService } from './content-safety.service'; +import { + EXTERNAL_MODERATION_PROVIDER, + ExternalModerationProvider, + ExternalModerationUnavailableError, +} from './external-moderation.provider'; +import { EnhancedCircuitBreakerService } from '../../common/services/circuit-breaker.service'; -describe('ContentSafetyService', () => { +function makeMockProvider( + impl: ExternalModerationProvider['scoreContent'], +): ExternalModerationProvider & jest.Mocked { + return { + name: 'mock-provider', + scoreContent: jest.fn().mockImplementation(impl), + } as ExternalModerationProvider & jest.Mocked; +} + +/** + * A circuit-breaker service that simply forwards execute() calls. We are not + * testing opossum here — that is the EnhancedCircuitBreakerService's job; we + * are testing that ContentSafetyService threads the fallback correctly. + * + * The execute() mock is typed as the interface method itself so we don't have + * to specify a concrete `T`. This sidesteps a TS2322 clash between the mock's + * inferred concrete `Promise` and the interface's generic ``. + */ +function makePassthroughBreaker(): Partial { + const fn = jest.fn( + async ( + _key: string, + op: () => Promise, + options: { fallback?: (err: Error) => unknown } = {}, + ): Promise => { + try { + return await op(); + } catch (err) { + if (options.fallback) return options.fallback(err as Error); + throw err; + } + }, + ); + return { execute: fn as unknown as EnhancedCircuitBreakerService['execute'] }; +} + +describe('ContentSafetyService (Issue #805 — external moderation with fallback)', () => { let service: ContentSafetyService; + let mockProvider: ReturnType; + let mockBreaker: ReturnType; - beforeEach(async () => { - const module: TestingModule = await Test.createTestingModule({ - providers: [ContentSafetyService], + async function buildModule(providerEnabled = true): Promise { + const moduleRef: TestingModule = await Test.createTestingModule({ + providers: [ + ContentSafetyService, + { + provide: EXTERNAL_MODERATION_PROVIDER, + useValue: mockProvider, + }, + { + provide: EnhancedCircuitBreakerService, + useValue: mockBreaker, + }, + { + provide: ConfigService, + useValue: { + get: jest.fn((key: string, fallback?: unknown) => { + if (key === 'OPENAI_MODERATION_ENABLED') return providerEnabled; + if (key === 'OPENAI_MODERATION_TIMEOUT_MS') return 500; + return fallback; + }), + }, + }, + ], }).compile(); - service = module.get(ContentSafetyService); - }); + service = moduleRef.get(ContentSafetyService); + } - it('should be defined', () => { - expect(service).toBeDefined(); + beforeEach(() => { + mockProvider = makeMockProvider(async () => 0); + mockBreaker = makePassthroughBreaker(); }); - describe('scoreContent', () => { - it('should return 0 for clean content', () => { - expect(service.scoreContent('This is a great lesson about JavaScript!')).toBe(0); + describe('keyword-only path', () => { + beforeEach(() => buildModule(false)); + + it('returns 0 for clean content', async () => { + expect(await service.scoreContent('JavaScript is great')).toBe(0); + }); + + it('flags violence', async () => { + expect(await service.scoreContent('this is violence')).toBeGreaterThan(0); + }); + + it('flags hate', async () => { + expect(await service.scoreContent('this is hate')).toBeGreaterThan(0); + }); + + it('flags explicit', async () => { + expect(await service.scoreContent('this is explicit')).toBeGreaterThan(0); + }); + + it('flags spam', async () => { + expect(await service.scoreContent('this is spam')).toBeGreaterThan(0); + }); + + it('flags scam', async () => { + expect(await service.scoreContent('this is a scam')).toBeGreaterThan(0); }); - it('should score content containing violence keyword', () => { - const score = service.scoreContent('This content contains violence'); - expect(score).toBeGreaterThan(0); - expect(score).toBeLessThanOrEqual(1); + it('caps at 1 for multi-keyword content', async () => { + expect(await service.scoreContent('violence hate explicit spam scam')).toBe(1); }); - it('should score content containing hate keyword', () => { - const score = service.scoreContent('This content promotes hate'); - expect(score).toBeGreaterThan(0); + it('returns 0 for empty/whitespace input', async () => { + expect(await service.scoreContent('')).toBe(0); + expect(await service.scoreContent(' ')).toBe(0); }); + }); + + describe('external provider path', () => { + beforeEach(() => buildModule(true)); - it('should score content containing explicit keyword', () => { - const score = service.scoreContent('This is explicit material'); - expect(score).toBeGreaterThan(0); + it('returns 0 when both provider and keyword report clean', async () => { + mockProvider.scoreContent.mockResolvedValue(0); + expect(await service.scoreContent('hello world')).toBe(0); + expect(mockProvider.scoreContent).toHaveBeenCalledWith('hello world'); }); - it('should score content containing spam keyword', () => { - const score = service.scoreContent('This is spam content'); - expect(score).toBeGreaterThan(0); + it('returns provider score when it is greater than keyword', async () => { + mockProvider.scoreContent.mockResolvedValue(0.95); + expect(await service.scoreContent('totally clean text')).toBe(0.95); }); - it('should score content containing scam keyword', () => { - const score = service.scoreContent('This is a scam'); - expect(score).toBeGreaterThan(0); + it('returns keyword score when it is greater (homoglyph bypass test)', async () => { + // Simulates: provider passes (e.g. it normalised text internally), + // but the legacy keyword regex still catches the un-normalised input. + mockProvider.scoreContent.mockResolvedValue(0); + expect(await service.scoreContent('contains violence')).toBeGreaterThan(0); + expect(await service.scoreContent('contains violence')).toBeCloseTo(0.8, 5); }); - it('should cap score at 1 for multiple violations', () => { - const score = service.scoreContent('violence hate explicit spam scam'); - expect(score).toBe(1); + it('falls back to keyword score when provider throws ExternalModerationUnavailableError', async () => { + mockProvider.scoreContent.mockRejectedValue( + new ExternalModerationUnavailableError('network down'), + ); + expect(await service.scoreContent('violence')).toBeGreaterThan(0); + expect(mockProvider.scoreContent).toHaveBeenCalled(); }); - it('should be case-insensitive', () => { - expect(service.scoreContent('VIOLENCE')).toBeGreaterThan(0); - expect(service.scoreContent('SPAM')).toBeGreaterThan(0); + it('falls back to keyword score on generic provider error', async () => { + mockProvider.scoreContent.mockRejectedValue(new Error('boom')); + expect(await service.scoreContent('violence')).toBeGreaterThan(0); }); - it('should return 0 for empty string', () => { - expect(service.scoreContent('')).toBe(0); + it('caps combined score at 1', async () => { + mockProvider.scoreContent.mockResolvedValue(1); + expect(await service.scoreContent('violence')).toBeLessThanOrEqual(1); + }); + + it('routes provider invocation through the static circuit-breaker key', async () => { + mockProvider.scoreContent.mockResolvedValue(0); + await service.scoreContent('hello'); + const breakerKey = (mockBreaker.execute as jest.Mock).mock.calls[0][0]; + expect(breakerKey).toBe(ContentSafetyService.CIRCUIT_BREAKER_KEY); + // Verify the key is static (constant) — dynamic keys would leak opossum memory. + expect(ContentSafetyService.CIRCUIT_BREAKER_KEY).toBe( + ContentSafetyService.CIRCUIT_BREAKER_KEY, + ); + }); + }); + + describe('homoglyph bypass regression (Issue #805 acceptance)', () => { + beforeEach(() => buildModule(true)); + + it('flags content with Unicode homoglyph substitution that the keyword regex would miss', async () => { + // The OpenAI adapter internally normalises the text and flags it. + // The legacy keyword regex would miss this because the lowercase form + // contains \uff56 (full-width 'v') instead of 'violence'. + mockProvider.scoreContent.mockResolvedValue(0.92); + const homoglyph = '\uff76' + 'iolence'; // full-width 'v' + 'iolence' + const score = await service.scoreContent(`user posted: ${homoglyph}`); + expect(score).toBeGreaterThanOrEqual(0.9); }); }); }); diff --git a/src/moderation/safety/content-safety.service.ts b/src/moderation/safety/content-safety.service.ts index 215deade..a1189243 100644 --- a/src/moderation/safety/content-safety.service.ts +++ b/src/moderation/safety/content-safety.service.ts @@ -1,17 +1,115 @@ -import { Injectable } from '@nestjs/common'; +import { Inject, Injectable, Logger } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import { EnhancedCircuitBreakerService } from '../../common/services/circuit-breaker.service'; +import { + EXTERNAL_MODERATION_PROVIDER, + ExternalModerationProvider, + ModerationScore, +} from './external-moderation.provider'; /** - * Provides content Safety operations. + * Issue #805 — Content safety scorer with circuit-breaker-protected external + * provider and a synchronous keyword fallback. + * + * Why both: + * - `keywordScore()` is fast, deterministic, and wrong only in adversarial + * cases (Unicode homoglyphs, zero-width characters, deliberate misspelling). + * - `provider.scoreContent()` is correct under those adversarial cases but is + * a remote call — it can fail, time out, or be unavailable. + * + * Behaviour: + * - We call the provider through {@link EnhancedCircuitBreakerService} using a + * *static* breaker key. A static key is required because opossum stores + * breaker instances in a Map keyed by the key string — dynamic keys (e.g. + * per-IP) would leak heap. + * - On any {@link ExternalModerationUnavailableError} or breaker-open condition, + * the provided `fallback()` resolves with the keyword score so the request + * completes successfully with degraded accuracy. + * - When the provider succeeds we return `max(external, local)` so a clean + * external pass does NOT mask a homoglyphed keyword match. */ @Injectable() export class ContentSafetyService { /** - * Executes score Content. - * @param content The content. - * @returns The calculated numeric value. + * Static breaker key. Do NOT make this dynamic (e.g. per-IP / per-content), + * otherwise opossum's internal Map grows unbounded. */ - scoreContent(content: string): number { - // Simple scoring logic (replace with ML model later) + static readonly CIRCUIT_BREAKER_KEY = 'content-safety:external-moderation'; + + private readonly logger = new Logger(ContentSafetyService.name); + private readonly breakerTimeoutMs: number; + private readonly providerEnabled: boolean; + + constructor( + @Inject(EXTERNAL_MODERATION_PROVIDER) + private readonly provider: ExternalModerationProvider, + private readonly circuitBreaker: EnhancedCircuitBreakerService, + private readonly configService: ConfigService, + ) { + this.breakerTimeoutMs = this.configService.get( + 'OPENAI_MODERATION_TIMEOUT_MS', + 500, + ); + this.providerEnabled = this.configService.get( + 'OPENAI_MODERATION_ENABLED', + true, + ); + } + + /** + * Returns a safety score in [0, 1] (higher = more unsafe). + * + * Always resolves successfully — provider failure is masked by the + * circuit breaker fallback. Only throws if the breaker fallback itself + * throws, which should be impossible by construction. + */ + async scoreContent(content: string): Promise { + if (!content || !content.trim()) return 0; + + const keywordScore = this.keywordScore(content); + + if (!this.providerEnabled) { + // Feature off — pure keyword filter. + return keywordScore; + } + + try { + const external = await this.circuitBreaker.execute( + ContentSafetyService.CIRCUIT_BREAKER_KEY, + () => this.provider.scoreContent(content), + { + name: ContentSafetyService.CIRCUIT_BREAKER_KEY, + timeout: this.breakerTimeoutMs, + fallback: (err: Error) => { + this.logger.warn( + `External moderation unavailable (${err.message}); falling back to keyword filter`, + ); + return keywordScore; + }, + }, + ); + + // Combine: external pass with masked homoglyph still trips the keyword filter. + return Math.max(external, keywordScore); + } catch (err) { + // Last-resort net — if even the fallback threw (shouldn't happen), degrade + // to keyword-only and log loudly so an operator investigates. + this.logger.error( + `ContentSafetyService fallback chain failed: ${(err as Error).message}`, + ); + return keywordScore; + } + } + + /** + * Synchronous local keyword heuristic. Preserved as a public method so it + * stays callable from contexts without async (legacy callers, tests). + * Note: this is the *same* trivially-bypassable regex the issue identified — + * it is intentionally retained ONLY as a fallback. The primary scoring path + * is `scoreContent()`. + */ + keywordScore(content: string): ModerationScore { + if (!content) return 0; let score = 0; if (/violence|hate|explicit/i.test(content)) score += 0.8; if (/spam|scam/i.test(content)) score += 0.5; diff --git a/src/moderation/safety/external-moderation.provider.ts b/src/moderation/safety/external-moderation.provider.ts new file mode 100644 index 00000000..8d3885ec --- /dev/null +++ b/src/moderation/safety/external-moderation.provider.ts @@ -0,0 +1,61 @@ +/** + * Issue #805 — ContentSafetyService must use a real moderation provider rather than + * the trivially-bypassable keyword regex. + * + * This module defines the contract adapters (OpenAI, AWS Rekognition, Perspective, …) + * implement. Adapters MUST throw {@link ExternalModerationUnavailableError} on any + * failure mode that the caller can recover from (network error, auth error, timeout, + * malformed response) so the {@link ContentSafetyService} can degrade to the keyword + * filter via its circuit-breaker fallback instead of returning 500. + */ + +/** + * Score returned in the closed interval [0, 1]. + * - `0` means definitively safe. + * - `1` means definitively unsafe. + * Implementations are free to choose any continuous scale in between; the caller's + * threshold logic interprets the value. + */ +export type ModerationScore = number; + +/** + * Any external moderation provider must implement this contract. + * + * Implementations are responsible for normalising input text — including stripping + * Unicode homoglyphs, zero-width characters, and other tricks that the legacy + * keyword regex could not catch. Callers must NOT pre-normalise before invoking + * the adapter because different adapters may want different normalisation rules. + */ +export interface ExternalModerationProvider { + /** Stable identifier used for circuit-breaker keying and observability. */ + readonly name: string; + + /** + * Returns a {@link ModerationScore} for the given text. + * Throws {@link ExternalModerationUnavailableError} on recoverable failures so + * the caller can fall back without leaking a 500. + */ + scoreContent(text: string): Promise; +} + +/** + * Marker error for transient / recoverable provider failure. The caller (typically + * ContentSafetyService inside its circuit breaker) interprets this as "fall back + * gracefully" rather than "propagate to the user as 500". + */ +export class ExternalModerationUnavailableError extends Error { + constructor( + message: string, + public readonly cause?: unknown, + ) { + super(message); + this.name = 'ExternalModerationUnavailableError'; + } +} + +/** + * String DI token used to register / consume an {@link ExternalModerationProvider} in + * NestJS modules. A constant keeps the token name consistent across producers and + * consumers and prevents accidental typos. + */ +export const EXTERNAL_MODERATION_PROVIDER = 'ExternalModerationProvider'; diff --git a/src/moderation/safety/openai-moderation.adapter.ts b/src/moderation/safety/openai-moderation.adapter.ts new file mode 100644 index 00000000..9abab27e --- /dev/null +++ b/src/moderation/safety/openai-moderation.adapter.ts @@ -0,0 +1,146 @@ +import { Injectable, Logger } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import { HttpService } from '@nestjs/axios'; +import { firstValueFrom, throwError, TimeoutError } from 'rxjs'; +import { catchError, timeout } from 'rxjs/operators'; +import { AxiosError, AxiosResponse } from 'axios'; +import { + ExternalModerationProvider, + ExternalModerationUnavailableError, + ModerationScore, +} from './external-moderation.provider'; + +interface OpenAIModerationCategories { + [category: string]: boolean; +} + +interface OpenAIModerationCategoryScores { + [category: string]: number; +} + +interface OpenAIModerationResult { + flagged?: boolean; + categories?: OpenAIModerationCategories; + category_scores?: OpenAIModerationCategoryScores; +} + +interface OpenAIModerationResponse { + results?: OpenAIModerationResult[]; +} + +/** + * Issue #805 — OpenAI moderation adapter. + * + * Calls `POST /v1/moderations` and translates the response into a single + * [0, 1] {@link ModerationScore}. The mapping is intentionally simple: take the + * max of `category_scores` (the highest-confidence category drives the score). + * When the response is shape-malformed we treat it as unavailability rather than + * "safe", because false negatives are worse than degraded UX. + * + * Any failure mode (missing key, timeout, HTTP error, malformed response) is + * converted into {@link ExternalModerationUnavailableError} so the caller can + * fall back to the keyword filter without a 500. + */ +@Injectable() +export class OpenAiModerationAdapter implements ExternalModerationProvider { + public readonly name = 'openai-moderation'; + + private readonly logger = new Logger(OpenAiModerationAdapter.name); + private readonly endpoint = 'https://api.openai.com/v1/moderations'; + + constructor( + private readonly httpService: HttpService, + private readonly configService: ConfigService, + ) {} + + async scoreContent(text: string): Promise { + const apiKey = this.configService.get('OPENAI_API_KEY'); + const timeoutMs = this.configService.get( + 'OPENAI_MODERATION_TIMEOUT_MS', + 500, + ); + + if (!apiKey) { + throw new ExternalModerationUnavailableError( + 'OPENAI_API_KEY is not configured', + ); + } + if (typeof text !== 'string' || text.length === 0) { + return 0; + } + + let response: AxiosResponse; + try { + response = await firstValueFrom( + this.httpService + .post( + this.endpoint, + { input: text }, + { headers: { Authorization: `Bearer ${apiKey}` } }, + ) + .pipe( + timeout(timeoutMs), + catchError((err) => { + if (err instanceof TimeoutError) { + return throwError( + () => + new ExternalModerationUnavailableError( + `OpenAI moderation timed out after ${timeoutMs}ms`, + err, + ), + ); + } + if (err instanceof AxiosError) { + return throwError( + () => + new ExternalModerationUnavailableError( + `OpenAI moderation HTTP error (${err.response?.status ?? 'no-status'}): ${err.message}`, + err, + ), + ); + } + return throwError( + () => + new ExternalModerationUnavailableError( + `OpenAI moderation failed: ${(err as Error).message}`, + err, + ), + ); + }), + ), + ); + } catch (err) { + if (err instanceof ExternalModerationUnavailableError) { + throw err; + } + throw new ExternalModerationUnavailableError( + `OpenAI moderation unexpected error: ${(err as Error).message}`, + err, + ); + } + + const result = response.data?.results?.[0]; + if (!result) { + // Treat empty results as unavailability — better to false-positive than + // to silently approve everything because the provider changed shape. + throw new ExternalModerationUnavailableError( + 'OpenAI moderation returned no results array', + ); + } + + const scores = Object.values(result.category_scores ?? {}); + if (scores.length === 0) { + // No per-category scores — fall back to the boolean flag with half-weight. + return result.flagged === true ? 0.8 : 0; + } + const maxScore = scores.reduce((acc, v) => (v > acc ? v : acc), 0); + return clamp01(maxScore); + } +} + +function clamp01(value: number): number { + if (Number.isNaN(value)) return 0; + if (value < 0) return 0; + if (value > 1) return 1; + return value; +} diff --git a/src/security/security.module.ts b/src/security/security.module.ts index c6454ff8..df72cee5 100644 --- a/src/security/security.module.ts +++ b/src/security/security.module.ts @@ -1,24 +1,41 @@ import { Module } from '@nestjs/common'; import { ScheduleModule } from '@nestjs/schedule'; +import { ConfigService } from '@nestjs/config'; import { SecurityService } from './security.service'; import { EncryptionService } from './encryption/encryption.service'; import { ThreatDetectionService } from './threats/threat-detection.service'; +import { THREAT_REDIS_CLIENT } from './threats/threat-detection.constants'; import { ComplianceService } from './compliance/compliance.service'; import { AuditLoggingService } from './audit/audit-logging.service'; import { SecretsModule } from './secrets/secrets.module'; +import { getSharedRedisClient } from '../config/cache.config'; /** * Registers the security module. + * + * Issue #798 — wires the shared Redis client behind `THREAT_REDIS_CLIENT` + * so {@link ThreatDetectionService} uses a distributed store instead of an + * in-process Map. The token keeps the dependency mockable in unit tests. */ @Module({ imports: [ScheduleModule.forRoot(), SecretsModule], providers: [ SecurityService, EncryptionService, + { + provide: THREAT_REDIS_CLIENT, + useFactory: (configService: ConfigService) => getSharedRedisClient(configService), + inject: [ConfigService], + }, ThreatDetectionService, ComplianceService, AuditLoggingService, ], - exports: [SecurityService, EncryptionService, SecretsModule], + exports: [ + SecurityService, + EncryptionService, + SecretsModule, + THREAT_REDIS_CLIENT, + ], }) export class SecurityModule {} diff --git a/src/security/threats/threat-detection.constants.ts b/src/security/threats/threat-detection.constants.ts new file mode 100644 index 00000000..deba7ebe --- /dev/null +++ b/src/security/threats/threat-detection.constants.ts @@ -0,0 +1,8 @@ +/** + * Issue #798 — DI tokens for {@link ThreatDetectionService}. + * + * String token (rather than the `Redis` class directly) so the test suite can + * supply a `createMockRedisClient()` instance without instantiating a real + * `ioredis.Redis`. `SecurityModule` binds this token to `getSharedRedisClient()`. + */ +export const THREAT_REDIS_CLIENT = 'THREAT_REDIS_CLIENT'; diff --git a/src/security/threats/threat-detection.service.spec.ts b/src/security/threats/threat-detection.service.spec.ts index 4b102af6..73727271 100644 --- a/src/security/threats/threat-detection.service.spec.ts +++ b/src/security/threats/threat-detection.service.spec.ts @@ -1,106 +1,144 @@ +import { Test, TestingModule } from '@nestjs/testing'; +import { ConfigService } from '@nestjs/config'; +import { ForbiddenOperationException } from '../../common/exceptions/app.exceptions'; import { ThreatDetectionService } from './threat-detection.service'; +import { THREAT_REDIS_CLIENT } from './threat-detection.constants'; +import { createMockRedisClient } from '../../../test/utils/mock-factories'; /** - * Helper: build a string IP for index `n` so we can deterministically - * know which entry should be the "oldest" (first inserted). + * Helper: build a Test module that wires the service with the supplied + * Redis mock + ConfigService. Centralised because every test in the suite + * builds this module the same way. */ -function ipFor(index: number): string { - return `10.0.${Math.floor(index / 256) % 256}.${index % 256}`; +async function buildModule(redis: ReturnType) { + const moduleRef: TestingModule = await Test.createTestingModule({ + providers: [ + ThreatDetectionService, + { provide: THREAT_REDIS_CLIENT, useValue: redis }, + { + provide: ConfigService, + useValue: { + get: jest.fn((key: string, fallback?: unknown) => { + // Use the documented defaults; tests can override via direct injection. + if (key === 'THREAT_FAILED_ATTEMPT_THRESHOLD') return 10; + if (key === 'THREAT_FAILED_ATTEMPT_WINDOW_SECONDS') return 15 * 60; + if (key === 'THREAT_FAILED_ATTEMPT_KEY_PREFIX') + return 'threat:failed-attempts:'; + return fallback; + }), + }, + }, + ], + }).compile(); + return moduleRef.get(ThreatDetectionService); } -describe('ThreatDetectionService', () => { - describe('behaviour (preserves existing semantics)', () => { - it('does not throw before the failure threshold is reached', () => { - const svc = new ThreatDetectionService({ max: 100, ttlMs: 60_000 }); - const ip = '192.168.0.1'; +describe('ThreatDetectionService (Issue #798 — Redis-backed counters)', () => { + let service: ThreatDetectionService; + let redis: ReturnType; - // 11 attempts is still allowed (attempts > 10 means 11+ throws) - for (let i = 0; i < 10; i++) svc.recordFailure(ip); - expect(() => svc.analyzeRequest(ip)).not.toThrow(); - }); + beforeEach(async () => { + redis = createMockRedisClient(); + service = await buildModule(redis); + }); - it('throws ForbiddenOperationException once attempts exceed 10', () => { - const svc = new ThreatDetectionService({ max: 100, ttlMs: 60_000 }); - const ip = '192.168.0.2'; + describe('recordFailure — INCR + first-call EXPIRE', () => { + it('calls INCR with the per-IP key', async () => { + redis.incr.mockResolvedValueOnce(1); - for (let i = 0; i < 11; i++) svc.recordFailure(ip); - expect(() => svc.analyzeRequest(ip)).toThrow(/Suspicious activity detected/); + await service.recordFailure('192.168.0.1'); + + expect(redis.incr).toHaveBeenCalledWith('threat:failed-attempts:192.168.0.1'); }); - it('clears the failure counter on reset()', () => { - const svc = new ThreatDetectionService({ max: 100, ttlMs: 60_000 }); - const ip = '192.168.0.3'; + it('sets EXPIRE on the first failure (INCR returned 1)', async () => { + redis.incr.mockResolvedValueOnce(1); - for (let i = 0; i < 11; i++) svc.recordFailure(ip); - expect(() => svc.analyzeRequest(ip)).toThrow(); + await service.recordFailure('192.168.0.1'); - svc.reset(ip); - expect(() => svc.analyzeRequest(ip)).not.toThrow(); - expect(svc.has(ip)).toBe(false); + expect(redis.expire).toHaveBeenCalledWith( + 'threat:failed-attempts:192.168.0.1', + 15 * 60, + ); }); - }); - describe('bounded cap (issue #882 acceptance criterion: 50k max)', () => { - it('caps the cache at the configured maximum entries', () => { - const cap = 50_000; - const svc = new ThreatDetectionService({ max: cap, ttlMs: 60 * 60 * 1000 }); - - for (let i = 0; i < cap + 1; i++) { - svc.recordFailure(ipFor(i)); - } + it('does NOT set EXPIRE on subsequent calls in the same window', async () => { + redis.incr.mockResolvedValueOnce(2).mockResolvedValueOnce(3); + await service.recordFailure('192.168.0.1'); + await service.recordFailure('192.168.0.1'); + expect(redis.expire).not.toHaveBeenCalled(); + }); - // Bounded at exactly the cap (spec: "Map size is bounded at 50,000 entries") - expect(svc.getCacheSize()).toBe(cap); + it('does not throw when Redis INCR fails (fails open for tracking)', async () => { + redis.incr.mockRejectedValueOnce(new Error('connection lost')); + await expect(service.recordFailure('192.168.0.1')).resolves.toBeUndefined(); }); + }); - it('evicts the oldest entry when inserting the (cap+1)-th entry', () => { - const cap = 50_000; - const svc = new ThreatDetectionService({ max: cap, ttlMs: 60 * 60 * 1000 }); + describe('analyzeRequest — GET + threshold check', () => { + it('does not throw when no failure counter is stored', async () => { + redis.get.mockResolvedValueOnce(null); + await expect(service.analyzeRequest('192.168.0.2')).resolves.toBeUndefined(); + }); - const firstIp = ipFor(0); + it('does not throw while count is at or below the threshold', async () => { + redis.get.mockResolvedValueOnce('10'); + await expect(service.analyzeRequest('192.168.0.2')).resolves.toBeUndefined(); + }); - // Fill to capacity - for (let i = 0; i < cap; i++) { - svc.recordFailure(ipFor(i)); - } - expect(svc.has(firstIp)).toBe(true); + it('throws ForbiddenOperationException when count exceeds the threshold', async () => { + redis.get.mockResolvedValueOnce('11'); + await expect(service.analyzeRequest('192.168.0.2')).rejects.toBeInstanceOf( + ForbiddenOperationException, + ); + }); - // The (cap+1)-th insertion triggers LRU eviction; the oldest entry - // (the first one we inserted) should be gone. - svc.recordFailure(ipFor(cap)); + it('fails open when Redis GET errors (does not block legitimate traffic)', async () => { + redis.get.mockRejectedValueOnce(new Error('connection lost')); + await expect(service.analyzeRequest('192.168.0.2')).resolves.toBeUndefined(); + }); + }); - expect(svc.has(firstIp)).toBe(false); - expect(svc.getCacheSize()).toBe(cap); + describe('reset', () => { + it('DELs the per-IP key', async () => { + redis.del.mockResolvedValueOnce(1); + await service.reset('192.168.0.3'); + expect(redis.del).toHaveBeenCalledWith('threat:failed-attempts:192.168.0.3'); }); - it('uses the documented 50,000 entry cap when no options are provided', () => { - const svc = new ThreatDetectionService(); - expect(svc.getCacheSize()).toBe(0); - // Sanity: the default must match the documented value. - expect(ThreatDetectionService.MAX_ENTRIES).toBe(50_000); + it('does not throw when Redis DEL fails', async () => { + redis.del.mockRejectedValueOnce(new Error('connection lost')); + await expect(service.reset('192.168.0.3')).resolves.toBeUndefined(); }); }); - describe('TTL (issue #882 acceptance criterion: 15-minute expiry)', () => { - it('expires entries after the configured TTL has elapsed', async () => { - // Tiny TTL keeps the test fast while still exercising the same code path. - const ttlMs = 30; - const svc = new ThreatDetectionService({ max: 100, ttlMs }); + describe('expiry semantics (Issue #798 acceptance)', () => { + it('a failure counter that has expired no longer triggers analyseRequest', async () => { + // Simulate an empty store: Redis returned null after the previous key + // expired — meaning the failure window cleared correctly. + redis.get.mockResolvedValue(null); + await expect(service.analyzeRequest('192.168.0.4')).resolves.toBeUndefined(); + }); + }); - const ip = '192.168.0.42'; - for (let i = 0; i < 11; i++) svc.recordFailure(ip); - expect(() => svc.analyzeRequest(ip)).toThrow(); + describe('introspection helpers', () => { + it('resolveKey returns the prefixed Redis key', () => { + expect(service.resolveKey('10.0.0.1')).toBe('threat:failed-attempts:10.0.0.1'); + }); - // Wait past the TTL so the entry is reaped. - await new Promise((resolve) => setTimeout(resolve, ttlMs + 50)); + it('has() returns true when EXISTS returns > 0', async () => { + redis.exists.mockResolvedValueOnce(1); + expect(await service.has('10.0.0.1')).toBe(true); + }); - // After expiry the entry is gone — analyseRequest should not throw. - expect(() => svc.analyzeRequest(ip)).not.toThrow(); - expect(svc.has(ip)).toBe(false); + it('has() returns false when EXISTS returns 0', async () => { + redis.exists.mockResolvedValueOnce(0); + expect(await service.has('10.0.0.1')).toBe(false); }); - it('uses a 15-minute TTL by default', () => { - expect(ThreatDetectionService.TTL_MS).toBe(15 * 60 * 1000); + it('has() degrades to false on Redis error', async () => { + redis.exists.mockRejectedValueOnce(new Error('boom')); + expect(await service.has('10.0.0.1')).toBe(false); }); }); }); diff --git a/src/security/threats/threat-detection.service.ts b/src/security/threats/threat-detection.service.ts index 25e7da6c..d3eda236 100644 --- a/src/security/threats/threat-detection.service.ts +++ b/src/security/threats/threat-detection.service.ts @@ -1,79 +1,143 @@ -import { Injectable, Logger, Optional } from '@nestjs/common'; -import { LRUCache } from 'lru-cache'; +import { Inject, Injectable, Logger } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import { Redis } from 'ioredis'; import { ForbiddenOperationException } from '../../common/exceptions/app.exceptions'; +import { THREAT_REDIS_CLIENT } from './threat-detection.constants'; /** - * Provides threat Detection operations. + * Issue #798 — Per-IP failed-attempt counter. * - * Tracks per-IP failed attempt counts in a bounded LRU cache so that - * IP-rotation attacks or large user bases cannot cause the heap to grow - * unbounded (see issue #882). + * Why Redis, not an in-process `Map`: + * - In a horizontally-scaled deployment, each pod had its own counter. + * An attacker spreading requests across pods flew under the per-pod + * threshold while running a full credential-stuffing attack. + * - The in-process map also grew without bound. * - * The cache is capped at `MAX_ENTRIES` entries and each entry expires - * `TTL_MS` after it was last written. Eviction of an entry (whether due - * to the cap or TTL expiry) emits a single warning log so operators can - * detect sustained pressure on the structure. + * Implementation: + * - `INCR ${key}` is atomic across all replicas; the result IS the current + * counter value. + * - On the first call (when `INCR` returns 1) we set `EXPIRE ${key} ttlSeconds` + * so the counter auto-clears once the window elapses. + * - Threshold (count above which we refuse), window length, and key prefix + * are configurable via {@link ConfigService}. */ @Injectable() export class ThreatDetectionService { - /** Max number of tracked IPs. Tuned for memory-bounded operation. */ - static readonly MAX_ENTRIES = 50_000; - /** 15-minute TTL on each entry. */ - static readonly TTL_MS = 15 * 60 * 1000; + static readonly DEFAULT_THRESHOLD = 10; + static readonly DEFAULT_WINDOW_SECONDS = 15 * 60; // 15 minutes + static readonly DEFAULT_KEY_PREFIX = 'threat:failed-attempts:'; private readonly logger = new Logger(ThreatDetectionService.name); - private readonly failedAttempts: LRUCache; - private lastEvictionWarnAt = 0; + private readonly threshold: number; + private readonly windowSeconds: number; + private readonly keyPrefix: string; - constructor(@Optional() options?: { max?: number; ttlMs?: number }) { - const max = options?.max ?? ThreatDetectionService.MAX_ENTRIES; - const ttl = options?.ttlMs ?? ThreatDetectionService.TTL_MS; + constructor( + @Inject(THREAT_REDIS_CLIENT) private readonly redis: Redis, + private readonly configService: ConfigService, + ) { + this.threshold = this.configService.get( + 'THREAT_FAILED_ATTEMPT_THRESHOLD', + ThreatDetectionService.DEFAULT_THRESHOLD, + ); + this.windowSeconds = this.configService.get( + 'THREAT_FAILED_ATTEMPT_WINDOW_SECONDS', + ThreatDetectionService.DEFAULT_WINDOW_SECONDS, + ); + this.keyPrefix = this.configService.get( + 'THREAT_FAILED_ATTEMPT_KEY_PREFIX', + ThreatDetectionService.DEFAULT_KEY_PREFIX, + ); + } - // `lru-cache` v11 fires the `dispose` callback once per eviction. We - // rate-limit the warn log to once per 60 s so a flood of evictions does - // not amplify the very load we are trying to detect. - this.failedAttempts = new LRUCache({ - max, - ttl, - ttlAutopurge: true, - updateAgeOnGet: false, - dispose: (_value, _key, reason) => { - if (reason !== 'evict') return; - const now = Date.now(); - if (now - this.lastEvictionWarnAt < 60_000) return; - this.lastEvictionWarnAt = now; - this.logger.warn( - `LRU eviction triggered on failedAttempts cache (cap=${max}). ` + - 'Sustained pressure indicates a potential IP-rotation attack; ' + - 'consider raising MAX_ENTRIES or migrating to a distributed store.', - ); - }, - }); + private keyFor(ip: string): string { + return `${this.keyPrefix}${ip}`; } - analyzeRequest(ip: string): void { - const attempts = this.failedAttempts.get(ip) || 0; - if (attempts > 10) { + /** + * Refuses the request if the IP currently has more than `threshold` failures + * recorded in the rolling Redis window. A failure count strictly greater than + * the configured threshold triggers {@link ForbiddenOperationException}. + * + * Note: this is now async because the underlying store is remote. Callers + * (guards, middleware) must await. We deliberately fail OPEN on Redis errors + * so an outage cannot amplify load by blocking legitimate traffic. + */ + async analyzeRequest(ip: string): Promise { + const key = this.keyFor(ip); + let count: number; + try { + const raw = await this.redis.get(key); + count = raw ? Number(raw) : 0; + } catch (err) { + this.logger.error( + `analyzeRequest: Redis GET failed (${(err as Error).message}); failing open.`, + ); + return; + } + if (count > this.threshold) { throw new ForbiddenOperationException('Suspicious activity detected'); } } - recordFailure(ip: string): void { - const attempts = this.failedAttempts.get(ip) || 0; - this.failedAttempts.set(ip, attempts + 1); + /** + * Atomically increments the IP's failure counter. On the first increment + * (the INCR returned 1) we install the TTL so the counter auto-expires. + * + * The TTL is set AFTER the INCR so a Redis outage between the two commands + * cannot leave behind a permanent counter — at worst the counter survives + * forever, which degrades to the same behaviour as the legacy in-memory + * map (a non-zero counter is still better than nothing). + */ + async recordFailure(ip: string): Promise { + const key = this.keyFor(ip); + try { + const count = await this.redis.incr(key); + if (count === 1) { + // First failure in this window — arm the auto-expiry. + await this.redis.expire(key, this.windowSeconds); + } + } catch (err) { + // Recording failures must not throw — losing a single increment is + // acceptable; throwing would amplify the very load we are tracking. + this.logger.error( + `recordFailure: Redis INCR failed (${(err as Error).message}); dropping.`, + ); + } } - reset(ip: string): void { - this.failedAttempts.delete(ip); + /** + * Clears the IP's counter (e.g. after a successful authentication). Best + * effort: a Redis outage here is logged but does not throw. + */ + async reset(ip: string): Promise { + const key = this.keyFor(ip); + try { + await this.redis.del(key); + } catch (err) { + this.logger.error( + `reset: Redis DEL failed (${(err as Error).message}); counter may persist briefly.`, + ); + } } + // ─── Test introspection helpers ───────────────────────────────────────── + // Kept on the public so the unit suite can validate the key shape and + // existence semantics without poking at Redis internals. + /** Test introspection helper — not used by production callers. */ - getCacheSize(): number { - return this.failedAttempts.size; + resolveKey(ip: string): string { + return this.keyFor(ip); } - /** Test introspection helper — checks for presence in the bounded cache. */ - has(ip: string): boolean { - return this.failedAttempts.has(ip); + /** Test introspection helper — composes `KEY`-then-`EXISTS`. */ + async has(ip: string): Promise { + try { + const result = await this.redis.exists(this.keyFor(ip)); + return result > 0; + } catch (err) { + this.logger.error(`has: Redis EXISTS failed (${(err as Error).message}).`); + return false; + } } }