Skip to content

Commit 4d1cee3

Browse files
authored
fix(core): harden parsing and cache edge cases (spiffe#399)
Signed-off-by: Max Lambrecht <maxlambrecht@gmail.com>
1 parent 8bf98fb commit 4d1cee3

10 files changed

Lines changed: 219 additions & 11 deletions

File tree

java-spiffe-core/src/main/java/io/spiffe/bundle/jwtbundle/JwtBundleSet.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public static JwtBundleSet of(Collection<JwtBundle> bundles) {
3939
}
4040
final Map<TrustDomain, JwtBundle> bundleMap = new ConcurrentHashMap<>();
4141
for (JwtBundle bundle : bundles) {
42+
Objects.requireNonNull(bundle, "bundle must not be null");
4243
bundleMap.put(bundle.getTrustDomain(), bundle);
4344
}
4445
return new JwtBundleSet(bundleMap);

java-spiffe-core/src/main/java/io/spiffe/bundle/x509bundle/X509BundleSet.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public static X509BundleSet of(Collection<X509Bundle> bundles) {
4040

4141
final Map<TrustDomain, X509Bundle> bundleMap = new ConcurrentHashMap<>();
4242
for (X509Bundle bundle : bundles) {
43+
Objects.requireNonNull(bundle, "bundle must not be null");
4344
bundleMap.put(bundle.getTrustDomain(), bundle);
4445
}
4546
return new X509BundleSet(bundleMap);

java-spiffe-core/src/main/java/io/spiffe/spiffeid/TrustDomain.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,11 @@ public static TrustDomain parse(String idOrName) {
3434
throw new IllegalArgumentException("Trust domain is missing");
3535
}
3636

37-
// Something looks kinda like a scheme separator, let's try to parse as
38-
// an ID. We use :/ instead of :// since the diagnostics are better for
39-
// a bad input like spiffe:/trustdomain.
37+
// Heuristic: if the input resembles a SPIFFE ID or a URI scheme
38+
// (e.g. spiffe://..., spiffe:/..., or <scheme>:/...), delegate parsing
39+
// to SpiffeId.parse() so scheme-related errors are reported consistently.
4040
if (idOrName.contains(":/")) {
41-
SpiffeId spiffeId = SpiffeId.parse(idOrName);
42-
return spiffeId.getTrustDomain();
41+
return SpiffeId.parse(idOrName).getTrustDomain();
4342
}
4443

4544
validateTrustDomainName(idOrName);

java-spiffe-core/src/main/java/io/spiffe/svid/jwtsvid/JwtSvid.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ private static SpiffeId getSpiffeIdOfSubject(final JWTClaimsSet claimsSet) throw
393393

394394
// expected audiences must be a subset of the audience claim in the token
395395
private static void validateAudience(List<String> audClaim, Set<String> expectedAudiences) throws JwtSvidException {
396+
if (audClaim == null || audClaim.isEmpty()) {
397+
throw new JwtSvidException("Token missing audience claim");
398+
}
396399
if (!audClaim.containsAll(expectedAudiences)) {
397400
throw new JwtSvidException(String.format("expected audience in %s (audience=%s)", expectedAudiences, audClaim));
398401
}

java-spiffe-core/src/main/java/io/spiffe/workloadapi/CachedJwtSource.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import io.spiffe.svid.jwtsvid.JwtSvid;
1111
import org.apache.commons.lang3.tuple.ImmutablePair;
1212

13-
import java.io.Closeable;
1413
import java.io.IOException;
1514
import java.time.Clock;
1615
import java.time.Duration;
@@ -228,7 +227,7 @@ private List<JwtSvid> getJwtSvids(SpiffeId subject, String audience, String... e
228227
ImmutablePair<SpiffeId, Set<String>> cacheKey = new ImmutablePair<>(subject, audiencesSet);
229228

230229
List<JwtSvid> svidList = jwtSvids.get(cacheKey);
231-
if (svidList != null && !isTokenPastHalfLifetime(svidList.get(0))) {
230+
if (svidList != null && !svidList.isEmpty() && !isTokenPastHalfLifetime(svidList.get(0))) {
232231
return svidList;
233232
}
234233

@@ -238,7 +237,7 @@ private List<JwtSvid> getJwtSvids(SpiffeId subject, String audience, String... e
238237
// If it does not exist or the JWT-SVID has passed half its lifetime, call the Workload API to fetch new JWT-SVIDs,
239238
// add them to the cache map, and return the list of JWT-SVIDs.
240239
svidList = jwtSvids.get(cacheKey);
241-
if (svidList != null && !isTokenPastHalfLifetime(svidList.get(0))) {
240+
if (svidList != null && !svidList.isEmpty() && !isTokenPastHalfLifetime(svidList.get(0))) {
242241
return svidList;
243242
}
244243

@@ -247,6 +246,9 @@ private List<JwtSvid> getJwtSvids(SpiffeId subject, String audience, String... e
247246
} else {
248247
svidList = workloadApiClient.fetchJwtSvids(cacheKey.left, audience, extraAudiences);
249248
}
249+
if (svidList == null || svidList.isEmpty()) {
250+
throw new JwtSvidException("Workload API returned empty JWT SVID list");
251+
}
250252
jwtSvids.put(cacheKey, svidList);
251253
return svidList;
252254
}
@@ -333,4 +335,16 @@ private static WorkloadApiClient createClient(JwtSourceOptions options)
333335
void setClock(Clock clock) {
334336
this.clock = clock;
335337
}
338+
339+
// Visible for testing only.
340+
// This method exists to allow deterministic testing of cache edge cases
341+
// (e.g. empty cached lists) without relying on reflection or timing-based
342+
// behavior, which would be more brittle and less safe.
343+
void putCachedJwtSvidsForTest(SpiffeId subject, Set<String> audiences, List<JwtSvid> svids) {
344+
Objects.requireNonNull(subject, "subject must not be null");
345+
Objects.requireNonNull(audiences, "audiences must not be null");
346+
Objects.requireNonNull(svids, "svids must not be null");
347+
ImmutablePair<SpiffeId, Set<String>> cacheKey = new ImmutablePair<>(subject, new HashSet<>(audiences));
348+
jwtSvids.put(cacheKey, new ArrayList<>(svids));
349+
}
336350
}

java-spiffe-core/src/test/java/io/spiffe/bundle/jwtbundle/JwtBundleSetTest.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import static org.junit.jupiter.api.Assertions.assertEquals;
1313
import static org.junit.jupiter.api.Assertions.assertFalse;
1414
import static org.junit.jupiter.api.Assertions.assertNotNull;
15+
import static org.junit.jupiter.api.Assertions.assertThrows;
1516
import static org.junit.jupiter.api.Assertions.assertTrue;
1617
import static org.junit.jupiter.api.Assertions.fail;
1718

@@ -153,4 +154,12 @@ void add_null_throwsNullPointerException() {
153154
assertEquals("jwtBundle must not be null", e.getMessage());
154155
}
155156
}
156-
}
157+
158+
@Test
159+
void testOf_nullElementInCollection_throwsNullPointerException() {
160+
JwtBundle jwtBundle1 = new JwtBundle(TrustDomain.parse("example.org"));
161+
List<JwtBundle> bundles = Arrays.asList(jwtBundle1, null);
162+
NullPointerException exception = assertThrows(NullPointerException.class, () -> JwtBundleSet.of(bundles));
163+
assertEquals("bundle must not be null", exception.getMessage());
164+
}
165+
}

java-spiffe-core/src/test/java/io/spiffe/bundle/x509bundle/X509BundleSetTest.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import static org.junit.jupiter.api.Assertions.assertEquals;
1313
import static org.junit.jupiter.api.Assertions.assertFalse;
1414
import static org.junit.jupiter.api.Assertions.assertNotNull;
15+
import static org.junit.jupiter.api.Assertions.assertThrows;
1516
import static org.junit.jupiter.api.Assertions.assertTrue;
1617
import static org.junit.jupiter.api.Assertions.fail;
1718

@@ -149,4 +150,12 @@ void testgetBundleForTrustDomain_nullTrustDomain_throwsException() throws Bundle
149150
assertEquals("trustDomain must not be null", e.getMessage());
150151
}
151152
}
152-
}
153+
154+
@Test
155+
void testOf_nullElementInCollection_throwsNullPointerException() {
156+
X509Bundle x509Bundle1 = new X509Bundle(TrustDomain.parse("example.org"));
157+
List<X509Bundle> bundles = Arrays.asList(x509Bundle1, null);
158+
NullPointerException exception = assertThrows(NullPointerException.class, () -> X509BundleSet.of(bundles));
159+
assertEquals("bundle must not be null", exception.getMessage());
160+
}
161+
}

java-spiffe-core/src/test/java/io/spiffe/spiffeid/TrustDomainTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import static io.spiffe.spiffeid.SpiffeIdTest.TD_CHARS;
1212
import static org.junit.jupiter.api.Assertions.assertEquals;
13+
import static org.junit.jupiter.api.Assertions.assertThrows;
1314
import static org.junit.jupiter.api.Assertions.fail;
1415

1516
class TrustDomainTest {
@@ -104,4 +105,28 @@ void test_toIdString() {
104105
final TrustDomain trustDomain = TrustDomain.parse("domain.test");
105106
assertEquals("spiffe://domain.test", trustDomain.toIdString());
106107
}
108+
109+
@Test
110+
void testParseFromSpiffeIdWithPath_extractsTrustDomain() {
111+
TrustDomain trustDomain = TrustDomain.parse("spiffe://example.org/foo");
112+
assertEquals("example.org", trustDomain.getName());
113+
}
114+
115+
@Test
116+
void testParseInvalidScheme_spiffeWithSingleSlash_throwsInvalidScheme() {
117+
assertThrows(InvalidSpiffeIdException.class,
118+
() -> TrustDomain.parse("spiffe:/example.org"));
119+
}
120+
121+
@Test
122+
void testParseInvalidScheme_httpScheme_throwsInvalidScheme() {
123+
assertThrows(InvalidSpiffeIdException.class,
124+
() -> TrustDomain.parse("http://example.org"));
125+
}
126+
127+
@Test
128+
void testParseColonNotFollowedBySlash_validatesAsTrustDomain() {
129+
assertThrows(InvalidSpiffeIdException.class,
130+
() -> TrustDomain.parse("trustdomain:test"));
131+
}
107132
}

java-spiffe-core/src/test/java/io/spiffe/svid/jwtsvid/JwtSvidParseAndValidateTest.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,22 @@ static Stream<Arguments> provideSuccessScenarios() {
179179
TestUtils.generateToken(claims, key3, "authority3"),
180180
""
181181
))
182+
.build()),
183+
Arguments.of(TestCase.builder()
184+
.name("audience contains expected - success")
185+
.jwtBundle(jwtBundle)
186+
.expectedAudience(Collections.singleton("audience1"))
187+
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
188+
.expectedException(null)
189+
.expectedJwtSvid(newJwtSvidInstance(
190+
trustDomain.newSpiffeId("host"),
191+
audience,
192+
issuedAt,
193+
expiration,
194+
claims.getClaims(),
195+
TestUtils.generateToken(claims, key1, "authority1"),
196+
null
197+
))
182198
.build())
183199
);
184200
}
@@ -243,6 +259,27 @@ static Stream<Arguments> provideFailureScenarios() {
243259
.generateToken(() -> TestUtils.generateToken(claims, key1, "authority1"))
244260
.expectedException(new JwtSvidException("expected audience in [another] (audience=[audience2, audience1])"))
245261
.build()),
262+
Arguments.of(TestCase.builder()
263+
.name("missing audience claim")
264+
.jwtBundle(jwtBundle)
265+
.expectedAudience(audience)
266+
.generateToken(() -> TestUtils.generateToken(new JWTClaimsSet.Builder()
267+
.subject(spiffeId.toString())
268+
.expirationTime(expiration)
269+
.build(), key1, "authority1"))
270+
.expectedException(new JwtSvidException("Token missing audience claim"))
271+
.build()),
272+
Arguments.of(TestCase.builder()
273+
.name("empty audience claim")
274+
.jwtBundle(jwtBundle)
275+
.expectedAudience(audience)
276+
.generateToken(() -> TestUtils.generateToken(new JWTClaimsSet.Builder()
277+
.subject(spiffeId.toString())
278+
.expirationTime(expiration)
279+
.audience(Collections.emptyList())
280+
.build(), key1, "authority1"))
281+
.expectedException(new JwtSvidException("Token missing audience claim"))
282+
.build()),
246283
Arguments.of(TestCase.builder()
247284
.name("invalid subject claim")
248285
.jwtBundle(jwtBundle)
@@ -388,4 +425,4 @@ public TestCase build() {
388425
}
389426
}
390427
}
391-
}
428+
}

java-spiffe-core/src/test/java/io/spiffe/workloadapi/CachedJwtSourceTest.java

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import java.time.Instant;
2121
import java.time.ZoneId;
2222
import java.util.ArrayList;
23+
import java.util.Collections;
2324
import java.util.List;
25+
import java.util.Set;
2426
import java.util.concurrent.ExecutionException;
2527
import java.util.concurrent.ExecutorService;
2628
import java.util.concurrent.Executors;
@@ -30,6 +32,9 @@
3032
import static org.junit.jupiter.api.Assertions.*;
3133

3234
class CachedJwtSourceTest {
35+
private static final SpiffeId TEST_SUBJECT = SpiffeId.parse("spiffe://example.org/workload-server");
36+
private static final String TEST_AUDIENCE = "aud1";
37+
3338
private CachedJwtSource jwtSource;
3439
private WorkloadApiClientStub workloadApiClient;
3540
private WorkloadApiClientErrorStub workloadApiClientErrorStub;
@@ -519,4 +524,109 @@ void newSource_noSocketAddress() throws Exception {
519524
}
520525
});
521526
}
527+
528+
@Test
529+
void testFetchJwtSvids_cacheContainsEmptyList_refetchesFromWorkloadApi() throws JwtSvidException, JwtSourceException, SocketEndpointAddressException {
530+
// Test that if cache somehow contains empty list (edge case), it refetches
531+
JwtSourceOptions options = JwtSourceOptions.builder()
532+
.workloadApiClient(workloadApiClient)
533+
.initTimeout(Duration.ofSeconds(0))
534+
.build();
535+
CachedJwtSource customJwtSource = (CachedJwtSource) CachedJwtSource.newSource(options);
536+
customJwtSource.setClock(clock);
537+
538+
try {
539+
// Seed cache with empty list to simulate edge case
540+
Set<String> audiences = Collections.singleton(TEST_AUDIENCE);
541+
customJwtSource.putCachedJwtSvidsForTest(TEST_SUBJECT, audiences, Collections.emptyList());
542+
543+
int initialCallCount = workloadApiClient.getFetchJwtSvidCallCount();
544+
545+
// Fetch should refetch from Workload API (empty list in cache triggers refetch)
546+
List<JwtSvid> svids = customJwtSource.fetchJwtSvids(TEST_SUBJECT, TEST_AUDIENCE);
547+
assertNotNull(svids);
548+
assertEquals(1, svids.size());
549+
// Should have called Workload API
550+
assertEquals(initialCallCount + 1, workloadApiClient.getFetchJwtSvidCallCount());
551+
552+
// Subsequent fetch should NOT call Workload API again (proves valid list was cached after refetch)
553+
List<JwtSvid> svids2 = customJwtSource.fetchJwtSvids(TEST_SUBJECT, TEST_AUDIENCE);
554+
assertNotNull(svids2);
555+
assertEquals(1, svids2.size());
556+
assertEquals(initialCallCount + 1, workloadApiClient.getFetchJwtSvidCallCount());
557+
} finally {
558+
customJwtSource.close();
559+
}
560+
}
561+
562+
@Test
563+
void testFetchJwtSvids_workloadApiReturnsEmptyList_throwsJwtSvidException() throws JwtSourceException, SocketEndpointAddressException {
564+
// Create a custom client that always returns empty list
565+
WorkloadApiClientStub emptyListClient = new WorkloadApiClientStub() {
566+
@Override
567+
public List<JwtSvid> fetchJwtSvids(SpiffeId subject, String audience, String... extraAudience) throws JwtSvidException {
568+
super.fetchJwtSvids(subject, audience, extraAudience); // increment counter
569+
return Collections.emptyList();
570+
}
571+
};
572+
emptyListClient.setClock(clock);
573+
574+
JwtSourceOptions options = JwtSourceOptions.builder()
575+
.workloadApiClient(emptyListClient)
576+
.initTimeout(Duration.ofSeconds(0))
577+
.build();
578+
CachedJwtSource customJwtSource = (CachedJwtSource) CachedJwtSource.newSource(options);
579+
customJwtSource.setClock(clock);
580+
581+
try {
582+
JwtSvidException exception = assertThrows(JwtSvidException.class,
583+
() -> customJwtSource.fetchJwtSvids(TEST_SUBJECT, TEST_AUDIENCE));
584+
assertEquals("Workload API returned empty JWT SVID list", exception.getMessage());
585+
} finally {
586+
customJwtSource.close();
587+
}
588+
}
589+
590+
@Test
591+
void testFetchJwtSvids_emptyListNeverCached() throws JwtSvidException, JwtSourceException, SocketEndpointAddressException {
592+
// Create a custom client that returns empty list on first call, then valid SVIDs
593+
final int[] callCount = new int[1];
594+
WorkloadApiClientStub customClient = new WorkloadApiClientStub() {
595+
@Override
596+
public List<JwtSvid> fetchJwtSvids(SpiffeId subject, String audience, String... extraAudience) throws JwtSvidException {
597+
callCount[0]++;
598+
if (callCount[0] == 1) {
599+
return Collections.emptyList();
600+
} else {
601+
return super.fetchJwtSvids(subject, audience, extraAudience);
602+
}
603+
}
604+
};
605+
customClient.setClock(clock);
606+
607+
JwtSourceOptions options = JwtSourceOptions.builder()
608+
.workloadApiClient(customClient)
609+
.initTimeout(Duration.ofSeconds(0))
610+
.build();
611+
CachedJwtSource customJwtSource = (CachedJwtSource) CachedJwtSource.newSource(options);
612+
customJwtSource.setClock(clock);
613+
614+
try {
615+
// First call returns empty, should throw (empty list is not cached)
616+
JwtSvidException exception = assertThrows(JwtSvidException.class,
617+
() -> customJwtSource.fetchJwtSvids(TEST_SUBJECT, TEST_AUDIENCE));
618+
assertEquals("Workload API returned empty JWT SVID list", exception.getMessage());
619+
620+
// Verify empty list was not cached: second call should fetch again and succeed
621+
int callCountBeforeSecondCall = callCount[0];
622+
List<JwtSvid> svids = customJwtSource.fetchJwtSvids(TEST_SUBJECT, TEST_AUDIENCE);
623+
assertNotNull(svids);
624+
assertEquals(1, svids.size());
625+
// Verify that second call actually made a fetch (callCount increased)
626+
assertEquals(callCountBeforeSecondCall + 1, callCount[0]);
627+
} finally {
628+
customJwtSource.close();
629+
}
630+
}
522631
}
632+

0 commit comments

Comments
 (0)