diff --git a/README.md b/README.md index 57b11c3..24749ac 100644 --- a/README.md +++ b/README.md @@ -240,6 +240,33 @@ Add to your Claude Desktop configuration (`claude_desktop_config.json`): ### Claude Code CLI +#### Windows + +Add the following to `%USERPROFILE%\.claude.json`: + +```json +{ + "mcpServers": { + "sqlserver": { + "command": "mcp-sqlserver", + "env": { + "SQLSERVER_HOST": "your-server", + "SQLSERVER_USER": "your-username", + "SQLSERVER_PASSWORD": "your-password", + "SQLSERVER_DATABASE": "your-database", + "SQLSERVER_PORT": "1433", + "SQLSERVER_ENCRYPT": "true", + "SQLSERVER_TRUST_CERT": "false" + } + } + } +} +``` + +#### macOS/Linux + +Add the server using the Claude Code CLI: + ```bash # Set environment variables export SQLSERVER_HOST="your-server" @@ -250,6 +277,27 @@ export SQLSERVER_PASSWORD="your-password" claude mcp add sqlserver mcp-sqlserver ``` +Or configure via `~/.claude.json`: + +```json +{ + "mcpServers": { + "sqlserver": { + "command": "mcp-sqlserver", + "env": { + "SQLSERVER_HOST": "your-server", + "SQLSERVER_USER": "your-username", + "SQLSERVER_PASSWORD": "your-password", + "SQLSERVER_DATABASE": "your-database", + "SQLSERVER_PORT": "1433", + "SQLSERVER_ENCRYPT": "true", + "SQLSERVER_TRUST_CERT": "false" + } + } + } +} +``` + ### VSCode with MCP Extension Install the MCP extension for VSCode and add the server configuration. @@ -325,12 +373,35 @@ npm test ## Troubleshooting ### Connection Issues + +Use the `--test-connection` flag to diagnose connectivity problems: + +```bash +# Test your connection (shows detailed diagnostics) +mcp-sqlserver --test-connection +# or use the short flag +mcp-sqlserver -t +``` + +The command provides helpful error suggestions based on the error type: + +| Error Type | Likely Cause | Suggested Fix | +|------------|-------------|---------------| +| `ENOTFOUND` / Server not found | Incorrect hostname | Check `SQLSERVER_HOST` | +| Login failed (18456) | Wrong username/password | Verify credentials | +| Timeout | Server unreachable or wrong port | Check port (default: 1433) | +| SSL/Certificate error | Self-signed certificate | Set `SQLSERVER_TRUST_CERT=true` | +| Connection refused | Firewall or SQL Server not running | Check server is running | + +### Before Troubleshooting + 1. Verify server hostname and port 2. Check if encryption/certificate settings match your SQL Server configuration 3. Ensure user has appropriate read permissions 4. Test connection using SQL Server Management Studio first ### Permission Issues + The user account needs at minimum: - `CONNECT` permission to the database - `SELECT` permission on tables/views you want to query diff --git a/jest.config.js b/jest.config.js new file mode 100644 index 0000000..b655cc2 --- /dev/null +++ b/jest.config.js @@ -0,0 +1,12 @@ +export default { + testEnvironment: 'node', + extensionsToTreatAsEsm: ['.ts'], + moduleNameMapper: { + '^(\\.{1,2}/.*)\\.js$': '$1', + }, + transform: { + '^.+\\.ts$': ['ts-jest', { useESM: true }], + }, + testMatch: ['**/test/**/*.test.ts'], + collectCoverageFrom: ['src/**/*.ts', '!src/**/*.d.ts', '!src/__tests__/**'], +}; diff --git a/package.json b/package.json index 63424e7..e46ca3b 100644 --- a/package.json +++ b/package.json @@ -54,16 +54,18 @@ "zod": "^3.23.8" }, "devDependencies": { + "@types/jest": "^30.0.0", "@types/mssql": "^9.1.5", "@types/node": "^22.0.0", "@typescript-eslint/eslint-plugin": "^8.0.0", "@typescript-eslint/parser": "^8.0.0", "eslint": "^9.0.0", "jest": "^29.7.0", + "ts-jest": "^29.4.9", "tsx": "^4.16.0", "typescript": "^5.5.0" }, "engines": { "node": ">=18" } -} \ No newline at end of file +} diff --git a/src/cli.ts b/src/cli.ts index d1aa0dc..40bfd1b 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -1,6 +1,7 @@ #!/usr/bin/env node import { ConnectionConfigSchema } from './types.js'; +import { SqlServerConnection } from './connection.js'; function showHelp() { console.log(` @@ -11,12 +12,17 @@ USAGE: ENVIRONMENT VARIABLES: SQLSERVER_HOST SQL Server hostname (required) - SQLSERVER_USER Database username (required) + SQLSERVER_USER Database username (required) SQLSERVER_PASSWORD Database password (required) SQLSERVER_DATABASE Database name (optional, default: master) SQLSERVER_PORT Port number (optional, default: 1433) SQLSERVER_ENCRYPT Enable encryption (optional, default: true) - SQLSERVER_TRUST_CERT Trust server certificate (optional, default: true) + SQLSERVER_TRUST_CERT Trust server certificate (optional, default: false) + +OPTIONS: + --help, -h Show this help message + --version, -v Show version number + --test-connection Test SQL Server connection and exit EXAMPLES: # Set environment variables and run @@ -25,6 +31,9 @@ EXAMPLES: export SQLSERVER_PASSWORD="your-password" mcp-sqlserver + # Test connection + mcp-sqlserver --test-connection + # Using Claude Desktop (add to claude_desktop_config.json): { "mcpServers": { @@ -62,13 +71,140 @@ For more information, visit: https://github.com/bilims/mcp-sqlserver function showVersion() { // Read version from package.json - console.log('2.0.1'); + console.log('2.0.3'); +} + +async function testConnection(): Promise { + const config = { + server: process.env.SQLSERVER_HOST || 'localhost', + user: process.env.SQLSERVER_USER || '', + password: process.env.SQLSERVER_PASSWORD || '', + database: process.env.SQLSERVER_DATABASE, + port: parseInt(process.env.SQLSERVER_PORT || '1433'), + encrypt: process.env.SQLSERVER_ENCRYPT !== 'false', + trustServerCertificate: process.env.SQLSERVER_TRUST_CERT !== 'false', + connectionTimeout: parseInt(process.env.SQLSERVER_CONNECTION_TIMEOUT || '15000'), + requestTimeout: parseInt(process.env.SQLSERVER_REQUEST_TIMEOUT || '30000'), + maxRows: parseInt(process.env.SQLSERVER_MAX_ROWS || '1000'), + }; + + // Validate config + try { + ConnectionConfigSchema.parse(config); + } catch (error) { + console.error('āŒ Configuration error:'); + console.error(error); + process.exit(1); + } + + if (!config.user || !config.password) { + console.error('āŒ Missing credentials:'); + console.error(' SQLSERVER_USER and SQLSERVER_PASSWORD are required'); + process.exit(1); + } + + console.log(`\nšŸ” Testing connection to ${config.server}:${config.port}/${config.database || 'default'}`); + console.log(` Encryption: ${config.encrypt ? 'enabled' : 'disabled'}`); + console.log(` Trust Certificate: ${config.trustServerCertificate ? 'yes' : 'no'}`); + console.log(''); + + const startTime = Date.now(); + const connection = new SqlServerConnection(config); + + try { + await connection.connect(); + const connectionTime = Date.now() - startTime; + + console.log('āœ… Connection successful!'); + console.log(` Connection time: ${connectionTime}ms`); + console.log(` Connected: ${connection.isConnected()}`); + + // Try to get server info + try { + const result = await connection.query(` + SELECT + @@SERVERNAME as serverName, + @@VERSION as version, + DB_NAME() as currentDatabase + `); + + if (result.recordset.length > 0) { + const info = result.recordset[0]; + console.log(`\nšŸ“‹ Server Info:`); + console.log(` Server: ${info.serverName}`); + console.log(` Database: ${info.currentDatabase}`); + // Version is multi-line, just show first line + const versionLine = info.version.split('\n')[0]; + console.log(` Version: ${versionLine}`); + } + } catch (queryError) { + // Server info is nice-to-have, don't fail on this + } + + await connection.disconnect(); + process.exit(0); + + } catch (error) { + const connectionTime = Date.now() - startTime; + const errorMessage = error instanceof Error ? error.message : String(error); + + console.error('āŒ Connection failed!'); + console.error(` Error: ${errorMessage}`); + console.error(` Time: ${connectionTime}ms`); + + // Provide helpful suggestions based on error + const suggestion = getConnectionSuggestion(errorMessage); + if (suggestion) { + console.error(`\nšŸ’” ${suggestion}`); + } + + // Show environment info for debugging + console.error(`\nšŸ”§ Environment:`); + console.error(` Host: ${config.server}:${config.port}`); + console.error(` Database: ${config.database || '(default)'}`); + console.error(` Encrypt: ${config.encrypt}`); + console.error(` Trust Cert: ${config.trustServerCertificate}`); + + await connection.disconnect().catch(() => {}); + process.exit(1); + } +} + +function getConnectionSuggestion(errorMessage: string): string | null { + const msg = errorMessage.toLowerCase(); + + if (msg.includes('login failed') || msg.includes('18456')) { + return 'Check your username and password. SQL Server authentication failed.'; + } + if (msg.includes('enotfound') || msg.includes('server was not found') || msg.includes('could not be located') || msg.includes('name or service not known')) { + return 'Check your server hostname. The SQL Server instance could not be found.'; + } + if (msg.includes('timeout') || msg.includes('-2') || msg.includes('timed out')) { + return 'Connection timed out. Check if the server is reachable and the port is correct.'; + } + if (msg.includes('ssl') || msg.includes('certificate') || msg.includes('tls') || msg.includes('ssl_ctx')) { + return 'SSL/Certificate error. Try setting SQLSERVER_TRUST_CERT=true if using a self-signed cert.'; + } + if (msg.includes('encrypt') || msg.includes('handshake') || msg.includes('ssl_negotiate')) { + return 'Encryption handshake failed. Try SQLSERVER_ENCRYPT=false or check certificate configuration.'; + } + if (msg.includes('port') || msg.includes('connection refused')) { + return 'Connection refused. Check your port number and ensure SQL Server is running.'; + } + if (msg.includes('named pipes')) { + return 'Named pipes error. Try using TCP/IP connection instead.'; + } + if (msg.includes('permission') || msg.includes('access') || msg.includes('denied')) { + return 'Permission denied. Your user may not have access to this database.'; + } + + return null; } function validateEnvironment(): boolean { const required = ['SQLSERVER_HOST', 'SQLSERVER_USER', 'SQLSERVER_PASSWORD']; const missing = required.filter(env => !process.env[env]); - + if (missing.length > 0) { console.error('āŒ Missing required environment variables:'); missing.forEach(env => { @@ -100,17 +236,22 @@ function validateEnvironment(): boolean { export function handleCliArgs(): boolean { const args = process.argv.slice(2); - + if (args.includes('--help') || args.includes('-h')) { showHelp(); return false; } - + if (args.includes('--version') || args.includes('-v')) { showVersion(); return false; } + if (args.includes('--test-connection') || args.includes('-t')) { + testConnection(); + return false; // testConnection exits on its own + } + if (!validateEnvironment()) { process.exit(1); } diff --git a/src/connection.ts b/src/connection.ts index 5f4f112..01e632f 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -68,6 +68,9 @@ export class SqlServerConnection { } getConfig(): Readonly { - return { ...this.config }; + // Return config with password redacted for security + const redactedConfig = { ...this.config }; + redactedConfig.password = '********'; + return redactedConfig; } } \ No newline at end of file diff --git a/src/index.ts b/src/index.ts index 2d4395c..418a0fc 100644 --- a/src/index.ts +++ b/src/index.ts @@ -147,7 +147,7 @@ async function runServer() { // Don't connect immediately in MCP mode - defer connection until first tool use // This prevents the server from failing startup if SQL Server is temporarily unavailable console.error(`MCP SQL Server initialized for ${config.server}:${config.port || 1433}`); - console.error(`Database: ${config.database || 'default'}, User: ${config.user}`); + console.error(`Database: ${config.database || 'default'}`); this.initializeTools(config.maxRows || 1000); } catch (error) { diff --git a/src/security.ts b/src/security.ts index 4747004..db25e6b 100644 --- a/src/security.ts +++ b/src/security.ts @@ -17,7 +17,6 @@ export class QueryValidator { 'TRUNCATE', 'EXEC', 'EXECUTE', - 'SP_', 'XP_', 'OPENROWSET', 'OPENDATASOURCE', @@ -26,42 +25,67 @@ export class QueryValidator { 'GRANT', 'REVOKE', 'DENY', + 'BACKUP', + 'RESTORE', + 'KILL', + 'DBCC', + 'SHUTDOWN', + 'RAISERROR', + 'WAITFOR', + // Catch stored procedure prefixes and dangerous patterns + 'SP_', + 'XP_', + // Prevent @ prefix variables being injected + '@', + // Prevent char/ASCII injection for password cracking + 'CHAR(', + 'ASCII(', ]; static validateQuery(query: string): { isValid: boolean; error?: string } { - const normalizedQuery = query.trim().toUpperCase(); + let normalizedQuery = query.trim().toUpperCase(); if (!normalizedQuery) { return { isValid: false, error: 'Empty query not allowed' }; } + // Strip SQL comments FIRST to prevent comment-based bypass + // e.g., SEL/**/ECT or SELECT--comment should be handled + // Replace comments with spaces to prevent word merging (SEL/**/ECT -> SEL ECT, not SELECT) + normalizedQuery = normalizedQuery + .replace(/--[\s\S]*?$/gm, ' ') // Replace -- comments with space + .replace(/\/\*[\s\S]*?\*\//g, ' ') // Replace /* */ comments with space + .replace(/\s+/g, ' ') // Normalize multiple spaces + .trim(); + + // Check for forbidden keywords BEFORE "starts with" check + // This ensures RESTORE, KILL, SHUTDOWN etc. are caught regardless of starting keyword + for (const forbidden of this.FORBIDDEN_KEYWORDS) { + if (normalizedQuery.includes(forbidden)) { + return { + isValid: false, + error: `Forbidden keyword detected: ${forbidden}` + }; + } + } + // Check if query starts with allowed statement - const startsWithAllowed = this.ALLOWED_STATEMENTS.some(stmt => + const startsWithAllowed = this.ALLOWED_STATEMENTS.some(stmt => normalizedQuery.startsWith(stmt) ); if (!startsWithAllowed) { - return { - isValid: false, - error: `Query must start with one of: ${this.ALLOWED_STATEMENTS.join(', ')}` + return { + isValid: false, + error: `Query must start with one of: ${this.ALLOWED_STATEMENTS.join(', ')}` }; } - // Check for forbidden keywords - for (const forbidden of this.FORBIDDEN_KEYWORDS) { - if (normalizedQuery.includes(forbidden)) { - return { - isValid: false, - error: `Forbidden keyword detected: ${forbidden}` - }; - } - } - - // Additional security checks + // Additional security checks for injection patterns if (this.containsSqlInjectionPatterns(normalizedQuery)) { - return { - isValid: false, - error: 'Potential SQL injection pattern detected' + return { + isValid: false, + error: 'Potential SQL injection pattern detected' }; } @@ -89,6 +113,28 @@ export class QueryValidator { `; } + static isValidCallbackUrl(url: string): boolean { + try { + const parsed = new URL(url); + // Only allow HTTPS and localhost/127.0.0.1 + if (parsed.protocol !== 'https:' && parsed.protocol !== 'http:') { + return false; + } + // Block external URLs - only allow localhost for security + // Strip brackets from IPv6 addresses (URL parser keeps them) + let hostname = parsed.hostname.toLowerCase(); + hostname = hostname.replace(/[\[\]]/g, ''); // Remove [ ] from IPv6 + const allowedHosts = ['localhost', '127.0.0.1', '::1', '::ffff:127.0.0.1']; + if (!allowedHosts.includes(hostname)) { + console.warn(`Blue Prompt callback blocked: external URL not allowed (${hostname})`); + return false; + } + return true; + } catch { + return false; + } + } + static async validateQueryWithBluePrompt(query: string): Promise<{ isValid: boolean; error?: string }> { // Step 1: Perform initial static validation const staticValidation = this.validateQuery(query); @@ -99,6 +145,11 @@ export class QueryValidator { // Step 2: Use the host AI platform for validation if a callback URL is provided const callbackUrl = process.env.BLUE_PROMPT_CALLBACK_URL; if (callbackUrl) { + // Validate URL before making callback to prevent SSRF + if (!this.isValidCallbackUrl(callbackUrl)) { + return { isValid: false, error: 'Invalid Blue Prompt callback URL - only localhost allowed for security' }; + } + try { const bluePrompt = this.generateBluePrompt(query); const response = await fetch(callbackUrl, { @@ -129,16 +180,31 @@ export class QueryValidator { } private static containsSqlInjectionPatterns(query: string): boolean { + // Remove SQL comments first (both -- and /* */) to prevent comment-based bypass + const queryWithoutComments = query + .replace(/--[\s\S]*?$/gm, '') // Remove -- comments (including end-of-line variants) + .replace(/\/\*[\s\S]*?\*\//g, ''); // Remove /* */ comments + const patterns = [ - /--/, // SQL comments - /\/\*/, // Multi-line comments - /;.*SELECT/, // Statement injection - /UNION.*SELECT/, // Union injection - /'\s*OR\s*'.*'/, // OR injection - /'\s*AND\s*'.*'/, // AND injection + // Statement injection - semicolon followed by any statement + /;\s*(SELECT|INSERT|UPDATE|DELETE|DROP|CREATE|ALTER|EXEC|EXECUTE)/i, + // Union-based injection + /UNION\s+(ALL\s+)?SELECT/i, + // OR/AND injection patterns with quote escaping + /'\s*(OR|AND)\s+['"\d]/i, + /['"\d]\s*(OR|AND)\s+['"\d]/i, + // Time-based blind injection + /WAITFOR\s+DELAY/i, + /BENCHMARK\s*\(/i, + /SLEEP\s*\(/i, + // Heavy hex/ascii operations for data exfiltration + /CHAR\s*\(\s*\d+\s*\)/i, + /ASCII\s*\(\s*SUBSTRING/i, + // Hex encoding attempts + /0x[0-9a-f]+/i, ]; - return patterns.some(pattern => pattern.test(query)); + return patterns.some(pattern => pattern.test(queryWithoutComments)); } static sanitizeQuery(query: string): string { diff --git a/src/test/security.test.ts b/src/test/security.test.ts new file mode 100644 index 0000000..f4630c3 --- /dev/null +++ b/src/test/security.test.ts @@ -0,0 +1,187 @@ +import { QueryValidator } from '../security.js'; + +describe('QueryValidator', () => { + describe('validateQuery - Allowed Statements', () => { + const allowedQueries = [ + 'SELECT * FROM users', + 'SELECT id, name FROM users WHERE id = 1', + 'WITH cte AS (SELECT 1) SELECT * FROM cte', + 'SHOW TABLES', + 'DESCRIBE users', + 'EXPLAIN SELECT * FROM users', + ]; + + test.each(allowedQueries)('should allow: %s', (query) => { + const result = QueryValidator.validateQuery(query); + expect(result.isValid).toBe(true); + }); + }); + + describe('validateQuery - Forbidden Keywords', () => { + const forbiddenQueries = [ + { query: 'INSERT INTO users VALUES (1)', keyword: 'INSERT' }, + { query: 'UPDATE users SET name = "test"', keyword: 'UPDATE' }, + { query: 'DELETE FROM users', keyword: 'DELETE' }, + { query: 'DROP TABLE users', keyword: 'DROP' }, + { query: 'CREATE TABLE users (id INT)', keyword: 'CREATE' }, + { query: 'ALTER TABLE users ADD col INT', keyword: 'ALTER' }, + { query: 'TRUNCATE TABLE users', keyword: 'TRUNCATE' }, + { query: 'EXEC sp_executesql', keyword: 'EXEC' }, + { query: 'EXECUTE sp_executesql', keyword: 'EXEC' }, // EXECUTE contains EXEC + { query: 'BACKUP DATABASE db TO DISK', keyword: 'BACKUP' }, + { query: 'RESTORE DATABASE db FROM DISK', keyword: 'RESTORE' }, + { query: 'KILL 1', keyword: 'KILL' }, + { query: 'SHUTDOWN', keyword: 'SHUTDOWN' }, + ]; + + test.each(forbiddenQueries)('should reject query with $keyword', ({ query, keyword }) => { + const result = QueryValidator.validateQuery(query); + expect(result.isValid).toBe(false); + expect(result.error).toContain(keyword); + }); + }); + + describe('validateQuery - Case Insensitive Bypass Prevention', () => { + const caseBypassQueries = [ + 'InSeRt INTO users VALUES (1)', + 'UpDaTe users SET name = "test"', + 'DeLeTe FROM users', + 'DrOp TABLE users', + ]; + + test.each(caseBypassQueries)('should reject case obfuscated: %s', (query) => { + const result = QueryValidator.validateQuery(query); + expect(result.isValid).toBe(false); + }); + }); + + describe('validateQuery - SQL Injection Patterns', () => { + const injectionQueries = [ + { query: "SELECT * FROM users WHERE id = 1; DROP TABLE users--", reason: 'Statement injection with DROP' }, + { query: "SELECT * FROM users WHERE id = 1 OR 1=1", reason: 'OR injection' }, + { query: "SELECT * FROM users WHERE id = 1' OR '1'='1", reason: 'OR injection with quotes' }, + { query: "SELECT * FROM users UNION SELECT * FROM passwords", reason: 'UNION injection' }, + { query: "SELECT CHAR(65)", reason: 'CHAR injection for exfiltration' }, + { query: "SELECT ASCII(SUBSTRING('password',1,1))", reason: 'ASCII blind injection' }, + { query: "SELECT 0x73756c656374", reason: 'Hex encoding attempt' }, + { query: "SELECT * FROM users; SELECT * FROM passwords", reason: 'Multiple statements' }, + ]; + + test.each(injectionQueries)('should reject: $reason', ({ query }) => { + const result = QueryValidator.validateQuery(query); + expect(result.isValid).toBe(false); + }); + }); + + describe('validateQuery - Comment-Based Bypass Prevention', () => { + it('should reject inline comment obfuscation', () => { + // SEL/**/ECT should be rejected - comments stripped leaves invalid query + const result = QueryValidator.validateQuery('SEL/**/ECT * FROM users'); + expect(result.isValid).toBe(false); + }); + + it('should reject comment between forbidden keyword', () => { + // DR/**/OP should be rejected - comments stripped leaves DROP + const result = QueryValidator.validateQuery('DR/**/OP TABLE users'); + expect(result.isValid).toBe(false); + }); + + it('should allow valid query with trailing comment', () => { + // Valid SELECT query with trailing comment should pass (comment stripped) + const result = QueryValidator.validateQuery('SELECT * FROM users WHERE id = 1--this is a comment'); + expect(result.isValid).toBe(true); + }); + }); + + describe('validateQuery - Variable Injection Prevention', () => { + const variableQueries = [ + { query: 'SELECT @variable FROM users', reason: 'T-SQL variable' }, + { query: 'EXEC sp_executesql @sql', reason: 'Variable in EXEC' }, + ]; + + test.each(variableQueries)('should reject: $reason', ({ query }) => { + const result = QueryValidator.validateQuery(query); + expect(result.isValid).toBe(false); + }); + }); + + describe('validateQuery - Empty Query', () => { + it('should reject empty query', () => { + expect(QueryValidator.validateQuery('').isValid).toBe(false); + }); + + it('should reject whitespace-only query', () => { + expect(QueryValidator.validateQuery(' ').isValid).toBe(false); + }); + }); + + describe('sanitizeQuery', () => { + it('should trim whitespace', () => { + expect(QueryValidator.sanitizeQuery(' SELECT * FROM users ')).toBe('SELECT * FROM users'); + }); + + it('should normalize multiple spaces', () => { + expect(QueryValidator.sanitizeQuery('SELECT * FROM users')).toBe('SELECT * FROM users'); + }); + + it('should remove trailing semicolon', () => { + expect(QueryValidator.sanitizeQuery('SELECT * FROM users;')).toBe('SELECT * FROM users'); + }); + }); + + describe('addRowLimit', () => { + it('should add TOP clause to SELECT', () => { + const result = QueryValidator.addRowLimit('SELECT * FROM users', 100); + expect(result).toBe('SELECT TOP 100 * FROM users'); + }); + + it('should not modify if TOP already exists', () => { + const result = QueryValidator.addRowLimit('SELECT TOP 50 * FROM users', 100); + expect(result).toBe('SELECT TOP 50 * FROM users'); + }); + + it('should handle SELECT with leading whitespace', () => { + const result = QueryValidator.addRowLimit(' SELECT * FROM users', 100); + expect(result).toBe(' SELECT TOP 100 * FROM users'); + }); + }); + + describe('isValidCallbackUrl - SSRF Prevention', () => { + it('should allow localhost HTTPS', () => { + expect(QueryValidator.isValidCallbackUrl('https://localhost:8080/callback')).toBe(true); + }); + + it('should allow 127.0.0.1 HTTP', () => { + expect(QueryValidator.isValidCallbackUrl('http://127.0.0.1:8080/callback')).toBe(true); + }); + + it('should allow ::1 (IPv6 localhost)', () => { + expect(QueryValidator.isValidCallbackUrl('http://[::1]:8080/callback')).toBe(true); + }); + + it('should block external HTTPS URLs', () => { + expect(QueryValidator.isValidCallbackUrl('https://api.example.com/callback')).toBe(false); + }); + + it('should block external HTTP URLs', () => { + expect(QueryValidator.isValidCallbackUrl('http://api.example.com/callback')).toBe(false); + }); + + it('should block internal IP addresses other than localhost', () => { + expect(QueryValidator.isValidCallbackUrl('http://192.168.1.1/callback')).toBe(false); + expect(QueryValidator.isValidCallbackUrl('http://10.0.0.1/callback')).toBe(false); + }); + + it('should block file:// protocol', () => { + expect(QueryValidator.isValidCallbackUrl('file:///etc/passwd')).toBe(false); + }); + + it('should block data:// protocol', () => { + expect(QueryValidator.isValidCallbackUrl('data:text/html,')).toBe(false); + }); + + it('should return false for invalid URLs', () => { + expect(QueryValidator.isValidCallbackUrl('not-a-url')).toBe(false); + }); + }); +}); diff --git a/src/tools/get-foreign-keys.ts b/src/tools/get-foreign-keys.ts index 374048b..1918759 100644 --- a/src/tools/get-foreign-keys.ts +++ b/src/tools/get-foreign-keys.ts @@ -1,5 +1,6 @@ import { BaseTool } from './base.js'; import { ForeignKeyInfo } from '../types.js'; +import { ParameterValidator } from '../validation.js'; export class GetForeignKeysTool extends BaseTool { getName(): string { @@ -29,10 +30,10 @@ export class GetForeignKeysTool extends BaseTool { } async execute(params: { table_name?: string; schema?: string }): Promise { - const { table_name, schema = 'dbo' } = params; + const { table_name, schema } = ParameterValidator.validateForeignKeyParameters(params); let query = ` - SELECT + SELECT fk.name as constraint_name, OBJECT_SCHEMA_NAME(fk.parent_object_id) as table_schema, OBJECT_NAME(fk.parent_object_id) as table_name, @@ -41,18 +42,23 @@ export class GetForeignKeysTool extends BaseTool { OBJECT_NAME(fk.referenced_object_id) as referenced_table_name, COL_NAME(fkc.referenced_object_id, fkc.referenced_column_id) as referenced_column_name FROM sys.foreign_keys fk - INNER JOIN sys.foreign_key_columns fkc + INNER JOIN sys.foreign_key_columns fkc ON fk.object_id = fkc.constraint_object_id `; const conditions = []; - + if (table_name) { - conditions.push(`OBJECT_NAME(fk.parent_object_id) = '${table_name.replace(/'/g, "''")}'`); + const escapedTableName = ParameterValidator.escapeIdentifier(table_name); + conditions.push(`OBJECT_NAME(fk.parent_object_id) = ${escapedTableName}`); } - + if (schema && table_name) { - conditions.push(`OBJECT_SCHEMA_NAME(fk.parent_object_id) = '${schema.replace(/'/g, "''")}'`); + const escapedSchema = ParameterValidator.escapeIdentifier(schema); + conditions.push(`OBJECT_SCHEMA_NAME(fk.parent_object_id) = ${escapedSchema}`); + } else if (schema) { + const escapedSchema = ParameterValidator.escapeIdentifier(schema); + conditions.push(`OBJECT_SCHEMA_NAME(fk.parent_object_id) = ${escapedSchema}`); } if (conditions.length > 0) { diff --git a/src/tools/get-table-stats.ts b/src/tools/get-table-stats.ts index e0d7ca0..2b4ec6e 100644 --- a/src/tools/get-table-stats.ts +++ b/src/tools/get-table-stats.ts @@ -1,5 +1,6 @@ import { BaseTool } from './base.js'; import { TableStats } from '../types.js'; +import { ParameterValidator } from '../validation.js'; export class GetTableStatsTool extends BaseTool { getName(): string { @@ -29,10 +30,10 @@ export class GetTableStatsTool extends BaseTool { } async execute(params: { table_name?: string; schema?: string }): Promise { - const { table_name, schema = 'dbo' } = params; + const { table_name, schema } = ParameterValidator.validateTableStatsParameters(params); let query = ` - SELECT + SELECT s.name as table_schema, t.name as table_name, p.rows as row_count, @@ -50,13 +51,15 @@ export class GetTableStatsTool extends BaseTool { `; const conditions = []; - + if (table_name) { - conditions.push(`t.name = '${table_name.replace(/'/g, "''")}'`); + const escapedTableName = ParameterValidator.escapeIdentifier(table_name); + conditions.push(`t.name = ${escapedTableName}`); } - - if (schema && table_name) { - conditions.push(`s.name = '${schema.replace(/'/g, "''")}'`); + + if (schema) { + const escapedSchema = ParameterValidator.escapeIdentifier(schema); + conditions.push(`s.name = ${escapedSchema}`); } if (conditions.length > 0) { diff --git a/src/tools/list-views.ts b/src/tools/list-views.ts index 6979887..3813dce 100644 --- a/src/tools/list-views.ts +++ b/src/tools/list-views.ts @@ -1,5 +1,6 @@ import { BaseTool } from './base.js'; import { ViewInfo } from '../types.js'; +import { ParameterValidator } from '../validation.js'; export class ListViewsTool extends BaseTool { getName(): string { @@ -24,10 +25,10 @@ export class ListViewsTool extends BaseTool { } async execute(params: { schema?: string }): Promise { - const { schema } = params; + const { schema } = ParameterValidator.validateListTablesParameters(params); let query = ` - SELECT + SELECT TABLE_CATALOG as table_catalog, TABLE_SCHEMA as table_schema, TABLE_NAME as table_name, @@ -38,7 +39,8 @@ export class ListViewsTool extends BaseTool { `; if (schema) { - query += ` WHERE TABLE_SCHEMA = '${schema.replace(/'/g, "''")}'`; + const escapedSchema = ParameterValidator.escapeIdentifier(schema); + query += ` WHERE TABLE_SCHEMA = ${escapedSchema}`; } query += ' ORDER BY TABLE_SCHEMA, TABLE_NAME'; diff --git a/src/types.ts b/src/types.ts index d571e1c..5419215 100644 --- a/src/types.ts +++ b/src/types.ts @@ -7,7 +7,7 @@ export const ConnectionConfigSchema = z.object({ password: z.string(), port: z.number().optional().default(1433), encrypt: z.boolean().optional().default(true), - trustServerCertificate: z.boolean().optional().default(true), + trustServerCertificate: z.boolean().optional().default(false), connectionTimeout: z.number().optional().default(30000), requestTimeout: z.number().optional().default(60000), maxRows: z.number().optional().default(1000), diff --git a/src/validation.ts b/src/validation.ts index 5b116c9..62ec0bb 100644 --- a/src/validation.ts +++ b/src/validation.ts @@ -149,6 +149,26 @@ export class ParameterValidator { return result; } + // Validate table stats parameters + static validateTableStatsParameters(params: { table_name?: string; schema?: string }): { + table_name?: string; + schema?: string; + } { + const result: { table_name?: string; schema?: string } = {}; + + if (params.schema) { + result.schema = this.validateSchemaName(params.schema); + } else { + result.schema = 'dbo'; // Default schema + } + + if (params.table_name) { + result.table_name = this.validateTableName(params.table_name); + } + + return result; + } + // Validate list tables parameters static validateListTablesParameters(params: { schema?: string }): { schema?: string;