In PyTorch, there are two versions of register backward hooks. I will describe each of them in turn.
register_full_backward_hook
The registration function is defined as register_full_backward_hook(hook, prepend=False), where the hook function follows the signature:
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
The parameters of this function are as follows:
- module: The module to which the hook is registered.
- grad_input. grad_input is the gradient of the current module’s input with respect to the variable on which backward() was called.
- grad_output. grad_output is the gradient of the current module’s input with respect to the variable on which backward() was called.
If the return value is not None and is a tensor, it will be used as a new grad_input in subsequent computations.
We use the following example to illustrate this function.
Function
$$ \begin{aligned} &\text{Linear}\\ & & x_1 = \underbrace{0.3}_{w_0}x +\underbrace{0.1}_{b_0} \\ &\text{SLinear} \\ & & s = \underbrace{0.4}_{w_1}x_1 +\underbrace{0.2}_{b_1} \\ & & y = s^2 \\ &\text{L (Loss function)} \\ && L = (y - 1)^2 \end{aligned} $$We set $x = 0.3$.
Forward
$$ \begin{aligned} &\text{Linear }\\ & & x_1 = \underbrace{0.3}_{w_0}x +\underbrace{0.1}_{b_0} = 0.3 * 0.3 + 0.1 = 0.19 \\ &\text{SLinear }\\ & & s = \underbrace{0.4}_{w_1}x_1 +\underbrace{0.2}_{b_1} = 0.4 * 0.19 + 0.2 = 0.276 \\ & & y = s^2 = 0.076176 \\ &\text{L (Loss function) }\\ & & L = (y - 1)^2 = 0.85345 & & \end{aligned} $$Backward
$$ \begin{aligned} \frac{\partial L}{\partial y} &= 2(y - 1) = 2 * (0.076176 - 1) = -1.847648 & \text{grad\_output}\\ \frac{\partial y}{\partial s} &= 2s = 0.552 \\ \frac{\partial s}{\partial x_1} &= w_1 = 0.4 \\ \frac{\partial x_1}{\partial x} &= w_0 = 0.3 \\ \frac{\partial x_1}{\partial w_0} &= x = 0.3 \\ \frac{\partial x_1}{\partial b_0} &= 1 \\ \end{aligned} $$According to the chain rule, we obtain
$$ \begin{aligned} \frac{\partial L}{\partial x_1} &= \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial s} \cdot \frac{\partial s}{\partial x_1} &= -1.847648 * 0.552 * 0.4 &= - 0.40796 & \text{grad\_input}\\ \frac{\partial L}{\partial x} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial x} &= - 0.40796 * 0.3 &= -0.12238 \\ \frac{\partial L}{\partial w_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial w_0} &= - 0.40796 * 0.3 &= -0.12238 \\ \frac{\partial L}{\partial b_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial b_0} &= - 0.40796 * 1 &= - 0.40796 \\ \end{aligned} $$If we double the gradient input, the gradients of $w_0$ and $b_0$ will become twice their original values, as shown below:
$$ \begin{aligned} \frac{\partial L}{\partial w_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial w_0} &= 2 * - 0.40796 * 0.3 &= -0.244776 \\ \frac{\partial L}{\partial b_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial b_0} &= 2 * - 0.40796 * 1 &= -0.81592 \\ \end{aligned} $$Code
The Python code is as follows:
def rescale_hook(layer, grad_input, grad_out):
print('grad_input: ', grad_input)
print('grad_output: ', grad_out)
#return (grad_input[0] * 2, )
class SLinear(nn.Module):
def __init__(self):
super(SLinear, self).__init__()
self.f = nn.Linear(1, 1, bias=True)
with torch.no_grad():
self.f.weight[0,0] = 0.4
self.f.bias[0] = 0.2
def forward(self, x):
s = self.f(x)
y = s**2
return y
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.f0 = nn.Linear(1, 1, bias=True)
with torch.no_grad():
self.f0.weight[0,0] = 0.3
self.f0.bias[0] = 0.1
self.f1 = SLinear()
#self.f1.register_backward_hook(rescale_hook)
self.f1.register_full_backward_hook(rescale_hook)
def forward(self, x):
x1 = self.f0(x)
x2 = self.f1(x1)
return x2
The output of this code snippet is as follows:
1.11.0
y: tensor([0.0762], grad_fn=<BackwardHookFunctionBackward>)
loss: tensor([0.8535], grad_fn=<PowBackward0>)
SLinear(
(f): Linear(in_features=1, out_features=1, bias=True)
)
grad_input: (tensor([-0.4080]),)
grad_output: (tensor([-1.8476]),)
x.grad: tensor([-0.1224])
f1 weight: tensor([0.4000]) bias: tensor(0.2000) f1 weight grad: tensor([[-0.1938]]) f1 bias grad: tensor([-1.0199])
f0 weight: tensor([0.3000]) bias: tensor(0.1000) f1 weight grad: tensor([[-0.1224]]) f1 bias grad: tensor([-0.4080])
After uncommenting #return (grad_input[0] * 2, )
, the output changes to the following:
1.11.0
y: tensor([0.0762], grad_fn=<BackwardHookFunctionBackward>)
loss: tensor([0.8535], grad_fn=<PowBackward0>)
SLinear(
(f): Linear(in_features=1, out_features=1, bias=True)
)
grad_input: (tensor([-0.4080]),)
grad_output: (tensor([-1.8476]),)
x.grad: tensor([-0.2448])
f1 weight: tensor([0.4000]) bias: tensor(0.2000) f1 weight grad: tensor([[-0.1938]]) f1 bias grad: tensor([-1.0199])
f0 weight: tensor([0.3000]) bias: tensor(0.1000) f1 weight grad: tensor([[-0.2448]]) f1 bias grad: tensor([-0.8159])
register_backward_hook
The difference between register_backward_hook
and register_full_backward_hook
is that the grad_input of register_backward_hook
refers to the gradient of the input of the last operation in the module, whereas in register_full_backward_hook
, it refers to the module’s input.
For register_backward_hook
, the grad_input is $\frac{\partial L}{\partial s} = -1.0119$.
After uncommenting register_backward_hook
and commenting out register_full_backward_hook
, the output changes as follows:
1.11.0
/Users/lifan/opt/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1033: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
y: tensor([0.0762], grad_fn=<PowBackward0>)
loss: tensor([0.8535], grad_fn=<PowBackward0>)
SLinear(
(f): Linear(in_features=1, out_features=1, bias=True)
)
grad_input: (tensor([-1.0199]),)
grad_output: (tensor([-1.8476]),)
x.grad: tensor([-0.1224])
f1 weight: tensor([0.4000]) bias: tensor(0.2000) f1 weight grad: tensor([[-0.1938]]) f1 bias grad: tensor([-1.0199])
f0 weight: tensor([0.3000]) bias: tensor(0.1000) f1 weight grad: tensor([[-0.1224]]) f1 bias grad: tensor([-0.4080])
Reference
https://stackoverflow.com/questions/65011884/understanding-backward-hooks
https://discuss.pytorch.org/t/how-to-manually-set-the-weights-in-a-two-layer-linear-model/45902