Skip to content
113 changes: 113 additions & 0 deletions include/pulsar/EncryptionContext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#pragma once

#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>

#include "CompressionType.h"
#include "defines.h"

namespace pulsar {

namespace proto {
class MessageMetadata;
}

struct PULSAR_PUBLIC EncryptionKey {
std::string key;
std::string value;
std::unordered_map<std::string, std::string> metadata;

EncryptionKey(const std::string& key, const std::string& value,
const decltype(EncryptionKey::metadata)& metadata)
: key(key), value(value), metadata(metadata) {}
};

/**
* It contains encryption and compression information in it using which application can decrypt consumed
* message with encrypted-payload.
*/
class PULSAR_PUBLIC EncryptionContext {
public:
using KeysType = std::vector<EncryptionKey>;

/**
* @return the map of encryption keys used for the message
*/
const KeysType& keys() const noexcept { return keys_; }

/**
* @return the encryption parameter used for the message
*/
const std::string& param() const noexcept { return param_; }

/**
* @return the encryption algorithm used for the message
*/
const std::string& algorithm() const noexcept { return algorithm_; }

/**
* @return the compression type used for the message
*/
CompressionType compressionType() const noexcept { return compressionType_; }

/**
* @return the uncompressed message size if the message is compressed, 0 otherwise
*/
uint32_t uncompressedMessageSize() const noexcept { return uncompressedMessageSize_; }

/**
* @return the batch size if the message is part of a batch, -1 otherwise
*/
int32_t batchSize() const noexcept { return batchSize_; }

/**
* When the `ConsumerConfiguration#getCryptoFailureAction` is set to `CONSUME`, the message will still be
* returned even if the decryption failed. This method is provided to let users know whether the
* decryption failed.
*
* @return whether the decryption failed
*/
bool isDecryptionFailed() const noexcept { return isDecryptionFailed_; }

/**
* This constructor is public to allow in-place construction via std::optional
* (e.g., `std::optional<EncryptionContext>(std::in_place, metadata, false)`),
* but should not be used directly in application code.
*/
EncryptionContext(const proto::MessageMetadata&, bool);

private:
KeysType keys_;
std::string param_;
std::string algorithm_;
CompressionType compressionType_{CompressionNone};
uint32_t uncompressedMessageSize_{0};
int32_t batchSize_{-1};
bool isDecryptionFailed_{false};

void setDecryptionFailed(bool failed) noexcept { isDecryptionFailed_ = failed; }

friend class ConsumerImpl;
};

} // namespace pulsar
8 changes: 8 additions & 0 deletions include/pulsar/Message.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
#ifndef MESSAGE_HPP_
#define MESSAGE_HPP_

#include <pulsar/EncryptionContext.h>
#include <pulsar/defines.h>

#include <map>
#include <memory>
#include <optional>
#include <string>

#include "KeyValue.h"
Expand Down Expand Up @@ -202,6 +204,12 @@ class PULSAR_PUBLIC Message {
*/
const std::string& getProducerName() const noexcept;

/**
* @return the optional encryption context that is present when the message is encrypted, the pointer is
* valid as the Message instance is alive
*/
std::optional<const EncryptionContext*> getEncryptionContext() const;

bool operator==(const Message& msg) const;

protected:
Expand Down
1 change: 1 addition & 0 deletions lib/Commands.cc
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@ Message Commands::deSerializeSingleMessageInBatch(Message& batchedMessage, int32
batchedMessage.impl_->metadata, payload, metadata,
batchedMessage.impl_->topicName_);
singleMessage.impl_->cnx_ = batchedMessage.impl_->cnx_;
singleMessage.impl_->encryptionContext_ = batchedMessage.impl_->encryptionContext_;

return singleMessage;
}
Expand Down
48 changes: 27 additions & 21 deletions lib/ConsumerImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "ConsumerImpl.h"

#include <pulsar/DeadLetterPolicyBuilder.h>
#include <pulsar/EncryptionContext.h>
#include <pulsar/MessageIdBuilder.h>

#include <algorithm>
Expand Down Expand Up @@ -549,24 +550,27 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto::
proto::MessageMetadata& metadata, SharedBuffer& payload) {
LOG_DEBUG(getName() << "Received Message -- Size: " << payload.readableBytes());

if (!decryptMessageIfNeeded(cnx, msg, metadata, payload)) {
// Message was discarded or not consumed due to decryption failure
return;
}

if (!isChecksumValid) {
// Message discarded for checksum error
discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_ChecksumMismatch);
return;
}

auto redeliveryCount = msg.redelivery_count();
const bool isMessageUndecryptable =
metadata.encryption_keys_size() > 0 && !config_.getCryptoKeyReader().get() &&
config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME;
auto encryptionContext = metadata.encryption_keys_size() > 0
? optional<EncryptionContext>(std::in_place, metadata, false)
: std::nullopt;
const auto decryptionResult = decryptMessageIfNeeded(cnx, msg, encryptionContext, payload);
if (decryptionResult == DecryptionResult::FAILED) {
// Message was discarded or not consumed due to decryption failure
return;
} else if (decryptionResult == DecryptionResult::CONSUME_ENCRYPTED && encryptionContext.has_value()) {
// Message is encrypted, but we let the application consume it as-is
encryptionContext->setDecryptionFailed(true);
}

auto redeliveryCount = msg.redelivery_count();
const bool isChunkedMessage = metadata.num_chunks_from_msg() > 1;
if (!isMessageUndecryptable && !isChunkedMessage) {
if (decryptionResult == DecryptionResult::SUCCESS && !isChunkedMessage) {
if (!uncompressMessageIfNeeded(cnx, msg.message_id(), metadata, payload, true)) {
// Message was discarded on decompression error
return;
Expand All @@ -590,6 +594,7 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto::
m.impl_->cnx_ = cnx.get();
m.impl_->setTopicName(getTopicPtr());
m.impl_->setRedeliveryCount(msg.redelivery_count());
m.impl_->encryptionContext_ = std::move(encryptionContext);

if (metadata.has_schema_version()) {
m.impl_->setSchemaVersion(metadata.schema_version());
Expand All @@ -610,7 +615,7 @@ void ConsumerImpl::messageReceived(const ClientConnectionPtr& cnx, const proto::
return;
}

if (metadata.has_num_messages_in_batch()) {
if (metadata.has_num_messages_in_batch() && decryptionResult == DecryptionResult::SUCCESS) {
BitSet::Data words(msg.ack_set_size());
for (int i = 0; i < words.size(); i++) {
words[i] = msg.ack_set(i);
Expand Down Expand Up @@ -812,17 +817,18 @@ uint32_t ConsumerImpl::receiveIndividualMessagesFromBatch(const ClientConnection
return batchSize - skippedMessages;
}

bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg,
const proto::MessageMetadata& metadata, SharedBuffer& payload) {
if (!metadata.encryption_keys_size()) {
return true;
auto ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg,
const optional<EncryptionContext>& context, SharedBuffer& payload)
-> DecryptionResult {
if (!context.has_value()) {
return DecryptionResult::SUCCESS;
}

// If KeyReader is not configured throw exception based on config param
if (!config_.isEncryptionEnabled()) {
if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) {
LOG_WARN(getName() << "CryptoKeyReader is not implemented. Consuming encrypted message.");
return true;
return DecryptionResult::CONSUME_ENCRYPTED;
} else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) {
LOG_WARN(getName() << "Skipping decryption since CryptoKeyReader is not implemented and config "
"is set to discard");
Expand All @@ -833,20 +839,20 @@ bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const
auto messageId = MessageIdBuilder::from(msg.message_id()).build();
unAckedMessageTrackerPtr_->add(messageId);
}
return false;
return DecryptionResult::FAILED;
}

SharedBuffer decryptedPayload;
if (msgCrypto_->decrypt(metadata, payload, config_.getCryptoKeyReader(), decryptedPayload)) {
if (msgCrypto_->decrypt(*context, payload, config_.getCryptoKeyReader(), decryptedPayload)) {
payload = decryptedPayload;
return true;
return DecryptionResult::SUCCESS;
}

if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::CONSUME) {
// Note, batch message will fail to consume even if config is set to consume
LOG_WARN(
getName() << "Decryption failed. Consuming encrypted message since config is set to consume.");
return true;
return DecryptionResult::CONSUME_ENCRYPTED;
} else if (config_.getCryptoFailureAction() == ConsumerCryptoFailureAction::DISCARD) {
LOG_WARN(getName() << "Discarding message since decryption failed and config is set to discard");
discardCorruptedMessage(cnx, msg.message_id(), CommandAck_ValidationError_DecryptionError);
Expand All @@ -855,7 +861,7 @@ bool ConsumerImpl::decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const
auto messageId = MessageIdBuilder::from(msg.message_id()).build();
unAckedMessageTrackerPtr_->add(messageId);
}
return false;
return DecryptionResult::FAILED;
}

bool ConsumerImpl::uncompressMessageIfNeeded(const ClientConnectionPtr& cnx,
Expand Down
11 changes: 9 additions & 2 deletions lib/ConsumerImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,15 @@ class ConsumerImpl : public ConsumerImplBase {
bool isPriorEntryIndex(int64_t idx);
void brokerConsumerStatsListener(Result, BrokerConsumerStatsImpl, const BrokerConsumerStatsCallback&);

bool decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg,
const proto::MessageMetadata& metadata, SharedBuffer& payload);
enum class DecryptionResult : uint8_t
{
SUCCESS,
CONSUME_ENCRYPTED,
FAILED
};
DecryptionResult decryptMessageIfNeeded(const ClientConnectionPtr& cnx, const proto::CommandMessage& msg,
const optional<EncryptionContext>& context,
SharedBuffer& payload);

// TODO - Convert these functions to lambda when we move to C++11
Result receiveHelper(Message& msg);
Expand Down
48 changes: 48 additions & 0 deletions lib/EncryptionContext.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <pulsar/EncryptionContext.h>

#include "PulsarApi.pb.h"

namespace pulsar {

static EncryptionContext::KeysType encryptedKeysFromMetadata(const proto::MessageMetadata& msgMetadata) {
EncryptionContext::KeysType keys;
for (auto&& key : msgMetadata.encryption_keys()) {
decltype(EncryptionKey::metadata) metadata;
for (int i = 0; i < key.metadata_size(); i++) {
const auto& entry = key.metadata(i);
metadata[entry.key()] = entry.value();
}
keys.emplace_back(key.key(), key.value(), std::move(metadata));
}
return keys;
}

EncryptionContext::EncryptionContext(const proto::MessageMetadata& msgMetadata, bool isDecryptionFailed)

: keys_(encryptedKeysFromMetadata(msgMetadata)),
param_(msgMetadata.encryption_param()),
algorithm_(msgMetadata.encryption_algo()),
compressionType_(static_cast<CompressionType>(msgMetadata.compression())),
uncompressedMessageSize_(msgMetadata.uncompressed_size()),
batchSize_(msgMetadata.has_num_messages_in_batch() ? msgMetadata.num_messages_in_batch() : -1),
isDecryptionFailed_(isDecryptionFailed) {}

} // namespace pulsar
7 changes: 7 additions & 0 deletions lib/Message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ const std::string& Message::getProducerName() const noexcept {
return impl_->metadata.producer_name();
}

std::optional<const EncryptionContext*> Message::getEncryptionContext() const {
if (!impl_ || !impl_->encryptionContext_.has_value()) {
return std::nullopt;
}
return &impl_->encryptionContext_.value();
}

bool Message::operator==(const Message& msg) const { return getMessageId() == msg.getMessageId(); }

KeyValue Message::getKeyValueData() const { return KeyValue(impl_->keyValuePtr); }
Expand Down
Loading
Loading