From 67d651ed4a798c5ef9530fb8aa8f98385425161a Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 5 Jun 2026 15:01:24 -0700 Subject: [PATCH] feat: message and concurrent payload visitors --- temporal-sdk/build.gradle | 29 + .../payload/visitor/GeneratedVisitor.java | 11 + .../payload/visitor/MessageRegistryEntry.java | 18 + .../payload/visitor/MessageVisitor.java | 23 + .../visitor/MessageVisitorOptions.java | 59 ++ .../payload/visitor/MessageVisitors.java | 43 + .../payload/visitor/PayloadVisitor.java | 26 + .../visitor/PayloadVisitorContext.java | 44 + .../visitor/PayloadVisitorOptions.java | 141 +++ .../payload/visitor/PayloadVisitors.java | 54 ++ .../internal/payload/visitor/Traversal.java | 233 +++++ .../payload/visitor/VisitorException.java | 15 + .../visitor/gen/PayloadVisitorGenerator.java | 689 ++++++++++++++ .../payload/visitor/MessageVisitorTest.java | 193 ++++ .../payload/visitor/PayloadVisitorTest.java | 889 ++++++++++++++++++ .../payload/visitor/TestVisitorException.java | 12 + 16 files changed, 2479 insertions(+) create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/GeneratedVisitor.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageRegistryEntry.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitor.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java create mode 100644 temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/VisitorException.java create mode 100644 temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java create mode 100644 temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java create mode 100644 temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java create mode 100644 temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/TestVisitorException.java diff --git a/temporal-sdk/build.gradle b/temporal-sdk/build.gradle index 9e914e31f4..6b54852ed7 100644 --- a/temporal-sdk/build.gradle +++ b/temporal-sdk/build.gradle @@ -65,6 +65,35 @@ dependencies { java21Implementation files(sourceSets.main.output.classesDirs) { builtBy compileJava } } +// --- Payload visitor code generation --- +// A build-time generator (compiled in its own source set against the proto classes from +// temporal-serviceclient) emits GeneratedPayloadVisitor.java, which knows how to walk every +// payload-bearing Temporal API message. The generated source is added to the main source set. +sourceSets { + payloadVisitorGenerator { + java { + srcDirs = ['src/payloadVisitorGenerator/java'] + } + } +} + +dependencies { + payloadVisitorGeneratorImplementation project(':temporal-serviceclient') +} + +def generatedPayloadVisitorDir = layout.buildDirectory.dir('generated/payloadvisitor/java') + +def generatePayloadVisitor = tasks.register('generatePayloadVisitor', JavaExec) { + dependsOn 'compilePayloadVisitorGeneratorJava' + classpath = sourceSets.payloadVisitorGenerator.runtimeClasspath + mainClass = 'io.temporal.internal.payload.visitor.gen.PayloadVisitorGenerator' + args generatedPayloadVisitorDir.get().asFile.absolutePath + inputs.files(sourceSets.payloadVisitorGenerator.runtimeClasspath) + outputs.dir(generatedPayloadVisitorDir) +} + +sourceSets.main.java.srcDir(generatePayloadVisitor) + tasks.named('compileJava17Java') { options.release = 17 } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/GeneratedVisitor.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/GeneratedVisitor.java new file mode 100644 index 0000000000..4e8325ba54 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/GeneratedVisitor.java @@ -0,0 +1,11 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.Message; + +/** + * Generated traversal for one message type: visits the message's payload fields and recurses into + * its child messages. There is one per message type that can contain a payload. + */ +interface GeneratedVisitor { + void visit(Traversal traversal, Message.Builder builder); +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageRegistryEntry.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageRegistryEntry.java new file mode 100644 index 0000000000..2510878c62 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageRegistryEntry.java @@ -0,0 +1,18 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.Message; +import java.util.function.Supplier; + +/** + * How to traverse one message type, and how to create an empty builder for it (used to unpack + * {@code google.protobuf.Any} values). + */ +final class MessageRegistryEntry { + final GeneratedVisitor visitor; + final Supplier newBuilder; + + MessageRegistryEntry(GeneratedVisitor visitor, Supplier newBuilder) { + this.visitor = visitor; + this.newBuilder = newBuilder; + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitor.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitor.java new file mode 100644 index 0000000000..4bb6083e3e --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitor.java @@ -0,0 +1,23 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.MessageOrBuilder; + +/** + * Callback invoked when traversal enters a proto message. The returned value becomes the contextual + * value in scope for that message and everything within it, and is restored to the enclosing value + * once traversal leaves the message. The message is provided as a builder and may be inspected or + * mutated. + * + * @param type of the contextual value + */ +@FunctionalInterface +public interface MessageVisitor { + /** + * Handles a message being entered and returns the contextual value for it and its contents. + * + * @param current the contextual value in scope from the enclosing message + * @param message the message being entered + * @return the contextual value to use for this message and its contents + */ + C onEnter(C current, MessageOrBuilder message); +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java new file mode 100644 index 0000000000..da3f6eeada --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitorOptions.java @@ -0,0 +1,59 @@ +package io.temporal.internal.payload.visitor; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Options for visiting the messages of a proto message, without visiting individual payloads. + * + * @param type of the contextual value supplied to the visitor + */ +public final class MessageVisitorOptions { + private final @Nonnull MessageVisitor messageVisitor; + private final @Nullable C initialContext; + + private MessageVisitorOptions(Builder b) { + this.messageVisitor = b.messageVisitor; + this.initialContext = b.initialContext; + } + + public static Builder newBuilder() { + return new Builder<>(); + } + + @Nonnull + public MessageVisitor getMessageVisitor() { + return messageVisitor; + } + + @Nullable + public C getInitialContext() { + return initialContext; + } + + public static final class Builder { + private MessageVisitor messageVisitor; + private C initialContext; + + private Builder() {} + + /** Required. The message visitor. */ + public Builder setMessageVisitor(@Nonnull MessageVisitor messageVisitor) { + this.messageVisitor = messageVisitor; + return this; + } + + /** Optional. The contextual value in scope before any message is entered. */ + public Builder setInitialContext(@Nullable C initialContext) { + this.initialContext = initialContext; + return this; + } + + public MessageVisitorOptions build() { + if (messageVisitor == null) { + throw new IllegalArgumentException("messageVisitor is required"); + } + return new MessageVisitorOptions<>(this); + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java new file mode 100644 index 0000000000..7361f998d6 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/MessageVisitors.java @@ -0,0 +1,43 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.Message; +import javax.annotation.Nonnull; + +/** + * Visits the messages within a proto message, invoking the message visitor on each, without + * visiting individual payloads. Only messages that can contain a payload are visited. + * + *

This is an SDK-internal utility; it is not part of the public API. + */ +public final class MessageVisitors { + private MessageVisitors() {} + + /** Visits the messages in {@code builder} in place. */ + public static void visit( + @Nonnull Message.Builder builder, @Nonnull MessageVisitorOptions options) { + Traversal traversal = + new Traversal( + null, + options.getMessageVisitor(), + options.getInitialContext(), + /* skipSearchAttributes= */ false, + /* skipHeaders= */ false, + 1, + null, + GeneratedPayloadVisitor.REGISTRY); + traversal.dispatch(builder); + traversal.execute(); + } + + /** + * Visits the messages in {@code message}, returning a copy with any changes applied; the input is + * unchanged. + */ + @SuppressWarnings("unchecked") + public static T visit( + @Nonnull T message, @Nonnull MessageVisitorOptions options) { + Message.Builder builder = message.toBuilder(); + visit(builder, options); + return (T) builder.build(); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java new file mode 100644 index 0000000000..e293f48e5a --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitor.java @@ -0,0 +1,26 @@ +package io.temporal.internal.payload.visitor; + +import io.temporal.api.common.v1.Payload; +import java.util.List; + +/** + * Callback for a sequence of payloads found in a proto message. The returned list replaces those + * payloads; return the same list to leave them unchanged. + * + *

When the visited field holds a single payload the list has one element and the visitor must + * return exactly one payload. With a concurrency limit greater than one, visits may run on multiple + * threads, so implementations must be thread-safe. + * + * @param type of the contextual value supplied to each visit + */ +@FunctionalInterface +public interface PayloadVisitor { + /** + * Visits a sequence of payloads and returns their replacements. + * + * @param context the location of these payloads and the contextual value in scope + * @param payloads the payloads found at this location + * @return the replacement payloads + */ + List visit(PayloadVisitorContext context, List payloads); +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java new file mode 100644 index 0000000000..f65453cf04 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorContext.java @@ -0,0 +1,44 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.MessageOrBuilder; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * The context for one payload visitor call: the contextual value in scope and the message that + * contains the payloads being visited. + * + * @param type of the contextual value + */ +public final class PayloadVisitorContext { + private final @Nullable C context; + private final @Nonnull MessageOrBuilder parent; + private final boolean singlePayloadRequired; + + PayloadVisitorContext( + @Nullable C context, @Nonnull MessageOrBuilder parent, boolean singlePayloadRequired) { + this.context = context; + this.parent = parent; + this.singlePayloadRequired = singlePayloadRequired; + } + + /** The contextual value in scope at this location, or {@code null} if none. */ + @Nullable + public C getContext() { + return context; + } + + /** The message that directly contains the payloads being visited. */ + @Nonnull + public MessageOrBuilder getParent() { + return parent; + } + + /** + * Whether the visited field holds a single payload. When {@code true}, the visitor must return + * exactly one payload. + */ + public boolean isSinglePayloadRequired() { + return singlePayloadRequired; + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java new file mode 100644 index 0000000000..03c99f513f --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitorOptions.java @@ -0,0 +1,141 @@ +package io.temporal.internal.payload.visitor; + +import java.util.concurrent.Executor; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * Options for visiting the payloads of a proto message. + * + * @param type of the contextual value supplied to the visitor + */ +public final class PayloadVisitorOptions { + private final @Nonnull PayloadVisitor payloadVisitor; + private final @Nullable MessageVisitor messageVisitor; + private final @Nullable C initialContext; + private final boolean skipSearchAttributes; + private final boolean skipHeaders; + private final int concurrency; + private final @Nullable Executor executor; + + private PayloadVisitorOptions(Builder b) { + this.payloadVisitor = b.payloadVisitor; + this.messageVisitor = b.messageVisitor; + this.initialContext = b.initialContext; + this.skipSearchAttributes = b.skipSearchAttributes; + this.skipHeaders = b.skipHeaders; + this.concurrency = b.concurrency; + this.executor = b.executor; + } + + public static Builder newBuilder() { + return new Builder<>(); + } + + @Nonnull + public PayloadVisitor getPayloadVisitor() { + return payloadVisitor; + } + + @Nullable + public MessageVisitor getMessageVisitor() { + return messageVisitor; + } + + @Nullable + public C getInitialContext() { + return initialContext; + } + + /** Whether search attribute payloads are skipped. */ + public boolean isSkipSearchAttributes() { + return skipSearchAttributes; + } + + /** Whether header payloads are skipped. */ + public boolean isSkipHeaders() { + return skipHeaders; + } + + /** Maximum number of visits that may run concurrently; {@code 1} is sequential. */ + public int getConcurrency() { + return concurrency; + } + + /** Executor for concurrent visits; {@code null} when concurrency is {@code 1}. */ + @Nullable + public Executor getExecutor() { + return executor; + } + + public static final class Builder { + private PayloadVisitor payloadVisitor; + private MessageVisitor messageVisitor; + private C initialContext; + private boolean skipSearchAttributes; + private boolean skipHeaders; + private int concurrency = 1; + private Executor executor; + + private Builder() {} + + /** Required. The payload visitor. */ + public Builder setPayloadVisitor(@Nonnull PayloadVisitor payloadVisitor) { + this.payloadVisitor = payloadVisitor; + return this; + } + + /** Optional. A callback invoked when entering each message. */ + public Builder setMessageVisitor(@Nullable MessageVisitor messageVisitor) { + this.messageVisitor = messageVisitor; + return this; + } + + /** Optional. The contextual value in scope before any message is entered. */ + public Builder setInitialContext(@Nullable C initialContext) { + this.initialContext = initialContext; + return this; + } + + /** Whether to skip search attribute payloads. */ + public Builder setSkipSearchAttributes(boolean skipSearchAttributes) { + this.skipSearchAttributes = skipSearchAttributes; + return this; + } + + /** Whether to skip header payloads. */ + public Builder setSkipHeaders(boolean skipHeaders) { + this.skipHeaders = skipHeaders; + return this; + } + + /** + * Maximum number of concurrent visits; must be at least {@code 1} (the default, sequential). A + * value greater than {@code 1} requires an executor (see {@link #setExecutor}). + */ + public Builder setConcurrency(int concurrency) { + this.concurrency = concurrency; + return this; + } + + /** Executor for concurrent visits. Required when concurrency is greater than {@code 1}. */ + public Builder setExecutor(@Nullable Executor executor) { + this.executor = executor; + return this; + } + + public PayloadVisitorOptions build() { + if (payloadVisitor == null) { + throw new IllegalArgumentException("payloadVisitor is required"); + } + if (concurrency < 1) { + throw new IllegalArgumentException("concurrency must be at least 1, got " + concurrency); + } + if (concurrency > 1 && executor == null) { + throw new IllegalArgumentException( + "executor is required when concurrency is greater than 1"); + } + return new PayloadVisitorOptions<>(this); + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java new file mode 100644 index 0000000000..8cfe80f5a4 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/PayloadVisitors.java @@ -0,0 +1,54 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.Message; +import javax.annotation.Nonnull; + +/** + * Visits every payload within a proto message. A message with no payloads is returned unchanged. + * + *

This is an SDK-internal utility; it is not part of the public API. + * + *

{@code
+ * RespondWorkflowTaskCompletedRequest result =
+ *     PayloadVisitors.visit(
+ *         request,
+ *         PayloadVisitorOptions.newBuilder()
+ *             .setPayloadVisitor((ctx, payloads) -> encode(ctx.getContext(), payloads))
+ *             .setMessageVisitor((current, msg) -> msg instanceof Command.Builder
+ *                 ? CommandInfo.of((Command.Builder) msg) : current)
+ *             .setConcurrency(4)
+ *             .build());
+ * }
+ */ +public final class PayloadVisitors { + private PayloadVisitors() {} + + /** Visits the payloads in {@code builder} in place. */ + public static void visit( + @Nonnull Message.Builder builder, @Nonnull PayloadVisitorOptions options) { + Traversal traversal = + new Traversal( + options.getPayloadVisitor(), + options.getMessageVisitor(), + options.getInitialContext(), + options.isSkipSearchAttributes(), + options.isSkipHeaders(), + options.getConcurrency(), + options.getExecutor(), + GeneratedPayloadVisitor.REGISTRY); + traversal.dispatch(builder); + traversal.execute(); + } + + /** + * Visits the payloads in {@code message}, returning a copy with replacements applied; the input + * is unchanged. + */ + @SuppressWarnings("unchecked") + public static T visit( + @Nonnull T message, @Nonnull PayloadVisitorOptions options) { + Message.Builder builder = message.toBuilder(); + visit(builder, options); + return (T) builder.build(); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java new file mode 100644 index 0000000000..c64ce29ada --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/Traversal.java @@ -0,0 +1,233 @@ +package io.temporal.internal.payload.visitor; + +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import com.google.protobuf.MessageOrBuilder; +import io.temporal.api.common.v1.Payload; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +/** + * Mutable state for one traversal, called into by the generated per-message visitors. + * + *

A single-threaded walk records a visit job and a write-back for each payload sequence it + * finds; {@link #execute()} then runs the visitor calls (optionally with bounded concurrency) and + * finally applies the write-backs in walk order, so the non-thread-safe builders are never mutated + * concurrently. + */ +final class Traversal { + // The payload visitor is null for a message-only traversal (see MessageVisitors); in that case + // the payload seams are skipped and only the per-message MessageVisitor fires. + private final PayloadVisitor payloadVisitor; + private final MessageVisitor messageVisitor; + private final Map registry; + final boolean skipSearchAttributes; + final boolean skipHeaders; + private final int concurrency; + private final Executor executor; + + private final List jobs = new ArrayList<>(); + private final List writeBacks = new ArrayList<>(); + private Object currentContext; + + @SuppressWarnings("unchecked") + Traversal( + PayloadVisitor payloadVisitor, + MessageVisitor messageVisitor, + Object initialContext, + boolean skipSearchAttributes, + boolean skipHeaders, + int concurrency, + Executor executor, + Map registry) { + this.payloadVisitor = (PayloadVisitor) payloadVisitor; + this.messageVisitor = (MessageVisitor) messageVisitor; + this.currentContext = initialContext; + this.skipSearchAttributes = skipSearchAttributes; + this.skipHeaders = skipHeaders; + this.concurrency = concurrency; + this.executor = executor; + this.registry = registry; + } + + // --- Structural walk: called by generated code --- + + /** Dispatch to the generated visitor for {@code builder}'s type; no-op if it has no payloads. */ + void dispatch(Message.Builder builder) { + MessageRegistryEntry entry = registry.get(builder.getDescriptorForType().getFullName()); + if (entry != null) { + entry.visitor.visit(this, builder); + } + } + + /** + * Run the message visitor for {@code message}, narrowing the scoped context; returns the value to + * restore. + */ + Object enter(MessageOrBuilder message) { + Object previous = currentContext; + if (messageVisitor != null) { + currentContext = messageVisitor.onEnter(previous, message); + } + return previous; + } + + /** Restore the scoped context to {@code previous} when leaving a message's subtree. */ + void exit(Object previous) { + currentContext = previous; + } + + /** Record a visit of a payload sequence ({@code Payloads} or {@code repeated Payload}). */ + void payloads(MessageOrBuilder parent, List batch, Consumer> writeBack) { + if (payloadVisitor == null) { + return; // message-only traversal: payload seams are inert + } + LeafJob job = new LeafJob(batch, currentContext, parent, false); + jobs.add(job); + writeBacks.add(() -> writeBack.accept(job.result)); + } + + /** + * Record a visit of a singular payload field. The visitor must return exactly one payload for + * such a field (enforced in {@link #runJob}), which the consumer writes back. + */ + void singlePayload(MessageOrBuilder parent, Payload value, Consumer writeBack) { + if (payloadVisitor == null) { + return; // message-only traversal: payload seams are inert + } + LeafJob job = new LeafJob(Collections.singletonList(value), currentContext, parent, true); + jobs.add(job); + writeBacks.add(() -> writeBack.accept(job.result.get(0))); + } + + /** Append a deferred write-back, applied (single-threaded) after all visits and in walk order. */ + void deferWriteBack(Runnable writeBack) { + writeBacks.add(writeBack); + } + + /** Unpack a {@code google.protobuf.Any}, traverse its contents, and re-pack it after visits. */ + void any(Any.Builder anyBuilder) { + String typeUrl = anyBuilder.getTypeUrl(); + int slash = typeUrl.lastIndexOf('/'); + String fullName = slash >= 0 ? typeUrl.substring(slash + 1) : typeUrl; + MessageRegistryEntry entry = registry.get(fullName); + if (entry == null) { + // Unknown type, or a type with no payloads; leave the Any untouched. + return; + } + Message.Builder inner = entry.newBuilder.get(); + try { + inner.mergeFrom(anyBuilder.getValue()); + } catch (InvalidProtocolBufferException e) { + throw new VisitorException("failed to unpack Any of type " + fullName, e); + } + entry.visitor.visit(this, inner); + deferWriteBack(() -> anyBuilder.setValue(inner.build().toByteString())); + } + + // --- Execution: visitor calls (phase 2) then write-backs (phase 3) --- + + void execute() { + if (jobs.isEmpty()) { + return; + } + if (concurrency <= 1 || jobs.size() == 1) { + for (LeafJob job : jobs) { + runJob(job); + } + } else { + executeConcurrently(); + } + for (Runnable writeBack : writeBacks) { + writeBack.run(); + } + } + + private void executeConcurrently() { + // executor is non-null here: PayloadVisitorOptions requires it when concurrency > 1, and the + // message-only entry point always uses concurrency 1 (this path is not reached). + Executor pool = executor; + Semaphore semaphore = new Semaphore(concurrency); + AtomicReference firstError = new AtomicReference<>(); + List> futures = new ArrayList<>(jobs.size()); + for (LeafJob job : jobs) { + if (firstError.get() != null) { + break; + } + try { + semaphore.acquire(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + firstError.compareAndSet(null, e); + break; + } + if (firstError.get() != null) { + semaphore.release(); + break; + } + futures.add( + CompletableFuture.runAsync( + () -> { + try { + runJob(job); + } catch (Throwable t) { + firstError.compareAndSet(null, t); + } finally { + semaphore.release(); + } + }, + pool)); + } + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + Throwable error = firstError.get(); + if (error instanceof RuntimeException) { + throw (RuntimeException) error; + } + if (error instanceof Error) { + throw (Error) error; + } + if (error != null) { + // The only checked exception that can reach here is an InterruptedException from acquiring + // the semaphore. + throw new VisitorException("payload visit interrupted", error); + } + } + + private void runJob(LeafJob job) { + List result = + payloadVisitor.visit( + new PayloadVisitorContext<>(job.context, job.parent, job.single), job.input); + if (result == null) { + throw new IllegalStateException("payload visitor returned null"); + } + if (job.single && result.size() != 1) { + throw new IllegalStateException( + "single-payload field requires exactly 1 returned payload, got " + result.size()); + } + job.result = result; + } + + /** A single recorded visitor call and the slot its result is written into. */ + private static final class LeafJob { + final List input; + final Object context; + final MessageOrBuilder parent; + final boolean single; + volatile List result; + + LeafJob(List input, Object context, MessageOrBuilder parent, boolean single) { + this.input = input; + this.context = context; + this.parent = parent; + this.single = single; + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/VisitorException.java b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/VisitorException.java new file mode 100644 index 0000000000..e8e54efce6 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/payload/visitor/VisitorException.java @@ -0,0 +1,15 @@ +package io.temporal.internal.payload.visitor; + +/** + * Thrown when visiting the payloads or messages of a proto message fails. The original failure, if + * any, is available via {@link #getCause()}. + */ +public final class VisitorException extends RuntimeException { + VisitorException(String message, Throwable cause) { + super(message, cause); + } + + VisitorException(String message) { + super(message); + } +} diff --git a/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java b/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java new file mode 100644 index 0000000000..0c036abbd3 --- /dev/null +++ b/temporal-sdk/src/payloadVisitorGenerator/java/io/temporal/internal/payload/visitor/gen/PayloadVisitorGenerator.java @@ -0,0 +1,689 @@ +package io.temporal.internal.payload.visitor.gen; + +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.Descriptors.FileDescriptor; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +/** + * Build-time generator that emits {@code GeneratedPayloadVisitor}. + * + *

Starting from the WorkflowService and OperatorService file descriptors, it walks the proto + * closure, determines which message types can transitively contain a {@code Payload} (or a {@code + * google.protobuf.Any}, treated conservatively as payload-bearing), and emits one {@code visit_*} + * method per such type plus a registry keyed by descriptor full name. + * + *

Usage: {@code PayloadVisitorGenerator }. + */ +public final class PayloadVisitorGenerator { + + static final String PAYLOAD = "temporal.api.common.v1.Payload"; + static final String PAYLOADS = "temporal.api.common.v1.Payloads"; + static final String ANY = "google.protobuf.Any"; + static final String SEARCH_ATTRIBUTES = "temporal.api.common.v1.SearchAttributes"; + static final String HEADER = "temporal.api.common.v1.Header"; + + static final String OUTPUT_PACKAGE = "io.temporal.internal.payload.visitor"; + static final String OUTPUT_CLASS = "GeneratedPayloadVisitor"; + static final String PAYLOADS_FQN = "io.temporal.api.common.v1.Payloads"; + static final int REGISTER_CHUNK = 40; + + enum Kind { + SINGLE_PAYLOAD, + REPEATED_PAYLOAD, + PAYLOADS_SINGLE, + PAYLOADS_REPEATED, + MAP_PAYLOAD, + MAP_PAYLOADS, + ANY_SINGLE, + ANY_REPEATED, + MAP_ANY, + MESSAGE_SINGLE, + MESSAGE_REPEATED, + MAP_MESSAGE, + IGNORE + } + + /** Classification of a field: how it should be traversed, and its child message type if any. */ + static final class FieldPlan { + final Kind kind; + final Descriptor child; // child message descriptor for MESSAGE_* / MAP_MESSAGE, else null + + FieldPlan(Kind kind, Descriptor child) { + this.kind = kind; + this.child = child; + } + } + + public static void main(String[] args) throws Exception { + if (args.length < 1) { + throw new IllegalArgumentException("usage: PayloadVisitorGenerator "); + } + new PayloadVisitorGenerator().run(Paths.get(args[0])); + } + + private final Map reachesCache = new HashMap<>(); + + void run(Path outputRoot) throws IOException { + List seeds = + Arrays.asList( + io.temporal.api.workflowservice.v1.ServiceProto.getDescriptor(), + io.temporal.api.operatorservice.v1.ServiceProto.getDescriptor()); + + Set files = fileClosure(seeds); + List allMessages = collectMessages(files); + + computeReachability(allMessages); + + // Deterministic output order, keyed by descriptor full name. + Map emitted = new TreeMap<>(); + for (Descriptor d : allMessages) { + if (reaches(d)) { + emitted.put(d.getFullName(), d); + } + } + + verifyAccessors(emitted.values()); + + String source = emit(emitted); + + Path dir = outputRoot; + for (String part : OUTPUT_PACKAGE.split("\\.", -1)) { + dir = dir.resolve(part); + } + Files.createDirectories(dir); + Path out = dir.resolve(OUTPUT_CLASS + ".java"); + Files.write(out, source.getBytes(StandardCharsets.UTF_8)); + System.out.println("PayloadVisitorGenerator: wrote " + emitted.size() + " visitors to " + out); + } + + // --- Descriptor discovery --- + + private Set fileClosure(List seeds) { + Set seen = new LinkedHashSet<>(); + Deque queue = new ArrayDeque<>(seeds); + while (!queue.isEmpty()) { + FileDescriptor f = queue.poll(); + if (seen.add(f)) { + queue.addAll(f.getDependencies()); + } + } + return seen; + } + + private List collectMessages(Set files) { + List result = new ArrayList<>(); + for (FileDescriptor f : files) { + for (Descriptor d : f.getMessageTypes()) { + collectMessages(d, result); + } + } + return result; + } + + private void collectMessages(Descriptor d, List out) { + if (d.getOptions().getMapEntry()) { + return; // synthetic map entry type; handled via the owning map field + } + out.add(d); + for (Descriptor nested : d.getNestedTypes()) { + collectMessages(nested, out); + } + } + + // --- Reachability + classification --- + + /** + * Whether {@code d} can transitively contain a payload; valid after {@link #computeReachability}. + */ + private boolean reaches(Descriptor d) { + return Boolean.TRUE.equals(reachesCache.get(d.getFullName())); + } + + /** + * Least-fixpoint reachability over the message-reference graph. A message reaches a payload if it + * has a direct payload/Any field, or it references (via a message or map-message field) another + * message that does. Iterating to a fixpoint handles cycles (e.g. {@code Failure.cause}) + * correctly without over-approximating payload-free cycles. + */ + private void computeReachability(List all) { + Map> children = new HashMap<>(); + for (Descriptor d : all) { + boolean direct = false; + List refs = new ArrayList<>(); + for (FieldDescriptor f : d.getFields()) { + FieldPlan plan = classify(f); + switch (plan.kind) { + case SINGLE_PAYLOAD: + case REPEATED_PAYLOAD: + case PAYLOADS_SINGLE: + case PAYLOADS_REPEATED: + case MAP_PAYLOAD: + case MAP_PAYLOADS: + case ANY_SINGLE: + case ANY_REPEATED: + case MAP_ANY: + direct = true; + break; + case MESSAGE_SINGLE: + case MESSAGE_REPEATED: + case MAP_MESSAGE: + refs.add(plan.child); + break; + default: + break; + } + } + reachesCache.put(d.getFullName(), direct); + children.put(d.getFullName(), refs); + } + boolean changed = true; + while (changed) { + changed = false; + for (Descriptor d : all) { + if (reachesCache.get(d.getFullName())) { + continue; + } + for (Descriptor c : children.get(d.getFullName())) { + if (Boolean.TRUE.equals(reachesCache.get(c.getFullName()))) { + reachesCache.put(d.getFullName(), true); + changed = true; + break; + } + } + } + } + } + + static FieldPlan classify(FieldDescriptor f) { + if (f.isMapField()) { + FieldDescriptor value = f.getMessageType().findFieldByNumber(2); + if (value.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { + String name = value.getMessageType().getFullName(); + if (PAYLOAD.equals(name)) { + return new FieldPlan(Kind.MAP_PAYLOAD, null); + } + if (PAYLOADS.equals(name)) { + return new FieldPlan(Kind.MAP_PAYLOADS, null); + } + if (ANY.equals(name)) { + return new FieldPlan(Kind.MAP_ANY, null); + } + if (isTemporal(value.getMessageType())) { + return new FieldPlan(Kind.MAP_MESSAGE, value.getMessageType()); + } + return new FieldPlan(Kind.IGNORE, null); + } + return new FieldPlan(Kind.IGNORE, null); + } + if (f.getJavaType() != FieldDescriptor.JavaType.MESSAGE) { + return new FieldPlan(Kind.IGNORE, null); + } + String name = f.getMessageType().getFullName(); + boolean repeated = f.isRepeated(); + if (PAYLOAD.equals(name)) { + return new FieldPlan(repeated ? Kind.REPEATED_PAYLOAD : Kind.SINGLE_PAYLOAD, null); + } + if (PAYLOADS.equals(name)) { + return new FieldPlan(repeated ? Kind.PAYLOADS_REPEATED : Kind.PAYLOADS_SINGLE, null); + } + if (ANY.equals(name)) { + return new FieldPlan(repeated ? Kind.ANY_REPEATED : Kind.ANY_SINGLE, null); + } + if (!isTemporal(f.getMessageType())) { + // Non-Temporal messages (google well-known types, etc.) never carry Temporal payloads + // except inside an Any, which is handled separately. + return new FieldPlan(Kind.IGNORE, null); + } + return new FieldPlan( + repeated ? Kind.MESSAGE_REPEATED : Kind.MESSAGE_SINGLE, f.getMessageType()); + } + + static boolean isTemporal(Descriptor d) { + return d.getFullName().startsWith("temporal."); + } + + // --- Java naming --- + + /** Mirrors protoc's UnderscoresToCamelCase used to derive Java accessor names. */ + static String camel(String input, boolean capNext) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < input.length(); i++) { + char c = input.charAt(i); + if (c >= 'a' && c <= 'z') { + sb.append(capNext ? Character.toUpperCase(c) : c); + capNext = false; + } else if (c >= 'A' && c <= 'Z') { + if (i == 0 && !capNext) { + sb.append(Character.toLowerCase(c)); + } else { + sb.append(c); + } + capNext = false; + } else if (c >= '0' && c <= '9') { + sb.append(c); + capNext = true; + } else { + capNext = true; + } + } + return sb.toString(); + } + + /** Capitalized accessor base, e.g. {@code schedule_activity} -> {@code ScheduleActivity}. */ + static String base(FieldDescriptor f) { + return camel(f.getName(), true); + } + + static String javaPackage(Descriptor d) { + String pkg = d.getFile().getOptions().getJavaPackage(); + if (pkg == null || pkg.isEmpty()) { + throw new IllegalStateException("message " + d.getFullName() + " has no java_package option"); + } + return pkg; + } + + /** + * Source-form class name, e.g. {@code io.temporal.api.common.v1.Payload.ExternalPayloadDetails}. + */ + static String sourceClassName(Descriptor d) { + Deque names = new ArrayDeque<>(); + for (Descriptor c = d; c != null; c = c.getContainingType()) { + names.addFirst(c.getName()); + } + return javaPackage(d) + "." + String.join(".", names); + } + + /** Binary class name (nested types joined with {@code $}) for reflective verification. */ + static String binaryClassName(Descriptor d) { + Deque names = new ArrayDeque<>(); + for (Descriptor c = d; c != null; c = c.getContainingType()) { + names.addFirst(c.getName()); + } + return javaPackage(d) + "." + String.join("$", names); + } + + static String methodName(String full) { + return "visit_" + full.replace('.', '_'); + } + + // --- Accessor verification (build-time safety net for the naming rules) --- + + private void verifyAccessors(Iterable descriptors) { + for (Descriptor d : descriptors) { + Class builder; + try { + builder = Class.forName(binaryClassName(d) + "$Builder"); + } catch (ClassNotFoundException e) { + throw new IllegalStateException("no builder class for " + d.getFullName(), e); + } + Set methods = new HashSet<>(); + for (java.lang.reflect.Method m : builder.getMethods()) { + methods.add(m.getName()); + } + for (FieldDescriptor f : d.getFields()) { + for (String required : requiredMethods(classify(f).kind, base(f))) { + if (!methods.contains(required)) { + throw new IllegalStateException( + "expected builder method " + + builder.getName() + + "#" + + required + + " for field " + + d.getFullName() + + "." + + f.getName() + + " (" + + classify(f).kind + + ")"); + } + } + } + } + } + + static List requiredMethods(Kind kind, String base) { + switch (kind) { + case SINGLE_PAYLOAD: + return Arrays.asList("has" + base, "get" + base, "set" + base); + case REPEATED_PAYLOAD: + return Arrays.asList("get" + base + "List", "clear" + base, "addAll" + base); + case PAYLOADS_SINGLE: + return Arrays.asList("has" + base, "get" + base, "set" + base); + case PAYLOADS_REPEATED: + return Arrays.asList("get" + base, "get" + base + "Count", "set" + base); + case MAP_PAYLOAD: + return Arrays.asList("get" + base + "Map", "put" + base); + case MAP_PAYLOADS: + case MAP_ANY: + case MAP_MESSAGE: + return Arrays.asList("get" + base + "Map", "put" + base); + case ANY_SINGLE: + case MESSAGE_SINGLE: + return Arrays.asList("has" + base, "get" + base + "Builder"); + case ANY_REPEATED: + case MESSAGE_REPEATED: + return Arrays.asList("get" + base + "BuilderList"); + default: + return Arrays.asList(); + } + } + + // --- Emission --- + + private String emit(Map emitted) { + StringBuilder sb = new StringBuilder(); + sb.append("// Code generated by PayloadVisitorGenerator; DO NOT EDIT.\n"); + sb.append("package ").append(OUTPUT_PACKAGE).append(";\n\n"); + sb.append("import java.util.ArrayList;\n"); + sb.append("import java.util.HashMap;\n"); + sb.append("import java.util.Map;\n\n"); + sb.append("@SuppressWarnings(\"deprecation\")\n"); + sb.append("final class ").append(OUTPUT_CLASS).append(" {\n"); + sb.append(" private ").append(OUTPUT_CLASS).append("() {}\n\n"); + + List list = new ArrayList<>(emitted.values()); + + sb.append(" static final Map REGISTRY = buildRegistry();\n\n"); + sb.append(" private static Map buildRegistry() {\n"); + sb.append(" Map m = new HashMap<>(") + .append(Math.max(16, list.size() * 2)) + .append(");\n"); + int chunks = (list.size() + REGISTER_CHUNK - 1) / REGISTER_CHUNK; + for (int i = 0; i < chunks; i++) { + sb.append(" register").append(i).append("(m);\n"); + } + sb.append(" return m;\n"); + sb.append(" }\n\n"); + + for (int i = 0; i < chunks; i++) { + sb.append(" private static void register") + .append(i) + .append("(Map m) {\n"); + int start = i * REGISTER_CHUNK; + int end = Math.min(start + REGISTER_CHUNK, list.size()); + for (int j = start; j < end; j++) { + Descriptor d = list.get(j); + String src = sourceClassName(d); + String mn = methodName(d.getFullName()); + sb.append(" m.put(\"") + .append(d.getFullName()) + .append("\", new MessageRegistryEntry((t, b) -> ") + .append(mn) + .append("(t, (") + .append(src) + .append(".Builder) b), ") + .append(src) + .append("::newBuilder));\n"); + } + sb.append(" }\n\n"); + } + + for (Descriptor d : list) { + emitVisitMethod(sb, d); + } + + sb.append("}\n"); + return sb.toString(); + } + + private void emitVisitMethod(StringBuilder sb, Descriptor d) { + String src = sourceClassName(d); + sb.append(" static void ") + .append(methodName(d.getFullName())) + .append("(Traversal t, ") + .append(src) + .append(".Builder b) {\n"); + sb.append(" Object __c = t.enter(b);\n"); + int fi = 0; + for (FieldDescriptor f : d.getFields()) { + FieldPlan plan = classify(f); + if (plan.kind == Kind.IGNORE) { + continue; + } + if ((plan.kind == Kind.MESSAGE_SINGLE + || plan.kind == Kind.MESSAGE_REPEATED + || plan.kind == Kind.MAP_MESSAGE) + && !reaches(plan.child)) { + continue; + } + emitField(sb, f, plan, fi++); + } + sb.append(" t.exit(__c);\n"); + sb.append(" }\n\n"); + } + + private void emitField(StringBuilder sb, FieldDescriptor f, FieldPlan plan, int fi) { + String B = base(f); + String k = "__key" + fi; + String v = "__v" + fi; + switch (plan.kind) { + case SINGLE_PAYLOAD: + sb.append(" if (b.has").append(B).append("()) {\n"); + sb.append(" t.singlePayload(b, b.get") + .append(B) + .append("(), p -> b.set") + .append(B) + .append("(p));\n"); + sb.append(" }\n"); + break; + case REPEATED_PAYLOAD: + sb.append(" t.payloads(b, b.get").append(B).append("List(), pl -> {\n"); + sb.append(" b.clear").append(B).append("();\n"); + sb.append(" b.addAll").append(B).append("(pl);\n"); + sb.append(" });\n"); + break; + case PAYLOADS_SINGLE: + sb.append(" if (b.has").append(B).append("()) {\n"); + sb.append(" t.payloads(b, b.get").append(B).append("().getPayloadsList(),\n"); + sb.append(" pl -> b.set") + .append(B) + .append("(") + .append(PAYLOADS_FQN) + .append(".newBuilder().addAllPayloads(pl).build()));\n"); + sb.append(" }\n"); + break; + case PAYLOADS_REPEATED: + sb.append(" for (int ") + .append(v) + .append(" = 0; ") + .append(v) + .append(" < b.get") + .append(B) + .append("Count(); ") + .append(v) + .append("++) {\n"); + sb.append(" final int ").append(k).append(" = ").append(v).append(";\n"); + sb.append(" t.payloads(b, b.get") + .append(B) + .append("(") + .append(k) + .append(").getPayloadsList(),\n"); + sb.append(" pl -> b.set") + .append(B) + .append("(") + .append(k) + .append(", ") + .append(PAYLOADS_FQN) + .append(".newBuilder().addAllPayloads(pl).build()));\n"); + sb.append(" }\n"); + break; + case MAP_PAYLOAD: + sb.append(" for (String ") + .append(k) + .append(" : new ArrayList<>(b.get") + .append(B) + .append("Map().keySet())) {\n"); + sb.append(" final String ").append(v).append(" = ").append(k).append(";\n"); + sb.append(" t.singlePayload(b, b.get") + .append(B) + .append("Map().get(") + .append(v) + .append("), p -> b.put") + .append(B) + .append("(") + .append(v) + .append(", p));\n"); + sb.append(" }\n"); + break; + case MAP_PAYLOADS: + sb.append(" for (String ") + .append(k) + .append(" : new ArrayList<>(b.get") + .append(B) + .append("Map().keySet())) {\n"); + sb.append(" final String ").append(v).append(" = ").append(k).append(";\n"); + sb.append(" t.payloads(b, b.get") + .append(B) + .append("Map().get(") + .append(v) + .append(").getPayloadsList(),\n"); + sb.append(" pl -> b.put") + .append(B) + .append("(") + .append(v) + .append(", ") + .append(PAYLOADS_FQN) + .append(".newBuilder().addAllPayloads(pl).build()));\n"); + sb.append(" }\n"); + break; + case ANY_SINGLE: + sb.append(" if (b.has").append(B).append("()) {\n"); + sb.append(" t.any(b.get").append(B).append("Builder());\n"); + sb.append(" }\n"); + break; + case ANY_REPEATED: + sb.append(" for (com.google.protobuf.Any.Builder ") + .append(v) + .append(" : b.get") + .append(B) + .append("BuilderList()) {\n"); + sb.append(" t.any(").append(v).append(");\n"); + sb.append(" }\n"); + break; + case MAP_ANY: + sb.append(" for (String ") + .append(k) + .append(" : new ArrayList<>(b.get") + .append(B) + .append("Map().keySet())) {\n"); + sb.append(" final String ").append(v).append(" = ").append(k).append(";\n"); + sb.append(" com.google.protobuf.Any.Builder ab") + .append(fi) + .append(" = b.get") + .append(B) + .append("Map().get(") + .append(v) + .append(").toBuilder();\n"); + sb.append(" t.any(ab").append(fi).append(");\n"); + sb.append(" t.deferWriteBack(() -> b.put") + .append(B) + .append("(") + .append(v) + .append(", ab") + .append(fi) + .append(".build()));\n"); + sb.append(" }\n"); + break; + case MESSAGE_SINGLE: + { + String guard = childGuard(plan.child); + sb.append(" if (").append(guard).append("b.has").append(B).append("()) {\n"); + sb.append(" ") + .append(methodName(plan.child.getFullName())) + .append("(t, b.get") + .append(B) + .append("Builder());\n"); + sb.append(" }\n"); + } + break; + case MESSAGE_REPEATED: + { + String childSrc = sourceClassName(plan.child); + String guard = childGuard(plan.child); + if (!guard.isEmpty()) { + sb.append(" if (").append(guard.substring(0, guard.length() - 4)).append(") {\n "); + } + sb.append(" for (") + .append(childSrc) + .append(".Builder ") + .append(v) + .append(" : b.get") + .append(B) + .append("BuilderList()) {\n"); + sb.append(" ") + .append(methodName(plan.child.getFullName())) + .append("(t, ") + .append(v) + .append(");\n"); + sb.append(" }\n"); + if (!guard.isEmpty()) { + sb.append(" }\n"); + } + } + break; + case MAP_MESSAGE: + { + String childSrc = sourceClassName(plan.child); + sb.append(" for (String ") + .append(k) + .append(" : new ArrayList<>(b.get") + .append(B) + .append("Map().keySet())) {\n"); + sb.append(" final String ").append(v).append(" = ").append(k).append(";\n"); + sb.append(" ") + .append(childSrc) + .append(".Builder vb") + .append(fi) + .append(" = b.get") + .append(B) + .append("Map().get(") + .append(v) + .append(").toBuilder();\n"); + sb.append(" ") + .append(methodName(plan.child.getFullName())) + .append("(t, vb") + .append(fi) + .append(");\n"); + sb.append(" t.deferWriteBack(() -> b.put") + .append(B) + .append("(") + .append(v) + .append(", vb") + .append(fi) + .append(".build()));\n"); + sb.append(" }\n"); + } + break; + default: + break; + } + } + + /** Optional {@code &&}-terminated guard expression for SearchAttributes/Header skipping. */ + private String childGuard(Descriptor child) { + String name = child.getFullName(); + if (SEARCH_ATTRIBUTES.equals(name)) { + return "!t.skipSearchAttributes && "; + } + if (HEADER.equals(name)) { + return "!t.skipHeaders && "; + } + return ""; + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java new file mode 100644 index 0000000000..a33fecb725 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/MessageVisitorTest.java @@ -0,0 +1,193 @@ +package io.temporal.internal.payload.visitor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import com.google.protobuf.ByteString; +import io.temporal.api.command.v1.Command; +import io.temporal.api.command.v1.CompleteWorkflowExecutionCommandAttributes; +import io.temporal.api.command.v1.RecordMarkerCommandAttributes; +import io.temporal.api.command.v1.ScheduleActivityTaskCommandAttributes; +import io.temporal.api.common.v1.Memo; +import io.temporal.api.common.v1.Payload; +import io.temporal.api.common.v1.Payloads; +import io.temporal.api.enums.v1.CommandType; +import io.temporal.api.workflowservice.v1.RespondWorkflowTaskCompletedRequest; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; + +/** Tests for {@link MessageVisitors}: message traversal with scoped context, and validation. */ +public class MessageVisitorTest { + + static Payload p(String s) { + return Payload.newBuilder().setData(ByteString.copyFromUtf8(s)).build(); + } + + static Command activity(String id, String... inputs) { + Payloads.Builder in = Payloads.newBuilder(); + for (String s : inputs) { + in.addPayloads(p(s)); + } + return Command.newBuilder() + .setCommandType(CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK) + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder().setActivityId(id).setInput(in)) + .build(); + } + + @Test + public void visitsEachMessageWithScopedContext() { + // Three commands with distinct types exercise per-command scoping and scope restoration + // between siblings. + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(activity("a", "x")) + .addCommands( + Command.newBuilder() + .setCommandType(CommandType.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION) + .setCompleteWorkflowExecutionCommandAttributes( + CompleteWorkflowExecutionCommandAttributes.newBuilder()) + .build()) + .addCommands( + Command.newBuilder() + .setCommandType(CommandType.COMMAND_TYPE_RECORD_MARKER) + .setRecordMarkerCommandAttributes( + RecordMarkerCommandAttributes.newBuilder().setMarkerName("m")) + .build()) + .build(); + + // MessageVisitors traversal is single-threaded, so the entered messages have a stable order. + List entered = new ArrayList<>(); + List contextOnEnter = new ArrayList<>(); + + MessageVisitorOptions opts = + MessageVisitorOptions.newBuilder() + .setMessageVisitor( + (current, msg) -> { + entered.add(msg.getDescriptorForType().getFullName()); + contextOnEnter.add(current); + return msg instanceof Command.Builder + ? ((Command.Builder) msg).getCommandType() + : current; + }) + .build(); + + MessageVisitors.visit(request, opts); + + // Exact order: the root, then each (repeated) command followed by its oneof attributes message. + assertEquals( + Arrays.asList( + "temporal.api.workflowservice.v1.RespondWorkflowTaskCompletedRequest", + "temporal.api.command.v1.Command", + "temporal.api.command.v1.ScheduleActivityTaskCommandAttributes", + "temporal.api.command.v1.Command", + "temporal.api.command.v1.CompleteWorkflowExecutionCommandAttributes", + "temporal.api.command.v1.Command", + "temporal.api.command.v1.RecordMarkerCommandAttributes"), + entered); + // Each command is entered with scope reset to null (restored between siblings), then its own + // type flows down into its attributes message. + assertEquals( + Arrays.asList( + null, + null, + CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, + null, + CommandType.COMMAND_TYPE_COMPLETE_WORKFLOW_EXECUTION, + null, + CommandType.COMMAND_TYPE_RECORD_MARKER), + contextOnEnter); + } + + @Test + public void messageOnlyVisitorValidatesPerMessageType() { + // A message-only validator that mirrors a server-side limit: at most two memo fields. + int maxMemoFields = 2; + MessageVisitorOptions opts = + MessageVisitorOptions.newBuilder() + .setMessageVisitor( + (current, msg) -> { + if (msg instanceof Memo.Builder + && ((Memo.Builder) msg).getFieldsCount() > maxMemoFields) { + throw new TestVisitorException("too many memo fields"); + } + return current; + }) + .build(); + + Memo ok = Memo.newBuilder().putFields("a", p("1")).putFields("b", p("2")).build(); + MessageVisitors.visit(ok, opts); // no throw + + Memo tooMany = + Memo.newBuilder() + .putFields("a", p("1")) + .putFields("b", p("2")) + .putFields("c", p("3")) + .build(); + assertThrows(TestVisitorException.class, () -> MessageVisitors.visit(tooMany, opts)); + } + + @Test + public void visitsBuilderInPlace() { + Memo.Builder builder = Memo.newBuilder().putFields("k", p("v")); + List entered = new ArrayList<>(); + MessageVisitors.visit( + builder, + MessageVisitorOptions.newBuilder() + .setMessageVisitor( + (current, msg) -> { + entered.add(msg.getDescriptorForType().getFullName()); + return current; + }) + .build()); + assertEquals(Arrays.asList("temporal.api.common.v1.Memo"), entered); + } + + @Test + public void messageVisitorMutatesInPlace() { + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder().addCommands(activity("orig", "x")).build(); + + RespondWorkflowTaskCompletedRequest result = + MessageVisitors.visit( + request, + MessageVisitorOptions.newBuilder() + .setMessageVisitor( + (current, msg) -> { + if (msg instanceof ScheduleActivityTaskCommandAttributes.Builder) { + ((ScheduleActivityTaskCommandAttributes.Builder) msg) + .setActivityId("rewritten"); + } + return current; + }) + .build()); + + assertEquals( + "rewritten", + result.getCommands(0).getScheduleActivityTaskCommandAttributes().getActivityId()); + } + + @Test + public void initialContextObservedAtRoot() { + Memo memo = Memo.newBuilder().putFields("k", p("v")).build(); + List observed = new ArrayList<>(); + MessageVisitors.visit( + memo, + MessageVisitorOptions.newBuilder() + .setInitialContext("root") + .setMessageVisitor( + (current, msg) -> { + observed.add(current); + return current; + }) + .build()); + assertEquals(Arrays.asList("root"), observed); + } + + @Test + public void rejectsMissingMessageVisitor() { + assertThrows(IllegalArgumentException.class, () -> MessageVisitorOptions.newBuilder().build()); + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java new file mode 100644 index 0000000000..810f6c9a73 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/PayloadVisitorTest.java @@ -0,0 +1,889 @@ +package io.temporal.internal.payload.visitor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import io.temporal.api.command.v1.Command; +import io.temporal.api.command.v1.RecordMarkerCommandAttributes; +import io.temporal.api.command.v1.ScheduleActivityTaskCommandAttributes; +import io.temporal.api.command.v1.ScheduleNexusOperationCommandAttributes; +import io.temporal.api.command.v1.StartChildWorkflowExecutionCommandAttributes; +import io.temporal.api.command.v1.UpsertWorkflowSearchAttributesCommandAttributes; +import io.temporal.api.common.v1.Header; +import io.temporal.api.common.v1.Memo; +import io.temporal.api.common.v1.Payload; +import io.temporal.api.common.v1.Payloads; +import io.temporal.api.common.v1.SearchAttributes; +import io.temporal.api.enums.v1.CommandType; +import io.temporal.api.failure.v1.ApplicationFailureInfo; +import io.temporal.api.failure.v1.Failure; +import io.temporal.api.protocol.v1.Message; +import io.temporal.api.query.v1.WorkflowQueryResult; +import io.temporal.api.workflowservice.v1.CountWorkflowExecutionsResponse; +import io.temporal.api.workflowservice.v1.RespondWorkflowTaskCompletedRequest; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class PayloadVisitorTest { + + static Payload p(String s) { + return Payload.newBuilder().setData(ByteString.copyFromUtf8(s)).build(); + } + + static String data(Payload p) { + return p.getData().toStringUtf8(); + } + + static Payloads payloads(String... values) { + Payloads.Builder b = Payloads.newBuilder(); + for (String v : values) { + b.addPayloads(p(v)); + } + return b.build(); + } + + /** + * Records every payload seen (in order) and the number of visit calls, leaving payloads + * unchanged. + */ + static final class CollectingVisitor implements PayloadVisitor { + final List seen = Collections.synchronizedList(new ArrayList<>()); + final AtomicInteger visits = new AtomicInteger(); + + @Override + public List visit(PayloadVisitorContext ctx, List payloads) { + visits.incrementAndGet(); + for (Payload p : payloads) { + seen.add(data(p)); + } + return payloads; + } + } + + static PayloadVisitorOptions options(PayloadVisitor visitor) { + return PayloadVisitorOptions.newBuilder().setPayloadVisitor(visitor).build(); + } + + static Command scheduleActivity(String activityId, Payloads input) { + return Command.newBuilder() + .setCommandType(CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK) + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder() + .setActivityId(activityId) + .setInput(input)) + .build(); + } + + @Test + public void visitsAndMutatesAllPayloads() { + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(scheduleActivity("a", payloads("one", "two"))) + .addCommands(scheduleActivity("b", payloads("three"))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + RespondWorkflowTaskCompletedRequest unchanged = + PayloadVisitors.visit(request, options(counter)); + assertEquals(java.util.Arrays.asList("one", "two", "three"), counter.seen); + // Two Payloads sequences (one per command's input): two visits, three payloads. + assertEquals(2, counter.visits.get()); + assertEquals(request, unchanged); + + // Mutating: uppercase every payload's data. + RespondWorkflowTaskCompletedRequest mutated = + PayloadVisitors.visit( + request, + options( + (ctx, pls) -> + pls.stream() + .map( + p -> + p.toBuilder() + .setData(ByteString.copyFromUtf8(data(p).toUpperCase())) + .build()) + .collect(Collectors.toList()))); + assertEquals( + payloads("ONE", "TWO"), + mutated.getCommands(0).getScheduleActivityTaskCommandAttributes().getInput()); + assertEquals( + payloads("THREE"), + mutated.getCommands(1).getScheduleActivityTaskCommandAttributes().getInput()); + } + + @Test + public void visitsSinglePayloadField() { + Command command = + Command.newBuilder() + .setScheduleNexusOperationCommandAttributes( + ScheduleNexusOperationCommandAttributes.newBuilder().setInput(p("nexus"))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + Command result = PayloadVisitors.visit(command, options(counter)); + assertEquals(Collections.singletonList("nexus"), counter.seen); + + // singlePayloadRequired must be set for this field. + Command observed = + PayloadVisitors.visit( + command, + options( + (ctx, pls) -> { + assertTrue(ctx.isSinglePayloadRequired()); + return Collections.singletonList(p("replaced")); + })); + assertEquals( + "replaced", + observed.getScheduleNexusOperationCommandAttributes().getInput().getData().toStringUtf8()); + assertEquals(Collections.singletonList("nexus"), counter.seen); + assertEquals(command, result); + } + + @Test + public void singlePayloadFieldRequiresExactlyOnePayload() { + Command command = + Command.newBuilder() + .setScheduleNexusOperationCommandAttributes( + ScheduleNexusOperationCommandAttributes.newBuilder().setInput(p("nexus"))) + .build(); + + // Returning zero payloads for a single-payload field is rejected. + assertThrows( + IllegalStateException.class, + () -> PayloadVisitors.visit(command, options((ctx, pls) -> Collections.emptyList()))); + + // Returning more than one payload for a single-payload field is rejected. + assertThrows( + IllegalStateException.class, + () -> + PayloadVisitors.visit( + command, options((ctx, pls) -> java.util.Arrays.asList(p("a"), p("b"))))); + } + + @Test + public void visitsMapOfPayloads() { + Command command = + Command.newBuilder() + .setUpsertWorkflowSearchAttributesCommandAttributes( + UpsertWorkflowSearchAttributesCommandAttributes.newBuilder() + .setSearchAttributes( + SearchAttributes.newBuilder() + .putIndexedFields("k1", p("v1")) + .putIndexedFields("k2", p("v2")))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(command, options(counter)); + // A map is visited once per entry; map iteration order is unspecified, so + // assert the exact visit count and the value set rather than positional offsets. + assertEquals(2, counter.visits.get()); + assertEquals(new HashSet<>(java.util.Arrays.asList("v1", "v2")), new HashSet<>(counter.seen)); + + Command mutated = + PayloadVisitors.visit( + command, options((ctx, pls) -> Collections.singletonList(p(data(pls.get(0)) + "!")))); + Map fields = + mutated + .getUpsertWorkflowSearchAttributesCommandAttributes() + .getSearchAttributes() + .getIndexedFieldsMap(); + assertEquals("v1!", data(fields.get("k1"))); + assertEquals("v2!", data(fields.get("k2"))); + } + + @Test + public void visitsMapOfPayloadsSequences() { + Command command = + Command.newBuilder() + .setRecordMarkerCommandAttributes( + RecordMarkerCommandAttributes.newBuilder() + .setMarkerName("m") + .putDetails("d1", payloads("x", "y"))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(command, options(counter)); + // A single map entry is one sequence: one visit, two payloads. + assertEquals(1, counter.visits.get()); + assertEquals(java.util.Arrays.asList("x", "y"), counter.seen); + } + + @Test + public void visitsMapOfMessages() { + // RespondWorkflowTaskCompletedRequest.query_results is map, whose + // values carry payloads: exercises the map-of-messages path (rebuild value + write back). + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder() + .putQueryResults( + "q1", WorkflowQueryResult.newBuilder().setAnswer(payloads("a")).build()) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(request, options(counter)); + assertEquals(Collections.singletonList("a"), counter.seen); + + RespondWorkflowTaskCompletedRequest mutated = + PayloadVisitors.visit( + request, options((ctx, pls) -> Collections.singletonList(p(data(pls.get(0)) + "!")))); + assertEquals(payloads("a!"), mutated.getQueryResultsMap().get("q1").getAnswer()); + } + + @Test + public void visitsRepeatedPayloadField() { + // CountWorkflowExecutionsResponse.AggregationGroup.group_values is a bare repeated Payload. + CountWorkflowExecutionsResponse response = + CountWorkflowExecutionsResponse.newBuilder() + .addGroups( + CountWorkflowExecutionsResponse.AggregationGroup.newBuilder() + .addGroupValues(p("g1")) + .addGroupValues(p("g2"))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(response, options(counter)); + // A repeated Payload is one sequence: one visit, two payloads. + assertEquals(1, counter.visits.get()); + assertEquals(java.util.Arrays.asList("g1", "g2"), counter.seen); + + CountWorkflowExecutionsResponse mutated = + PayloadVisitors.visit( + response, + options( + (ctx, pls) -> + pls.stream().map(pl -> p(data(pl) + "!")).collect(Collectors.toList()))); + assertEquals("g1!", data(mutated.getGroups(0).getGroupValues(0))); + assertEquals("g2!", data(mutated.getGroups(0).getGroupValues(1))); + } + + @Test + public void visitsPayloadsAsRoot() { + Payloads root = payloads("a", "b"); + + CollectingVisitor counter = new CollectingVisitor(); + Payloads unchanged = PayloadVisitors.visit(root, options(counter)); + // The repeated Payload inside Payloads is one sequence: one visit, two payloads. + assertEquals(1, counter.visits.get()); + assertEquals(java.util.Arrays.asList("a", "b"), counter.seen); + assertEquals(root, unchanged); + + Payloads mutated = + PayloadVisitors.visit(root, options((ctx, pls) -> Collections.singletonList(p("x")))); + assertEquals(payloads("x"), mutated); + } + + @Test + public void visitsBuilderInPlace() { + RespondWorkflowTaskCompletedRequest.Builder builder = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(scheduleActivity("a", payloads("x"))); + + PayloadVisitors.visit(builder, options((ctx, pls) -> Collections.singletonList(p("y")))); + + assertEquals( + payloads("y"), + builder.getCommands(0).getScheduleActivityTaskCommandAttributes().getInput()); + } + + @Test + public void visitCountDistinguishesSequencesFromMapEntries() { + // A Memo with two fields is visited once per entry: two visits, two payloads. + Memo memo = Memo.newBuilder().putFields("a", p("1")).putFields("b", p("2")).build(); + CollectingVisitor memoVisitor = new CollectingVisitor(); + PayloadVisitors.visit(memo, options(memoVisitor)); + // Memo fields are a map (unspecified order): assert visit count and the value set. + assertEquals(2, memoVisitor.visits.get()); + assertEquals(new HashSet<>(java.util.Arrays.asList("1", "2")), new HashSet<>(memoVisitor.seen)); + + // An activity command with two inputs is one Payloads sequence: one visit, two payloads. + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder().setInput(payloads("1", "2"))) + .build(); + CollectingVisitor inputVisitor = new CollectingVisitor(); + PayloadVisitors.visit(command, options(inputVisitor)); + // A Payloads sequence preserves order, so assert the exact ordered values. + assertEquals(1, inputVisitor.visits.get()); + assertEquals(java.util.Arrays.asList("1", "2"), inputVisitor.seen); + } + + @Test + public void skipsSearchAttributesAndHeaders() { + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder() + .setInput(payloads("in")) + .setHeader(Header.newBuilder().putFields("h", p("hv")))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit( + command, + PayloadVisitorOptions.newBuilder().setPayloadVisitor(counter).setSkipHeaders(true).build()); + assertEquals(Collections.singletonList("in"), counter.seen); + + SearchAttributes sa = SearchAttributes.newBuilder().putIndexedFields("k", p("v")).build(); + Command withSa = + Command.newBuilder() + .setUpsertWorkflowSearchAttributesCommandAttributes( + UpsertWorkflowSearchAttributesCommandAttributes.newBuilder() + .setSearchAttributes(sa)) + .build(); + CollectingVisitor counter2 = new CollectingVisitor(); + PayloadVisitors.visit( + withSa, + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor(counter2) + .setSkipSearchAttributes(true) + .build()); + assertTrue(counter2.seen.isEmpty()); + } + + @Test + public void contextHookScopesPerCommand() { + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(scheduleActivity("a", payloads("act"))) + .addCommands( + Command.newBuilder() + .setCommandType(CommandType.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION) + .setStartChildWorkflowExecutionCommandAttributes( + StartChildWorkflowExecutionCommandAttributes.newBuilder() + .setInput(payloads("child"))) + .build()) + .build(); + + // Default concurrency (1) visits the two repeated commands in declaration order. + List dataOrder = new ArrayList<>(); + List contextOrder = new ArrayList<>(); + PayloadVisitorOptions opts = + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor( + (ctx, pls) -> { + for (Payload p : pls) { + dataOrder.add(data(p)); + contextOrder.add(ctx.getContext()); + } + return pls; + }) + .setMessageVisitor( + (current, msg) -> + msg instanceof Command.Builder + ? ((Command.Builder) msg).getCommandType() + : current) + .build(); + + PayloadVisitors.visit(request, opts); + assertEquals(java.util.Arrays.asList("act", "child"), dataOrder); + assertEquals( + java.util.Arrays.asList( + CommandType.COMMAND_TYPE_SCHEDULE_ACTIVITY_TASK, + CommandType.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION), + contextOrder); + } + + @Test + public void initialContextUsedOutsideAnyHookScope() { + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder().setInput(payloads("x"))) + .build(); + List observed = new ArrayList<>(); + PayloadVisitorOptions opts = + PayloadVisitorOptions.newBuilder() + .setInitialContext("root") + .setPayloadVisitor( + (ctx, pls) -> { + observed.add(ctx.getContext()); + return pls; + }) + .build(); + PayloadVisitors.visit(command, opts); + assertEquals(Collections.singletonList("root"), observed); + } + + @Test + public void limitsStyleValidatorComposesBothSeams() { + // The payload-limits feature is a read-only validator using both seams of PayloadVisitors: + // - per-payload (blob size) on the payload seam + // - per-message (e.g. memo field count) on the message seam + int blobLimit = 8; + int maxMemoFields = 2; + + PayloadVisitorOptions validator = + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor( + (ctx, pls) -> { + for (Payload pl : pls) { + if (pl.getData().size() > blobLimit) { + throw new TestVisitorException("blob too large"); + } + } + return pls; // read-only + }) + .setMessageVisitor( + (current, msg) -> { + if (msg instanceof Memo.Builder + && ((Memo.Builder) msg).getFieldsCount() > maxMemoFields) { + throw new TestVisitorException("too many memo fields"); + } + return current; + }) + .build(); + + // Within limits: passes. + RespondWorkflowTaskCompletedRequest ok = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(scheduleActivity("a", payloads("small"))) + .build(); + PayloadVisitors.visit(ok, validator); + + // Oversized blob trips the payload seam. + RespondWorkflowTaskCompletedRequest bigBlob = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(scheduleActivity("a", payloads("way-too-large-payload"))) + .build(); + assertThrows(TestVisitorException.class, () -> PayloadVisitors.visit(bigBlob, validator)); + } + + @Test + public void traversesCyclicFailureChain() { + Failure failure = + Failure.newBuilder() + .setMessage("outer") + .setApplicationFailureInfo( + ApplicationFailureInfo.newBuilder().setDetails(payloads("d1"))) + .setCause( + Failure.newBuilder() + .setMessage("inner") + .setApplicationFailureInfo( + ApplicationFailureInfo.newBuilder().setDetails(payloads("d2")))) + .build(); + + CollectingVisitor counter = new CollectingVisitor(); + PayloadVisitors.visit(failure, options(counter)); + assertEquals(2, counter.seen.size()); + assertTrue(counter.seen.contains("d1")); + assertTrue(counter.seen.contains("d2")); + } + + @Test + public void roundTripsPayloadInsideAny() throws Exception { + Memo memo = Memo.newBuilder().putFields("k", p("inside-any")).build(); + Message message = Message.newBuilder().setBody(Any.pack(memo)).build(); + + CollectingVisitor counter = new CollectingVisitor(); + Message result = PayloadVisitors.visit(message, options(counter)); + assertEquals(Collections.singletonList("inside-any"), counter.seen); + + // Mutating through the Any re-packs correctly. + Message mutated = + PayloadVisitors.visit( + message, options((ctx, pls) -> Collections.singletonList(p("changed")))); + Memo unpacked = mutated.getBody().unpack(Memo.class); + assertEquals("changed", data(unpacked.getFieldsMap().get("k"))); + // Unrelated content unchanged. + assertEquals(result.getBody().getTypeUrl(), mutated.getBody().getTypeUrl()); + } + + @Test + public void leavesUnknownAnyUntouched() throws Exception { + // An Any whose type is not in the registry is left as-is. + Message message = + Message.newBuilder() + .setBody( + Any.newBuilder() + .setTypeUrl("type.googleapis.com/some.unknown.Type") + .setValue(ByteString.copyFromUtf8("opaque"))) + .build(); + CollectingVisitor counter = new CollectingVisitor(); + Message result = PayloadVisitors.visit(message, options(counter)); + assertTrue(counter.seen.isEmpty()); + assertEquals(message, result); + } + + @Test + public void messageWithoutPayloadsReturnedUnchanged() { + Command command = + Command.newBuilder() + .setCancelWorkflowExecutionCommandAttributes( + io.temporal.api.command.v1.CancelWorkflowExecutionCommandAttributes.newBuilder()) + .build(); + CollectingVisitor counter = new CollectingVisitor(); + Command result = PayloadVisitors.visit(command, options(counter)); + assertTrue(counter.seen.isEmpty()); + assertEquals(command, result); + } + + @Test + public void propagatesVisitorError() { + RespondWorkflowTaskCompletedRequest request = + RespondWorkflowTaskCompletedRequest.newBuilder() + .addCommands(scheduleActivity("a", payloads("x"))) + .build(); + TestVisitorException boom = new TestVisitorException("boom"); + TestVisitorException thrown = + assertThrows( + TestVisitorException.class, + () -> + PayloadVisitors.visit( + request, + options( + (ctx, pls) -> { + throw boom; + }))); + assertSame(boom, thrown); + } + + @Test + public void contextHookErrorPropagates() { + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder().setInput(payloads("x"))) + .build(); + TestVisitorException boom = new TestVisitorException("hook boom"); + TestVisitorException thrown = + assertThrows( + TestVisitorException.class, + () -> + PayloadVisitors.visit( + command, + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor((ctx, pls) -> pls) + .setMessageVisitor( + (current, msg) -> { + throw boom; + }) + .build())); + assertSame(boom, thrown); + } + + @Test + public void registryCoversPayloadBearingTypesAndExcludesOthers() { + Map registry = GeneratedPayloadVisitor.REGISTRY; + // Representative payload-bearing types must be present. + for (String fullName : + new String[] { + "temporal.api.command.v1.Command", + "temporal.api.command.v1.ScheduleActivityTaskCommandAttributes", + "temporal.api.command.v1.RecordMarkerCommandAttributes", + "temporal.api.failure.v1.Failure", + "temporal.api.common.v1.Memo", + "temporal.api.common.v1.Header", + "temporal.api.common.v1.SearchAttributes", + "temporal.api.common.v1.Payloads", + "temporal.api.protocol.v1.Message", + "temporal.api.workflowservice.v1.RespondWorkflowTaskCompletedRequest" + }) { + assertTrue("missing visitor for " + fullName, registry.containsKey(fullName)); + } + // Types without reachable payloads must be excluded. + assertFalse(registry.containsKey("temporal.api.common.v1.Payload")); + assertFalse(registry.containsKey("temporal.api.common.v1.WorkflowExecution")); + assertFalse(registry.containsKey("google.protobuf.DescriptorProto")); + } + + @Test + public void rejectsConcurrencyBelowOne() { + assertThrows( + IllegalArgumentException.class, + () -> + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor((ctx, pls) -> pls) + .setConcurrency(0) + .build()); + } + + @Test + public void rejectsMissingVisitor() { + assertThrows(IllegalArgumentException.class, () -> PayloadVisitorOptions.newBuilder().build()); + } + + @Test + public void nullReturnFromVisitorFails() { + Command command = + Command.newBuilder() + .setScheduleActivityTaskCommandAttributes( + ScheduleActivityTaskCommandAttributes.newBuilder().setInput(payloads("x"))) + .build(); + assertThrows( + IllegalStateException.class, + () -> PayloadVisitors.visit(command, options((ctx, pls) -> null))); + } + + // --- Concurrency and executor --- + + private ExecutorService executor; + + @Before + public void setUpExecutor() { + executor = Executors.newCachedThreadPool(); + } + + @After + public void tearDownExecutor() { + executor.shutdownNow(); + } + + @Test + public void rejectsConcurrencyAboveOneWithoutExecutor() { + assertThrows( + IllegalArgumentException.class, + () -> + PayloadVisitorOptions.newBuilder() + .setPayloadVisitor((ctx, pls) -> pls) + .setConcurrency(2) + .build()); + } + + /** + * A request with {@code n} activity commands, each carrying one distinct single-payload input. + */ + static RespondWorkflowTaskCompletedRequest requestWithInputs(int n) { + RespondWorkflowTaskCompletedRequest.Builder b = + RespondWorkflowTaskCompletedRequest.newBuilder(); + for (int i = 0; i < n; i++) { + b.addCommands(scheduleActivity("a" + i, payloads("p" + i))); + } + return b.build(); + } + + @Test + public void concurrencyEqualToWorkAllowsFullOverlap() throws Exception { + int n = 4; + RespondWorkflowTaskCompletedRequest request = requestWithInputs(n); + CyclicBarrier barrier = new CyclicBarrier(n); + + // Each of the n visits must reach the barrier simultaneously, proving n concurrent visits. + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(n) + .setExecutor(executor) + .setPayloadVisitor( + (ctx, pls) -> { + try { + barrier.await(5, TimeUnit.SECONDS); + } catch (InterruptedException | BrokenBarrierException | TimeoutException e) { + throw new RuntimeException(e); + } + return pls; + }) + .build()); + } + + @Test + public void boundedConcurrencyNeverExceedsLimit() { + int n = 8; + int limit = 3; + RespondWorkflowTaskCompletedRequest request = requestWithInputs(n); + + AtomicInteger inFlight = new AtomicInteger(); + AtomicInteger maxInFlight = new AtomicInteger(); + + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(limit) + .setExecutor(executor) + .setPayloadVisitor( + (ctx, pls) -> { + int now = inFlight.incrementAndGet(); + maxInFlight.accumulateAndGet(now, Math::max); + try { + Thread.sleep(20); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + inFlight.decrementAndGet(); + return pls; + }) + .build()); + + assertTrue( + "max in-flight " + maxInFlight.get() + " > limit " + limit, maxInFlight.get() <= limit); + assertTrue("expected some overlap, got " + maxInFlight.get(), maxInFlight.get() > 1); + } + + @Test + public void sequentialConcurrencyVisitsOneAtATimeInOrder() { + int n = 5; + RespondWorkflowTaskCompletedRequest request = requestWithInputs(n); + + AtomicInteger inFlight = new AtomicInteger(); + AtomicInteger maxInFlight = new AtomicInteger(); + List order = new ArrayList<>(); + + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(1) + .setPayloadVisitor( + (ctx, pls) -> { + int now = inFlight.incrementAndGet(); + maxInFlight.accumulateAndGet(now, Math::max); + order.add(pls.get(0).getData().toStringUtf8()); + inFlight.decrementAndGet(); + return pls; + }) + .build()); + + assertEquals(1, maxInFlight.get()); + List expected = new ArrayList<>(); + for (int i = 0; i < n; i++) { + expected.add("p" + i); + } + assertEquals(expected, order); + } + + @Test + public void concurrentVisitorErrorPropagates() { + RespondWorkflowTaskCompletedRequest request = requestWithInputs(8); + TestVisitorException boom = new TestVisitorException("boom"); + TestVisitorException thrown = + assertThrows( + TestVisitorException.class, + () -> + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(4) + .setExecutor(executor) + .setPayloadVisitor( + (ctx, pls) -> { + if (pls.get(0).getData().toStringUtf8().equals("p5")) { + throw boom; + } + return pls; + }) + .build())); + assertSame(boom, thrown); + } + + @Test + public void mutationsAppliedCorrectlyUnderConcurrency() { + int n = 16; + RespondWorkflowTaskCompletedRequest request = requestWithInputs(n); + RespondWorkflowTaskCompletedRequest mutated = + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(8) + .setExecutor(executor) + .setPayloadVisitor( + (ctx, pls) -> { + Payload p = pls.get(0); + return Collections.singletonList( + p.toBuilder() + .setData(ByteString.copyFromUtf8(p.getData().toStringUtf8() + "!")) + .build()); + }) + .build()); + for (int i = 0; i < n; i++) { + assertEquals( + "p" + i + "!", + mutated + .getCommands(i) + .getScheduleActivityTaskCommandAttributes() + .getInput() + .getPayloads(0) + .getData() + .toStringUtf8()); + } + } + + @Test + public void concurrentVisitsRunOnProvidedExecutor() throws InterruptedException { + AtomicInteger threadSeq = new AtomicInteger(); + ThreadFactory factory = + r -> { + Thread t = new Thread(r); + t.setName("pv-exec-" + threadSeq.incrementAndGet()); + return t; + }; + ExecutorService pool = Executors.newFixedThreadPool(4, factory); + try { + RespondWorkflowTaskCompletedRequest request = requestWithInputs(12); + Set visitThreads = ConcurrentHashMap.newKeySet(); + + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(4) + .setExecutor(pool) + .setPayloadVisitor( + (ctx, pls) -> { + visitThreads.add(Thread.currentThread().getName()); + return pls; + }) + .build()); + + assertTrue("no visits recorded", !visitThreads.isEmpty()); + for (String name : visitThreads) { + assertTrue("visit ran off the provided executor: " + name, name.startsWith("pv-exec-")); + } + } finally { + pool.shutdownNow(); + pool.awaitTermination(5, TimeUnit.SECONDS); + } + } + + @Test + public void sequentialConcurrencyIgnoresExecutorAndRunsOnCallingThread() { + // With concurrency == 1 the executor must never be touched and visits run inline. + AtomicInteger submittedToExecutor = new AtomicInteger(); + Executor tripwire = + command -> { + submittedToExecutor.incrementAndGet(); + command.run(); + }; + + RespondWorkflowTaskCompletedRequest request = requestWithInputs(5); + String callingThread = Thread.currentThread().getName(); + AtomicReference sawDifferentThread = new AtomicReference<>(); + + PayloadVisitors.visit( + request, + PayloadVisitorOptions.newBuilder() + .setConcurrency(1) + .setExecutor(tripwire) + .setPayloadVisitor( + (ctx, pls) -> { + if (!Thread.currentThread().getName().equals(callingThread)) { + sawDifferentThread.set(Thread.currentThread().getName()); + } + return pls; + }) + .build()); + + assertEquals("executor was used for sequential traversal", 0, submittedToExecutor.get()); + assertEquals("visit ran off the calling thread", null, sawDifferentThread.get()); + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/TestVisitorException.java b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/TestVisitorException.java new file mode 100644 index 0000000000..ecbfab7829 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/internal/payload/visitor/TestVisitorException.java @@ -0,0 +1,12 @@ +package io.temporal.internal.payload.visitor; + +/** + * Exception thrown only by test visitor/message callbacks. Using a dedicated type keeps "the + * visitor threw" assertions from being satisfied by an unrelated {@link IllegalStateException} that + * production code might raise. + */ +class TestVisitorException extends RuntimeException { + TestVisitorException(String message) { + super(message); + } +}