diff --git a/application/build.gradle b/application/build.gradle index b93e391f..3f0b7d88 100644 --- a/application/build.gradle +++ b/application/build.gradle @@ -24,4 +24,8 @@ tasks.withType(GroovyCompile).configureEach { configurations.each { it.exclude group: "org.slf4j", module: "slf4j-log4j12" it.exclude group: "org.springframework.boot", module: "spring-boot-starter-logging" -} \ No newline at end of file +} + +bootJar { + mainClass = 'javasabr.mqtt.broker.application.MqttBrokerApplication' +} diff --git a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java index eb4c4470..756014ba 100644 --- a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java +++ b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java @@ -93,13 +93,14 @@ CredentialSource credentialSource( @Bean AuthenticationService authenticationService( CredentialSource credentialSource, - @Value("${authentication.allow.anonymous:false}") boolean allowAnonymousAuth) { + @Value("${authentication.allow.anonymous:false}") + boolean allowAnonymousAuth) { return new SimpleAuthenticationService(credentialSource, allowAnonymousAuth); } @Bean - SubscriptionService subscriptionService() { - return new InMemorySubscriptionService(); + SubscriptionService subscriptionService(PublishDeliveringService publishDeliveringService) { + return new InMemorySubscriptionService(publishDeliveringService); } @Bean @@ -153,10 +154,7 @@ MqttInMessageHandler publishMqttInMessageHandler( PublishReceivingService publishReceivingService, MessageOutFactoryService messageOutFactoryService, TopicService topicService) { - return new PublishMqttInMessageHandler( - publishReceivingService, - messageOutFactoryService, - topicService); + return new PublishMqttInMessageHandler(publishReceivingService, messageOutFactoryService, topicService); } @Bean @@ -187,10 +185,7 @@ MqttInMessageHandler unsubscribeMqttInMessageHandler( SubscriptionService subscriptionService, MessageOutFactoryService messageOutFactoryService, TopicService topicService) { - return new UnsubscribeMqttInMessageHandler( - subscriptionService, - messageOutFactoryService, - topicService); + return new UnsubscribeMqttInMessageHandler(subscriptionService, messageOutFactoryService, topicService); } @Bean @@ -199,24 +194,18 @@ ConnectionService externalMqttConnectionService(Collection deliverRetainedMessages(TopicFilter topicFilter, SingleSubscriber subscriber); } diff --git a/core-service/src/main/java/javasabr/mqtt/service/SubscriptionService.java b/core-service/src/main/java/javasabr/mqtt/service/SubscriptionService.java index aead2590..750fe0f1 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/SubscriptionService.java +++ b/core-service/src/main/java/javasabr/mqtt/service/SubscriptionService.java @@ -4,7 +4,6 @@ import javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode; import javasabr.mqtt.model.session.MqttSession; import javasabr.mqtt.model.subscriber.SingleSubscriber; -import javasabr.mqtt.model.subscriber.Subscriber; import javasabr.mqtt.model.subscription.Subscription; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; @@ -17,8 +16,6 @@ */ public interface SubscriptionService { - NetworkMqttUser resolveClient(Subscriber subscriber); - default Array findSubscribers(TopicName topicName) { return findSubscribersTo(MutableArray.ofType(SingleSubscriber.class), topicName); } diff --git a/core-service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java b/core-service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java index 8e0a7529..9c45b9b7 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java +++ b/core-service/src/main/java/javasabr/mqtt/service/impl/DefaultPublishDeliveringService.java @@ -4,9 +4,13 @@ import javasabr.mqtt.model.QoS; import javasabr.mqtt.model.publishing.Publish; import javasabr.mqtt.model.subscriber.SingleSubscriber; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.mqtt.model.topic.tree.ConcurrentRetainedMessageTree; import javasabr.mqtt.service.PublishDeliveringService; import javasabr.mqtt.service.publish.handler.MqttPublishOutMessageHandler; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; +import javasabr.rlib.collections.array.Array; +import javasabr.rlib.collections.array.MutableArray; import lombok.AccessLevel; import lombok.CustomLog; import lombok.experimental.FieldDefaults; @@ -18,6 +22,7 @@ public class DefaultPublishDeliveringService implements PublishDeliveringService @Nullable MqttPublishOutMessageHandler[] publishOutMessageHandlers; + ConcurrentRetainedMessageTree retainedMessageTree; public DefaultPublishDeliveringService( Collection knownPublishOutHandlers) { @@ -39,7 +44,7 @@ public DefaultPublishDeliveringService( } handlers[qos.level()] = knownPublishOutHandler; } - + this.retainedMessageTree = new ConcurrentRetainedMessageTree(); this.publishOutMessageHandlers = handlers; log.info(publishOutMessageHandlers, DefaultPublishDeliveringService::buildServiceDescription); } @@ -47,6 +52,7 @@ public DefaultPublishDeliveringService( @Override public PublishHandlingResult startDelivering(Publish publish, SingleSubscriber subscriber) { try { + retainedMessageTree.retainMessage(publish); //noinspection DataFlowIssue return publishOutMessageHandlers[subscriber.qos().level()].handle(publish, subscriber); } catch (IndexOutOfBoundsException | NullPointerException ex) { @@ -55,6 +61,16 @@ public PublishHandlingResult startDelivering(Publish publish, SingleSubscriber s } } + @Override + public Array deliverRetainedMessages(TopicFilter topicFilter, SingleSubscriber subscriber) { + Array retainedMessage = retainedMessageTree.getRetainedMessage(topicFilter); + MutableArray result = MutableArray.ofType(PublishHandlingResult.class); + for (Publish message : retainedMessage) { + result.add(startDelivering(message, subscriber)); + } + return result; + } + private static String buildServiceDescription( @Nullable MqttPublishOutMessageHandler[] publishOutMessageHandlers) { var builder = new StringBuilder(); diff --git a/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java b/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java index fcc635ba..1c2e3e74 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java +++ b/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java @@ -1,22 +1,26 @@ package javasabr.mqtt.service.impl; +import static javasabr.mqtt.model.SubscribeRetainHandling.SEND; +import static javasabr.mqtt.model.SubscribeRetainHandling.SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST; import static javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode.NO_SUBSCRIPTION_EXISTED; import static javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode.SUCCESS; import javasabr.mqtt.model.MqttClientConnectionConfig; +import javasabr.mqtt.model.MqttUser; import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode; import javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode; import javasabr.mqtt.model.session.ActiveSubscriptions; import javasabr.mqtt.model.session.MqttSession; import javasabr.mqtt.model.subscriber.SingleSubscriber; -import javasabr.mqtt.model.subscriber.Subscriber; import javasabr.mqtt.model.subscriber.tree.ConcurrentSubscriberTree; import javasabr.mqtt.model.subscription.Subscription; import javasabr.mqtt.model.topic.SharedTopicFilter; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; import javasabr.mqtt.network.user.NetworkMqttUser; +import javasabr.mqtt.service.PublishDeliveringService; import javasabr.mqtt.service.SubscriptionService; +import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import javasabr.rlib.collections.array.Array; import javasabr.rlib.collections.array.ArrayFactory; import javasabr.rlib.collections.array.MutableArray; @@ -31,18 +35,12 @@ @FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) public class InMemorySubscriptionService implements SubscriptionService { + PublishDeliveringService publishDeliveringService; ConcurrentSubscriberTree subscriberTree; - public InMemorySubscriptionService() { + public InMemorySubscriptionService(PublishDeliveringService publishDeliveringService) { this.subscriberTree = new ConcurrentSubscriberTree(); - } - - @Override - public NetworkMqttUser resolveClient(Subscriber subscriber) { - if (subscriber instanceof SingleSubscriber single) { - return (NetworkMqttUser) single.user(); - } - throw new IllegalArgumentException("Unexpected subscriber: " + subscriber); + this.publishDeliveringService = publishDeliveringService; } @Override @@ -84,6 +82,10 @@ private SubscribeAckReasonCode addSubscription(NetworkMqttUser user, MqttSession if (previous != null) { activeSubscriptions.remove(previous.subscription()); } + if ((subscription.retainHandling() == SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST && previous != null) + || subscription.retainHandling() == SEND) { + sendRetainedMessages(user, subscription); + } activeSubscriptions.add(subscription); return subscription.qos().subscribeAckReasonCode(); } @@ -137,4 +139,39 @@ public void restoreSubscriptions(NetworkMqttUser user, MqttSession session) { subscriberTree.subscribe(user, subscription); } } + + private void sendRetainedMessages(MqttUser user, Subscription subscription) { + int count = 0; + PublishHandlingResult errorResult = null; + if (subscription + .qos() + .subscribeAckReasonCode() + .ordinal() > 2) { + // TODO handle error ? + return; + } + SingleSubscriber singleSubscriber = new SingleSubscriber(user, subscription); + var results = publishDeliveringService.deliverRetainedMessages(subscription.topicFilter(), singleSubscriber); + for (PublishHandlingResult result : results) { + if (result.error()) { + errorResult = result; + } else if (result == PublishHandlingResult.SUCCESS) { + count++; + } + if (errorResult != null) { + log.debug( + user.clientId(), + errorResult, + "[%s] Found final error:[%s] during sending retained messages"::formatted); + // TODO handleError(client, publish, errorResult); + } else { + log.debug( + user.clientId(), + count, + "[%s] Successfully started delivering retained messages to [%s] subscribers"::formatted); + // TODO handleSuccessfulResult(client, publish, count); + } + + } + } } diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java index 613e1f96..2b329bc6 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java @@ -3,10 +3,10 @@ import javasabr.mqtt.model.MqttProperties; import javasabr.mqtt.model.publishing.Publish; import javasabr.mqtt.model.subscriber.SingleSubscriber; +import javasabr.mqtt.model.subscriber.Subscriber; import javasabr.mqtt.network.message.out.MqttOutMessage; import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.MqttPublishOutMessageHandler; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import lombok.AccessLevel; @@ -22,12 +22,18 @@ public abstract class AbstractMqttPublishOutMessageHandler expectedUser; - SubscriptionService subscriptionService; MessageOutFactoryService messageOutFactoryService; + private static NetworkMqttUser resolveClient(Subscriber subscriber) { + if (subscriber instanceof SingleSubscriber single) { + return (NetworkMqttUser) single.user(); + } + throw new IllegalArgumentException("Unexpected subscriber: " + subscriber); + } + @Override public PublishHandlingResult handle(Publish publish, SingleSubscriber subscriber) { - NetworkMqttUser user = subscriptionService.resolveClient(subscriber); + NetworkMqttUser user = resolveClient(subscriber); if (!expectedUser.isInstance(user)) { log.warning(user, "Accepted not expected client:[%s]"::formatted); return PublishHandlingResult.NOT_EXPECTED_CLIENT; diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/PersistedMqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/PersistedMqttPublishOutMessageHandler.java index 13a5b252..9f572f5b 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/PersistedMqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/PersistedMqttPublishOutMessageHandler.java @@ -8,7 +8,6 @@ import javasabr.mqtt.network.session.NetworkMqttSession.PendingMessageHandler; import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import lombok.AccessLevel; import lombok.experimental.FieldDefaults; @@ -20,15 +19,14 @@ public abstract class PersistedMqttPublishOutMessageHandler extends PendingMessageHandler pendingMessageHandler; - protected PersistedMqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, messageOutFactoryService); + protected PersistedMqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(ExternalNetworkMqttUser.class, messageOutFactoryService); this.pendingMessageHandler = new PendingMessageHandler() { @Override public boolean handleResponse(NetworkMqttUser user, TrackableMqttMessage response) { return handleReceivedResponse(user, response); } + @Override public void resend(NetworkMqttUser user, Publish publish) { tryToDeliverAgain(user, publish); @@ -45,10 +43,7 @@ protected Publish reconstruct(NetworkMqttUser user, Publish original) { } return original.with( // generate new uniq packet id per client - session.generateMessageId(), - qos(), - false, - MqttProperties.TOPIC_ALIAS_NOT_SET); + session.generateMessageId(), qos(), false, MqttProperties.TOPIC_ALIAS_NOT_SET); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java index ad09e51d..e31dd20b 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java @@ -4,15 +4,12 @@ import javasabr.mqtt.model.publishing.Publish; import javasabr.mqtt.network.impl.ExternalNetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; public class Qos0MqttPublishOutMessageHandler extends AbstractMqttPublishOutMessageHandler { - public Qos0MqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, messageOutFactoryService); + public Qos0MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(ExternalNetworkMqttUser.class, messageOutFactoryService); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java index 4aec8dd7..826ee01d 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java @@ -5,14 +5,11 @@ import javasabr.mqtt.network.message.in.PublishAckMqttInMessage; import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; public class Qos1MqttPublishOutMessageHandler extends PersistedMqttPublishOutMessageHandler { - public Qos1MqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(subscriptionService, messageOutFactoryService); + public Qos1MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(messageOutFactoryService); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java index cb9584f3..ecb65c0e 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java @@ -8,14 +8,11 @@ import javasabr.mqtt.network.message.in.PublishReceivedMqttInMessage; import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; public class Qos2MqttPublishOutMessageHandler extends PersistedMqttPublishOutMessageHandler { - public Qos2MqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(subscriptionService, messageOutFactoryService); + public Qos2MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(messageOutFactoryService); } @Override diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy index f1054d75..577388b3 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy @@ -44,8 +44,7 @@ abstract class IntegrationServiceSpecification extends Specification { @Shared def defaultTopicService = new DefaultTopicService() - @Shared - def defaultSubscriptionService = new InMemorySubscriptionService() + @Shared def defaultMessageOutFactoryService = new DefaultMessageOutFactoryService([ @@ -55,11 +54,14 @@ abstract class IntegrationServiceSpecification extends Specification { @Shared def defaultPublishDeliveringService = new DefaultPublishDeliveringService([ - new Qos0MqttPublishOutMessageHandler(defaultSubscriptionService, defaultMessageOutFactoryService), - new Qos1MqttPublishOutMessageHandler(defaultSubscriptionService, defaultMessageOutFactoryService), - new Qos2MqttPublishOutMessageHandler(defaultSubscriptionService, defaultMessageOutFactoryService) + new Qos0MqttPublishOutMessageHandler(defaultMessageOutFactoryService), + new Qos1MqttPublishOutMessageHandler(defaultMessageOutFactoryService), + new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) ]) + @Shared + def defaultSubscriptionService = new InMemorySubscriptionService(defaultPublishDeliveringService) + @Shared def qos0MqttPublishInMessageHandler = new Qos0MqttPublishInMessageHandler( defaultSubscriptionService, diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy index 6e97e1de..e473f024 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy @@ -12,8 +12,6 @@ import javasabr.rlib.collections.array.Array class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { - SubscriptionService subscriptionService = new InMemorySubscriptionService() - def "should subscribe with expected results in default settings"() { given: def serverConfig = defaultExternalServerConnectionConfig @@ -49,7 +47,7 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { true, true)) when: - def result = subscriptionService + def result = defaultSubscriptionService .subscribe(mqttUser, mqttUser.session(), subscriptions) then: result.size() == 4 @@ -104,7 +102,7 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { true, true)) when: - def result = subscriptionService + def result = defaultSubscriptionService .subscribe(mqttUser, mqttUser.session(), subscriptions) then: result.size() == 5 @@ -152,7 +150,7 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { true) def subscriptions = Array.of(sub1, sub2, sub3, sub4) when: - def result = subscriptionService + def result = defaultSubscriptionService .subscribe(mqttUser, mqttUser.session(), subscriptions) then: result.size() == 4 @@ -200,14 +198,14 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { SubscribeRetainHandling.SEND, true, true)) - subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + defaultSubscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) def topicsToUnsubscribe = Array.of( defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), defaultTopicService.createTopicFilter(mqttUser, "topic/filter/3"), defaultTopicService.createTopicFilter(mqttUser, "topic/filter/notexist"), defaultTopicService.createTopicFilter(mqttUser, "topic/filter/invalid##")) when: - def result = subscriptionService + def result = defaultSubscriptionService .unsubscribe(mqttUser, mqttUser.session(), topicsToUnsubscribe) then: result.size() == 4 @@ -251,13 +249,13 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), defaultTopicService.createTopicFilter(mqttUser, "topic/filter/3")) when: - subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + defaultSubscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) def storedSubscriptions = activeSubscriptions.subscriptions() then: storedSubscriptions.size() == 3 storedSubscriptions == subscriptions when: - subscriptionService.unsubscribe(mqttUser, mqttUser.session(), topicsToUnsubscribe) + defaultSubscriptionService.unsubscribe(mqttUser, mqttUser.session(), topicsToUnsubscribe) storedSubscriptions = activeSubscriptions.subscriptions() then: storedSubscriptions.size() == 1 @@ -313,13 +311,13 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { subscriptions.get(1), subscriptions2.get(1)) when: - subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + defaultSubscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) def storedSubscriptions = activeSubscriptions.subscriptions() then: storedSubscriptions.size() == 3 storedSubscriptions == subscriptions when: - subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions2) + defaultSubscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions2) storedSubscriptions = activeSubscriptions.subscriptions() then: storedSubscriptions.size() == 3 diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy index 2ebca07c..09a94c0a 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy @@ -69,7 +69,7 @@ class UnsubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecificatio def "should response with expected results"() { given: def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) - def subscriptionService = new InMemorySubscriptionService() + def subscriptionService = new InMemorySubscriptionService(defaultPublishDeliveringService) def messageHandler = new UnsubscribeMqttInMessageHandler( subscriptionService, defaultMessageOutFactoryService, diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java index 4a6579c2..fdcad600 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java @@ -39,7 +39,7 @@ class SubscriberNode extends SubscriberTreeBase { * @return the previous subscription from the same owner */ @Nullable - public SingleSubscriber subscribe(int level, MqttUser owner, Subscription subscription, TopicFilter topicFilter) { + protected SingleSubscriber subscribe(int level, MqttUser owner, Subscription subscription, TopicFilter topicFilter) { if (level == topicFilter.levelsCount()) { return addSubscriber(getOrCreateSubscribers(), owner, subscription, topicFilter); } @@ -47,7 +47,7 @@ public SingleSubscriber subscribe(int level, MqttUser owner, Subscription subscr return childNode.subscribe(level + 1, owner, subscription, topicFilter); } - public boolean unsubscribe(int level, MqttUser owner, TopicFilter topicFilter) { + protected boolean unsubscribe(int level, MqttUser owner, TopicFilter topicFilter) { if (level == topicFilter.levelsCount()) { return removeSubscriber(subscribers(), owner, topicFilter); } @@ -56,51 +56,28 @@ public boolean unsubscribe(int level, MqttUser owner, TopicFilter topicFilter) { } protected void matchesTo(int level, TopicName topicName, int lastLevel, MutableArray container) { - exactlyTopicMatch(level, topicName, lastLevel, container); - singleWildcardTopicMatch(level, topicName, lastLevel, container); - multiWildcardTopicMatch(container); + collectMatchingSubscribers(topicName.segment(level), level, topicName, lastLevel, container); + collectMatchingSubscribers(TopicFilter.SINGLE_LEVEL_WILDCARD, level, topicName, lastLevel, container); + collectMatchingSubscribers(TopicFilter.MULTI_LEVEL_WILDCARD, level, topicName, lastLevel, container); } - private void exactlyTopicMatch( + private void collectMatchingSubscribers( + String segment, int level, TopicName topicName, int lastLevel, MutableArray result) { - String segment = topicName.segment(level); - SubscriberNode subscriberNode = childNode(segment); + SubscriberNode subscriberNode = getChildNode(segment); if (subscriberNode == null) { return; } - if (level == lastLevel) { + if (level == lastLevel || TopicFilter.MULTI_LEVEL_WILDCARD.equals(segment)) { appendSubscribersTo(result, subscriberNode); } else if (level < lastLevel) { subscriberNode.matchesTo(level + 1, topicName, lastLevel, result); } } - private void singleWildcardTopicMatch( - int level, - TopicName topicName, - int lastLevel, - MutableArray result) { - SubscriberNode subscriberNode = childNode(TopicFilter.SINGLE_LEVEL_WILDCARD); - if (subscriberNode == null) { - return; - } - if (level == lastLevel) { - appendSubscribersTo(result, subscriberNode); - } else if (level < lastLevel) { - subscriberNode.matchesTo(level + 1, topicName, lastLevel, result); - } - } - - private void multiWildcardTopicMatch(MutableArray result) { - SubscriberNode subscriberNode = childNode(TopicFilter.MULTI_LEVEL_WILDCARD); - if (subscriberNode != null) { - appendSubscribersTo(result, subscriberNode); - } - } - private SubscriberNode getOrCreateChildNode(String segment) { LockableRefToRefDictionary childNodes = getOrCreateChildNodes(); long stamp = childNodes.readLock(); @@ -121,41 +98,47 @@ private SubscriberNode getOrCreateChildNode(String segment) { } @Nullable - private SubscriberNode childNode(String segment) { - LockableRefToRefDictionary childNodes = childNodes(); - if (childNodes == null) { + private SubscriberNode getChildNode(String segment) { + LockableRefToRefDictionary localChildNodes = childNodes; + if (localChildNodes == null) { return null; } - long stamp = childNodes.readLock(); + long stamp = localChildNodes.readLock(); try { - return childNodes.get(segment); + return localChildNodes.get(segment); } finally { - childNodes.readUnlock(stamp); + localChildNodes.readUnlock(stamp); } } private LockableRefToRefDictionary getOrCreateChildNodes() { - if (childNodes == null) { - synchronized (this) { - if (childNodes == null) { - childNodes = DictionaryFactory.stampedLockBasedRefToRefDictionary(); - } + LockableRefToRefDictionary localChildNodes = childNodes; + if (localChildNodes != null) { + return localChildNodes; + } + synchronized (this) { + localChildNodes = childNodes; + if (localChildNodes == null) { + localChildNodes = DictionaryFactory.stampedLockBasedRefToRefDictionary(); + childNodes = localChildNodes; } + return localChildNodes; } - //noinspection ConstantConditions - return childNodes; } private LockableArray getOrCreateSubscribers() { - if (subscribers == null) { - synchronized (this) { - if (subscribers == null) { - subscribers = ArrayFactory.stampedLockBasedArray(Subscriber.class); - } + LockableArray localSubscribers = subscribers; + if (localSubscribers != null) { + return localSubscribers; + } + synchronized (this) { + localSubscribers = subscribers; + if (localSubscribers == null) { + localSubscribers = ArrayFactory.stampedLockBasedArray(Subscriber.class); + subscribers = localSubscribers; } + return localSubscribers; } - //noinspection ConstantConditions - return subscribers; } @Override diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java index 972b696b..9ae8c953 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java @@ -45,9 +45,7 @@ protected static SingleSubscriber addSubscriber( } @Nullable - private static SingleSubscriber removePreviousIfExist( - LockableArray subscribers, - MqttUser user) { + private static SingleSubscriber removePreviousIfExist(LockableArray subscribers, MqttUser user) { int index = subscribers.indexOf(Subscriber::resolveUser, user); if (index < 0) { return null; @@ -84,10 +82,7 @@ protected static void appendSubscribersTo(MutableArray result, long stamp = subscribers.readLock(); try { for (Subscriber subscriber : subscribers) { - SingleSubscriber singleSubscriber = subscriber.resolveSingle(); - if (removeDuplicateWithLowerQoS(result, singleSubscriber)) { - result.add(singleSubscriber); - } + addOrReplaceIfLowerQos(result, subscriber); } } finally { subscribers.readUnlock(stamp); @@ -141,23 +136,20 @@ private static boolean isSharedSubscriberWithGroup(Subscriber subscriber, String return subscriber instanceof SharedSubscriber shared && Objects.equals(group, shared.group()); } - private static boolean removeDuplicateWithLowerQoS( - MutableArray result, SingleSubscriber candidate) { - + private static void addOrReplaceIfLowerQos(MutableArray result, Subscriber subscriber) { + SingleSubscriber candidate = subscriber.resolveSingle(); int found = result.indexOf(SingleSubscriber::user, candidate.user()); if (found == -1) { - return true; + result.add(candidate); + return; } - QoS candidateQos = candidate.qos(); - SingleSubscriber exist = result.get(found); - QoS existeQos = exist.qos(); - - if (existeQos.ordinal() < candidateQos.ordinal()) { + QoS existedQos = result + .get(found) + .qos(); + if (existedQos.ordinal() < candidateQos.ordinal()) { result.remove(found); - return true; + result.add(candidate); } - - return false; } } diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java new file mode 100644 index 00000000..6f38475f --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java @@ -0,0 +1,31 @@ +package javasabr.mqtt.model.topic.tree; + +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.rlib.collections.array.Array; +import javasabr.rlib.collections.array.MutableArray; +import javasabr.rlib.common.ThreadSafe; +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; + +@FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) +public class ConcurrentRetainedMessageTree implements ThreadSafe { + + RetainedMessageNode rootNode; + + public ConcurrentRetainedMessageTree() { + this.rootNode = new RetainedMessageNode(); + } + + public void retainMessage(Publish message) { + if (message.retained()) { + rootNode.retainMessage(0, message, message.topicName()); + } + } + + public Array getRetainedMessage(TopicFilter topicFilter) { + var resultArray = MutableArray.ofType(Publish.class); + rootNode.collectRetainedMessages(0, topicFilter, resultArray); + return resultArray; + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java new file mode 100644 index 00000000..d02fc9e4 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java @@ -0,0 +1,156 @@ +package javasabr.mqtt.model.topic.tree; + +import java.util.LinkedList; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import javasabr.mqtt.base.util.DebugUtils; +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.mqtt.model.topic.TopicName; +import javasabr.rlib.collections.array.ArrayFactory; +import javasabr.rlib.collections.array.MutableArray; +import javasabr.rlib.collections.dictionary.DictionaryFactory; +import javasabr.rlib.collections.dictionary.LockableRefToRefDictionary; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.experimental.Accessors; +import lombok.experimental.FieldDefaults; +import org.jspecify.annotations.Nullable; + +@Getter(AccessLevel.PACKAGE) +@Accessors(fluent = true, chain = false) +@FieldDefaults(level = AccessLevel.PRIVATE) +class RetainedMessageNode { + + private final static Supplier TOPIC_NODE_FACTORY = RetainedMessageNode::new; + + static { + DebugUtils.registerIncludedFields("childNodes", "retainedMessage"); + } + + @Nullable + volatile LockableRefToRefDictionary childNodes; + final AtomicReference<@Nullable Publish> retainedMessage = new AtomicReference<>(); + + public void retainMessage(int level, Publish message, TopicName topicName) { + var child = getOrCreateChildNode(topicName.segment(level)); + boolean isLastLevel = (level + 1 == topicName.levelsCount()); + if (isLastLevel) { + child.retainedMessage.set(message.payload().length == 0 ? null : message); + } else { + child.retainMessage(level + 1, message, topicName); + } + } + + public void collectRetainedMessages(int level, TopicFilter topicFilter, MutableArray result) { + if (level == topicFilter.levelsCount()) { + Publish publish = retainedMessage.get(); + if (publish != null) { + result.add(publish); + } + return; + } + String segment = topicFilter.segment(level); + boolean isOneCharSegment = segment.length() == 1; + if (isOneCharSegment && segment.charAt(0) == TopicFilter.MULTI_LEVEL_WILDCARD_CHAR) { + collectAllMessages(this, result); + return; + } + if (isOneCharSegment && segment.charAt(0) == TopicFilter.SINGLE_LEVEL_WILDCARD_CHAR) { + var localChildNodes = childNodes; + if (localChildNodes != null) { + var nextChildNodes = ArrayFactory.mutableArray(RetainedMessageNode.class); + long stamp = localChildNodes.readLock(); + try { + localChildNodes.values(nextChildNodes); + } finally { + localChildNodes.readUnlock(stamp); + } + for (RetainedMessageNode childNode : nextChildNodes) { + childNode.collectRetainedMessages(level + 1, topicFilter, result); + } + } + } else { + RetainedMessageNode retainedMessageNode = getChildNode(segment); + if (retainedMessageNode != null) { + retainedMessageNode.collectRetainedMessages(level + 1, topicFilter, result); + } + } + } + + private void collectAllMessages(RetainedMessageNode node, MutableArray result) { + Queue queue = new LinkedList<>(); + queue.add(node); + while (!queue.isEmpty()) { + RetainedMessageNode poll = queue.poll(); + Publish message = poll.retainedMessage.get(); + if (message != null) { + result.add(message); + } + var childNodes = poll.childNodes(); + if (childNodes == null) { + continue; + } + long stamp = childNodes.readLock(); + try { + childNodes.values(queue); + } finally { + childNodes.readUnlock(stamp); + } + } + } + + @Nullable + private RetainedMessageNode getChildNode(String segment) { + var childNodes = childNodes(); + if (childNodes == null) { + return null; + } + long stamp = childNodes.readLock(); + try { + return childNodes.get(segment); + } finally { + childNodes.readUnlock(stamp); + } + } + + private RetainedMessageNode getOrCreateChildNode(String segment) { + var childNodes = getOrCreateChildNodes(); + long stamp = childNodes.readLock(); + try { + RetainedMessageNode topicFilterNode = childNodes.get(segment); + if (topicFilterNode != null) { + return topicFilterNode; + } + } finally { + childNodes.readUnlock(stamp); + } + stamp = childNodes.writeLock(); + try { + return childNodes.getOrCompute(segment, TOPIC_NODE_FACTORY); + } finally { + childNodes.writeUnlock(stamp); + } + } + + private LockableRefToRefDictionary getOrCreateChildNodes() { + var current = childNodes; + if (current != null) { + return current; + } + synchronized (this) { + current = childNodes; + if (current == null) { + current = DictionaryFactory.stampedLockBasedRefToRefDictionary(); + childNodes = current; + } + return current; + } + } + + @Override + public String toString() { + return DebugUtils.toJsonString(this); + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java new file mode 100644 index 00000000..1df48806 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java @@ -0,0 +1,4 @@ +@NullMarked +package javasabr.mqtt.model.topic.tree; + +import org.jspecify.annotations.NullMarked; diff --git a/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy new file mode 100644 index 00000000..1baf0582 --- /dev/null +++ b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy @@ -0,0 +1,122 @@ +package javasabr.mqtt.model.topic.tree + + +import javasabr.mqtt.model.PayloadFormat +import javasabr.mqtt.model.QoS +import javasabr.mqtt.model.publishing.Publish +import javasabr.mqtt.model.topic.TopicFilter +import javasabr.mqtt.model.topic.TopicName +import javasabr.mqtt.test.support.UnitSpecification +import javasabr.rlib.collections.array.Array +import javasabr.rlib.collections.array.IntArray + +import static java.nio.charset.StandardCharsets.UTF_8 + +class RetainedMessageTreeTest extends UnitSpecification { + + def "should fetch retained messages by topic filter"( + List messages, + String topicFilter, + List expectedMessages) { + given: + ConcurrentRetainedMessageTree retainedMessageTree = new ConcurrentRetainedMessageTree(); + messages.eachWithIndex { Publish message, int i -> + retainedMessageTree.retainMessage(message) + } + when: + def retainedMessages = retainedMessageTree.getRetainedMessage(TopicFilter.valueOf(topicFilter)) + .collect { it } + then: + retainedMessages.size() == expectedMessages.size() + for (int i = 0; i < retainedMessages.size(); i++) { + assert retainedMessages.get(i).topicName() == expectedMessages.get(i).topicName() + } + where: + topicFilter << [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment3", + "/topic/+/segment2", + "/topic/#" + ] + messages << [ + [ + makePublish("/topic/segment1"), + makePublish("/topic/segment2"), + makePublish("/topic/segment1/segment2"), + makePublish("/topic/"), + makePublish("/topic") + ], + [ + makePublish("/topic/segment1"), + makePublish("/topic/segment2"), + makePublish("/topic/segment1/segment2"), + makePublish("/topic/"), + makePublish("/topic/segment2"), + makePublish("/"), + makePublish("/topic/segment2/segment1") + ], + [ + makePublish("/topic/segment1"), + makePublish("/topic/segment2"), + makePublish("/topic/segment3"), + makePublish("/topic/segment3"), + makePublish("/topic/segment3"), + makePublish("/topic/segment3") + ], + [ + makePublish("/topic/segment1"), + makePublish("/topic/segment2"), + makePublish("/topic/segment1/segment2"), + makePublish("/topic/segment500/segment2"), + makePublish("/topic/"), + makePublish("/topic") + ], + [ + makePublish("/topic1/segment1"), + makePublish("/topic/segment2"), + makePublish("/topic2/segment1/segment2"), + makePublish("/topic/segment3"), + makePublish("/topic/segment1/segment2") + ] + ] + expectedMessages << [ + [ + makePublish("/topic/segment1") + ], + [ + makePublish("/topic/segment2") + ], + [ + makePublish("/topic/segment3") + ], + [ + makePublish("/topic/segment1/segment2"), + makePublish("/topic/segment500/segment2") + ], + [ + makePublish("/topic/segment2"), + makePublish("/topic/segment3"), + makePublish("/topic/segment1/segment2") + ] + ] + } + + static def makePublish(String topicName) { + return new Publish( + 1, + QoS.AT_MOST_ONCE, + TopicName.valueOf(topicName), + null, + "payload".getBytes(UTF_8), + false, + true, + null, + IntArray.of(30), + null, + 60000, + 1, + PayloadFormat.UTF8_STRING, + Array.of()); + } +} diff --git a/network/src/test/groovy/javasabr/mqtt/network/message/out/ConnectAckMqtt5OutMessageTest.groovy b/network/src/test/groovy/javasabr/mqtt/network/message/out/ConnectAckMqtt5OutMessageTest.groovy index 8d94d698..789e0077 100644 --- a/network/src/test/groovy/javasabr/mqtt/network/message/out/ConnectAckMqtt5OutMessageTest.groovy +++ b/network/src/test/groovy/javasabr/mqtt/network/message/out/ConnectAckMqtt5OutMessageTest.groovy @@ -1,6 +1,5 @@ package javasabr.mqtt.network.message.out - import javasabr.mqtt.model.MqttVersion import javasabr.mqtt.model.QoS import javasabr.mqtt.model.message.MqttMessageType