Skip to content

Commit 7c74294

Browse files
authored
Mlu notifier reuse (#825)
1 parent 9ec7969 commit 7c74294

File tree

5 files changed

+438
-28
lines changed

5 files changed

+438
-28
lines changed

backends/mlu/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: "
2525
message(STATUS "AR tools: ${CMAKE_AR}")
2626

2727
# custom runtime
28-
set(CUSTOM_MLU_SRCS runtime/runtime.cc)
28+
set(CUSTOM_MLU_SRCS runtime/runtime.cc runtime/CNRTEvent.h)
2929
add_definitions(-DPADDLE_WITH_CUSTOM_DEVICE)
3030
# TODO(qiil93): avoid compile error, to be removed
3131
add_definitions(-DPADDLE_WITH_CUSTOM_KERNEL)

backends/mlu/runtime/CNRTEvent.h

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
16+
#include <cnrt.h>
17+
18+
#include "runtime/runtime.h"
19+
20+
/*
21+
* CNRTEvents are movable not copyable wrappers around CNRT's events.
22+
*/
23+
struct CNRTEvent {
24+
// Constructors
25+
// Default value for "flags" is specified below - it's
26+
// CNRT_NOTIFIER_DISABLE_TIMING_ALL
27+
CNRTEvent() noexcept = default;
28+
explicit CNRTEvent(cnrtNotifierFlags flag, int dev_idx) noexcept
29+
: flag_{flag}, dev_idx_(dev_idx) {
30+
createEvent();
31+
}
32+
33+
~CNRTEvent() {
34+
try {
35+
if (is_created_) {
36+
PADDLE_ENFORCE_MLU_SUCCESS(cnrtNotifierDestroy(event_));
37+
}
38+
} catch (...) { /* No throw */
39+
}
40+
}
41+
42+
CNRTEvent(const CNRTEvent&) = delete;
43+
CNRTEvent& operator=(const CNRTEvent&) = delete;
44+
45+
CNRTEvent(CNRTEvent&& other) noexcept { moveHelper(std::move(other)); }
46+
CNRTEvent& operator=(CNRTEvent&& other) noexcept {
47+
if (this != &other) {
48+
moveHelper(std::move(other));
49+
}
50+
return *this;
51+
}
52+
53+
operator cnrtNotifier_t() const { return event(); }
54+
55+
// Less than operator (to allow use in sets)
56+
friend bool operator<(const CNRTEvent& left, const CNRTEvent& right) {
57+
return left.event_ < right.event_;
58+
}
59+
60+
int device() const {
61+
if (is_created_) {
62+
return dev_idx_;
63+
} else {
64+
return -1;
65+
}
66+
}
67+
68+
bool isCreated() const { return is_created_; }
69+
bool wasRecorded() const { return was_recorded_; }
70+
bool isCompleted() const { return is_completed_; }
71+
int device_index() const { return dev_idx_; }
72+
cnrtNotifier_t event() const { return event_; }
73+
74+
bool query() {
75+
if (!is_created_) {
76+
return true;
77+
}
78+
79+
cnrtRet_t ret_err = cnrtQueryNotifier(event_);
80+
if (ret_err == cnrtSuccess) {
81+
is_completed_ = true;
82+
} else {
83+
(void)cnrtGetLastError();
84+
}
85+
86+
return is_completed_ ? true : false;
87+
}
88+
89+
void record(const int device, cnrtQueue_t stream) {
90+
PADDLE_ENFORCE_EQ(
91+
dev_idx_,
92+
device,
93+
phi::errors::InvalidArgument(
94+
"Event device %d does not match recording stream's device %d.",
95+
dev_idx_,
96+
device));
97+
PADDLE_ENFORCE_MLU_SUCCESS(
98+
cnrtPlaceNotifier(reinterpret_cast<cnrtNotifier_t>(event_), stream));
99+
was_recorded_ = true;
100+
}
101+
102+
void block(const cnrtQueue_t stream) {
103+
if (is_created_) {
104+
PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueWaitNotifier(
105+
reinterpret_cast<cnrtNotifier_t>(event_), stream, 0));
106+
}
107+
}
108+
109+
void synchronize() const {
110+
if (is_created_) {
111+
PADDLE_ENFORCE_MLU_SUCCESS(
112+
cnrtWaitNotifier(reinterpret_cast<cnrtNotifier_t>(event_)));
113+
}
114+
}
115+
116+
private:
117+
cnrtNotifierFlags flag_ =
118+
CNRT_NOTIFIER_DISABLE_TIMING_ALL; // no time elapsing
119+
bool is_created_ = false;
120+
bool was_recorded_ = false;
121+
bool is_completed_ = false;
122+
int dev_idx_ = -1;
123+
cnrtNotifier_t event_ = nullptr;
124+
125+
void createEvent() {
126+
// dev_idx_ = dev_idx;
127+
PADDLE_ENFORCE_MLU_SUCCESS(cnrtNotifierCreateWithFlags(&event_, flag_));
128+
is_created_ = true;
129+
}
130+
131+
void moveHelper(CNRTEvent&& other) {
132+
std::swap(flag_, other.flag_);
133+
std::swap(is_created_, other.is_created_);
134+
std::swap(was_recorded_, other.was_recorded_);
135+
std::swap(dev_idx_, other.dev_idx_);
136+
std::swap(event_, other.event_);
137+
}
138+
};

backends/mlu/runtime/flags.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// Part of the following code in this file is from
16+
// https://github.com/google/glog/blob/master/src/base/commandlineflags.h
17+
// Git commit hash: 9f0b7d3bfe1542848f784e8d1c545b916cec6b3e
18+
// Retain the following license from the original files:
19+
20+
// Copyright (c) 2008, Google Inc.
21+
// All rights reserved.
22+
//
23+
// Redistribution and use in source and binary forms, with or without
24+
// modification, are permitted provided that the following conditions are
25+
// met:
26+
//
27+
// * Redistributions of source code must retain the above copyright
28+
// notice, this list of conditions and the following disclaimer.
29+
// * Redistributions in binary form must reproduce the above
30+
// copyright notice, this list of conditions and the following disclaimer
31+
// in the documentation and/or other materials provided with the
32+
// distribution.
33+
// * Neither the name of Google Inc. nor the names of its
34+
// contributors may be used to endorse or promote products derived from
35+
// this software without specific prior written permission.
36+
//
37+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
38+
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
39+
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
40+
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
41+
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
42+
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
43+
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
44+
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
45+
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
46+
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
47+
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
48+
49+
#ifndef BACKENDS_MLU_RUNTIME_FLAGS_H_
50+
#define BACKENDS_MLU_RUNTIME_FLAGS_H_
51+
52+
#include "gflags/gflags.h"
53+
54+
#define FLAGS_DEFINE_bool(name, value, meaning) \
55+
DEFINE_bool(name, EnvToBool("FLAGS_" #name, value), meaning)
56+
57+
#define FLAGS_DEFINE_int32(name, value, meaning) \
58+
DEFINE_int32(name, EnvToInt("FLAGS_" #name, value), meaning)
59+
60+
#define FLAGS_DEFINE_uint32(name, value, meaning) \
61+
DEFINE_uint32(name, EnvToUInt("FLAGS_" #name, value), meaning)
62+
63+
#define FLAGS_DEFINE_uint64(name, value, meaning) \
64+
DEFINE_uint64(name, EnvToUInt("FLAGS_" #name, value), meaning)
65+
66+
#define FLAGS_DEFINE_string(name, value, meaning) \
67+
DEFINE_string(name, EnvToString("FLAGS_" #name, value), meaning)
68+
69+
#define EnvToString(envname, dflt) (!getenv(envname) ? (dflt) : getenv(envname))
70+
71+
#define EnvToBool(envname, dflt) \
72+
(!getenv(envname) ? (dflt) : memchr("tTyY1\0", getenv(envname)[0], 6) != NULL)
73+
74+
#define EnvToInt(envname, dflt) \
75+
(!getenv(envname) ? (dflt) : strtol(getenv(envname), NULL, 10))
76+
77+
#define EnvToUInt(envname, dflt) \
78+
(!getenv(envname) ? (dflt) : strtoul(getenv(envname), NULL, 10))
79+
80+
#endif // BACKENDS_MLU_RUNTIME_FLAGS_H_

0 commit comments

Comments
 (0)