From e8e3e2344ed8f5ba3d506db039cdcdf72b595ffa Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Wed, 3 Dec 2025 20:35:03 +0800 Subject: [PATCH 01/12] Initial commit --- include/pulsar/EncryptionContext.h | 119 +++++++++++++++++++++++ include/pulsar/Message.h | 22 ++--- lib/Commands.cc | 14 +-- lib/Commands.h | 3 +- lib/ConsumerImpl.cc | 72 +++++++------- lib/ConsumerImpl.h | 14 ++- lib/EncryptionContext.cc | 48 +++++++++ lib/Message.cc | 68 ++----------- lib/MessageBatch.cc | 3 +- lib/MessageCrypto.cc | 35 +++---- lib/MessageCrypto.h | 14 +-- lib/MessageImpl.cc | 51 ++++++++++ lib/MessageImpl.h | 38 +++++--- tests/EncryptionTests.cc | 150 +++++++++++++++++++++++++++++ tests/PulsarFriend.h | 2 + 15 files changed, 498 insertions(+), 155 deletions(-) create mode 100644 include/pulsar/EncryptionContext.h create mode 100644 lib/EncryptionContext.cc create mode 100644 tests/EncryptionTests.cc diff --git a/include/pulsar/EncryptionContext.h b/include/pulsar/EncryptionContext.h new file mode 100644 index 00000000..3576410a --- /dev/null +++ b/include/pulsar/EncryptionContext.h @@ -0,0 +1,119 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include "CompressionType.h" +#include "defines.h" + +namespace pulsar { + +namespace proto { +class MessageMetadata; +} + +class Message; + +struct PULSAR_PUBLIC EncryptionKey { + std::string key; + std::string value; + std::unordered_map metadata; + + explicit EncryptionKey() = default; + + // Support in-place construction + EncryptionKey(const std::string& key, const std::string& value, + const decltype(EncryptionKey::metadata)& metadata) + : key(key), value(value), metadata(metadata) {} +}; + +/** + * It contains encryption and compression information in it using which application can decrypt consumed + * message with encrypted-payload. + */ +class PULSAR_PUBLIC EncryptionContext { + public: + explicit EncryptionContext() + : compressionType_(CompressionNone), + uncompressedMessageSize_(0), + batchSize_(-1), + isDecryptionFailed_(false) {} + EncryptionContext(const EncryptionContext&) = default; + EncryptionContext(EncryptionContext&&) noexcept = default; + EncryptionContext(const proto::MessageMetadata& metadata, bool isDecryptionFailed); + + using KeysType = std::vector; + + /** + * @return the map of encryption keys used for the message + */ + const KeysType& keys() const noexcept { return keys_; } + + /** + * @return the encryption parameter used for the message + */ + const std::string& param() const noexcept { return param_; } + + /** + * @return the encryption algorithm used for the message + */ + const std::string& algorithm() const noexcept { return algorithm_; } + + /** + * @return the compression type used for the message + */ + CompressionType compressionType() const noexcept { return compressionType_; } + + /** + * @return the uncompressed message size if the message is compressed, 0 otherwise + */ + uint32_t uncompressedMessageSize() const noexcept { return uncompressedMessageSize_; } + + /** + * @return the batch size if the message is part of a batch, -1 otherwise + */ + int32_t batchSize() const noexcept { return batchSize_; } + + /** + * When the `ConsumerConfiguration#getCryptoFailureAction` is set to `CONSUME`, the message will still be + * returned even if the decryption failed, in this case, the message payload is still not decrypted but + * users have no way to know that. This method is provided to let users know whether the decryption + * failed. + * + * @return whether the decryption failed + */ + bool isDecryptionFailed() const noexcept { return isDecryptionFailed_; } + + private: + KeysType keys_; + std::string param_; + std::string algorithm_; + CompressionType compressionType_; + uint32_t uncompressedMessageSize_; + int32_t batchSize_; + bool isDecryptionFailed_; + + friend class ConsumerImpl; +}; + +} // namespace pulsar diff --git a/include/pulsar/Message.h b/include/pulsar/Message.h index ea4c4ab4..0a5ba7e0 100644 --- a/include/pulsar/Message.h +++ b/include/pulsar/Message.h @@ -19,22 +19,18 @@ #ifndef MESSAGE_HPP_ #define MESSAGE_HPP_ +#include #include #include #include +#include #include #include "KeyValue.h" #include "MessageId.h" namespace pulsar { -namespace proto { -class CommandMessage; -class BrokerEntryMetadata; -class MessageMetadata; -class SingleMessageMetadata; -} // namespace proto class SharedBuffer; class MessageBuilder; @@ -202,19 +198,19 @@ class PULSAR_PUBLIC Message { */ const std::string& getProducerName() const noexcept; + /** + * @return the optional encryption context that is present when the message is encrypted, the pointer is + * valid as the Message instance is alive + */ + std::optional getEncryptionContext() const; + bool operator==(const Message& msg) const; protected: typedef std::shared_ptr MessageImplPtr; MessageImplPtr impl_; - Message(MessageImplPtr& impl); - Message(const MessageId& messageId, proto::BrokerEntryMetadata& brokerEntryMetadata, - proto::MessageMetadata& metadata, SharedBuffer& payload); - /// Used for Batch Messages - Message(const MessageId& messageId, proto::BrokerEntryMetadata& brokerEntryMetadata, - proto::MessageMetadata& metadata, SharedBuffer& payload, - proto::SingleMessageMetadata& singleMetadata, const std::shared_ptr& topicName); + Message(const MessageImplPtr& impl); friend class PartitionedProducerImpl; friend class MultiTopicsConsumerImpl; friend class MessageBuilder; diff --git a/lib/Commands.cc b/lib/Commands.cc index 3c687c0a..f244db69 100644 --- a/lib/Commands.cc +++ b/lib/Commands.cc @@ -906,7 +906,8 @@ uint64_t Commands::serializeSingleMessagesToBatchPayload(SharedBuffer& batchPayl } Message Commands::deSerializeSingleMessageInBatch(Message& batchedMessage, int32_t batchIndex, - int32_t batchSize, const BatchMessageAckerPtr& acker) { + int32_t batchSize, const BatchMessageAckerPtr& acker, + const optional& encryptionContext) { SharedBuffer& uncompressedPayload = batchedMessage.impl_->payload; // Format of batch message @@ -926,12 +927,13 @@ Message Commands::deSerializeSingleMessageInBatch(Message& batchedMessage, int32 const MessageId& m = batchedMessage.impl_->messageId; auto messageId = MessageIdBuilder::from(m).batchIndex(batchIndex).batchSize(batchSize).build(); auto batchedMessageId = std::make_shared(*(messageId.impl_), acker); - Message singleMessage(MessageId{batchedMessageId}, batchedMessage.impl_->brokerEntryMetadata, - batchedMessage.impl_->metadata, payload, metadata, - batchedMessage.impl_->topicName_); - singleMessage.impl_->cnx_ = batchedMessage.impl_->cnx_; - return singleMessage; + auto msgImpl = std::make_shared(messageId, batchedMessage.impl_->brokerEntryMetadata, + batchedMessage.impl_->metadata, payload, metadata, + batchedMessage.impl_->topicName_, encryptionContext); + msgImpl->cnx_ = batchedMessage.impl_->cnx_; + + return Message(msgImpl); } MessageIdImplPtr Commands::getMessageIdImpl(const MessageId& messageId) { return messageId.impl_; } diff --git a/lib/Commands.h b/lib/Commands.h index 8403d6e2..be778456 100644 --- a/lib/Commands.h +++ b/lib/Commands.h @@ -155,7 +155,8 @@ class Commands { const std::vector& messages); static Message deSerializeSingleMessageInBatch(Message& batchedMessage, int32_t batchIndex, - int32_t batchSize, const BatchMessageAckerPtr& acker); + int32_t batchSize, const BatchMessageAckerPtr& acker, + const optional& encryptionContext); static MessageIdImplPtr getMessageIdImpl(const MessageId& messageId); diff --git a/lib/ConsumerImpl.cc b/lib/ConsumerImpl.cc index 4781e966..74d2ffc1 100644 --- a/lib/ConsumerImpl.cc +++ b/lib/ConsumerImpl.cc @@ -548,25 +548,27 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: bool& isChecksumValid, proto::BrokerEntryMetadata& brokerEntryMetadata, proto::MessageMetadata& metadata, SharedBuffer& payload) { LOG_DEBUG(getName() << "Received Message -- Size: " << payload.readableBytes()); - - if (!decryptMessageIfNeeded(cnx, msg, metadata, payload)) { - // Message was discarded or not consumed due to decryption failure - return; - } - if (!isChecksumValid) { // Message discarded for checksum error discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_ChecksumMismatch); return; } - auto redeliveryCount = msg.redelivery_count(); - const bool isMessageUndecryptable = - metadata.encryption_keys_size() > 0 && !config_.getCryptoKeyReader().get() && - config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME; + auto encryptionContext = metadata.encryption_keys_size() > 0 + ? optional{std::in_place, metadata, false} + : std::nullopt; + auto decryptResult = decryptMessageIfNeeded(cnx, encryptionContext, payload, msg.message_id()); + if (decryptResult == FAILED) { + // Message was discarded due to decryption failure or not consumed due to decryption failure + return; + } else if (decryptResult == CONSUME_ENCRYPTED) { + encryptionContext->isDecryptionFailed_ = true; + } + + auto redeliveryCount = msg.redelivery_count(); const bool isChunkedMessage = metadata.num_chunks_from_msg() > 1; - if (!isMessageUndecryptable && !isChunkedMessage) { + if (decryptResult == DECRYPTED && !isChunkedMessage) { if (!uncompressMessageIfNeeded(cnx, msg.message_id(), metadata, payload, true)) { // Message was discarded on decompression error return; @@ -586,9 +588,9 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: } } - Message m(messageId, brokerEntryMetadata, metadata, payload); + Message m{std::make_shared(messageId, brokerEntryMetadata, metadata, payload, std::nullopt, + getTopicPtr(), std::move(encryptionContext))}; m.impl_->cnx_ = cnx.get(); - m.impl_->setTopicName(getTopicPtr()); m.impl_->setRedeliveryCount(msg.redelivery_count()); if (metadata.has_schema_version()) { @@ -610,14 +612,16 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: return; } - if (metadata.has_num_messages_in_batch()) { + // When the decryption failed, the whole batch message will be treated as a single message. + if (metadata.has_num_messages_in_batch() && decryptResult == DECRYPTED) { BitSet::Data words(msg.ack_set_size()); for (int i = 0; i < words.size(); i++) { words[i] = msg.ack_set(i); } BitSet ackSet{std::move(words)}; Lock lock(mutex_); - numOfMessageReceived = receiveIndividualMessagesFromBatch(cnx, m, ackSet, msg.redelivery_count()); + numOfMessageReceived = + receiveIndividualMessagesFromBatch(cnx, m, ackSet, msg.redelivery_count(), encryptionContext); } else { // try convert key value data. m.impl_->convertPayloadToKeyValue(config_.getSchema()); @@ -742,9 +746,9 @@ void ConsumerImpl::notifyPendingReceivedCallback(Result result, Message& msg, } // Zero Queue size is not supported with Batch Messages -uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnectionPtr& cnx, - Message& batchedMessage, const BitSet& ackSet, - int redeliveryCount) { +uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch( + const ClientConnectionPtr& cnx, Message& batchedMessage, const BitSet& ackSet, int redeliveryCount, + const optional& encryptionContext) { auto batchSize = batchedMessage.impl_->metadata.num_messages_in_batch(); LOG_DEBUG("Received Batch messages of size - " << batchSize << " -- msgId: " << batchedMessage.getMessageId()); @@ -756,7 +760,8 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnection std::vector possibleToDeadLetter; for (int i = 0; i < batchSize; i++) { // This is a cheap copy since message contains only one shared pointer (impl_) - Message msg = Commands::deSerializeSingleMessageInBatch(batchedMessage, i, batchSize, acker); + Message msg = + Commands::deSerializeSingleMessageInBatch(batchedMessage, i, batchSize, acker, encryptionContext); msg.impl_->setRedeliveryCount(redeliveryCount); msg.impl_->setTopicName(batchedMessage.impl_->topicName_); msg.impl_->convertPayloadToKeyValue(config_.getSchema()); @@ -812,50 +817,51 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnection return batchSize - skippedMessages; } -bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, - const proto::MessageMetadata& metadata, SharedBuffer& payload) { - if (!metadata.encryption_keys_size()) { - return true; +auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, + const optional& context, SharedBuffer& payload, + const proto::MessageIdData& msgId) -> DecryptResult { + if (!context.has_value()) { + return DECRYPTED; } // If KeyReader is not configured throw exception based on config param if (!config_.isEncryptionEnabled()) { if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) { LOG_WARN(getName() << "CryptoKeyReader is not implemented. Consuming encrypted message."); - return true; + return CONSUME_ENCRYPTED; } else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) { LOG_WARN(getName() << "Skipping decryption since CryptoKeyReader is not implemented and config " "is set to discard"); - discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_DecryptionError); + discardCorruptedMessage(cnx, msgId, CommandAck_ValidationError_DecryptionError); } else { LOG_ERROR(getName() << "Message delivery failed since CryptoKeyReader is not implemented to " "consume encrypted message"); - auto messageId = MessageIdBuilder::from(msg.message_id()).build(); + auto messageId = MessageIdBuilder::from(msgId).build(); unAckedMessageTrackerPtr_->add(messageId); } - return false; + return FAILED; } SharedBuffer decryptedPayload; - if (msgCrypto_->decrypt(metadata, payload, config_.getCryptoKeyReader(), decryptedPayload)) { + if (msgCrypto_->decrypt(*context, payload, config_.getCryptoKeyReader(), decryptedPayload)) { payload = decryptedPayload; - return true; + return DECRYPTED; } if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) { // Note, batch message will fail to consume even if config is set to consume LOG_WARN( getName() << "Decryption failed. Consuming encrypted message since config is set to consume."); - return true; + return CONSUME_ENCRYPTED; } else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) { LOG_WARN(getName() << "Discarding message since decryption failed and config is set to discard"); - discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_DecryptionError); + discardCorruptedMessage(cnx, msgId, CommandAck_ValidationError_DecryptionError); } else { LOG_ERROR(getName() << "Message delivery failed since unable to decrypt incoming message"); - auto messageId = MessageIdBuilder::from(msg.message_id()).build(); + auto messageId = MessageIdBuilder::from(msgId).build(); unAckedMessageTrackerPtr_->add(messageId); } - return false; + return FAILED; } bool ConsumerImpl::uncompressMessageIfNeeded(const ClientConnectionPtr& cnx, diff --git a/lib/ConsumerImpl.h b/lib/ConsumerImpl.h index c1df0804..02f9cb1b 100644 --- a/lib/ConsumerImpl.h +++ b/lib/ConsumerImpl.h @@ -64,6 +64,7 @@ using UnAckedMessageTrackerPtr = std::shared_ptr namespace proto { class CommandMessage; class BrokerEntryMetadata; +class MessageIdData; class MessageMetadata; } // namespace proto @@ -190,13 +191,20 @@ class ConsumerImpl : public ConsumerImplBase { void increaseAvailablePermits(const Message& msg); void drainIncomingMessageQueue(size_t count); uint32_t receiveIndividualMessagesFromBatch(const ClientConnectionPtr& cnx, Message& batchedMessage, - const BitSet& ackSet, int redeliveryCount); + const BitSet& ackSet, int redeliveryCount, + const optional& encryptionContext); bool isPriorBatchIndex(int32_t idx); bool isPriorEntryIndex(int64_t idx); void brokerConsumerStatsListener(Result, BrokerConsumerStatsImpl, const BrokerConsumerStatsCallback&); - bool decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, - const proto::MessageMetadata& metadata, SharedBuffer& payload); + enum DecryptResult + { + DECRYPTED, + CONSUME_ENCRYPTED, + FAILED + }; + DecryptResult decryptMessageIfNeeded(const ClientConnectionPtr&, const optional&, + SharedBuffer& payload, const proto::MessageIdData&); // TODO - Convert these functions to lambda when we move to C++11 Result receiveHelper(Message& msg); diff --git a/lib/EncryptionContext.cc b/lib/EncryptionContext.cc new file mode 100644 index 00000000..5376f062 --- /dev/null +++ b/lib/EncryptionContext.cc @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "PulsarApi.pb.h" + +namespace pulsar { + +static EncryptionContext::KeysType encryptedKeysFromMetadata(const proto::MessageMetadata& msgMetadata) { + EncryptionContext::KeysType keys; + for (auto&& key : msgMetadata.encryption_keys()) { + decltype(EncryptionKey::metadata) metadata; + for (int i = 0; i < key.metadata_size(); i++) { + const auto& entry = key.metadata(i); + metadata[entry.key()] = entry.value(); + } + keys.emplace_back(key.key(), key.value(), std::move(metadata)); + } + return keys; +} + +EncryptionContext::EncryptionContext(const proto::MessageMetadata& msgMetadata, bool isDecryptionFailed) + + : keys_(encryptedKeysFromMetadata(msgMetadata)), + param_(msgMetadata.encryption_param()), + algorithm_(msgMetadata.encryption_algo()), + compressionType_(static_cast(msgMetadata.compression())), + uncompressedMessageSize_(msgMetadata.uncompressed_size()), + batchSize_(msgMetadata.has_num_messages_in_batch() ? msgMetadata.num_messages_in_batch() : -1), + isDecryptionFailed_(isDecryptionFailed) {} + +} // namespace pulsar diff --git a/lib/Message.cc b/lib/Message.cc index 1e26b521..5faf9c35 100644 --- a/lib/Message.cc +++ b/lib/Message.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -24,10 +25,7 @@ #include #include "Int64SerDes.h" -#include "KeyValueImpl.h" #include "MessageImpl.h" -#include "PulsarApi.pb.h" -#include "SharedBuffer.h" using namespace pulsar; @@ -68,62 +66,7 @@ std::string Message::getDataAsString() const { return std::string((const char*)g Message::Message() : impl_() {} -Message::Message(MessageImplPtr& impl) : impl_(impl) {} - -Message::Message(const MessageId& messageId, proto::BrokerEntryMetadata& brokerEntryMetadata, - proto::MessageMetadata& metadata, SharedBuffer& payload) - : impl_(std::make_shared()) { - impl_->messageId = messageId; - impl_->brokerEntryMetadata = brokerEntryMetadata; - impl_->metadata = metadata; - impl_->payload = payload; -} - -Message::Message(const MessageId& messageID, proto::BrokerEntryMetadata& brokerEntryMetadata, - proto::MessageMetadata& metadata, SharedBuffer& payload, - proto::SingleMessageMetadata& singleMetadata, const std::shared_ptr& topicName) - : impl_(std::make_shared()) { - impl_->messageId = messageID; - impl_->brokerEntryMetadata = brokerEntryMetadata; - impl_->metadata = metadata; - impl_->payload = payload; - impl_->metadata.mutable_properties()->CopyFrom(singleMetadata.properties()); - impl_->topicName_ = topicName; - - impl_->metadata.clear_properties(); - if (singleMetadata.properties_size() > 0) { - impl_->metadata.mutable_properties()->Reserve(singleMetadata.properties_size()); - for (int i = 0; i < singleMetadata.properties_size(); i++) { - auto keyValue = proto::KeyValue().New(); - *keyValue = singleMetadata.properties(i); - impl_->metadata.mutable_properties()->AddAllocated(keyValue); - } - } - - if (singleMetadata.has_partition_key()) { - impl_->metadata.set_partition_key(singleMetadata.partition_key()); - } else { - impl_->metadata.clear_partition_key(); - } - - if (singleMetadata.has_ordering_key()) { - impl_->metadata.set_ordering_key(singleMetadata.ordering_key()); - } else { - impl_->metadata.clear_ordering_key(); - } - - if (singleMetadata.has_event_time()) { - impl_->metadata.set_event_time(singleMetadata.event_time()); - } else { - impl_->metadata.clear_event_time(); - } - - if (singleMetadata.has_sequence_id()) { - impl_->metadata.set_sequence_id(singleMetadata.sequence_id()); - } else { - impl_->metadata.clear_sequence_id(); - } -} +Message::Message(const MessageImplPtr& impl) : impl_(impl) {} const MessageId& Message::getMessageId() const { if (!impl_) { @@ -220,6 +163,13 @@ const std::string& Message::getProducerName() const noexcept { return impl_->metadata.producer_name(); } +std::optional Message::getEncryptionContext() const { + if (!impl_ || !impl_->encryptionContext_.has_value()) { + return std::nullopt; + } + return {&(*impl_->encryptionContext_)}; +} + bool Message::operator==(const Message& msg) const { return getMessageId() == msg.getMessageId(); } KeyValue Message::getKeyValueData() const { return KeyValue(impl_->keyValuePtr); } diff --git a/lib/MessageBatch.cc b/lib/MessageBatch.cc index f2c1cb29..4678e766 100644 --- a/lib/MessageBatch.cc +++ b/lib/MessageBatch.cc @@ -49,7 +49,8 @@ MessageBatch& MessageBatch::parseFrom(const SharedBuffer& payload, uint32_t batc auto acker = BatchMessageAckerImpl::create(batchSize); for (int i = 0; i < batchSize; ++i) { - batch_.push_back(Commands::deSerializeSingleMessageInBatch(batchMessage_, i, batchSize, acker)); + batch_.push_back( + Commands::deSerializeSingleMessageInBatch(batchMessage_, i, batchSize, acker, std::nullopt)); } return *this; } diff --git a/lib/MessageCrypto.cc b/lib/MessageCrypto.cc index b06ff652..daa492ea 100644 --- a/lib/MessageCrypto.cc +++ b/lib/MessageCrypto.cc @@ -394,13 +394,13 @@ bool MessageCrypto::encrypt(const std::set& encKeys, const CryptoKe return true; } -bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const CryptoKeyReader& keyReader) { - const auto& keyName = encKeys.key(); - const auto& encryptedDataKey = encKeys.value(); - const auto& encKeyMeta = encKeys.metadata(); +bool MessageCrypto::decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& keyReader) { + const auto& keyName = encKeys.key; + const auto& encryptedDataKey = encKeys.value; + const auto& encKeyMeta = encKeys.metadata; StringMap keyMeta; for (auto iter = encKeyMeta.begin(); iter != encKeyMeta.end(); iter++) { - keyMeta[iter->key()] = iter->value(); + keyMeta[iter->first] = iter->second; } // Read the private key info using callback @@ -451,11 +451,10 @@ bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const C return true; } -bool MessageCrypto::decryptData(const std::string& dataKeySecret, const proto::MessageMetadata& msgMetadata, +bool MessageCrypto::decryptData(const std::string& dataKeySecret, const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decryptedPayload) { // unpack iv and encrypted data - msgMetadata.encryption_param().copy(reinterpret_cast(iv_.get()), - msgMetadata.encryption_param().size()); + context.param().copy(reinterpret_cast(iv_.get()), context.param().size()); EVP_CIPHER_CTX* cipherCtx = NULL; decryptedPayload = SharedBuffer::allocate(payload.readableBytes() + EVP_MAX_BLOCK_LENGTH + tagLen_); @@ -518,15 +517,14 @@ bool MessageCrypto::decryptData(const std::string& dataKeySecret, const proto::M return true; } -bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, +bool MessageCrypto::getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decryptedPayload) { SharedBuffer decryptedData; bool dataDecrypted = false; - for (auto iter = msgMetadata.encryption_keys().begin(); iter != msgMetadata.encryption_keys().end(); - iter++) { - const std::string& keyName = iter->key(); - const std::string& encDataKey = iter->value(); + for (auto&& kv : context.keys()) { + const std::string& keyName = kv.key; + const std::string& encDataKey = kv.value; unsigned char keyDigest[EVP_MAX_MD_SIZE]; unsigned int digestLen = 0; getDigest(keyName, encDataKey.c_str(), encDataKey.size(), keyDigest, digestLen); @@ -539,7 +537,7 @@ bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetada // retruns a different key, decryption fails. At this point, we would // call decryptDataKey to refresh the cache and come here again to decrypt. auto dataKeyEntry = dataKeyCacheIter->second; - if (decryptData(dataKeyEntry.first, msgMetadata, payload, decryptedPayload)) { + if (decryptData(dataKeyEntry.first, context, payload, decryptedPayload)) { dataDecrypted = true; break; } @@ -552,17 +550,16 @@ bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetada return dataDecrypted; } -bool MessageCrypto::decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, +bool MessageCrypto::decrypt(const EncryptionContext& context, SharedBuffer& payload, const CryptoKeyReaderPtr& keyReader, SharedBuffer& decryptedPayload) { // Attempt to decrypt using the existing key - if (getKeyAndDecryptData(msgMetadata, payload, decryptedPayload)) { + if (getKeyAndDecryptData(context, payload, decryptedPayload)) { return true; } // Either first time, or decryption failed. Attempt to regenerate data key bool isDataKeyDecrypted = false; - for (int index = 0; index < msgMetadata.encryption_keys_size(); index++) { - const proto::EncryptionKeys& encKeys = msgMetadata.encryption_keys(index); + for (auto&& encKeys : context.keys()) { if (decryptDataKey(encKeys, *keyReader)) { isDataKeyDecrypted = true; break; @@ -574,7 +571,7 @@ bool MessageCrypto::decrypt(const proto::MessageMetadata& msgMetadata, SharedBuf return false; } - return getKeyAndDecryptData(msgMetadata, payload, decryptedPayload); + return getKeyAndDecryptData(context, payload, decryptedPayload); } } /* namespace pulsar */ diff --git a/lib/MessageCrypto.h b/lib/MessageCrypto.h index cd07bf55..84075ce9 100644 --- a/lib/MessageCrypto.h +++ b/lib/MessageCrypto.h @@ -26,10 +26,10 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -90,15 +90,15 @@ class MessageCrypto { /* * Decrypt the payload using the data key. Keys used to encrypt data key can be retrieved from msgMetadata * - * @param msgMetadata Message Metadata + * @param context the context of encryption * @param payload Message which needs to be decrypted * @param keyReader KeyReader implementation to retrieve key value * @param decryptedPayload Contains decrypted payload if success * * @return true if success */ - bool decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, - const CryptoKeyReaderPtr& keyReader, SharedBuffer& decryptedPayload); + bool decrypt(const EncryptionContext& context, SharedBuffer& payload, const CryptoKeyReaderPtr& keyReader, + SharedBuffer& decryptedPayload); private: typedef std::unique_lock Lock; @@ -137,10 +137,10 @@ class MessageCrypto { Result addPublicKeyCipher(const std::string& keyName, const CryptoKeyReaderPtr& keyReader); - bool decryptDataKey(const proto::EncryptionKeys& encKeys, const CryptoKeyReader& keyReader); - bool decryptData(const std::string& dataKeySecret, const proto::MessageMetadata& msgMetadata, + bool decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& keyReader); + bool decryptData(const std::string& dataKeySecret, const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decPayload); - bool getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, + bool getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decryptedPayload); std::string stringToHex(const std::string& inputStr, size_t len); std::string stringToHex(const char* inputStr, size_t len); diff --git a/lib/MessageImpl.cc b/lib/MessageImpl.cc index 17239a82..da3b8902 100644 --- a/lib/MessageImpl.cc +++ b/lib/MessageImpl.cc @@ -18,8 +18,59 @@ */ #include "MessageImpl.h" +#include + +#include "PulsarApi.pb.h" + namespace pulsar { +MessageImpl::MessageImpl(const MessageId& messageId, const proto::BrokerEntryMetadata& brokerEntryMetadata, + const proto::MessageMetadata& metadata, const SharedBuffer& payload, + const optional& singleMetadata, + const std::shared_ptr& topicName, + optional encryptionContext) + : messageId(messageId), + brokerEntryMetadata(brokerEntryMetadata), + metadata(metadata), + payload(payload), + topicName_(topicName), + encryptionContext_(std::move(encryptionContext)) { + if (singleMetadata.has_value()) { + this->metadata.clear_properties(); + if (singleMetadata->properties_size() > 0) { + this->metadata.mutable_properties()->Reserve(singleMetadata->properties_size()); + for (int i = 0; i < singleMetadata->properties_size(); i++) { + auto keyValue = proto::KeyValue().New(); + *keyValue = singleMetadata->properties(i); + this->metadata.mutable_properties()->AddAllocated(keyValue); + } + } + if (singleMetadata->has_partition_key()) { + this->metadata.set_partition_key(singleMetadata->partition_key()); + } else { + this->metadata.clear_partition_key(); + } + + if (singleMetadata->has_ordering_key()) { + this->metadata.set_ordering_key(singleMetadata->ordering_key()); + } else { + this->metadata.clear_ordering_key(); + } + + if (singleMetadata->has_event_time()) { + this->metadata.set_event_time(singleMetadata->event_time()); + } else { + this->metadata.clear_event_time(); + } + + if (singleMetadata->has_sequence_id()) { + this->metadata.set_sequence_id(singleMetadata->sequence_id()); + } else { + this->metadata.clear_sequence_id(); + } + } +} + const Message::StringMap& MessageImpl::properties() { if (properties_.size() == 0) { for (int i = 0; i < metadata.properties_size(); i++) { diff --git a/lib/MessageImpl.h b/lib/MessageImpl.h index 6467b359..8c8eb016 100644 --- a/lib/MessageImpl.h +++ b/lib/MessageImpl.h @@ -19,14 +19,19 @@ #ifndef LIB_MESSAGEIMPL_H_ #define LIB_MESSAGEIMPL_H_ +#include #include #include +#include +#include + #include "KeyValueImpl.h" #include "PulsarApi.pb.h" #include "SharedBuffer.h" -using namespace pulsar; +using std::optional; + namespace pulsar { class PulsarWrapper; @@ -35,19 +40,13 @@ class BatchMessageContainer; class MessageImpl { public: - const Message::StringMap& properties(); + explicit MessageImpl() : encryptionContext_(std::nullopt) {} + MessageImpl(const MessageId& messageId, const proto::BrokerEntryMetadata& brokerEntryMetadata, + const proto::MessageMetadata& metadata, const SharedBuffer& payload, + const optional& singleMetadata, + const std::shared_ptr& topicName, optional encryptionContext); - proto::BrokerEntryMetadata brokerEntryMetadata; - proto::MessageMetadata metadata; - SharedBuffer payload; - std::shared_ptr keyValuePtr; - MessageId messageId; - ClientConnection* cnx_; - std::shared_ptr topicName_; - int redeliveryCount_; - bool hasSchemaVersion_; - const std::string* schemaVersion_; - std::weak_ptr consumerPtr_; + const Message::StringMap& properties(); const std::string& getPartitionKey() const; bool hasPartitionKey() const; @@ -81,6 +80,19 @@ class MessageImpl { friend class PulsarWrapper; friend class MessageBuilder; + MessageId messageId; + proto::BrokerEntryMetadata brokerEntryMetadata; + proto::MessageMetadata metadata; + SharedBuffer payload; + std::shared_ptr keyValuePtr; + ClientConnection* cnx_; + std::shared_ptr topicName_; + int redeliveryCount_; + bool hasSchemaVersion_; + const std::string* schemaVersion_; + std::weak_ptr consumerPtr_; + const optional encryptionContext_; + private: void setReplicationClusters(const std::vector& clusters); void setProperty(const std::string& name, const std::string& value); diff --git a/tests/EncryptionTests.cc b/tests/EncryptionTests.cc new file mode 100644 index 00000000..bbbf31ee --- /dev/null +++ b/tests/EncryptionTests.cc @@ -0,0 +1,150 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include + +#include "PulsarApi.pb.h" +#include "lib/CompressionCodec.h" +#include "lib/MessageCrypto.h" +#include "lib/SharedBuffer.h" +#include "tests/PulsarFriend.h" + +static std::string lookupUrl = "pulsar://localhost:6650"; + +using namespace pulsar; + +static CryptoKeyReaderPtr getDefaultCryptoKeyReader() { + return std::make_shared(TEST_CONF_DIR "/public-key.client-rsa.pem", + TEST_CONF_DIR "/private-key.client-rsa.pem"); +} + +static std::vector decryptValue(const Message& message) { + if (!message.getEncryptionContext().has_value()) { + return {message.getDataAsString()}; + } + auto context = message.getEncryptionContext().value(); + if (!context->isDecryptionFailed()) { + return {message.getDataAsString()}; + } + + MessageCrypto crypto{"test", false}; + auto msgImpl = PulsarFriend::getMessageImplPtr(message); + SharedBuffer decryptedPayload; + auto originalPayload = + SharedBuffer::copy(static_cast(message.getData()), message.getLength()); + if (!crypto.decrypt(*context, originalPayload, getDefaultCryptoKeyReader(), decryptedPayload)) { + throw std::runtime_error("Decryption failed"); + } + + SharedBuffer uncompressedPayload; + if (!CompressionCodecProvider::getCodec(context->compressionType()) + .decode(decryptedPayload, context->uncompressedMessageSize(), uncompressedPayload)) { + throw std::runtime_error("Decompression failed"); + } + + std::vector values; + if (auto batchSize = message.getEncryptionContext().value()->batchSize(); batchSize > 0) { + for (decltype(batchSize) i = 0; i < batchSize; i++) { + auto singleMetaSize = uncompressedPayload.readUnsignedInt(); + proto::SingleMessageMetadata singleMeta; + singleMeta.ParseFromArray(uncompressedPayload.data(), singleMetaSize); + uncompressedPayload.consume(singleMetaSize); + + auto payload = uncompressedPayload.slice(0, singleMeta.payload_size()); + uncompressedPayload.consume(payload.readableBytes()); + values.emplace_back(payload.data(), payload.readableBytes()); + } + } else { + // non-batched message + values.emplace_back(uncompressedPayload.data(), uncompressedPayload.readableBytes()); + } + return values; +} + +static void testDecryption(Client& client, const std::string& topic, bool decryptionSucceed, + int numMessageReceived) { + ProducerConfiguration producerConf; + producerConf.setCompressionType(CompressionLZ4); + producerConf.addEncryptionKey("client-rsa.pem"); + producerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); + + Producer producer; + ASSERT_EQ(ResultOk, client.createProducer(topic, producerConf, producer)); + + std::vector sentValues; + auto send = [&producer, &sentValues](const std::string& value) { + Message msg = MessageBuilder().setContent(value).build(); + producer.sendAsync(msg, nullptr); + sentValues.emplace_back(value); + }; + + for (int i = 0; i < 5; i++) { + send("msg-" + std::to_string(i)); + } + producer.flush(); + send("last-msg"); + producer.flush(); + + ASSERT_EQ(ResultOk, client.createProducer(topic, producer)); + send("unencrypted-msg"); + producer.flush(); + producer.close(); + + ConsumerConfiguration consumerConf; + consumerConf.setSubscriptionInitialPosition(InitialPositionEarliest); + if (decryptionSucceed) { + consumerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); + } else { + consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME); + } + Consumer consumer; + ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConf, consumer)); + + std::vector values; + for (int i = 0; i < numMessageReceived; i++) { + Message msg; + ASSERT_EQ(ResultOk, consumer.receive(msg, 3000)); + if (i < numMessageReceived - 1) { + ASSERT_TRUE(msg.getEncryptionContext().has_value()); + } + for (auto&& value : decryptValue(msg)) { + values.emplace_back(value); + } + } + ASSERT_EQ(values, sentValues); + consumer.close(); +} + +TEST(EncryptionTests, testDecryptionSuccess) { + Client client{lookupUrl}; + std::string topic = "test-decryption-success-" + std::to_string(time(nullptr)); + testDecryption(client, topic, true, 7); + client.close(); +} + +TEST(EncryptionTests, testDecryptionFailure) { + Client client{lookupUrl}; + std::string topic = "test-decryption-failure-" + std::to_string(time(nullptr)); + // The 1st batch that has 5 messages cannot be decrypted, so they can be received only once + testDecryption(client, topic, false, 3); + client.close(); +} diff --git a/tests/PulsarFriend.h b/tests/PulsarFriend.h index e7084050..780ec2a9 100644 --- a/tests/PulsarFriend.h +++ b/tests/PulsarFriend.h @@ -217,6 +217,8 @@ class PulsarFriend { return waitUntil(std::chrono::seconds(3), [producerImpl] { return !producerImpl->getCnx().expired(); }); } + + static auto getMessageImplPtr(const Message& message) { return message.impl_; } }; } // namespace pulsar From 27e0092bcc3f006a407d0cf949fb87b9e6ca6bf3 Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Fri, 5 Dec 2025 10:18:10 +0800 Subject: [PATCH 02/12] Revert "Initial commit" This reverts commit e8e3e2344ed8f5ba3d506db039cdcdf72b595ffa. --- include/pulsar/EncryptionContext.h | 119 ----------------------- include/pulsar/Message.h | 22 +++-- lib/Commands.cc | 14 ++- lib/Commands.h | 3 +- lib/ConsumerImpl.cc | 72 +++++++------- lib/ConsumerImpl.h | 14 +-- lib/EncryptionContext.cc | 48 --------- lib/Message.cc | 68 +++++++++++-- lib/MessageBatch.cc | 3 +- lib/MessageCrypto.cc | 35 ++++--- lib/MessageCrypto.h | 14 +-- lib/MessageImpl.cc | 51 ---------- lib/MessageImpl.h | 38 +++----- tests/EncryptionTests.cc | 150 ----------------------------- tests/PulsarFriend.h | 2 - 15 files changed, 155 insertions(+), 498 deletions(-) delete mode 100644 include/pulsar/EncryptionContext.h delete mode 100644 lib/EncryptionContext.cc delete mode 100644 tests/EncryptionTests.cc diff --git a/include/pulsar/EncryptionContext.h b/include/pulsar/EncryptionContext.h deleted file mode 100644 index 3576410a..00000000 --- a/include/pulsar/EncryptionContext.h +++ /dev/null @@ -1,119 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#pragma once - -#include -#include -#include -#include - -#include "CompressionType.h" -#include "defines.h" - -namespace pulsar { - -namespace proto { -class MessageMetadata; -} - -class Message; - -struct PULSAR_PUBLIC EncryptionKey { - std::string key; - std::string value; - std::unordered_map metadata; - - explicit EncryptionKey() = default; - - // Support in-place construction - EncryptionKey(const std::string& key, const std::string& value, - const decltype(EncryptionKey::metadata)& metadata) - : key(key), value(value), metadata(metadata) {} -}; - -/** - * It contains encryption and compression information in it using which application can decrypt consumed - * message with encrypted-payload. - */ -class PULSAR_PUBLIC EncryptionContext { - public: - explicit EncryptionContext() - : compressionType_(CompressionNone), - uncompressedMessageSize_(0), - batchSize_(-1), - isDecryptionFailed_(false) {} - EncryptionContext(const EncryptionContext&) = default; - EncryptionContext(EncryptionContext&&) noexcept = default; - EncryptionContext(const proto::MessageMetadata& metadata, bool isDecryptionFailed); - - using KeysType = std::vector; - - /** - * @return the map of encryption keys used for the message - */ - const KeysType& keys() const noexcept { return keys_; } - - /** - * @return the encryption parameter used for the message - */ - const std::string& param() const noexcept { return param_; } - - /** - * @return the encryption algorithm used for the message - */ - const std::string& algorithm() const noexcept { return algorithm_; } - - /** - * @return the compression type used for the message - */ - CompressionType compressionType() const noexcept { return compressionType_; } - - /** - * @return the uncompressed message size if the message is compressed, 0 otherwise - */ - uint32_t uncompressedMessageSize() const noexcept { return uncompressedMessageSize_; } - - /** - * @return the batch size if the message is part of a batch, -1 otherwise - */ - int32_t batchSize() const noexcept { return batchSize_; } - - /** - * When the `ConsumerConfiguration#getCryptoFailureAction` is set to `CONSUME`, the message will still be - * returned even if the decryption failed, in this case, the message payload is still not decrypted but - * users have no way to know that. This method is provided to let users know whether the decryption - * failed. - * - * @return whether the decryption failed - */ - bool isDecryptionFailed() const noexcept { return isDecryptionFailed_; } - - private: - KeysType keys_; - std::string param_; - std::string algorithm_; - CompressionType compressionType_; - uint32_t uncompressedMessageSize_; - int32_t batchSize_; - bool isDecryptionFailed_; - - friend class ConsumerImpl; -}; - -} // namespace pulsar diff --git a/include/pulsar/Message.h b/include/pulsar/Message.h index 0a5ba7e0..ea4c4ab4 100644 --- a/include/pulsar/Message.h +++ b/include/pulsar/Message.h @@ -19,18 +19,22 @@ #ifndef MESSAGE_HPP_ #define MESSAGE_HPP_ -#include #include #include #include -#include #include #include "KeyValue.h" #include "MessageId.h" namespace pulsar { +namespace proto { +class CommandMessage; +class BrokerEntryMetadata; +class MessageMetadata; +class SingleMessageMetadata; +} // namespace proto class SharedBuffer; class MessageBuilder; @@ -198,19 +202,19 @@ class PULSAR_PUBLIC Message { */ const std::string& getProducerName() const noexcept; - /** - * @return the optional encryption context that is present when the message is encrypted, the pointer is - * valid as the Message instance is alive - */ - std::optional getEncryptionContext() const; - bool operator==(const Message& msg) const; protected: typedef std::shared_ptr MessageImplPtr; MessageImplPtr impl_; - Message(const MessageImplPtr& impl); + Message(MessageImplPtr& impl); + Message(const MessageId& messageId, proto::BrokerEntryMetadata& brokerEntryMetadata, + proto::MessageMetadata& metadata, SharedBuffer& payload); + /// Used for Batch Messages + Message(const MessageId& messageId, proto::BrokerEntryMetadata& brokerEntryMetadata, + proto::MessageMetadata& metadata, SharedBuffer& payload, + proto::SingleMessageMetadata& singleMetadata, const std::shared_ptr& topicName); friend class PartitionedProducerImpl; friend class MultiTopicsConsumerImpl; friend class MessageBuilder; diff --git a/lib/Commands.cc b/lib/Commands.cc index f244db69..3c687c0a 100644 --- a/lib/Commands.cc +++ b/lib/Commands.cc @@ -906,8 +906,7 @@ uint64_t Commands::serializeSingleMessagesToBatchPayload(SharedBuffer& batchPayl } Message Commands::deSerializeSingleMessageInBatch(Message& batchedMessage, int32_t batchIndex, - int32_t batchSize, const BatchMessageAckerPtr& acker, - const optional& encryptionContext) { + int32_t batchSize, const BatchMessageAckerPtr& acker) { SharedBuffer& uncompressedPayload = batchedMessage.impl_->payload; // Format of batch message @@ -927,13 +926,12 @@ Message Commands::deSerializeSingleMessageInBatch(Message& batchedMessage, int32 const MessageId& m = batchedMessage.impl_->messageId; auto messageId = MessageIdBuilder::from(m).batchIndex(batchIndex).batchSize(batchSize).build(); auto batchedMessageId = std::make_shared(*(messageId.impl_), acker); + Message singleMessage(MessageId{batchedMessageId}, batchedMessage.impl_->brokerEntryMetadata, + batchedMessage.impl_->metadata, payload, metadata, + batchedMessage.impl_->topicName_); + singleMessage.impl_->cnx_ = batchedMessage.impl_->cnx_; - auto msgImpl = std::make_shared(messageId, batchedMessage.impl_->brokerEntryMetadata, - batchedMessage.impl_->metadata, payload, metadata, - batchedMessage.impl_->topicName_, encryptionContext); - msgImpl->cnx_ = batchedMessage.impl_->cnx_; - - return Message(msgImpl); + return singleMessage; } MessageIdImplPtr Commands::getMessageIdImpl(const MessageId& messageId) { return messageId.impl_; } diff --git a/lib/Commands.h b/lib/Commands.h index be778456..8403d6e2 100644 --- a/lib/Commands.h +++ b/lib/Commands.h @@ -155,8 +155,7 @@ class Commands { const std::vector& messages); static Message deSerializeSingleMessageInBatch(Message& batchedMessage, int32_t batchIndex, - int32_t batchSize, const BatchMessageAckerPtr& acker, - const optional& encryptionContext); + int32_t batchSize, const BatchMessageAckerPtr& acker); static MessageIdImplPtr getMessageIdImpl(const MessageId& messageId); diff --git a/lib/ConsumerImpl.cc b/lib/ConsumerImpl.cc index 74d2ffc1..4781e966 100644 --- a/lib/ConsumerImpl.cc +++ b/lib/ConsumerImpl.cc @@ -548,27 +548,25 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: bool& isChecksumValid, proto::BrokerEntryMetadata& brokerEntryMetadata, proto::MessageMetadata& metadata, SharedBuffer& payload) { LOG_DEBUG(getName() << "Received Message -- Size: " << payload.readableBytes()); - if (!isChecksumValid) { - // Message discarded for checksum error - discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_ChecksumMismatch); + + if (!decryptMessageIfNeeded(cnx, msg, metadata, payload)) { + // Message was discarded or not consumed due to decryption failure return; } - auto encryptionContext = metadata.encryption_keys_size() > 0 - ? optional{std::in_place, metadata, false} - : std::nullopt; - - auto decryptResult = decryptMessageIfNeeded(cnx, encryptionContext, payload, msg.message_id()); - if (decryptResult == FAILED) { - // Message was discarded due to decryption failure or not consumed due to decryption failure + if (!isChecksumValid) { + // Message discarded for checksum error + discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_ChecksumMismatch); return; - } else if (decryptResult == CONSUME_ENCRYPTED) { - encryptionContext->isDecryptionFailed_ = true; } auto redeliveryCount = msg.redelivery_count(); + const bool isMessageUndecryptable = + metadata.encryption_keys_size() > 0 && !config_.getCryptoKeyReader().get() && + config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME; + const bool isChunkedMessage = metadata.num_chunks_from_msg() > 1; - if (decryptResult == DECRYPTED && !isChunkedMessage) { + if (!isMessageUndecryptable && !isChunkedMessage) { if (!uncompressMessageIfNeeded(cnx, msg.message_id(), metadata, payload, true)) { // Message was discarded on decompression error return; @@ -588,9 +586,9 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: } } - Message m{std::make_shared(messageId, brokerEntryMetadata, metadata, payload, std::nullopt, - getTopicPtr(), std::move(encryptionContext))}; + Message m(messageId, brokerEntryMetadata, metadata, payload); m.impl_->cnx_ = cnx.get(); + m.impl_->setTopicName(getTopicPtr()); m.impl_->setRedeliveryCount(msg.redelivery_count()); if (metadata.has_schema_version()) { @@ -612,16 +610,14 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: return; } - // When the decryption failed, the whole batch message will be treated as a single message. - if (metadata.has_num_messages_in_batch() && decryptResult == DECRYPTED) { + if (metadata.has_num_messages_in_batch()) { BitSet::Data words(msg.ack_set_size()); for (int i = 0; i < words.size(); i++) { words[i] = msg.ack_set(i); } BitSet ackSet{std::move(words)}; Lock lock(mutex_); - numOfMessageReceived = - receiveIndividualMessagesFromBatch(cnx, m, ackSet, msg.redelivery_count(), encryptionContext); + numOfMessageReceived = receiveIndividualMessagesFromBatch(cnx, m, ackSet, msg.redelivery_count()); } else { // try convert key value data. m.impl_->convertPayloadToKeyValue(config_.getSchema()); @@ -746,9 +742,9 @@ void ConsumerImpl::notifyPendingReceivedCallback(Result result, Message& msg, } // Zero Queue size is not supported with Batch Messages -uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch( - const ClientConnectionPtr& cnx, Message& batchedMessage, const BitSet& ackSet, int redeliveryCount, - const optional& encryptionContext) { +uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnectionPtr& cnx, + Message& batchedMessage, const BitSet& ackSet, + int redeliveryCount) { auto batchSize = batchedMessage.impl_->metadata.num_messages_in_batch(); LOG_DEBUG("Received Batch messages of size - " << batchSize << " -- msgId: " << batchedMessage.getMessageId()); @@ -760,8 +756,7 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch( std::vector possibleToDeadLetter; for (int i = 0; i < batchSize; i++) { // This is a cheap copy since message contains only one shared pointer (impl_) - Message msg = - Commands::deSerializeSingleMessageInBatch(batchedMessage, i, batchSize, acker, encryptionContext); + Message msg = Commands::deSerializeSingleMessageInBatch(batchedMessage, i, batchSize, acker); msg.impl_->setRedeliveryCount(redeliveryCount); msg.impl_->setTopicName(batchedMessage.impl_->topicName_); msg.impl_->convertPayloadToKeyValue(config_.getSchema()); @@ -817,51 +812,50 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch( return batchSize - skippedMessages; } -auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, - const optional& context, SharedBuffer& payload, - const proto::MessageIdData& msgId) -> DecryptResult { - if (!context.has_value()) { - return DECRYPTED; +bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, + const proto::MessageMetadata& metadata, SharedBuffer& payload) { + if (!metadata.encryption_keys_size()) { + return true; } // If KeyReader is not configured throw exception based on config param if (!config_.isEncryptionEnabled()) { if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) { LOG_WARN(getName() << "CryptoKeyReader is not implemented. Consuming encrypted message."); - return CONSUME_ENCRYPTED; + return true; } else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) { LOG_WARN(getName() << "Skipping decryption since CryptoKeyReader is not implemented and config " "is set to discard"); - discardCorruptedMessage(cnx, msgId, CommandAck_ValidationError_DecryptionError); + discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_DecryptionError); } else { LOG_ERROR(getName() << "Message delivery failed since CryptoKeyReader is not implemented to " "consume encrypted message"); - auto messageId = MessageIdBuilder::from(msgId).build(); + auto messageId = MessageIdBuilder::from(msg.message_id()).build(); unAckedMessageTrackerPtr_->add(messageId); } - return FAILED; + return false; } SharedBuffer decryptedPayload; - if (msgCrypto_->decrypt(*context, payload, config_.getCryptoKeyReader(), decryptedPayload)) { + if (msgCrypto_->decrypt(metadata, payload, config_.getCryptoKeyReader(), decryptedPayload)) { payload = decryptedPayload; - return DECRYPTED; + return true; } if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) { // Note, batch message will fail to consume even if config is set to consume LOG_WARN( getName() << "Decryption failed. Consuming encrypted message since config is set to consume."); - return CONSUME_ENCRYPTED; + return true; } else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) { LOG_WARN(getName() << "Discarding message since decryption failed and config is set to discard"); - discardCorruptedMessage(cnx, msgId, CommandAck_ValidationError_DecryptionError); + discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_DecryptionError); } else { LOG_ERROR(getName() << "Message delivery failed since unable to decrypt incoming message"); - auto messageId = MessageIdBuilder::from(msgId).build(); + auto messageId = MessageIdBuilder::from(msg.message_id()).build(); unAckedMessageTrackerPtr_->add(messageId); } - return FAILED; + return false; } bool ConsumerImpl::uncompressMessageIfNeeded(const ClientConnectionPtr& cnx, diff --git a/lib/ConsumerImpl.h b/lib/ConsumerImpl.h index 02f9cb1b..c1df0804 100644 --- a/lib/ConsumerImpl.h +++ b/lib/ConsumerImpl.h @@ -64,7 +64,6 @@ using UnAckedMessageTrackerPtr = std::shared_ptr namespace proto { class CommandMessage; class BrokerEntryMetadata; -class MessageIdData; class MessageMetadata; } // namespace proto @@ -191,20 +190,13 @@ class ConsumerImpl : public ConsumerImplBase { void increaseAvailablePermits(const Message& msg); void drainIncomingMessageQueue(size_t count); uint32_t receiveIndividualMessagesFromBatch(const ClientConnectionPtr& cnx, Message& batchedMessage, - const BitSet& ackSet, int redeliveryCount, - const optional& encryptionContext); + const BitSet& ackSet, int redeliveryCount); bool isPriorBatchIndex(int32_t idx); bool isPriorEntryIndex(int64_t idx); void brokerConsumerStatsListener(Result, BrokerConsumerStatsImpl, const BrokerConsumerStatsCallback&); - enum DecryptResult - { - DECRYPTED, - CONSUME_ENCRYPTED, - FAILED - }; - DecryptResult decryptMessageIfNeeded(const ClientConnectionPtr&, const optional&, - SharedBuffer& payload, const proto::MessageIdData&); + bool decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, + const proto::MessageMetadata& metadata, SharedBuffer& payload); // TODO - Convert these functions to lambda when we move to C++11 Result receiveHelper(Message& msg); diff --git a/lib/EncryptionContext.cc b/lib/EncryptionContext.cc deleted file mode 100644 index 5376f062..00000000 --- a/lib/EncryptionContext.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include - -#include "PulsarApi.pb.h" - -namespace pulsar { - -static EncryptionContext::KeysType encryptedKeysFromMetadata(const proto::MessageMetadata& msgMetadata) { - EncryptionContext::KeysType keys; - for (auto&& key : msgMetadata.encryption_keys()) { - decltype(EncryptionKey::metadata) metadata; - for (int i = 0; i < key.metadata_size(); i++) { - const auto& entry = key.metadata(i); - metadata[entry.key()] = entry.value(); - } - keys.emplace_back(key.key(), key.value(), std::move(metadata)); - } - return keys; -} - -EncryptionContext::EncryptionContext(const proto::MessageMetadata& msgMetadata, bool isDecryptionFailed) - - : keys_(encryptedKeysFromMetadata(msgMetadata)), - param_(msgMetadata.encryption_param()), - algorithm_(msgMetadata.encryption_algo()), - compressionType_(static_cast(msgMetadata.compression())), - uncompressedMessageSize_(msgMetadata.uncompressed_size()), - batchSize_(msgMetadata.has_num_messages_in_batch() ? msgMetadata.num_messages_in_batch() : -1), - isDecryptionFailed_(isDecryptionFailed) {} - -} // namespace pulsar diff --git a/lib/Message.cc b/lib/Message.cc index 5faf9c35..1e26b521 100644 --- a/lib/Message.cc +++ b/lib/Message.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include #include #include #include @@ -25,7 +24,10 @@ #include #include "Int64SerDes.h" +#include "KeyValueImpl.h" #include "MessageImpl.h" +#include "PulsarApi.pb.h" +#include "SharedBuffer.h" using namespace pulsar; @@ -66,7 +68,62 @@ std::string Message::getDataAsString() const { return std::string((const char*)g Message::Message() : impl_() {} -Message::Message(const MessageImplPtr& impl) : impl_(impl) {} +Message::Message(MessageImplPtr& impl) : impl_(impl) {} + +Message::Message(const MessageId& messageId, proto::BrokerEntryMetadata& brokerEntryMetadata, + proto::MessageMetadata& metadata, SharedBuffer& payload) + : impl_(std::make_shared()) { + impl_->messageId = messageId; + impl_->brokerEntryMetadata = brokerEntryMetadata; + impl_->metadata = metadata; + impl_->payload = payload; +} + +Message::Message(const MessageId& messageID, proto::BrokerEntryMetadata& brokerEntryMetadata, + proto::MessageMetadata& metadata, SharedBuffer& payload, + proto::SingleMessageMetadata& singleMetadata, const std::shared_ptr& topicName) + : impl_(std::make_shared()) { + impl_->messageId = messageID; + impl_->brokerEntryMetadata = brokerEntryMetadata; + impl_->metadata = metadata; + impl_->payload = payload; + impl_->metadata.mutable_properties()->CopyFrom(singleMetadata.properties()); + impl_->topicName_ = topicName; + + impl_->metadata.clear_properties(); + if (singleMetadata.properties_size() > 0) { + impl_->metadata.mutable_properties()->Reserve(singleMetadata.properties_size()); + for (int i = 0; i < singleMetadata.properties_size(); i++) { + auto keyValue = proto::KeyValue().New(); + *keyValue = singleMetadata.properties(i); + impl_->metadata.mutable_properties()->AddAllocated(keyValue); + } + } + + if (singleMetadata.has_partition_key()) { + impl_->metadata.set_partition_key(singleMetadata.partition_key()); + } else { + impl_->metadata.clear_partition_key(); + } + + if (singleMetadata.has_ordering_key()) { + impl_->metadata.set_ordering_key(singleMetadata.ordering_key()); + } else { + impl_->metadata.clear_ordering_key(); + } + + if (singleMetadata.has_event_time()) { + impl_->metadata.set_event_time(singleMetadata.event_time()); + } else { + impl_->metadata.clear_event_time(); + } + + if (singleMetadata.has_sequence_id()) { + impl_->metadata.set_sequence_id(singleMetadata.sequence_id()); + } else { + impl_->metadata.clear_sequence_id(); + } +} const MessageId& Message::getMessageId() const { if (!impl_) { @@ -163,13 +220,6 @@ const std::string& Message::getProducerName() const noexcept { return impl_->metadata.producer_name(); } -std::optional Message::getEncryptionContext() const { - if (!impl_ || !impl_->encryptionContext_.has_value()) { - return std::nullopt; - } - return {&(*impl_->encryptionContext_)}; -} - bool Message::operator==(const Message& msg) const { return getMessageId() == msg.getMessageId(); } KeyValue Message::getKeyValueData() const { return KeyValue(impl_->keyValuePtr); } diff --git a/lib/MessageBatch.cc b/lib/MessageBatch.cc index 4678e766..f2c1cb29 100644 --- a/lib/MessageBatch.cc +++ b/lib/MessageBatch.cc @@ -49,8 +49,7 @@ MessageBatch& MessageBatch::parseFrom(const SharedBuffer& payload, uint32_t batc auto acker = BatchMessageAckerImpl::create(batchSize); for (int i = 0; i < batchSize; ++i) { - batch_.push_back( - Commands::deSerializeSingleMessageInBatch(batchMessage_, i, batchSize, acker, std::nullopt)); + batch_.push_back(Commands::deSerializeSingleMessageInBatch(batchMessage_, i, batchSize, acker)); } return *this; } diff --git a/lib/MessageCrypto.cc b/lib/MessageCrypto.cc index daa492ea..b06ff652 100644 --- a/lib/MessageCrypto.cc +++ b/lib/MessageCrypto.cc @@ -394,13 +394,13 @@ bool MessageCrypto::encrypt(const std::set& encKeys, const CryptoKe return true; } -bool MessageCrypto::decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& keyReader) { - const auto& keyName = encKeys.key; - const auto& encryptedDataKey = encKeys.value; - const auto& encKeyMeta = encKeys.metadata; +bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const CryptoKeyReader& keyReader) { + const auto& keyName = encKeys.key(); + const auto& encryptedDataKey = encKeys.value(); + const auto& encKeyMeta = encKeys.metadata(); StringMap keyMeta; for (auto iter = encKeyMeta.begin(); iter != encKeyMeta.end(); iter++) { - keyMeta[iter->first] = iter->second; + keyMeta[iter->key()] = iter->value(); } // Read the private key info using callback @@ -451,10 +451,11 @@ bool MessageCrypto::decryptDataKey(const EncryptionKey& encKeys, const CryptoKey return true; } -bool MessageCrypto::decryptData(const std::string& dataKeySecret, const EncryptionContext& context, +bool MessageCrypto::decryptData(const std::string& dataKeySecret, const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, SharedBuffer& decryptedPayload) { // unpack iv and encrypted data - context.param().copy(reinterpret_cast(iv_.get()), context.param().size()); + msgMetadata.encryption_param().copy(reinterpret_cast(iv_.get()), + msgMetadata.encryption_param().size()); EVP_CIPHER_CTX* cipherCtx = NULL; decryptedPayload = SharedBuffer::allocate(payload.readableBytes() + EVP_MAX_BLOCK_LENGTH + tagLen_); @@ -517,14 +518,15 @@ bool MessageCrypto::decryptData(const std::string& dataKeySecret, const Encrypti return true; } -bool MessageCrypto::getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& payload, +bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, SharedBuffer& decryptedPayload) { SharedBuffer decryptedData; bool dataDecrypted = false; - for (auto&& kv : context.keys()) { - const std::string& keyName = kv.key; - const std::string& encDataKey = kv.value; + for (auto iter = msgMetadata.encryption_keys().begin(); iter != msgMetadata.encryption_keys().end(); + iter++) { + const std::string& keyName = iter->key(); + const std::string& encDataKey = iter->value(); unsigned char keyDigest[EVP_MAX_MD_SIZE]; unsigned int digestLen = 0; getDigest(keyName, encDataKey.c_str(), encDataKey.size(), keyDigest, digestLen); @@ -537,7 +539,7 @@ bool MessageCrypto::getKeyAndDecryptData(const EncryptionContext& context, Share // retruns a different key, decryption fails. At this point, we would // call decryptDataKey to refresh the cache and come here again to decrypt. auto dataKeyEntry = dataKeyCacheIter->second; - if (decryptData(dataKeyEntry.first, context, payload, decryptedPayload)) { + if (decryptData(dataKeyEntry.first, msgMetadata, payload, decryptedPayload)) { dataDecrypted = true; break; } @@ -550,16 +552,17 @@ bool MessageCrypto::getKeyAndDecryptData(const EncryptionContext& context, Share return dataDecrypted; } -bool MessageCrypto::decrypt(const EncryptionContext& context, SharedBuffer& payload, +bool MessageCrypto::decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, const CryptoKeyReaderPtr& keyReader, SharedBuffer& decryptedPayload) { // Attempt to decrypt using the existing key - if (getKeyAndDecryptData(context, payload, decryptedPayload)) { + if (getKeyAndDecryptData(msgMetadata, payload, decryptedPayload)) { return true; } // Either first time, or decryption failed. Attempt to regenerate data key bool isDataKeyDecrypted = false; - for (auto&& encKeys : context.keys()) { + for (int index = 0; index < msgMetadata.encryption_keys_size(); index++) { + const proto::EncryptionKeys& encKeys = msgMetadata.encryption_keys(index); if (decryptDataKey(encKeys, *keyReader)) { isDataKeyDecrypted = true; break; @@ -571,7 +574,7 @@ bool MessageCrypto::decrypt(const EncryptionContext& context, SharedBuffer& payl return false; } - return getKeyAndDecryptData(context, payload, decryptedPayload); + return getKeyAndDecryptData(msgMetadata, payload, decryptedPayload); } } /* namespace pulsar */ diff --git a/lib/MessageCrypto.h b/lib/MessageCrypto.h index 84075ce9..cd07bf55 100644 --- a/lib/MessageCrypto.h +++ b/lib/MessageCrypto.h @@ -26,10 +26,10 @@ #include #include #include -#include #include #include +#include #include #include #include @@ -90,15 +90,15 @@ class MessageCrypto { /* * Decrypt the payload using the data key. Keys used to encrypt data key can be retrieved from msgMetadata * - * @param context the context of encryption + * @param msgMetadata Message Metadata * @param payload Message which needs to be decrypted * @param keyReader KeyReader implementation to retrieve key value * @param decryptedPayload Contains decrypted payload if success * * @return true if success */ - bool decrypt(const EncryptionContext& context, SharedBuffer& payload, const CryptoKeyReaderPtr& keyReader, - SharedBuffer& decryptedPayload); + bool decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, + const CryptoKeyReaderPtr& keyReader, SharedBuffer& decryptedPayload); private: typedef std::unique_lock Lock; @@ -137,10 +137,10 @@ class MessageCrypto { Result addPublicKeyCipher(const std::string& keyName, const CryptoKeyReaderPtr& keyReader); - bool decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& keyReader); - bool decryptData(const std::string& dataKeySecret, const EncryptionContext& context, + bool decryptDataKey(const proto::EncryptionKeys& encKeys, const CryptoKeyReader& keyReader); + bool decryptData(const std::string& dataKeySecret, const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, SharedBuffer& decPayload); - bool getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& payload, + bool getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, SharedBuffer& decryptedPayload); std::string stringToHex(const std::string& inputStr, size_t len); std::string stringToHex(const char* inputStr, size_t len); diff --git a/lib/MessageImpl.cc b/lib/MessageImpl.cc index da3b8902..17239a82 100644 --- a/lib/MessageImpl.cc +++ b/lib/MessageImpl.cc @@ -18,59 +18,8 @@ */ #include "MessageImpl.h" -#include - -#include "PulsarApi.pb.h" - namespace pulsar { -MessageImpl::MessageImpl(const MessageId& messageId, const proto::BrokerEntryMetadata& brokerEntryMetadata, - const proto::MessageMetadata& metadata, const SharedBuffer& payload, - const optional& singleMetadata, - const std::shared_ptr& topicName, - optional encryptionContext) - : messageId(messageId), - brokerEntryMetadata(brokerEntryMetadata), - metadata(metadata), - payload(payload), - topicName_(topicName), - encryptionContext_(std::move(encryptionContext)) { - if (singleMetadata.has_value()) { - this->metadata.clear_properties(); - if (singleMetadata->properties_size() > 0) { - this->metadata.mutable_properties()->Reserve(singleMetadata->properties_size()); - for (int i = 0; i < singleMetadata->properties_size(); i++) { - auto keyValue = proto::KeyValue().New(); - *keyValue = singleMetadata->properties(i); - this->metadata.mutable_properties()->AddAllocated(keyValue); - } - } - if (singleMetadata->has_partition_key()) { - this->metadata.set_partition_key(singleMetadata->partition_key()); - } else { - this->metadata.clear_partition_key(); - } - - if (singleMetadata->has_ordering_key()) { - this->metadata.set_ordering_key(singleMetadata->ordering_key()); - } else { - this->metadata.clear_ordering_key(); - } - - if (singleMetadata->has_event_time()) { - this->metadata.set_event_time(singleMetadata->event_time()); - } else { - this->metadata.clear_event_time(); - } - - if (singleMetadata->has_sequence_id()) { - this->metadata.set_sequence_id(singleMetadata->sequence_id()); - } else { - this->metadata.clear_sequence_id(); - } - } -} - const Message::StringMap& MessageImpl::properties() { if (properties_.size() == 0) { for (int i = 0; i < metadata.properties_size(); i++) { diff --git a/lib/MessageImpl.h b/lib/MessageImpl.h index 8c8eb016..6467b359 100644 --- a/lib/MessageImpl.h +++ b/lib/MessageImpl.h @@ -19,19 +19,14 @@ #ifndef LIB_MESSAGEIMPL_H_ #define LIB_MESSAGEIMPL_H_ -#include #include #include -#include -#include - #include "KeyValueImpl.h" #include "PulsarApi.pb.h" #include "SharedBuffer.h" -using std::optional; - +using namespace pulsar; namespace pulsar { class PulsarWrapper; @@ -40,14 +35,20 @@ class BatchMessageContainer; class MessageImpl { public: - explicit MessageImpl() : encryptionContext_(std::nullopt) {} - MessageImpl(const MessageId& messageId, const proto::BrokerEntryMetadata& brokerEntryMetadata, - const proto::MessageMetadata& metadata, const SharedBuffer& payload, - const optional& singleMetadata, - const std::shared_ptr& topicName, optional encryptionContext); - const Message::StringMap& properties(); + proto::BrokerEntryMetadata brokerEntryMetadata; + proto::MessageMetadata metadata; + SharedBuffer payload; + std::shared_ptr keyValuePtr; + MessageId messageId; + ClientConnection* cnx_; + std::shared_ptr topicName_; + int redeliveryCount_; + bool hasSchemaVersion_; + const std::string* schemaVersion_; + std::weak_ptr consumerPtr_; + const std::string& getPartitionKey() const; bool hasPartitionKey() const; @@ -80,19 +81,6 @@ class MessageImpl { friend class PulsarWrapper; friend class MessageBuilder; - MessageId messageId; - proto::BrokerEntryMetadata brokerEntryMetadata; - proto::MessageMetadata metadata; - SharedBuffer payload; - std::shared_ptr keyValuePtr; - ClientConnection* cnx_; - std::shared_ptr topicName_; - int redeliveryCount_; - bool hasSchemaVersion_; - const std::string* schemaVersion_; - std::weak_ptr consumerPtr_; - const optional encryptionContext_; - private: void setReplicationClusters(const std::vector& clusters); void setProperty(const std::string& name, const std::string& value); diff --git a/tests/EncryptionTests.cc b/tests/EncryptionTests.cc deleted file mode 100644 index bbbf31ee..00000000 --- a/tests/EncryptionTests.cc +++ /dev/null @@ -1,150 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include - -#include - -#include "PulsarApi.pb.h" -#include "lib/CompressionCodec.h" -#include "lib/MessageCrypto.h" -#include "lib/SharedBuffer.h" -#include "tests/PulsarFriend.h" - -static std::string lookupUrl = "pulsar://localhost:6650"; - -using namespace pulsar; - -static CryptoKeyReaderPtr getDefaultCryptoKeyReader() { - return std::make_shared(TEST_CONF_DIR "/public-key.client-rsa.pem", - TEST_CONF_DIR "/private-key.client-rsa.pem"); -} - -static std::vector decryptValue(const Message& message) { - if (!message.getEncryptionContext().has_value()) { - return {message.getDataAsString()}; - } - auto context = message.getEncryptionContext().value(); - if (!context->isDecryptionFailed()) { - return {message.getDataAsString()}; - } - - MessageCrypto crypto{"test", false}; - auto msgImpl = PulsarFriend::getMessageImplPtr(message); - SharedBuffer decryptedPayload; - auto originalPayload = - SharedBuffer::copy(static_cast(message.getData()), message.getLength()); - if (!crypto.decrypt(*context, originalPayload, getDefaultCryptoKeyReader(), decryptedPayload)) { - throw std::runtime_error("Decryption failed"); - } - - SharedBuffer uncompressedPayload; - if (!CompressionCodecProvider::getCodec(context->compressionType()) - .decode(decryptedPayload, context->uncompressedMessageSize(), uncompressedPayload)) { - throw std::runtime_error("Decompression failed"); - } - - std::vector values; - if (auto batchSize = message.getEncryptionContext().value()->batchSize(); batchSize > 0) { - for (decltype(batchSize) i = 0; i < batchSize; i++) { - auto singleMetaSize = uncompressedPayload.readUnsignedInt(); - proto::SingleMessageMetadata singleMeta; - singleMeta.ParseFromArray(uncompressedPayload.data(), singleMetaSize); - uncompressedPayload.consume(singleMetaSize); - - auto payload = uncompressedPayload.slice(0, singleMeta.payload_size()); - uncompressedPayload.consume(payload.readableBytes()); - values.emplace_back(payload.data(), payload.readableBytes()); - } - } else { - // non-batched message - values.emplace_back(uncompressedPayload.data(), uncompressedPayload.readableBytes()); - } - return values; -} - -static void testDecryption(Client& client, const std::string& topic, bool decryptionSucceed, - int numMessageReceived) { - ProducerConfiguration producerConf; - producerConf.setCompressionType(CompressionLZ4); - producerConf.addEncryptionKey("client-rsa.pem"); - producerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); - - Producer producer; - ASSERT_EQ(ResultOk, client.createProducer(topic, producerConf, producer)); - - std::vector sentValues; - auto send = [&producer, &sentValues](const std::string& value) { - Message msg = MessageBuilder().setContent(value).build(); - producer.sendAsync(msg, nullptr); - sentValues.emplace_back(value); - }; - - for (int i = 0; i < 5; i++) { - send("msg-" + std::to_string(i)); - } - producer.flush(); - send("last-msg"); - producer.flush(); - - ASSERT_EQ(ResultOk, client.createProducer(topic, producer)); - send("unencrypted-msg"); - producer.flush(); - producer.close(); - - ConsumerConfiguration consumerConf; - consumerConf.setSubscriptionInitialPosition(InitialPositionEarliest); - if (decryptionSucceed) { - consumerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); - } else { - consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME); - } - Consumer consumer; - ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConf, consumer)); - - std::vector values; - for (int i = 0; i < numMessageReceived; i++) { - Message msg; - ASSERT_EQ(ResultOk, consumer.receive(msg, 3000)); - if (i < numMessageReceived - 1) { - ASSERT_TRUE(msg.getEncryptionContext().has_value()); - } - for (auto&& value : decryptValue(msg)) { - values.emplace_back(value); - } - } - ASSERT_EQ(values, sentValues); - consumer.close(); -} - -TEST(EncryptionTests, testDecryptionSuccess) { - Client client{lookupUrl}; - std::string topic = "test-decryption-success-" + std::to_string(time(nullptr)); - testDecryption(client, topic, true, 7); - client.close(); -} - -TEST(EncryptionTests, testDecryptionFailure) { - Client client{lookupUrl}; - std::string topic = "test-decryption-failure-" + std::to_string(time(nullptr)); - // The 1st batch that has 5 messages cannot be decrypted, so they can be received only once - testDecryption(client, topic, false, 3); - client.close(); -} diff --git a/tests/PulsarFriend.h b/tests/PulsarFriend.h index 780ec2a9..e7084050 100644 --- a/tests/PulsarFriend.h +++ b/tests/PulsarFriend.h @@ -217,8 +217,6 @@ class PulsarFriend { return waitUntil(std::chrono::seconds(3), [producerImpl] { return !producerImpl->getCnx().expired(); }); } - - static auto getMessageImplPtr(const Message& message) { return message.impl_; } }; } // namespace pulsar From 5775987be4ebd3f5845b452af0b4f1b4758e381b Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Fri, 5 Dec 2025 11:23:57 +0800 Subject: [PATCH 03/12] simplify the implementation --- include/pulsar/EncryptionContext.h | 119 +++++++++++++++++++++++ include/pulsar/Message.h | 8 ++ lib/ConsumerImpl.cc | 48 +++++----- lib/ConsumerImpl.h | 11 ++- lib/EncryptionContext.cc | 48 ++++++++++ lib/Message.cc | 7 ++ lib/MessageCrypto.cc | 35 ++++--- lib/MessageCrypto.h | 14 +-- lib/MessageImpl.h | 3 + tests/EncryptionTest.cc | 147 +++++++++++++++++++++++++++++ tests/PulsarFriend.h | 2 + 11 files changed, 393 insertions(+), 49 deletions(-) create mode 100644 include/pulsar/EncryptionContext.h create mode 100644 lib/EncryptionContext.cc create mode 100644 tests/EncryptionTest.cc diff --git a/include/pulsar/EncryptionContext.h b/include/pulsar/EncryptionContext.h new file mode 100644 index 00000000..7726dfa4 --- /dev/null +++ b/include/pulsar/EncryptionContext.h @@ -0,0 +1,119 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include "CompressionType.h" +#include "defines.h" + +namespace pulsar { + +namespace proto { +class MessageMetadata; +} + +class Message; + +struct PULSAR_PUBLIC EncryptionKey { + std::string key; + std::string value; + std::unordered_map metadata; + + explicit EncryptionKey() = default; + + EncryptionKey(const std::string& key, const std::string& value, + const decltype(EncryptionKey::metadata)& metadata) + : key(key), value(value), metadata(metadata) {} +}; + +/** + * It contains encryption and compression information in it using which application can decrypt consumed + * message with encrypted-payload. + */ +class PULSAR_PUBLIC EncryptionContext { + public: + explicit EncryptionContext() = default; + EncryptionContext(const EncryptionContext&) = default; + EncryptionContext(EncryptionContext&&) noexcept = default; + EncryptionContext& operator=(const EncryptionContext&) = default; + EncryptionContext& operator=(EncryptionContext&&) noexcept = default; + + using KeysType = std::vector; + + /** + * @return the map of encryption keys used for the message + */ + const KeysType& keys() const noexcept { return keys_; } + + /** + * @return the encryption parameter used for the message + */ + const std::string& param() const noexcept { return param_; } + + /** + * @return the encryption algorithm used for the message + */ + const std::string& algorithm() const noexcept { return algorithm_; } + + /** + * @return the compression type used for the message + */ + CompressionType compressionType() const noexcept { return compressionType_; } + + /** + * @return the uncompressed message size if the message is compressed, 0 otherwise + */ + uint32_t uncompressedMessageSize() const noexcept { return uncompressedMessageSize_; } + + /** + * @return the batch size if the message is part of a batch, -1 otherwise + */ + int32_t batchSize() const noexcept { return batchSize_; } + + /** + * When the `ConsumerConfiguration#getCryptoFailureAction` is set to `CONSUME`, the message will still be + * returned even if the decryption failed, in this case, the message payload is still not decrypted but + * users have no way to know that. This method is provided to let users know whether the decryption + * failed. + * + * @return whether the decryption failed + */ + bool isDecryptionFailed() const noexcept { return isDecryptionFailed_; } + + // It should be used only internally but it's exposed so that `std::make_optional` can construct the + // object in place with this constructor. + EncryptionContext(const proto::MessageMetadata& metadata, bool isDecryptionFailed); + + private: + KeysType keys_; + std::string param_; + std::string algorithm_; + CompressionType compressionType_{CompressionNone}; + uint32_t uncompressedMessageSize_{0}; + int32_t batchSize_{-1}; + bool isDecryptionFailed_{false}; + + friend class ConsumerImpl; +}; + +} // namespace pulsar diff --git a/include/pulsar/Message.h b/include/pulsar/Message.h index ea4c4ab4..f52879e8 100644 --- a/include/pulsar/Message.h +++ b/include/pulsar/Message.h @@ -19,10 +19,12 @@ #ifndef MESSAGE_HPP_ #define MESSAGE_HPP_ +#include #include #include #include +#include #include #include "KeyValue.h" @@ -202,6 +204,12 @@ class PULSAR_PUBLIC Message { */ const std::string& getProducerName() const noexcept; + /** + * @return the optional encryption context that is present when the message is encrypted, the pointer is + * valid as the Message instance is alive + */ + std::optional getEncryptionContext() const; + bool operator==(const Message& msg) const; protected: diff --git a/lib/ConsumerImpl.cc b/lib/ConsumerImpl.cc index 4781e966..5b695597 100644 --- a/lib/ConsumerImpl.cc +++ b/lib/ConsumerImpl.cc @@ -19,6 +19,7 @@ #include "ConsumerImpl.h" #include +#include #include #include @@ -549,24 +550,27 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: proto::MessageMetadata& metadata, SharedBuffer& payload) { LOG_DEBUG(getName() << "Received Message -- Size: " << payload.readableBytes()); - if (!decryptMessageIfNeeded(cnx, msg, metadata, payload)) { - // Message was discarded or not consumed due to decryption failure - return; - } - if (!isChecksumValid) { // Message discarded for checksum error discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_ChecksumMismatch); return; } - auto redeliveryCount = msg.redelivery_count(); - const bool isMessageUndecryptable = - metadata.encryption_keys_size() > 0 && !config_.getCryptoKeyReader().get() && - config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME; + auto encryptionContext = metadata.encryption_keys_size() > 0 + ? optional(std::in_place, metadata, false) + : std::nullopt; + const auto decryptionResult = decryptMessageIfNeeded(cnx, msg, encryptionContext, payload); + if (decryptionResult == FAILED) { + // Message was discarded or not consumed due to decryption failure + return; + } else if (decryptionResult == CONSUME_ENCRYPTED && encryptionContext.has_value()) { + // Message is encrypted, but we let the application consume it as-is + encryptionContext->isDecryptionFailed_ = true; + } + auto redeliveryCount = msg.redelivery_count(); const bool isChunkedMessage = metadata.num_chunks_from_msg() > 1; - if (!isMessageUndecryptable && !isChunkedMessage) { + if (decryptionResult == SUCCESS && !isChunkedMessage) { if (!uncompressMessageIfNeeded(cnx, msg.message_id(), metadata, payload, true)) { // Message was discarded on decompression error return; @@ -590,6 +594,7 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: m.impl_->cnx_ = cnx.get(); m.impl_->setTopicName(getTopicPtr()); m.impl_->setRedeliveryCount(msg.redelivery_count()); + m.impl_->encryptionContext_ = std::move(encryptionContext); if (metadata.has_schema_version()) { m.impl_->setSchemaVersion(metadata.schema_version()); @@ -610,7 +615,7 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: return; } - if (metadata.has_num_messages_in_batch()) { + if (metadata.has_num_messages_in_batch() && decryptionResult == SUCCESS) { BitSet::Data words(msg.ack_set_size()); for (int i = 0; i < words.size(); i++) { words[i] = msg.ack_set(i); @@ -812,17 +817,18 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnection return batchSize - skippedMessages; } -bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, - const proto::MessageMetadata& metadata, SharedBuffer& payload) { - if (!metadata.encryption_keys_size()) { - return true; +auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, + const optional& context, SharedBuffer& payload) + -> DecryptionResult { + if (!context.has_value()) { + return SUCCESS; } // If KeyReader is not configured throw exception based on config param if (!config_.isEncryptionEnabled()) { if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) { LOG_WARN(getName() << "CryptoKeyReader is not implemented. Consuming encrypted message."); - return true; + return CONSUME_ENCRYPTED; } else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) { LOG_WARN(getName() << "Skipping decryption since CryptoKeyReader is not implemented and config " "is set to discard"); @@ -833,20 +839,20 @@ bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const auto messageId = MessageIdBuilder::from(msg.message_id()).build(); unAckedMessageTrackerPtr_->add(messageId); } - return false; + return FAILED; } SharedBuffer decryptedPayload; - if (msgCrypto_->decrypt(metadata, payload, config_.getCryptoKeyReader(), decryptedPayload)) { + if (msgCrypto_->decrypt(*context, payload, config_.getCryptoKeyReader(), decryptedPayload)) { payload = decryptedPayload; - return true; + return SUCCESS; } if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) { // Note, batch message will fail to consume even if config is set to consume LOG_WARN( getName() << "Decryption failed. Consuming encrypted message since config is set to consume."); - return true; + return CONSUME_ENCRYPTED; } else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) { LOG_WARN(getName() << "Discarding message since decryption failed and config is set to discard"); discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_DecryptionError); @@ -855,7 +861,7 @@ bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const auto messageId = MessageIdBuilder::from(msg.message_id()).build(); unAckedMessageTrackerPtr_->add(messageId); } - return false; + return FAILED; } bool ConsumerImpl::uncompressMessageIfNeeded(const ClientConnectionPtr& cnx, diff --git a/lib/ConsumerImpl.h b/lib/ConsumerImpl.h index c1df0804..04e3a9c8 100644 --- a/lib/ConsumerImpl.h +++ b/lib/ConsumerImpl.h @@ -195,8 +195,15 @@ class ConsumerImpl : public ConsumerImplBase { bool isPriorEntryIndex(int64_t idx); void brokerConsumerStatsListener(Result, BrokerConsumerStatsImpl, const BrokerConsumerStatsCallback&); - bool decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, - const proto::MessageMetadata& metadata, SharedBuffer& payload); + enum DecryptionResult + { + SUCCESS, + CONSUME_ENCRYPTED, + FAILED + }; + DecryptionResult decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg, + const optional& context, + SharedBuffer& payload); // TODO - Convert these functions to lambda when we move to C++11 Result receiveHelper(Message& msg); diff --git a/lib/EncryptionContext.cc b/lib/EncryptionContext.cc new file mode 100644 index 00000000..5376f062 --- /dev/null +++ b/lib/EncryptionContext.cc @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "PulsarApi.pb.h" + +namespace pulsar { + +static EncryptionContext::KeysType encryptedKeysFromMetadata(const proto::MessageMetadata& msgMetadata) { + EncryptionContext::KeysType keys; + for (auto&& key : msgMetadata.encryption_keys()) { + decltype(EncryptionKey::metadata) metadata; + for (int i = 0; i < key.metadata_size(); i++) { + const auto& entry = key.metadata(i); + metadata[entry.key()] = entry.value(); + } + keys.emplace_back(key.key(), key.value(), std::move(metadata)); + } + return keys; +} + +EncryptionContext::EncryptionContext(const proto::MessageMetadata& msgMetadata, bool isDecryptionFailed) + + : keys_(encryptedKeysFromMetadata(msgMetadata)), + param_(msgMetadata.encryption_param()), + algorithm_(msgMetadata.encryption_algo()), + compressionType_(static_cast(msgMetadata.compression())), + uncompressedMessageSize_(msgMetadata.uncompressed_size()), + batchSize_(msgMetadata.has_num_messages_in_batch() ? msgMetadata.num_messages_in_batch() : -1), + isDecryptionFailed_(isDecryptionFailed) {} + +} // namespace pulsar diff --git a/lib/Message.cc b/lib/Message.cc index 1e26b521..9505565b 100644 --- a/lib/Message.cc +++ b/lib/Message.cc @@ -220,6 +220,13 @@ const std::string& Message::getProducerName() const noexcept { return impl_->metadata.producer_name(); } +std::optional Message::getEncryptionContext() const { + if (!impl_ || !impl_->encryptionContext_.has_value()) { + return std::nullopt; + } + return &impl_->encryptionContext_.value(); +} + bool Message::operator==(const Message& msg) const { return getMessageId() == msg.getMessageId(); } KeyValue Message::getKeyValueData() const { return KeyValue(impl_->keyValuePtr); } diff --git a/lib/MessageCrypto.cc b/lib/MessageCrypto.cc index b06ff652..daa492ea 100644 --- a/lib/MessageCrypto.cc +++ b/lib/MessageCrypto.cc @@ -394,13 +394,13 @@ bool MessageCrypto::encrypt(const std::set& encKeys, const CryptoKe return true; } -bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const CryptoKeyReader& keyReader) { - const auto& keyName = encKeys.key(); - const auto& encryptedDataKey = encKeys.value(); - const auto& encKeyMeta = encKeys.metadata(); +bool MessageCrypto::decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& keyReader) { + const auto& keyName = encKeys.key; + const auto& encryptedDataKey = encKeys.value; + const auto& encKeyMeta = encKeys.metadata; StringMap keyMeta; for (auto iter = encKeyMeta.begin(); iter != encKeyMeta.end(); iter++) { - keyMeta[iter->key()] = iter->value(); + keyMeta[iter->first] = iter->second; } // Read the private key info using callback @@ -451,11 +451,10 @@ bool MessageCrypto::decryptDataKey(const proto::EncryptionKeys& encKeys, const C return true; } -bool MessageCrypto::decryptData(const std::string& dataKeySecret, const proto::MessageMetadata& msgMetadata, +bool MessageCrypto::decryptData(const std::string& dataKeySecret, const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decryptedPayload) { // unpack iv and encrypted data - msgMetadata.encryption_param().copy(reinterpret_cast(iv_.get()), - msgMetadata.encryption_param().size()); + context.param().copy(reinterpret_cast(iv_.get()), context.param().size()); EVP_CIPHER_CTX* cipherCtx = NULL; decryptedPayload = SharedBuffer::allocate(payload.readableBytes() + EVP_MAX_BLOCK_LENGTH + tagLen_); @@ -518,15 +517,14 @@ bool MessageCrypto::decryptData(const std::string& dataKeySecret, const proto::M return true; } -bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, +bool MessageCrypto::getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decryptedPayload) { SharedBuffer decryptedData; bool dataDecrypted = false; - for (auto iter = msgMetadata.encryption_keys().begin(); iter != msgMetadata.encryption_keys().end(); - iter++) { - const std::string& keyName = iter->key(); - const std::string& encDataKey = iter->value(); + for (auto&& kv : context.keys()) { + const std::string& keyName = kv.key; + const std::string& encDataKey = kv.value; unsigned char keyDigest[EVP_MAX_MD_SIZE]; unsigned int digestLen = 0; getDigest(keyName, encDataKey.c_str(), encDataKey.size(), keyDigest, digestLen); @@ -539,7 +537,7 @@ bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetada // retruns a different key, decryption fails. At this point, we would // call decryptDataKey to refresh the cache and come here again to decrypt. auto dataKeyEntry = dataKeyCacheIter->second; - if (decryptData(dataKeyEntry.first, msgMetadata, payload, decryptedPayload)) { + if (decryptData(dataKeyEntry.first, context, payload, decryptedPayload)) { dataDecrypted = true; break; } @@ -552,17 +550,16 @@ bool MessageCrypto::getKeyAndDecryptData(const proto::MessageMetadata& msgMetada return dataDecrypted; } -bool MessageCrypto::decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, +bool MessageCrypto::decrypt(const EncryptionContext& context, SharedBuffer& payload, const CryptoKeyReaderPtr& keyReader, SharedBuffer& decryptedPayload) { // Attempt to decrypt using the existing key - if (getKeyAndDecryptData(msgMetadata, payload, decryptedPayload)) { + if (getKeyAndDecryptData(context, payload, decryptedPayload)) { return true; } // Either first time, or decryption failed. Attempt to regenerate data key bool isDataKeyDecrypted = false; - for (int index = 0; index < msgMetadata.encryption_keys_size(); index++) { - const proto::EncryptionKeys& encKeys = msgMetadata.encryption_keys(index); + for (auto&& encKeys : context.keys()) { if (decryptDataKey(encKeys, *keyReader)) { isDataKeyDecrypted = true; break; @@ -574,7 +571,7 @@ bool MessageCrypto::decrypt(const proto::MessageMetadata& msgMetadata, SharedBuf return false; } - return getKeyAndDecryptData(msgMetadata, payload, decryptedPayload); + return getKeyAndDecryptData(context, payload, decryptedPayload); } } /* namespace pulsar */ diff --git a/lib/MessageCrypto.h b/lib/MessageCrypto.h index cd07bf55..4052066d 100644 --- a/lib/MessageCrypto.h +++ b/lib/MessageCrypto.h @@ -26,10 +26,10 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -90,15 +90,15 @@ class MessageCrypto { /* * Decrypt the payload using the data key. Keys used to encrypt data key can be retrieved from msgMetadata * - * @param msgMetadata Message Metadata + * @param context the encryption context * @param payload Message which needs to be decrypted * @param keyReader KeyReader implementation to retrieve key value * @param decryptedPayload Contains decrypted payload if success * * @return true if success */ - bool decrypt(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, - const CryptoKeyReaderPtr& keyReader, SharedBuffer& decryptedPayload); + bool decrypt(const EncryptionContext& context, SharedBuffer& payload, const CryptoKeyReaderPtr& keyReader, + SharedBuffer& decryptedPayload); private: typedef std::unique_lock Lock; @@ -137,10 +137,10 @@ class MessageCrypto { Result addPublicKeyCipher(const std::string& keyName, const CryptoKeyReaderPtr& keyReader); - bool decryptDataKey(const proto::EncryptionKeys& encKeys, const CryptoKeyReader& keyReader); - bool decryptData(const std::string& dataKeySecret, const proto::MessageMetadata& msgMetadata, + bool decryptDataKey(const EncryptionKey& encKeys, const CryptoKeyReader& keyReader); + bool decryptData(const std::string& dataKeySecret, const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decPayload); - bool getKeyAndDecryptData(const proto::MessageMetadata& msgMetadata, SharedBuffer& payload, + bool getKeyAndDecryptData(const EncryptionContext& context, SharedBuffer& payload, SharedBuffer& decryptedPayload); std::string stringToHex(const std::string& inputStr, size_t len); std::string stringToHex(const char* inputStr, size_t len); diff --git a/lib/MessageImpl.h b/lib/MessageImpl.h index 6467b359..1b53c448 100644 --- a/lib/MessageImpl.h +++ b/lib/MessageImpl.h @@ -21,10 +21,12 @@ #include #include +#include #include "KeyValueImpl.h" #include "PulsarApi.pb.h" #include "SharedBuffer.h" +#include "pulsar/EncryptionContext.h" using namespace pulsar; namespace pulsar { @@ -48,6 +50,7 @@ class MessageImpl { bool hasSchemaVersion_; const std::string* schemaVersion_; std::weak_ptr consumerPtr_; + std::optional encryptionContext_; const std::string& getPartitionKey() const; bool hasPartitionKey() const; diff --git a/tests/EncryptionTest.cc b/tests/EncryptionTest.cc new file mode 100644 index 00000000..8fd83e9b --- /dev/null +++ b/tests/EncryptionTest.cc @@ -0,0 +1,147 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include + +#include +#include + +#include "PulsarApi.pb.h" +#include "lib/CompressionCodec.h" +#include "lib/MessageCrypto.h" +#include "lib/SharedBuffer.h" +#include "tests/PulsarFriend.h" + +static std::string lookupUrl = "pulsar://localhost:6650"; + +using namespace pulsar; + +static CryptoKeyReaderPtr getDefaultCryptoKeyReader() { + return std::make_shared(TEST_CONF_DIR "/public-key.client-rsa.pem", + TEST_CONF_DIR "/private-key.client-rsa.pem"); +} + +static std::vector decryptValue(const char* data, size_t length, + std::optional context) { + if (!context.has_value()) { + return {std::string(data, length)}; + } + if (!context.value()->isDecryptionFailed()) { + return {std::string(data, length)}; + } + + MessageCrypto crypto{"test", false}; + SharedBuffer decryptedPayload; + auto originalPayload = SharedBuffer::copy(data, length); + if (!crypto.decrypt(*context.value(), originalPayload, getDefaultCryptoKeyReader(), decryptedPayload)) { + throw std::runtime_error("Decryption failed"); + } + + SharedBuffer uncompressedPayload; + if (!CompressionCodecProvider::getCodec(context.value()->compressionType()) + .decode(decryptedPayload, context.value()->uncompressedMessageSize(), uncompressedPayload)) { + throw std::runtime_error("Decompression failed"); + } + + std::vector values; + if (auto batchSize = context.value()->batchSize(); batchSize > 0) { + for (decltype(batchSize) i = 0; i < batchSize; i++) { + auto singleMetaSize = uncompressedPayload.readUnsignedInt(); + proto::SingleMessageMetadata singleMeta; + singleMeta.ParseFromArray(uncompressedPayload.data(), singleMetaSize); + uncompressedPayload.consume(singleMetaSize); + + auto payload = uncompressedPayload.slice(0, singleMeta.payload_size()); + uncompressedPayload.consume(payload.readableBytes()); + values.emplace_back(payload.data(), payload.readableBytes()); + } + } else { + // non-batched message + values.emplace_back(uncompressedPayload.data(), uncompressedPayload.readableBytes()); + } + return values; +} + +static void testDecryption(Client& client, const std::string& topic, bool decryptionSucceed, + int numMessageReceived) { + ProducerConfiguration producerConf; + producerConf.setCompressionType(CompressionLZ4); + producerConf.addEncryptionKey("client-rsa.pem"); + producerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); + + Producer producer; + ASSERT_EQ(ResultOk, client.createProducer(topic, producerConf, producer)); + + std::vector sentValues; + auto send = [&producer, &sentValues](const std::string& value) { + Message msg = MessageBuilder().setContent(value).build(); + producer.sendAsync(msg, nullptr); + sentValues.emplace_back(value); + }; + + for (int i = 0; i < 5; i++) { + send("msg-" + std::to_string(i)); + } + producer.flush(); + send("last-msg"); + producer.flush(); + + ASSERT_EQ(ResultOk, client.createProducer(topic, producer)); + send("unencrypted-msg"); + producer.flush(); + producer.close(); + + ConsumerConfiguration consumerConf; + consumerConf.setSubscriptionInitialPosition(InitialPositionEarliest); + if (decryptionSucceed) { + consumerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); + } else { + consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME); + } + Consumer consumer; + ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConf, consumer)); + + std::vector values; + for (int i = 0; i < numMessageReceived; i++) { + Message msg; + ASSERT_EQ(ResultOk, consumer.receive(msg, 3000)); + for (auto&& value : decryptValue(static_cast(msg.getData()), msg.getLength(), + msg.getEncryptionContext())) { + values.emplace_back(value); + } + } + ASSERT_EQ(values, sentValues); + consumer.close(); +} + +TEST(EncryptionTests, testDecryptionSuccess) { + Client client{lookupUrl}; + std::string topic = "test-decryption-success-" + std::to_string(time(nullptr)); + testDecryption(client, topic, true, 7); + client.close(); +} + +TEST(EncryptionTests, testDecryptionFailure) { + Client client{lookupUrl}; + std::string topic = "test-decryption-failure-" + std::to_string(time(nullptr)); + // The 1st batch that has 5 messages cannot be decrypted, so they can be received only once + testDecryption(client, topic, false, 3); + client.close(); +} diff --git a/tests/PulsarFriend.h b/tests/PulsarFriend.h index e7084050..780ec2a9 100644 --- a/tests/PulsarFriend.h +++ b/tests/PulsarFriend.h @@ -217,6 +217,8 @@ class PulsarFriend { return waitUntil(std::chrono::seconds(3), [producerImpl] { return !producerImpl->getCnx().expired(); }); } + + static auto getMessageImplPtr(const Message& message) { return message.impl_; } }; } // namespace pulsar From aec221524458ac38b94f7b3c07ea5066336e9337 Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Fri, 5 Dec 2025 12:42:21 +0800 Subject: [PATCH 04/12] fix format --- lib/MessageImpl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/MessageImpl.h b/lib/MessageImpl.h index 1b53c448..a234ca45 100644 --- a/lib/MessageImpl.h +++ b/lib/MessageImpl.h @@ -21,6 +21,7 @@ #include #include + #include #include "KeyValueImpl.h" From dcaf9207783f811545aa4a7fa15180b8d7657740 Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Fri, 5 Dec 2025 13:47:02 +0800 Subject: [PATCH 05/12] Fix windows example --- win-examples/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/win-examples/CMakeLists.txt b/win-examples/CMakeLists.txt index 3998c43a..c8d74b60 100644 --- a/win-examples/CMakeLists.txt +++ b/win-examples/CMakeLists.txt @@ -20,6 +20,7 @@ cmake_minimum_required(VERSION 3.4) project(pulsar-cpp-win-examples) +set(CMAKE_CXX_STANDARD 17) find_path(PULSAR_INCLUDES NAMES "pulsar/Client.h") if (PULSAR_INCLUDES) message(STATUS "PULSAR_INCLUDES: " ${PULSAR_INCLUDES}) From 5e159b133391e4a7d6d07cbaea4b76fa00da6168 Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Fri, 5 Dec 2025 15:53:50 +0800 Subject: [PATCH 06/12] Fix encryption context not set for batched messages --- lib/Commands.cc | 1 + tests/BasicEndToEndTest.cc | 4 ++++ tests/EncryptionTest.cc | 6 +++++- tests/PulsarFriend.h | 2 -- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/lib/Commands.cc b/lib/Commands.cc index 3c687c0a..30f5bf1a 100644 --- a/lib/Commands.cc +++ b/lib/Commands.cc @@ -930,6 +930,7 @@ Message Commands::deSerializeSingleMessageInBatch(Message& batchedMessage, int32 batchedMessage.impl_->metadata, payload, metadata, batchedMessage.impl_->topicName_); singleMessage.impl_->cnx_ = batchedMessage.impl_->cnx_; + singleMessage.impl_->encryptionContext_ = batchedMessage.impl_->encryptionContext_; return singleMessage; } diff --git a/tests/BasicEndToEndTest.cc b/tests/BasicEndToEndTest.cc index c9a8faa9..9a02df0c 100644 --- a/tests/BasicEndToEndTest.cc +++ b/tests/BasicEndToEndTest.cc @@ -1465,6 +1465,10 @@ TEST(BasicEndToEndTest, testRSAEncryption) { expected << msgContent << msgNum; ASSERT_EQ(expected.str(), msgReceived.getDataAsString()); ASSERT_EQ(ResultOk, consumer.acknowledge(msgReceived)); + auto context = msgReceived.getEncryptionContext(); + ASSERT_TRUE(context.has_value()); + ASSERT_EQ(context.value()->keys().size(), 1); + ASSERT_EQ(context.value()->keys()[0].key, "client-rsa.pem"); } ASSERT_EQ(ResultOk, consumer.unsubscribe()); diff --git a/tests/EncryptionTest.cc b/tests/EncryptionTest.cc index 8fd83e9b..97f60179 100644 --- a/tests/EncryptionTest.cc +++ b/tests/EncryptionTest.cc @@ -27,7 +27,6 @@ #include "lib/CompressionCodec.h" #include "lib/MessageCrypto.h" #include "lib/SharedBuffer.h" -#include "tests/PulsarFriend.h" static std::string lookupUrl = "pulsar://localhost:6650"; @@ -122,6 +121,11 @@ static void testDecryption(Client& client, const std::string& topic, bool decryp for (int i = 0; i < numMessageReceived; i++) { Message msg; ASSERT_EQ(ResultOk, consumer.receive(msg, 3000)); + if (i < numMessageReceived - 1) { + ASSERT_TRUE(msg.getEncryptionContext().has_value()); + } else { + ASSERT_FALSE(msg.getEncryptionContext().has_value()); + } for (auto&& value : decryptValue(static_cast(msg.getData()), msg.getLength(), msg.getEncryptionContext())) { values.emplace_back(value); diff --git a/tests/PulsarFriend.h b/tests/PulsarFriend.h index 780ec2a9..e7084050 100644 --- a/tests/PulsarFriend.h +++ b/tests/PulsarFriend.h @@ -217,8 +217,6 @@ class PulsarFriend { return waitUntil(std::chrono::seconds(3), [producerImpl] { return !producerImpl->getCnx().expired(); }); } - - static auto getMessageImplPtr(const Message& message) { return message.impl_; } }; } // namespace pulsar From 832feda240f1251fe4840188333b28598738cf6e Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Fri, 5 Dec 2025 16:28:29 +0800 Subject: [PATCH 07/12] Simplify header --- include/pulsar/EncryptionContext.h | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/include/pulsar/EncryptionContext.h b/include/pulsar/EncryptionContext.h index 7726dfa4..c97f0fe2 100644 --- a/include/pulsar/EncryptionContext.h +++ b/include/pulsar/EncryptionContext.h @@ -52,12 +52,6 @@ struct PULSAR_PUBLIC EncryptionKey { */ class PULSAR_PUBLIC EncryptionContext { public: - explicit EncryptionContext() = default; - EncryptionContext(const EncryptionContext&) = default; - EncryptionContext(EncryptionContext&&) noexcept = default; - EncryptionContext& operator=(const EncryptionContext&) = default; - EncryptionContext& operator=(EncryptionContext&&) noexcept = default; - using KeysType = std::vector; /** @@ -100,9 +94,11 @@ class PULSAR_PUBLIC EncryptionContext { */ bool isDecryptionFailed() const noexcept { return isDecryptionFailed_; } - // It should be used only internally but it's exposed so that `std::make_optional` can construct the - // object in place with this constructor. - EncryptionContext(const proto::MessageMetadata& metadata, bool isDecryptionFailed); + /** + * It should be used only internally but it's exposed so that `std::make_optional` can construct the + * object in place with this constructor. + */ + EncryptionContext(const proto::MessageMetadata&, bool); private: KeysType keys_; From ee909bedc8b202ec5997609779d799858af65677 Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Fri, 5 Dec 2025 16:36:16 +0800 Subject: [PATCH 08/12] Use MessageBatch to de-batch payload --- tests/EncryptionTest.cc | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/EncryptionTest.cc b/tests/EncryptionTest.cc index 97f60179..84d5c109 100644 --- a/tests/EncryptionTest.cc +++ b/tests/EncryptionTest.cc @@ -19,11 +19,11 @@ #include #include #include +#include #include #include -#include "PulsarApi.pb.h" #include "lib/CompressionCodec.h" #include "lib/MessageCrypto.h" #include "lib/SharedBuffer.h" @@ -61,15 +61,9 @@ static std::vector decryptValue(const char* data, size_t length, std::vector values; if (auto batchSize = context.value()->batchSize(); batchSize > 0) { - for (decltype(batchSize) i = 0; i < batchSize; i++) { - auto singleMetaSize = uncompressedPayload.readUnsignedInt(); - proto::SingleMessageMetadata singleMeta; - singleMeta.ParseFromArray(uncompressedPayload.data(), singleMetaSize); - uncompressedPayload.consume(singleMetaSize); - - auto payload = uncompressedPayload.slice(0, singleMeta.payload_size()); - uncompressedPayload.consume(payload.readableBytes()); - values.emplace_back(payload.data(), payload.readableBytes()); + MessageBatch batch; + for (auto&& msg : batch.parseFrom(uncompressedPayload, batchSize).messages()) { + values.emplace_back(msg.getDataAsString()); } } else { // non-batched message From 99cf8822b14402e6d722b85816200487f74d0e1e Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Fri, 5 Dec 2025 17:43:38 +0800 Subject: [PATCH 09/12] address copilot comments --- include/pulsar/EncryptionContext.h | 11 +++++------ lib/ConsumerImpl.cc | 22 +++++++++++----------- lib/ConsumerImpl.h | 2 +- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/include/pulsar/EncryptionContext.h b/include/pulsar/EncryptionContext.h index c97f0fe2..89023cfa 100644 --- a/include/pulsar/EncryptionContext.h +++ b/include/pulsar/EncryptionContext.h @@ -32,15 +32,11 @@ namespace proto { class MessageMetadata; } -class Message; - struct PULSAR_PUBLIC EncryptionKey { std::string key; std::string value; std::unordered_map metadata; - explicit EncryptionKey() = default; - EncryptionKey(const std::string& key, const std::string& value, const decltype(EncryptionKey::metadata)& metadata) : key(key), value(value), metadata(metadata) {} @@ -95,8 +91,9 @@ class PULSAR_PUBLIC EncryptionContext { bool isDecryptionFailed() const noexcept { return isDecryptionFailed_; } /** - * It should be used only internally but it's exposed so that `std::make_optional` can construct the - * object in place with this constructor. + * This constructor is public to allow in-place construction via std::optional + * (e.g., `std::optional(std::in_place, metadata, false)`), + * but should not be used directly in application code. */ EncryptionContext(const proto::MessageMetadata&, bool); @@ -109,6 +106,8 @@ class PULSAR_PUBLIC EncryptionContext { int32_t batchSize_{-1}; bool isDecryptionFailed_{false}; + void setDecryptionFailed(bool failed) noexcept { isDecryptionFailed_ = failed; } + friend class ConsumerImpl; }; diff --git a/lib/ConsumerImpl.cc b/lib/ConsumerImpl.cc index 5b695597..430b8512 100644 --- a/lib/ConsumerImpl.cc +++ b/lib/ConsumerImpl.cc @@ -560,17 +560,17 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: ? optional(std::in_place, metadata, false) : std::nullopt; const auto decryptionResult = decryptMessageIfNeeded(cnx, msg, encryptionContext, payload); - if (decryptionResult == FAILED) { + if (decryptionResult == DecryptionResult::FAILED) { // Message was discarded or not consumed due to decryption failure return; - } else if (decryptionResult == CONSUME_ENCRYPTED && encryptionContext.has_value()) { + } else if (decryptionResult == DecryptionResult::CONSUME_ENCRYPTED && encryptionContext.has_value()) { // Message is encrypted, but we let the application consume it as-is - encryptionContext->isDecryptionFailed_ = true; + encryptionContext->setDecryptionFailed(true); } auto redeliveryCount = msg.redelivery_count(); const bool isChunkedMessage = metadata.num_chunks_from_msg() > 1; - if (decryptionResult == SUCCESS && !isChunkedMessage) { + if (decryptionResult == DecryptionResult::SUCCESS && !isChunkedMessage) { if (!uncompressMessageIfNeeded(cnx, msg.message_id(), metadata, payload, true)) { // Message was discarded on decompression error return; @@ -615,7 +615,7 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto:: return; } - if (metadata.has_num_messages_in_batch() && decryptionResult == SUCCESS) { + if (metadata.has_num_messages_in_batch() && decryptionResult == DecryptionResult::SUCCESS) { BitSet::Data words(msg.ack_set_size()); for (int i = 0; i < words.size(); i++) { words[i] = msg.ack_set(i); @@ -821,14 +821,14 @@ auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const const optional& context, SharedBuffer& payload) -> DecryptionResult { if (!context.has_value()) { - return SUCCESS; + return DecryptionResult::SUCCESS; } // If KeyReader is not configured throw exception based on config param if (!config_.isEncryptionEnabled()) { if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) { LOG_WARN(getName() << "CryptoKeyReader is not implemented. Consuming encrypted message."); - return CONSUME_ENCRYPTED; + return DecryptionResult::CONSUME_ENCRYPTED; } else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) { LOG_WARN(getName() << "Skipping decryption since CryptoKeyReader is not implemented and config " "is set to discard"); @@ -839,20 +839,20 @@ auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const auto messageId = MessageIdBuilder::from(msg.message_id()).build(); unAckedMessageTrackerPtr_->add(messageId); } - return FAILED; + return DecryptionResult::FAILED; } SharedBuffer decryptedPayload; if (msgCrypto_->decrypt(*context, payload, config_.getCryptoKeyReader(), decryptedPayload)) { payload = decryptedPayload; - return SUCCESS; + return DecryptionResult::SUCCESS; } if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) { // Note, batch message will fail to consume even if config is set to consume LOG_WARN( getName() << "Decryption failed. Consuming encrypted message since config is set to consume."); - return CONSUME_ENCRYPTED; + return DecryptionResult::CONSUME_ENCRYPTED; } else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) { LOG_WARN(getName() << "Discarding message since decryption failed and config is set to discard"); discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_DecryptionError); @@ -861,7 +861,7 @@ auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const auto messageId = MessageIdBuilder::from(msg.message_id()).build(); unAckedMessageTrackerPtr_->add(messageId); } - return FAILED; + return DecryptionResult::FAILED; } bool ConsumerImpl::uncompressMessageIfNeeded(const ClientConnectionPtr& cnx, diff --git a/lib/ConsumerImpl.h b/lib/ConsumerImpl.h index 04e3a9c8..63eb51d6 100644 --- a/lib/ConsumerImpl.h +++ b/lib/ConsumerImpl.h @@ -195,7 +195,7 @@ class ConsumerImpl : public ConsumerImplBase { bool isPriorEntryIndex(int64_t idx); void brokerConsumerStatsListener(Result, BrokerConsumerStatsImpl, const BrokerConsumerStatsCallback&); - enum DecryptionResult + enum class DecryptionResult : uint8_t { SUCCESS, CONSUME_ENCRYPTED, From 9ea5a406ff65301f3465fa7457b8d8ebc390a814 Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Fri, 5 Dec 2025 17:45:39 +0800 Subject: [PATCH 10/12] improve naming --- tests/EncryptionTest.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/EncryptionTest.cc b/tests/EncryptionTest.cc index 84d5c109..ff5cb98e 100644 --- a/tests/EncryptionTest.cc +++ b/tests/EncryptionTest.cc @@ -72,7 +72,7 @@ static std::vector decryptValue(const char* data, size_t length, return values; } -static void testDecryption(Client& client, const std::string& topic, bool decryptionSucceed, +static void testDecryption(Client& client, const std::string& topic, bool withDecryption, int numMessageReceived) { ProducerConfiguration producerConf; producerConf.setCompressionType(CompressionLZ4); @@ -103,7 +103,7 @@ static void testDecryption(Client& client, const std::string& topic, bool decryp ConsumerConfiguration consumerConf; consumerConf.setSubscriptionInitialPosition(InitialPositionEarliest); - if (decryptionSucceed) { + if (withDecryption) { consumerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); } else { consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME); From 70ab53f3d46e680b3c62ad836992e74c3865d94d Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Mon, 8 Dec 2025 18:11:46 +0800 Subject: [PATCH 11/12] Update include/pulsar/EncryptionContext.h Co-authored-by: Zike Yang --- include/pulsar/EncryptionContext.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/pulsar/EncryptionContext.h b/include/pulsar/EncryptionContext.h index 89023cfa..bbed32e3 100644 --- a/include/pulsar/EncryptionContext.h +++ b/include/pulsar/EncryptionContext.h @@ -82,8 +82,7 @@ class PULSAR_PUBLIC EncryptionContext { /** * When the `ConsumerConfiguration#getCryptoFailureAction` is set to `CONSUME`, the message will still be - * returned even if the decryption failed, in this case, the message payload is still not decrypted but - * users have no way to know that. This method is provided to let users know whether the decryption + * returned even if the decryption failed. This method is provided to let users know whether the decryption * failed. * * @return whether the decryption failed From 2fca472b62b3ac31fdd3c4baa1e7055ca7959bb7 Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Mon, 8 Dec 2025 20:19:59 +0800 Subject: [PATCH 12/12] fix format --- include/pulsar/EncryptionContext.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/pulsar/EncryptionContext.h b/include/pulsar/EncryptionContext.h index bbed32e3..ac7ebf91 100644 --- a/include/pulsar/EncryptionContext.h +++ b/include/pulsar/EncryptionContext.h @@ -82,8 +82,8 @@ class PULSAR_PUBLIC EncryptionContext { /** * When the `ConsumerConfiguration#getCryptoFailureAction` is set to `CONSUME`, the message will still be - * returned even if the decryption failed. This method is provided to let users know whether the decryption - * failed. + * returned even if the decryption failed. This method is provided to let users know whether the + * decryption failed. * * @return whether the decryption failed */