From 26129f5b43a95555141670643585734bb813222c Mon Sep 17 00:00:00 2001 From: Maksim Kashapov Date: Wed, 19 Nov 2025 13:42:49 +0100 Subject: [PATCH 1/9] [broker-30] Implement delivering retained messages --- .../config/MqttBrokerSpringConfig.java | 6 +- .../tree/ConcurrentSubscriptionTree.java} | 10 +- .../tree/TopicFilterNode.java} | 50 +++++----- .../tree/TopicFilterTreeBase.java} | 12 +-- .../model/subscribtion/tree/package-info.java | 4 + .../mqtt/model/topic/TopicFilter.java | 4 +- .../tree/ConcurrentRetainedMessageTree.java | 31 ++++++ .../model/topic/tree/TopicMessageNode.java | 94 +++++++++++++++++++ .../model/topic/tree/TopicTreeTest.groovy | 19 ++-- .../service/PublishDeliveringService.java | 3 + .../impl/DefaultPublishDeliveringService.java | 14 ++- .../impl/InMemorySubscriptionService.java | 8 +- .../impl/SubscribeMqttInMessageHandler.java | 43 ++++++++- .../SubscribeMqttInMessageHandlerTest.groovy | 87 +++++++++++++++-- 14 files changed, 321 insertions(+), 64 deletions(-) rename model/src/main/java/javasabr/mqtt/model/{topic/tree/ConcurrentTopicTree.java => subscribtion/tree/ConcurrentSubscriptionTree.java} (86%) rename model/src/main/java/javasabr/mqtt/model/{topic/tree/TopicNode.java => subscribtion/tree/TopicFilterNode.java} (71%) rename model/src/main/java/javasabr/mqtt/model/{topic/tree/TopicTreeBase.java => subscribtion/tree/TopicFilterTreeBase.java} (93%) create mode 100644 model/src/main/java/javasabr/mqtt/model/subscribtion/tree/package-info.java create mode 100644 model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java create mode 100644 model/src/main/java/javasabr/mqtt/model/topic/tree/TopicMessageNode.java diff --git a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java index ddf6f743..b809f587 100644 --- a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java +++ b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java @@ -5,6 +5,7 @@ import javasabr.mqtt.model.MqttProperties; import javasabr.mqtt.model.MqttServerConnectionConfig; import javasabr.mqtt.model.QoS; +import javasabr.mqtt.model.topic.tree.ConcurrentRetainedMessageTree; import javasabr.mqtt.network.MqttClientFactory; import javasabr.mqtt.network.MqttConnection; import javasabr.mqtt.network.MqttConnectionFactory; @@ -178,8 +179,9 @@ MqttInMessageHandler disconnectMqttInMessageHandler(MessageOutFactoryService mes MqttInMessageHandler subscribeMqttInMessageHandler( SubscriptionService subscriptionService, MessageOutFactoryService messageOutFactoryService, - TopicService topicService) { - return new SubscribeMqttInMessageHandler(subscriptionService, messageOutFactoryService, topicService); + TopicService topicService, + PublishDeliveringService publishDeliveringService) { + return new SubscribeMqttInMessageHandler(subscriptionService, messageOutFactoryService, topicService, publishDeliveringService); } @Bean diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentTopicTree.java b/model/src/main/java/javasabr/mqtt/model/subscribtion/tree/ConcurrentSubscriptionTree.java similarity index 86% rename from model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentTopicTree.java rename to model/src/main/java/javasabr/mqtt/model/subscribtion/tree/ConcurrentSubscriptionTree.java index b095419a..f55a5548 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentTopicTree.java +++ b/model/src/main/java/javasabr/mqtt/model/subscribtion/tree/ConcurrentSubscriptionTree.java @@ -1,4 +1,4 @@ -package javasabr.mqtt.model.topic.tree; +package javasabr.mqtt.model.subscribtion.tree; import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscribtion.Subscription; @@ -13,12 +13,12 @@ import org.jspecify.annotations.Nullable; @FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) -public class ConcurrentTopicTree implements ThreadSafe { +public class ConcurrentSubscriptionTree implements ThreadSafe { - TopicNode rootNode; + TopicFilterNode rootNode; - public ConcurrentTopicTree() { - this.rootNode = new TopicNode(); + public ConcurrentSubscriptionTree() { + this.rootNode = new TopicFilterNode(); } @Nullable diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/TopicNode.java b/model/src/main/java/javasabr/mqtt/model/subscribtion/tree/TopicFilterNode.java similarity index 71% rename from model/src/main/java/javasabr/mqtt/model/topic/tree/TopicNode.java rename to model/src/main/java/javasabr/mqtt/model/subscribtion/tree/TopicFilterNode.java index 03a624b1..431e1d87 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/tree/TopicNode.java +++ b/model/src/main/java/javasabr/mqtt/model/subscribtion/tree/TopicFilterNode.java @@ -1,4 +1,4 @@ -package javasabr.mqtt.model.topic.tree; +package javasabr.mqtt.model.subscribtion.tree; import java.util.function.Supplier; import javasabr.mqtt.base.util.DebugUtils; @@ -22,16 +22,16 @@ @Getter(AccessLevel.PACKAGE) @Accessors(fluent = true, chain = false) @FieldDefaults(level = AccessLevel.PRIVATE) -class TopicNode extends TopicTreeBase { +class TopicFilterNode extends TopicFilterTreeBase { - private final static Supplier TOPIC_NODE_FACTORY = TopicNode::new; + private final static Supplier TOPIC_NODE_FACTORY = TopicFilterNode::new; static { DebugUtils.registerIncludedFields("childNodes", "subscribers"); } @Nullable - volatile LockableRefToRefDictionary childNodes; + volatile LockableRefToRefDictionary childNodes; @Nullable volatile LockableArray subscribers; @@ -43,7 +43,7 @@ public SingleSubscriber subscribe(int level, SubscriptionOwner owner, Subscripti if (level == topicFilter.levelsCount()) { return addSubscriber(getOrCreateSubscribers(), owner, subscription, topicFilter); } - TopicNode childNode = getOrCreateChildNode(topicFilter.segment(level)); + TopicFilterNode childNode = getOrCreateChildNode(topicFilter.segment(level)); return childNode.subscribe(level + 1, owner, subscription, topicFilter); } @@ -51,7 +51,7 @@ public boolean unsubscribe(int level, SubscriptionOwner owner, TopicFilter topic if (level == topicFilter.levelsCount()) { return removeSubscriber(subscribers(), owner, topicFilter); } - TopicNode childNode = getOrCreateChildNode(topicFilter.segment(level)); + TopicFilterNode childNode = getOrCreateChildNode(topicFilter.segment(level)); return childNode.unsubscribe(level + 1, owner, topicFilter); } @@ -67,14 +67,14 @@ private void exactlyTopicMatch( int lastLevel, MutableArray result) { String segment = topicName.segment(level); - TopicNode topicNode = childNode(segment); - if (topicNode == null) { + TopicFilterNode topicFilterNode = childNode(segment); + if (topicFilterNode == null) { return; } if (level == lastLevel) { - appendSubscribersTo(result, topicNode); + appendSubscribersTo(result, topicFilterNode); } else if (level < lastLevel) { - topicNode.matchesTo(level + 1, topicName, lastLevel, result); + topicFilterNode.matchesTo(level + 1, topicName, lastLevel, result); } } @@ -83,31 +83,31 @@ private void singleWildcardTopicMatch( TopicName topicName, int lastLevel, MutableArray result) { - TopicNode topicNode = childNode(TopicFilter.SINGLE_LEVEL_WILDCARD); - if (topicNode == null) { + TopicFilterNode topicFilterNode = childNode(TopicFilter.SINGLE_LEVEL_WILDCARD); + if (topicFilterNode == null) { return; } if (level == lastLevel) { - appendSubscribersTo(result, topicNode); + appendSubscribersTo(result, topicFilterNode); } else if (level < lastLevel) { - topicNode.matchesTo(level + 1, topicName, lastLevel, result); + topicFilterNode.matchesTo(level + 1, topicName, lastLevel, result); } } private void multiWildcardTopicMatch(MutableArray result) { - TopicNode topicNode = childNode(TopicFilter.MULTI_LEVEL_WILDCARD); - if (topicNode != null) { - appendSubscribersTo(result, topicNode); + TopicFilterNode topicFilterNode = childNode(TopicFilter.MULTI_LEVEL_WILDCARD); + if (topicFilterNode != null) { + appendSubscribersTo(result, topicFilterNode); } } - private TopicNode getOrCreateChildNode(String segment) { - LockableRefToRefDictionary childNodes = getOrCreateChildNodes(); + private TopicFilterNode getOrCreateChildNode(String segment) { + LockableRefToRefDictionary childNodes = getOrCreateChildNodes(); long stamp = childNodes.readLock(); try { - TopicNode topicNode = childNodes.get(segment); - if (topicNode != null) { - return topicNode; + TopicFilterNode topicFilterNode = childNodes.get(segment); + if (topicFilterNode != null) { + return topicFilterNode; } } finally { childNodes.readUnlock(stamp); @@ -122,8 +122,8 @@ private TopicNode getOrCreateChildNode(String segment) { } @Nullable - private TopicNode childNode(String segment) { - LockableRefToRefDictionary childNodes = childNodes(); + private TopicFilterNode childNode(String segment) { + LockableRefToRefDictionary childNodes = childNodes(); if (childNodes == null) { return null; } @@ -135,7 +135,7 @@ private TopicNode childNode(String segment) { } } - private LockableRefToRefDictionary getOrCreateChildNodes() { + private LockableRefToRefDictionary getOrCreateChildNodes() { if (childNodes == null) { synchronized (this) { if (childNodes == null) { diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/TopicTreeBase.java b/model/src/main/java/javasabr/mqtt/model/subscribtion/tree/TopicFilterTreeBase.java similarity index 93% rename from model/src/main/java/javasabr/mqtt/model/topic/tree/TopicTreeBase.java rename to model/src/main/java/javasabr/mqtt/model/subscribtion/tree/TopicFilterTreeBase.java index 7f0dbb23..f035052c 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/tree/TopicTreeBase.java +++ b/model/src/main/java/javasabr/mqtt/model/subscribtion/tree/TopicFilterTreeBase.java @@ -1,4 +1,4 @@ -package javasabr.mqtt.model.topic.tree; +package javasabr.mqtt.model.subscribtion.tree; import java.util.Objects; import javasabr.mqtt.model.QoS; @@ -18,7 +18,7 @@ @RequiredArgsConstructor @FieldDefaults(level = AccessLevel.PROTECTED, makeFinal = true) -abstract class TopicTreeBase { +abstract class TopicFilterTreeBase { /** * @return previous subscriber with the same owner @@ -66,7 +66,7 @@ private static void addSharedSubscriber( String group = sharedTopicFilter.shareName(); SharedSubscriber sharedSubscriber = (SharedSubscriber) subscribers .iterations() - .findAny(group, TopicTreeBase::isSharedSubscriberWithGroup); + .findAny(group, TopicFilterTreeBase::isSharedSubscriberWithGroup); if (sharedSubscriber == null) { sharedSubscriber = new SharedSubscriber(sharedTopicFilter); @@ -76,8 +76,8 @@ private static void addSharedSubscriber( sharedSubscriber.addSubscriber(new SingleSubscriber(owner, subscription)); } - protected static void appendSubscribersTo(MutableArray result, TopicNode topicNode) { - LockableArray subscribers = topicNode.subscribers(); + protected static void appendSubscribersTo(MutableArray result, TopicFilterNode topicFilterNode) { + LockableArray subscribers = topicFilterNode.subscribers(); if (subscribers == null) { return; } @@ -125,7 +125,7 @@ private static boolean removeSharedSubscriber( String group = sharedTopicFilter.shareName(); SharedSubscriber sharedSubscriber = (SharedSubscriber) subscribers .iterations() - .findAny(group, TopicTreeBase::isSharedSubscriberWithGroup); + .findAny(group, TopicFilterTreeBase::isSharedSubscriberWithGroup); if (sharedSubscriber != null) { boolean removed = sharedSubscriber.removeSubscriberWithOwner(owner); if (sharedSubscriber.isEmpty()) { diff --git a/model/src/main/java/javasabr/mqtt/model/subscribtion/tree/package-info.java b/model/src/main/java/javasabr/mqtt/model/subscribtion/tree/package-info.java new file mode 100644 index 00000000..692a51cf --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/subscribtion/tree/package-info.java @@ -0,0 +1,4 @@ +@NullMarked +package javasabr.mqtt.model.subscribtion.tree; + +import org.jspecify.annotations.NullMarked; \ No newline at end of file diff --git a/model/src/main/java/javasabr/mqtt/model/topic/TopicFilter.java b/model/src/main/java/javasabr/mqtt/model/topic/TopicFilter.java index fa9c1b6f..da653f8a 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/TopicFilter.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/TopicFilter.java @@ -11,9 +11,9 @@ public class TopicFilter extends AbstractTopic { public static final String MULTI_LEVEL_WILDCARD = "#"; - public static final char MULTI_LEVEL_WILDCARD_CHAR = '#'; + public static final char MULTI_LEVEL_WILDCARD_CHAR = MULTI_LEVEL_WILDCARD.charAt(0); public static final String SINGLE_LEVEL_WILDCARD = "+"; - public static final char SINGLE_LEVEL_WILDCARD_CHAR = '+'; + public static final char SINGLE_LEVEL_WILDCARD_CHAR = SINGLE_LEVEL_WILDCARD.charAt(0); public static final String SPECIAL = "$"; public static final TopicFilter INVALID_TOPIC_FILTER = new TopicFilter("$invalid$") { diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java new file mode 100644 index 00000000..a5bed489 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java @@ -0,0 +1,31 @@ +package javasabr.mqtt.model.topic.tree; + +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.mqtt.model.topic.TopicName; +import javasabr.rlib.common.ThreadSafe; +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import org.jspecify.annotations.Nullable; + +@FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) +public class ConcurrentRetainedMessageTree implements ThreadSafe { + + TopicMessageNode rootNode; + + public ConcurrentRetainedMessageTree() { + this.rootNode = new TopicMessageNode(); + } + + public void retainMessage(Publish message) { + rootNode.retainMessage(0, message, message.topicName()); + } + + public @Nullable Publish getRetainedMessage(TopicName topicName) { + return rootNode.getRetainedMessage(0, topicName); + } + + public @Nullable Publish getRetainedMessage(TopicFilter topicFilter) { + return rootNode.getRetainedMessage(0, topicFilter); + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/TopicMessageNode.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/TopicMessageNode.java new file mode 100644 index 00000000..55cad0ec --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/TopicMessageNode.java @@ -0,0 +1,94 @@ +package javasabr.mqtt.model.topic.tree; + +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import javasabr.mqtt.base.util.DebugUtils; +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.mqtt.model.topic.TopicName; +import javasabr.rlib.collections.dictionary.DictionaryFactory; +import javasabr.rlib.collections.dictionary.LockableRefToRefDictionary; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.experimental.Accessors; +import lombok.experimental.FieldDefaults; +import org.jspecify.annotations.Nullable; + +@Getter(AccessLevel.PACKAGE) +@Accessors(fluent = true, chain = false) +@FieldDefaults(level = AccessLevel.PRIVATE) +class TopicMessageNode { + + private final static Supplier TOPIC_NODE_FACTORY = TopicMessageNode::new; + + static { + DebugUtils.registerIncludedFields("childNodes", "retainedMessage"); + } + + @Nullable + volatile LockableRefToRefDictionary childNodes; + final AtomicReference<@Nullable Publish> retainedMessage = new AtomicReference<>(); + + public void retainMessage(int level, Publish message, TopicName topicFilter) { + if (level + 1 == topicFilter.levelsCount()) { + retainedMessage.set(message); + return; + } + TopicMessageNode childNode = getOrCreateChildNode(topicFilter.segment(level)); + childNode.retainMessage(level + 1, message, topicFilter); + } + + @Nullable + public Publish getRetainedMessage(int level, TopicName topicName) { + if (level + 1 == topicName.levelsCount()) { + return retainedMessage.get(); + } + TopicMessageNode childNode = getOrCreateChildNode(topicName.segment(level)); + return childNode.getRetainedMessage(level + 1, topicName); + } + + @Nullable + public Publish getRetainedMessage(int level, TopicFilter topicName) { + if (level + 1 == topicName.levelsCount()) { + return retainedMessage.get(); + } + TopicMessageNode childNode = getOrCreateChildNode(topicName.segment(level)); + return childNode.getRetainedMessage(level + 1, topicName); + } + + private TopicMessageNode getOrCreateChildNode(String segment) { + LockableRefToRefDictionary childNodes = getOrCreateChildNodes(); + long stamp = childNodes.readLock(); + try { + TopicMessageNode topicFilterNode = childNodes.get(segment); + if (topicFilterNode != null) { + return topicFilterNode; + } + } finally { + childNodes.readUnlock(stamp); + } + stamp = childNodes.writeLock(); + try { + return childNodes.getOrCompute(segment, TOPIC_NODE_FACTORY); + } finally { + childNodes.writeUnlock(stamp); + } + } + + private LockableRefToRefDictionary getOrCreateChildNodes() { + if (childNodes == null) { + synchronized (this) { + if (childNodes == null) { + childNodes = DictionaryFactory.stampedLockBasedRefToRefDictionary(); + } + } + } + //noinspection ConstantConditions + return childNodes; + } + + @Override + public String toString() { + return DebugUtils.toJsonString(this); + } +} diff --git a/model/src/test/groovy/javasabr/mqtt/model/topic/tree/TopicTreeTest.groovy b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/TopicTreeTest.groovy index fe5aa51f..1b220202 100644 --- a/model/src/test/groovy/javasabr/mqtt/model/topic/tree/TopicTreeTest.groovy +++ b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/TopicTreeTest.groovy @@ -6,6 +6,7 @@ import javasabr.mqtt.model.SubscribeRetainHandling import javasabr.mqtt.model.subscriber.SingleSubscriber import javasabr.mqtt.model.subscribtion.Subscription import javasabr.mqtt.model.subscribtion.SubscriptionOwner +import javasabr.mqtt.model.subscribtion.tree.ConcurrentSubscriptionTree import javasabr.mqtt.model.subscription.TestSubscriptionOwner import javasabr.mqtt.model.topic.SharedTopicFilter import javasabr.mqtt.model.topic.TopicFilter @@ -20,7 +21,7 @@ class TopicTreeTest extends UnitSpecification { String topicName, List expectedOwners) { given: - ConcurrentTopicTree topicTree = new ConcurrentTopicTree() + ConcurrentSubscriptionTree topicTree = new ConcurrentSubscriptionTree() subscriptions.eachWithIndex { Subscription subscription, int i -> topicTree.subscribe(owners.get(i), subscription) } @@ -108,7 +109,7 @@ class TopicTreeTest extends UnitSpecification { String topicName, List expectedOwners) { given: - ConcurrentTopicTree topicTree = new ConcurrentTopicTree() + ConcurrentSubscriptionTree topicTree = new ConcurrentSubscriptionTree() subscriptions.eachWithIndex { Subscription subscription, int i -> topicTree.subscribe(owners.get(i), subscription) } @@ -213,7 +214,7 @@ class TopicTreeTest extends UnitSpecification { String topicName, List expectedOwners) { given: - ConcurrentTopicTree topicTree = new ConcurrentTopicTree() + ConcurrentSubscriptionTree topicTree = new ConcurrentSubscriptionTree() subscriptions.eachWithIndex { Subscription subscription, int i -> topicTree.subscribe(owners.get(i), subscription) } @@ -327,7 +328,7 @@ class TopicTreeTest extends UnitSpecification { String topicName, List expectedSubscribers) { given: - ConcurrentTopicTree topicTree = new ConcurrentTopicTree() + ConcurrentSubscriptionTree topicTree = new ConcurrentSubscriptionTree() subscriptions.eachWithIndex { Subscription subscription, int i -> topicTree.subscribe(owners.get(i), subscription) } @@ -434,7 +435,7 @@ class TopicTreeTest extends UnitSpecification { given: def group1 = ["id1", "id2", "id3", "id4", "id5"] def group2 = ["id6", "id7", "id8", "id9", "id10"] - ConcurrentTopicTree topicTree = new ConcurrentTopicTree() + ConcurrentSubscriptionTree topicTree = new ConcurrentSubscriptionTree() topicTree.subscribe(makeOwner("id1"), makeSharedSubscription('$share/group1/topic/name1')) topicTree.subscribe(makeOwner("id2"), makeSharedSubscription('$share/group1/topic/name1')) topicTree.subscribe(makeOwner("id3"), makeSharedSubscription('$share/group1/topic/name1')) @@ -467,7 +468,7 @@ class TopicTreeTest extends UnitSpecification { def "should subscribe and unsubscribe simple topic correctly correctly"() { given: - ConcurrentTopicTree topicTree = new ConcurrentTopicTree() + ConcurrentSubscriptionTree topicTree = new ConcurrentSubscriptionTree() topicTree.subscribe(makeOwner("id1"), makeSubscription('topic/name1')) topicTree.subscribe(makeOwner("id2"), makeSubscription('topic/name1')) topicTree.subscribe(makeOwner("id3"), makeSubscription('topic/name1')) @@ -504,7 +505,7 @@ class TopicTreeTest extends UnitSpecification { def "should subscribe and unsubscribe shared topic correctly correctly"() { given: - ConcurrentTopicTree topicTree = new ConcurrentTopicTree() + ConcurrentSubscriptionTree topicTree = new ConcurrentSubscriptionTree() topicTree.subscribe(makeOwner("id1"), makeSharedSubscription('$share/group1/topic/name1')) topicTree.subscribe(makeOwner("id2"), makeSharedSubscription('$share/group1/topic/name1')) topicTree.subscribe(makeOwner("id3"), makeSharedSubscription('$share/group1/topic/name1')) @@ -541,7 +542,7 @@ class TopicTreeTest extends UnitSpecification { def "should replace the same subscriptions"() { given: - ConcurrentTopicTree topicTree = new ConcurrentTopicTree() + ConcurrentSubscriptionTree topicTree = new ConcurrentSubscriptionTree() def owner1 = makeOwner("id1") def originalSub = makeSubscription('topic/name1') def replacementSub = makeSubscription('topic/name1') @@ -569,7 +570,7 @@ class TopicTreeTest extends UnitSpecification { def "should extend shared subscription group on multiply subscribing by the same topic"() { given: - ConcurrentTopicTree topicTree = new ConcurrentTopicTree() + ConcurrentSubscriptionTree topicTree = new ConcurrentSubscriptionTree() def owner1 = makeOwner("id1") def owner2 = makeOwner("id2") topicTree.subscribe(owner1, makeSharedSubscription('$share/group1/topic/name1')) diff --git a/service/src/main/java/javasabr/mqtt/service/PublishDeliveringService.java b/service/src/main/java/javasabr/mqtt/service/PublishDeliveringService.java index 8110b7bf..f051d8a3 100644 --- a/service/src/main/java/javasabr/mqtt/service/PublishDeliveringService.java +++ b/service/src/main/java/javasabr/mqtt/service/PublishDeliveringService.java @@ -2,9 +2,12 @@ import javasabr.mqtt.model.publishing.Publish; import javasabr.mqtt.model.subscriber.SingleSubscriber; +import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; public interface PublishDeliveringService { PublishHandlingResult startDelivering(Publish publish, SingleSubscriber subscriber); + + PublishHandlingResult deliverRetainedMessages(TopicFilter topicFilter, SingleSubscriber subscriber); } diff --git a/service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java b/service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java index 8e0a7529..c66d3f49 100644 --- a/service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java +++ b/service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java @@ -4,6 +4,8 @@ import javasabr.mqtt.model.QoS; import javasabr.mqtt.model.publishing.Publish; import javasabr.mqtt.model.subscriber.SingleSubscriber; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.mqtt.model.topic.tree.ConcurrentRetainedMessageTree; import javasabr.mqtt.service.PublishDeliveringService; import javasabr.mqtt.service.publish.handler.MqttPublishOutMessageHandler; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; @@ -18,6 +20,7 @@ public class DefaultPublishDeliveringService implements PublishDeliveringService @Nullable MqttPublishOutMessageHandler[] publishOutMessageHandlers; + ConcurrentRetainedMessageTree topicTree; public DefaultPublishDeliveringService( Collection knownPublishOutHandlers) { @@ -39,7 +42,7 @@ public DefaultPublishDeliveringService( } handlers[qos.level()] = knownPublishOutHandler; } - + this.topicTree = new ConcurrentRetainedMessageTree(); this.publishOutMessageHandlers = handlers; log.info(publishOutMessageHandlers, DefaultPublishDeliveringService::buildServiceDescription); } @@ -47,6 +50,9 @@ public DefaultPublishDeliveringService( @Override public PublishHandlingResult startDelivering(Publish publish, SingleSubscriber subscriber) { try { + if (publish.retained()) { + topicTree.retainMessage(publish); + } //noinspection DataFlowIssue return publishOutMessageHandlers[subscriber.qos().level()].handle(publish, subscriber); } catch (IndexOutOfBoundsException | NullPointerException ex) { @@ -55,6 +61,12 @@ public PublishHandlingResult startDelivering(Publish publish, SingleSubscriber s } } + @Override + public PublishHandlingResult deliverRetainedMessages(TopicFilter topicFilter, SingleSubscriber subscriber) { + Publish retainedMessage = topicTree.getRetainedMessage(topicFilter); + return startDelivering(retainedMessage, subscriber); + } + private static String buildServiceDescription( @Nullable MqttPublishOutMessageHandler[] publishOutMessageHandlers) { var builder = new StringBuilder(); diff --git a/service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java b/service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java index cb1d44ff..3cb1aa4b 100644 --- a/service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java +++ b/service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java @@ -12,7 +12,7 @@ import javasabr.mqtt.model.topic.SharedTopicFilter; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; -import javasabr.mqtt.model.topic.tree.ConcurrentTopicTree; +import javasabr.mqtt.model.subscribtion.tree.ConcurrentSubscriptionTree; import javasabr.mqtt.network.MqttClient; import javasabr.mqtt.network.session.ActiveSubscriptions; import javasabr.mqtt.network.session.MqttSession; @@ -25,16 +25,16 @@ import lombok.experimental.FieldDefaults; /** - * In memory subscription service based on {@link ConcurrentTopicTree} + * In memory subscription service based on {@link ConcurrentSubscriptionTree} */ @CustomLog @FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) public class InMemorySubscriptionService implements SubscriptionService { - ConcurrentTopicTree topicTree; + ConcurrentSubscriptionTree topicTree; public InMemorySubscriptionService() { - this.topicTree = new ConcurrentTopicTree(); + this.topicTree = new ConcurrentSubscriptionTree(); } @Override diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java index 9e5528d7..949bdf48 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java @@ -9,6 +9,7 @@ import javasabr.mqtt.model.QoS; import javasabr.mqtt.model.reason.code.DisconnectReasonCode; import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode; +import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscribtion.RequestedSubscription; import javasabr.mqtt.model.subscribtion.Subscription; import javasabr.mqtt.model.topic.TopicFilter; @@ -20,8 +21,10 @@ import javasabr.mqtt.network.session.MessageTacker; import javasabr.mqtt.network.session.MqttSession; import javasabr.mqtt.service.MessageOutFactoryService; +import javasabr.mqtt.service.PublishDeliveringService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.TopicService; +import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import javasabr.rlib.collections.array.Array; import javasabr.rlib.collections.array.ArrayFactory; import javasabr.rlib.collections.array.MutableArray; @@ -38,16 +41,19 @@ public class SubscribeMqttInMessageHandler extends SHARED_SUBSCRIPTIONS_NOT_SUPPORTED, WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED); + PublishDeliveringService publishDeliveringService; SubscriptionService subscriptionService; TopicService topicService; public SubscribeMqttInMessageHandler( SubscriptionService subscriptionService, MessageOutFactoryService messageOutFactoryService, - TopicService topicService) { + TopicService topicService, + PublishDeliveringService publishDeliveringService) { super(ExternalMqttClient.class, SubscribeMqttInMessage.class, messageOutFactoryService); this.subscriptionService = subscriptionService; this.topicService = topicService; + this.publishDeliveringService = publishDeliveringService; } @Override @@ -92,6 +98,7 @@ protected void processValidMessage( .subscribe(client, session, subscriptions); sendSubscribeResults(client, session, subscribeMessage, subscribeResults); + sendRetainedMessages(client, subscribeMessage, subscribeResults, subscriptions); SubscribeAckReasonCode anyReasonToDisconnect = subscribeResults .iterations() @@ -174,4 +181,38 @@ private void sendSubscribeResults( .inMessageTracker() .remove(messageId)); } + + private void sendRetainedMessages( + ExternalMqttClient client, + SubscribeMqttInMessage subscribeMessage, + Array subscribeResults, + Array subs) { + int count = 0; + PublishHandlingResult errorResult = null; + Array subscriptions = subscribeMessage.subscriptions(); + for (int i = 0; i < subscribeMessage.subscriptionsCount(); i++) { + RequestedSubscription requestedSubscription = subscriptions.get(i); + SubscribeAckReasonCode subscribeAckReasonCode = subscribeResults.get(i); + Subscription subscription = subs.get(i); + if (subscribeAckReasonCode.ordinal() < 3) { + TopicFilter topicFilter = TopicFilter.valueOf(requestedSubscription.rawTopicFilter()); + SingleSubscriber singleSubscriber = new SingleSubscriber(client, subscription); + PublishHandlingResult result = publishDeliveringService.deliverRetainedMessages(topicFilter, singleSubscriber); + if (result.error()) { + errorResult = result; + } else if(result == PublishHandlingResult.SUCCESS) { + count++; + } + if (errorResult != null) { + log.debug(client.clientId(), errorResult, + "[%s] Found final error:[%s] during sending retained messages"::formatted); + // handleError(client, publish, errorResult); + } else { + log.debug(client.clientId(), count, + "[%s] Successfully started delivering retained messages to [%s] subscribers"::formatted); + // handleSuccessfulResult(client, publish, count); + } + } + } + } } diff --git a/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy b/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy index c740fdc3..1ab9d29e 100644 --- a/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy +++ b/service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy @@ -1,22 +1,32 @@ package javasabr.mqtt.service.message.handler.impl import javasabr.mqtt.model.MqttVersion +import javasabr.mqtt.model.PayloadFormat import javasabr.mqtt.model.QoS +import javasabr.mqtt.model.SubscribeRetainHandling +import javasabr.mqtt.model.publishing.Publish import javasabr.mqtt.model.reason.code.DisconnectReasonCode import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode +import javasabr.mqtt.model.subscriber.SingleSubscriber import javasabr.mqtt.model.subscribtion.RequestedSubscription +import javasabr.mqtt.model.subscribtion.Subscription +import javasabr.mqtt.model.topic.TopicName import javasabr.mqtt.network.message.in.SubscribeMqttInMessage import javasabr.mqtt.network.message.out.DisconnectMqtt5OutMessage +import javasabr.mqtt.network.message.out.PublishMqtt5OutMessage import javasabr.mqtt.network.message.out.SubscribeAckMqtt5OutMessage import javasabr.mqtt.network.util.ExtraErrorReasons import javasabr.mqtt.service.IntegrationServiceSpecification import javasabr.mqtt.service.TestExternalMqttClient import javasabr.rlib.collections.array.Array +import javasabr.rlib.collections.array.IntArray import javasabr.rlib.collections.array.MutableArray import javasabr.rlib.common.util.ThreadUtils import javasabr.rlib.logger.api.LoggerLevel import javasabr.rlib.logger.api.LoggerManager +import static java.nio.charset.StandardCharsets.UTF_8 + class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification { static { @@ -30,7 +40,8 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService) + defaultTopicService, + publishDeliveringService) def mqttClient = mqttConnection.client() as TestExternalMqttClient mqttClient.session(null) when: @@ -49,7 +60,8 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService) + defaultTopicService, + publishDeliveringService) def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient def session = mqttClient.session() @@ -80,7 +92,8 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService) + defaultTopicService, + publishDeliveringService) def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient when: @@ -110,7 +123,8 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService) + defaultTopicService, + publishDeliveringService) def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient when: @@ -140,7 +154,8 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService) + defaultTopicService, + publishDeliveringService) def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient when: @@ -173,7 +188,8 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService) + defaultTopicService, + publishDeliveringService) def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient when: @@ -204,7 +220,8 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService) + defaultTopicService, + publishDeliveringService) def mqttClient = mqttConnection.client() as TestExternalMqttClient when: def subscribeMessage = new SubscribeMqttInMessage(0 as byte) @@ -222,7 +239,8 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService) + defaultTopicService, + publishDeliveringService) def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient when: @@ -260,7 +278,8 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService) + defaultTopicService, + publishDeliveringService) def expectedMessageId = 15 def mqttClient = mqttConnection.client() as TestExternalMqttClient mqttClient.returnCompletedFeatures(false) @@ -292,4 +311,54 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification && reasonCodes2.get(0) == SubscribeAckReasonCode.PACKET_IDENTIFIER_IN_USE && subscribeAck2.messageId() == expectedMessageId } + + def "should deliver retained messages"() { + given: + def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) + def messageHandler = new SubscribeMqttInMessageHandler( + defaultSubscriptionService, + defaultMessageOutFactoryService, + defaultTopicService, + publishDeliveringService) + def expectedMessageId = 15 + def mqttClient = mqttConnection.client() as TestExternalMqttClient + mqttClient.returnCompletedFeatures(false) + when: + Publish publish = new Publish( + 1, + QoS.AT_MOST_ONCE, + TopicName.valueOf("topic2"), + null, + "payload".getBytes(UTF_8), + false, + true, + null, + IntArray.of(30), + null, + 60000, + 1, + PayloadFormat.UTF8_STRING, + Array.of()); + Subscription subscription = new Subscription( + defaultTopicService.createTopicFilter(mqttClient, "topic2"), + 30, + QoS.EXACTLY_ONCE, + SubscribeRetainHandling.SEND, + true, + true); + SingleSubscriber subscriber = new SingleSubscriber(mqttClient, subscription); + publishDeliveringService.startDelivering(publish, subscriber) + + def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ + this.messageId = 1 + this.subscriptions = MutableArray.ofType(RequestedSubscription) + this.subscriptions.addAll(Array.of(RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) + }} + messageHandler.processValidMessage(mqttConnection, subscribeMessage) + then: + mqttClient.nextSentMessage(PublishMqtt5OutMessage) + mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) + def retainedMessageDelivery = mqttClient.nextSentMessage(PublishMqtt5OutMessage) + retainedMessageDelivery.messageId() == 2 + } } From c6ada4d461288fca6306978a485bc9402f4384d5 Mon Sep 17 00:00:00 2001 From: Maksim Kashapov Date: Wed, 19 Nov 2025 19:05:56 +0100 Subject: [PATCH 2/9] [broker-30] Rewrite retained messages collecting --- .../mqtt/model/topic/AbstractTopic.java | 2 +- .../tree/ConcurrentRetainedMessageTree.java | 22 +-- .../model/topic/tree/RetainedMessageNode.java | 150 ++++++++++++++++++ .../model/topic/tree/TopicMessageNode.java | 94 ----------- .../service/PublishDeliveringService.java | 3 +- .../impl/DefaultPublishDeliveringService.java | 20 ++- .../impl/SubscribeMqttInMessageHandler.java | 18 ++- 7 files changed, 187 insertions(+), 122 deletions(-) create mode 100644 model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java delete mode 100644 model/src/main/java/javasabr/mqtt/model/topic/tree/TopicMessageNode.java diff --git a/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java b/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java index 7ff9a93e..9cdb4370 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java @@ -41,7 +41,7 @@ public int levelsCount() { return segments.length; } - String lastSegment() { + public String lastSegment() { return segments[segments.length - 1]; } diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java index a5bed489..78c78f11 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java @@ -2,30 +2,30 @@ import javasabr.mqtt.model.publishing.Publish; import javasabr.mqtt.model.topic.TopicFilter; -import javasabr.mqtt.model.topic.TopicName; +import javasabr.rlib.collections.array.Array; +import javasabr.rlib.collections.array.MutableArray; import javasabr.rlib.common.ThreadSafe; import lombok.AccessLevel; import lombok.experimental.FieldDefaults; -import org.jspecify.annotations.Nullable; @FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) public class ConcurrentRetainedMessageTree implements ThreadSafe { - TopicMessageNode rootNode; + RetainedMessageNode rootNode; public ConcurrentRetainedMessageTree() { - this.rootNode = new TopicMessageNode(); + this.rootNode = new RetainedMessageNode(); } public void retainMessage(Publish message) { - rootNode.retainMessage(0, message, message.topicName()); + if (message.retained()) { + rootNode.retainMessage(0, message, message.topicName()); + } } - public @Nullable Publish getRetainedMessage(TopicName topicName) { - return rootNode.getRetainedMessage(0, topicName); - } - - public @Nullable Publish getRetainedMessage(TopicFilter topicFilter) { - return rootNode.getRetainedMessage(0, topicFilter); + public Array getRetainedMessage(TopicFilter topicFilter) { + var resultArray = MutableArray.ofType(Publish.class); + rootNode.collectRetainedMessages(0, topicFilter, topicFilter.levelsCount() - 1, resultArray); + return resultArray; } } diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java new file mode 100644 index 00000000..0c362fc9 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java @@ -0,0 +1,150 @@ +package javasabr.mqtt.model.topic.tree; + +import static javasabr.mqtt.model.topic.TopicFilter.MULTI_LEVEL_WILDCARD; +import static javasabr.mqtt.model.topic.TopicFilter.SINGLE_LEVEL_WILDCARD; + +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import javasabr.mqtt.base.util.DebugUtils; +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.mqtt.model.topic.TopicName; +import javasabr.rlib.collections.array.MutableArray; +import javasabr.rlib.collections.dictionary.DictionaryFactory; +import javasabr.rlib.collections.dictionary.LockableRefToRefDictionary; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.experimental.Accessors; +import lombok.experimental.FieldDefaults; +import org.jspecify.annotations.Nullable; + +@Getter(AccessLevel.PACKAGE) +@Accessors(fluent = true, chain = false) +@FieldDefaults(level = AccessLevel.PRIVATE) +class RetainedMessageNode { + + private final static Supplier TOPIC_NODE_FACTORY = RetainedMessageNode::new; + + static { + DebugUtils.registerIncludedFields("childNodes", "retainedMessage"); + } + + @Nullable + volatile LockableRefToRefDictionary childNodes; + final AtomicReference<@Nullable Publish> retainedMessage = new AtomicReference<>(); + + public void retainMessage(int level, Publish message, TopicName topicName) { + if (level + 1 == topicName.levelsCount()) { + retainedMessage.set(message.payload().length == 0 ? null : message); + return; + } + RetainedMessageNode childNode = getOrCreateChildNode(topicName.segment(level)); + childNode.retainMessage(level + 1, message, topicName); + } + + public void collectRetainedMessages(int level, TopicFilter topicFilter, int lastLevel, MutableArray result) { + String segment = topicFilter.segment(level); + Publish publish = retainedMessage.get(); + if (Objects.equals(segment, MULTI_LEVEL_WILDCARD)) { + collectAllMessages(this, result); + } else if (Objects.equals(segment, SINGLE_LEVEL_WILDCARD)) { + var childNodes = childNodes(); + if (childNodes == null) { + return; + } + long stamp = childNodes.readLock(); + try { + for (RetainedMessageNode n : childNodes) { + n.collectRetainedMessages(level + 1, topicFilter, lastLevel, result); + } + } finally { + childNodes.readUnlock(stamp); + } + } else if (level == lastLevel && publish != null && Objects.equals(segment, publish.topicName().lastSegment())) { + result.add(publish); + } else { + RetainedMessageNode topicFilterNode = childNode(segment); + if (topicFilterNode == null) { + return; + } + topicFilterNode.collectRetainedMessages(level + 1, topicFilter, lastLevel, result); + } + } + + private void collectAllMessages(RetainedMessageNode node, MutableArray result) { + Queue queue = new PriorityQueue<>(); + queue.add(node); + while (!queue.isEmpty()) { + RetainedMessageNode poll = queue.poll(); + Publish message = poll.retainedMessage.get(); + if (message != null) { + result.add(message); + } + var childNodes = poll.childNodes(); + if (childNodes == null) { + continue; + } + long stamp = childNodes.readLock(); + try { + for (RetainedMessageNode n : childNodes) { + queue.add(n); + } + } finally { + childNodes.readUnlock(stamp); + } + } + } + + @Nullable + private RetainedMessageNode childNode(String segment) { + LockableRefToRefDictionary childNodes = childNodes(); + if (childNodes == null) { + return null; + } + long stamp = childNodes.readLock(); + try { + return childNodes.get(segment); + } finally { + childNodes.readUnlock(stamp); + } + } + + private RetainedMessageNode getOrCreateChildNode(String segment) { + LockableRefToRefDictionary childNodes = getOrCreateChildNodes(); + long stamp = childNodes.readLock(); + try { + RetainedMessageNode topicFilterNode = childNodes.get(segment); + if (topicFilterNode != null) { + return topicFilterNode; + } + } finally { + childNodes.readUnlock(stamp); + } + stamp = childNodes.writeLock(); + try { + return childNodes.getOrCompute(segment, TOPIC_NODE_FACTORY); + } finally { + childNodes.writeUnlock(stamp); + } + } + + private LockableRefToRefDictionary getOrCreateChildNodes() { + if (childNodes == null) { + synchronized (this) { + if (childNodes == null) { + childNodes = DictionaryFactory.stampedLockBasedRefToRefDictionary(); + } + } + } + //noinspection ConstantConditions + return childNodes; + } + + @Override + public String toString() { + return DebugUtils.toJsonString(this); + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/TopicMessageNode.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/TopicMessageNode.java deleted file mode 100644 index 55cad0ec..00000000 --- a/model/src/main/java/javasabr/mqtt/model/topic/tree/TopicMessageNode.java +++ /dev/null @@ -1,94 +0,0 @@ -package javasabr.mqtt.model.topic.tree; - -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; -import javasabr.mqtt.base.util.DebugUtils; -import javasabr.mqtt.model.publishing.Publish; -import javasabr.mqtt.model.topic.TopicFilter; -import javasabr.mqtt.model.topic.TopicName; -import javasabr.rlib.collections.dictionary.DictionaryFactory; -import javasabr.rlib.collections.dictionary.LockableRefToRefDictionary; -import lombok.AccessLevel; -import lombok.Getter; -import lombok.experimental.Accessors; -import lombok.experimental.FieldDefaults; -import org.jspecify.annotations.Nullable; - -@Getter(AccessLevel.PACKAGE) -@Accessors(fluent = true, chain = false) -@FieldDefaults(level = AccessLevel.PRIVATE) -class TopicMessageNode { - - private final static Supplier TOPIC_NODE_FACTORY = TopicMessageNode::new; - - static { - DebugUtils.registerIncludedFields("childNodes", "retainedMessage"); - } - - @Nullable - volatile LockableRefToRefDictionary childNodes; - final AtomicReference<@Nullable Publish> retainedMessage = new AtomicReference<>(); - - public void retainMessage(int level, Publish message, TopicName topicFilter) { - if (level + 1 == topicFilter.levelsCount()) { - retainedMessage.set(message); - return; - } - TopicMessageNode childNode = getOrCreateChildNode(topicFilter.segment(level)); - childNode.retainMessage(level + 1, message, topicFilter); - } - - @Nullable - public Publish getRetainedMessage(int level, TopicName topicName) { - if (level + 1 == topicName.levelsCount()) { - return retainedMessage.get(); - } - TopicMessageNode childNode = getOrCreateChildNode(topicName.segment(level)); - return childNode.getRetainedMessage(level + 1, topicName); - } - - @Nullable - public Publish getRetainedMessage(int level, TopicFilter topicName) { - if (level + 1 == topicName.levelsCount()) { - return retainedMessage.get(); - } - TopicMessageNode childNode = getOrCreateChildNode(topicName.segment(level)); - return childNode.getRetainedMessage(level + 1, topicName); - } - - private TopicMessageNode getOrCreateChildNode(String segment) { - LockableRefToRefDictionary childNodes = getOrCreateChildNodes(); - long stamp = childNodes.readLock(); - try { - TopicMessageNode topicFilterNode = childNodes.get(segment); - if (topicFilterNode != null) { - return topicFilterNode; - } - } finally { - childNodes.readUnlock(stamp); - } - stamp = childNodes.writeLock(); - try { - return childNodes.getOrCompute(segment, TOPIC_NODE_FACTORY); - } finally { - childNodes.writeUnlock(stamp); - } - } - - private LockableRefToRefDictionary getOrCreateChildNodes() { - if (childNodes == null) { - synchronized (this) { - if (childNodes == null) { - childNodes = DictionaryFactory.stampedLockBasedRefToRefDictionary(); - } - } - } - //noinspection ConstantConditions - return childNodes; - } - - @Override - public String toString() { - return DebugUtils.toJsonString(this); - } -} diff --git a/service/src/main/java/javasabr/mqtt/service/PublishDeliveringService.java b/service/src/main/java/javasabr/mqtt/service/PublishDeliveringService.java index f051d8a3..b5f65426 100644 --- a/service/src/main/java/javasabr/mqtt/service/PublishDeliveringService.java +++ b/service/src/main/java/javasabr/mqtt/service/PublishDeliveringService.java @@ -4,10 +4,11 @@ import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; +import javasabr.rlib.collections.array.Array; public interface PublishDeliveringService { PublishHandlingResult startDelivering(Publish publish, SingleSubscriber subscriber); - PublishHandlingResult deliverRetainedMessages(TopicFilter topicFilter, SingleSubscriber subscriber); + Array deliverRetainedMessages(TopicFilter topicFilter, SingleSubscriber subscriber); } diff --git a/service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java b/service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java index c66d3f49..9c45b9b7 100644 --- a/service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java +++ b/service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java @@ -9,6 +9,8 @@ import javasabr.mqtt.service.PublishDeliveringService; import javasabr.mqtt.service.publish.handler.MqttPublishOutMessageHandler; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; +import javasabr.rlib.collections.array.Array; +import javasabr.rlib.collections.array.MutableArray; import lombok.AccessLevel; import lombok.CustomLog; import lombok.experimental.FieldDefaults; @@ -20,7 +22,7 @@ public class DefaultPublishDeliveringService implements PublishDeliveringService @Nullable MqttPublishOutMessageHandler[] publishOutMessageHandlers; - ConcurrentRetainedMessageTree topicTree; + ConcurrentRetainedMessageTree retainedMessageTree; public DefaultPublishDeliveringService( Collection knownPublishOutHandlers) { @@ -42,7 +44,7 @@ public DefaultPublishDeliveringService( } handlers[qos.level()] = knownPublishOutHandler; } - this.topicTree = new ConcurrentRetainedMessageTree(); + this.retainedMessageTree = new ConcurrentRetainedMessageTree(); this.publishOutMessageHandlers = handlers; log.info(publishOutMessageHandlers, DefaultPublishDeliveringService::buildServiceDescription); } @@ -50,9 +52,7 @@ public DefaultPublishDeliveringService( @Override public PublishHandlingResult startDelivering(Publish publish, SingleSubscriber subscriber) { try { - if (publish.retained()) { - topicTree.retainMessage(publish); - } + retainedMessageTree.retainMessage(publish); //noinspection DataFlowIssue return publishOutMessageHandlers[subscriber.qos().level()].handle(publish, subscriber); } catch (IndexOutOfBoundsException | NullPointerException ex) { @@ -62,9 +62,13 @@ public PublishHandlingResult startDelivering(Publish publish, SingleSubscriber s } @Override - public PublishHandlingResult deliverRetainedMessages(TopicFilter topicFilter, SingleSubscriber subscriber) { - Publish retainedMessage = topicTree.getRetainedMessage(topicFilter); - return startDelivering(retainedMessage, subscriber); + public Array deliverRetainedMessages(TopicFilter topicFilter, SingleSubscriber subscriber) { + Array retainedMessage = retainedMessageTree.getRetainedMessage(topicFilter); + MutableArray result = MutableArray.ofType(PublishHandlingResult.class); + for (Publish message : retainedMessage) { + result.add(startDelivering(message, subscriber)); + } + return result; } private static String buildServiceDescription( diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java index 949bdf48..67eab259 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java @@ -194,23 +194,27 @@ private void sendRetainedMessages( RequestedSubscription requestedSubscription = subscriptions.get(i); SubscribeAckReasonCode subscribeAckReasonCode = subscribeResults.get(i); Subscription subscription = subs.get(i); - if (subscribeAckReasonCode.ordinal() < 3) { - TopicFilter topicFilter = TopicFilter.valueOf(requestedSubscription.rawTopicFilter()); - SingleSubscriber singleSubscriber = new SingleSubscriber(client, subscription); - PublishHandlingResult result = publishDeliveringService.deliverRetainedMessages(topicFilter, singleSubscriber); + if (subscribeAckReasonCode.ordinal() > 2) { + // TODO handle error + continue; + } + TopicFilter topicFilter = TopicFilter.valueOf(requestedSubscription.rawTopicFilter()); + SingleSubscriber singleSubscriber = new SingleSubscriber(client, subscription); + var results = publishDeliveringService.deliverRetainedMessages(topicFilter, singleSubscriber); + for (PublishHandlingResult result : results) { if (result.error()) { errorResult = result; - } else if(result == PublishHandlingResult.SUCCESS) { + } else if (result == PublishHandlingResult.SUCCESS) { count++; } if (errorResult != null) { log.debug(client.clientId(), errorResult, "[%s] Found final error:[%s] during sending retained messages"::formatted); - // handleError(client, publish, errorResult); + // TODO handleError(client, publish, errorResult); } else { log.debug(client.clientId(), count, "[%s] Successfully started delivering retained messages to [%s] subscribers"::formatted); - // handleSuccessfulResult(client, publish, count); + // TODO handleSuccessfulResult(client, publish, count); } } } From 01bac5bffc8ec7d39213170e28e6f3565f01e159 Mon Sep 17 00:00:00 2001 From: Maksim Kashapov Date: Wed, 19 Nov 2025 21:48:58 +0100 Subject: [PATCH 3/9] [broker-30] Cleanup code after merge --- .../mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java | 1 - .../javasabr/mqtt/model/subscriber/tree/SubscriberNode.java | 5 ----- .../src/main/java/javasabr/mqtt/model/topic/TopicFilter.java | 4 ++-- .../java/javasabr/mqtt/model/topic/tree/package-info.java | 2 +- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java index 5edb2233..c5e52e00 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java @@ -13,7 +13,6 @@ import org.jspecify.annotations.Nullable; @FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) - public class ConcurrentSubscriberTree implements ThreadSafe { SubscriberNode rootNode; diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java index a3ac5beb..ddd62086 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java @@ -4,7 +4,6 @@ import javasabr.mqtt.base.util.DebugUtils; import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscriber.Subscriber; -import javasabr.mqtt.model.subscriber.tree.SubscriberTreeBase; import javasabr.mqtt.model.subscribtion.Subscription; import javasabr.mqtt.model.subscribtion.SubscriptionOwner; import javasabr.mqtt.model.topic.TopicFilter; @@ -44,7 +43,6 @@ public SingleSubscriber subscribe(int level, SubscriptionOwner owner, Subscripti if (level == topicFilter.levelsCount()) { return addSubscriber(getOrCreateSubscribers(), owner, subscription, topicFilter); } - SubscriberNode childNode = getOrCreateChildNode(topicFilter.segment(level)); return childNode.subscribe(level + 1, owner, subscription, topicFilter); } @@ -53,7 +51,6 @@ public boolean unsubscribe(int level, SubscriptionOwner owner, TopicFilter topic if (level == topicFilter.levelsCount()) { return removeSubscriber(subscribers(), owner, topicFilter); } - SubscriberNode childNode = getOrCreateChildNode(topicFilter.segment(level)); return childNode.unsubscribe(level + 1, owner, topicFilter); } @@ -86,7 +83,6 @@ private void singleWildcardTopicMatch( TopicName topicName, int lastLevel, MutableArray result) { - SubscriberNode subscriberNode = childNode(TopicFilter.SINGLE_LEVEL_WILDCARD); if (subscriberNode == null) { return; @@ -99,7 +95,6 @@ private void singleWildcardTopicMatch( } private void multiWildcardTopicMatch(MutableArray result) { - SubscriberNode subscriberNode = childNode(TopicFilter.MULTI_LEVEL_WILDCARD); if (subscriberNode != null) { appendSubscribersTo(result, subscriberNode); diff --git a/model/src/main/java/javasabr/mqtt/model/topic/TopicFilter.java b/model/src/main/java/javasabr/mqtt/model/topic/TopicFilter.java index da653f8a..fa9c1b6f 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/TopicFilter.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/TopicFilter.java @@ -11,9 +11,9 @@ public class TopicFilter extends AbstractTopic { public static final String MULTI_LEVEL_WILDCARD = "#"; - public static final char MULTI_LEVEL_WILDCARD_CHAR = MULTI_LEVEL_WILDCARD.charAt(0); + public static final char MULTI_LEVEL_WILDCARD_CHAR = '#'; public static final String SINGLE_LEVEL_WILDCARD = "+"; - public static final char SINGLE_LEVEL_WILDCARD_CHAR = SINGLE_LEVEL_WILDCARD.charAt(0); + public static final char SINGLE_LEVEL_WILDCARD_CHAR = '+'; public static final String SPECIAL = "$"; public static final TopicFilter INVALID_TOPIC_FILTER = new TopicFilter("$invalid$") { diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java index 95d9e9e6..1df48806 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java @@ -1,4 +1,4 @@ @NullMarked package javasabr.mqtt.model.topic.tree; -import org.jspecify.annotations.NullMarked; \ No newline at end of file +import org.jspecify.annotations.NullMarked; From c903510a753ac3b60e7d2c4129ec26444c879043 Mon Sep 17 00:00:00 2001 From: Maksim Kashapov Date: Thu, 20 Nov 2025 08:15:35 +0100 Subject: [PATCH 4/9] [broker-30] Fix corner cases in retained messages --- .../tree/ConcurrentRetainedMessageTree.java | 2 +- .../model/topic/tree/RetainedMessageNode.java | 45 ++++--- .../topic/tree/RetainedMessageTreeTest.groovy | 122 ++++++++++++++++++ 3 files changed, 151 insertions(+), 18 deletions(-) create mode 100644 model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java index 78c78f11..6f38475f 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java @@ -25,7 +25,7 @@ public void retainMessage(Publish message) { public Array getRetainedMessage(TopicFilter topicFilter) { var resultArray = MutableArray.ofType(Publish.class); - rootNode.collectRetainedMessages(0, topicFilter, topicFilter.levelsCount() - 1, resultArray); + rootNode.collectRetainedMessages(0, topicFilter, resultArray); return resultArray; } } diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java index 0c362fc9..cc6b4b68 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java @@ -3,6 +3,7 @@ import static javasabr.mqtt.model.topic.TopicFilter.MULTI_LEVEL_WILDCARD; import static javasabr.mqtt.model.topic.TopicFilter.SINGLE_LEVEL_WILDCARD; +import java.util.LinkedList; import java.util.Objects; import java.util.PriorityQueue; import java.util.Queue; @@ -37,19 +38,22 @@ class RetainedMessageNode { final AtomicReference<@Nullable Publish> retainedMessage = new AtomicReference<>(); public void retainMessage(int level, Publish message, TopicName topicName) { - if (level + 1 == topicName.levelsCount()) { - retainedMessage.set(message.payload().length == 0 ? null : message); - return; + var child = getOrCreateChildNode(topicName.segment(level)); + boolean isLeaf = (level + 1 == topicName.levelsCount()); + if (isLeaf) { + if (Objects.equals(message.topicName().lastSegment(), topicName.lastSegment())) { + child.retainedMessage.set(message.payload().length == 0 ? null : message); + } + } else { + child.retainMessage(level + 1, message, topicName); } - RetainedMessageNode childNode = getOrCreateChildNode(topicName.segment(level)); - childNode.retainMessage(level + 1, message, topicName); } - public void collectRetainedMessages(int level, TopicFilter topicFilter, int lastLevel, MutableArray result) { + public void collectRetainedMessages(int level, TopicFilter topicFilter, MutableArray result) { String segment = topicFilter.segment(level); - Publish publish = retainedMessage.get(); if (Objects.equals(segment, MULTI_LEVEL_WILDCARD)) { collectAllMessages(this, result); + return; } else if (Objects.equals(segment, SINGLE_LEVEL_WILDCARD)) { var childNodes = childNodes(); if (childNodes == null) { @@ -57,25 +61,32 @@ public void collectRetainedMessages(int level, TopicFilter topicFilter, int last } long stamp = childNodes.readLock(); try { - for (RetainedMessageNode n : childNodes) { - n.collectRetainedMessages(level + 1, topicFilter, lastLevel, result); + for (RetainedMessageNode childNode : childNodes) { + childNode.collectRetainedMessages(level + 1, topicFilter, result); } } finally { childNodes.readUnlock(stamp); } - } else if (level == lastLevel && publish != null && Objects.equals(segment, publish.topicName().lastSegment())) { - result.add(publish); - } else { - RetainedMessageNode topicFilterNode = childNode(segment); - if (topicFilterNode == null) { - return; + return; + } + int lastLevel = topicFilter.levelsCount() - 1; + RetainedMessageNode retainedMessageNode = childNode(segment); + if (retainedMessageNode == null || level > lastLevel) { + return; + } + boolean isLeaf = (level == lastLevel); + if (isLeaf) { + Publish publish = retainedMessageNode.retainedMessage.get(); + if(publish != null && Objects.equals(segment, publish.topicName().lastSegment())){ + result.add(publish); } - topicFilterNode.collectRetainedMessages(level + 1, topicFilter, lastLevel, result); + } else { + retainedMessageNode.collectRetainedMessages(level + 1, topicFilter, result); } } private void collectAllMessages(RetainedMessageNode node, MutableArray result) { - Queue queue = new PriorityQueue<>(); + Queue queue = new LinkedList<>(); queue.add(node); while (!queue.isEmpty()) { RetainedMessageNode poll = queue.poll(); diff --git a/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy new file mode 100644 index 00000000..1baf0582 --- /dev/null +++ b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy @@ -0,0 +1,122 @@ +package javasabr.mqtt.model.topic.tree + + +import javasabr.mqtt.model.PayloadFormat +import javasabr.mqtt.model.QoS +import javasabr.mqtt.model.publishing.Publish +import javasabr.mqtt.model.topic.TopicFilter +import javasabr.mqtt.model.topic.TopicName +import javasabr.mqtt.test.support.UnitSpecification +import javasabr.rlib.collections.array.Array +import javasabr.rlib.collections.array.IntArray + +import static java.nio.charset.StandardCharsets.UTF_8 + +class RetainedMessageTreeTest extends UnitSpecification { + + def "should fetch retained messages by topic filter"( + List messages, + String topicFilter, + List expectedMessages) { + given: + ConcurrentRetainedMessageTree retainedMessageTree = new ConcurrentRetainedMessageTree(); + messages.eachWithIndex { Publish message, int i -> + retainedMessageTree.retainMessage(message) + } + when: + def retainedMessages = retainedMessageTree.getRetainedMessage(TopicFilter.valueOf(topicFilter)) + .collect { it } + then: + retainedMessages.size() == expectedMessages.size() + for (int i = 0; i < retainedMessages.size(); i++) { + assert retainedMessages.get(i).topicName() == expectedMessages.get(i).topicName() + } + where: + topicFilter << [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment3", + "/topic/+/segment2", + "/topic/#" + ] + messages << [ + [ + makePublish("/topic/segment1"), + makePublish("/topic/segment2"), + makePublish("/topic/segment1/segment2"), + makePublish("/topic/"), + makePublish("/topic") + ], + [ + makePublish("/topic/segment1"), + makePublish("/topic/segment2"), + makePublish("/topic/segment1/segment2"), + makePublish("/topic/"), + makePublish("/topic/segment2"), + makePublish("/"), + makePublish("/topic/segment2/segment1") + ], + [ + makePublish("/topic/segment1"), + makePublish("/topic/segment2"), + makePublish("/topic/segment3"), + makePublish("/topic/segment3"), + makePublish("/topic/segment3"), + makePublish("/topic/segment3") + ], + [ + makePublish("/topic/segment1"), + makePublish("/topic/segment2"), + makePublish("/topic/segment1/segment2"), + makePublish("/topic/segment500/segment2"), + makePublish("/topic/"), + makePublish("/topic") + ], + [ + makePublish("/topic1/segment1"), + makePublish("/topic/segment2"), + makePublish("/topic2/segment1/segment2"), + makePublish("/topic/segment3"), + makePublish("/topic/segment1/segment2") + ] + ] + expectedMessages << [ + [ + makePublish("/topic/segment1") + ], + [ + makePublish("/topic/segment2") + ], + [ + makePublish("/topic/segment3") + ], + [ + makePublish("/topic/segment1/segment2"), + makePublish("/topic/segment500/segment2") + ], + [ + makePublish("/topic/segment2"), + makePublish("/topic/segment3"), + makePublish("/topic/segment1/segment2") + ] + ] + } + + static def makePublish(String topicName) { + return new Publish( + 1, + QoS.AT_MOST_ONCE, + TopicName.valueOf(topicName), + null, + "payload".getBytes(UTF_8), + false, + true, + null, + IntArray.of(30), + null, + 60000, + 1, + PayloadFormat.UTF8_STRING, + Array.of()); + } +} From e1f5c2a4fc148be5c421d3df3041f65a4cad69f3 Mon Sep 17 00:00:00 2001 From: Maksim Kashapov Date: Thu, 20 Nov 2025 10:35:31 +0100 Subject: [PATCH 5/9] [broker-30] Handle SubscribeRetainHandling --- .../config/MqttBrokerSpringConfig.java | 9 ++-- .../impl/InMemorySubscriptionService.java | 49 ++++++++++++++++++- .../impl/SubscribeMqttInMessageHandler.java | 47 +----------------- 3 files changed, 52 insertions(+), 53 deletions(-) diff --git a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java index b809f587..e8e7597f 100644 --- a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java +++ b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java @@ -99,8 +99,8 @@ AuthenticationService authenticationService( } @Bean - SubscriptionService subscriptionService() { - return new InMemorySubscriptionService(); + SubscriptionService subscriptionService(PublishDeliveringService publishDeliveringService) { + return new InMemorySubscriptionService(publishDeliveringService); } @Bean @@ -179,9 +179,8 @@ MqttInMessageHandler disconnectMqttInMessageHandler(MessageOutFactoryService mes MqttInMessageHandler subscribeMqttInMessageHandler( SubscriptionService subscriptionService, MessageOutFactoryService messageOutFactoryService, - TopicService topicService, - PublishDeliveringService publishDeliveringService) { - return new SubscribeMqttInMessageHandler(subscriptionService, messageOutFactoryService, topicService, publishDeliveringService); + TopicService topicService) { + return new SubscribeMqttInMessageHandler(subscriptionService, messageOutFactoryService, topicService); } @Bean diff --git a/service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java b/service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java index 78647c51..3128e574 100644 --- a/service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java +++ b/service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java @@ -1,5 +1,7 @@ package javasabr.mqtt.service.impl; +import static javasabr.mqtt.model.SubscribeRetainHandling.SEND; +import static javasabr.mqtt.model.SubscribeRetainHandling.SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST; import static javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode.NO_SUBSCRIPTION_EXISTED; import static javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode.SUCCESS; @@ -8,15 +10,17 @@ import javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode; import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscriber.Subscriber; +import javasabr.mqtt.model.subscriber.tree.ConcurrentSubscriberTree; import javasabr.mqtt.model.subscribtion.Subscription; import javasabr.mqtt.model.topic.SharedTopicFilter; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; -import javasabr.mqtt.model.subscriber.tree.ConcurrentSubscriberTree; import javasabr.mqtt.network.MqttClient; import javasabr.mqtt.network.session.ActiveSubscriptions; import javasabr.mqtt.network.session.MqttSession; +import javasabr.mqtt.service.PublishDeliveringService; import javasabr.mqtt.service.SubscriptionService; +import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import javasabr.rlib.collections.array.Array; import javasabr.rlib.collections.array.ArrayFactory; import javasabr.rlib.collections.array.MutableArray; @@ -31,10 +35,12 @@ @FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) public class InMemorySubscriptionService implements SubscriptionService { + PublishDeliveringService publishDeliveringService; ConcurrentSubscriberTree subscriberTree; - public InMemorySubscriptionService() { + public InMemorySubscriptionService(PublishDeliveringService publishDeliveringService) { this.subscriberTree = new ConcurrentSubscriberTree(); + this.publishDeliveringService = publishDeliveringService; } @Override @@ -84,6 +90,10 @@ private SubscribeAckReasonCode addSubscription(MqttClient client, MqttSession se if (previous != null) { activeSubscriptions.remove(previous.subscription()); } + if ((subscription.retainHandling() == SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST && previous != null) + || subscription.retainHandling() == SEND) { + sendRetainedMessages(client, subscription); + } activeSubscriptions.add(subscription); return subscription.qos().subscribeAckReasonCode(); } @@ -137,4 +147,39 @@ public void restoreSubscriptions(MqttClient client, MqttSession session) { subscriberTree.subscribe(client, subscription); } } + + private void sendRetainedMessages(MqttClient client, Subscription subscription) { + int count = 0; + PublishHandlingResult errorResult = null; + if (subscription + .qos() + .subscribeAckReasonCode() + .ordinal() > 2) { + // TODO handle error ? + return; + } + SingleSubscriber singleSubscriber = new SingleSubscriber(client, subscription); + var results = publishDeliveringService.deliverRetainedMessages(subscription.topicFilter(), singleSubscriber); + for (PublishHandlingResult result : results) { + if (result.error()) { + errorResult = result; + } else if (result == PublishHandlingResult.SUCCESS) { + count++; + } + if (errorResult != null) { + log.debug( + client.clientId(), + errorResult, + "[%s] Found final error:[%s] during sending retained messages"::formatted); + // TODO handleError(client, publish, errorResult); + } else { + log.debug( + client.clientId(), + count, + "[%s] Successfully started delivering retained messages to [%s] subscribers"::formatted); + // TODO handleSuccessfulResult(client, publish, count); + } + + } + } } diff --git a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java index 67eab259..9e5528d7 100644 --- a/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java +++ b/service/src/main/java/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandler.java @@ -9,7 +9,6 @@ import javasabr.mqtt.model.QoS; import javasabr.mqtt.model.reason.code.DisconnectReasonCode; import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode; -import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscribtion.RequestedSubscription; import javasabr.mqtt.model.subscribtion.Subscription; import javasabr.mqtt.model.topic.TopicFilter; @@ -21,10 +20,8 @@ import javasabr.mqtt.network.session.MessageTacker; import javasabr.mqtt.network.session.MqttSession; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.PublishDeliveringService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.TopicService; -import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import javasabr.rlib.collections.array.Array; import javasabr.rlib.collections.array.ArrayFactory; import javasabr.rlib.collections.array.MutableArray; @@ -41,19 +38,16 @@ public class SubscribeMqttInMessageHandler extends SHARED_SUBSCRIPTIONS_NOT_SUPPORTED, WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED); - PublishDeliveringService publishDeliveringService; SubscriptionService subscriptionService; TopicService topicService; public SubscribeMqttInMessageHandler( SubscriptionService subscriptionService, MessageOutFactoryService messageOutFactoryService, - TopicService topicService, - PublishDeliveringService publishDeliveringService) { + TopicService topicService) { super(ExternalMqttClient.class, SubscribeMqttInMessage.class, messageOutFactoryService); this.subscriptionService = subscriptionService; this.topicService = topicService; - this.publishDeliveringService = publishDeliveringService; } @Override @@ -98,7 +92,6 @@ protected void processValidMessage( .subscribe(client, session, subscriptions); sendSubscribeResults(client, session, subscribeMessage, subscribeResults); - sendRetainedMessages(client, subscribeMessage, subscribeResults, subscriptions); SubscribeAckReasonCode anyReasonToDisconnect = subscribeResults .iterations() @@ -181,42 +174,4 @@ private void sendSubscribeResults( .inMessageTracker() .remove(messageId)); } - - private void sendRetainedMessages( - ExternalMqttClient client, - SubscribeMqttInMessage subscribeMessage, - Array subscribeResults, - Array subs) { - int count = 0; - PublishHandlingResult errorResult = null; - Array subscriptions = subscribeMessage.subscriptions(); - for (int i = 0; i < subscribeMessage.subscriptionsCount(); i++) { - RequestedSubscription requestedSubscription = subscriptions.get(i); - SubscribeAckReasonCode subscribeAckReasonCode = subscribeResults.get(i); - Subscription subscription = subs.get(i); - if (subscribeAckReasonCode.ordinal() > 2) { - // TODO handle error - continue; - } - TopicFilter topicFilter = TopicFilter.valueOf(requestedSubscription.rawTopicFilter()); - SingleSubscriber singleSubscriber = new SingleSubscriber(client, subscription); - var results = publishDeliveringService.deliverRetainedMessages(topicFilter, singleSubscriber); - for (PublishHandlingResult result : results) { - if (result.error()) { - errorResult = result; - } else if (result == PublishHandlingResult.SUCCESS) { - count++; - } - if (errorResult != null) { - log.debug(client.clientId(), errorResult, - "[%s] Found final error:[%s] during sending retained messages"::formatted); - // TODO handleError(client, publish, errorResult); - } else { - log.debug(client.clientId(), count, - "[%s] Successfully started delivering retained messages to [%s] subscribers"::formatted); - // TODO handleSuccessfulResult(client, publish, count); - } - } - } - } } From a1b87dd99f6272247635bcb2500eeebc37c8515d Mon Sep 17 00:00:00 2001 From: Maksim Kashapov <56276969+crazyrokr@users.noreply.github.com> Date: Mon, 1 Dec 2025 15:10:45 +0100 Subject: [PATCH 6/9] [broker-30] Fix build --- .../config/MqttBrokerSpringConfig.java | 69 ++++---------- .../mqtt/service/SubscriptionService.java | 3 - .../impl/InMemorySubscriptionService.java | 9 -- .../AbstractMqttPublishOutMessageHandler.java | 12 ++- ...PersistedMqttPublishOutMessageHandler.java | 13 +-- .../Qos0MqttPublishOutMessageHandler.java | 7 +- .../Qos1MqttPublishOutMessageHandler.java | 7 +- .../Qos2MqttPublishOutMessageHandler.java | 7 +- .../IntegrationServiceSpecification.groovy | 12 +-- .../InMemorySubscriptionServiceTest.groovy | 20 ++--- .../SubscribeMqttInMessageHandlerTest.groovy | 90 ++----------------- ...UnsubscribeMqttInMessageHandlerTest.groovy | 2 +- 12 files changed, 64 insertions(+), 187 deletions(-) diff --git a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java index a6961877..756014ba 100644 --- a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java +++ b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java @@ -93,7 +93,8 @@ CredentialSource credentialSource( @Bean AuthenticationService authenticationService( CredentialSource credentialSource, - @Value("${authentication.allow.anonymous:false}") boolean allowAnonymousAuth) { + @Value("${authentication.allow.anonymous:false}") + boolean allowAnonymousAuth) { return new SimpleAuthenticationService(credentialSource, allowAnonymousAuth); } @@ -153,10 +154,7 @@ MqttInMessageHandler publishMqttInMessageHandler( PublishReceivingService publishReceivingService, MessageOutFactoryService messageOutFactoryService, TopicService topicService) { - return new PublishMqttInMessageHandler( - publishReceivingService, - messageOutFactoryService, - topicService); + return new PublishMqttInMessageHandler(publishReceivingService, messageOutFactoryService, topicService); } @Bean @@ -187,10 +185,7 @@ MqttInMessageHandler unsubscribeMqttInMessageHandler( SubscriptionService subscriptionService, MessageOutFactoryService messageOutFactoryService, TopicService topicService) { - return new UnsubscribeMqttInMessageHandler( - subscriptionService, - messageOutFactoryService, - topicService); + return new UnsubscribeMqttInMessageHandler(subscriptionService, messageOutFactoryService, topicService); } @Bean @@ -199,24 +194,18 @@ ConnectionService externalMqttConnectionService(Collection findSubscribers(TopicName topicName) { return findSubscribersTo(MutableArray.ofType(SingleSubscriber.class), topicName); } diff --git a/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java b/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java index c1234fe4..1c2e3e74 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java +++ b/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java @@ -12,7 +12,6 @@ import javasabr.mqtt.model.session.ActiveSubscriptions; import javasabr.mqtt.model.session.MqttSession; import javasabr.mqtt.model.subscriber.SingleSubscriber; -import javasabr.mqtt.model.subscriber.Subscriber; import javasabr.mqtt.model.subscriber.tree.ConcurrentSubscriberTree; import javasabr.mqtt.model.subscription.Subscription; import javasabr.mqtt.model.topic.SharedTopicFilter; @@ -44,14 +43,6 @@ public InMemorySubscriptionService(PublishDeliveringService publishDeliveringSer this.publishDeliveringService = publishDeliveringService; } - @Override - public NetworkMqttUser resolveClient(Subscriber subscriber) { - if (subscriber instanceof SingleSubscriber single) { - return (NetworkMqttUser) single.user(); - } - throw new IllegalArgumentException("Unexpected subscriber: " + subscriber); - } - @Override public Array findSubscribersTo(MutableArray container, TopicName topicName) { Array matched = subscriberTree.matches(topicName); diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java index 613e1f96..2b329bc6 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java @@ -3,10 +3,10 @@ import javasabr.mqtt.model.MqttProperties; import javasabr.mqtt.model.publishing.Publish; import javasabr.mqtt.model.subscriber.SingleSubscriber; +import javasabr.mqtt.model.subscriber.Subscriber; import javasabr.mqtt.network.message.out.MqttOutMessage; import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.MqttPublishOutMessageHandler; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import lombok.AccessLevel; @@ -22,12 +22,18 @@ public abstract class AbstractMqttPublishOutMessageHandler expectedUser; - SubscriptionService subscriptionService; MessageOutFactoryService messageOutFactoryService; + private static NetworkMqttUser resolveClient(Subscriber subscriber) { + if (subscriber instanceof SingleSubscriber single) { + return (NetworkMqttUser) single.user(); + } + throw new IllegalArgumentException("Unexpected subscriber: " + subscriber); + } + @Override public PublishHandlingResult handle(Publish publish, SingleSubscriber subscriber) { - NetworkMqttUser user = subscriptionService.resolveClient(subscriber); + NetworkMqttUser user = resolveClient(subscriber); if (!expectedUser.isInstance(user)) { log.warning(user, "Accepted not expected client:[%s]"::formatted); return PublishHandlingResult.NOT_EXPECTED_CLIENT; diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/PersistedMqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/PersistedMqttPublishOutMessageHandler.java index 13a5b252..9f572f5b 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/PersistedMqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/PersistedMqttPublishOutMessageHandler.java @@ -8,7 +8,6 @@ import javasabr.mqtt.network.session.NetworkMqttSession.PendingMessageHandler; import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import lombok.AccessLevel; import lombok.experimental.FieldDefaults; @@ -20,15 +19,14 @@ public abstract class PersistedMqttPublishOutMessageHandler extends PendingMessageHandler pendingMessageHandler; - protected PersistedMqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, messageOutFactoryService); + protected PersistedMqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(ExternalNetworkMqttUser.class, messageOutFactoryService); this.pendingMessageHandler = new PendingMessageHandler() { @Override public boolean handleResponse(NetworkMqttUser user, TrackableMqttMessage response) { return handleReceivedResponse(user, response); } + @Override public void resend(NetworkMqttUser user, Publish publish) { tryToDeliverAgain(user, publish); @@ -45,10 +43,7 @@ protected Publish reconstruct(NetworkMqttUser user, Publish original) { } return original.with( // generate new uniq packet id per client - session.generateMessageId(), - qos(), - false, - MqttProperties.TOPIC_ALIAS_NOT_SET); + session.generateMessageId(), qos(), false, MqttProperties.TOPIC_ALIAS_NOT_SET); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java index ad09e51d..e31dd20b 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java @@ -4,15 +4,12 @@ import javasabr.mqtt.model.publishing.Publish; import javasabr.mqtt.network.impl.ExternalNetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; public class Qos0MqttPublishOutMessageHandler extends AbstractMqttPublishOutMessageHandler { - public Qos0MqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, messageOutFactoryService); + public Qos0MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(ExternalNetworkMqttUser.class, messageOutFactoryService); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java index 4aec8dd7..826ee01d 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java @@ -5,14 +5,11 @@ import javasabr.mqtt.network.message.in.PublishAckMqttInMessage; import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; public class Qos1MqttPublishOutMessageHandler extends PersistedMqttPublishOutMessageHandler { - public Qos1MqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(subscriptionService, messageOutFactoryService); + public Qos1MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(messageOutFactoryService); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java index cb9584f3..ecb65c0e 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java @@ -8,14 +8,11 @@ import javasabr.mqtt.network.message.in.PublishReceivedMqttInMessage; import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; public class Qos2MqttPublishOutMessageHandler extends PersistedMqttPublishOutMessageHandler { - public Qos2MqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(subscriptionService, messageOutFactoryService); + public Qos2MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(messageOutFactoryService); } @Override diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy index f1054d75..577388b3 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy @@ -44,8 +44,7 @@ abstract class IntegrationServiceSpecification extends Specification { @Shared def defaultTopicService = new DefaultTopicService() - @Shared - def defaultSubscriptionService = new InMemorySubscriptionService() + @Shared def defaultMessageOutFactoryService = new DefaultMessageOutFactoryService([ @@ -55,11 +54,14 @@ abstract class IntegrationServiceSpecification extends Specification { @Shared def defaultPublishDeliveringService = new DefaultPublishDeliveringService([ - new Qos0MqttPublishOutMessageHandler(defaultSubscriptionService, defaultMessageOutFactoryService), - new Qos1MqttPublishOutMessageHandler(defaultSubscriptionService, defaultMessageOutFactoryService), - new Qos2MqttPublishOutMessageHandler(defaultSubscriptionService, defaultMessageOutFactoryService) + new Qos0MqttPublishOutMessageHandler(defaultMessageOutFactoryService), + new Qos1MqttPublishOutMessageHandler(defaultMessageOutFactoryService), + new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) ]) + @Shared + def defaultSubscriptionService = new InMemorySubscriptionService(defaultPublishDeliveringService) + @Shared def qos0MqttPublishInMessageHandler = new Qos0MqttPublishInMessageHandler( defaultSubscriptionService, diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy index 6e97e1de..e473f024 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy @@ -12,8 +12,6 @@ import javasabr.rlib.collections.array.Array class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { - SubscriptionService subscriptionService = new InMemorySubscriptionService() - def "should subscribe with expected results in default settings"() { given: def serverConfig = defaultExternalServerConnectionConfig @@ -49,7 +47,7 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { true, true)) when: - def result = subscriptionService + def result = defaultSubscriptionService .subscribe(mqttUser, mqttUser.session(), subscriptions) then: result.size() == 4 @@ -104,7 +102,7 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { true, true)) when: - def result = subscriptionService + def result = defaultSubscriptionService .subscribe(mqttUser, mqttUser.session(), subscriptions) then: result.size() == 5 @@ -152,7 +150,7 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { true) def subscriptions = Array.of(sub1, sub2, sub3, sub4) when: - def result = subscriptionService + def result = defaultSubscriptionService .subscribe(mqttUser, mqttUser.session(), subscriptions) then: result.size() == 4 @@ -200,14 +198,14 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { SubscribeRetainHandling.SEND, true, true)) - subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + defaultSubscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) def topicsToUnsubscribe = Array.of( defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), defaultTopicService.createTopicFilter(mqttUser, "topic/filter/3"), defaultTopicService.createTopicFilter(mqttUser, "topic/filter/notexist"), defaultTopicService.createTopicFilter(mqttUser, "topic/filter/invalid##")) when: - def result = subscriptionService + def result = defaultSubscriptionService .unsubscribe(mqttUser, mqttUser.session(), topicsToUnsubscribe) then: result.size() == 4 @@ -251,13 +249,13 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), defaultTopicService.createTopicFilter(mqttUser, "topic/filter/3")) when: - subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + defaultSubscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) def storedSubscriptions = activeSubscriptions.subscriptions() then: storedSubscriptions.size() == 3 storedSubscriptions == subscriptions when: - subscriptionService.unsubscribe(mqttUser, mqttUser.session(), topicsToUnsubscribe) + defaultSubscriptionService.unsubscribe(mqttUser, mqttUser.session(), topicsToUnsubscribe) storedSubscriptions = activeSubscriptions.subscriptions() then: storedSubscriptions.size() == 1 @@ -313,13 +311,13 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { subscriptions.get(1), subscriptions2.get(1)) when: - subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + defaultSubscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) def storedSubscriptions = activeSubscriptions.subscriptions() then: storedSubscriptions.size() == 3 storedSubscriptions == subscriptions when: - subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions2) + defaultSubscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions2) storedSubscriptions = activeSubscriptions.subscriptions() then: storedSubscriptions.size() == 3 diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy index 6a3bb5a0..df3539ba 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/SubscribeMqttInMessageHandlerTest.groovy @@ -1,36 +1,23 @@ package javasabr.mqtt.service.message.handler.impl import javasabr.mqtt.model.MqttVersion -import javasabr.mqtt.model.PayloadFormat import javasabr.mqtt.model.QoS -import javasabr.mqtt.model.SubscribeRetainHandling -import javasabr.mqtt.model.publishing.Publish -import javasabr.mqtt.model.reason.code.DisconnectReasonCode -import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode -import javasabr.mqtt.model.subscriber.SingleSubscriber -import javasabr.mqtt.model.subscribtion.RequestedSubscription -import javasabr.mqtt.model.subscribtion.Subscription -import javasabr.mqtt.model.topic.TopicName import javasabr.mqtt.model.message.MqttMessageType import javasabr.mqtt.model.reason.code.DisconnectReasonCode import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode import javasabr.mqtt.model.subscription.RequestedSubscription import javasabr.mqtt.network.message.in.SubscribeMqttInMessage import javasabr.mqtt.network.message.out.DisconnectMqtt5OutMessage -import javasabr.mqtt.network.message.out.PublishMqtt5OutMessage import javasabr.mqtt.network.message.out.SubscribeAckMqtt5OutMessage import javasabr.mqtt.network.util.ExtraErrorReasons import javasabr.mqtt.service.IntegrationServiceSpecification import javasabr.mqtt.service.TestExternalNetworkMqttUser import javasabr.rlib.collections.array.Array -import javasabr.rlib.collections.array.IntArray import javasabr.rlib.collections.array.MutableArray import javasabr.rlib.common.util.ThreadUtils import javasabr.rlib.logger.api.LoggerLevel import javasabr.rlib.logger.api.LoggerManager -import static java.nio.charset.StandardCharsets.UTF_8 - class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification { static { @@ -44,8 +31,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService, - publishDeliveringService) + defaultTopicService) def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser mqttUser.session(null) when: @@ -64,8 +50,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService, - publishDeliveringService) + defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser def session = mqttUser.session() @@ -97,8 +82,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService, - publishDeliveringService) + defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: @@ -128,8 +112,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService, - publishDeliveringService) + defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: @@ -159,8 +142,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService, - publishDeliveringService) + defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: @@ -193,8 +175,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService, - publishDeliveringService) + defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: @@ -225,8 +206,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService, - publishDeliveringService) + defaultTopicService) def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: def subscribeMessage = new SubscribeMqttInMessage(0 as byte) @@ -244,8 +224,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService, - publishDeliveringService) + defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser when: @@ -283,8 +262,7 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification def messageHandler = new SubscribeMqttInMessageHandler( defaultSubscriptionService, defaultMessageOutFactoryService, - defaultTopicService, - publishDeliveringService) + defaultTopicService) def expectedMessageId = 15 def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser mqttUser.returnCompletedFeatures(false) @@ -316,54 +294,4 @@ class SubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecification reasonCodes2.get(0) == SubscribeAckReasonCode.PACKET_IDENTIFIER_IN_USE subscribeAck2.messageId() == expectedMessageId } - - def "should deliver retained messages"() { - given: - def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) - def messageHandler = new SubscribeMqttInMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService, - defaultTopicService, - publishDeliveringService) - def expectedMessageId = 15 - def mqttClient = mqttConnection.client() as TestExternalMqttClient - mqttClient.returnCompletedFeatures(false) - when: - Publish publish = new Publish( - 1, - QoS.AT_MOST_ONCE, - TopicName.valueOf("topic2"), - null, - "payload".getBytes(UTF_8), - false, - true, - null, - IntArray.of(30), - null, - 60000, - 1, - PayloadFormat.UTF8_STRING, - Array.of()); - Subscription subscription = new Subscription( - defaultTopicService.createTopicFilter(mqttClient, "topic2"), - 30, - QoS.EXACTLY_ONCE, - SubscribeRetainHandling.SEND, - true, - true); - SingleSubscriber subscriber = new SingleSubscriber(mqttClient, subscription); - publishDeliveringService.startDelivering(publish, subscriber) - - def subscribeMessage = new SubscribeMqttInMessage(SubscribeMqttInMessage.MESSAGE_FLAGS) {{ - this.messageId = 1 - this.subscriptions = MutableArray.ofType(RequestedSubscription) - this.subscriptions.addAll(Array.of(RequestedSubscription.minimal("topic2", QoS.EXACTLY_ONCE))) - }} - messageHandler.processValidMessage(mqttConnection, subscribeMessage) - then: - mqttClient.nextSentMessage(PublishMqtt5OutMessage) - mqttClient.nextSentMessage(SubscribeAckMqtt5OutMessage) - def retainedMessageDelivery = mqttClient.nextSentMessage(PublishMqtt5OutMessage) - retainedMessageDelivery.messageId() == 2 - } } diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy index 2ebca07c..09a94c0a 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy @@ -69,7 +69,7 @@ class UnsubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecificatio def "should response with expected results"() { given: def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) - def subscriptionService = new InMemorySubscriptionService() + def subscriptionService = new InMemorySubscriptionService(defaultPublishDeliveringService) def messageHandler = new UnsubscribeMqttInMessageHandler( subscriptionService, defaultMessageOutFactoryService, From be8070bc9ce241846e7b174fd0ae9eb6302c6072 Mon Sep 17 00:00:00 2001 From: Maksim Kashapov <56276969+crazyrokr@users.noreply.github.com> Date: Tue, 2 Dec 2025 23:56:47 +0100 Subject: [PATCH 7/9] [broker-30] SubscriberNode refactoring --- .../model/subscriber/tree/SubscriberNode.java | 85 ++++++++----------- .../subscriber/tree/SubscriberTreeBase.java | 30 +++---- 2 files changed, 45 insertions(+), 70 deletions(-) diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java index 4a6579c2..3a0ec8bc 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java @@ -22,7 +22,7 @@ @Getter(AccessLevel.PACKAGE) @Accessors(fluent = true, chain = false) @FieldDefaults(level = AccessLevel.PRIVATE) -class SubscriberNode extends SubscriberTreeBase { +public class SubscriberNode extends SubscriberTreeBase { private final static Supplier SUBSCRIBER_NODE_FACTORY = SubscriberNode::new; @@ -39,7 +39,7 @@ class SubscriberNode extends SubscriberTreeBase { * @return the previous subscription from the same owner */ @Nullable - public SingleSubscriber subscribe(int level, MqttUser owner, Subscription subscription, TopicFilter topicFilter) { + protected SingleSubscriber subscribe(int level, MqttUser owner, Subscription subscription, TopicFilter topicFilter) { if (level == topicFilter.levelsCount()) { return addSubscriber(getOrCreateSubscribers(), owner, subscription, topicFilter); } @@ -47,7 +47,7 @@ public SingleSubscriber subscribe(int level, MqttUser owner, Subscription subscr return childNode.subscribe(level + 1, owner, subscription, topicFilter); } - public boolean unsubscribe(int level, MqttUser owner, TopicFilter topicFilter) { + protected boolean unsubscribe(int level, MqttUser owner, TopicFilter topicFilter) { if (level == topicFilter.levelsCount()) { return removeSubscriber(subscribers(), owner, topicFilter); } @@ -56,51 +56,28 @@ public boolean unsubscribe(int level, MqttUser owner, TopicFilter topicFilter) { } protected void matchesTo(int level, TopicName topicName, int lastLevel, MutableArray container) { - exactlyTopicMatch(level, topicName, lastLevel, container); - singleWildcardTopicMatch(level, topicName, lastLevel, container); - multiWildcardTopicMatch(container); + collectMatchingSubscribers(topicName.segment(level), level, topicName, lastLevel, container); + collectMatchingSubscribers(TopicFilter.SINGLE_LEVEL_WILDCARD, level, topicName, lastLevel, container); + collectMatchingSubscribers(TopicFilter.MULTI_LEVEL_WILDCARD, level, topicName, lastLevel, container); } - private void exactlyTopicMatch( + private void collectMatchingSubscribers( + String segment, int level, TopicName topicName, int lastLevel, MutableArray result) { - String segment = topicName.segment(level); SubscriberNode subscriberNode = childNode(segment); if (subscriberNode == null) { return; } - if (level == lastLevel) { + if (level == lastLevel || TopicFilter.MULTI_LEVEL_WILDCARD.equals(segment)) { appendSubscribersTo(result, subscriberNode); } else if (level < lastLevel) { subscriberNode.matchesTo(level + 1, topicName, lastLevel, result); } } - private void singleWildcardTopicMatch( - int level, - TopicName topicName, - int lastLevel, - MutableArray result) { - SubscriberNode subscriberNode = childNode(TopicFilter.SINGLE_LEVEL_WILDCARD); - if (subscriberNode == null) { - return; - } - if (level == lastLevel) { - appendSubscribersTo(result, subscriberNode); - } else if (level < lastLevel) { - subscriberNode.matchesTo(level + 1, topicName, lastLevel, result); - } - } - - private void multiWildcardTopicMatch(MutableArray result) { - SubscriberNode subscriberNode = childNode(TopicFilter.MULTI_LEVEL_WILDCARD); - if (subscriberNode != null) { - appendSubscribersTo(result, subscriberNode); - } - } - private SubscriberNode getOrCreateChildNode(String segment) { LockableRefToRefDictionary childNodes = getOrCreateChildNodes(); long stamp = childNodes.readLock(); @@ -122,40 +99,46 @@ private SubscriberNode getOrCreateChildNode(String segment) { @Nullable private SubscriberNode childNode(String segment) { - LockableRefToRefDictionary childNodes = childNodes(); - if (childNodes == null) { + LockableRefToRefDictionary localChildNodes = childNodes; + if (localChildNodes == null) { return null; } - long stamp = childNodes.readLock(); + long stamp = localChildNodes.readLock(); try { - return childNodes.get(segment); + return localChildNodes.get(segment); } finally { - childNodes.readUnlock(stamp); + localChildNodes.readUnlock(stamp); } } private LockableRefToRefDictionary getOrCreateChildNodes() { - if (childNodes == null) { - synchronized (this) { - if (childNodes == null) { - childNodes = DictionaryFactory.stampedLockBasedRefToRefDictionary(); - } + LockableRefToRefDictionary localChildNodes = childNodes; + if (localChildNodes != null) { + return localChildNodes; + } + synchronized (this) { + localChildNodes = childNodes; + if (localChildNodes == null) { + localChildNodes = DictionaryFactory.stampedLockBasedRefToRefDictionary(); + childNodes = localChildNodes; } + return localChildNodes; } - //noinspection ConstantConditions - return childNodes; } private LockableArray getOrCreateSubscribers() { - if (subscribers == null) { - synchronized (this) { - if (subscribers == null) { - subscribers = ArrayFactory.stampedLockBasedArray(Subscriber.class); - } + LockableArray localSubscribers = subscribers; + if (localSubscribers != null) { + return localSubscribers; + } + synchronized (this) { + localSubscribers = subscribers; + if (localSubscribers == null) { + localSubscribers = ArrayFactory.stampedLockBasedArray(Subscriber.class); + subscribers = localSubscribers; } + return localSubscribers; } - //noinspection ConstantConditions - return subscribers; } @Override diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java index 972b696b..9ae8c953 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java @@ -45,9 +45,7 @@ protected static SingleSubscriber addSubscriber( } @Nullable - private static SingleSubscriber removePreviousIfExist( - LockableArray subscribers, - MqttUser user) { + private static SingleSubscriber removePreviousIfExist(LockableArray subscribers, MqttUser user) { int index = subscribers.indexOf(Subscriber::resolveUser, user); if (index < 0) { return null; @@ -84,10 +82,7 @@ protected static void appendSubscribersTo(MutableArray result, long stamp = subscribers.readLock(); try { for (Subscriber subscriber : subscribers) { - SingleSubscriber singleSubscriber = subscriber.resolveSingle(); - if (removeDuplicateWithLowerQoS(result, singleSubscriber)) { - result.add(singleSubscriber); - } + addOrReplaceIfLowerQos(result, subscriber); } } finally { subscribers.readUnlock(stamp); @@ -141,23 +136,20 @@ private static boolean isSharedSubscriberWithGroup(Subscriber subscriber, String return subscriber instanceof SharedSubscriber shared && Objects.equals(group, shared.group()); } - private static boolean removeDuplicateWithLowerQoS( - MutableArray result, SingleSubscriber candidate) { - + private static void addOrReplaceIfLowerQos(MutableArray result, Subscriber subscriber) { + SingleSubscriber candidate = subscriber.resolveSingle(); int found = result.indexOf(SingleSubscriber::user, candidate.user()); if (found == -1) { - return true; + result.add(candidate); + return; } - QoS candidateQos = candidate.qos(); - SingleSubscriber exist = result.get(found); - QoS existeQos = exist.qos(); - - if (existeQos.ordinal() < candidateQos.ordinal()) { + QoS existedQos = result + .get(found) + .qos(); + if (existedQos.ordinal() < candidateQos.ordinal()) { result.remove(found); - return true; + result.add(candidate); } - - return false; } } From 0c219dd553fdac05d1084630c7147063902a578c Mon Sep 17 00:00:00 2001 From: Maksim Kashapov <56276969+crazyrokr@users.noreply.github.com> Date: Tue, 2 Dec 2025 23:59:20 +0100 Subject: [PATCH 8/9] [broker-30] Revert wrong access level modifier --- .../javasabr/mqtt/model/subscriber/tree/SubscriberNode.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java index 3a0ec8bc..2bdb4314 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java @@ -22,7 +22,7 @@ @Getter(AccessLevel.PACKAGE) @Accessors(fluent = true, chain = false) @FieldDefaults(level = AccessLevel.PRIVATE) -public class SubscriberNode extends SubscriberTreeBase { +class SubscriberNode extends SubscriberTreeBase { private final static Supplier SUBSCRIBER_NODE_FACTORY = SubscriberNode::new; From 62900cac43c9f24237aa9df796de5acad0155b10 Mon Sep 17 00:00:00 2001 From: Maksim Kashapov <56276969+crazyrokr@users.noreply.github.com> Date: Wed, 3 Dec 2025 08:50:12 +0100 Subject: [PATCH 9/9] [broker-30] Avoid tree branch locking --- .../model/subscriber/tree/SubscriberNode.java | 4 +- .../mqtt/model/topic/AbstractTopic.java | 2 +- .../model/topic/tree/RetainedMessageNode.java | 91 +++++++++---------- 3 files changed, 46 insertions(+), 51 deletions(-) diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java index 2bdb4314..fdcad600 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java @@ -67,7 +67,7 @@ private void collectMatchingSubscribers( TopicName topicName, int lastLevel, MutableArray result) { - SubscriberNode subscriberNode = childNode(segment); + SubscriberNode subscriberNode = getChildNode(segment); if (subscriberNode == null) { return; } @@ -98,7 +98,7 @@ private SubscriberNode getOrCreateChildNode(String segment) { } @Nullable - private SubscriberNode childNode(String segment) { + private SubscriberNode getChildNode(String segment) { LockableRefToRefDictionary localChildNodes = childNodes; if (localChildNodes == null) { return null; diff --git a/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java b/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java index 737b0bf6..cea0f54d 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java @@ -41,7 +41,7 @@ public int levelsCount() { return segments.length; } - public String lastSegment() { + String lastSegment() { return segments[segments.length - 1]; } diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java index cc6b4b68..d02fc9e4 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java @@ -1,11 +1,6 @@ package javasabr.mqtt.model.topic.tree; -import static javasabr.mqtt.model.topic.TopicFilter.MULTI_LEVEL_WILDCARD; -import static javasabr.mqtt.model.topic.TopicFilter.SINGLE_LEVEL_WILDCARD; - import java.util.LinkedList; -import java.util.Objects; -import java.util.PriorityQueue; import java.util.Queue; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; @@ -13,6 +8,7 @@ import javasabr.mqtt.model.publishing.Publish; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; +import javasabr.rlib.collections.array.ArrayFactory; import javasabr.rlib.collections.array.MutableArray; import javasabr.rlib.collections.dictionary.DictionaryFactory; import javasabr.rlib.collections.dictionary.LockableRefToRefDictionary; @@ -39,49 +35,47 @@ class RetainedMessageNode { public void retainMessage(int level, Publish message, TopicName topicName) { var child = getOrCreateChildNode(topicName.segment(level)); - boolean isLeaf = (level + 1 == topicName.levelsCount()); - if (isLeaf) { - if (Objects.equals(message.topicName().lastSegment(), topicName.lastSegment())) { - child.retainedMessage.set(message.payload().length == 0 ? null : message); - } + boolean isLastLevel = (level + 1 == topicName.levelsCount()); + if (isLastLevel) { + child.retainedMessage.set(message.payload().length == 0 ? null : message); } else { child.retainMessage(level + 1, message, topicName); } } public void collectRetainedMessages(int level, TopicFilter topicFilter, MutableArray result) { - String segment = topicFilter.segment(level); - if (Objects.equals(segment, MULTI_LEVEL_WILDCARD)) { - collectAllMessages(this, result); - return; - } else if (Objects.equals(segment, SINGLE_LEVEL_WILDCARD)) { - var childNodes = childNodes(); - if (childNodes == null) { - return; - } - long stamp = childNodes.readLock(); - try { - for (RetainedMessageNode childNode : childNodes) { - childNode.collectRetainedMessages(level + 1, topicFilter, result); - } - } finally { - childNodes.readUnlock(stamp); + if (level == topicFilter.levelsCount()) { + Publish publish = retainedMessage.get(); + if (publish != null) { + result.add(publish); } return; } - int lastLevel = topicFilter.levelsCount() - 1; - RetainedMessageNode retainedMessageNode = childNode(segment); - if (retainedMessageNode == null || level > lastLevel) { + String segment = topicFilter.segment(level); + boolean isOneCharSegment = segment.length() == 1; + if (isOneCharSegment && segment.charAt(0) == TopicFilter.MULTI_LEVEL_WILDCARD_CHAR) { + collectAllMessages(this, result); return; } - boolean isLeaf = (level == lastLevel); - if (isLeaf) { - Publish publish = retainedMessageNode.retainedMessage.get(); - if(publish != null && Objects.equals(segment, publish.topicName().lastSegment())){ - result.add(publish); + if (isOneCharSegment && segment.charAt(0) == TopicFilter.SINGLE_LEVEL_WILDCARD_CHAR) { + var localChildNodes = childNodes; + if (localChildNodes != null) { + var nextChildNodes = ArrayFactory.mutableArray(RetainedMessageNode.class); + long stamp = localChildNodes.readLock(); + try { + localChildNodes.values(nextChildNodes); + } finally { + localChildNodes.readUnlock(stamp); + } + for (RetainedMessageNode childNode : nextChildNodes) { + childNode.collectRetainedMessages(level + 1, topicFilter, result); + } } } else { - retainedMessageNode.collectRetainedMessages(level + 1, topicFilter, result); + RetainedMessageNode retainedMessageNode = getChildNode(segment); + if (retainedMessageNode != null) { + retainedMessageNode.collectRetainedMessages(level + 1, topicFilter, result); + } } } @@ -100,9 +94,7 @@ private void collectAllMessages(RetainedMessageNode node, MutableArray } long stamp = childNodes.readLock(); try { - for (RetainedMessageNode n : childNodes) { - queue.add(n); - } + childNodes.values(queue); } finally { childNodes.readUnlock(stamp); } @@ -110,8 +102,8 @@ private void collectAllMessages(RetainedMessageNode node, MutableArray } @Nullable - private RetainedMessageNode childNode(String segment) { - LockableRefToRefDictionary childNodes = childNodes(); + private RetainedMessageNode getChildNode(String segment) { + var childNodes = childNodes(); if (childNodes == null) { return null; } @@ -124,7 +116,7 @@ private RetainedMessageNode childNode(String segment) { } private RetainedMessageNode getOrCreateChildNode(String segment) { - LockableRefToRefDictionary childNodes = getOrCreateChildNodes(); + var childNodes = getOrCreateChildNodes(); long stamp = childNodes.readLock(); try { RetainedMessageNode topicFilterNode = childNodes.get(segment); @@ -143,15 +135,18 @@ private RetainedMessageNode getOrCreateChildNode(String segment) { } private LockableRefToRefDictionary getOrCreateChildNodes() { - if (childNodes == null) { - synchronized (this) { - if (childNodes == null) { - childNodes = DictionaryFactory.stampedLockBasedRefToRefDictionary(); - } + var current = childNodes; + if (current != null) { + return current; + } + synchronized (this) { + current = childNodes; + if (current == null) { + current = DictionaryFactory.stampedLockBasedRefToRefDictionary(); + childNodes = current; } + return current; } - //noinspection ConstantConditions - return childNodes; } @Override