Skip to content

Commit db82a0e

Browse files
authored
Merge pull request #980 from reyoung/feature/add_const_in_gradient_machine_eval
Add const to GradientMachine::eval
2 parents c1b294a + 4d5a0b0 commit db82a0e

File tree

9 files changed

+16
-16
lines changed

9 files changed

+16
-16
lines changed

paddle/gserver/gradientmachines/GradientMachine.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,12 @@ class GradientMachine {
181181
/**
182182
* Create an evaluator which can be used for eval()
183183
*/
184-
virtual Evaluator* makeEvaluator() = 0;
184+
virtual Evaluator* makeEvaluator() const = 0;
185185

186186
/**
187187
* evaluate using the given evaluator
188188
*/
189-
virtual void eval(Evaluator* evaluator) = 0;
189+
virtual void eval(Evaluator* evaluator) const = 0;
190190

191191
std::vector<ParameterPtr>& getParameters() { return parameters_; }
192192

paddle/gserver/gradientmachines/MultiGradientMachine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,11 @@ void MultiGradientMachine::finish() {
327327
}
328328
}
329329

330-
Evaluator* MultiGradientMachine::makeEvaluator() {
330+
Evaluator* MultiGradientMachine::makeEvaluator() const {
331331
return threads_[0]->getGradientMachine()->makeEvaluator();
332332
}
333333

334-
void MultiGradientMachine::eval(Evaluator* evaluator) {
334+
void MultiGradientMachine::eval(Evaluator* evaluator) const {
335335
for (auto& thread : threads_) {
336336
SetDevice device(thread->getDeviceId());
337337
thread->getGradientMachine()->eval(evaluator);

paddle/gserver/gradientmachines/MultiGradientMachine.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,9 @@ class MultiGradientMachine : public GradientMachine {
193193

194194
virtual void finish();
195195

196-
virtual Evaluator* makeEvaluator();
196+
virtual Evaluator* makeEvaluator() const;
197197

198-
virtual void eval(Evaluator* evaluator);
198+
virtual void eval(Evaluator* evaluator) const;
199199

200200
bool useGpu() const { return useGpu_; }
201201

paddle/gserver/gradientmachines/MultiNetwork.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class MultiCombinedEvaluator : public Evaluator {
171171
std::vector<std::unique_ptr<Evaluator>> evaluators_;
172172
};
173173

174-
Evaluator* MultiNetwork::makeEvaluator() {
174+
Evaluator* MultiNetwork::makeEvaluator() const {
175175
MultiCombinedEvaluator* multiCombinedEvaluator = new MultiCombinedEvaluator();
176176
for (size_t i = 0; i < subNetworks_.size(); i++) {
177177
std::unique_ptr<Evaluator> evaluator(subNetworks_[i]->makeEvaluator());
@@ -180,6 +180,6 @@ Evaluator* MultiNetwork::makeEvaluator() {
180180
return multiCombinedEvaluator;
181181
}
182182

183-
void MultiNetwork::eval(Evaluator* evaluator) { evaluator->eval(*this); }
183+
void MultiNetwork::eval(Evaluator* evaluator) const { evaluator->eval(*this); }
184184

185185
} // namespace paddle

paddle/gserver/gradientmachines/MultiNetwork.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ class MultiNetwork : public NeuralNetwork {
4646

4747
virtual void onPassEnd();
4848

49-
virtual Evaluator* makeEvaluator();
49+
virtual Evaluator* makeEvaluator() const;
5050

51-
virtual void eval(Evaluator* evaluator);
51+
virtual void eval(Evaluator* evaluator) const;
5252

5353
const std::vector<std::unique_ptr<NeuralNetwork>>& getSubNetworks() const {
5454
return subNetworks_;

paddle/gserver/gradientmachines/NeuralNetwork.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ class CombinedEvaluator : public Evaluator {
348348
std::vector<std::unique_ptr<Evaluator>> evaluators_;
349349
};
350350

351-
Evaluator* NeuralNetwork::makeEvaluator() {
351+
Evaluator* NeuralNetwork::makeEvaluator() const {
352352
CombinedEvaluator* combinedEvaluator = new CombinedEvaluator();
353353
auto subModelConfig = std::find_if(config_.sub_models().begin(),
354354
config_.sub_models().end(),
@@ -383,7 +383,7 @@ Evaluator* NeuralNetwork::makeEvaluator() {
383383
return combinedEvaluator;
384384
}
385385

386-
void NeuralNetwork::eval(Evaluator* evaluator) { evaluator->eval(*this); }
386+
void NeuralNetwork::eval(Evaluator* evaluator) const { evaluator->eval(*this); }
387387

388388
void NeuralNetwork::setOutputGrad(const std::vector<Argument>& args) {
389389
CHECK_GE(outputLayers_.size(), args.size());

paddle/gserver/gradientmachines/NeuralNetwork.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ class NeuralNetwork : public GradientMachine {
9696

9797
virtual void onPassEnd();
9898

99-
virtual Evaluator* makeEvaluator();
99+
virtual Evaluator* makeEvaluator() const;
100100

101-
virtual void eval(Evaluator* evaluator);
101+
virtual void eval(Evaluator* evaluator) const;
102102
virtual void resetState();
103103
virtual void setOutputGrad(const std::vector<Argument>& args);
104104

paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ void RecurrentGradientMachine::forwardBackward(
593593
LOG(FATAL) << "should not use this function";
594594
}
595595

596-
void RecurrentGradientMachine::eval(Evaluator* evaluator) {
596+
void RecurrentGradientMachine::eval(Evaluator* evaluator) const {
597597
// call printers frame by frame
598598
for (int i = 0; i < maxSequenceLength_; ++i) {
599599
LOG(INFO) << "Recurrent Layer Group eval frame " << i << " begin";

paddle/gserver/gradientmachines/RecurrentGradientMachine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class RecurrentGradientMachine : public NeuralNetwork {
6363
const UpdateCallback& callback);
6464

6565
virtual void resetState() {}
66-
virtual void eval(Evaluator* evaluator);
66+
virtual void eval(Evaluator* evaluator) const;
6767

6868
const std::vector<int>& getParameterIds() { return parameterIds_; }
6969

0 commit comments

Comments
 (0)