diff --git a/backend/prisma/schema.prisma b/backend/prisma/schema.prisma index 9f658a2c..1945c966 100644 --- a/backend/prisma/schema.prisma +++ b/backend/prisma/schema.prisma @@ -1973,3 +1973,113 @@ model FeeChangeLog { @@index([createdAt]) @@map("fee_change_logs") } + +// ─── Issue #447: Smart Contract Event Indexer ───────────────────────────────── + +enum IndexedEventChain { + stellar + evm +} + +model IndexedEvent { + id String @id @default(uuid()) + dedupKey String @unique @map("dedup_key") // chain:txHash[:logIndex] + chain IndexedEventChain + contractAddress String @map("contract_address") + eventType String @map("event_type") + blockNumber BigInt @map("block_number") + txHash String @map("tx_hash") + timestamp DateTime + payload Json + confirmations Int @default(0) + retentionUntil DateTime @map("retention_until") + createdAt DateTime @default(now()) @map("created_at") + + @@index([chain, contractAddress, eventType]) + @@index([chain, contractAddress, timestamp]) + @@index([eventType]) + @@index([timestamp]) + @@index([retentionUntil]) + @@map("indexed_events") +} + +model WsEventSubscription { + id String @id @default(uuid()) + sessionId String @map("session_id") + contractAddress String? @map("contract_address") + eventType String? @map("event_type") + chain String? + createdAt DateTime @default(now()) @map("created_at") + expiresAt DateTime @map("expires_at") + + @@index([sessionId]) + @@index([expiresAt]) + @@map("ws_event_subscriptions") +} + +// ─── Issue #446: AI Payment Routing Metrics ──────────────────────────────────── + +model RoutingDecision { + id String @id @default(uuid()) + tenantId String? @map("tenant_id") + requestId String @unique @map("request_id") + selectedChain String @map("selected_chain") + fallbackChains String[] @map("fallback_chains") + scoreStellar Float? @map("score_stellar") + scoreEvm Float? @map("score_evm") + featureSnapshot Json @map("feature_snapshot") // gas prices, latencies, success rates at decision time + rationale String? + latencyMs Int? @map("latency_ms") + isManualOverride Boolean @default(false) @map("is_manual_override") + overrideBy String? @map("override_by") + abVariant String? @map("ab_variant") // "static" | "ai" + outcome String? // "success" | "failed" | null (pending) + createdAt DateTime @default(now()) @map("created_at") + + @@index([tenantId, createdAt]) + @@index([selectedChain, createdAt]) + @@index([abVariant]) + @@map("routing_decisions") +} + +model ChainPerformanceMetric { + id String @id @default(uuid()) + chain String + sampleAt DateTime @default(now()) @map("sample_at") + avgGasPrice Float? @map("avg_gas_price") + avgConfirmTimeMs Float? @map("avg_confirm_time_ms") + successRate Float? @map("success_rate") // 0–1 + p50LatencyMs Float? @map("p50_latency_ms") + p99LatencyMs Float? @map("p99_latency_ms") + sampleSize Int @default(0) @map("sample_size") + + @@index([chain, sampleAt]) + @@index([sampleAt]) + @@map("chain_performance_metrics") +} + +// ─── Issue #668: PII Audit ──────────────────────────────────────────────────── + +enum PiiClassificationLevel { + strict + standard + permissive +} + +model PiiAuditLog { + id String @id @default(uuid()) + endpoint String + method String @default("GET") + fieldPath String @map("field_path") // JSON path of detected PII + piiType String @map("pii_type") // email | phone | ssn | crypto_address | api_key + action String @default("redacted") // redacted | masked | allowed + level PiiClassificationLevel @default(standard) + tenantId String? @map("tenant_id") + requestId String? @map("request_id") + createdAt DateTime @default(now()) @map("created_at") + + @@index([endpoint, createdAt]) + @@index([piiType]) + @@index([tenantId, createdAt]) + @@map("pii_audit_logs") +} diff --git a/backend/src/index.ts b/backend/src/index.ts index bd27a9dd..55e9b81f 100644 --- a/backend/src/index.ts +++ b/backend/src/index.ts @@ -130,6 +130,10 @@ import { liquidityProtectionRouter } from './routes/liquidity-protection.js'; import { bulkPaymentsRouter } from './routes/bulk-payments.js'; import { feesRouter } from './routes/fees.js'; import { apiUsageTracker, checkQuota } from './middleware/api-usage-tracker.js'; +import indexerRouter from './routes/indexer.js'; +import aiRoutingRouter from './routes/ai-routing.js'; +import piiRouter from './routes/pii.js'; +import { piiRedactionMiddleware } from './middleware/pii-redaction.js'; // Validate environment variables at startup validateEnv(); @@ -217,6 +221,8 @@ app.use('/webhooks', webhookHandlersRouter); app.use(express.json()); app.use(express.text({ type: ['text/csv', 'text/plain'] })); app.use('/api', openApiValidator({ validateResponses: process.env.OPENAPI_VALIDATE_RESPONSES === 'true' })); +// Redact PII from all outgoing JSON API responses — Issue #668 +app.use('/api', piiRedactionMiddleware); app.use( compressionMiddleware({ @@ -414,6 +420,15 @@ app.use('/api/v1/payments/bulk', bulkPaymentsRouter); // Dynamic fee calculation engine with tiered pricing — Issue #468 app.use('/api/v1/fees', feesRouter); +// Smart contract event indexer with real-time WebSocket streams — Issue #447 +app.use('/api/v1/indexer', indexerRouter); + +// AI-powered payment routing engine — Issue #446 +app.use('/api/v1/routing/ai', aiRoutingRouter); + +// PII classification and redaction audit — Issue #668 +app.use('/api/v1/pii', piiRouter); + // Sandbox environment for testing (with relaxed rate limits) const sandboxRouter = createSandboxRouter(getSandboxManager(), getMockPaymentProcessor(), getTestDataSeeder()); app.use('/api/v1/sandbox', sandboxRateLimiter, sandboxRouter); diff --git a/backend/src/middleware/pii-redaction.ts b/backend/src/middleware/pii-redaction.ts new file mode 100644 index 00000000..4ef6cd57 --- /dev/null +++ b/backend/src/middleware/pii-redaction.ts @@ -0,0 +1,75 @@ +/** + * PII Redaction Middleware (#668) + * + * Express middleware that intercepts outgoing JSON responses and log records, + * runs them through the PiiClassifier, redacts detected PII, and writes an + * audit entry to the PiiAuditLog table. + */ +import type { Request, Response, NextFunction } from 'express'; +import { piiClassifier, type DetectedPii } from '../services/pii/pii-classifier.js'; +import { prisma } from '../lib/prisma.js'; + +// ─── Response redaction middleware ──────────────────────────────────────────── + +export function piiRedactionMiddleware(req: Request, res: Response, next: NextFunction): void { + const originalJson = res.json.bind(res) as typeof res.json; + + res.json = function (body: unknown): Response { + try { + const { detections, redacted } = piiClassifier.classify(body); + if (detections.length > 0) { + void persistAudit(detections, req); + return originalJson(redacted); + } + } catch { + // Never block the response on classifier failure + } + return originalJson(body); + }; + + next(); +} + +async function persistAudit(detections: DetectedPii[], req: Request): Promise { + const tenantId = (req as Request & { tenantId?: string }).tenantId; + const requestId = (req as Request & { id?: string }).id; + + await prisma.piiAuditLog.createMany({ + data: detections.map((d) => ({ + endpoint: req.path, + method: req.method, + fieldPath: d.path, + piiType: d.type, + action: 'redacted', + level: 'standard', + tenantId: tenantId ?? null, + requestId: requestId ?? null, + })), + skipDuplicates: true, + }).catch(() => { /* non-fatal */ }); +} + +// ─── Log redaction helper (for Pino / Winston formatters) ──────────────────── + +/** + * Pass this as a `redact` transform in your logger config. + * Compatible with pino's `redact` option when used as a custom serializer. + */ +export function redactLogRecord(record: Record): Record { + try { + const { redacted } = piiClassifier.classify(record); + return redacted; + } catch { + return record; + } +} + +/** + * Pino-compatible serializer that strips PII from any object field. + */ +export const piiLogSerializer = { + // Applied to any field named "body", "payload", "data", "req", "res" + body: redactLogRecord, + payload: redactLogRecord, + data: redactLogRecord, +}; diff --git a/backend/src/queues/routing-evaluation.queue.ts b/backend/src/queues/routing-evaluation.queue.ts new file mode 100644 index 00000000..f976f0af --- /dev/null +++ b/backend/src/queues/routing-evaluation.queue.ts @@ -0,0 +1,121 @@ +/** + * Routing evaluation queue (#446) + * + * Periodically samples chain performance (gas prices, latency, success rate) + * and writes them to: + * 1. Redis sorted sets – for sub-millisecond read by the AI router + * 2. ChainPerformanceMetric Prisma model – for historical analysis + */ +import { Queue, Worker, type Job } from 'bullmq'; +import { prisma } from '../../lib/prisma.js'; +import type { ChainFeatures } from '../routing/ai-router.js'; + +const QUEUE_NAME = 'routing-evaluation'; +const REDIS_KEY_PREFIX = 'chain:perf:'; + +// ─── Mock collectors – replace with real RPC/API calls in production ────────── + +async function collectStellarMetrics(): Promise { + // TODO: query Stellar Horizon / Soroban RPC for real fee stats + return { + chain: 'stellar', + avgGasPrice: 0.00001, + avgConfirmTimeMs: 5_000, + successRate: 0.98, + p99LatencyMs: 6_000, + }; +} + +async function collectEvmMetrics(): Promise { + // TODO: query EVM node eth_gasPrice + filter recent txs for success rate + return { + chain: 'evm', + avgGasPrice: 30, + avgConfirmTimeMs: 15_000, + successRate: 0.94, + p99LatencyMs: 20_000, + }; +} + +// ─── Queue & Worker ─────────────────────────────────────────────────────────── + +export interface RoutingEvalJobData { + sampleId: string; +} + +let _queue: Queue | null = null; +let _worker: Worker | null = null; + +type RedisClient = { + zadd(key: string, score: number, member: string): Promise; + zremrangebyscore(key: string, min: number | string, max: number | string): Promise; +}; + +export function startRoutingEvalQueue( + connection: { host: string; port: number }, + redisClient?: RedisClient, +): void { + _queue = new Queue(QUEUE_NAME, { connection }); + + _worker = new Worker( + QUEUE_NAME, + async (_job: Job) => { + const collectors = [collectStellarMetrics, collectEvmMetrics]; + const results = await Promise.allSettled(collectors.map((fn) => fn())); + + for (const result of results) { + if (result.status !== 'fulfilled') continue; + const metrics = result.value; + const now = new Date(); + + // Persist to DB + await prisma.chainPerformanceMetric.create({ + data: { + chain: metrics.chain, + sampleAt: now, + avgGasPrice: metrics.avgGasPrice, + avgConfirmTimeMs: metrics.avgConfirmTimeMs, + successRate: metrics.successRate, + p99LatencyMs: metrics.p99LatencyMs, + sampleSize: 1, + }, + }).catch(() => { /* non-fatal */ }); + + // Write to Redis sorted set (score = timestamp for TTL pruning) + if (redisClient) { + const key = `${REDIS_KEY_PREFIX}${metrics.chain}`; + const member = JSON.stringify(metrics); + const nowMs = now.getTime(); + await redisClient.zadd(key, nowMs, member).catch(() => {}); + // Prune samples older than 1 hour + await redisClient.zremrangebyscore(key, '-inf', nowMs - 3_600_000).catch(() => {}); + } + } + }, + { connection, concurrency: 1 }, + ); + + _worker.on('failed', (job, err) => { + console.error(`[routing-eval] job ${job?.id} failed:`, err); + }); +} + +/** Schedule recurring eval jobs (call once on server startup) */ +export async function scheduleRoutingEvalJobs(intervalMs = 60_000): Promise { + if (!_queue) throw new Error('startRoutingEvalQueue must be called first'); + await _queue.upsertJobScheduler( + 'routing-eval-periodic', + { every: intervalMs }, + { name: 'routing-eval', data: { sampleId: 'periodic' } }, + ); +} + +export function stopRoutingEvalQueue(): Promise { + return Promise.all([ + _worker?.close(), + _queue?.close(), + ]).then(() => { + _queue = null; + _worker = null; + }); +} diff --git a/backend/src/routes/ai-routing.ts b/backend/src/routes/ai-routing.ts new file mode 100644 index 00000000..6f4e44fb --- /dev/null +++ b/backend/src/routes/ai-routing.ts @@ -0,0 +1,149 @@ +/** + * AI routing admin & A/B test endpoints (#446) + * + * POST /api/v1/routing/decide - get a routing decision (clients/AI agents) + * POST /api/v1/routing/override - manual override for a specific tenant (admin) + * GET /api/v1/routing/decisions - paginated decision log + * GET /api/v1/routing/ab-report - A/B test comparison: static vs AI routing + * GET /api/v1/routing/chain-metrics - latest sampled chain performance + */ +import { Router, type Request, type Response } from 'express'; +import { prisma } from '../lib/prisma.js'; +import { aiRouter } from '../services/routing/ai-router.js'; + +const router = Router(); + +// POST /api/v1/routing/decide +router.post('/decide', async (req: Request, res: Response) => { + try { + const { tenantId, amount, fromAsset, preferSpeed, preferCost, abVariant } = req.body as { + tenantId?: string; + amount?: number; + fromAsset?: string; + preferSpeed?: boolean; + preferCost?: boolean; + abVariant?: 'static' | 'ai'; + }; + + const result = await aiRouter.route({ + tenantId, + amount: amount ?? 0, + fromAsset: fromAsset ?? 'XLM', + preferSpeed, + preferCost, + abVariant: abVariant ?? 'ai', + }); + + res.json({ data: result }); + } catch (err) { + res.status(500).json({ error: err instanceof Error ? err.message : 'Routing failed' }); + } +}); + +// POST /api/v1/routing/override (admin) +router.post('/override', async (req: Request, res: Response) => { + try { + const { tenantId, chain, actor, amount, fromAsset } = req.body as { + tenantId?: string; + chain: string; + actor: string; + amount?: number; + fromAsset?: string; + }; + + if (!chain || !actor) { + return res.status(400).json({ error: 'chain and actor are required' }); + } + + const result = await aiRouter.route({ + tenantId, + amount: amount ?? 0, + fromAsset: fromAsset ?? 'XLM', + manualOverride: { chain, actor }, + }); + + res.json({ data: result }); + } catch (err) { + res.status(500).json({ error: err instanceof Error ? err.message : 'Override failed' }); + } +}); + +// GET /api/v1/routing/decisions +router.get('/decisions', async (req: Request, res: Response) => { + try { + const { tenantId, chain, limit = '50', offset = '0' } = req.query as Record; + const take = Math.min(Number(limit) || 50, 200); + const skip = Number(offset) || 0; + + const where = { + ...(tenantId ? { tenantId } : {}), + ...(chain ? { selectedChain: chain } : {}), + }; + + const [decisions, total] = await Promise.all([ + prisma.routingDecision.findMany({ + where, + orderBy: { createdAt: 'desc' }, + take, + skip, + }), + prisma.routingDecision.count({ where }), + ]); + + res.json({ data: decisions, total, limit: take, offset: skip }); + } catch { + res.status(500).json({ error: 'Failed to fetch decisions' }); + } +}); + +// GET /api/v1/routing/ab-report +router.get('/ab-report', async (_req: Request, res: Response) => { + try { + const [aiStats, staticStats] = await Promise.all([ + prisma.routingDecision.groupBy({ + by: ['selectedChain'], + where: { abVariant: 'ai' }, + _count: { id: true }, + _avg: { latencyMs: true }, + }), + prisma.routingDecision.groupBy({ + by: ['selectedChain'], + where: { abVariant: 'static' }, + _count: { id: true }, + _avg: { latencyMs: true }, + }), + ]); + + res.json({ + data: { + ai: aiStats, + static: staticStats, + }, + }); + } catch { + res.status(500).json({ error: 'Failed to generate A/B report' }); + } +}); + +// GET /api/v1/routing/chain-metrics +router.get('/chain-metrics', async (req: Request, res: Response) => { + try { + const { chain } = req.query as { chain?: string }; + const since = new Date(Date.now() - 3_600_000); // last 1 hour + + const metrics = await prisma.chainPerformanceMetric.findMany({ + where: { + ...(chain ? { chain } : {}), + sampleAt: { gte: since }, + }, + orderBy: { sampleAt: 'desc' }, + take: 100, + }); + + res.json({ data: metrics }); + } catch { + res.status(500).json({ error: 'Failed to fetch chain metrics' }); + } +}); + +export default router; diff --git a/backend/src/routes/indexer.ts b/backend/src/routes/indexer.ts new file mode 100644 index 00000000..61c6931d --- /dev/null +++ b/backend/src/routes/indexer.ts @@ -0,0 +1,94 @@ +/** + * REST API routes for the smart contract event indexer (#447). + * GET /api/v1/indexer/events - paginated event history with filtering + * GET /api/v1/indexer/events/:id - single event by id + * DELETE /api/v1/indexer/events/prune - manual retention prune (admin) + */ +import { Router, type Request, type Response } from 'express'; +import { prisma } from '../lib/prisma.js'; +import type { Prisma } from '@prisma/client'; + +const router = Router(); + +// GET /api/v1/indexer/events +router.get('/events', async (req: Request, res: Response) => { + try { + const { + chain, + contractAddress, + eventType, + fromTimestamp, + toTimestamp, + limit = '50', + offset = '0', + } = req.query as Record; + + const where: Prisma.IndexedEventWhereInput = {}; + if (chain) where.chain = chain as 'stellar' | 'evm'; + if (contractAddress) where.contractAddress = { equals: contractAddress, mode: 'insensitive' }; + if (eventType) where.eventType = eventType; + if (fromTimestamp || toTimestamp) { + where.timestamp = { + ...(fromTimestamp ? { gte: new Date(fromTimestamp) } : {}), + ...(toTimestamp ? { lte: new Date(toTimestamp) } : {}), + }; + } + + const take = Math.min(Number(limit) || 50, 200); + const skip = Number(offset) || 0; + + const [events, total] = await Promise.all([ + prisma.indexedEvent.findMany({ + where, + orderBy: { timestamp: 'desc' }, + take, + skip, + select: { + id: true, + dedupKey: true, + chain: true, + contractAddress: true, + eventType: true, + blockNumber: true, + txHash: true, + timestamp: true, + payload: true, + confirmations: true, + createdAt: true, + }, + }), + prisma.indexedEvent.count({ where }), + ]); + + res.json({ data: events, total, limit: take, offset: skip }); + } catch (err) { + res.status(500).json({ error: 'Failed to fetch events' }); + } +}); + +// GET /api/v1/indexer/events/:id +router.get('/events/:id', async (req: Request, res: Response) => { + try { + const event = await prisma.indexedEvent.findUnique({ + where: { id: req.params.id }, + }); + if (!event) return res.status(404).json({ error: 'Event not found' }); + res.json({ data: event }); + } catch { + res.status(500).json({ error: 'Failed to fetch event' }); + } +}); + +// DELETE /api/v1/indexer/events/prune (admin: prune expired events) +router.delete('/events/prune', async (_req: Request, res: Response) => { + try { + const result = await prisma.indexedEvent.deleteMany({ + where: { retentionUntil: { lt: new Date() } }, + }); + res.json({ pruned: result.count }); + } catch { + res.status(500).json({ error: 'Prune failed' }); + } +}); + +export default router; diff --git a/backend/src/routes/pii.ts b/backend/src/routes/pii.ts new file mode 100644 index 00000000..73c63ac2 --- /dev/null +++ b/backend/src/routes/pii.ts @@ -0,0 +1,90 @@ +/** + * PII audit report endpoints (#668) + * + * GET /api/v1/pii/audit - paginated PII detection log + * GET /api/v1/pii/audit/report - aggregate report by endpoint/piiType + */ +import { Router, type Request, type Response } from 'express'; +import { prisma } from '../lib/prisma.js'; + +const router = Router(); + +// GET /api/v1/pii/audit +router.get('/audit', async (req: Request, res: Response) => { + try { + const { endpoint, piiType, tenantId, limit = '50', offset = '0' } = + req.query as Record; + + const where = { + ...(endpoint ? { endpoint } : {}), + ...(piiType ? { piiType } : {}), + ...(tenantId ? { tenantId } : {}), + }; + + const take = Math.min(Number(limit) || 50, 200); + const skip = Number(offset) || 0; + + const [logs, total] = await Promise.all([ + prisma.piiAuditLog.findMany({ + where, + orderBy: { createdAt: 'desc' }, + take, + skip, + select: { + id: true, + endpoint: true, + method: true, + fieldPath: true, + piiType: true, + action: true, + level: true, + tenantId: true, + requestId: true, + createdAt: true, + }, + }), + prisma.piiAuditLog.count({ where }), + ]); + + res.json({ data: logs, total, limit: take, offset: skip }); + } catch { + res.status(500).json({ error: 'Failed to fetch PII audit logs' }); + } +}); + +// GET /api/v1/pii/audit/report +router.get('/audit/report', async (_req: Request, res: Response) => { + try { + const [byEndpoint, byType] = await Promise.all([ + prisma.piiAuditLog.groupBy({ + by: ['endpoint', 'method'], + _count: { id: true }, + orderBy: { _count: { id: 'desc' } }, + take: 20, + }), + prisma.piiAuditLog.groupBy({ + by: ['piiType'], + _count: { id: true }, + orderBy: { _count: { id: 'desc' } }, + }), + ]); + + res.json({ + data: { + topEndpoints: byEndpoint.map((r) => ({ + endpoint: r.endpoint, + method: r.method, + detections: r._count.id, + })), + byPiiType: byType.map((r) => ({ + piiType: r.piiType, + detections: r._count.id, + })), + }, + }); + } catch { + res.status(500).json({ error: 'Failed to generate PII report' }); + } +}); + +export default router; diff --git a/backend/src/services/__tests__/ai-router.test.ts b/backend/src/services/__tests__/ai-router.test.ts new file mode 100644 index 00000000..be2a6eb2 --- /dev/null +++ b/backend/src/services/__tests__/ai-router.test.ts @@ -0,0 +1,68 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { AiPaymentRouter, setFeatureStore, type ChainFeatures } from '../../services/routing/ai-router.js'; + +// Mock prisma so tests don't need a DB connection +vi.mock('../../lib/prisma.js', () => ({ + prisma: { routingDecision: { create: vi.fn().mockResolvedValue({}) } }, +})); + +const mockFeatures: ChainFeatures[] = [ + { chain: 'stellar', avgGasPrice: 0.00001, avgConfirmTimeMs: 5_000, successRate: 0.98, p99LatencyMs: 6_000 }, + { chain: 'evm', avgGasPrice: 30, avgConfirmTimeMs: 15_000, successRate: 0.94, p99LatencyMs: 20_000 }, +]; + +beforeEach(() => { + setFeatureStore({ getFeatures: async () => mockFeatures }); +}); + +describe('AiPaymentRouter', () => { + const router = new AiPaymentRouter(); + + it('returns a selected chain', async () => { + const result = await router.route({ amount: 100, fromAsset: 'XLM' }); + expect(['stellar', 'evm']).toContain(result.selectedChain); + }); + + it('selects stellar (lower cost) with cost preference', async () => { + const result = await router.route({ amount: 100, fromAsset: 'XLM', preferCost: true }); + expect(result.selectedChain).toBe('stellar'); + }); + + it('selects stellar (faster) with speed preference', async () => { + const result = await router.route({ amount: 100, fromAsset: 'XLM', preferSpeed: true }); + expect(result.selectedChain).toBe('stellar'); + }); + + it('respects manual override', async () => { + const result = await router.route({ + amount: 100, + fromAsset: 'XLM', + manualOverride: { chain: 'evm', actor: 'admin@test.com' }, + }); + expect(result.selectedChain).toBe('evm'); + expect(result.rationale).toContain('Manual override'); + }); + + it('includes fallback chains', async () => { + const result = await router.route({ amount: 100, fromAsset: 'XLM' }); + expect(Array.isArray(result.fallbackChains)).toBe(true); + }); + + it('returns latencyMs under 50ms', async () => { + const result = await router.route({ amount: 100, fromAsset: 'XLM' }); + expect(result.latencyMs).toBeLessThan(50); + }); + + it('throws when no features available', async () => { + setFeatureStore({ getFeatures: async () => [] }); + await expect(router.route({ amount: 100, fromAsset: 'XLM' })).rejects.toThrow('No chain features'); + // restore + setFeatureStore({ getFeatures: async () => mockFeatures }); + }); + + it('includes scores for each chain', async () => { + const result = await router.route({ amount: 100, fromAsset: 'XLM' }); + expect(result.scores).toHaveProperty('stellar'); + expect(result.scores).toHaveProperty('evm'); + }); +}); diff --git a/backend/src/services/__tests__/pii-classifier.test.ts b/backend/src/services/__tests__/pii-classifier.test.ts new file mode 100644 index 00000000..48c2b230 --- /dev/null +++ b/backend/src/services/__tests__/pii-classifier.test.ts @@ -0,0 +1,79 @@ +import { describe, it, expect } from 'vitest'; +import { PiiClassifier, piiClassifier } from '../../services/pii/pii-classifier.js'; + +describe('PiiClassifier', () => { + describe('scanString', () => { + it('detects email', () => { + const d = piiClassifier.scanString('contact user@example.com now'); + expect(d.some(x => x.type === 'email')).toBe(true); + }); + + it('detects SSN', () => { + const d = piiClassifier.scanString('ssn: 123-45-6789'); + expect(d.some(x => x.type === 'ssn')).toBe(true); + }); + + it('detects API key', () => { + const d = piiClassifier.scanString('key=sk_live_abcdefghijklmnop'); + expect(d.some(x => x.type === 'api_key')).toBe(true); + }); + + it('returns empty for clean strings', () => { + const d = piiClassifier.scanString('hello world, no pii here'); + expect(d).toHaveLength(0); + }); + }); + + describe('classify', () => { + it('redacts email in nested object', () => { + const { redacted, detections } = piiClassifier.classify({ + user: { email: 'alice@example.com', name: 'Alice' }, + }); + expect(detections.length).toBeGreaterThan(0); + // original email is replaced with the mask + expect((redacted as any).user.email).not.toBe('alice@example.com'); + }); + + it('does not mutate original object', () => { + const original = { email: 'test@test.com' }; + piiClassifier.classify(original); + expect(original.email).toBe('test@test.com'); + }); + + it('handles arrays', () => { + const { redacted } = piiClassifier.classify({ emails: ['a@b.com', 'c@d.com'] }); + const emails = (redacted as any).emails as string[]; + // each address should be replaced with the mask, not the original + expect(emails.every((e: string) => e !== 'a@b.com' && e !== 'c@d.com')).toBe(true); + }); + + it('passes through non-PII values unchanged', () => { + const { redacted } = piiClassifier.classify({ count: 42, label: 'payment' }); + expect((redacted as any).count).toBe(42); + expect((redacted as any).label).toBe('payment'); + }); + }); + + describe('classification levels', () => { + it('strict level detects EVM crypto address', () => { + const strict = new PiiClassifier('strict'); + const d = strict.scanString('wallet: 0xAbCdEf1234567890abcdef1234567890ABCDEF12'); + expect(d.some(x => x.type === 'crypto_address')).toBe(true); + }); + + it('permissive level skips crypto addresses', () => { + const permissive = new PiiClassifier('permissive'); + const d = permissive.scanString('wallet: 0xAbCdEf1234567890abcdef1234567890ABCDEF12'); + expect(d.some(x => x.type === 'crypto_address')).toBe(false); + }); + }); + + describe('custom patterns', () => { + it('detects custom pattern', () => { + const c = new PiiClassifier('standard'); + c.addPattern({ type: 'custom', regex: /EMP-\d{6}/g, minLevel: 'standard' }); + const d = c.scanString('id: EMP-123456'); + expect(d.some(x => x.type === 'custom')).toBe(true); + }); + }); +}); diff --git a/backend/src/services/indexer/evm-listener.ts b/backend/src/services/indexer/evm-listener.ts new file mode 100644 index 00000000..9ef3090a --- /dev/null +++ b/backend/src/services/indexer/evm-listener.ts @@ -0,0 +1,126 @@ +/** + * EVM event listener – uses ethers.js to subscribe to / poll contract logs, + * normalises them into IndexedEvent shape, and emits to the indexer event bus. + * Handles chain reorganisations by requiring MIN_CONFIRMATIONS before emitting. + */ +import { EventEmitter } from 'node:events'; +import { ethers } from 'ethers'; +import type { NormalizedEvent } from './soroban-listener.js'; + +export interface EvmListenerOptions { + rpcUrl: string; + contractAddress: string; + abi: ethers.InterfaceAbi; + minConfirmations?: number; + pollIntervalMs?: number; +} + +// Minimal ABI covering the events we care about. +const DEFAULT_ABI: ethers.InterfaceAbi = [ + 'event PaymentSent(address indexed from, address indexed to, uint256 amount)', + 'event PaymentReceived(address indexed from, address indexed to, uint256 amount)', + 'event DisputeRaised(bytes32 indexed projectId, address indexed initiator)', + 'event SettlementCompleted(bytes32 indexed projectId, uint256 amount)', + 'event EscrowFunded(bytes32 indexed projectId, uint256 amount)', + 'event EscrowReleased(bytes32 indexed projectId, uint256 amount)', +]; + +export class EvmListener extends EventEmitter { + private readonly contractAddress: string; + private readonly minConfirmations: number; + private readonly pollIntervalMs: number; + private provider: ethers.JsonRpcProvider | null = null; + private contract: ethers.Contract | null = null; + private abi: ethers.InterfaceAbi; + private rpcUrl: string; + private timer: ReturnType | null = null; + private lastBlock = 0; + + constructor(opts: EvmListenerOptions) { + super(); + this.contractAddress = opts.contractAddress; + this.abi = opts.abi ?? DEFAULT_ABI; + this.rpcUrl = opts.rpcUrl; + this.minConfirmations = opts.minConfirmations ?? 6; + this.pollIntervalMs = opts.pollIntervalMs ?? 12_000; + } + + async start(): Promise { + if (this.timer) return; + this.provider = new ethers.JsonRpcProvider(this.rpcUrl); + this.contract = new ethers.Contract(this.contractAddress, this.abi, this.provider); + + try { + this.lastBlock = (await this.provider.getBlockNumber()) - this.minConfirmations; + } catch { + this.lastBlock = 0; + } + + this.timer = setInterval(() => void this.poll(), this.pollIntervalMs); + } + + stop(): void { + if (this.timer) { + clearInterval(this.timer); + this.timer = null; + } + this.provider = null; + this.contract = null; + } + + private async poll(): Promise { + if (!this.provider || !this.contract) return; + + try { + const currentBlock = await this.provider.getBlockNumber(); + const safeBlock = currentBlock - this.minConfirmations; + if (safeBlock <= this.lastBlock) return; + + const logs = await this.contract.queryFilter('*', this.lastBlock + 1, safeBlock); + + for (const log of logs) { + const event = this.normalizeLog(log, safeBlock); + if (event) this.emit('event', event); + } + + this.lastBlock = safeBlock; + } catch (err) { + this.emit('error', err); + } + } + + private normalizeLog( + log: ethers.EventLog | ethers.Log, + currentSafeBlock: number, + ): NormalizedEvent | null { + try { + const blockNum = log.blockNumber ?? 0; + const isEventLog = 'eventName' in log; + return { + id: `evm:${log.transactionHash}:${log.index ?? 0}`, + chain: 'evm', + contractAddress: this.contractAddress, + eventType: isEventLog ? (log as ethers.EventLog).eventName : 'UnknownEvent', + blockNumber: blockNum, + txHash: log.transactionHash ?? '', + timestamp: new Date(), // block timestamp requires extra RPC call; set at persist time + payload: isEventLog + ? this.serializeArgs((log as ethers.EventLog).args) + : { data: log.data, topics: log.topics }, + confirmations: currentSafeBlock - blockNum, + }; + } catch { + return null; + } + } + + private serializeArgs(args: ethers.Result): Record { + const out: Record = {}; + for (const [k, v] of Object.entries(args)) { + out[k] = typeof v === 'bigint' ? v.toString() : v; + } + return out; + } +} + +export { DEFAULT_ABI as EVM_DEFAULT_ABI }; diff --git a/backend/src/services/indexer/soroban-listener.ts b/backend/src/services/indexer/soroban-listener.ts new file mode 100644 index 00000000..0b5d8826 --- /dev/null +++ b/backend/src/services/indexer/soroban-listener.ts @@ -0,0 +1,134 @@ +/** + * Soroban event listener – polls Stellar Horizon for contract events, + * normalises them into IndexedEvent shape, deduplicates via Redis, and + * publishes to the indexer event bus. + */ +import * as StellarSdk from '@stellar/stellar-sdk'; +import { EventEmitter } from 'node:events'; +import { config } from '../../config/env.js'; + +export type IndexedEventChain = 'stellar' | 'evm'; + +export interface NormalizedEvent { + id: string; // chain-unique deduplication key + chain: IndexedEventChain; + contractAddress: string; + eventType: string; + blockNumber: number; + txHash: string; + timestamp: Date; + payload: Record; + confirmations: number; +} + +export interface SorobanListenerOptions { + contractId: string; + pollIntervalMs?: number; + minConfirmations?: number; +} + +const NETWORK = config().STELLAR_NETWORK ?? 'testnet'; +const HORIZON_URL = + NETWORK === 'public' + ? 'https://horizon.stellar.org' + : 'https://horizon-testnet.stellar.org'; + +function parseEventType(topics: string[]): string { + // Topics[0] is typically the event name symbol + if (topics.length === 0) return 'unknown'; + try { + const scVal = StellarSdk.xdr.ScVal.fromXDR(topics[0], 'base64'); + if (scVal.switch() === StellarSdk.xdr.ScValType.scvSymbol()) { + return scVal.sym().toString(); + } + } catch { + // ignore parse errors; use raw value + } + return topics[0].slice(0, 64); +} + +function parsePayload(data: string): Record { + try { + const scVal = StellarSdk.xdr.ScVal.fromXDR(data, 'base64'); + return { raw: scVal.toXDR('base64'), parsed: scVal.toString() }; + } catch { + return { raw: data }; + } +} + +export class SorobanListener extends EventEmitter { + private readonly contractId: string; + private readonly pollIntervalMs: number; + private readonly minConfirmations: number; + private readonly server: StellarSdk.Horizon.Server; + private cursor = 'now'; + private timer: ReturnType | null = null; + + constructor(opts: SorobanListenerOptions) { + super(); + this.contractId = opts.contractId; + this.pollIntervalMs = opts.pollIntervalMs ?? 5_000; + this.minConfirmations = opts.minConfirmations ?? 6; + this.server = new StellarSdk.Horizon.Server(HORIZON_URL); + } + + start(): void { + if (this.timer) return; + void this.poll(); + this.timer = setInterval(() => void this.poll(), this.pollIntervalMs); + } + + stop(): void { + if (this.timer) { + clearInterval(this.timer); + this.timer = null; + } + } + + private async poll(): Promise { + try { + // Horizon /transactions gives us ledger context; for events we use the + // effects stream and filter by contract. Production use should switch to + // the Soroban RPC `getEvents` call once widely available. + const txPage = await this.server + .transactions() + .forAccount(this.contractId) + .cursor(this.cursor) + .limit(50) + .call(); + + for (const tx of txPage.records) { + if (!tx.successful) continue; + + const event: NormalizedEvent = { + id: `stellar:${tx.id}`, + chain: 'stellar', + contractAddress: this.contractId, + eventType: this.inferEventType(tx), + blockNumber: tx.ledger_attr as unknown as number, + txHash: tx.hash, + timestamp: new Date(tx.created_at), + payload: { envelope: tx.envelope_xdr, result: tx.result_xdr }, + confirmations: this.minConfirmations, // Stellar is final after ledger close + }; + + if (event.confirmations >= this.minConfirmations) { + this.emit('event', event); + } + + this.cursor = tx.paging_token; + } + } catch (err) { + this.emit('error', err); + } + } + + private inferEventType(tx: StellarSdk.Horizon.ServerApi.TransactionRecord): string { + // Simple heuristic; replace with full XDR decode for production. + const memo = typeof tx.memo === 'string' ? tx.memo : ''; + if (memo.includes('payment')) return 'PaymentSent'; + if (memo.includes('dispute')) return 'DisputeRaised'; + if (memo.includes('settle')) return 'SettlementCompleted'; + return 'ContractEvent'; + } +} diff --git a/backend/src/services/pii/pii-classifier.ts b/backend/src/services/pii/pii-classifier.ts new file mode 100644 index 00000000..179f8082 --- /dev/null +++ b/backend/src/services/pii/pii-classifier.ts @@ -0,0 +1,180 @@ +/** + * PII Classification Engine (#668) + * + * Detects and classifies PII in arbitrary JSON blobs or strings. + * Supports three classification levels (strict, standard, permissive) + * and is fully configurable with custom patterns. + */ + +export type PiiType = + | 'email' + | 'phone' + | 'ssn' + | 'credit_card' + | 'crypto_address' + | 'api_key' + | 'ip_address' + | 'date_of_birth' + | 'passport' + | 'custom'; + +export type ClassificationLevel = 'strict' | 'standard' | 'permissive'; + +export interface PiiPattern { + type: PiiType | string; + regex: RegExp; + /** Minimum severity level at which this pattern is active */ + minLevel: ClassificationLevel; + /** Replace matched groups with this mask instead of full redaction */ + mask?: string; +} + +export interface DetectedPii { + type: PiiType | string; + path: string; // JSON pointer, e.g. "/user/email" + original?: string; // only populated when level === 'strict' (for audit, never logged) + masked: string; +} + +export interface ClassificationResult { + detections: DetectedPii[]; + redacted: Record; +} + +// ─── Default patterns ───────────────────────────────────────────────────────── + +const LEVEL_ORDER: Record = { + strict: 0, + standard: 1, + permissive: 2, +}; + +const DEFAULT_PATTERNS: PiiPattern[] = [ + { + type: 'email', + regex: /[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}/g, + minLevel: 'permissive', + mask: '***@***.***', + }, + { + type: 'phone', + regex: /(?:\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]\d{3}[-.\s]\d{4}/g, + minLevel: 'standard', + mask: '***-***-****', + }, + { + type: 'ssn', + regex: /\b\d{3}-\d{2}-\d{4}\b/g, + minLevel: 'permissive', + mask: '***-**-****', + }, + { + type: 'credit_card', + regex: /\b(?:4\d{3}|5[1-5]\d{2}|6011|3[47]\d{2})[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b/g, + minLevel: 'permissive', + mask: '**** **** **** ****', + }, + { + type: 'crypto_address', + // Stellar (G... 56 chars), EVM (0x...), Bitcoin + regex: /\b(?:G[A-Z2-7]{54,55}|0x[a-fA-F0-9]{40}|[13][a-km-zA-HJ-NP-Z1-9]{25,34})\b/g, + minLevel: 'strict', + mask: '[CRYPTO_ADDR]', + }, + { + type: 'api_key', + // sk_*, pk_*, Bearer tokens + regex: /\b(?:sk|pk|api)_(?:live|test|[a-z]+)_[a-zA-Z0-9_\-]{16,}\b/g, + minLevel: 'permissive', + mask: '[API_KEY_REDACTED]', + }, + { + type: 'ip_address', + regex: /\b(?:\d{1,3}\.){3}\d{1,3}\b/g, + minLevel: 'strict', + mask: '*.*.*.*', + }, +]; + +// ─── Classifier ─────────────────────────────────────────────────────────────── + +export class PiiClassifier { + private patterns: PiiPattern[]; + private level: ClassificationLevel; + + constructor(level: ClassificationLevel = 'standard', extraPatterns: PiiPattern[] = []) { + this.level = level; + this.patterns = [...DEFAULT_PATTERNS, ...extraPatterns]; + } + + /** Scan a raw string and return detections */ + scanString(value: string, path = '/'): DetectedPii[] { + const detections: DetectedPii[] = []; + for (const pattern of this.patterns) { + // Pattern is active when our level is AT LEAST as strict as pattern.minLevel + // strict(0) <= strict(0) ✓ strict(0) <= permissive(2) ✓ + // permissive(2) <= strict(0) ✗ → permissive mode skips strict-only patterns + if (LEVEL_ORDER[this.level] > LEVEL_ORDER[pattern.minLevel]) continue; + const re = new RegExp(pattern.regex.source, pattern.regex.flags.includes('g') ? 'g' : 'g'); + let match: RegExpExecArray | null; + while ((match = re.exec(value)) !== null) { + detections.push({ + type: pattern.type, + path, + masked: pattern.mask ?? '[REDACTED]', + }); + } + } + return detections; + } + + /** Deep-scan a JSON object, redacting in place (returns a clone) */ + classify(obj: unknown, basePath = ''): ClassificationResult { + const detections: DetectedPii[] = []; + const redacted = this.redactValue(obj, basePath, detections); + return { detections, redacted: redacted as Record }; + } + + private redactValue(value: unknown, path: string, detections: DetectedPii[]): unknown { + if (typeof value === 'string') { + return this.redactString(value, path, detections); + } + if (Array.isArray(value)) { + return value.map((item, i) => this.redactValue(item, `${path}/${i}`, detections)); + } + if (value !== null && typeof value === 'object') { + const out: Record = {}; + for (const [k, v] of Object.entries(value as Record)) { + out[k] = this.redactValue(v, `${path}/${k}`, detections); + } + return out; + } + return value; + } + + private redactString(value: string, path: string, detections: DetectedPii[]): string { + let result = value; + for (const pattern of this.patterns) { + if (LEVEL_ORDER[this.level] > LEVEL_ORDER[pattern.minLevel]) continue; + const re = new RegExp(pattern.regex.source, pattern.regex.flags.includes('g') ? 'g' : 'g'); + if (re.test(value)) { + detections.push({ type: pattern.type, path, masked: pattern.mask ?? '[REDACTED]' }); + const re2 = new RegExp(pattern.regex.source, 'g'); + result = result.replace(re2, pattern.mask ?? '[REDACTED]'); + } + } + return result; + } + + /** Add a custom pattern at runtime */ + addPattern(p: PiiPattern): void { + this.patterns.push(p); + } + + setLevel(level: ClassificationLevel): void { + this.level = level; + } +} + +// Default singleton +export const piiClassifier = new PiiClassifier('standard'); diff --git a/backend/src/services/routing/ai-router.ts b/backend/src/services/routing/ai-router.ts new file mode 100644 index 00000000..8cc65079 --- /dev/null +++ b/backend/src/services/routing/ai-router.ts @@ -0,0 +1,219 @@ +/** + * AI-Powered Payment Routing Engine (#446) + * + * Selects the optimal chain for each payment using a lightweight scoring model + * that weighs real-time chain performance metrics (gas price, latency, + * success rate) collected by the routing evaluation queue. + * + * Design: + * - Feature extraction from Redis sorted sets (populated by the queue) + * - Weighted linear scoring as the "ML" model (drop-in for XGBoost inference) + * - Decision logged to RoutingDecision table for auditability & A/B testing + * - <50ms p99 – pure in-memory computation, no DB hit on hot path + */ +import { randomUUID } from 'node:crypto'; +import { prisma } from '../../lib/prisma.js'; + +// ─── Types ──────────────────────────────────────────────────────────────────── + +export interface ChainFeatures { + chain: string; + avgGasPrice: number; // gwei or stroops + avgConfirmTimeMs: number; + successRate: number; // 0–1 + p99LatencyMs: number; +} + +export interface RoutingRequest { + tenantId?: string; + amount: number; + fromAsset: string; + preferSpeed?: boolean; + preferCost?: boolean; + abVariant?: 'static' | 'ai'; + manualOverride?: { chain: string; actor: string }; +} + +export interface RoutingResult { + requestId: string; + selectedChain: string; + fallbackChains: string[]; + rationale: string; + latencyMs: number; + scores: Record; +} + +// ─── Model weights (tunable / replace with ONNX/XGBoost inference) ──────────── + +interface ModelWeights { + gasPrice: number; // lower is better + confirmTime: number; // lower is better + successRate: number; // higher is better + p99Latency: number; // lower is better +} + +const DEFAULT_WEIGHTS: ModelWeights = { + gasPrice: 0.30, + confirmTime: 0.25, + successRate: 0.30, + p99Latency: 0.15, +}; + +const SPEED_WEIGHTS: ModelWeights = { + gasPrice: 0.15, + confirmTime: 0.40, + successRate: 0.25, + p99Latency: 0.20, +}; + +const COST_WEIGHTS: ModelWeights = { + gasPrice: 0.50, + confirmTime: 0.15, + successRate: 0.25, + p99Latency: 0.10, +}; + +// ─── Feature store (Redis-backed or fallback to DB) ─────────────────────────── + +export type FeatureStore = { + getFeatures(): Promise; +}; + +function buildInMemoryStore(): FeatureStore { + // Fallback static data – replaced at runtime by the routing queue + const defaults: ChainFeatures[] = [ + { chain: 'stellar', avgGasPrice: 0.00001, avgConfirmTimeMs: 5_000, successRate: 0.98, p99LatencyMs: 6_000 }, + { chain: 'evm', avgGasPrice: 30, avgConfirmTimeMs: 15_000, successRate: 0.94, p99LatencyMs: 20_000 }, + ]; + return { getFeatures: async () => defaults }; +} + +let _featureStore: FeatureStore = buildInMemoryStore(); + +export function setFeatureStore(store: FeatureStore): void { + _featureStore = store; +} + +// ─── Scoring ────────────────────────────────────────────────────────────────── + +function normalize(value: number, min: number, max: number): number { + if (max === min) return 0.5; + return (value - min) / (max - min); +} + +function scoreChains( + features: ChainFeatures[], + weights: ModelWeights, +): { chain: string; score: number }[] { + if (features.length === 0) return []; + + const gasPrices = features.map((f) => f.avgGasPrice); + const confirmTimes = features.map((f) => f.avgConfirmTimeMs); + const p99Latencies = features.map((f) => f.p99LatencyMs); + + const minGas = Math.min(...gasPrices), maxGas = Math.max(...gasPrices); + const minTime = Math.min(...confirmTimes), maxTime = Math.max(...confirmTimes); + const minP99 = Math.min(...p99Latencies), maxP99 = Math.max(...p99Latencies); + + return features.map((f) => { + // For cost metrics: lower is better → invert normalisation + const gasScore = 1 - normalize(f.avgGasPrice, minGas, maxGas); + const confirmScore = 1 - normalize(f.avgConfirmTimeMs, minTime, maxTime); + const successScore = f.successRate; // already 0–1, higher is better + const p99Score = 1 - normalize(f.p99LatencyMs, minP99, maxP99); + + const score = + weights.gasPrice * gasScore + + weights.confirmTime * confirmScore + + weights.successRate * successScore + + weights.p99Latency * p99Score; + + return { chain: f.chain, score: Math.round(score * 10_000) / 10_000 }; + }).sort((a, b) => b.score - a.score); +} + +// ─── Main router class ──────────────────────────────────────────────────────── + +export class AiPaymentRouter { + async route(req: RoutingRequest): Promise { + const start = Date.now(); + const requestId = randomUUID(); + + // Manual override short-circuits the model + if (req.manualOverride) { + const result: RoutingResult = { + requestId, + selectedChain: req.manualOverride.chain, + fallbackChains: [], + rationale: `Manual override by ${req.manualOverride.actor}`, + latencyMs: Date.now() - start, + scores: {}, + }; + void this.logDecision(result, req, {}, true); + return result; + } + + const features = await _featureStore.getFeatures(); + + // Choose weight profile + const weights = req.preferSpeed + ? SPEED_WEIGHTS + : req.preferCost + ? COST_WEIGHTS + : DEFAULT_WEIGHTS; + + const scored = scoreChains(features, weights); + + if (scored.length === 0) { + throw new Error('No chain features available for routing'); + } + + const [best, ...rest] = scored; + const scores = Object.fromEntries(scored.map((s) => [s.chain, s.score])); + const featureSnapshot = Object.fromEntries( + features.map((f) => [f.chain, f]), + ) as Record; + + const result: RoutingResult = { + requestId, + selectedChain: best.chain, + fallbackChains: rest.map((s) => s.chain), + rationale: `Selected ${best.chain} (score: ${best.score}) using ${req.preferSpeed ? 'speed' : req.preferCost ? 'cost' : 'balanced'} weights`, + latencyMs: Date.now() - start, + scores, + }; + + void this.logDecision(result, req, featureSnapshot as Record, false); + return result; + } + + private async logDecision( + result: RoutingResult, + req: RoutingRequest, + featureSnapshot: Record, + isManualOverride: boolean, + ): Promise { + try { + await prisma.routingDecision.create({ + data: { + requestId: result.requestId, + tenantId: req.tenantId, + selectedChain: result.selectedChain, + fallbackChains: result.fallbackChains, + scoreStellar: result.scores['stellar'] ?? null, + scoreEvm: result.scores['evm'] ?? null, + featureSnapshot, + rationale: result.rationale, + latencyMs: result.latencyMs, + isManualOverride, + overrideBy: req.manualOverride?.actor, + abVariant: req.abVariant ?? 'ai', + }, + }); + } catch { + // non-fatal: routing continues even if logging fails + } + } +} + +export const aiRouter = new AiPaymentRouter(); diff --git a/backend/src/websocket/event-stream.ts b/backend/src/websocket/event-stream.ts new file mode 100644 index 00000000..994aa234 --- /dev/null +++ b/backend/src/websocket/event-stream.ts @@ -0,0 +1,101 @@ +/** + * WebSocket event stream – bridges the indexer event bus to connected WS clients. + * Supports filtering by contractAddress, eventType, and time range. + * Uses Redis pub/sub for cross-instance broadcasting. + */ +import { EventEmitter } from 'node:events'; +import type { AgenticPayWebSocketServer } from '../websocket/server.js'; +import type { NormalizedEvent } from './soroban-listener.js'; + +export interface EventStreamOptions { + wsServer: AgenticPayWebSocketServer; + /** Optional Redis client for cross-instance pub/sub; if omitted events only broadcast locally */ + redisPublish?: (channel: string, message: string) => Promise; + redisSubscribe?: (channel: string, handler: (message: string) => void) => Promise; +} + +export interface EventFilter { + contractAddress?: string; + eventType?: string; + fromTimestamp?: Date; + toTimestamp?: Date; +} + +const REDIS_CHANNEL = 'indexer:events'; + +export class EventStreamHandler extends EventEmitter { + private readonly wsServer: AgenticPayWebSocketServer; + private readonly redisPublish?: EventStreamOptions['redisPublish']; + + constructor(opts: EventStreamOptions) { + super(); + this.wsServer = opts.wsServer; + this.redisPublish = opts.redisPublish; + + // Subscribe to cross-instance events + if (opts.redisSubscribe) { + void opts.redisSubscribe(REDIS_CHANNEL, (raw) => { + try { + const event = JSON.parse(raw) as NormalizedEvent; + this.pushToClients(event); + } catch { + // ignore malformed messages + } + }); + } + } + + /** + * Called by SorobanListener / EvmListener when a new event is confirmed. + */ + async ingest(event: NormalizedEvent): Promise { + this.emit('indexed', event); + + // Fan-out to local WS clients + this.pushToClients(event); + + // Fan-out to other instances via Redis pub/sub + if (this.redisPublish) { + await this.redisPublish(REDIS_CHANNEL, JSON.stringify(event)).catch(() => { + // non-fatal: local delivery already done + }); + } + } + + private pushToClients(event: NormalizedEvent): void { + // Broadcast to the contract-specific channel AND the wildcard channel + const channels = [ + `indexer.${event.chain}.${event.contractAddress.toLowerCase()}`, + `indexer.${event.chain}.all`, + 'indexer.all', + ]; + + for (const channel of channels) { + this.wsServer.broadcastToChannel(channel, { + type: 'indexer.event', + payload: { + id: event.id, + chain: event.chain, + contractAddress: event.contractAddress, + eventType: event.eventType, + blockNumber: event.blockNumber, + txHash: event.txHash, + timestamp: event.timestamp.toISOString(), + payload: event.payload, + confirmations: event.confirmations, + }, + }); + } + } + + /** Filter helper used by the REST history endpoint */ + static matchesFilter(event: NormalizedEvent, filter: EventFilter): boolean { + if (filter.contractAddress && event.contractAddress.toLowerCase() !== filter.contractAddress.toLowerCase()) { + return false; + } + if (filter.eventType && event.eventType !== filter.eventType) return false; + if (filter.fromTimestamp && event.timestamp < filter.fromTimestamp) return false; + if (filter.toTimestamp && event.timestamp > filter.toTimestamp) return false; + return true; + } +} diff --git a/sdks/go/examples/basic/main.go b/sdks/go/examples/basic/main.go new file mode 100644 index 00000000..ebd1bcc4 --- /dev/null +++ b/sdks/go/examples/basic/main.go @@ -0,0 +1,78 @@ +// Package main demonstrates common SubTrackr Go SDK workflows. +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/Smartdevs17/subtrackr/sdks/go/subtrackr" +) + +func main() { + client := subtrackr.New("https://api.subtrackr.io", "sk_live_your_key_here") + ctx := context.Background() + + // ── Create a subscription ───────────────────────────────────────────────── + sub, err := client.Subscriptions.Create(ctx, subtrackr.CreateSubscriptionParams{ + CustomerID: "cus_abc123", + PlanID: "plan_monthly", + TrialDays: 14, + }) + if err != nil { + log.Fatalf("create subscription: %v", err) + } + fmt.Printf("Created subscription %s (status: %s)\n", sub.ID, sub.Status) + + // ── List subscriptions with pagination ──────────────────────────────────── + page, err := client.Subscriptions.List(ctx, subtrackr.ListParams{Limit: 20}) + if err != nil { + log.Fatalf("list subscriptions: %v", err) + } + fmt.Printf("Total subscriptions: %d\n", page.Total) + + // ── Pause and reactivate ────────────────────────────────────────────────── + resumeAt := time.Now().Add(7 * 24 * time.Hour) + paused, err := client.Subscriptions.Pause(ctx, sub.ID, subtrackr.PauseParams{ResumeAt: &resumeAt}) + if err != nil { + log.Fatalf("pause: %v", err) + } + fmt.Printf("Paused: %s\n", paused.Status) + + active, err := client.Subscriptions.Reactivate(ctx, sub.ID) + if err != nil { + log.Fatalf("reactivate: %v", err) + } + fmt.Printf("Reactivated: %s\n", active.Status) + + // ── Report metered usage ────────────────────────────────────────────────── + _, err = client.Metering.Report(ctx, subtrackr.ReportUsageParams{ + SubscriptionID: sub.ID, + Feature: "api_calls", + Quantity: 150, + Timestamp: time.Now(), + }) + if err != nil { + log.Fatalf("report usage: %v", err) + } + + // ── Verify a webhook ────────────────────────────────────────────────────── + body := []byte(`{"id":"evt_1","type":"subscription.created","createdAt":"2026-01-01T00:00:00Z","data":{}}`) + event, err := client.Webhooks.Verify("whsec_your_secret", "sha256=...", body) + if err != nil { + fmt.Println("Webhook signature invalid:", err) + } else { + fmt.Printf("Received event: %s\n", event.Type) + } + + // ── Cancel at period end ────────────────────────────────────────────────── + cancelled, err := client.Subscriptions.Cancel(ctx, sub.ID, subtrackr.CancelParams{ + Immediately: false, + Reason: "customer request", + }) + if err != nil { + log.Fatalf("cancel: %v", err) + } + fmt.Printf("Cancellation scheduled: %s\n", cancelled.Status) +} diff --git a/sdks/go/go.mod b/sdks/go/go.mod new file mode 100644 index 00000000..952e467a --- /dev/null +++ b/sdks/go/go.mod @@ -0,0 +1,3 @@ +module github.com/Smartdevs17/subtrackr/sdks/go + +go 1.21 diff --git a/sdks/go/subtrackr/client.go b/sdks/go/subtrackr/client.go new file mode 100644 index 00000000..f74a91c8 --- /dev/null +++ b/sdks/go/subtrackr/client.go @@ -0,0 +1,181 @@ +// Package subtrackr provides a Go SDK for the SubTrackr subscription management API. +// It covers the full subscription lifecycle: create, read, update, cancel, pause, +// reactivate, dunning, billing, metering, and webhook verification. +// +// Usage: +// +// client := subtrackr.New("https://api.subtrackr.io", "sk_live_...") +// sub, err := client.Subscriptions.Create(ctx, subtrackr.CreateSubscriptionParams{...}) +package subtrackr + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "net/url" + "strconv" + "time" +) + +// ─── Client ─────────────────────────────────────────────────────────────────── + +// Client is the root SubTrackr API client. +type Client struct { + baseURL string + apiKey string + httpClient *http.Client + + Subscriptions *SubscriptionService + Billing *BillingService + Dunning *DunningService + Metering *MeteringService + Webhooks *WebhookService +} + +// New creates a new SubTrackr client. +func New(baseURL, apiKey string) *Client { + c := &Client{ + baseURL: baseURL, + apiKey: apiKey, + httpClient: &http.Client{Timeout: 30 * time.Second}, + } + c.Subscriptions = &SubscriptionService{client: c} + c.Billing = &BillingService{client: c} + c.Dunning = &DunningService{client: c} + c.Metering = &MeteringService{client: c} + c.Webhooks = &WebhookService{client: c} + return c +} + +// ─── HTTP helpers ───────────────────────────────────────────────────────────── + +// APIError represents an error returned by the API. +type APIError struct { + StatusCode int `json:"-"` + Code string `json:"code"` + Message string `json:"message"` +} + +func (e *APIError) Error() string { + return fmt.Sprintf("subtrackr: %s (HTTP %d, code=%s)", e.Message, e.StatusCode, e.Code) +} + +func (c *Client) do(ctx context.Context, method, path string, body, out interface{}) error { + return c.doWithRetry(ctx, method, path, body, out, 3) +} + +func (c *Client) doWithRetry(ctx context.Context, method, path string, body, out interface{}, maxAttempts int) error { + var lastErr error + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + wait := time.Duration(math.Pow(2, float64(attempt))) * 200 * time.Millisecond + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(wait): + } + } + err := c.doOnce(ctx, method, path, body, out) + if err == nil { + return nil + } + // Retry on 429 or 5xx + if apiErr, ok := err.(*APIError); ok { + if apiErr.StatusCode == 429 || apiErr.StatusCode >= 500 { + lastErr = err + continue + } + } + return err // Non-retryable + } + return lastErr +} + +func (c *Client) doOnce(ctx context.Context, method, path string, body, out interface{}) error { + var reqBody io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("subtrackr: marshal request: %w", err) + } + reqBody = bytes.NewReader(b) + } + + req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, reqBody) + if err != nil { + return fmt.Errorf("subtrackr: build request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("subtrackr: http: %w", err) + } + defer resp.Body.Close() + + respBytes, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("subtrackr: read response: %w", err) + } + + if resp.StatusCode >= 400 { + var apiErr APIError + apiErr.StatusCode = resp.StatusCode + _ = json.Unmarshal(respBytes, &apiErr) + if apiErr.Message == "" { + apiErr.Message = http.StatusText(resp.StatusCode) + } + return &apiErr + } + + if out != nil { + if err := json.Unmarshal(respBytes, out); err != nil { + return fmt.Errorf("subtrackr: unmarshal response: %w", err) + } + } + return nil +} + +// ─── Pagination ─────────────────────────────────────────────────────────────── + +// ListParams contains common pagination parameters. +type ListParams struct { + Limit int `json:"-"` + Offset int `json:"-"` + Cursor string `json:"-"` +} + +func (p ListParams) toQuery() string { + q := url.Values{} + if p.Limit > 0 { + q.Set("limit", strconv.Itoa(p.Limit)) + } + if p.Offset > 0 { + q.Set("offset", strconv.Itoa(p.Offset)) + } + if p.Cursor != "" { + q.Set("cursor", p.Cursor) + } + if len(q) == 0 { + return "" + } + return "?" + q.Encode() +} + +// Page is a generic paginated response wrapper. +type Page[T any] struct { + Data []T `json:"data"` + Total int `json:"total"` + Limit int `json:"limit"` + Offset int `json:"offset"` + NextCursor string `json:"nextCursor,omitempty"` +} diff --git a/sdks/go/subtrackr/client_test.go b/sdks/go/subtrackr/client_test.go new file mode 100644 index 00000000..1d912ecb --- /dev/null +++ b/sdks/go/subtrackr/client_test.go @@ -0,0 +1,221 @@ +package subtrackr_test + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Smartdevs17/subtrackr/sdks/go/subtrackr" +) + +// newTestServer creates a test HTTP server that always returns the given status +// and JSON-encoded body. +func newTestServer(t *testing.T, status int, body interface{}) (*httptest.Server, *subtrackr.Client) { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) + })) + t.Cleanup(srv.Close) + return srv, subtrackr.New(srv.URL, "test_key") +} + +// ─── Subscription lifecycle ─────────────────────────────────────────────────── + +func TestSubscriptionCreate(t *testing.T) { + now := time.Now().UTC().Truncate(time.Second) + want := subtrackr.Subscription{ + ID: "sub_001", + CustomerID: "cus_001", + PlanID: "plan_monthly", + Status: subtrackr.SubscriptionStatusActive, + CreatedAt: now, + UpdatedAt: now, + } + _, client := newTestServer(t, http.StatusOK, want) + + got, err := client.Subscriptions.Create(context.Background(), subtrackr.CreateSubscriptionParams{ + CustomerID: "cus_001", + PlanID: "plan_monthly", + }) + if err != nil { + t.Fatalf("Create: unexpected error: %v", err) + } + if got.ID != want.ID { + t.Errorf("ID: got %q, want %q", got.ID, want.ID) + } + if got.Status != want.Status { + t.Errorf("Status: got %q, want %q", got.Status, want.Status) + } +} + +func TestSubscriptionGet(t *testing.T) { + want := subtrackr.Subscription{ID: "sub_002", Status: subtrackr.SubscriptionStatusPaused} + _, client := newTestServer(t, http.StatusOK, want) + + got, err := client.Subscriptions.Get(context.Background(), "sub_002") + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.ID != "sub_002" { + t.Errorf("ID mismatch: %s", got.ID) + } +} + +func TestSubscriptionUpdate(t *testing.T) { + newPlan := "plan_yearly" + want := subtrackr.Subscription{ID: "sub_003", PlanID: newPlan, Status: subtrackr.SubscriptionStatusActive} + _, client := newTestServer(t, http.StatusOK, want) + + got, err := client.Subscriptions.Update(context.Background(), "sub_003", subtrackr.UpdateSubscriptionParams{ + PlanID: &newPlan, + }) + if err != nil { + t.Fatalf("Update: %v", err) + } + if got.PlanID != newPlan { + t.Errorf("PlanID: got %q, want %q", got.PlanID, newPlan) + } +} + +func TestSubscriptionCancel(t *testing.T) { + now := time.Now() + want := subtrackr.Subscription{ID: "sub_004", Status: subtrackr.SubscriptionStatusCancelled, CancelledAt: &now} + _, client := newTestServer(t, http.StatusOK, want) + + got, err := client.Subscriptions.Cancel(context.Background(), "sub_004", subtrackr.CancelParams{Immediately: true}) + if err != nil { + t.Fatalf("Cancel: %v", err) + } + if got.Status != subtrackr.SubscriptionStatusCancelled { + t.Errorf("Status: got %q", got.Status) + } +} + +func TestSubscriptionPauseReactivate(t *testing.T) { + tests := []struct { + name string + serverBody subtrackr.Subscription + run func(c *subtrackr.Client) (*subtrackr.Subscription, error) + }{ + { + name: "pause", + serverBody: subtrackr.Subscription{ID: "sub_005", Status: subtrackr.SubscriptionStatusPaused}, + run: func(c *subtrackr.Client) (*subtrackr.Subscription, error) { + return c.Subscriptions.Pause(context.Background(), "sub_005", subtrackr.PauseParams{}) + }, + }, + { + name: "reactivate", + serverBody: subtrackr.Subscription{ID: "sub_005", Status: subtrackr.SubscriptionStatusActive}, + run: func(c *subtrackr.Client) (*subtrackr.Subscription, error) { + return c.Subscriptions.Reactivate(context.Background(), "sub_005") + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, client := newTestServer(t, http.StatusOK, tc.serverBody) + got, err := tc.run(client) + if err != nil { + t.Fatalf("%s: %v", tc.name, err) + } + if got.Status != tc.serverBody.Status { + t.Errorf("Status: got %q, want %q", got.Status, tc.serverBody.Status) + } + }) + } +} + +// ─── Webhook signature verification ────────────────────────────────────────── + +func makeSignature(secret string, body []byte) string { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(body) + return "sha256=" + hex.EncodeToString(mac.Sum(nil)) +} + +func TestWebhookVerify(t *testing.T) { + client := subtrackr.New("https://example.com", "key") + secret := "whsec_test" + body := []byte(`{"id":"evt_1","type":"subscription.created","createdAt":"2026-01-01T00:00:00Z","data":{}}`) + validSig := makeSignature(secret, body) + + tests := []struct { + name string + sig string + wantErr bool + }{ + {"valid signature", validSig, false}, + {"wrong signature", "sha256=deadbeef", true}, + {"empty signature", "", true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + evt, err := client.Webhooks.Verify(secret, tc.sig, body) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if evt.ID != "evt_1" { + t.Errorf("event ID: got %q", evt.ID) + } + }) + } +} + +// ─── Error handling ─────────────────────────────────────────────────────────── + +func TestAPIError(t *testing.T) { + errBody := map[string]string{"code": "not_found", "message": "subscription not found"} + _, client := newTestServer(t, http.StatusNotFound, errBody) + + _, err := client.Subscriptions.Get(context.Background(), "sub_missing") + if err == nil { + t.Fatal("expected error, got nil") + } + apiErr, ok := err.(*subtrackr.APIError) + if !ok { + t.Fatalf("expected *APIError, got %T", err) + } + if apiErr.StatusCode != http.StatusNotFound { + t.Errorf("StatusCode: got %d, want %d", apiErr.StatusCode, http.StatusNotFound) + } + if apiErr.Code != "not_found" { + t.Errorf("Code: got %q, want %q", apiErr.Code, "not_found") + } +} + +func TestListPagination(t *testing.T) { + want := subtrackr.Page[subtrackr.Subscription]{ + Data: []subtrackr.Subscription{{ID: "sub_p1"}, {ID: "sub_p2"}}, + Total: 100, + Limit: 2, + } + _, client := newTestServer(t, http.StatusOK, want) + + page, err := client.Subscriptions.List(context.Background(), subtrackr.ListParams{Limit: 2}) + if err != nil { + t.Fatalf("List: %v", err) + } + if page.Total != 100 { + t.Errorf("Total: got %d, want 100", page.Total) + } + if len(page.Data) != 2 { + t.Errorf("len(Data): got %d, want 2", len(page.Data)) + } +} diff --git a/sdks/go/subtrackr/services.go b/sdks/go/subtrackr/services.go new file mode 100644 index 00000000..180595ba --- /dev/null +++ b/sdks/go/subtrackr/services.go @@ -0,0 +1,168 @@ +package subtrackr + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "time" +) + +// ─── Billing ────────────────────────────────────────────────────────────────── + +// Invoice represents a billing invoice. +type Invoice struct { + ID string `json:"id"` + SubscriptionID string `json:"subscriptionId"` + CustomerID string `json:"customerId"` + AmountDue int64 `json:"amountDue"` // smallest currency unit + AmountPaid int64 `json:"amountPaid"` + Currency string `json:"currency"` + Status string `json:"status"` + DueDate *time.Time `json:"dueDate,omitempty"` + PaidAt *time.Time `json:"paidAt,omitempty"` + CreatedAt time.Time `json:"createdAt"` +} + +// BillingService manages invoices and payment retries. +type BillingService struct { + client *Client +} + +// GetInvoice retrieves a specific invoice. +func (b *BillingService) GetInvoice(ctx context.Context, id string) (*Invoice, error) { + var out Invoice + return &out, b.client.do(ctx, "GET", fmt.Sprintf("/api/v1/billing/invoices/%s", id), nil, &out) +} + +// ListInvoices lists invoices for a subscription. +func (b *BillingService) ListInvoices(ctx context.Context, subscriptionID string, p ListParams) (*Page[Invoice], error) { + var out Page[Invoice] + path := fmt.Sprintf("/api/v1/billing/invoices?subscriptionId=%s%s", subscriptionID, p.toQuery()) + return &out, b.client.do(ctx, "GET", path, nil, &out) +} + +// RetryInvoice retries payment for a failed invoice. +func (b *BillingService) RetryInvoice(ctx context.Context, invoiceID string) (*Invoice, error) { + var out Invoice + return &out, b.client.do(ctx, "POST", fmt.Sprintf("/api/v1/billing/invoices/%s/retry", invoiceID), nil, &out) +} + +// ─── Dunning ────────────────────────────────────────────────────────────────── + +// DunningRecord represents the dunning state for a subscription. +type DunningRecord struct { + SubscriptionID string `json:"subscriptionId"` + Attempts int `json:"attempts"` + NextRetryAt *time.Time `json:"nextRetryAt,omitempty"` + Status string `json:"status"` // "active" | "resolved" | "failed" + LastError string `json:"lastError,omitempty"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// DunningService manages dunning (failed-payment recovery) workflows. +type DunningService struct { + client *Client +} + +// Get returns the current dunning state for a subscription. +func (d *DunningService) Get(ctx context.Context, subscriptionID string) (*DunningRecord, error) { + var out DunningRecord + return &out, d.client.do(ctx, "GET", fmt.Sprintf("/api/v1/dunning/%s", subscriptionID), nil, &out) +} + +// Resolve manually marks a dunning cycle as resolved (e.g., after out-of-band payment). +func (d *DunningService) Resolve(ctx context.Context, subscriptionID string) (*DunningRecord, error) { + var out DunningRecord + return &out, d.client.do(ctx, "POST", fmt.Sprintf("/api/v1/dunning/%s/resolve", subscriptionID), nil, &out) +} + +// ─── Metering ───────────────────────────────────────────────────────────────── + +// UsageRecord records metered usage for a subscription feature. +type UsageRecord struct { + ID string `json:"id"` + SubscriptionID string `json:"subscriptionId"` + Feature string `json:"feature"` + Quantity float64 `json:"quantity"` + Timestamp time.Time `json:"timestamp"` +} + +// ReportUsageParams describes a usage event to report. +type ReportUsageParams struct { + SubscriptionID string `json:"subscriptionId"` + Feature string `json:"feature"` + Quantity float64 `json:"quantity"` + Timestamp time.Time `json:"timestamp,omitempty"` +} + +// UsageSummary summarises metered usage for a billing period. +type UsageSummary struct { + SubscriptionID string `json:"subscriptionId"` + PeriodStart time.Time `json:"periodStart"` + PeriodEnd time.Time `json:"periodEnd"` + Totals map[string]float64 `json:"totals"` // feature -> total quantity +} + +// MeteringService manages usage metering. +type MeteringService struct { + client *Client +} + +// Report records a usage event. +func (m *MeteringService) Report(ctx context.Context, p ReportUsageParams) (*UsageRecord, error) { + var out UsageRecord + return &out, m.client.do(ctx, "POST", "/api/v1/metering/usage", p, &out) +} + +// GetSummary returns aggregated usage for a subscription's current billing period. +func (m *MeteringService) GetSummary(ctx context.Context, subscriptionID string) (*UsageSummary, error) { + var out UsageSummary + return &out, m.client.do(ctx, "GET", + fmt.Sprintf("/api/v1/metering/usage/%s/summary", subscriptionID), nil, &out) +} + +// ─── Webhooks ───────────────────────────────────────────────────────────────── + +// WebhookEvent is the parsed payload of an inbound webhook. +type WebhookEvent struct { + ID string `json:"id"` + Type string `json:"type"` + CreatedAt time.Time `json:"createdAt"` + Data json.RawMessage `json:"data"` +} + +// ErrInvalidSignature is returned when webhook signature verification fails. +var ErrInvalidSignature = errors.New("subtrackr: invalid webhook signature") + +// WebhookService handles webhook verification. +type WebhookService struct { + client *Client +} + +// Verify verifies the HMAC-SHA256 signature of an incoming webhook and returns +// the parsed event. sigHeader is the value of the X-SubTrackr-Signature header. +// +// Timing-safe comparison is used to prevent side-channel attacks. +func (w *WebhookService) Verify(secret, sigHeader string, body []byte) (*WebhookEvent, error) { + if sigHeader == "" { + return nil, ErrInvalidSignature + } + + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(body) + expected := "sha256=" + hex.EncodeToString(mac.Sum(nil)) + + if !hmac.Equal([]byte(expected), []byte(sigHeader)) { + return nil, ErrInvalidSignature + } + + var evt WebhookEvent + if err := json.Unmarshal(body, &evt); err != nil { + return nil, fmt.Errorf("subtrackr: parse webhook event: %w", err) + } + return &evt, nil +} diff --git a/sdks/go/subtrackr/subscriptions.go b/sdks/go/subtrackr/subscriptions.go new file mode 100644 index 00000000..76833f28 --- /dev/null +++ b/sdks/go/subtrackr/subscriptions.go @@ -0,0 +1,147 @@ +package subtrackr + +import ( + "context" + "fmt" + "time" +) + +// ─── Types ──────────────────────────────────────────────────────────────────── + +// SubscriptionStatus represents the current lifecycle state of a subscription. +type SubscriptionStatus string + +const ( + SubscriptionStatusActive SubscriptionStatus = "active" + SubscriptionStatusPaused SubscriptionStatus = "paused" + SubscriptionStatusCancelled SubscriptionStatus = "cancelled" + SubscriptionStatusTrialing SubscriptionStatus = "trialing" + SubscriptionStatusPastDue SubscriptionStatus = "past_due" + SubscriptionStatusUnpaid SubscriptionStatus = "unpaid" +) + +// BillingInterval represents the billing frequency. +type BillingInterval string + +const ( + BillingIntervalDaily BillingInterval = "daily" + BillingIntervalWeekly BillingInterval = "weekly" + BillingIntervalMonthly BillingInterval = "monthly" + BillingIntervalYearly BillingInterval = "yearly" +) + +// Subscription is the full subscription resource. +type Subscription struct { + ID string `json:"id"` + CustomerID string `json:"customerId"` + PlanID string `json:"planId"` + Status SubscriptionStatus `json:"status"` + CurrentPeriodStart time.Time `json:"currentPeriodStart"` + CurrentPeriodEnd time.Time `json:"currentPeriodEnd"` + CancelAtPeriodEnd bool `json:"cancelAtPeriodEnd"` + PausedAt *time.Time `json:"pausedAt,omitempty"` + CancelledAt *time.Time `json:"cancelledAt,omitempty"` + TrialEnd *time.Time `json:"trialEnd,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +// CreateSubscriptionParams are the parameters for creating a subscription. +type CreateSubscriptionParams struct { + CustomerID string `json:"customerId"` + PlanID string `json:"planId"` + TrialDays int `json:"trialDays,omitempty"` + CancelAtPeriodEnd bool `json:"cancelAtPeriodEnd,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// UpdateSubscriptionParams are the parameters for updating a subscription. +type UpdateSubscriptionParams struct { + PlanID *string `json:"planId,omitempty"` + CancelAtPeriodEnd *bool `json:"cancelAtPeriodEnd,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// PauseParams controls subscription pause behaviour. +type PauseParams struct { + // ResumeAt, if non-nil, schedules automatic reactivation. + ResumeAt *time.Time `json:"resumeAt,omitempty"` +} + +// CancelParams controls subscription cancellation behaviour. +type CancelParams struct { + // Immediately cancels when true; otherwise cancels at period end. + Immediately bool `json:"immediately,omitempty"` + Reason string `json:"reason,omitempty"` +} + +// ─── Service ────────────────────────────────────────────────────────────────── + +// SubscriptionService exposes all subscription lifecycle operations. +type SubscriptionService struct { + client *Client +} + +// Create creates a new subscription. +func (s *SubscriptionService) Create(ctx context.Context, p CreateSubscriptionParams) (*Subscription, error) { + var out Subscription + if err := s.client.do(ctx, "POST", "/api/v1/subscriptions", p, &out); err != nil { + return nil, err + } + return &out, nil +} + +// Get retrieves a subscription by ID. +func (s *SubscriptionService) Get(ctx context.Context, id string) (*Subscription, error) { + var out Subscription + if err := s.client.do(ctx, "GET", fmt.Sprintf("/api/v1/subscriptions/%s", id), nil, &out); err != nil { + return nil, err + } + return &out, nil +} + +// List returns a paginated list of subscriptions. +func (s *SubscriptionService) List(ctx context.Context, p ListParams) (*Page[Subscription], error) { + var out Page[Subscription] + if err := s.client.do(ctx, "GET", "/api/v1/subscriptions"+p.toQuery(), nil, &out); err != nil { + return nil, err + } + return &out, nil +} + +// Update updates mutable fields on a subscription. +func (s *SubscriptionService) Update(ctx context.Context, id string, p UpdateSubscriptionParams) (*Subscription, error) { + var out Subscription + if err := s.client.do(ctx, "PATCH", fmt.Sprintf("/api/v1/subscriptions/%s", id), p, &out); err != nil { + return nil, err + } + return &out, nil +} + +// Cancel cancels a subscription, either immediately or at period end. +func (s *SubscriptionService) Cancel(ctx context.Context, id string, p CancelParams) (*Subscription, error) { + var out Subscription + if err := s.client.do(ctx, "DELETE", fmt.Sprintf("/api/v1/subscriptions/%s", id), p, &out); err != nil { + return nil, err + } + return &out, nil +} + +// Pause pauses a subscription, optionally scheduling automatic reactivation. +func (s *SubscriptionService) Pause(ctx context.Context, id string, p PauseParams) (*Subscription, error) { + var out Subscription + if err := s.client.do(ctx, "POST", fmt.Sprintf("/api/v1/subscriptions/%s/pause", id), p, &out); err != nil { + return nil, err + } + return &out, nil +} + +// Reactivate reactivates a paused or cancelled subscription. +func (s *SubscriptionService) Reactivate(ctx context.Context, id string) (*Subscription, error) { + var out Subscription + if err := s.client.do(ctx, "POST", fmt.Sprintf("/api/v1/subscriptions/%s/reactivate", id), nil, &out); err != nil { + return nil, err + } + return &out, nil +}