In PyTorch, there are two versions of register backward hooks. I will describe each of them in turn.

register_full_backward_hook

The registration function is defined as register_full_backward_hook(hook, prepend=False), where the hook function follows the signature:

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The parameters of this function are as follows:

  • module: The module to which the hook is registered.
  • grad_input. grad_input is the gradient of the current module’s input with respect to the variable on which backward() was called.
  • grad_output. grad_output is the gradient of the current module’s input with respect to the variable on which backward() was called.

If the return value is not None and is a tensor, it will be used as a new grad_input in subsequent computations.

We use the following example to illustrate this function.

Function

$$ \begin{aligned} &\text{Linear}\\ & & x_1 = \underbrace{0.3}_{w_0}x +\underbrace{0.1}_{b_0} \\ &\text{SLinear} \\ & & s = \underbrace{0.4}_{w_1}x_1 +\underbrace{0.2}_{b_1} \\ & & y = s^2 \\ &\text{L (Loss function)} \\ && L = (y - 1)^2 \end{aligned} $$

We set $x = 0.3$.

Forward

$$ \begin{aligned} &\text{Linear }\\ & & x_1 = \underbrace{0.3}_{w_0}x +\underbrace{0.1}_{b_0} = 0.3 * 0.3 + 0.1 = 0.19 \\ &\text{SLinear }\\ & & s = \underbrace{0.4}_{w_1}x_1 +\underbrace{0.2}_{b_1} = 0.4 * 0.19 + 0.2 = 0.276 \\ & & y = s^2 = 0.076176 \\ &\text{L (Loss function) }\\ & & L = (y - 1)^2 = 0.85345 & & \end{aligned} $$

Backward

$$ \begin{aligned} \frac{\partial L}{\partial y} &= 2(y - 1) = 2 * (0.076176 - 1) = -1.847648 & \text{grad\_output}\\ \frac{\partial y}{\partial s} &= 2s = 0.552 \\ \frac{\partial s}{\partial x_1} &= w_1 = 0.4 \\ \frac{\partial x_1}{\partial x} &= w_0 = 0.3 \\ \frac{\partial x_1}{\partial w_0} &= x = 0.3 \\ \frac{\partial x_1}{\partial b_0} &= 1 \\ \end{aligned} $$

According to the chain rule, we obtain

$$ \begin{aligned} \frac{\partial L}{\partial x_1} &= \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial s} \cdot \frac{\partial s}{\partial x_1} &= -1.847648 * 0.552 * 0.4 &= - 0.40796 & \text{grad\_input}\\ \frac{\partial L}{\partial x} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial x} &= - 0.40796 * 0.3 &= -0.12238 \\ \frac{\partial L}{\partial w_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial w_0} &= - 0.40796 * 0.3 &= -0.12238 \\ \frac{\partial L}{\partial b_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial b_0} &= - 0.40796 * 1 &= - 0.40796 \\ \end{aligned} $$

If we double the gradient input, the gradients of $w_0$ and $b_0$ will become twice their original values, as shown below:

$$ \begin{aligned} \frac{\partial L}{\partial w_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial w_0} &= 2 * - 0.40796 * 0.3 &= -0.244776 \\ \frac{\partial L}{\partial b_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial b_0} &= 2 * - 0.40796 * 1 &= -0.81592 \\ \end{aligned} $$

Code

The Python code is as follows:

def rescale_hook(layer, grad_input, grad_out):
    print('grad_input: ', grad_input)
    print('grad_output: ', grad_out)
    #return (grad_input[0] * 2, )

class SLinear(nn.Module):
    def __init__(self):
        super(SLinear, self).__init__()
        self.f = nn.Linear(1, 1, bias=True)
        with torch.no_grad():
            self.f.weight[0,0] = 0.4
            self.f.bias[0] = 0.2
        
    def forward(self, x):
        s = self.f(x) 
        y = s**2
        return y
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.f0 = nn.Linear(1, 1, bias=True)
        with torch.no_grad():
            self.f0.weight[0,0] = 0.3
            self.f0.bias[0] = 0.1
        self.f1 = SLinear()
        #self.f1.register_backward_hook(rescale_hook)
        self.f1.register_full_backward_hook(rescale_hook)

    def forward(self, x):
        x1 = self.f0(x)
        x2 = self.f1(x1)
        return x2

The output of this code snippet is as follows:

1.11.0
y:  tensor([0.0762], grad_fn=<BackwardHookFunctionBackward>)
loss: tensor([0.8535], grad_fn=<PowBackward0>)
SLinear(
  (f): Linear(in_features=1, out_features=1, bias=True)
)
grad_input:  (tensor([-0.4080]),)
grad_output:  (tensor([-1.8476]),)
x.grad:  tensor([-0.1224])
f1 weight:  tensor([0.4000]) bias:  tensor(0.2000) f1 weight grad:  tensor([[-0.1938]]) f1 bias grad:  tensor([-1.0199])
f0 weight:  tensor([0.3000]) bias:  tensor(0.1000) f1 weight grad:  tensor([[-0.1224]]) f1 bias grad:  tensor([-0.4080])

After uncommenting #return (grad_input[0] * 2, ), the output changes to the following:

1.11.0
y:  tensor([0.0762], grad_fn=<BackwardHookFunctionBackward>)
loss: tensor([0.8535], grad_fn=<PowBackward0>)
SLinear(
  (f): Linear(in_features=1, out_features=1, bias=True)
)
grad_input:  (tensor([-0.4080]),)
grad_output:  (tensor([-1.8476]),)
x.grad:  tensor([-0.2448])
f1 weight:  tensor([0.4000]) bias:  tensor(0.2000) f1 weight grad:  tensor([[-0.1938]]) f1 bias grad:  tensor([-1.0199])
f0 weight:  tensor([0.3000]) bias:  tensor(0.1000) f1 weight grad:  tensor([[-0.2448]]) f1 bias grad:  tensor([-0.8159])

register_backward_hook

The difference between register_backward_hook and register_full_backward_hook is that the grad_input of register_backward_hook refers to the gradient of the input of the last operation in the module, whereas in register_full_backward_hook, it refers to the module’s input.

For register_backward_hook, the grad_input is $\frac{\partial L}{\partial s} = -1.0119$.

After uncommenting register_backward_hook and commenting out register_full_backward_hook, the output changes as follows:

1.11.0
/Users/lifan/opt/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1033: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
y:  tensor([0.0762], grad_fn=<PowBackward0>)
loss: tensor([0.8535], grad_fn=<PowBackward0>)
SLinear(
  (f): Linear(in_features=1, out_features=1, bias=True)
)
grad_input:  (tensor([-1.0199]),)
grad_output:  (tensor([-1.8476]),)
x.grad:  tensor([-0.1224])
f1 weight:  tensor([0.4000]) bias:  tensor(0.2000) f1 weight grad:  tensor([[-0.1938]]) f1 bias grad:  tensor([-1.0199])
f0 weight:  tensor([0.3000]) bias:  tensor(0.1000) f1 weight grad:  tensor([[-0.1224]]) f1 bias grad:  tensor([-0.4080])

Reference

https://stackoverflow.com/questions/65011884/understanding-backward-hooks

https://discuss.pytorch.org/t/how-to-manually-set-the-weights-in-a-two-layer-linear-model/45902