diff --git a/sdk-serde-kotlinx/src/main/kotlin/dev/restate/serde/kotlinx/DefaultJsonSchemaFactory.kt b/sdk-serde-kotlinx/src/main/kotlin/dev/restate/serde/kotlinx/DefaultJsonSchemaFactory.kt index f1b1c88b..a14368ac 100644 --- a/sdk-serde-kotlinx/src/main/kotlin/dev/restate/serde/kotlinx/DefaultJsonSchemaFactory.kt +++ b/sdk-serde-kotlinx/src/main/kotlin/dev/restate/serde/kotlinx/DefaultJsonSchemaFactory.kt @@ -23,7 +23,6 @@ import io.github.smiley4.schemakenerator.jsonschema.jsonDsl.array import io.github.smiley4.schemakenerator.serialization.SerializationSteps.analyzeTypeUsingKotlinxSerialization import io.github.smiley4.schemakenerator.serialization.SerializationSteps.initial import io.github.smiley4.schemakenerator.serialization.SerializationSteps.renameMembers -import kotlin.collections.set import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.KSerializer import kotlinx.serialization.json.Json @@ -67,7 +66,7 @@ object DefaultJsonSchemaFactory : KotlinSerializationSerdeFactory.JsonSchemaFact // Add $schema rootNode.properties.put( "\$schema", - JsonTextValue("https://json-schema.org/draft/2020-12/schema"), + JsonTextValue("https://json-schema.org/draft/2020-12/schema"), ) // Add $defs val definitions = @@ -109,7 +108,7 @@ object DefaultJsonSchemaFactory : KotlinSerializationSerdeFactory.JsonSchemaFact (schema.json as JsonObject).properties["title"] == null ) { (schema.json as JsonObject).properties["title"] = - JsonTextValue(TitleBuilder.BUILDER_SIMPLE(schema.typeData, this.typeDataById)) + JsonTextValue(TitleBuilder.BUILDER_SIMPLE(schema.typeData, this.typeDataById)) } } } @@ -126,8 +125,10 @@ object DefaultJsonSchemaFactory : KotlinSerializationSerdeFactory.JsonSchemaFact private fun JsonObject.fixRefsPrefix(rootDefinition: String) { this.properties.computeIfPresent("\$ref") { key, node -> if (node is JsonTextValue) { - if (node.value.startsWith(rootDefinition)) { - JsonTextValue("#/" + node.value.removePrefix(rootDefinition)) + if (node.value == rootDefinition) { + JsonTextValue("#/") + } else if (node.value.startsWith("$rootDefinition/")) { + JsonTextValue("#/" + node.value.removePrefix("$rootDefinition/")) } else { JsonTextValue("#/\$defs/" + node.value.removePrefix("#/definitions/")) } diff --git a/sdk-serde-kotlinx/src/test/kotlin/dev/restate/serde/kotlinx/KotlinxSerdeTest.kt b/sdk-serde-kotlinx/src/test/kotlin/dev/restate/serde/kotlinx/KotlinxSerdeTest.kt index 4583a389..a5de8386 100644 --- a/sdk-serde-kotlinx/src/test/kotlin/dev/restate/serde/kotlinx/KotlinxSerdeTest.kt +++ b/sdk-serde-kotlinx/src/test/kotlin/dev/restate/serde/kotlinx/KotlinxSerdeTest.kt @@ -176,6 +176,71 @@ class KotlinxSerdeTest { ) } + @Serializable + enum class TaskStatus { + TODO, + IN_PROGRESS, + DONE, + } + + @Serializable + enum class PriorityOrder { + HIGH, + MID, + LOW, + } + + @Serializable + data class Task(val title: String, val status: TaskStatus, val priority: PriorityOrder) + + @Test + fun schemaGenWithExternalEnum() { + testSchemaGen( + $$""" + { + "type": "object", + "required": [ + "title", + "status", + "priority" + ], + "properties": { + "title": { + "type": "string" + }, + "priority": { + "$ref": "#/$defs/PriorityOrder" + }, + "status": { + "$ref": "#/$defs/TaskStatus" + } + }, + "title": "Task", + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$defs": { + "TaskStatus": { + "enum": [ + "TODO", + "IN_PROGRESS", + "DONE" + ], + "title": "TaskStatus" + }, + "PriorityOrder": { + "enum": [ + "HIGH", + "MID", + "LOW" + ], + "title": "PriorityOrder" + } + } + } + """ + .trimIndent() + ) + } + inline fun testSchemaGen(expectedSchema: String) { val expectedJsonElement = Json.decodeFromString(expectedSchema) val actualSchema =