Skip to content

Commit f565723

Browse files
improve inplace performance
Summary: !ci_branch_mk2 When inplace open, poprithm memory graph will repeatedly constructed to query alias (a view) of a tensor. This is time consuming. The naive idea is to reuse the the alias model and poprithm graph grower to save compilation time. Reviewers: #popart, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, jamesn, shirazb Reviewed By: #popart, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, jamesn, shirazb Subscribers: markk, shirazb Maniphest Tasks: T74421 Differential Revision: https://phabricator.sourcevertex.net/D85549
1 parent 4ccb2f2 commit f565723

File tree

6 files changed

+190
-22
lines changed

6 files changed

+190
-22
lines changed

willow/include/popart/ir.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <string>
1313
#include <typeindex>
1414
#include <vector>
15+
#include <popart/alias/aliasmodel.hpp>
16+
#include <popart/alias/aliasmodelgrower.hpp>
1517
#include <popart/bimap.hpp>
1618
#include <popart/dataflow.hpp>
1719
#include <popart/inputshapeinfo.hpp>

willow/include/popart/op.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <utility>
1717
#include <vector>
1818
#include <poprithms/memory/inplace/proposal.hpp>
19+
#include <popart/alias/aliasmodelgrower.hpp>
1920
#include <popart/attributes.hpp>
2021
#include <popart/basicoptionals.hpp>
2122
#include <popart/bwdgraphinfo.hpp>
@@ -1678,6 +1679,16 @@ class Op : public Vertex {
16781679
*/
16791680
bool inputUnmodifiable(InIndex in) const;
16801681

1682+
/**
1683+
* Check if the input index is unmodifiable or aliases an unmodifiable tensor
1684+
* with given poprithm graph.
1685+
*
1686+
* \param in The input index to check.
1687+
* \returns `true` if any connected variable tensor has a non-empty alias
1688+
* chain and is unmodifiable, `false` otherwise.
1689+
*/
1690+
bool inputUnmodifiableFor(InIndex in, const AliasModel *popMem) const;
1691+
16811692
/**
16821693
* Check if output is modified by any consumer.
16831694
*
@@ -1687,6 +1698,15 @@ class Op : public Vertex {
16871698
*/
16881699
bool hasAliasedModifiers(OutIndex out) const;
16891700

1701+
/**
1702+
* Check if output is modified by any consumer with the given poprithm graph.
1703+
*
1704+
* \param out The output index to check.
1705+
* \returns `true` if any consumer of any aliased tensor downstream modifies
1706+
* a non-empty region, `false` otherwise.
1707+
*/
1708+
bool hasAliasedModifiersFor(OutIndex out, const AliasModel *popMem) const;
1709+
16901710
// Helper functions for probing graph structure.
16911711
/**
16921712
* Check if the graph is a parent of the op.

willow/include/popart/tensor.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <string>
1212
#include <utility>
1313
#include <vector>
14+
#include <popart/alias/aliasmodelgrower.hpp>
1415
#include <popart/dataflow.hpp>
1516
#include <popart/error.hpp>
1617
#include <popart/names.hpp>
@@ -221,6 +222,11 @@ class Tensor : public Vertex {
221222
// Returns true if the tensor or any of it's aliases fulfill the predicate
222223
bool anyAlias(std::function<bool(Tensor *)> predicate) const;
223224

225+
// Returns true if the tensor or any of it's aliases fulfill the predicate in
226+
// the given poprithm memory graph
227+
bool anyAliasFor(std::function<bool(Tensor *)> predicate,
228+
const AliasModel &popMem) const;
229+
224230
void setTensorDataFromCopyOf(const void *src, std::size_t size);
225231
void setTensorDataFromViewOf(void *src, std::size_t size);
226232
void setTensorDataByEmplaceOf(std::vector<char> &&data);
@@ -291,6 +297,12 @@ class Tensor : public Vertex {
291297
*/
292298
std::set<Op *, POpCmp> getInplaceModifiers() const;
293299

300+
/**
301+
* Find operations that modify a tensor with the given poprithm graph
302+
* \return All operations that (direct and indirectly) modify this tensor
303+
*/
304+
std::set<Op *, POpCmp> getInplaceModifiersFor(const AliasModel *popMem) const;
305+
294306
// Backtrack through input and parent graph tensors in order to get data from
295307
// initializer tensors (if they exist).
296308
// When ops are performed on initializers (e.g. slice), the
@@ -311,6 +323,10 @@ class Tensor : public Vertex {
311323

312324
int getBatchAxisFromOp(Op *, bool, int) const;
313325

326+
bool anyAliasImpl(std::function<bool(Tensor *)> predicate,
327+
const AliasModel &popMem,
328+
const char *scopeDesc) const;
329+
314330
const TensorDebugInfo di;
315331

316332
/**

willow/src/ir.cpp

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include <graphfromlosstolossupdater.hpp>
99
#include <onnxutil.hpp>
1010
#include <poprithms/logging/timepartitionlogger.hpp>
11-
#include <popart/alias/aliasmodelgrower.hpp>
1211
#include <popart/ces/constexpr.hpp>
1312
#include <popart/devicemanager.hpp>
1413
#include <popart/error.hpp>
@@ -48,7 +47,6 @@
4847
#include <poprithms/memory/inplace/proposal.hpp>
4948
#include <poprithms/memory/inplace/result.hpp>
5049
#include <poprithms/util/typedinteger.hpp>
51-
#include <popart/alias/aliasmodel.hpp>
5250
#include <popart/dotvisualizer.hpp>
5351
#include <popart/op/copyvarupdate.hpp>
5452
#include <popart/op/ipucopy.hpp>
@@ -157,6 +155,55 @@ std::ostream &operator<<(std::ostream &ost, const OpsBeforeKey &o) {
157155
return ost;
158156
}
159157

158+
struct PopMemGrower {
159+
std::unique_ptr<AliasModel> popMem_;
160+
std::unique_ptr<AliasModelGrower> aliasModelGrower_;
161+
162+
using Proposal = poprithms::memory::inplace::Proposal;
163+
using OpeningStatus = poprithms::memory::inplace::OpeningStatus;
164+
165+
PopMemGrower();
166+
167+
void init();
168+
169+
void growFullGraph(Graph &graph,
170+
DataDependenciesOnly dep = DataDependenciesOnly::No) {
171+
aliasModelGrower_->growFullGraph(graph, dep);
172+
}
173+
174+
AliasModel &getAliasModelRef() { return *(popMem_.get()); }
175+
176+
Proposal mapInplaceProposal(Op *op, OperatorIdentifier identifier) {
177+
return op->mapInplaceProposal(*popMem_, identifier);
178+
}
179+
180+
OpeningStatus tryOpening(Proposal &proposal) {
181+
return popMem_->g.tryOpening(
182+
proposal,
183+
poprithms::memory::inplace::CheckParallelWriteable::No,
184+
poprithms::memory::inplace::AllowMultiGateAlias::No);
185+
}
186+
187+
void reset(Graph &graph) {
188+
init();
189+
growFullGraph(graph, DataDependenciesOnly::Yes);
190+
}
191+
};
192+
193+
PopMemGrower::PopMemGrower()
194+
: popMem_(std::make_unique<AliasModel>()),
195+
aliasModelGrower_(std::make_unique<AliasModelGrower>(*popMem_)) {}
196+
197+
void PopMemGrower::init() {
198+
if (popMem_ == nullptr) {
199+
popMem_ = std::make_unique<AliasModel>();
200+
aliasModelGrower_ = std::make_unique<AliasModelGrower>(*popMem_);
201+
} else {
202+
popMem_.reset(new AliasModel());
203+
aliasModelGrower_.reset(new AliasModelGrower(*popMem_));
204+
}
205+
}
206+
160207
poprithms::logging::TimePartitionLogger &Ir::timePartitionLogger() const {
161208
return *timePartitionLogger_;
162209
}
@@ -3466,12 +3513,16 @@ void Ir::applyUpdateInplacePrioritiesForIpu() {
34663513
}
34673514

34683515
void Ir::applyInplacePattern(Graph &graph) {
3469-
34703516
// The decision of where topological constraints need to be inserted is made
34713517
// by a poprithms Graph whose Ops mirror those in \a graph.
3472-
AliasModel popMem;
3473-
AliasModelGrower aliasModelGrower{popMem};
3474-
aliasModelGrower.growFullGraph(graph, DataDependenciesOnly::No);
3518+
// Create poprithm memroy graph growers for this graph
3519+
PopMemGrower popMemGrowerOfSubgraph;
3520+
PopMemGrower popMemGrowerOfTensor;
3521+
3522+
popMemGrowerOfSubgraph.growFullGraph(graph, DataDependenciesOnly::No);
3523+
popMemGrowerOfTensor.growFullGraph(graph, DataDependenciesOnly::Yes);
3524+
3525+
AliasModel &popMem = *(popMemGrowerOfSubgraph.popMem_.get());
34753526

34763527
Inplace inplace;
34773528

@@ -3591,6 +3642,8 @@ void Ir::applyInplacePattern(Graph &graph) {
35913642
continue;
35923643
}
35933644

3645+
auto tProposal = popMemGrowerOfTensor.mapInplaceProposal(op, identifier);
3646+
35943647
// Convert poprithms topological constraints into popart constraints
35953648
OpsBeforeKey newTopoCons;
35963649
for (auto from_to : result.constraints()) {
@@ -3695,14 +3748,22 @@ void Ir::applyInplacePattern(Graph &graph) {
36953748
};
36963749

36973750
bool restoreInplaceIn =
3698-
op->input->tensor(in_index.first)->anyAlias(restoreInplaceTensor);
3699-
bool restoreInplaceOut = op->output->tensor(out_index.first)
3700-
->anyAlias(restoreInplaceTensor);
3751+
op->input->tensor(in_index.first)
3752+
->anyAliasFor(restoreInplaceTensor,
3753+
popMemGrowerOfTensor.getAliasModelRef());
3754+
bool restoreInplaceOut =
3755+
op->output->tensor(out_index.first)
3756+
->anyAliasFor(restoreInplaceTensor,
3757+
popMemGrowerOfTensor.getAliasModelRef());
37013758

37023759
bool conflictIn =
3703-
op->input->tensor(in_index.first)->anyAlias(isConflictTensor);
3760+
op->input->tensor(in_index.first)
3761+
->anyAliasFor(isConflictTensor,
3762+
popMemGrowerOfTensor.getAliasModelRef());
37043763
bool conflictOut =
3705-
op->output->tensor(out_index.first)->anyAlias(isConflictTensor);
3764+
op->output->tensor(out_index.first)
3765+
->anyAliasFor(isConflictTensor,
3766+
popMemGrowerOfTensor.getAliasModelRef());
37063767

37073768
// Check that no conflict tensors, through aliasing, can be consumed
37083769
// by a RestoreInplaceOp
@@ -3726,10 +3787,13 @@ void Ir::applyInplacePattern(Graph &graph) {
37263787

37273788
// Unmodifiable
37283789
// 1. Is the input unmodifiable?
3729-
bool unmodifiable = op->inputUnmodifiable(in_index.first);
3790+
bool unmodifiable = op->inputUnmodifiableFor(
3791+
in_index.first, popMemGrowerOfTensor.popMem_.get());
37303792
// 2. Does it indirectly modify this tensor and alias it?
37313793
bool indirectModify =
3732-
(op->hasAliasedModifiers(out_index.first) && opAliases);
3794+
(opAliases &&
3795+
op->hasAliasedModifiersFor(out_index.first,
3796+
popMemGrowerOfTensor.popMem_.get()));
37333797
// 3. Does it directly modify a weight?
37343798
bool directModify = inplaceOp->modifiesIndex(in_index.first);
37353799
// If ((1 and 2) or 3) : do not inplace.
@@ -3748,7 +3812,8 @@ void Ir::applyInplacePattern(Graph &graph) {
37483812

37493813
if ((indirectModify || directModify) &&
37503814
op->input->tensor(in_index.first)
3751-
->anyAlias(isImplicitRecomputeTensor)) {
3815+
->anyAliasFor(isImplicitRecomputeTensor,
3816+
popMemGrowerOfTensor.getAliasModelRef())) {
37523817
logging::pattern::trace("[Inplacing] Not inplacing {} with {} as "
37533818
"it would be modified by a recomputation "
37543819
"{} -> {} ",
@@ -3881,7 +3946,8 @@ void Ir::applyInplacePattern(Graph &graph) {
38813946
};
38823947

38833948
for (auto tensor : currentInsOuts) {
3884-
tensor->anyAlias(populateConsumersInIndices);
3949+
tensor->anyAliasFor(populateConsumersInIndices,
3950+
popMemGrowerOfTensor.getAliasModelRef());
38853951
}
38863952

38873953
for (const auto &consumerInIndices : consumersInIndices) {
@@ -3944,6 +4010,15 @@ void Ir::applyInplacePattern(Graph &graph) {
39444010
// The Op in graph has changed, mirror the change in the poprithms
39454011
// Graph
39464012
popMem.update(id, opOutput->getProducer()->id);
4013+
4014+
const auto status = popMemGrowerOfTensor.tryOpening(tProposal);
4015+
if (status != PopMemGrower::OpeningStatus::Valid) {
4016+
popMemGrowerOfTensor.reset(graph);
4017+
} else {
4018+
// The Op in graph has changed, mirror the change in the poprithms
4019+
// Graph
4020+
popMemGrowerOfTensor.popMem_->update(id, opOutput->getProducer()->id);
4021+
}
39474022
}
39484023
}
39494024
}

willow/src/op.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,12 +1352,26 @@ bool Op::producesGraphOutput() const {
13521352
}
13531353

13541354
bool Op::inputUnmodifiable(InIndex in) const {
1355+
return inputUnmodifiableFor(in, nullptr);
1356+
}
1357+
1358+
bool Op::inputUnmodifiableFor(InIndex in, const AliasModel *popMemPtr) const {
13551359
auto t = input->tensor(in);
13561360
// If the input tensor itself, or any of it's aliases, are unmodifiable
1357-
return t->anyAlias([](Tensor *tensor) { return tensor->isUnmodifiable(); });
1361+
if (popMemPtr == nullptr) {
1362+
return t->anyAlias([](Tensor *tensor) { return tensor->isUnmodifiable(); });
1363+
} else {
1364+
return t->anyAliasFor(
1365+
[](Tensor *tensor) { return tensor->isUnmodifiable(); }, *popMemPtr);
1366+
}
13581367
}
13591368

13601369
bool Op::hasAliasedModifiers(OutIndex out) const {
1370+
return hasAliasedModifiersFor(out, nullptr);
1371+
}
1372+
1373+
bool Op::hasAliasedModifiersFor(OutIndex out,
1374+
const AliasModel *popMemPtr) const {
13611375
auto t = output->tensor(out);
13621376

13631377
auto checkConsumers = [](Tensor *t_in) {
@@ -1372,7 +1386,11 @@ bool Op::hasAliasedModifiers(OutIndex out) const {
13721386
return false;
13731387
};
13741388

1375-
return t->anyAlias(checkConsumers);
1389+
if (popMemPtr == nullptr) {
1390+
return t->anyAlias(checkConsumers);
1391+
} else {
1392+
return t->anyAliasFor(checkConsumers, *popMemPtr);
1393+
}
13761394
}
13771395

13781396
bool Op::isParentOf(const Op *op) const {

willow/src/tensor.cpp

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,14 @@ view::Regions Tensor::modifiedRegionsByOps(std::vector<Op *> ops,
307307
}
308308

309309
std::set<Op *, POpCmp> Tensor::getInplaceModifiers() const {
310+
return getInplaceModifiersFor(nullptr);
311+
}
312+
313+
std::set<Op *, POpCmp>
314+
Tensor::getInplaceModifiersFor(const AliasModel *popMem) const {
310315
std::set<Op *, POpCmp> ops;
311-
anyAlias([&ops](Tensor *t) {
316+
317+
auto cb = [&ops](Tensor *t) {
312318
auto consumers = t->consumers.getOps();
313319
for (auto c : consumers) {
314320
if (c->modifies()) {
@@ -317,7 +323,13 @@ std::set<Op *, POpCmp> Tensor::getInplaceModifiers() const {
317323
}
318324
// Continue until all aliases have been visited
319325
return false;
320-
});
326+
};
327+
328+
if (popMem == nullptr) {
329+
anyAlias(cb);
330+
} else {
331+
anyAliasFor(cb, *popMem);
332+
}
321333
return ops;
322334
}
323335

@@ -1122,9 +1134,6 @@ bool Tensor::isRootAnchor() const { return graph.getIr().isRootAnchor(id); }
11221134
bool Tensor::anyAlias(std::function<bool(Tensor *)> predicate) const {
11231135

11241136
constexpr const char *const ctxt{"Tensor::anyAlias"};
1125-
logging::ir::trace("{} for Tensor {},", ctxt, str());
1126-
1127-
auto scopedStopwatch = getIr().timePartitionLogger().scopedStopwatch(ctxt);
11281137

11291138
// First check if this tensor itself satisfies the predicate. If so, we need
11301139
// not bother constructing a poprithms graph to check for alias tensors.
@@ -1138,6 +1147,34 @@ bool Tensor::anyAlias(std::function<bool(Tensor *)> predicate) const {
11381147
AliasModelGrower aliasModelGrower{popMem};
11391148
aliasModelGrower.growPartialGraph(graph, id, DataDependenciesOnly::Yes);
11401149

1150+
return anyAliasImpl(predicate, popMem, ctxt);
1151+
}
1152+
1153+
bool Tensor::anyAliasFor(std::function<bool(Tensor *)> predicate,
1154+
const AliasModel &popMem) const {
1155+
1156+
constexpr const char *const ctxt{
1157+
"Tensor::anyAlias with prebuild AliasModel and AliasModelGrower."};
1158+
1159+
return anyAliasImpl(predicate, popMem, ctxt);
1160+
}
1161+
1162+
bool Tensor::anyAliasImpl(std::function<bool(Tensor *)> predicate,
1163+
const AliasModel &popMem,
1164+
const char *scopeDesc) const {
1165+
1166+
logging::ir::trace("{} for Tensor {},", scopeDesc, str());
1167+
1168+
auto scopedStopwatch =
1169+
getIr().timePartitionLogger().scopedStopwatch(scopeDesc);
1170+
1171+
// First check if this tensor itself satisfies the predicate. If so, we need
1172+
// not bother constructing a poprithms graph to check for alias tensors.
1173+
Tensor *t = graph.getTensors().get(id);
1174+
if (predicate(t)) {
1175+
return true;
1176+
}
1177+
11411178
// Get the identifier used to represent this tensor in poprithms.
11421179
auto poprithmsTensorId = popMem.getPoprithmsTensorId(id);
11431180

0 commit comments

Comments
 (0)