In PyTorch, there are two versions of register backward hooks. I will describe each of them in turn.

register_full_backward_hook

The registration function is defined as register_full_backward_hook(hook, prepend=False), where the hook function follows the signature:

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The parameters of this function are as follows:

  • module: The module to which the hook is registered.
  • grad_input. grad_input is the gradient of the current module’s input with respect to the variable on which backward() was called.
  • grad_output. grad_output is the gradient of the current module’s input with respect to the variable on which backward() was called.

If the return value is not None and is a tensor, it will be used as a new grad_input in subsequent computations.

We use the following example to illustrate this function.

Function

$$ \begin{aligned} &\text{Linear}\\ & & x_1 = \underbrace{0.3}_{w_0}x +\underbrace{0.1}_{b_0} \\ &\text{SLinear} \\ & & s = \underbrace{0.4}_{w_1}x_1 +\underbrace{0.2}_{b_1} \\ & & y = s^2 \\ &\text{L (Loss function)} \\ && L = (y - 1)^2 \end{aligned} $$

We set $x = 0.3$.

Forward

$$ \begin{aligned} &\text{Linear }\\ & & x_1 = \underbrace{0.3}_{w_0}x +\underbrace{0.1}_{b_0} = 0.3 * 0.3 + 0.1 = 0.19 \\ &\text{SLinear }\\ & & s = \underbrace{0.4}_{w_1}x_1 +\underbrace{0.2}_{b_1} = 0.4 * 0.19 + 0.2 = 0.276 \\ & & y = s^2 = 0.076176 \\ &\text{L (Loss function) }\\ & & L = (y - 1)^2 = 0.85345 & & \end{aligned} $$

Backward

$$ \begin{aligned} \frac{\partial L}{\partial y} &= 2(y - 1) = 2 * (0.076176 - 1) = -1.847648 & \text{grad\_output}\\ \frac{\partial y}{\partial s} &= 2s = 0.552 \\ \frac{\partial s}{\partial x_1} &= w_1 = 0.4 \\ \frac{\partial x_1}{\partial x} &= w_0 = 0.3 \\ \frac{\partial x_1}{\partial w_0} &= x = 0.3 \\ \frac{\partial x_1}{\partial b_0} &= 1 \\ \end{aligned} $$

According to the chain rule, we obtain

$$ \begin{aligned} \frac{\partial L}{\partial x_1} &= \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial s} \cdot \frac{\partial s}{\partial x_1} &= -1.847648 * 0.552 * 0.4 &= - 0.40796 & \text{grad\_input}\\ \frac{\partial L}{\partial x} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial x} &= - 0.40796 * 0.3 &= -0.12238 \\ \frac{\partial L}{\partial w_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial w_0} &= - 0.40796 * 0.3 &= -0.12238 \\ \frac{\partial L}{\partial b_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial b_0} &= - 0.40796 * 1 &= - 0.40796 \\ \end{aligned} $$

If we double the gradient input, the gradients of $w_0$ and $b_0$ will become twice their original values, as shown below:

$$ \begin{aligned} \frac{\partial L}{\partial w_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial w_0} &= 2 * - 0.40796 * 0.3 &= -0.244776 \\ \frac{\partial L}{\partial b_0} &= \frac{\partial L}{\partial x_1} \cdot \frac{\partial x_1}{\partial b_0} &= 2 * - 0.40796 * 1 &= -0.81592 \\ \end{aligned} $$

Code

The Python code is as follows:

def rescale_hook(layer, grad_input, grad_out):
    print('grad_input: ', grad_input)
    print('grad_output: ', grad_out)
    #return (grad_input[0] * 2, )

class SLinear(nn.Module):
    def __init__(self):
        super(SLinear, self).__init__()
        self.f = nn.Linear(1, 1, bias=True)
        with torch.no_grad():
            self.f.weight[0,0] = 0.4
            self.f.bias[0] = 0.2
        
    def forward(self, x):
        s = self.f(x) 
        y = s**2
        return y
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.f0 = nn.Linear(1, 1, bias=True)
        with torch.no_grad():
            self.f0.weight[0,0] = 0.3
            self.f0.bias[0] = 0.1
        self.f1 = SLinear()
        #self.f1.register_backward_hook(rescale_hook)
        self.f1.register_full_backward_hook(rescale_hook)

    def forward(self, x):
        x1 = self.f0(x)
        x2 = self.f1(x1)
        return x2

if __name__ == "__main__":
    torch.manual_seed(0)
    model = Net()
    x = torch.tensor([0.3])
    x.requires_grad = True
    #register_backward_hook(model)
    y = model(x)
    print('y: ', y)
    l = (y - 1)**2
    print("loss:", l)
    l.backward()
    print("x.grad: ", x.grad)
    print("f1 weight: ", model.f1.f.weight.data[0], "bias: ", model.f1.f.bias.data[0], "f1 weight grad: ", model.f1.f.weight.grad,  "f1 bias grad: ", model.f1.f.bias.grad)
    print("f0 weight: ", model.f0.weight.data[0], "bias: ", model.f0.bias.data[0], "f1 weight grad: ", model.f0.weight.grad,  "f1 bias grad: ", model.f0.bias.grad)

The output of this code snippet is as follows:

1.11.0
y:  tensor([0.0762], grad_fn=<BackwardHookFunctionBackward>)
loss: tensor([0.8535], grad_fn=<PowBackward0>)
SLinear(
  (f): Linear(in_features=1, out_features=1, bias=True)
)
grad_input:  (tensor([-0.4080]),)
grad_output:  (tensor([-1.8476]),)
x.grad:  tensor([-0.1224])
f1 weight:  tensor([0.4000]) bias:  tensor(0.2000) f1 weight grad:  tensor([[-0.1938]]) f1 bias grad:  tensor([-1.0199])
f0 weight:  tensor([0.3000]) bias:  tensor(0.1000) f1 weight grad:  tensor([[-0.1224]]) f1 bias grad:  tensor([-0.4080])

After uncommenting #return (grad_input[0] * 2, ), the output changes to the following:

1.11.0
y:  tensor([0.0762], grad_fn=<BackwardHookFunctionBackward>)
loss: tensor([0.8535], grad_fn=<PowBackward0>)
SLinear(
  (f): Linear(in_features=1, out_features=1, bias=True)
)
grad_input:  (tensor([-0.4080]),)
grad_output:  (tensor([-1.8476]),)
x.grad:  tensor([-0.2448])
f1 weight:  tensor([0.4000]) bias:  tensor(0.2000) f1 weight grad:  tensor([[-0.1938]]) f1 bias grad:  tensor([-1.0199])
f0 weight:  tensor([0.3000]) bias:  tensor(0.1000) f1 weight grad:  tensor([[-0.2448]]) f1 bias grad:  tensor([-0.8159])

register_backward_hook

The difference between register_backward_hook and register_full_backward_hook is that the grad_input of register_backward_hook refers to the gradient of the input of the last operation in the module, whereas in register_full_backward_hook, it refers to the module’s input.

For register_backward_hook, the grad_input is $\frac{\partial L}{\partial s} = -1.0119$.

After uncommenting register_backward_hook and commenting out register_full_backward_hook, the output changes as follows:

1.11.0
/Users/lifan/opt/anaconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py:1033: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
  warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
y:  tensor([0.0762], grad_fn=<PowBackward0>)
loss: tensor([0.8535], grad_fn=<PowBackward0>)
SLinear(
  (f): Linear(in_features=1, out_features=1, bias=True)
)
grad_input:  (tensor([-1.0199]),)
grad_output:  (tensor([-1.8476]),)
x.grad:  tensor([-0.1224])
f1 weight:  tensor([0.4000]) bias:  tensor(0.2000) f1 weight grad:  tensor([[-0.1938]]) f1 bias grad:  tensor([-1.0199])
f0 weight:  tensor([0.3000]) bias:  tensor(0.1000) f1 weight grad:  tensor([[-0.1224]]) f1 bias grad:  tensor([-0.4080])

register_forward_hook

We modify the output value of the SLinear, resulting in the following changes to the forward process:

Forward

$$ \begin{aligned} &\text{Linear }\\ & & x_1 = \underbrace{0.3}_{w_0}x +\underbrace{0.1}_{b_0} = 0.3 * 0.3 + 0.1 = 0.19 \\ &\text{SLinear }\\ & & s = \underbrace{0.4}_{w_1}x_1 +\underbrace{0.2}_{b_1} = 0.4 * 0.19 + 0.2 = 0.276 \\ & & y_{\text{ori}} = s^2 = 0.076176 \\ & & y = 2 * y_{\text{ori}} = 0.1524 & {\quad \text{hook}} \\ &\text{L (Loss function) }\\ & & L = (y - 1)^2 = 0.7185 \\ \end{aligned} $$

code

import torch 
import torch.nn as nn
print(torch.__version__)
def rescale_hook(layer, grad_input, grad_out):
    print(layer)
    print('grad_input: ', grad_input)
    print('grad_output: ', grad_out)
    #return (grad_input[0] * 2, )

class SLinear(nn.Module):
    def __init__(self):
        super(SLinear, self).__init__()
        self.f = nn.Linear(1, 1, bias=True)
        with torch.no_grad():
            self.f.weight[0,0] = 0.4
            self.f.bias[0] = 0.2
        
    def forward(self, x):
        s = self.f(x) 
        y = s**2
        return y
    


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.f0 = nn.Linear(1, 1, bias=True)
        with torch.no_grad():
            self.f0.weight[0,0] = 0.3
            self.f0.bias[0] = 0.1
        self.f1 = SLinear()

        #self.f1.register_backward_hook(rescale_hook)
        #self.f1.register_full_backward_hook(rescale_hook)

    def forward(self, x):
        x1 = self.f0(x)
        x2 = self.f1(x1)
        return x2

def forward_hook(module, input, output):
    print(f"Inside forward hook for {module.__class__.__name__}")
    print(f"Input shape: {input[0].shape}, \t", input)
    print(f"Output shape: {output.shape}, \t", output)
    print("--------")
    return output * 2

if __name__ == "__main__":
    torch.manual_seed(0)
    model = Net()
    model.f1.register_forward_hook(forward_hook)
    x = torch.tensor([0.3])
    x.requires_grad = True
    #register_backward_hook(model)
    y = model(x)
    print('y: ', y)
    l = (y - 1)**2
    print("loss:", l)

Reference

https://stackoverflow.com/questions/65011884/understanding-backward-hooks

https://discuss.pytorch.org/t/how-to-manually-set-the-weights-in-a-two-layer-linear-model/45902