|
| 1 | +import torch |
| 2 | +from torch.autograd.gradcheck import get_analytical_jacobian, get_numerical_jacobian, zero_gradients, make_jacobian |
| 3 | + |
| 4 | + |
| 5 | +def get_analytical_jacobian_params(output, target): |
| 6 | + """ |
| 7 | + Computes the analytical jacobian with respect to all tensors in `target`, which can hold some or all of the |
| 8 | + parameters of a module used to compute `output`. |
| 9 | +
|
| 10 | + output: torch.tensor output from which to backpropagate the gradients |
| 11 | + target: torch.tensor or iterable containing torch.tensor for which to compute the gradients |
| 12 | + """ |
| 13 | + |
| 14 | + jacobian = make_jacobian(target, output.numel()) |
| 15 | + grad_output = torch.zeros_like(output) |
| 16 | + flat_grad_output = grad_output.view(-1) |
| 17 | + |
| 18 | + for i in range(flat_grad_output.numel()): |
| 19 | + flat_grad_output.zero_() |
| 20 | + flat_grad_output[i] = 1 |
| 21 | + |
| 22 | + zero_gradients(target) |
| 23 | + torch.autograd.backward(output, grad_output, retain_graph=True) |
| 24 | + |
| 25 | + for j in range(len(jacobian)): |
| 26 | + jacobian[j][:, i] = target[j].grad.clone().flatten() |
| 27 | + |
| 28 | + return jacobian |
| 29 | + |
| 30 | + |
| 31 | +def gradcheck(m, input, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True): |
| 32 | + """ |
| 33 | + Compare analytical gradients of a module to numerical gradients computed via central finite differences. |
| 34 | +
|
| 35 | + Disclaimer: this is a modified version of torch.autograd.gradcheck::gradcheck |
| 36 | + (https://pytorch.org/docs/stable/_modules/torch/autograd/gradcheck.html, 2019-06-04) distributed under license |
| 37 | + https://github.com/pytorch/pytorch/blob/master/LICENSE |
| 38 | + :param m: |
| 39 | + :param input: |
| 40 | + :param eps: |
| 41 | + :param atol: |
| 42 | + :param rtol: |
| 43 | + :param raise_exception: |
| 44 | + :return: |
| 45 | + """ |
| 46 | + def fail_test(msg): |
| 47 | + if raise_exception: |
| 48 | + raise RuntimeError(msg) |
| 49 | + return False |
| 50 | + |
| 51 | + def fn(input): |
| 52 | + return m(*input) |
| 53 | + |
| 54 | + output = fn(input) |
| 55 | + |
| 56 | + for i, o in enumerate(output): |
| 57 | + if not o.requires_grad: |
| 58 | + continue |
| 59 | + |
| 60 | + # compare input gradients |
| 61 | + analytical, reentrant, correct_grad_sizes = get_analytical_jacobian(input, o) |
| 62 | + numerical = get_numerical_jacobian(fn, input, eps=eps) |
| 63 | + |
| 64 | + if not correct_grad_sizes: |
| 65 | + return fail_test('Analytical gradient has incorrect size') |
| 66 | + |
| 67 | + for j, (a, n) in enumerate(zip(analytical, numerical)): |
| 68 | + if a.numel() != 0 or n.numel() != 0: |
| 69 | + if not torch.allclose(a, n, rtol, atol): |
| 70 | + return fail_test('Jacobian mismatch for output %d with respect to input %d,\n' |
| 71 | + 'numerical:%s\nanalytical:%s\n' % (i, j, n, a)) |
| 72 | + |
| 73 | + if not reentrant: |
| 74 | + return fail_test('Backward is not reentrant, i.e., running backward with same ' |
| 75 | + 'input and grad_output multiple times gives different values, ' |
| 76 | + 'although analytical gradient matches numerical gradient') |
| 77 | + |
| 78 | + # compare parameter gradients |
| 79 | + pars = [t for t in m.parameters()] |
| 80 | + |
| 81 | + if pars: |
| 82 | + numerical = get_numerical_jacobian(fn, input, target=pars) |
| 83 | + analytical = get_analytical_jacobian_params(output, pars) |
| 84 | + |
| 85 | + for j, (a, n) in enumerate(zip(analytical, numerical)): |
| 86 | + if a.numel() != 0 or n.numel() != 0: |
| 87 | + if not torch.allclose(a, n, rtol, atol): |
| 88 | + return fail_test('Jacobian mismatch for output %d with respect to parameter %d,\n' |
| 89 | + 'numerical:%s\nanalytical:%s\n' % (i, j, n, a)) |
| 90 | + return True |
0 commit comments