Skip to content

Commit d0af0b3

Browse files
committed
[broker-12] add concurrent access support to subscription and authentication services
1 parent e442a76 commit d0af0b3

File tree

9 files changed

+86
-41
lines changed

9 files changed

+86
-41
lines changed

src/main/java/com/ss/mqtt/broker/model/SubscribeTopicFilter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
public class SubscribeTopicFilter {
99

1010
/**
11-
* The subscriber's topic filter.
11+
* The subscriber's topic name.
1212
*/
13-
private final String topicFilter;
13+
private final String topicName;
1414

1515
/**
1616
* Maximum QoS field. This gives the maximum QoS level at which the Server

src/main/java/com/ss/mqtt/broker/service/SubscriptionService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.ss.mqtt.broker.model.SubscribeAckReasonCode;
44
import com.ss.mqtt.broker.model.SubscribeTopicFilter;
5+
import com.ss.mqtt.broker.model.Subscriber;
56
import com.ss.mqtt.broker.model.UnsubscribeAckReasonCode;
67
import com.ss.mqtt.broker.network.client.MqttClient;
78
import com.ss.rlib.common.util.array.Array;
@@ -42,5 +43,5 @@ public interface SubscriptionService {
4243
* @param topicName topic name
4344
* @return array of topic subscribers
4445
*/
45-
@NotNull Array<MqttClient> getSubscribers(@NotNull String topicName);
46+
@NotNull Array<Subscriber> getSubscribers(@NotNull String topicName);
4647
}

src/main/java/com/ss/mqtt/broker/service/Subscriptions.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.ss.mqtt.broker.model.SubscribeAckReasonCode;
44
import com.ss.mqtt.broker.model.SubscribeTopicFilter;
5+
import com.ss.mqtt.broker.model.Subscriber;
56
import com.ss.mqtt.broker.model.UnsubscribeAckReasonCode;
67
import com.ss.mqtt.broker.network.client.MqttClient;
78
import com.ss.rlib.common.util.array.Array;
@@ -18,7 +19,7 @@ public interface Subscriptions {
1819
* @param topicName topic name on which subscribers should be returned
1920
* @return array of MQTT clients
2021
*/
21-
@NotNull Array<MqttClient> getSubscribers(@NotNull String topicName);
22+
@NotNull Array<Subscriber> getSubscribers(@NotNull String topicName);
2223

2324
/**
2425
* Returns result of subscription adding

src/main/java/com/ss/mqtt/broker/service/impl/AbstractCredentialSource.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
package com.ss.mqtt.broker.service.impl;
22

33
import com.ss.mqtt.broker.service.CredentialSource;
4+
import com.ss.rlib.common.util.dictionary.ConcurrentObjectDictionary;
5+
import com.ss.rlib.common.util.dictionary.Dictionary;
6+
import com.ss.rlib.common.util.dictionary.DictionaryFactory;
7+
import com.ss.rlib.common.util.dictionary.ObjectDictionary;
48
import org.jetbrains.annotations.NotNull;
59
import reactor.core.publisher.Mono;
610

7-
import java.nio.charset.StandardCharsets;
811
import java.util.Arrays;
9-
import java.util.HashMap;
10-
import java.util.Map;
1112

1213
public abstract class AbstractCredentialSource implements CredentialSource {
1314

14-
private final Map<String, byte[]> credentials = new HashMap<>();
15+
private final ConcurrentObjectDictionary<String, byte[]> credentials =
16+
DictionaryFactory.newConcurrentStampedLockObjectDictionary();
1517

1618
abstract void init();
1719

18-
void putCredentials(@NotNull Object user, @NotNull Object pass) {
19-
credentials.put(user.toString(), pass.toString().getBytes(StandardCharsets.UTF_8));
20+
void putAll(@NotNull Dictionary<String, byte[]> creds) {
21+
credentials.runInWriteLock(creds, Dictionary::put);
22+
}
23+
24+
void put(@NotNull String user, @NotNull byte[] pass) {
25+
credentials.runInWriteLock(user, pass, ObjectDictionary::put);
2026
}
2127

2228
@Override

src/main/java/com/ss/mqtt/broker/service/impl/FileCredentialsSource.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package com.ss.mqtt.broker.service.impl;
22

33
import com.ss.mqtt.broker.exception.CredentialsSourceException;
4+
import com.ss.rlib.common.util.dictionary.Dictionary;
5+
import com.ss.rlib.common.util.dictionary.DictionaryCollectors;
46
import org.jetbrains.annotations.NotNull;
57

68
import java.io.FileInputStream;
79
import java.io.IOException;
10+
import java.nio.charset.StandardCharsets;
811
import java.util.Properties;
912

1013
public class FileCredentialsSource extends AbstractCredentialSource {
@@ -25,7 +28,16 @@ void init() {
2528
try {
2629
var credentialsProperties = new Properties();
2730
credentialsProperties.load(new FileInputStream(credentialUrl.getPath()));
28-
credentialsProperties.forEach(this::putCredentials);
31+
32+
Dictionary<String, byte[]> creds = credentialsProperties.entrySet()
33+
.stream()
34+
.collect(DictionaryCollectors.toObjectDictionary(
35+
entry -> entry.getKey().toString(),
36+
entry -> entry.getValue().toString().getBytes(StandardCharsets.UTF_8)
37+
));
38+
39+
putAll(creds);
40+
2941
} catch (IOException e) {
3042
throw new CredentialsSourceException(e);
3143
}

src/main/java/com/ss/mqtt/broker/service/impl/SimplePublishingService.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package com.ss.mqtt.broker.service.impl;
22

33
import com.ss.mqtt.broker.model.PublishAckReasonCode;
4-
import com.ss.mqtt.broker.network.client.MqttClient;
4+
import com.ss.mqtt.broker.model.Subscriber;
55
import com.ss.mqtt.broker.network.packet.in.PublishInPacket;
66
import com.ss.mqtt.broker.service.PublishingService;
77
import com.ss.mqtt.broker.service.SubscriptionService;
@@ -17,10 +17,11 @@ public class SimplePublishingService implements PublishingService {
1717
private final @NotNull SubscriptionService subscriptionService;
1818

1919
private static @NotNull PublishAckReasonCode send(
20-
@NotNull MqttClient mqttClient,
20+
@NotNull Subscriber subscriber,
2121
@NotNull PublishInPacket publish
2222
) {
2323

24+
var mqttClient = subscriber.getMqttClient();
2425
mqttClient.send(mqttClient.getPacketOutFactory().newPublish(
2526
mqttClient,
2627
publish.getPacketId(),
@@ -50,7 +51,7 @@ public class SimplePublishingService implements PublishingService {
5051
}
5152

5253
var success = subscribers.stream()
53-
.map(targetMqttClient -> send(targetMqttClient, publish))
54+
.map(subscriber -> send(subscriber, publish))
5455
.allMatch(ackReasonCode -> ackReasonCode.equals(PublishAckReasonCode.SUCCESS));
5556

5657
return success ? PublishAckReasonCode.SUCCESS : PublishAckReasonCode.UNSPECIFIED_ERROR;

src/main/java/com/ss/mqtt/broker/service/impl/SimpleSubscriptionService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import com.ss.mqtt.broker.model.SubscribeAckReasonCode;
44
import com.ss.mqtt.broker.model.SubscribeTopicFilter;
5+
import com.ss.mqtt.broker.model.Subscriber;
56
import com.ss.mqtt.broker.model.UnsubscribeAckReasonCode;
67
import com.ss.mqtt.broker.network.client.MqttClient;
78
import com.ss.mqtt.broker.service.SubscriptionService;
@@ -40,7 +41,7 @@ public class SimpleSubscriptionService implements SubscriptionService {
4041
}
4142

4243
@Override
43-
public @NotNull Array<MqttClient> getSubscribers(@NotNull String topicName) {
44+
public @NotNull Array<Subscriber> getSubscribers(@NotNull String topicName) {
4445
return subscriptions.getSubscribers(topicName);
4546
}
4647
}

src/main/java/com/ss/mqtt/broker/service/impl/SimpleSubscriptions.java

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,50 +6,73 @@
66
import com.ss.mqtt.broker.model.UnsubscribeAckReasonCode;
77
import com.ss.mqtt.broker.network.client.MqttClient;
88
import com.ss.mqtt.broker.service.Subscriptions;
9+
import com.ss.rlib.common.function.NotNullSupplier;
910
import com.ss.rlib.common.util.array.Array;
10-
import com.ss.rlib.common.util.array.ArrayCollectors;
11-
import com.ss.rlib.common.util.array.impl.FastArraySet;
11+
import com.ss.rlib.common.util.array.ConcurrentArray;
12+
import com.ss.rlib.common.util.dictionary.ConcurrentObjectDictionary;
13+
import com.ss.rlib.common.util.dictionary.DictionaryFactory;
14+
import com.ss.rlib.common.util.dictionary.ObjectDictionary;
1215
import org.jetbrains.annotations.NotNull;
1316

14-
import java.util.HashMap;
15-
import java.util.Map;
16-
1717
/**
1818
* Simple container of subscriptions
1919
*/
2020
public class SimpleSubscriptions implements Subscriptions {
2121

22-
private final @NotNull Map<String, Array<Subscriber>> subscriptions = new HashMap<>();
22+
private final static @NotNull NotNullSupplier<ConcurrentArray<Subscriber>> SUBSCRIBER_ARRAY_SUPPLIER =
23+
ConcurrentArray.supplier(Subscriber.class);
24+
25+
private final @NotNull ConcurrentObjectDictionary<String, ConcurrentArray<Subscriber>> subscriptions =
26+
DictionaryFactory.newConcurrentStampedLockObjectDictionary();
2327

24-
public @NotNull Array<MqttClient> getSubscribers(@NotNull String topicName) {
25-
return subscriptions.get(topicName)
26-
.stream()
27-
.map(Subscriber::getMqttClient)
28-
.collect(ArrayCollectors.toArray(MqttClient.class));
28+
public @NotNull Array<Subscriber> getSubscribers(@NotNull String topicName) {
29+
30+
var subscribers = subscriptions.getInReadLock(topicName, ObjectDictionary::get);
31+
if (subscribers == null) {
32+
return Array.empty();
33+
}
34+
35+
//noinspection ConstantConditions
36+
return subscribers.getInReadLock(Array::of);
2937
}
3038

3139
public @NotNull SubscribeAckReasonCode addSubscription(
3240
@NotNull SubscribeTopicFilter topicFilter,
3341
@NotNull MqttClient mqttClient
3442
) {
3543
var subscriber = new Subscriber(mqttClient, topicFilter);
36-
var subscribers = subscriptions.computeIfAbsent(
37-
topicFilter.getTopicFilter(),
38-
key -> new FastArraySet<>(Subscriber.class)
39-
);
40-
subscribers.add(subscriber);
44+
var subscribers = subscriptions.getInReadLock(topicFilter.getTopicName(), ObjectDictionary::get);
45+
46+
if (subscribers == null) {
47+
subscribers = subscriptions.getInWriteLock(
48+
topicFilter.getTopicName(),
49+
SUBSCRIBER_ARRAY_SUPPLIER,
50+
ObjectDictionary::getOrCompute
51+
);
52+
}
53+
54+
//noinspection ConstantConditions
55+
subscribers.runInWriteLock(subscriber, Array::add);
56+
4157
return topicFilter.getQos().getSubscribeAckReasonCode();
4258
}
4359

4460
public @NotNull UnsubscribeAckReasonCode removeSubscription(
4561
@NotNull String topicName,
4662
@NotNull MqttClient mqttClient
4763
) {
48-
var subscribers = subscriptions.getOrDefault(topicName, Array.empty());
49-
if (subscribers.removeIf(subscriber -> mqttClient.equals(subscriber.getMqttClient()))) {
50-
return UnsubscribeAckReasonCode.SUCCESS;
51-
} else {
64+
var subscribers = subscriptions.getInReadLock(topicName, ObjectDictionary::get);
65+
66+
if (subscribers == null) {
5267
return UnsubscribeAckReasonCode.NO_SUBSCRIPTION_EXISTED;
68+
} else {
69+
//noinspection ConstantConditions
70+
boolean removed = subscribers.getInWriteLock(
71+
mqttClient,
72+
(subs, client) -> subs.removeIf(subscriber -> client.equals(subscriber.getMqttClient()))
73+
);
74+
75+
return removed ? UnsubscribeAckReasonCode.SUCCESS : UnsubscribeAckReasonCode.NO_SUBSCRIPTION_EXISTED;
5376
}
5477
}
5578
}

src/test/groovy/com/ss/mqtt/broker/test/network/in/SubscribeInPacketTest.groovy

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ class SubscribeInPacketTest extends BaseInPacketTest {
2929
result
3030
packet.topicFilters.size() == 2
3131
packet.topicFilters.get(0).getQos() == QoS.AT_LEAST_ONCE_DELIVERY
32-
packet.topicFilters.get(0).getTopicFilter() == topicFilter
32+
packet.topicFilters.get(0).getTopicName() == topicFilter
3333
packet.topicFilters.get(0).isNoLocal()
3434
!packet.topicFilters.get(0).isRetainAsPublished()
3535
packet.topicFilters.get(0).getRetainHandling() == SubscribeRetainHandling.SEND_AT_THE_TIME_OF_SUBSCRIBE
3636
packet.topicFilters.get(1).getQos() == QoS.EXACTLY_ONCE_DELIVERY
37-
packet.topicFilters.get(1).getTopicFilter() == topicFilter2
37+
packet.topicFilters.get(1).getTopicName() == topicFilter2
3838
packet.topicFilters.get(1).isNoLocal()
3939
!packet.topicFilters.get(1).isRetainAsPublished()
4040
packet.topicFilters.get(1).getRetainHandling() == SubscribeRetainHandling.SEND_AT_THE_TIME_OF_SUBSCRIBE
@@ -69,12 +69,12 @@ class SubscribeInPacketTest extends BaseInPacketTest {
6969
result
7070
packet.topicFilters.size() == 2
7171
packet.topicFilters.get(0).getQos() == QoS.AT_LEAST_ONCE_DELIVERY
72-
packet.topicFilters.get(0).getTopicFilter() == topicFilter
72+
packet.topicFilters.get(0).getTopicName() == topicFilter
7373
!packet.topicFilters.get(0).isNoLocal()
7474
packet.topicFilters.get(0).isRetainAsPublished()
7575
packet.topicFilters.get(0).getRetainHandling() == SubscribeRetainHandling.SEND_AT_THE_TIME_OF_SUBSCRIBE
7676
packet.topicFilters.get(1).getQos() == QoS.EXACTLY_ONCE_DELIVERY
77-
packet.topicFilters.get(1).getTopicFilter() == topicFilter2
77+
packet.topicFilters.get(1).getTopicName() == topicFilter2
7878
packet.topicFilters.get(1).isNoLocal()
7979
!packet.topicFilters.get(1).isRetainAsPublished()
8080
packet.topicFilters.get(1).getRetainHandling() == SubscribeRetainHandling.SEND_AT_SUBSCRIBE_ONLY_IF_THE_SUBSCRIPTION_DOES_NOT_CURRENTLY_EXIST
@@ -98,12 +98,12 @@ class SubscribeInPacketTest extends BaseInPacketTest {
9898
result
9999
packet.topicFilters.size() == 2
100100
packet.topicFilters.get(0).getQos() == QoS.AT_LEAST_ONCE_DELIVERY
101-
packet.topicFilters.get(0).getTopicFilter() == topicFilter
101+
packet.topicFilters.get(0).getTopicName() == topicFilter
102102
!packet.topicFilters.get(0).isNoLocal()
103103
!packet.topicFilters.get(0).isRetainAsPublished()
104104
packet.topicFilters.get(0).getRetainHandling() == SubscribeRetainHandling.SEND_AT_THE_TIME_OF_SUBSCRIBE
105105
packet.topicFilters.get(1).getQos() == QoS.EXACTLY_ONCE_DELIVERY
106-
packet.topicFilters.get(1).getTopicFilter() == topicFilter2
106+
packet.topicFilters.get(1).getTopicName() == topicFilter2
107107
!packet.topicFilters.get(1).isNoLocal()
108108
!packet.topicFilters.get(1).isRetainAsPublished()
109109
packet.topicFilters.get(1).getRetainHandling() == SubscribeRetainHandling.SEND_AT_THE_TIME_OF_SUBSCRIBE

0 commit comments

Comments
 (0)