Skip to content

Commit 7341577

Browse files
committed
[ISSUE #10214] Fix TOCTOU race condition in MQClientInstance.brokerVersionTable
Fixes #10214
1 parent 932588d commit 7341577

2 files changed

Lines changed: 127 additions & 11 deletions

File tree

client/src/main/java/org/apache/rocketmq/client/impl/factory/MQClientInstance.java

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -654,10 +654,8 @@ public boolean sendHeartbeatToBroker(long id, String brokerName, String addr, bo
654654
private boolean sendHeartbeatToBroker(long id, String brokerName, String addr, HeartbeatData heartbeatData) {
655655
try {
656656
int version = this.mQClientAPIImpl.sendHeartbeat(addr, heartbeatData, clientConfig.getMqClientApiTimeout());
657-
if (!this.brokerVersionTable.containsKey(brokerName)) {
658-
this.brokerVersionTable.put(brokerName, new ConcurrentHashMap<>(4));
659-
}
660-
this.brokerVersionTable.get(brokerName).put(addr, version);
657+
this.brokerVersionTable.computeIfAbsent(brokerName, k -> new ConcurrentHashMap<>(4))
658+
.put(addr, version);
661659
long times = this.sendHeartbeatTimesTotal.getAndIncrement();
662660
if (times % 20 == 0) {
663661
log.info("send heart beat to broker[{} {} {}] success", brokerName, id, addr);
@@ -734,10 +732,8 @@ private boolean sendHeartbeatToBrokerV2(long id, String brokerName, String addr,
734732
log.info("sendHeartbeatToAllBrokerV2 normal brokerName: {} subChange: {} brokerAddrHeartbeatFingerprintTable: {}", brokerName, heartbeatV2Result.isSubChange(), JSON.toJSONString(brokerAddrHeartbeatFingerprintTable));
735733
}
736734
version = heartbeatV2Result.getVersion();
737-
if (!this.brokerVersionTable.containsKey(brokerName)) {
738-
this.brokerVersionTable.put(brokerName, new ConcurrentHashMap<>(4));
739-
}
740-
this.brokerVersionTable.get(brokerName).put(addr, version);
735+
this.brokerVersionTable.computeIfAbsent(brokerName, k -> new ConcurrentHashMap<>(4))
736+
.put(addr, version);
741737
long times = this.sendHeartbeatTimesTotal.getAndIncrement();
742738
if (times % 20 == 0) {
743739
log.info("send heart beat to broker[{} {} {}] success", brokerName, id, addr);
@@ -1301,9 +1297,11 @@ public FindBrokerResult findBrokerAddressInSubscribe(
13011297
}
13021298

13031299
private int findBrokerVersion(String brokerName, String brokerAddr) {
1304-
if (this.brokerVersionTable.containsKey(brokerName)) {
1305-
if (this.brokerVersionTable.get(brokerName).containsKey(brokerAddr)) {
1306-
return this.brokerVersionTable.get(brokerName).get(brokerAddr);
1300+
ConcurrentHashMap<String, Integer> brokerVersions = this.brokerVersionTable.get(brokerName);
1301+
if (brokerVersions != null) {
1302+
Integer version = brokerVersions.get(brokerAddr);
1303+
if (version != null) {
1304+
return version;
13071305
}
13081306
}
13091307
//To do need to fresh the version

client/src/test/java/org/apache/rocketmq/client/impl/factory/MQClientInstanceTest.java

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,4 +554,122 @@ public void testSendHeartbeatToAllBrokerConcurrently() {
554554
fail("failed: " + e.getMessage());
555555
}
556556
}
557+
558+
@Test
559+
public void testFindBrokerVersionWhenVersionExists() throws Exception {
560+
String brokerName = "broker-a";
561+
String brokerAddr = "127.0.0.1:10911";
562+
563+
// Populate brokerAddrTable so findBrokerAddressInSubscribe reaches findBrokerVersion
564+
ConcurrentMap<String, HashMap<Long, String>> brokerAddrTable =
565+
(ConcurrentMap<String, HashMap<Long, String>>) FieldUtils.readDeclaredField(
566+
mqClientInstance, "brokerAddrTable", true);
567+
HashMap<Long, String> addrMap = new HashMap<>();
568+
addrMap.put(MixAll.MASTER_ID, brokerAddr);
569+
brokerAddrTable.put(brokerName, addrMap);
570+
571+
// Populate brokerVersionTable with a known version
572+
ConcurrentMap<String, ConcurrentHashMap<String, Integer>> brokerVersionTable =
573+
(ConcurrentMap<String, ConcurrentHashMap<String, Integer>>) FieldUtils.readDeclaredField(
574+
mqClientInstance, "brokerVersionTable", true);
575+
ConcurrentHashMap<String, Integer> versionMap = new ConcurrentHashMap<>();
576+
versionMap.put(brokerAddr, 401);
577+
brokerVersionTable.put(brokerName, versionMap);
578+
579+
FindBrokerResult result = mqClientInstance.findBrokerAddressInSubscribe(brokerName, MixAll.MASTER_ID, false);
580+
assertNotNull(result);
581+
assertEquals(brokerAddr, result.getBrokerAddr());
582+
assertEquals(401, result.getBrokerVersion());
583+
584+
// Cleanup
585+
brokerAddrTable.remove(brokerName);
586+
brokerVersionTable.remove(brokerName);
587+
}
588+
589+
@Test
590+
public void testFindBrokerVersionWhenBrokerNameNotInVersionTable() throws Exception {
591+
String brokerName = "broker-a";
592+
String brokerAddr = "127.0.0.1:10911";
593+
594+
// Populate brokerAddrTable so findBrokerAddressInSubscribe reaches findBrokerVersion
595+
ConcurrentMap<String, HashMap<Long, String>> brokerAddrTable =
596+
(ConcurrentMap<String, HashMap<Long, String>>) FieldUtils.readDeclaredField(
597+
mqClientInstance, "brokerAddrTable", true);
598+
HashMap<Long, String> addrMap = new HashMap<>();
599+
addrMap.put(MixAll.MASTER_ID, brokerAddr);
600+
brokerAddrTable.put(brokerName, addrMap);
601+
602+
// Do NOT populate brokerVersionTable - simulates broker not yet heartbeated
603+
ConcurrentMap<String, ConcurrentHashMap<String, Integer>> brokerVersionTable =
604+
(ConcurrentMap<String, ConcurrentHashMap<String, Integer>>) FieldUtils.readDeclaredField(
605+
mqClientInstance, "brokerVersionTable", true);
606+
brokerVersionTable.remove(brokerName);
607+
608+
FindBrokerResult result = mqClientInstance.findBrokerAddressInSubscribe(brokerName, MixAll.MASTER_ID, false);
609+
assertNotNull(result);
610+
assertEquals(brokerAddr, result.getBrokerAddr());
611+
// findBrokerVersion returns 0 when brokerName not in version table
612+
assertEquals(0, result.getBrokerVersion());
613+
614+
// Cleanup
615+
brokerAddrTable.remove(brokerName);
616+
}
617+
618+
@Test
619+
public void testFindBrokerVersionWhenAddrNotInVersionTable() throws Exception {
620+
String brokerName = "broker-a";
621+
String brokerAddr = "127.0.0.1:10911";
622+
623+
// Populate brokerAddrTable
624+
ConcurrentMap<String, HashMap<Long, String>> brokerAddrTable =
625+
(ConcurrentMap<String, HashMap<Long, String>>) FieldUtils.readDeclaredField(
626+
mqClientInstance, "brokerAddrTable", true);
627+
HashMap<Long, String> addrMap = new HashMap<>();
628+
addrMap.put(MixAll.MASTER_ID, brokerAddr);
629+
brokerAddrTable.put(brokerName, addrMap);
630+
631+
// Populate brokerVersionTable with broker name but different address
632+
ConcurrentMap<String, ConcurrentHashMap<String, Integer>> brokerVersionTable =
633+
(ConcurrentMap<String, ConcurrentHashMap<String, Integer>>) FieldUtils.readDeclaredField(
634+
mqClientInstance, "brokerVersionTable", true);
635+
ConcurrentHashMap<String, Integer> versionMap = new ConcurrentHashMap<>();
636+
versionMap.put("127.0.0.1:99999", 401); // different address
637+
brokerVersionTable.put(brokerName, versionMap);
638+
639+
FindBrokerResult result = mqClientInstance.findBrokerAddressInSubscribe(brokerName, MixAll.MASTER_ID, false);
640+
assertNotNull(result);
641+
// findBrokerVersion returns 0 when addr not found in the version map
642+
assertEquals(0, result.getBrokerVersion());
643+
644+
// Cleanup
645+
brokerAddrTable.remove(brokerName);
646+
brokerVersionTable.remove(brokerName);
647+
}
648+
649+
@Test
650+
public void testBrokerVersionTableComputeIfAbsent() throws Exception {
651+
ConcurrentMap<String, ConcurrentHashMap<String, Integer>> brokerVersionTable =
652+
(ConcurrentMap<String, ConcurrentHashMap<String, Integer>>) FieldUtils.readDeclaredField(
653+
mqClientInstance, "brokerVersionTable", true);
654+
655+
String brokerName = "broker-computeIfAbsent-test";
656+
657+
// Use computeIfAbsent as the fix does
658+
ConcurrentHashMap<String, Integer> inner = brokerVersionTable.computeIfAbsent(
659+
brokerName, k -> new ConcurrentHashMap<>(4));
660+
inner.put("127.0.0.1:10911", 401);
661+
662+
// Verify the entry was created atomically
663+
assertNotNull(brokerVersionTable.get(brokerName));
664+
assertEquals(Integer.valueOf(401), brokerVersionTable.get(brokerName).get("127.0.0.1:10911"));
665+
666+
// Calling computeIfAbsent again should return the same map, not create a new one
667+
ConcurrentHashMap<String, Integer> inner2 = brokerVersionTable.computeIfAbsent(
668+
brokerName, k -> new ConcurrentHashMap<>(4));
669+
assertEquals(inner, inner2);
670+
assertEquals(Integer.valueOf(401), inner2.get("127.0.0.1:10911"));
671+
672+
// Cleanup
673+
brokerVersionTable.remove(brokerName);
674+
}
557675
}

0 commit comments

Comments
 (0)