Skip to content

Commit 0e46f5e

Browse files
authored
Merge pull request PaddlePaddle#4094 from lcy-seso/fix_cross_entropy_op_output_shape
fix shape of output tensor of cross_entropy_op.
2 parents 012db97 + e0ca4d7 commit 0e46f5e

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

paddle/operators/cross_entropy_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
2929
PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2.");
3030
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1.");
3131
PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]);
32-
ctx.Output<Tensor>("Y")->Resize({X->dims()[0]});
32+
ctx.Output<Tensor>("Y")->Resize({X->dims()[0], 1});
3333
}
3434
};
3535

python/paddle/v2/framework/tests/test_cross_entropy_op.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,22 @@ def setUp(self):
88
self.op_type = "onehot_cross_entropy"
99
batch_size = 30
1010
class_num = 10
11+
1112
X = numpy.random.uniform(0.1, 1.0,
1213
[batch_size, class_num]).astype("float32")
13-
label = (class_num / 2) * numpy.ones(batch_size).astype("int32")
14-
self.inputs = {'X': X, 'label': label}
15-
Y = []
16-
for i in range(0, batch_size):
17-
Y.append(-numpy.log(X[i][label[i]]))
18-
self.outputs = {'Y': numpy.array(Y).astype("float32")}
14+
labels = numpy.random.randint(0, class_num, batch_size, dtype="int32")
15+
16+
cross_entropy = numpy.asmatrix(
17+
[[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])],
18+
dtype="float32")
19+
self.inputs = {"X": X, "label": labels}
20+
self.outputs = {"Y": cross_entropy}
1921

2022
def test_check_output(self):
2123
self.check_output()
2224

2325
def test_check_grad(self):
24-
self.check_grad(['X'], 'Y')
26+
self.check_grad(["X"], "Y")
2527

2628

2729
if __name__ == "__main__":

0 commit comments

Comments
 (0)