diff --git a/.github/workflows/build-debug-apk.yml b/.github/workflows/build-debug-apk.yml new file mode 100644 index 00000000..e45db48f --- /dev/null +++ b/.github/workflows/build-debug-apk.yml @@ -0,0 +1,35 @@ +name: Build Debug APK + +on: + pull_request: + branches: [ main, master ] + workflow_dispatch: + +jobs: + build: + name: Build Debug APK + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up JDK 17 + uses: actions/setup-java@v4 + with: + java-version: '17' + distribution: 'temurin' + cache: gradle + + - name: Grant execute permission for gradlew + run: chmod +x gradlew + + - name: Build Debug APK + run: ./gradlew assembleDebug --no-daemon + + - name: Upload Debug APK + uses: actions/upload-artifact@v4 + with: + name: app-debug + path: app/build/outputs/apk/debug/app-debug.apk + retention-days: 14 diff --git a/app/build.gradle.kts b/app/build.gradle.kts index a0366e91..f3b0a1b8 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -143,6 +143,10 @@ dependencies { // Debug debugImplementation(libs.androidx.compose.ui.tooling) + + // Tests + testImplementation(libs.junit) + testImplementation(libs.org.json) } fun getProperty(value: String): String { @@ -154,4 +158,4 @@ fun getProperty(value: String): String { } else { System.getenv(value) ?: "\"sample_val\"" } -} \ No newline at end of file +} diff --git a/app/src/main/java/com/dark/tool_neuron/activity/MainActivity.kt b/app/src/main/java/com/dark/tool_neuron/activity/MainActivity.kt index bcb6a4aa..58045d57 100644 --- a/app/src/main/java/com/dark/tool_neuron/activity/MainActivity.kt +++ b/app/src/main/java/com/dark/tool_neuron/activity/MainActivity.kt @@ -17,6 +17,7 @@ import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.saveable.rememberSaveable import androidx.compose.runtime.setValue import androidx.hilt.lifecycle.viewmodel.compose.hiltViewModel import androidx.navigation.compose.NavHost @@ -25,7 +26,9 @@ import androidx.navigation.compose.rememberNavController import com.dark.tool_neuron.data.TermsDataStore import com.dark.tool_neuron.di.AppContainer import com.dark.tool_neuron.engine.EmbeddingEngine +import com.dark.tool_neuron.ui.screen.IntroScreen import com.dark.tool_neuron.ui.screen.EmbeddingSetupScreen +import com.dark.tool_neuron.ui.screen.McpServersScreen import com.dark.tool_neuron.ui.screen.ModelConfigEditorScreen import com.dark.tool_neuron.ui.screen.ModelStoreScreen import com.dark.tool_neuron.ui.screen.TermsAndConditionsScreen @@ -77,6 +80,8 @@ class MainActivity : ComponentActivity() { val hasAcceptedTerms by termsDataStore.hasAcceptedTerms.collectAsState(initial = true) val scope = rememberCoroutineScope() + var showIntro by rememberSaveable { mutableStateOf(true) } + // Start background download if model not present LaunchedEffect(Unit) { withContext(Dispatchers.IO) { @@ -86,22 +91,26 @@ class MainActivity : ComponentActivity() { } } - if (!hasAcceptedTerms) { - TermsAndConditionsScreen( - onAccept = { - scope.launch { - termsDataStore.acceptTerms() - } - } - ) + if (showIntro) { + IntroScreen(onFinished = { showIntro = false }) } else { - val chatViewModel: ChatViewModel = hiltViewModel() - val llmModelViewModel: LLMModelViewModel = hiltViewModel() - - AppNavigation( - chatViewModel = chatViewModel, - llmModelViewModel = llmModelViewModel - ) + if (!hasAcceptedTerms) { + TermsAndConditionsScreen( + onAccept = { + scope.launch { + termsDataStore.acceptTerms() + } + } + ) + } else { + val chatViewModel: ChatViewModel = hiltViewModel() + val llmModelViewModel: LLMModelViewModel = hiltViewModel() + + AppNavigation( + chatViewModel = chatViewModel, + llmModelViewModel = llmModelViewModel + ) + } } } } @@ -121,6 +130,7 @@ sealed class Screen(val route: String) { object Store : Screen("store") object Editor : Screen("editor") object VaultManager: Screen("vault_manager") + object McpServers: Screen("mcp_servers") } @Composable @@ -176,6 +186,9 @@ fun AppNavigation( onVaultManagerClick = { navController.navigate(Screen.VaultManager.route) }, + onMcpServersClick = { + navController.navigate(Screen.McpServers.route) + }, chatViewModel = chatViewModel, llmModelViewModel = llmModelViewModel ) @@ -196,5 +209,11 @@ fun AppNavigation( composable(Screen.VaultManager.route) { VaultDashboard() } + + composable(Screen.McpServers.route) { + McpServersScreen(onBackClick = { + navController.popBackStack() + }) + } } } \ No newline at end of file diff --git a/app/src/main/java/com/dark/tool_neuron/database/AppDatabase.kt b/app/src/main/java/com/dark/tool_neuron/database/AppDatabase.kt index 8ac7b2d0..5490feb6 100644 --- a/app/src/main/java/com/dark/tool_neuron/database/AppDatabase.kt +++ b/app/src/main/java/com/dark/tool_neuron/database/AppDatabase.kt @@ -7,17 +7,19 @@ import androidx.room.RoomDatabase import androidx.room.TypeConverters import androidx.room.migration.Migration import androidx.sqlite.db.SupportSQLiteDatabase +import com.dark.tool_neuron.database.dao.McpServerDao import com.dark.tool_neuron.database.dao.ModelConfigDao import com.dark.tool_neuron.database.dao.ModelDao import com.dark.tool_neuron.database.dao.RagDao import com.dark.tool_neuron.models.converters.Converters import com.dark.tool_neuron.models.table_schema.InstalledRag +import com.dark.tool_neuron.models.table_schema.McpServer import com.dark.tool_neuron.models.table_schema.Model import com.dark.tool_neuron.models.table_schema.ModelConfig @Database( - entities = [Model::class, ModelConfig::class, InstalledRag::class], - version = 4, + entities = [Model::class, ModelConfig::class, InstalledRag::class, McpServer::class], + version = 5, exportSchema = false ) @TypeConverters(Converters::class) @@ -25,6 +27,7 @@ abstract class AppDatabase : RoomDatabase() { abstract fun modelDao(): ModelDao abstract fun modelConfigDao(): ModelConfigDao abstract fun ragDao(): RagDao + abstract fun mcpServerDao(): McpServerDao companion object { @Volatile @@ -114,6 +117,28 @@ abstract class AppDatabase : RoomDatabase() { } } + private val MIGRATION_4_5 = object : Migration(4, 5) { + override fun migrate(db: SupportSQLiteDatabase) { + // Create mcp_servers table + db.execSQL(""" + CREATE TABLE IF NOT EXISTS mcp_servers ( + id TEXT PRIMARY KEY NOT NULL, + name TEXT NOT NULL, + url TEXT NOT NULL, + transportType TEXT NOT NULL, + apiKey TEXT, + isEnabled INTEGER NOT NULL, + lastError TEXT, + createdAt INTEGER NOT NULL, + updatedAt INTEGER NOT NULL, + lastConnectedAt INTEGER, + description TEXT NOT NULL, + customHeadersJson TEXT + ) + """.trimIndent()) + } + } + fun getDatabase(context: Context): AppDatabase { return INSTANCE ?: synchronized(this) { val instance = Room.databaseBuilder( @@ -121,7 +146,7 @@ abstract class AppDatabase : RoomDatabase() { AppDatabase::class.java, "llm_models_database" ) - .addMigrations(MIGRATION_1_2, MIGRATION_2_3, MIGRATION_3_4) + .addMigrations(MIGRATION_1_2, MIGRATION_2_3, MIGRATION_3_4, MIGRATION_4_5) .fallbackToDestructiveMigration() .build() INSTANCE = instance diff --git a/app/src/main/java/com/dark/tool_neuron/database/dao/McpServerDao.kt b/app/src/main/java/com/dark/tool_neuron/database/dao/McpServerDao.kt new file mode 100644 index 00000000..5525881d --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/database/dao/McpServerDao.kt @@ -0,0 +1,42 @@ +package com.dark.tool_neuron.database.dao + +import androidx.room.* +import com.dark.tool_neuron.models.table_schema.McpServer +import kotlinx.coroutines.flow.Flow + +@Dao +interface McpServerDao { + + @Query("SELECT * FROM mcp_servers ORDER BY name ASC") + fun getAllServers(): Flow> + + @Query("SELECT * FROM mcp_servers WHERE isEnabled = 1 ORDER BY name ASC") + fun getEnabledServers(): Flow> + + @Query("SELECT * FROM mcp_servers WHERE id = :id") + suspend fun getServerById(id: String): McpServer? + + @Insert(onConflict = OnConflictStrategy.REPLACE) + suspend fun insertServer(server: McpServer) + + @Update + suspend fun updateServer(server: McpServer) + + @Delete + suspend fun deleteServer(server: McpServer) + + @Query("DELETE FROM mcp_servers WHERE id = :id") + suspend fun deleteServerById(id: String) + + @Query("UPDATE mcp_servers SET isEnabled = :isEnabled, updatedAt = :updatedAt WHERE id = :id") + suspend fun updateServerEnabled(id: String, isEnabled: Boolean, updatedAt: Long) + + @Query("UPDATE mcp_servers SET lastConnectedAt = :timestamp, updatedAt = :updatedAt WHERE id = :id") + suspend fun updateLastConnected(id: String, timestamp: Long, updatedAt: Long) + + @Query("SELECT COUNT(*) FROM mcp_servers") + fun getServerCount(): Flow + + @Query("SELECT COUNT(*) FROM mcp_servers WHERE isEnabled = 1") + fun getEnabledServerCount(): Flow +} diff --git a/app/src/main/java/com/dark/tool_neuron/di/AppContainer.kt b/app/src/main/java/com/dark/tool_neuron/di/AppContainer.kt index 687be60e..14e24fbf 100644 --- a/app/src/main/java/com/dark/tool_neuron/di/AppContainer.kt +++ b/app/src/main/java/com/dark/tool_neuron/di/AppContainer.kt @@ -4,7 +4,9 @@ import android.app.Application import android.content.Context import com.dark.tool_neuron.database.AppDatabase import com.dark.tool_neuron.repo.ChatRepository +import com.dark.tool_neuron.repo.McpServerRepository import com.dark.tool_neuron.repo.ModelRepository +import com.dark.tool_neuron.service.McpClientService import com.dark.tool_neuron.vault.VaultHelper import com.dark.tool_neuron.viewmodel.factory.ChatListViewModelFactory import com.dark.tool_neuron.viewmodel.factory.ChatViewModelFactory @@ -21,6 +23,8 @@ object AppContainer { private lateinit var database: AppDatabase private lateinit var modelRepository: ModelRepository private lateinit var chatRepository: ChatRepository + private lateinit var mcpServerRepository: McpServerRepository + private lateinit var mcpClientService: McpClientService private lateinit var llmModelViewModelFactory: LLMModelViewModelFactory private lateinit var chatListViewModelFactory: ChatListViewModelFactory private lateinit var chatViewModelFactory: ChatViewModelFactory @@ -38,10 +42,17 @@ object AppContainer { ) chatRepository = ChatRepository() + mcpServerRepository = McpServerRepository(database.mcpServerDao()) + mcpClientService = McpClientService() llmModelViewModelFactory = LLMModelViewModelFactory(application, modelRepository) chatListViewModelFactory = ChatListViewModelFactory(chatManager) - chatViewModelFactory = ChatViewModelFactory(chatManager, generationManager) + chatViewModelFactory = ChatViewModelFactory( + chatManager, + generationManager, + mcpServerRepository, + mcpClientService + ) initVault(context) } @@ -84,4 +95,4 @@ object AppContainer { fun isVaultReady(): Boolean = vaultInitialized -} \ No newline at end of file +} diff --git a/app/src/main/java/com/dark/tool_neuron/di/HiltModules.kt b/app/src/main/java/com/dark/tool_neuron/di/HiltModules.kt index e8765672..8be54ff1 100644 --- a/app/src/main/java/com/dark/tool_neuron/di/HiltModules.kt +++ b/app/src/main/java/com/dark/tool_neuron/di/HiltModules.kt @@ -4,8 +4,10 @@ import com.dark.tool_neuron.database.AppDatabase import com.dark.tool_neuron.engine.EmbeddingEngine import com.dark.tool_neuron.repo.ChatRepository + import com.dark.tool_neuron.repo.McpServerRepository import com.dark.tool_neuron.repo.ModelRepository import com.dark.tool_neuron.repo.RagRepository + import com.dark.tool_neuron.service.McpClientService import com.dark.tool_neuron.worker.ChatManager import com.dark.tool_neuron.worker.GenerationManager import com.dark.tool_neuron.worker.RagVaultIntegration @@ -65,6 +67,14 @@ context = context ) } + + @Provides + @Singleton + fun provideMcpServerRepository(database: AppDatabase): McpServerRepository { + return McpServerRepository( + mcpServerDao = database.mcpServerDao() + ) + } } @Module @@ -78,6 +88,17 @@ } } + @Module + @InstallIn(SingletonComponent::class) + object ServiceModule { + + @Provides + @Singleton + fun provideMcpClientService(): McpClientService { + return McpClientService() + } + } + @Module @InstallIn(SingletonComponent::class) object WorkerModule { diff --git a/app/src/main/java/com/dark/tool_neuron/models/converters/Converters.kt b/app/src/main/java/com/dark/tool_neuron/models/converters/Converters.kt index d801f7f3..33dd1528 100644 --- a/app/src/main/java/com/dark/tool_neuron/models/converters/Converters.kt +++ b/app/src/main/java/com/dark/tool_neuron/models/converters/Converters.kt @@ -3,6 +3,7 @@ package com.dark.tool_neuron.models.converters import androidx.room.TypeConverter import com.dark.tool_neuron.models.enums.PathType import com.dark.tool_neuron.models.enums.ProviderType +import com.dark.tool_neuron.models.table_schema.McpTransportType class Converters { @TypeConverter @@ -16,4 +17,10 @@ class Converters { @TypeConverter fun toPathType(value: String): PathType = PathType.valueOf(value) + + @TypeConverter + fun fromMcpTransportType(value: McpTransportType): String = value.name + + @TypeConverter + fun toMcpTransportType(value: String): McpTransportType = McpTransportType.valueOf(value) } \ No newline at end of file diff --git a/app/src/main/java/com/dark/tool_neuron/models/table_schema/McpServer.kt b/app/src/main/java/com/dark/tool_neuron/models/table_schema/McpServer.kt new file mode 100644 index 00000000..a8087cd4 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/models/table_schema/McpServer.kt @@ -0,0 +1,69 @@ +package com.dark.tool_neuron.models.table_schema + +import androidx.room.Entity +import androidx.room.PrimaryKey + +/** + * Transport type for MCP server connections + */ +enum class McpTransportType { + SSE, // Server-Sent Events (HTTP) + STREAMABLE_HTTP // Streamable HTTP transport +} + +/** + * Connection status of an MCP server (runtime only, not persisted) + */ +enum class McpConnectionStatus { + DISCONNECTED, + CONNECTING, + CONNECTED, + ERROR +} + +/** + * Entity representing a remote MCP (Model Context Protocol) server configuration. + * MCP servers provide tools, resources, and prompts to LLM applications. + */ +@Entity(tableName = "mcp_servers") +data class McpServer( + @PrimaryKey + val id: String, + + /** Display name for the server */ + val name: String, + + /** Server URL (e.g., "https://api.example.com/mcp") */ + val url: String, + + /** Transport type for the connection */ + val transportType: McpTransportType = McpTransportType.SSE, + + /** Optional API key for authentication */ + val apiKey: String? = null, + + /** Whether the server is enabled */ + val isEnabled: Boolean = true, + + /** Last error message if connection failed */ + val lastError: String? = null, + + /** Timestamp when the server was added */ + val createdAt: Long = System.currentTimeMillis(), + + /** Timestamp when the server was last modified */ + val updatedAt: Long = System.currentTimeMillis(), + + /** Timestamp when last successfully connected */ + val lastConnectedAt: Long? = null, + + /** Optional description */ + val description: String = "", + + /** Custom headers as JSON string (e.g., for additional auth) */ + val customHeadersJson: String? = null +) { + companion object { + fun generateId(): String = java.util.UUID.randomUUID().toString() + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/repo/McpServerRepository.kt b/app/src/main/java/com/dark/tool_neuron/repo/McpServerRepository.kt new file mode 100644 index 00000000..bb8c5726 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/repo/McpServerRepository.kt @@ -0,0 +1,153 @@ +package com.dark.tool_neuron.repo + +import android.util.Log +import com.dark.tool_neuron.database.dao.McpServerDao +import com.dark.tool_neuron.models.table_schema.McpConnectionStatus +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import java.net.URI +import javax.inject.Inject +import javax.inject.Singleton + +/** + * Repository for managing MCP (Model Context Protocol) server configurations + */ +@Singleton +class McpServerRepository @Inject constructor( + private val mcpServerDao: McpServerDao +) { + companion object { + private const val TAG = "McpServerRepository" + } + // Runtime connection status tracking (not persisted) + private val _connectionStatuses = MutableStateFlow>(emptyMap()) + val connectionStatuses: StateFlow> = _connectionStatuses.asStateFlow() + + /** + * Get all configured MCP servers + */ + fun getAllServers(): Flow> = mcpServerDao.getAllServers() + + /** + * Get only enabled MCP servers + */ + fun getEnabledServers(): Flow> = mcpServerDao.getEnabledServers() + + /** + * Get a specific server by ID + */ + suspend fun getServerById(id: String): McpServer? = mcpServerDao.getServerById(id) + + /** + * Add a new MCP server + * @throws IllegalArgumentException if the URL is not valid + */ + suspend fun addServer( + name: String, + url: String, + transportType: McpTransportType = McpTransportType.SSE, + apiKey: String? = null, + description: String = "" + ): McpServer { + val trimmedUrl = url.trim() + + // Validate URL format + val validatedUrl = try { + val uri = URI(trimmedUrl) + if (uri.scheme.isNullOrBlank() || uri.host.isNullOrBlank()) { + throw IllegalArgumentException("Invalid server URL: missing scheme or host") + } + if (uri.scheme != "http" && uri.scheme != "https") { + throw IllegalArgumentException("Invalid server URL scheme: ${uri.scheme}") + } + trimmedUrl + } catch (e: IllegalArgumentException) { + throw e + } catch (e: Exception) { + throw IllegalArgumentException("Invalid server URL format: '$trimmedUrl'", e) + } + + val server = McpServer( + id = McpServer.generateId(), + name = name, + url = validatedUrl, + transportType = transportType, + apiKey = apiKey?.trim()?.takeIf { it.isNotEmpty() }, + description = description.trim(), + isEnabled = true, + createdAt = System.currentTimeMillis(), + updatedAt = System.currentTimeMillis() + ) + mcpServerDao.insertServer(server) + return server + } + + /** + * Update an existing MCP server + */ + suspend fun updateServer(server: McpServer) { + mcpServerDao.updateServer(server.copy(updatedAt = System.currentTimeMillis())) + } + + /** + * Delete an MCP server + */ + suspend fun deleteServer(id: String) { + mcpServerDao.deleteServerById(id) + // Remove from runtime status tracking + _connectionStatuses.value = _connectionStatuses.value - id + } + + /** + * Toggle server enabled/disabled state + */ + suspend fun setServerEnabled(id: String, enabled: Boolean) { + mcpServerDao.updateServerEnabled(id, enabled, System.currentTimeMillis()) + if (!enabled) { + // When disabled, set status to disconnected + updateConnectionStatus(id, McpConnectionStatus.DISCONNECTED) + } + } + + /** + * Update the runtime connection status of a server + * @param serverId The ID of the server + * @param status The new connection status + * @param error Optional error message when status is ERROR + */ + fun updateConnectionStatus(serverId: String, status: McpConnectionStatus, error: String? = null) { + if (error != null && status == McpConnectionStatus.ERROR) { + Log.w(TAG, "MCP server $serverId connection error: $error") + } + _connectionStatuses.value = _connectionStatuses.value + (serverId to status) + } + + /** + * Update last connected timestamp + */ + suspend fun updateLastConnected(id: String) { + val now = System.currentTimeMillis() + mcpServerDao.updateLastConnected(id, now, now) + } + + /** + * Get the count of all servers + */ + fun getServerCount(): Flow = mcpServerDao.getServerCount() + + /** + * Get the count of enabled servers + */ + fun getEnabledServerCount(): Flow = mcpServerDao.getEnabledServerCount() + + /** + * Get the current connection status for a server + */ + fun getConnectionStatus(serverId: String): McpConnectionStatus { + return _connectionStatuses.value[serverId] ?: McpConnectionStatus.DISCONNECTED + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/service/McpClientService.kt b/app/src/main/java/com/dark/tool_neuron/service/McpClientService.kt new file mode 100644 index 00000000..6f319d5a --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/service/McpClientService.kt @@ -0,0 +1,373 @@ +package com.dark.tool_neuron.service + +import android.util.Log +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.withContext +import okhttp3.MediaType.Companion.toMediaType +import okhttp3.OkHttpClient +import okhttp3.Request +import okhttp3.RequestBody.Companion.toRequestBody +import org.json.JSONObject +import java.util.concurrent.TimeUnit +import javax.inject.Inject +import javax.inject.Singleton + +/** + * MCP Client response data + */ +data class McpToolInfo( + val name: String, + val description: String?, + val inputSchema: String? +) + +data class McpTestResult( + val success: Boolean, + val message: String, + val tools: List = emptyList(), + val serverInfo: String? = null +) + +/** + * Client service for connecting to remote MCP (Model Context Protocol) servers. + * Supports both SSE (Server-Sent Events) and Streamable HTTP transport types. + * + * Transport Types: + * - SSE: Uses text/event-stream for responses (commonly used by servers like Zapier MCP) + * - Streamable HTTP: Uses standard JSON responses + */ +@Singleton +class McpClientService @Inject constructor() { + + companion object { + private const val TAG = "McpClientService" + private const val CONNECT_TIMEOUT_SECONDS = 15L + private const val READ_TIMEOUT_SECONDS = 30L + private const val MCP_PROTOCOL_VERSION = "2024-11-05" + private const val CLIENT_NAME = "ToolNeuron" + private const val CLIENT_VERSION = "1.0.0" + private val JSON_MEDIA_TYPE = "application/json".toMediaType() + // Accept headers for different transport types + private const val ACCEPT_HEADER_SSE = "application/json, text/event-stream" + private const val ACCEPT_HEADER_HTTP = "application/json" + } + + private val httpClient = OkHttpClient.Builder() + .connectTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .readTimeout(READ_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .build() + + /** + * Clean up resources associated with the underlying OkHttpClient. + * This should be called when the McpClientService is no longer needed. + */ + fun close() { + try { + // Shut down the executor service used by the dispatcher + httpClient.dispatcher.executorService.shutdown() + // Evict all connections from the connection pool + httpClient.connectionPool.evictAll() + // Close any configured cache + httpClient.cache?.close() + } catch (e: Exception) { + Log.w(TAG, "Error while closing OkHttpClient resources", e) + } + } + + /** + * Get the appropriate Accept header based on transport type + */ + private fun getAcceptHeader(transportType: McpTransportType): String { + return when (transportType) { + McpTransportType.SSE -> ACCEPT_HEADER_SSE + McpTransportType.STREAMABLE_HTTP -> ACCEPT_HEADER_HTTP + } + } + + /** + * Parse response body, handling SSE format automatically. + * Some MCP servers return SSE-formatted responses regardless of the declared transport type, + * so we detect and parse SSE format for both transport types. + */ + private fun parseResponse(responseBody: String, transportType: McpTransportType): String { + // Always try to parse SSE format first, as some servers return SSE regardless of transport type + // The parseSseResponse function will return the original body if it's not SSE format + return parseSseResponse(responseBody) + } + + /** + * Parse SSE (Server-Sent Events) response format. + * SSE responses come as "event: message\ndata: {...json...}\n\n" + * This handles single-event responses commonly used in MCP request/response patterns. + * + * Note: For streaming scenarios, this parser extracts the last complete event. + * In MCP's request/response pattern, this is typically the only event. + */ + private fun parseSseResponse(responseBody: String): String { + // Check if this is an SSE response + if (!responseBody.contains("data:")) { + // Not SSE format, return as-is + return responseBody + } + + // Split by double newlines to separate events + val events = responseBody.split("\n\n") + + // Find the last event with data (for request/response pattern) + for (event in events.reversed()) { + val lines = event.lines() + val dataLines = lines.filter { it.startsWith("data:") } + + if (dataLines.isNotEmpty()) { + // Extract JSON from "data: {...}" format + // Multiple data lines in same event should be joined with newlines per SSE spec + val joinedData = dataLines.joinToString("\n") { it.removePrefix("data:").trim() } + + // Validate that the joined data is valid JSON to avoid propagating malformed JSON-RPC + return try { + JSONObject(joinedData) + joinedData + } catch (e: Exception) { + Log.w(TAG, "SSE data is not valid JSON; returning raw SSE response body", e) + responseBody + } + } + } + + // Fallback: return original response + return responseBody + } + + /** + * Test connection to an MCP server and retrieve server capabilities + */ + suspend fun testConnection(server: McpServer): McpTestResult = withContext(Dispatchers.IO) { + try { + Log.d(TAG, "Testing connection to MCP server: ${server.name} at ${server.url} (transport: ${server.transportType})") + + // Build the initialize request according to MCP protocol + val initializeRequest = JSONObject().apply { + put("jsonrpc", "2.0") + put("id", 1) + put("method", "initialize") + put("params", JSONObject().apply { + put("protocolVersion", MCP_PROTOCOL_VERSION) + put("capabilities", JSONObject()) + put("clientInfo", JSONObject().apply { + put("name", CLIENT_NAME) + put("version", CLIENT_VERSION) + }) + }) + } + + val requestBuilder = Request.Builder() + .url(server.url) + .post(initializeRequest.toString().toRequestBody(JSON_MEDIA_TYPE)) + .addHeader("Content-Type", "application/json") + .addHeader("Accept", getAcceptHeader(server.transportType)) + + // Add API key if provided + server.apiKey?.let { key -> + requestBuilder.addHeader("Authorization", "Bearer $key") + } + + val response = httpClient.newCall(requestBuilder.build()).execute() + + if (!response.isSuccessful) { + return@withContext McpTestResult( + success = false, + message = "Server returned error: ${response.code} ${response.message}" + ) + } + + val rawResponseBody = response.body?.string() + if (rawResponseBody.isNullOrBlank()) { + return@withContext McpTestResult( + success = false, + message = "Server returned empty response" + ) + } + + // Parse response based on transport type + val responseBody = parseResponse(rawResponseBody, server.transportType) + + // Parse JSON response with specific error handling + val jsonResponse = try { + JSONObject(responseBody) + } catch (e: org.json.JSONException) { + Log.e(TAG, "Failed to parse MCP response as JSON: ${e.message}") + return@withContext McpTestResult( + success = false, + message = "Server returned invalid JSON response. The server may not be a valid MCP server." + ) + } + + // Check for JSON-RPC error + if (jsonResponse.has("error")) { + val error = jsonResponse.getJSONObject("error") + return@withContext McpTestResult( + success = false, + message = "Server error: ${error.optString("message", "Unknown error")}" + ) + } + + // Parse the result + val result = jsonResponse.optJSONObject("result") + val serverInfo = result?.optJSONObject("serverInfo") + val serverName = serverInfo?.optString("name", "Unknown Server") ?: "Unknown Server" + val serverVersion = serverInfo?.optString("version", "") ?: "" + + // Now list available tools + val tools = listTools(server) + + McpTestResult( + success = true, + message = "Connected successfully", + tools = tools, + serverInfo = if (serverVersion.isNotEmpty()) "$serverName v$serverVersion" else serverName + ) + + } catch (e: Exception) { + Log.e(TAG, "Failed to connect to MCP server: ${e.message}", e) + McpTestResult( + success = false, + message = "Connection failed: ${e.message ?: "Unknown error"}" + ) + } + } + + /** + * List available tools from an MCP server + */ + suspend fun listTools(server: McpServer): List = withContext(Dispatchers.IO) { + try { + val listToolsRequest = JSONObject().apply { + put("jsonrpc", "2.0") + put("id", 2) + put("method", "tools/list") + put("params", JSONObject()) + } + + val requestBuilder = Request.Builder() + .url(server.url) + .post(listToolsRequest.toString().toRequestBody(JSON_MEDIA_TYPE)) + .addHeader("Content-Type", "application/json") + .addHeader("Accept", getAcceptHeader(server.transportType)) + + server.apiKey?.let { key -> + requestBuilder.addHeader("Authorization", "Bearer $key") + } + + val response = httpClient.newCall(requestBuilder.build()).execute() + + if (!response.isSuccessful) { + return@withContext emptyList() + } + + val rawResponseBody = response.body?.string() ?: return@withContext emptyList() + // Parse response based on transport type + val responseBody = parseResponse(rawResponseBody, server.transportType) + val jsonResponse = JSONObject(responseBody) + + if (jsonResponse.has("error")) { + return@withContext emptyList() + } + + val result = jsonResponse.optJSONObject("result") ?: return@withContext emptyList() + val toolsArray = result.optJSONArray("tools") ?: return@withContext emptyList() + + val tools = mutableListOf() + for (i in 0 until toolsArray.length()) { + val tool = toolsArray.getJSONObject(i) + tools.add(McpToolInfo( + name = tool.getString("name"), + description = tool.optString("description", null), + inputSchema = tool.optJSONObject("inputSchema")?.toString() + )) + } + + tools + + } catch (e: Exception) { + Log.e(TAG, "Failed to list tools: ${e.message}", e) + emptyList() + } + } + + /** + * Call a tool on an MCP server + */ + suspend fun callTool( + server: McpServer, + toolName: String, + arguments: Map + ): Result = callToolInternal(server, toolName, JSONObject(arguments)) + + suspend fun callTool( + server: McpServer, + toolName: String, + argumentsJson: String + ): Result { + val parsedArguments = try { + if (argumentsJson.isBlank()) JSONObject() else JSONObject(argumentsJson) + } catch (e: Exception) { + return Result.failure(Exception("Invalid tool arguments JSON: ${e.message}")) + } + return callToolInternal(server, toolName, parsedArguments) + } + + private suspend fun callToolInternal( + server: McpServer, + toolName: String, + arguments: JSONObject + ): Result = withContext(Dispatchers.IO) { + try { + val callToolRequest = JSONObject().apply { + put("jsonrpc", "2.0") + put("id", System.currentTimeMillis()) + put("method", "tools/call") + put("params", JSONObject().apply { + put("name", toolName) + put("arguments", arguments) + }) + } + + val requestBuilder = Request.Builder() + .url(server.url) + .post(callToolRequest.toString().toRequestBody(JSON_MEDIA_TYPE)) + .addHeader("Content-Type", "application/json") + .addHeader("Accept", getAcceptHeader(server.transportType)) + + server.apiKey?.let { key -> + requestBuilder.addHeader("Authorization", "Bearer $key") + } + + val response = httpClient.newCall(requestBuilder.build()).execute() + + if (!response.isSuccessful) { + return@withContext Result.failure(Exception("Server returned: ${response.code}")) + } + + val rawResponseBody = response.body?.string() + ?: return@withContext Result.failure(Exception("Empty response")) + + // Parse response based on transport type + val responseBody = parseResponse(rawResponseBody, server.transportType) + val jsonResponse = JSONObject(responseBody) + + if (jsonResponse.has("error")) { + val error = jsonResponse.getJSONObject("error") + return@withContext Result.failure(Exception(error.optString("message", "Unknown error"))) + } + + val result = jsonResponse.optJSONObject("result") + return@withContext Result.success(result?.toString() ?: responseBody) + + } catch (e: Exception) { + Log.e(TAG, "Failed to call tool: ${e.message}", e) + return@withContext Result.failure(e) + } + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/service/McpToolMapper.kt b/app/src/main/java/com/dark/tool_neuron/service/McpToolMapper.kt new file mode 100644 index 00000000..7495c1d8 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/service/McpToolMapper.kt @@ -0,0 +1,68 @@ +package com.dark.tool_neuron.service + +import com.dark.tool_neuron.models.table_schema.McpServer +import org.json.JSONArray +import org.json.JSONObject + +data class McpToolReference( + val server: McpServer, + val toolName: String +) + +data class McpToolMapping( + val toolsJson: String, + val toolRegistry: Map +) + +object McpToolMapper { + fun sanitizeIdentifier(value: String): String { + return value.lowercase() + .replace(Regex("[^a-z0-9]+"), "_") + .trim('_') + } + + fun buildMapping(serverTools: Map>): McpToolMapping { + val toolsArray = JSONArray() + val registry = mutableMapOf() + + serverTools.forEach { (server, tools) -> + val serverPrefix = sanitizeIdentifier(server.name).ifBlank { "mcp" } + tools.forEach { tool -> + val toolSlug = sanitizeIdentifier(tool.name).ifBlank { "tool" } + val toolId = "${serverPrefix}_${toolSlug}" + toolsArray.put(buildToolDefinition(toolId, tool)) + registry[toolId] = McpToolReference(server, tool.name) + } + } + + return McpToolMapping( + toolsJson = toolsArray.toString(), + toolRegistry = registry + ) + } + + private fun buildToolDefinition(toolId: String, tool: McpToolInfo): JSONObject { + val function = JSONObject().apply { + put("name", toolId) + tool.description?.takeIf { it.isNotBlank() }?.let { put("description", it) } + put("parameters", buildParameters(tool.inputSchema)) + } + + return JSONObject().apply { + put("type", "function") + put("function", function) + } + } + + private fun buildParameters(inputSchema: String?): JSONObject { + val parsedSchema = inputSchema?.takeIf { it.isNotBlank() }?.let { + runCatching { JSONObject(it) }.getOrNull() + } + + return (parsedSchema ?: JSONObject()).apply { + if (!has("type")) { + put("type", "object") + } + } + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/ui/screen/IntroScreen.kt b/app/src/main/java/com/dark/tool_neuron/ui/screen/IntroScreen.kt index d336044e..3d843871 100644 --- a/app/src/main/java/com/dark/tool_neuron/ui/screen/IntroScreen.kt +++ b/app/src/main/java/com/dark/tool_neuron/ui/screen/IntroScreen.kt @@ -52,7 +52,7 @@ import kotlinx.coroutines.withContext import java.io.File @Composable -fun IntroScreen() { +fun IntroScreen(onFinished: () -> Unit) { val context = LocalContext.current var progress by remember { mutableFloatStateOf(0f) } @@ -100,6 +100,7 @@ fun IntroScreen() { delay(delayTime) progress = i / 1000f } + onFinished() } Scaffold(Modifier.fillMaxSize()) { _ -> diff --git a/app/src/main/java/com/dark/tool_neuron/ui/screen/McpServersScreen.kt b/app/src/main/java/com/dark/tool_neuron/ui/screen/McpServersScreen.kt new file mode 100644 index 00000000..877efafe --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/ui/screen/McpServersScreen.kt @@ -0,0 +1,865 @@ +package com.dark.tool_neuron.ui.screen + +import androidx.compose.animation.* +import androidx.compose.animation.core.* +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.* +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.foundation.shape.CircleShape +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.* +import androidx.compose.material3.* +import androidx.compose.runtime.* +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.text.input.KeyboardType +import androidx.compose.ui.text.input.PasswordVisualTransformation +import androidx.compose.ui.text.input.VisualTransformation +import androidx.compose.ui.text.style.TextOverflow +import androidx.compose.ui.unit.dp +import androidx.hilt.navigation.compose.hiltViewModel +import androidx.lifecycle.compose.collectAsStateWithLifecycle +import com.dark.tool_neuron.models.table_schema.McpConnectionStatus +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import com.dark.tool_neuron.service.McpTestResult +import com.dark.tool_neuron.ui.components.ActionButton +import com.dark.tool_neuron.ui.components.ActionTextButton +import com.dark.tool_neuron.ui.components.CuteSwitch +import com.dark.tool_neuron.ui.theme.rDp +import com.dark.tool_neuron.viewmodel.McpServerUiState +import com.dark.tool_neuron.viewmodel.McpServerViewModel +import java.text.SimpleDateFormat +import java.util.* + +// Success color for connected/successful states +private val SuccessGreen = Color(0xFF4CAF50) + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +fun McpServersScreen( + onBackClick: () -> Unit, + viewModel: McpServerViewModel = hiltViewModel() +) { + val servers by viewModel.servers.collectAsStateWithLifecycle() + val serverCount by viewModel.serverCount.collectAsStateWithLifecycle() + val enabledServerCount by viewModel.enabledServerCount.collectAsStateWithLifecycle() + val showAddDialog by viewModel.showAddDialog.collectAsStateWithLifecycle() + val showEditDialog by viewModel.showEditDialog.collectAsStateWithLifecycle() + val selectedServer by viewModel.selectedServer.collectAsStateWithLifecycle() + val testingServerId by viewModel.testingServerId.collectAsStateWithLifecycle() + val testResult by viewModel.testResult.collectAsStateWithLifecycle() + val isLoading by viewModel.isLoading.collectAsStateWithLifecycle() + val error by viewModel.error.collectAsStateWithLifecycle() + + Scaffold( + topBar = { + CenterAlignedTopAppBar( + title = { + Column(horizontalAlignment = Alignment.CenterHorizontally) { + Text( + "MCP Servers", + style = MaterialTheme.typography.titleMedium, + fontWeight = FontWeight.SemiBold + ) + Text( + "$enabledServerCount active / $serverCount total", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + }, + navigationIcon = { + ActionTextButton( + onClickListener = onBackClick, + icon = Icons.Default.ChevronLeft, + text = "Back", + modifier = Modifier.padding(start = rDp(6.dp)) + ) + }, + actions = { + ActionButton( + onClickListener = { viewModel.showAddServerDialog() }, + icon = Icons.Default.Add, + modifier = Modifier.padding(end = rDp(6.dp)) + ) + } + ) + } + ) { padding -> + Box( + modifier = Modifier + .fillMaxSize() + .padding(padding) + ) { + if (servers.isEmpty()) { + EmptyServersState(onAddServer = { viewModel.showAddServerDialog() }) + } else { + ServersList( + servers = servers, + testingServerId = testingServerId, + onServerClick = { viewModel.showEditServerDialog(it.server) }, + onToggleEnabled = { server, enabled -> + viewModel.toggleServerEnabled(server.server.id, enabled) + }, + onTestConnection = { viewModel.testConnection(it.server) }, + onDeleteServer = { viewModel.deleteServer(it.server.id) } + ) + } + + // Loading overlay + AnimatedVisibility( + visible = isLoading, + enter = fadeIn(), + exit = fadeOut() + ) { + Box( + modifier = Modifier + .fillMaxSize() + .background(MaterialTheme.colorScheme.surface.copy(alpha = 0.8f)), + contentAlignment = Alignment.Center + ) { + CircularProgressIndicator() + } + } + + // Error snackbar + error?.let { errorMessage -> + Snackbar( + modifier = Modifier + .align(Alignment.BottomCenter) + .padding(rDp(16.dp)), + action = { + TextButton(onClick = { viewModel.clearError() }) { + Text("Dismiss") + } + } + ) { + Text(errorMessage) + } + } + } + } + + // Add Server Dialog + if (showAddDialog) { + AddEditServerDialog( + server = null, + isTesting = testingServerId == "new", + testResult = testResult, + onDismiss = { viewModel.hideAddServerDialog() }, + onSave = { name, url, transportType, apiKey, description -> + viewModel.addServer(name, url, transportType, apiKey, description) + }, + onTestConnection = { name, url, transportType, apiKey -> + viewModel.testConnectionWithParams(name, url, transportType, apiKey) + }, + onClearTestResult = { viewModel.clearTestResult() } + ) + } + + // Edit Server Dialog + if (showEditDialog && selectedServer != null) { + AddEditServerDialog( + server = selectedServer, + isTesting = testingServerId == selectedServer?.id, + testResult = testResult, + onDismiss = { viewModel.hideEditServerDialog() }, + onSave = { name, url, transportType, apiKey, description -> + selectedServer?.let { server -> + viewModel.updateServer( + server.copy( + name = name, + url = url, + transportType = transportType, + apiKey = apiKey?.takeIf { it.isNotBlank() }, + description = description + ) + ) + } + }, + onTestConnection = { name, url, transportType, apiKey -> + viewModel.testConnectionWithParams(name, url, transportType, apiKey) + }, + onClearTestResult = { viewModel.clearTestResult() } + ) + } +} + +@Composable +private fun EmptyServersState(onAddServer: () -> Unit) { + Box( + modifier = Modifier.fillMaxSize(), + contentAlignment = Alignment.Center + ) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.spacedBy(rDp(16.dp)) + ) { + Icon( + imageVector = Icons.Default.Cloud, + contentDescription = null, + modifier = Modifier.size(rDp(72.dp)), + tint = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f) + ) + Text( + "No MCP Servers", + style = MaterialTheme.typography.titleMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + "Connect to remote MCP servers to extend\nyour AI capabilities with external tools", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f), + modifier = Modifier.padding(horizontal = rDp(32.dp)) + ) + Spacer(modifier = Modifier.height(rDp(8.dp))) + ActionTextButton( + onClickListener = onAddServer, + icon = Icons.Default.Add, + text = "Add Server", + shape = RoundedCornerShape(rDp(12.dp)) + ) + } + } +} + +@Composable +private fun ServersList( + servers: List, + testingServerId: String?, + onServerClick: (McpServerUiState) -> Unit, + onToggleEnabled: (McpServerUiState, Boolean) -> Unit, + onTestConnection: (McpServerUiState) -> Unit, + onDeleteServer: (McpServerUiState) -> Unit +) { + LazyColumn( + modifier = Modifier.fillMaxSize(), + contentPadding = PaddingValues(rDp(16.dp)), + verticalArrangement = Arrangement.spacedBy(rDp(12.dp)) + ) { + // Info card + item { + InfoCard() + } + + items(servers, key = { it.server.id }) { serverState -> + ServerCard( + serverState = serverState, + isTesting = testingServerId == serverState.server.id, + onClick = { onServerClick(serverState) }, + onToggleEnabled = { enabled -> onToggleEnabled(serverState, enabled) }, + onTestConnection = { onTestConnection(serverState) }, + onDelete = { onDeleteServer(serverState) } + ) + } + } +} + +@Composable +private fun InfoCard() { + Surface( + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)), + color = MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + ) { + Row( + modifier = Modifier.padding(rDp(16.dp)), + horizontalArrangement = Arrangement.spacedBy(rDp(12.dp)), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + imageVector = Icons.Default.Info, + contentDescription = null, + tint = MaterialTheme.colorScheme.primary, + modifier = Modifier.size(rDp(24.dp)) + ) + Column(modifier = Modifier.weight(1f)) { + Text( + text = "MCP (Model Context Protocol)", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.SemiBold, + color = MaterialTheme.colorScheme.onSurface + ) + Text( + text = "Connect to remote MCP servers to access external tools, resources, and capabilities for your AI conversations.", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } +} + +@Composable +private fun ServerCard( + serverState: McpServerUiState, + isTesting: Boolean, + onClick: () -> Unit, + onToggleEnabled: (Boolean) -> Unit, + onTestConnection: () -> Unit, + onDelete: () -> Unit +) { + val server = serverState.server + val status = serverState.connectionStatus + + Card( + modifier = Modifier + .fillMaxWidth() + .clickable { onClick() }, + shape = RoundedCornerShape(rDp(16.dp)), + colors = CardDefaults.cardColors( + containerColor = when (status) { + McpConnectionStatus.CONNECTED -> MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.2f) + McpConnectionStatus.ERROR -> MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.2f) + McpConnectionStatus.CONNECTING -> MaterialTheme.colorScheme.tertiaryContainer.copy(alpha = 0.2f) + else -> MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.5f) + } + ) + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(rDp(16.dp)) + ) { + // Header row + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier.weight(1f) + ) { + // Status indicator + StatusIndicator(status = status, isTesting = isTesting) + + Spacer(modifier = Modifier.width(rDp(12.dp))) + + Column { + Text( + text = server.name, + style = MaterialTheme.typography.bodyLarge, + fontWeight = FontWeight.SemiBold, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + Text( + text = server.url, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 1, + overflow = TextOverflow.Ellipsis + ) + } + } + + CuteSwitch( + checked = server.isEnabled, + onCheckedChange = onToggleEnabled + ) + } + + // Transport type badge + Spacer(modifier = Modifier.height(rDp(12.dp))) + Row( + horizontalArrangement = Arrangement.spacedBy(rDp(8.dp)), + verticalAlignment = Alignment.CenterVertically + ) { + TransportBadge(transportType = server.transportType) + + if (server.apiKey != null) { + Badge( + containerColor = MaterialTheme.colorScheme.secondary.copy(alpha = 0.1f), + contentColor = MaterialTheme.colorScheme.secondary + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(rDp(4.dp)), + modifier = Modifier.padding(horizontal = rDp(4.dp)) + ) { + Icon( + Icons.Default.Key, + contentDescription = null, + modifier = Modifier.size(rDp(12.dp)) + ) + Text("Auth", style = MaterialTheme.typography.labelSmall) + } + } + } + + server.lastConnectedAt?.let { lastConnected -> + Text( + text = "Last connected: ${formatTimestamp(lastConnected)}", + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f) + ) + } + } + + // Description + if (server.description.isNotBlank()) { + Spacer(modifier = Modifier.height(rDp(8.dp))) + Text( + text = server.description, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + maxLines = 2, + overflow = TextOverflow.Ellipsis + ) + } + + // Actions + Spacer(modifier = Modifier.height(rDp(12.dp))) + HorizontalDivider(color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.1f)) + Spacer(modifier = Modifier.height(rDp(12.dp))) + + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + ActionTextButton( + onClickListener = onTestConnection, + icon = Icons.Default.Refresh, + text = if (isTesting) "Testing..." else "Test Connection", + shape = RoundedCornerShape(rDp(12.dp)) + ) + + IconButton( + onClick = onDelete, + modifier = Modifier.size(rDp(36.dp)) + ) { + Icon( + Icons.Default.Delete, + contentDescription = "Delete", + tint = MaterialTheme.colorScheme.error, + modifier = Modifier.size(rDp(20.dp)) + ) + } + } + } + } +} + +@Composable +private fun StatusIndicator(status: McpConnectionStatus, isTesting: Boolean) { + val color = when { + isTesting -> MaterialTheme.colorScheme.tertiary + status == McpConnectionStatus.CONNECTED -> SuccessGreen + status == McpConnectionStatus.ERROR -> MaterialTheme.colorScheme.error + status == McpConnectionStatus.CONNECTING -> MaterialTheme.colorScheme.tertiary + else -> MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.3f) + } + + val infiniteTransition = rememberInfiniteTransition(label = "pulse") + val alpha by infiniteTransition.animateFloat( + initialValue = 1f, + targetValue = 0.3f, + animationSpec = infiniteRepeatable( + animation = tween(1000), + repeatMode = RepeatMode.Reverse + ), + label = "pulseAlpha" + ) + + Box( + modifier = Modifier + .size(rDp(12.dp)) + .clip(CircleShape) + .background( + if (isTesting || status == McpConnectionStatus.CONNECTING) { + color.copy(alpha = alpha) + } else { + color + } + ) + ) +} + +@Composable +private fun TransportBadge(transportType: McpTransportType) { + Badge( + containerColor = MaterialTheme.colorScheme.primary.copy(alpha = 0.1f), + contentColor = MaterialTheme.colorScheme.primary + ) { + Text( + text = when (transportType) { + McpTransportType.SSE -> "SSE" + McpTransportType.STREAMABLE_HTTP -> "HTTP" + }, + style = MaterialTheme.typography.labelSmall, + modifier = Modifier.padding(horizontal = rDp(4.dp)) + ) + } +} + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +private fun AddEditServerDialog( + server: McpServer?, + isTesting: Boolean, + testResult: McpTestResult?, + onDismiss: () -> Unit, + onSave: (name: String, url: String, transportType: McpTransportType, apiKey: String?, description: String) -> Unit, + onTestConnection: (name: String, url: String, transportType: McpTransportType, apiKey: String?) -> Unit, + onClearTestResult: () -> Unit +) { + var name by remember { mutableStateOf(server?.name ?: "") } + var url by remember { mutableStateOf(server?.url ?: "") } + var transportType by remember { mutableStateOf(server?.transportType ?: McpTransportType.SSE) } + var apiKey by remember { mutableStateOf(server?.apiKey ?: "") } + var description by remember { mutableStateOf(server?.description ?: "") } + var showApiKey by remember { mutableStateOf(false) } + + val isValid = name.isNotBlank() && url.isNotBlank() && + (url.startsWith("http://") || url.startsWith("https://")) + + ModalBottomSheet( + onDismissRequest = onDismiss, + sheetState = rememberModalBottomSheetState(skipPartiallyExpanded = true), + containerColor = MaterialTheme.colorScheme.surface, + dragHandle = { + Box( + Modifier + .padding(vertical = rDp(12.dp)) + .width(rDp(40.dp)) + .height(rDp(4.dp)) + .clip(RoundedCornerShape(rDp(2.dp))) + .background(MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.4f)) + ) + } + ) { + Column( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = rDp(24.dp)) + .padding(bottom = rDp(32.dp)) + ) { + // Header + Text( + text = if (server == null) "Add MCP Server" else "Edit MCP Server", + style = MaterialTheme.typography.titleLarge, + fontWeight = FontWeight.Bold + ) + + Text( + text = "Configure a remote MCP server connection", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + + Spacer(modifier = Modifier.height(rDp(24.dp))) + + // Name field + OutlinedTextField( + value = name, + onValueChange = { + name = it + onClearTestResult() + }, + label = { Text("Server Name") }, + placeholder = { Text("My MCP Server") }, + singleLine = true, + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)), + leadingIcon = { + Icon(Icons.Default.Label, contentDescription = null) + } + ) + + Spacer(modifier = Modifier.height(rDp(16.dp))) + + // URL field + val isInsecureUrl = url.startsWith("http://") && !url.startsWith("https://") + val showSecurityWarning = isInsecureUrl && apiKey.isNotBlank() + + OutlinedTextField( + value = url, + onValueChange = { + url = it + onClearTestResult() + }, + label = { Text("Server URL") }, + placeholder = { Text("https://api.example.com/mcp") }, + singleLine = true, + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)), + leadingIcon = { + Icon(Icons.Default.Link, contentDescription = null) + }, + trailingIcon = if (showSecurityWarning) { + { + Icon( + Icons.Default.Warning, + contentDescription = "Security warning", + tint = MaterialTheme.colorScheme.error + ) + } + } else null, + isError = url.isNotBlank() && !url.startsWith("http://") && !url.startsWith("https://"), + supportingText = when { + url.isNotBlank() && !url.startsWith("http://") && !url.startsWith("https://") -> { + { Text("URL must start with http:// or https://") } + } + showSecurityWarning -> { + { + Text( + "Warning: Using HTTP with an API key is insecure. Use HTTPS for secure connections.", + color = MaterialTheme.colorScheme.error + ) + } + } + isInsecureUrl -> { + { + Text( + "Consider using HTTPS for secure connections", + color = MaterialTheme.colorScheme.tertiary + ) + } + } + else -> null + }, + keyboardOptions = KeyboardOptions(keyboardType = KeyboardType.Uri) + ) + + Spacer(modifier = Modifier.height(rDp(16.dp))) + + // Transport type selector + Text( + text = "Transport Type", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Medium + ) + Spacer(modifier = Modifier.height(rDp(8.dp))) + Row( + horizontalArrangement = Arrangement.spacedBy(rDp(8.dp)) + ) { + FilterChip( + selected = transportType == McpTransportType.SSE, + onClick = { + transportType = McpTransportType.SSE + onClearTestResult() + }, + label = { Text("SSE (Server-Sent Events)") }, + leadingIcon = if (transportType == McpTransportType.SSE) { + { Icon(Icons.Default.Check, contentDescription = null, modifier = Modifier.size(rDp(16.dp))) } + } else null + ) + FilterChip( + selected = transportType == McpTransportType.STREAMABLE_HTTP, + onClick = { + transportType = McpTransportType.STREAMABLE_HTTP + onClearTestResult() + }, + label = { Text("Streamable HTTP") }, + leadingIcon = if (transportType == McpTransportType.STREAMABLE_HTTP) { + { Icon(Icons.Default.Check, contentDescription = null, modifier = Modifier.size(rDp(16.dp))) } + } else null + ) + } + + Spacer(modifier = Modifier.height(rDp(16.dp))) + + // API Key field + OutlinedTextField( + value = apiKey, + onValueChange = { + apiKey = it + onClearTestResult() + }, + label = { Text("API Key (Optional)") }, + placeholder = { Text("Bearer token or API key") }, + singleLine = true, + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)), + leadingIcon = { + Icon(Icons.Default.Key, contentDescription = null) + }, + trailingIcon = { + IconButton(onClick = { showApiKey = !showApiKey }) { + Icon( + if (showApiKey) Icons.Default.VisibilityOff else Icons.Default.Visibility, + contentDescription = if (showApiKey) "Hide" else "Show" + ) + } + }, + visualTransformation = if (showApiKey) VisualTransformation.None else PasswordVisualTransformation() + ) + + Spacer(modifier = Modifier.height(rDp(16.dp))) + + // Description field + OutlinedTextField( + value = description, + onValueChange = { description = it }, + label = { Text("Description (Optional)") }, + placeholder = { Text("What this server provides...") }, + minLines = 2, + maxLines = 3, + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)) + ) + + Spacer(modifier = Modifier.height(rDp(16.dp))) + + // Test result + AnimatedVisibility( + visible = testResult != null, + enter = fadeIn() + expandVertically(), + exit = fadeOut() + shrinkVertically() + ) { + testResult?.let { result -> + TestResultCard(result = result) + Spacer(modifier = Modifier.height(rDp(16.dp))) + } + } + + // Action buttons + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.spacedBy(rDp(12.dp)) + ) { + OutlinedButton( + onClick = { + onTestConnection(name, url, transportType, apiKey.takeIf { it.isNotBlank() }) + }, + enabled = isValid && !isTesting, + modifier = Modifier.weight(1f), + shape = RoundedCornerShape(rDp(12.dp)) + ) { + if (isTesting) { + CircularProgressIndicator( + modifier = Modifier.size(rDp(16.dp)), + strokeWidth = rDp(2.dp) + ) + Spacer(modifier = Modifier.width(rDp(8.dp))) + } + Text(if (isTesting) "Testing..." else "Test Connection") + } + + Button( + onClick = { + onSave(name, url, transportType, apiKey.takeIf { it.isNotBlank() }, description) + }, + enabled = isValid, + modifier = Modifier.weight(1f), + shape = RoundedCornerShape(rDp(12.dp)) + ) { + Icon(Icons.Default.Save, contentDescription = null) + Spacer(modifier = Modifier.width(rDp(8.dp))) + Text(if (server == null) "Add Server" else "Save Changes") + } + } + } + } +} + +@Composable +private fun TestResultCard(result: McpTestResult) { + Surface( + modifier = Modifier.fillMaxWidth(), + shape = RoundedCornerShape(rDp(12.dp)), + color = if (result.success) { + MaterialTheme.colorScheme.primaryContainer.copy(alpha = 0.3f) + } else { + MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f) + } + ) { + Column( + modifier = Modifier.padding(rDp(16.dp)) + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(rDp(8.dp)) + ) { + Icon( + imageVector = if (result.success) Icons.Default.CheckCircle else Icons.Default.Error, + contentDescription = null, + tint = if (result.success) { + SuccessGreen + } else { + MaterialTheme.colorScheme.error + }, + modifier = Modifier.size(rDp(20.dp)) + ) + Text( + text = if (result.success) "Connection Successful" else "Connection Failed", + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.SemiBold, + color = if (result.success) { + SuccessGreen + } else { + MaterialTheme.colorScheme.error + } + ) + } + + if (result.serverInfo != null) { + Spacer(modifier = Modifier.height(rDp(4.dp))) + Text( + text = "Server: ${result.serverInfo}", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + + if (!result.success) { + Spacer(modifier = Modifier.height(rDp(4.dp))) + Text( + text = result.message, + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.error + ) + } + + if (result.tools.isNotEmpty()) { + Spacer(modifier = Modifier.height(rDp(8.dp))) + Text( + text = "Available Tools (${result.tools.size}):", + style = MaterialTheme.typography.labelMedium, + fontWeight = FontWeight.Medium + ) + Spacer(modifier = Modifier.height(rDp(4.dp))) + result.tools.take(5).forEach { tool -> + Text( + text = "• ${tool.name}", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + if (result.tools.size > 5) { + Text( + text = "... and ${result.tools.size - 5} more", + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant.copy(alpha = 0.7f) + ) + } + } + } + } +} + +private fun formatTimestamp(timestamp: Long): String { + val now = System.currentTimeMillis() + val diff = now - timestamp + + return when { + diff < 60_000 -> "just now" + diff < 3600_000 -> "${diff / 60_000}m ago" + diff < 86400_000 -> "${diff / 3600_000}h ago" + diff < 604800_000 -> "${diff / 86400_000}d ago" + else -> { + val sdf = SimpleDateFormat("MMM dd", Locale.getDefault()) + sdf.format(Date(timestamp)) + } + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeDrawerScreen.kt b/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeDrawerScreen.kt index 6ee260bc..baacd1f2 100644 --- a/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeDrawerScreen.kt +++ b/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeDrawerScreen.kt @@ -16,6 +16,7 @@ import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material.icons.Icons import androidx.compose.material.icons.filled.Add import androidx.compose.material.icons.filled.Close +import androidx.compose.material.icons.filled.Cloud import androidx.compose.material.icons.filled.Delete import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.ExperimentalMaterial3Api @@ -56,6 +57,7 @@ import java.util.Locale fun HomeDrawerScreen( onChatSelected: (String) -> Unit, onVaultManagerClick: () -> Unit, + onMcpServersClick: () -> Unit, viewModel: ChatListViewModel = hiltViewModel() ) { val chats by viewModel.chats.collectAsStateWithLifecycle() @@ -77,8 +79,9 @@ fun HomeDrawerScreen( ), containerColor = MaterialTheme.colorScheme.background, topBar = { - TopBar( + DrawerTopBar( onVaultManagerClick, + onMcpServersClick, onCreateNewChat = { viewModel.createNewChat { chatId -> onChatSelected(chatId) @@ -120,10 +123,20 @@ fun HomeDrawerScreen( } } +/** + * Top app bar used in the home drawer screen. + * Provides quick access actions for managing vaults, configuring MCP servers, + * and creating a new chat session. + * + * @param onVaultManagerClick Invoked when the vault manager action is selected. + * @param onMcpServersClick Invoked when the MCP servers action is selected. + * @param onCreateNewChat Invoked when the user requests to create a new chat. + */ @OptIn(ExperimentalMaterial3Api::class) @Composable -private fun TopBar( +private fun DrawerTopBar( onVaultManagerClick: () -> Unit, + onMcpServersClick: () -> Unit, onCreateNewChat: () -> Unit ) { TopAppBar( @@ -135,6 +148,11 @@ private fun TopBar( }, actions = { Row{ + ActionButton( + onClickListener = onMcpServersClick, + icon = Icons.Filled.Cloud, + modifier = Modifier.padding(end = rDp(6.dp)) + ) ActionButton( onClickListener = onVaultManagerClick, icon = R.drawable.smart_temp_message, diff --git a/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeScreen.kt b/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeScreen.kt index 33f83d8c..a5f9a5ed 100644 --- a/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeScreen.kt +++ b/app/src/main/java/com/dark/tool_neuron/ui/screen/home_screen/HomeScreen.kt @@ -89,6 +89,7 @@ fun HomeScreen( onStoreButtonClicked: () -> Unit, onModelEditor: () -> Unit, onVaultManagerClick: () -> Unit, + onMcpServersClick: () -> Unit, chatViewModel: ChatViewModel, llmModelViewModel: LLMModelViewModel ) { @@ -103,14 +104,23 @@ fun HomeScreen( ModalNavigationDrawer( drawerState = drawerState, drawerContent = { ModalDrawerSheet { - HomeDrawerScreen(onVaultManagerClick = { - onVaultManagerClick() - }, onChatSelected = { - chatViewModel.loadChat(it) - scope.launch { - drawerState.close() + HomeDrawerScreen( + onVaultManagerClick = { + onVaultManagerClick() + }, + onMcpServersClick = { + scope.launch { + drawerState.close() + } + onMcpServersClick() + }, + onChatSelected = { + chatViewModel.loadChat(it) + scope.launch { + drawerState.close() + } } - }) + ) } }) { Scaffold( diff --git a/app/src/main/java/com/dark/tool_neuron/viewmodel/ChatViewModel.kt b/app/src/main/java/com/dark/tool_neuron/viewmodel/ChatViewModel.kt index 0029b221..c0c3c929 100644 --- a/app/src/main/java/com/dark/tool_neuron/viewmodel/ChatViewModel.kt +++ b/app/src/main/java/com/dark/tool_neuron/viewmodel/ChatViewModel.kt @@ -13,7 +13,12 @@ import com.dark.tool_neuron.models.messages.MessageContent import com.dark.tool_neuron.models.messages.Messages import com.dark.tool_neuron.models.messages.RagResultItem import com.dark.tool_neuron.models.messages.Role +import com.dark.tool_neuron.models.table_schema.McpServer import com.dark.tool_neuron.models.table_schema.ModelConfig +import com.dark.tool_neuron.repo.McpServerRepository +import com.dark.tool_neuron.service.McpClientService +import com.dark.tool_neuron.service.McpToolMapper +import com.dark.tool_neuron.service.McpToolReference import com.dark.tool_neuron.state.AppStateManager import com.dark.tool_neuron.worker.ChatManager import com.dark.tool_neuron.worker.DiffusionConfig @@ -26,12 +31,15 @@ import jakarta.inject.Inject import kotlinx.coroutines.Job import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.first import kotlinx.coroutines.launch @HiltViewModel class ChatViewModel @Inject constructor( private val chatManager: ChatManager, - private val generationManager: GenerationManager + private val generationManager: GenerationManager, + private val mcpServerRepository: McpServerRepository, + private val mcpClientService: McpClientService ) : ViewModel() { private val _messages = mutableStateListOf() @@ -103,6 +111,13 @@ class ChatViewModel @Inject constructor( private val _currentRagResults = MutableStateFlow>(emptyList()) val currentRagResults: StateFlow> = _currentRagResults + private data class ToolCallInfo( + val name: String, + val argsJson: String + ) + + private var mcpToolRegistry: Map = emptyMap() + // ==================== RAG Controls ==================== fun setRagEnabled(enabled: Boolean) { @@ -240,11 +255,13 @@ class ChatViewModel @Inject constructor( val tokenBatchSize = 3 try { + var pendingToolCall: ToolCallInfo? = null // Prepend RAG context if available val finalPrompt = _currentRagContext.value?.let { ragContext -> "$ragContext\n\n### User Query:\n$prompt" } ?: prompt + syncMcpTools() generationManager.generateTextStreaming(finalPrompt, maxTokens).collect { event -> when (event) { is GenerationEvent.Token -> { @@ -265,6 +282,11 @@ class ChatViewModel @Inject constructor( } is GenerationEvent.Done -> { + val toolCall = pendingToolCall + if (toolCall != null) { + handleToolCallForNewChat(prompt, toolCall) + return@collect + } _streamingAssistantMessage.value = currentGeneratedContent // Don't set _isGenerating.value = false here // It will be set in resetStreamingState() after messages are added @@ -279,7 +301,13 @@ class ChatViewModel @Inject constructor( currentMetrics = event.metrics } - is GenerationEvent.ToolCall -> {} + is GenerationEvent.ToolCall -> { + pendingToolCall = ToolCallInfo(event.name, event.args) + currentGeneratedContent = "" + tokenBuffer.clear() + tokenCount = 0 + _streamingAssistantMessage.value = "" + } } } } catch (e: Exception) { @@ -305,6 +333,7 @@ class ChatViewModel @Inject constructor( val tokenBatchSize = 3 try { + var pendingToolCall: ToolCallInfo? = null var conversationPrompt = generationManager.buildConversationPrompt( _messages, userMessage.content.content ) @@ -314,6 +343,7 @@ class ChatViewModel @Inject constructor( conversationPrompt = "$ragContext\n\n$conversationPrompt" } + syncMcpTools() generationManager.generateTextStreaming(conversationPrompt, maxTokens) .collect { event -> when (event) { @@ -335,6 +365,11 @@ class ChatViewModel @Inject constructor( } is GenerationEvent.Done -> { + val toolCall = pendingToolCall + if (toolCall != null) { + handleToolCallExistingChat(chatId, userMessage, toolCall) + return@collect + } _streamingAssistantMessage.value = currentGeneratedContent // Add user message first if not already added @@ -383,7 +418,13 @@ class ChatViewModel @Inject constructor( currentMetrics = event.metrics } - is GenerationEvent.ToolCall -> {} + is GenerationEvent.ToolCall -> { + pendingToolCall = ToolCallInfo(event.name, event.args) + currentGeneratedContent = "" + tokenBuffer.clear() + tokenCount = 0 + _streamingAssistantMessage.value = "" + } } } } catch (e: Exception) { @@ -524,6 +565,105 @@ class ChatViewModel @Inject constructor( } } + // ==================== MCP Tool Integration ==================== + + private suspend fun syncMcpTools() { + try { + val enabledServers = mcpServerRepository.getEnabledServers().first() + if (enabledServers.isEmpty()) { + mcpToolRegistry = emptyMap() + LlmModelWorker.clearGgufTools() + return + } + + val serverTools = mutableMapOf>() + enabledServers.forEach { server -> + val tools = mcpClientService.listTools(server) + if (tools.isNotEmpty()) { + serverTools[server] = tools + } + } + + if (serverTools.isEmpty()) { + mcpToolRegistry = emptyMap() + LlmModelWorker.clearGgufTools() + return + } + + val mapping = McpToolMapper.buildMapping(serverTools) + mcpToolRegistry = mapping.toolRegistry + + if (mapping.toolRegistry.isEmpty()) { + LlmModelWorker.clearGgufTools() + return + } + + val success = LlmModelWorker.setGgufToolsJson(mapping.toolsJson) + if (!success) { + LlmModelWorker.clearGgufTools() + } + } catch (e: Exception) { + val message = "Failed to refresh MCP tools: ${e.message}" + _error.value = message + AppStateManager.setError(message) + } + } + + private suspend fun handleToolCallForNewChat(prompt: String, toolCall: ToolCallInfo) { + val response = resolveToolCallResponse(toolCall) + _streamingAssistantMessage.value = response + createChatWithMessages(prompt, response, null) + } + + private suspend fun handleToolCallExistingChat( + chatId: String, + userMessage: Messages, + toolCall: ToolCallInfo + ) { + val response = resolveToolCallResponse(toolCall) + _streamingAssistantMessage.value = response + + if (!userMessageAdded) { + _messages.add(userMessage) + userMessageAdded = true + } + + val assistantMessage = Messages( + role = Role.Assistant, + content = MessageContent( + contentType = ContentType.Text, + content = response + ) + ) + _messages.add(assistantMessage) + + chatManager.addAssistantMessage(chatId, response, null) + AppStateManager.setGenerationComplete() + resetStreamingState() + } + + private suspend fun resolveToolCallResponse(toolCall: ToolCallInfo): String { + return executeToolCall(toolCall).fold( + onSuccess = { result -> formatToolResult(toolCall.name, result) }, + onFailure = { error -> + val message = "Tool ${toolCall.name} failed: ${error.message ?: "Unknown error"}" + _error.value = message + AppStateManager.setError(message) + message + } + ) + } + + private suspend fun executeToolCall(toolCall: ToolCallInfo): Result { + val reference = mcpToolRegistry[toolCall.name] + ?: return Result.failure(Exception("Tool not found: ${toolCall.name}")) + return mcpClientService.callTool(reference.server, reference.toolName, toolCall.argsJson) + } + + private fun formatToolResult(toolName: String, result: String): String { + return "Tool $toolName result:\n$result" + } + private fun generateImageForNewChat( prompt: String, negativePrompt: String, @@ -1018,4 +1158,4 @@ class ChatViewModel @Inject constructor( fun hideModelList() { _showModelList.value = false } -} \ No newline at end of file +} diff --git a/app/src/main/java/com/dark/tool_neuron/viewmodel/McpServerViewModel.kt b/app/src/main/java/com/dark/tool_neuron/viewmodel/McpServerViewModel.kt new file mode 100644 index 00000000..64688ef3 --- /dev/null +++ b/app/src/main/java/com/dark/tool_neuron/viewmodel/McpServerViewModel.kt @@ -0,0 +1,288 @@ +package com.dark.tool_neuron.viewmodel + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.dark.tool_neuron.models.table_schema.McpConnectionStatus +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import com.dark.tool_neuron.repo.McpServerRepository +import com.dark.tool_neuron.service.McpClientService +import com.dark.tool_neuron.service.McpTestResult +import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.combine +import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.launch +import javax.inject.Inject + +/** + * UI state for a single MCP server with runtime status + */ +data class McpServerUiState( + val server: McpServer, + val connectionStatus: McpConnectionStatus = McpConnectionStatus.DISCONNECTED +) + +/** + * ViewModel for managing MCP (Model Context Protocol) servers + */ +@HiltViewModel +class McpServerViewModel @Inject constructor( + private val repository: McpServerRepository, + private val mcpClientService: McpClientService +) : ViewModel() { + + companion object { + private const val ERROR_DISPLAY_DURATION_MS = 5000L + } + + // All servers with their runtime status + val servers: StateFlow> = combine( + repository.getAllServers(), + repository.connectionStatuses + ) { servers, statuses -> + servers.map { server -> + McpServerUiState( + server = server, + connectionStatus = statuses[server.id] ?: McpConnectionStatus.DISCONNECTED + ) + } + }.stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), emptyList()) + + // Server count + val serverCount: StateFlow = repository.getServerCount() + .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), 0) + + // Enabled server count + val enabledServerCount: StateFlow = repository.getEnabledServerCount() + .stateIn(viewModelScope, SharingStarted.WhileSubscribed(5000), 0) + + // Currently selected server for editing + private val _selectedServer = MutableStateFlow(null) + val selectedServer: StateFlow = _selectedServer.asStateFlow() + + // Dialog state + private val _showAddDialog = MutableStateFlow(false) + val showAddDialog: StateFlow = _showAddDialog.asStateFlow() + + private val _showEditDialog = MutableStateFlow(false) + val showEditDialog: StateFlow = _showEditDialog.asStateFlow() + + // Test result for the current dialog + private val _testingServerId = MutableStateFlow(null) + val testingServerId: StateFlow = _testingServerId.asStateFlow() + + private val _testResult = MutableStateFlow(null) + val testResult: StateFlow = _testResult.asStateFlow() + + // Loading state + private val _isLoading = MutableStateFlow(false) + val isLoading: StateFlow = _isLoading.asStateFlow() + + // Error state + private val _error = MutableStateFlow(null) + val error: StateFlow = _error.asStateFlow() + private var errorClearJob: kotlinx.coroutines.Job? = null + + /** + * Set an error message that auto-clears after a timeout. + * Cancels any previous auto-clear job to prevent race conditions. + */ + private fun setError(message: String) { + // Cancel any pending error clear job + errorClearJob?.cancel() + + _error.value = message + + // Start new clear job + errorClearJob = viewModelScope.launch { + delay(ERROR_DISPLAY_DURATION_MS) + _error.value = null + } + } + + /** + * Show the add server dialog + */ + fun showAddServerDialog() { + _selectedServer.value = null + _testResult.value = null + _showAddDialog.value = true + } + + /** + * Hide the add server dialog + */ + fun hideAddServerDialog() { + _showAddDialog.value = false + _testResult.value = null + } + + /** + * Show the edit server dialog + */ + fun showEditServerDialog(server: McpServer) { + _selectedServer.value = server + _testResult.value = null + _showEditDialog.value = true + } + + /** + * Hide the edit server dialog + */ + fun hideEditServerDialog() { + _showEditDialog.value = false + _selectedServer.value = null + _testResult.value = null + } + + /** + * Add a new MCP server + */ + fun addServer( + name: String, + url: String, + transportType: McpTransportType = McpTransportType.SSE, + apiKey: String? = null, + description: String = "" + ) { + viewModelScope.launch { + try { + _isLoading.value = true + repository.addServer(name, url, transportType, apiKey, description) + hideAddServerDialog() + } catch (e: Exception) { + setError("Failed to add server: ${e.message}") + } finally { + _isLoading.value = false + } + } + } + + /** + * Update an existing MCP server + */ + fun updateServer(server: McpServer) { + viewModelScope.launch { + try { + _isLoading.value = true + repository.updateServer(server) + hideEditServerDialog() + } catch (e: Exception) { + setError("Failed to update server: ${e.message}") + } finally { + _isLoading.value = false + } + } + } + + /** + * Delete an MCP server + */ + fun deleteServer(serverId: String) { + viewModelScope.launch { + try { + repository.deleteServer(serverId) + } catch (e: Exception) { + setError("Failed to delete server: ${e.message}") + } + } + } + + /** + * Toggle server enabled state + */ + fun toggleServerEnabled(serverId: String, enabled: Boolean) { + viewModelScope.launch { + try { + repository.setServerEnabled(serverId, enabled) + } catch (e: Exception) { + setError("Failed to update server: ${e.message}") + } + } + } + + /** + * Test connection to a server + */ + fun testConnection(server: McpServer) { + viewModelScope.launch { + try { + _testingServerId.value = server.id + _testResult.value = null + repository.updateConnectionStatus(server.id, McpConnectionStatus.CONNECTING) + + val result = mcpClientService.testConnection(server) + _testResult.value = result + + if (result.success) { + repository.updateConnectionStatus(server.id, McpConnectionStatus.CONNECTED) + repository.updateLastConnected(server.id) + } else { + repository.updateConnectionStatus(server.id, McpConnectionStatus.ERROR, result.message) + } + } catch (e: Exception) { + _testResult.value = McpTestResult( + success = false, + message = "Test failed: ${e.message}" + ) + repository.updateConnectionStatus(server.id, McpConnectionStatus.ERROR, e.message) + } finally { + _testingServerId.value = null + } + } + } + + /** + * Test connection with provided parameters (for add/edit dialog) + */ + fun testConnectionWithParams( + name: String, + url: String, + transportType: McpTransportType, + apiKey: String? + ) { + viewModelScope.launch { + try { + _testingServerId.value = "new" + _testResult.value = null + + val tempServer = McpServer( + id = "test", + name = name, + url = url, + transportType = transportType, + apiKey = apiKey?.takeIf { it.isNotBlank() } + ) + + val result = mcpClientService.testConnection(tempServer) + _testResult.value = result + } catch (e: Exception) { + _testResult.value = McpTestResult( + success = false, + message = "Test failed: ${e.message}" + ) + } finally { + _testingServerId.value = null + } + } + } + + /** + * Clear error message + */ + fun clearError() { + _error.value = null + } + + /** + * Clear test result + */ + fun clearTestResult() { + _testResult.value = null + } +} diff --git a/app/src/main/java/com/dark/tool_neuron/viewmodel/factory/ChatViewModelFactory.kt b/app/src/main/java/com/dark/tool_neuron/viewmodel/factory/ChatViewModelFactory.kt index a134d583..e9b39c66 100644 --- a/app/src/main/java/com/dark/tool_neuron/viewmodel/factory/ChatViewModelFactory.kt +++ b/app/src/main/java/com/dark/tool_neuron/viewmodel/factory/ChatViewModelFactory.kt @@ -2,18 +2,28 @@ package com.dark.tool_neuron.viewmodel.factory import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModelProvider +import com.dark.tool_neuron.repo.McpServerRepository +import com.dark.tool_neuron.service.McpClientService import com.dark.tool_neuron.viewmodel.ChatViewModel import com.dark.tool_neuron.worker.ChatManager import com.dark.tool_neuron.worker.GenerationManager class ChatViewModelFactory( - private val chatManager: ChatManager, private val generationManager: GenerationManager + private val chatManager: ChatManager, + private val generationManager: GenerationManager, + private val mcpServerRepository: McpServerRepository, + private val mcpClientService: McpClientService ) : ViewModelProvider.Factory { @Suppress("UNCHECKED_CAST") override fun create(modelClass: Class): T { if (modelClass.isAssignableFrom(ChatViewModel::class.java)) { - return ChatViewModel(chatManager, generationManager) as T + return ChatViewModel( + chatManager, + generationManager, + mcpServerRepository, + mcpClientService + ) as T } throw IllegalArgumentException("Unknown ViewModel class") } -} \ No newline at end of file +} diff --git a/app/src/main/java/com/dark/tool_neuron/worker/LlmModelWorker.kt b/app/src/main/java/com/dark/tool_neuron/worker/LlmModelWorker.kt index ffbb9c93..17da6bd7 100644 --- a/app/src/main/java/com/dark/tool_neuron/worker/LlmModelWorker.kt +++ b/app/src/main/java/com/dark/tool_neuron/worker/LlmModelWorker.kt @@ -289,7 +289,7 @@ object LlmModelWorker { } awaitClose { - // Optional: stop generation if flow is cancelled + ggufStopGeneration() } }.buffer(Channel.UNLIMITED) .flowOn(Dispatchers.IO) @@ -309,6 +309,25 @@ object LlmModelWorker { return service?.modelInfoGguf } + suspend fun setGgufToolsJson(toolsJson: String): Boolean { + val svc = ensureServiceBound() + return try { + svc.setToolsJsonGguf(toolsJson) + } catch (e: Exception) { + Log.e(TAG, "Failed to set GGUF tools JSON", e) + false + } + } + + fun clearGgufTools() { + try { + service?.clearToolsGguf() + Log.i(TAG, "Cleared GGUF tools") + } catch (e: Exception) { + Log.e(TAG, "Failed to clear GGUF tools", e) + } + } + // ==================== Diffusion Methods ==================== /** @@ -491,7 +510,7 @@ object LlmModelWorker { } awaitClose { - // Flow cancelled + stopDiffusionGeneration() } }.buffer(Channel.UNLIMITED) .flowOn(Dispatchers.IO) @@ -618,4 +637,4 @@ object LlmModelWorker { Log.i(TAG, "Embedding model download started in background") } -} \ No newline at end of file +} diff --git a/app/src/test/java/com/dark/tool_neuron/integration/McpServerTest.kt b/app/src/test/java/com/dark/tool_neuron/integration/McpServerTest.kt new file mode 100644 index 00000000..63bb7cef --- /dev/null +++ b/app/src/test/java/com/dark/tool_neuron/integration/McpServerTest.kt @@ -0,0 +1,235 @@ +package com.dark.tool_neuron.integration + +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import com.dark.tool_neuron.service.McpToolInfo +import com.dark.tool_neuron.service.McpToolMapper +import org.json.JSONArray +import org.json.JSONObject +import org.junit.Assert.* +import org.junit.Test + +/** + * Unit tests for MCP server-related functionality. + * These tests validate McpToolMapper functionality, JSON parsing, + * and configuration objects without connecting to real MCP servers. + */ +class McpServerTest { + + // Helper function to parse SSE response format. + // This is a simplified version for tests that extracts JSON from single-event SSE responses. + // The production code in McpClientService.parseSseResponse() handles multiple events and validates JSON. + private fun parseSseData(sseResponse: String): String { + val dataLine = sseResponse.lines().find { it.startsWith("data:") } + ?: return sseResponse + return dataLine.removePrefix("data:").trim() + } + + /** + * Test that McpServer can be created with the correct configuration + * for connecting to Zapier's MCP endpoint. + */ + @Test + fun createZapierMcpServerConfiguration() { + val zapierUrl = "https://mcp.zapier.com/api/v1/connect?token=example-token" + + val server = McpServer( + id = McpServer.generateId(), + name = "Zapier MCP", + url = zapierUrl, + transportType = McpTransportType.SSE, + apiKey = null, // Token is in URL + description = "Zapier MCP integration for Google Docs tools" + ) + + assertNotNull(server.id) + assertEquals("Zapier MCP", server.name) + assertEquals(zapierUrl, server.url) + assertEquals(McpTransportType.SSE, server.transportType) + assertTrue(server.isEnabled) + } + + /** + * Test parsing of MCP initialize response in SSE format using helper function. + */ + @Test + fun parseMcpInitializeResponse() { + val sseResponse = """event: message +data: {"result":{"protocolVersion":"2024-11-05","capabilities":{"tools":{"listChanged":true}},"serverInfo":{"name":"zapier","title":"Zapier MCP","version":"1.0.0"}},"jsonrpc":"2.0","id":1}""" + + // Use helper function to extract JSON from SSE format + val jsonStr = parseSseData(sseResponse) + val json = JSONObject(jsonStr) + + assertEquals("2.0", json.getString("jsonrpc")) + assertEquals(1, json.getInt("id")) + + val result = json.getJSONObject("result") + assertEquals("2024-11-05", result.getString("protocolVersion")) + + val serverInfo = result.getJSONObject("serverInfo") + assertEquals("zapier", serverInfo.getString("name")) + assertEquals("1.0.0", serverInfo.getString("version")) + } + + /** + * Test parsing of MCP tools/list response. + */ + @Test + fun parseMcpToolsListResponse() { + val sseResponse = """event: message +data: {"result":{"tools":[{"name":"google_docs_create_document_from_text","description":"Create a new document from text.","inputSchema":{"type":"object","properties":{"title":{"type":"string"}},"required":[]}}]},"jsonrpc":"2.0","id":2}""" + + // Use helper function to extract JSON from SSE format + val jsonStr = parseSseData(sseResponse) + val json = JSONObject(jsonStr) + + val result = json.getJSONObject("result") + val tools = result.getJSONArray("tools") + + assertEquals(1, tools.length()) + + val tool = tools.getJSONObject(0) + assertEquals("google_docs_create_document_from_text", tool.getString("name")) + assertEquals("Create a new document from text.", tool.getString("description")) + + val inputSchema = tool.getJSONObject("inputSchema") + assertEquals("object", inputSchema.getString("type")) + } + + /** + * Test that McpToolMapper correctly maps Zapier tools to the LLM format. + */ + @Test + fun mapZapierToolsToLlmFormat() { + val server = McpServer( + id = "zapier-1", + name = "Zapier MCP", + url = "https://mcp.zapier.com/api/v1/connect", + transportType = McpTransportType.SSE + ) + + val tools = listOf( + McpToolInfo( + name = "google_docs_create_document_from_text", + description = "Create a new document from text. Also supports limited HTML.", + inputSchema = """{"type":"object","properties":{"instructions":{"type":"string","description":"Instructions for running this tool"},"title":{"type":"string","description":"Document Name"},"file":{"type":"string","description":"Document Content"}},"required":["instructions"]}""" + ), + McpToolInfo( + name = "google_docs_find_a_document", + description = "Search for a specific document by name.", + inputSchema = """{"type":"object","properties":{"instructions":{"type":"string","description":"Instructions for running this tool"},"title":{"type":"string","description":"Document Name"}},"required":["instructions"]}""" + ) + ) + + val mapping = McpToolMapper.buildMapping(mapOf(server to tools)) + + // Check that tools JSON is valid + val toolsArray = JSONArray(mapping.toolsJson) + assertEquals(2, toolsArray.length()) + + // Check first tool structure + val firstTool = toolsArray.getJSONObject(0) + assertEquals("function", firstTool.getString("type")) + + val function = firstTool.getJSONObject("function") + // Verify exact tool name format: "zapier_mcp_google_docs_create_document_from_text" + assertEquals("zapier_mcp_google_docs_create_document_from_text", function.getString("name")) + assertTrue(function.has("description")) + + // Check tool registry size and contents + assertEquals(2, mapping.toolRegistry.size) + + // Verify exact tool name mapping in registry + val toolNames = mapping.toolRegistry.values.map { it.toolName }.toSet() + assertEquals( + setOf("google_docs_create_document_from_text", "google_docs_find_a_document"), + toolNames + ) + + // Verify all entries reference the same server + mapping.toolRegistry.values.forEach { entry -> + assertEquals(server, entry.server) + } + } + + /** + * Test that tool call request is properly formatted for MCP protocol. + */ + @Test + fun formatMcpToolCallRequest() { + val toolName = "google_docs_create_document_from_text" + val arguments = JSONObject().apply { + put("instructions", "Create a document titled 'Test' with content 'Hello World'") + put("output_hint", "just the document URL") + put("title", "Test Document") + put("file", "Hello World") + } + + // Use fixed ID for deterministic test behavior + val request = JSONObject().apply { + put("jsonrpc", "2.0") + put("id", 123L) + put("method", "tools/call") + put("params", JSONObject().apply { + put("name", toolName) + put("arguments", arguments) + }) + } + + assertEquals("2.0", request.getString("jsonrpc")) + assertEquals(123L, request.getLong("id")) + assertEquals("tools/call", request.getString("method")) + + val params = request.getJSONObject("params") + assertEquals(toolName, params.getString("name")) + + val args = params.getJSONObject("arguments") + assertEquals("Test Document", args.getString("title")) + assertEquals("Hello World", args.getString("file")) + } + + /** + * Test that both transport types can be assigned to McpServer. + */ + @Test + fun verifyTransportTypeAssignment() { + // SSE transport type + val sseServer = McpServer( + id = "server-sse", + name = "SSE Server", + url = "https://mcp.example.com/sse", + transportType = McpTransportType.SSE + ) + assertEquals(McpTransportType.SSE, sseServer.transportType) + + // Streamable HTTP transport type + val httpServer = McpServer( + id = "server-http", + name = "HTTP Server", + url = "https://mcp.example.com/http", + transportType = McpTransportType.STREAMABLE_HTTP + ) + assertEquals(McpTransportType.STREAMABLE_HTTP, httpServer.transportType) + } + + /** + * Test that server ID generation produces unique UUIDs. + */ + @Test + fun generateUniqueServerIds() { + val ids = mutableSetOf() + // Generate 10 IDs to demonstrate uniqueness with reasonable confidence + repeat(10) { + ids.add(McpServer.generateId()) + } + + // All 10 IDs should be unique + assertEquals(10, ids.size) + + // Verify IDs are valid UUID format (lowercase hexadecimal) + ids.forEach { id -> + assertTrue("ID should be a valid UUID format", id.matches(Regex("[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"))) + } + } +} diff --git a/app/src/test/java/com/dark/tool_neuron/service/McpToolMapperTest.kt b/app/src/test/java/com/dark/tool_neuron/service/McpToolMapperTest.kt new file mode 100644 index 00000000..2f975569 --- /dev/null +++ b/app/src/test/java/com/dark/tool_neuron/service/McpToolMapperTest.kt @@ -0,0 +1,38 @@ +package com.dark.tool_neuron.service + +import com.dark.tool_neuron.models.table_schema.McpServer +import com.dark.tool_neuron.models.table_schema.McpTransportType +import org.json.JSONArray +import org.junit.Assert.assertEquals +import org.junit.Assert.assertNotNull +import org.junit.Test + +class McpToolMapperTest { + @Test + fun buildMappingCreatesToolRegistry() { + val server = McpServer( + id = "server-1", + name = "Zapier MCP", + url = "https://example.com/mcp", + transportType = McpTransportType.SSE + ) + val tool = McpToolInfo( + name = "send-email", + description = "Send an email", + inputSchema = """{"type":"object","properties":{"to":{"type":"string"}}}""" + ) + + val mapping = McpToolMapper.buildMapping(mapOf(server to listOf(tool))) + val toolsArray = JSONArray(mapping.toolsJson) + + assertEquals(1, toolsArray.length()) + val function = toolsArray.getJSONObject(0).getJSONObject("function") + assertEquals("zapier_mcp_send_email", function.getString("name")) + assertEquals("object", function.getJSONObject("parameters").getString("type")) + + val reference = mapping.toolRegistry["zapier_mcp_send_email"] + assertNotNull(reference) + assertEquals(server, reference?.server) + assertEquals("send-email", reference?.toolName) + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index f709a8d7..d3811ee3 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -33,6 +33,7 @@ xz = "1.11" androidx-espresso-core = "3.7.0" androidx-junit = "1.3.0" junit = "4.13.2" +org-json = "20240303" [libraries] androidx-activity-compose = { group = "androidx.activity", name = "activity-compose", version.ref = "androidx-activity-compose" } @@ -77,6 +78,7 @@ xz = { group = "org.tukaani", name = "xz", version.ref = "xz" } androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "androidx-espresso-core" } androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" } junit = { group = "junit", name = "junit", version.ref = "junit" } +org-json = { group = "org.json", name = "json", version.ref = "org-json" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" } @@ -85,4 +87,4 @@ google-dagger-hilt = { id = "com.google.dagger.hilt.android", version.ref = "dag kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } kotlin-compose = { id = "org.jetbrains.kotlin.plugin.compose", version.ref = "kotlin" } kotlin-ksp = { id = "com.google.devtools.ksp", version.ref = "ksp" } -kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } \ No newline at end of file +kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } diff --git a/memory-vault/build.gradle.kts b/memory-vault/build.gradle.kts index 3dd1a1cd..1bec47f5 100644 --- a/memory-vault/build.gradle.kts +++ b/memory-vault/build.gradle.kts @@ -27,6 +27,10 @@ android { buildConfigField("String", "ALIAS", getProperty("ALIAS")) } + buildFeatures { + buildConfig = true + } + buildTypes { release { isMinifyEnabled = false @@ -71,4 +75,4 @@ fun getProperty(value: String): String { } else { System.getenv(value) ?: "\"sample_val\"" } -} \ No newline at end of file +}