Skip to content

Commit 8bbff9f

Browse files
committed
[ISSUE #10214] Fix TOCTOU race condition in MQClientInstance.brokerVersionTable
Fixes #10214
1 parent 932588d commit 8bbff9f

2 files changed

Lines changed: 79 additions & 11 deletions

File tree

client/src/main/java/org/apache/rocketmq/client/impl/factory/MQClientInstance.java

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -654,10 +654,8 @@ public boolean sendHeartbeatToBroker(long id, String brokerName, String addr, bo
654654
private boolean sendHeartbeatToBroker(long id, String brokerName, String addr, HeartbeatData heartbeatData) {
655655
try {
656656
int version = this.mQClientAPIImpl.sendHeartbeat(addr, heartbeatData, clientConfig.getMqClientApiTimeout());
657-
if (!this.brokerVersionTable.containsKey(brokerName)) {
658-
this.brokerVersionTable.put(brokerName, new ConcurrentHashMap<>(4));
659-
}
660-
this.brokerVersionTable.get(brokerName).put(addr, version);
657+
this.brokerVersionTable.computeIfAbsent(brokerName, k -> new ConcurrentHashMap<>(4))
658+
.put(addr, version);
661659
long times = this.sendHeartbeatTimesTotal.getAndIncrement();
662660
if (times % 20 == 0) {
663661
log.info("send heart beat to broker[{} {} {}] success", brokerName, id, addr);
@@ -734,10 +732,8 @@ private boolean sendHeartbeatToBrokerV2(long id, String brokerName, String addr,
734732
log.info("sendHeartbeatToAllBrokerV2 normal brokerName: {} subChange: {} brokerAddrHeartbeatFingerprintTable: {}", brokerName, heartbeatV2Result.isSubChange(), JSON.toJSONString(brokerAddrHeartbeatFingerprintTable));
735733
}
736734
version = heartbeatV2Result.getVersion();
737-
if (!this.brokerVersionTable.containsKey(brokerName)) {
738-
this.brokerVersionTable.put(brokerName, new ConcurrentHashMap<>(4));
739-
}
740-
this.brokerVersionTable.get(brokerName).put(addr, version);
735+
this.brokerVersionTable.computeIfAbsent(brokerName, k -> new ConcurrentHashMap<>(4))
736+
.put(addr, version);
741737
long times = this.sendHeartbeatTimesTotal.getAndIncrement();
742738
if (times % 20 == 0) {
743739
log.info("send heart beat to broker[{} {} {}] success", brokerName, id, addr);
@@ -1301,9 +1297,11 @@ public FindBrokerResult findBrokerAddressInSubscribe(
13011297
}
13021298

13031299
private int findBrokerVersion(String brokerName, String brokerAddr) {
1304-
if (this.brokerVersionTable.containsKey(brokerName)) {
1305-
if (this.brokerVersionTable.get(brokerName).containsKey(brokerAddr)) {
1306-
return this.brokerVersionTable.get(brokerName).get(brokerAddr);
1300+
ConcurrentHashMap<String, Integer> brokerVersions = this.brokerVersionTable.get(brokerName);
1301+
if (brokerVersions != null) {
1302+
Integer version = brokerVersions.get(brokerAddr);
1303+
if (version != null) {
1304+
return version;
13071305
}
13081306
}
13091307
//To do need to fresh the version

client/src/test/java/org/apache/rocketmq/client/impl/factory/MQClientInstanceTest.java

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,4 +554,74 @@ public void testSendHeartbeatToAllBrokerConcurrently() {
554554
fail("failed: " + e.getMessage());
555555
}
556556
}
557+
558+
@Test
559+
public void testFindBrokerVersionWhenBrokerNameNotExist() throws Exception {
560+
ConcurrentMap<String, ConcurrentHashMap<String, Integer>> brokerVersionTable =
561+
(ConcurrentMap<String, ConcurrentHashMap<String, Integer>>) FieldUtils.readDeclaredField(
562+
mqClientInstance, "brokerVersionTable", true);
563+
brokerVersionTable.clear();
564+
565+
FindBrokerResult result = mqClientInstance.findBrokerAddressInSubscribe("nonExistentBroker", 0, false);
566+
assertNull(result);
567+
}
568+
569+
@Test
570+
public void testFindBrokerVersionWhenBrokerAddrNotExist() throws Exception {
571+
ConcurrentMap<String, ConcurrentHashMap<String, Integer>> brokerVersionTable =
572+
(ConcurrentMap<String, ConcurrentHashMap<String, Integer>>) FieldUtils.readDeclaredField(
573+
mqClientInstance, "brokerVersionTable", true);
574+
brokerVersionTable.clear();
575+
brokerVersionTable.put("broker-a", new ConcurrentHashMap<>());
576+
577+
FindBrokerResult result = mqClientInstance.findBrokerAddressInSubscribe("broker-a", 0, false);
578+
assertNull(result);
579+
}
580+
581+
@Test
582+
public void testFindBrokerVersionConcurrentRemoval() throws Exception {
583+
ConcurrentMap<String, ConcurrentHashMap<String, Integer>> brokerVersionTable =
584+
(ConcurrentMap<String, ConcurrentHashMap<String, Integer>>) FieldUtils.readDeclaredField(
585+
mqClientInstance, "brokerVersionTable", true);
586+
brokerVersionTable.clear();
587+
588+
String brokerName = "broker-a";
589+
ConcurrentHashMap<String, Integer> versionMap = new ConcurrentHashMap<>(4);
590+
versionMap.put("127.0.0.1:10911", 401);
591+
brokerVersionTable.put(brokerName, versionMap);
592+
593+
// Simulate concurrent removal between containsKey and get
594+
// With the fix using local variable caching, this should not cause NPE
595+
brokerVersionTable.remove(brokerName);
596+
597+
// findBrokerVersion is private, but it's called via findBrokerAddressInSubscribe
598+
// The key test is that no NPE is thrown when brokerName is absent
599+
FindBrokerResult result = mqClientInstance.findBrokerAddressInSubscribe(brokerName, 0, false);
600+
assertNull(result);
601+
}
602+
603+
@Test
604+
public void testBrokerVersionTableComputeIfAbsent() throws Exception {
605+
ConcurrentMap<String, ConcurrentHashMap<String, Integer>> brokerVersionTable =
606+
(ConcurrentMap<String, ConcurrentHashMap<String, Integer>>) FieldUtils.readDeclaredField(
607+
mqClientInstance, "brokerVersionTable", true);
608+
brokerVersionTable.clear();
609+
610+
String brokerName = "broker-a";
611+
612+
// Use computeIfAbsent as the fix does
613+
ConcurrentHashMap<String, Integer> inner = brokerVersionTable.computeIfAbsent(
614+
brokerName, k -> new ConcurrentHashMap<>(4));
615+
inner.put("127.0.0.1:10911", 401);
616+
617+
// Verify the entry was created atomically
618+
assertNotNull(brokerVersionTable.get(brokerName));
619+
assertEquals(Integer.valueOf(401), brokerVersionTable.get(brokerName).get("127.0.0.1:10911"));
620+
621+
// Calling computeIfAbsent again should return the same map, not create a new one
622+
ConcurrentHashMap<String, Integer> inner2 = brokerVersionTable.computeIfAbsent(
623+
brokerName, k -> new ConcurrentHashMap<>(4));
624+
assertEquals(inner, inner2);
625+
assertEquals(Integer.valueOf(401), inner2.get("127.0.0.1:10911"));
626+
}
557627
}

0 commit comments

Comments
 (0)