Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions Firestore/core/src/remote/datastore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ void Datastore::RunPipeline(
const StatusOr<AuthToken>& auth_token,
const std::string& app_check_token) mutable {
if (!auth_token.ok()) {
// result_callback(auth_token.status());
result_callback(auth_token.status());
return;
}
RunPipelineWithCredentials(auth_token.ValueOrDie(), app_check_token,
Expand All @@ -338,27 +338,40 @@ void Datastore::RunPipelineWithCredentials(
LOG_DEBUG("Run Pipeline: %s", request.ToString());

grpc::ByteBuffer message = MakeByteBuffer(request);
std::unique_ptr<GrpcUnaryCall> call_owning = grpc_connection_.CreateUnaryCall(
kRpcNameExecutePipeline, auth_token, app_check_token, std::move(message));
GrpcUnaryCall* call = call_owning.get();
std::unique_ptr<GrpcStreamingReader> call_owning =
grpc_connection_.CreateStreamingReader(kRpcNameExecutePipeline,
auth_token, app_check_token,
std::move(message));
GrpcStreamingReader* call = call_owning.get();
active_calls_.push_back(std::move(call_owning));

call->Start(
[this, db = pipeline.firestore(), call, callback = std::move(callback)](
const StatusOr<grpc::ByteBuffer>& result) {
LogGrpcCallFinished("ExecutePipeline", call, result.status());
HandleCallStatus(result.status());
auto responses_callback = [this, db = pipeline.firestore(), callback](
const std::vector<grpc::ByteBuffer>& result) {
if (result.empty()) {
callback(util::Status(Error::kErrorInternal,
"Received empty response for RunPipeline"));
return;
}

if (result.ok()) {
auto response = datastore_serializer_.DecodeExecutePipelineResponse(
result.ValueOrDie(), std::move(db));
callback(response);
} else {
callback(result.status());
}
auto response = datastore_serializer_.MergeExecutePipelineResponses(
result, std::move(db));
callback(response);
};

RemoveGrpcCall(call);
});
auto close_callback = [this, call, callback](const util::Status& status,
bool callback_fired) {
if (!callback_fired) {
callback(status);
}
if (!status.ok()) {
LogGrpcCallFinished("ExecutePipeline", call, status);
HandleCallStatus(status);
}
RemoveGrpcCall(call);
};

call->Start(util::Status(Error::kErrorUnknown, "Unknown response count"),
responses_callback, close_callback);
}

void Datastore::ResumeRpcWithCredentials(const OnCredentials& on_credentials) {
Expand Down
7 changes: 4 additions & 3 deletions Firestore/core/src/remote/grpc_streaming_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ GrpcStreamingReader::GrpcStreamingReader(
request_{request} {
}

void GrpcStreamingReader::Start(size_t expected_response_count,
void GrpcStreamingReader::Start(util::StatusOr<size_t> expected_response_count,
ResponsesCallback&& responses_callback,
CloseCallback&& close_callback) {
expected_response_count_ = expected_response_count;
expected_response_count_ = std::move(expected_response_count);
responses_callback_ = std::move(responses_callback);
close_callback_ = std::move(close_callback);
stream_->Start();
Expand All @@ -72,7 +72,8 @@ void GrpcStreamingReader::OnStreamRead(const grpc::ByteBuffer& message) {
// Accumulate responses, responses_callback_ will be fired if
// GrpcStreamingReader has received all the responses.
responses_.push_back(message);
if (responses_.size() == expected_response_count_) {
if (expected_response_count_.ok() &&
responses_.size() == expected_response_count_.ValueOrDie()) {
callback_fired_ = true;
responses_callback_(responses_);
}
Expand Down
5 changes: 3 additions & 2 deletions Firestore/core/src/remote/grpc_streaming_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "Firestore/core/src/remote/grpc_stream_observer.h"
#include "Firestore/core/src/util/status.h"
#include "Firestore/core/src/util/status_fwd.h"
#include "Firestore/core/src/util/statusor.h"
#include "Firestore/core/src/util/warnings.h"
#include "grpcpp/client_context.h"
#include "grpcpp/support/byte_buffer.h"
Expand Down Expand Up @@ -62,7 +63,7 @@ class GrpcStreamingReader : public GrpcCall, public GrpcStreamObserver {
* results of the call. If the call fails, the `callback` will be invoked with
* a non-ok status.
*/
void Start(size_t expected_response_count,
void Start(util::StatusOr<size_t> expected_response_count,
ResponsesCallback&& responses_callback,
CloseCallback&& close_callback);

Expand Down Expand Up @@ -103,7 +104,7 @@ class GrpcStreamingReader : public GrpcCall, public GrpcStreamObserver {
std::unique_ptr<GrpcStream> stream_;
grpc::ByteBuffer request_;

size_t expected_response_count_;
util::StatusOr<size_t> expected_response_count_;
bool callback_fired_ = false;
ResponsesCallback responses_callback_;
CloseCallback close_callback_;
Expand Down
44 changes: 44 additions & 0 deletions Firestore/core/src/remote/remote_objc_bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,50 @@ DatastoreSerializer::DecodeExecutePipelineResponse(
return snapshot;
}

util::StatusOr<api::PipelineSnapshot>
DatastoreSerializer::MergeExecutePipelineResponses(
const std::vector<grpc::ByteBuffer>& responses,
std::shared_ptr<api::Firestore> db) const {
std::vector<api::PipelineResult> all_results;
model::SnapshotVersion execution_time = model::SnapshotVersion::None();

for (const auto& response : responses) {
ByteBufferReader reader{response};
auto message =
Message<google_firestore_v1_ExecutePipelineResponse>::TryParse(&reader);
if (!reader.ok()) {
return reader.status();
}

// DecodePipelineResponse decodes the whole message into a Snapshot.
// We can reuse it to get the partial results and execution time.
auto partial_snapshot =
serializer_.DecodePipelineResponse(reader.context(), message);
if (!reader.ok()) {
return reader.status();
}

// Accumulate results
// PipelineSnapshot::results() returns a const ref. We need to copy.
// But PipelineResult should be copyable/movable.
for (const auto& result : partial_snapshot.results()) {
all_results.push_back(result);
}

// Update execution time if present.
// DecodePipelineResponse returns SnapshotVersion::None() if not present?
// Let's assume the last non-None execution time is the correct one, or just
// update it.
if (partial_snapshot.execution_time() != model::SnapshotVersion::None()) {
execution_time = partial_snapshot.execution_time();
}
}

api::PipelineSnapshot merged_snapshot{std::move(all_results), execution_time};
merged_snapshot.SetFirestore(std::move(db));
return merged_snapshot;
}

} // namespace remote
} // namespace firestore
} // namespace firebase
4 changes: 4 additions & 0 deletions Firestore/core/src/remote/remote_objc_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ class DatastoreSerializer {
const grpc::ByteBuffer& response,
std::shared_ptr<api::Firestore> db) const;

util::StatusOr<api::PipelineSnapshot> MergeExecutePipelineResponses(
const std::vector<grpc::ByteBuffer>& responses,
std::shared_ptr<api::Firestore> db) const;

private:
Serializer serializer_;
};
Expand Down
95 changes: 82 additions & 13 deletions Firestore/core/test/unit/remote/grpc_streaming_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ class GrpcStreamingReaderTest : public testing::Test {
tester.KeepPollingGrpcQueue();
}

void StartReader(size_t expected_response_count) {
void StartReader(util::StatusOr<size_t> expected_response_count) {
worker_queue->EnqueueBlocking([&] {
reader->Start(
expected_response_count,
std::move(expected_response_count),
[&](std::vector<ResponsesT> result) {
responses = std::move(result);
},
Expand All @@ -101,7 +101,7 @@ TEST_F(GrpcStreamingReaderTest, FinishImmediatelyIsIdempotent) {
worker_queue->EnqueueBlocking(
[&] { EXPECT_NO_THROW(reader->FinishImmediately()); });

StartReader(0);
StartReader(util::StatusOr<size_t>(0));

KeepPollingGrpcQueue();
worker_queue->EnqueueBlocking([&] {
Expand All @@ -114,12 +114,12 @@ TEST_F(GrpcStreamingReaderTest, FinishImmediatelyIsIdempotent) {
// Method prerequisites -- correct usage of `GetResponseHeaders`

TEST_F(GrpcStreamingReaderTest, CanGetResponseHeadersAfterStarting) {
StartReader(0);
StartReader(util::StatusOr<size_t>(0));
EXPECT_NO_THROW(reader->GetResponseHeaders());
}

TEST_F(GrpcStreamingReaderTest, CanGetResponseHeadersAfterFinishing) {
StartReader(0);
StartReader(util::StatusOr<size_t>(0));

KeepPollingGrpcQueue();
worker_queue->EnqueueBlocking([&] {
Expand All @@ -139,7 +139,7 @@ TEST_F(GrpcStreamingReaderTest, CannotFinishAndNotifyBeforeStarting) {
// Normal operation

TEST_F(GrpcStreamingReaderTest, OneSuccessfulRead) {
StartReader(1);
StartReader(util::StatusOr<size_t>(1));

ForceFinishAnyTypeOrder({
{Type::Write, CompletionResult::Ok},
Expand All @@ -158,7 +158,7 @@ TEST_F(GrpcStreamingReaderTest, OneSuccessfulRead) {
}

TEST_F(GrpcStreamingReaderTest, TwoSuccessfulReads) {
StartReader(2);
StartReader(util::StatusOr<size_t>(2));

ForceFinishAnyTypeOrder({
{Type::Write, CompletionResult::Ok},
Expand All @@ -178,7 +178,7 @@ TEST_F(GrpcStreamingReaderTest, TwoSuccessfulReads) {
}

TEST_F(GrpcStreamingReaderTest, FinishWhileReading) {
StartReader(1);
StartReader(util::StatusOr<size_t>(1));

ForceFinishAnyTypeOrder({{Type::Write, CompletionResult::Ok},
{Type::Read, CompletionResult::Ok}});
Expand All @@ -194,7 +194,7 @@ TEST_F(GrpcStreamingReaderTest, FinishWhileReading) {
// Errors

TEST_F(GrpcStreamingReaderTest, ErrorOnWrite) {
StartReader(1);
StartReader(util::StatusOr<size_t>(1));

bool failed_write = false;
auto future = tester.ForceFinishAsync([&](GrpcCompletion* completion) {
Expand Down Expand Up @@ -230,7 +230,7 @@ TEST_F(GrpcStreamingReaderTest, ErrorOnWrite) {
}

TEST_F(GrpcStreamingReaderTest, ErrorOnFirstRead) {
StartReader(1);
StartReader(util::StatusOr<size_t>(1));

ForceFinishAnyTypeOrder({
{Type::Write, CompletionResult::Ok},
Expand All @@ -245,7 +245,7 @@ TEST_F(GrpcStreamingReaderTest, ErrorOnFirstRead) {
}

TEST_F(GrpcStreamingReaderTest, ErrorOnSecondRead) {
StartReader(2);
StartReader(util::StatusOr<size_t>(2));

ForceFinishAnyTypeOrder({
{Type::Write, CompletionResult::Ok},
Expand All @@ -259,12 +259,81 @@ TEST_F(GrpcStreamingReaderTest, ErrorOnSecondRead) {
EXPECT_TRUE(responses.empty());
}

TEST_F(GrpcStreamingReaderTest,
UnknownResponseCountReceivesAllMessagesOnFinish) {
// Use Status(Error::kErrorUnknown) to signify unknown response count
StartReader(util::Status(Error::kErrorUnknown, "Unknown response count"));

// Send some messages
ForceFinishAnyTypeOrder({
{Type::Write, CompletionResult::Ok},
{Type::Read, MakeByteBuffer("msg1")},
{Type::Read, MakeByteBuffer("msg2")},
/*Read after last*/ {Type::Read, CompletionResult::Error},
});

// At this point, responses_callback_ should NOT have been fired because
// expected_response_count_ is not 'ok'.
EXPECT_TRUE(responses.empty());
EXPECT_FALSE(status.has_value());

// Now, finish the stream successfully. This should trigger the
// responses_callback_ with all accumulated messages.
ForceFinish({{Type::Finish, grpc::Status::OK}});

ASSERT_TRUE(status.has_value());
EXPECT_EQ(status.value(), Status::OK());
ASSERT_EQ(responses.size(), 2);
EXPECT_EQ(ByteBufferToString(responses[0]), std::string{"msg1"});
EXPECT_EQ(ByteBufferToString(responses[1]), std::string{"msg2"});
}

TEST_F(GrpcStreamingReaderTest,
UnknownResponseCountReceivesEmptyOnFinishWithNoReads) {
StartReader(util::Status(Error::kErrorUnknown, "Unknown response count"));

ForceFinishAnyTypeOrder({
{Type::Write, CompletionResult::Ok},
/*Read after last*/ {Type::Read, CompletionResult::Error},
});

EXPECT_TRUE(responses.empty());
EXPECT_FALSE(status.has_value());

ForceFinish({{Type::Finish, grpc::Status::OK}});

ASSERT_TRUE(status.has_value());
EXPECT_EQ(status.value(), Status::OK());
ASSERT_TRUE(responses.empty()); // Should still be empty, but callback fired
}

TEST_F(GrpcStreamingReaderTest, UnknownResponseCountErrorOnFinish) {
StartReader(util::Status(Error::kErrorUnknown, "Unknown response count"));

ForceFinishAnyTypeOrder({
{Type::Write, CompletionResult::Ok},
{Type::Read, MakeByteBuffer("msg1")},
/*Read after last*/ {Type::Read, CompletionResult::Error},
});

EXPECT_TRUE(responses.empty());
EXPECT_FALSE(status.has_value());

grpc::Status error_status{grpc::StatusCode::DATA_LOSS, "Bad stream"};
ForceFinish({{Type::Finish, error_status}});

ASSERT_TRUE(status.has_value());
EXPECT_EQ(status.value().code(), Error::kErrorDataLoss);
EXPECT_TRUE(
responses.empty()); // responses_callback_ should not be fired on error
}

// Callback destroys reader

TEST_F(GrpcStreamingReaderTest, CallbackCanDestroyReaderOnSuccess) {
worker_queue->EnqueueBlocking([&] {
reader->Start(
1, [&](std::vector<ResponsesT>) {},
util::StatusOr<size_t>(1), [&](std::vector<ResponsesT>) {},
[&](const util::Status&, bool) { reader.reset(); });
});

Expand All @@ -282,7 +351,7 @@ TEST_F(GrpcStreamingReaderTest, CallbackCanDestroyReaderOnSuccess) {
TEST_F(GrpcStreamingReaderTest, CallbackCanDestroyReaderOnError) {
worker_queue->EnqueueBlocking([&] {
reader->Start(
1, [&](std::vector<ResponsesT>) {},
util::StatusOr<size_t>(1), [&](std::vector<ResponsesT>) {},
[&](const util::Status&, bool) { reader.reset(); });
});

Expand Down
Loading