Skip to content

Commit 2e1017f

Browse files
committed
Support detached mean and var input tensors in batchnorm
Summary: In Poptorch batchnorm both 'running_mean' and 'running_var' are model parameters with flag 'requires_grad' set to False. As such both should be detached tensors to prevent gradient update. PopArt must handle that scenario and still be able to update input tensors via Copy. !ci_branch_mk2 Test Plan: CI Reviewers: #popart, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, matthewha Reviewed By: #popart, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, matthewha Subscribers: matthewha JIRA Issues: AFS-358 Differential Revision: https://phabricator.sourcevertex.net/D84992
1 parent 4bfb66a commit 2e1017f

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

tests/integration/operators_test/batchnorm_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,3 +1324,61 @@ def reference(_):
13241324

13251325
op_tester.atol = 1e-05
13261326
op_tester.run(init_builder, reference, "infer")
1327+
1328+
1329+
# PopTorch shall detach both mean and var input tensors to prevent gradient update
1330+
# but batchnorm should still be able to update those when track_running_stats is True.
1331+
def test_batchnorm_test_detached_mean_var(op_tester):
1332+
# create test data
1333+
d1 = np.array(
1334+
[[[[1, 1], [1, 1]], [[1, 1], [1, 1]]], [[[1, 0], [0, 1]], [[1, 0], [0, 1]]]],
1335+
dtype=np.float32,
1336+
)
1337+
1338+
scale = np.ones(2).astype(np.float32)
1339+
b = np.zeros(2).astype(np.float32)
1340+
mean = np.zeros(2).astype(np.float32)
1341+
var = np.ones(2).astype(np.float32)
1342+
epsilon = 1e-05
1343+
momentum = 0.1
1344+
1345+
def init_builder(builder):
1346+
1347+
i1 = builder.addInputTensor(d1)
1348+
iScale = builder.addInputTensor(scale)
1349+
iB = builder.addInputTensor(b)
1350+
iMean = builder.addInputTensor(mean)
1351+
dMean = builder.aiGraphcore.detach([iMean])
1352+
iVar = builder.addInputTensor(var)
1353+
dVar = builder.aiGraphcore.detach([iVar])
1354+
(o_y,) = builder.aiOnnx.batchnormalization(
1355+
[i1, iScale, iB, dMean, dVar], 1, epsilon, momentum
1356+
)
1357+
builder.addOutputTensor(o_y)
1358+
return [o_y]
1359+
1360+
def reference(_): # ref_data is an unused argument
1361+
_input = torch.tensor(d1, requires_grad=False)
1362+
_weight = torch.tensor(scale, requires_grad=False)
1363+
_bias = torch.tensor(b, requires_grad=False)
1364+
_mean = torch.tensor(mean, requires_grad=False)
1365+
_var = torch.tensor(var, requires_grad=False)
1366+
1367+
m = torch.nn.BatchNorm2d(
1368+
2, eps=epsilon, momentum=momentum, track_running_stats=True
1369+
)
1370+
m.state_dict()["weight"].copy_(_weight)
1371+
m.state_dict()["bias"].copy_(_bias)
1372+
m.state_dict()["running_mean"].copy_(_mean)
1373+
m.state_dict()["running_var"].copy_(_var)
1374+
1375+
m.eval()
1376+
1377+
_y = m(_input)
1378+
1379+
return [_y]
1380+
1381+
op_tester.setPatterns(
1382+
["PreUniRepl", "ReciprocalGradOp"], enableRuntimeAsserts=False
1383+
)
1384+
op_tester.run(init_builder, reference, "infer")

willow/src/op/batchnorm.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ void BatchNormOp::setup() {
7474
auto meanTensor = inTensor(getMeanInIndex());
7575
auto varTensor = inTensor(getVarInIndex());
7676

77+
// input might be detached by PopTorch if it had flag requires_grad == False
78+
if (meanTensor->hasProducer() &&
79+
meanTensor->getProducer()->opid.type == "Detach") {
80+
meanTensor = meanTensor->getProducer()->inTensor(0);
81+
}
82+
83+
if (varTensor->hasProducer() &&
84+
varTensor->getProducer()->opid.type == "Detach") {
85+
varTensor = varTensor->getProducer()->inTensor(0);
86+
}
87+
7788
if (meanTensor->tensorType() == TensorType::Variable) {
7889
meanTensor->setVariableUpdateType(VariableUpdateType::Copy);
7990
meanTensor->setCopyFromTensor(outId(getMeanOutIndex()));

0 commit comments

Comments
 (0)