脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - pytorch中的hook机制register_forward_hook

pytorch中的hook机制register_forward_hook

2022-10-28 10:58机器学习入坑者 Python

这篇文章主要介绍了pytorch中的hook机制register_forward_hook,手动在forward之前注册hook,hook在forward执行以后被自动执行,下面详细的内容介绍,需要的小伙伴可以参考一下

1、hook背景

Hook被成为钩子机制,这不是pytorch的首创,在Windows的编程中已经被普遍采用,包括进程内钩子和全局钩子。按照自己的理解,hook的作用是通过系统来维护一个链表,使得用户拦截(获取)通信消息,用于处理事件。

pytorch中包含forwardbackward两个钩子注册函数,用于获取forward和backward中输入和输出,按照自己不全面的理解,应该目的是“不改变网络的定义代码,也不需要在forward函数中return某个感兴趣层的输出,这样代码太冗杂了”。

2、源码阅读

register_forward_hook()函数必须在forward()函数调用之前被使用,因为这个函数源码注释显示这个函数“ it will not have effect on forward since this is called after :func:`forward` is called”,也就是这个函数在forward()之后就没有作用了!!!):

作用:获取forward过程中每层的输入和输出,用于对比hook是不是正确记录。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def register_forward_hook(self, hook):
        r"""Registers a forward hook on the module.
        The hook will be called every time after :func:`forward` has computed an output.
        It should have the following signature::
            hook(module, input, output) -> None or modified output
        The hook can modify the output. It can modify the input inplace but
        it will not have effect on forward since this is called after
        :func:`forward` is called.
 
        Returns:
            :class:`torch.utils.hooks.RemovableHandle`:
                a handle that can be used to remove the added hook by calling
                ``handle.remove()``
        """
        handle = hooks.RemovableHandle(self._forward_hooks)
        self._forward_hooks[handle.id] = hook
        return handle

3、定义一个用于测试hooker的类

如果随机的初始化每个层,那么就无法测试出自己获取的输入输出是不是forward中的输入输出了,所以需要将每一层的权重和偏置设置为可识别的值(比如全部初始化为1)。网络包含两层(Linear有需要求导的参数被称为一个层,而ReLU没有需要求导的参数不被称作一层),__init__()中调用initialize函数对所有层进行初始化。

注意:在forward()函数返回各个层的输出,但是ReLU6没有返回,因为后续测试的时候不对这一层进行注册hook。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class TestForHook(nn.Module):
    def __init__(self):
        super().__init__()
 
        self.linear_1 = nn.Linear(in_features=2, out_features=2)
        self.linear_2 = nn.Linear(in_features=2, out_features=1)
        self.relu = nn.ReLU()
        self.relu6 = nn.ReLU6()
        self.initialize()
 
    def forward(self, x):
        linear_1 = self.linear_1(x)
        linear_2 = self.linear_2(linear_1)
        relu = self.relu(linear_2)
        relu_6 = self.relu6(relu)
        layers_in = (x, linear_1, linear_2)
        layers_out = (linear_1, linear_2, relu)
        return relu_6, layers_in, layers_out
    def initialize(self):
        """ 定义特殊的初始化,用于验证是不是获取了权重"""
        self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))
        self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))
        self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))
        self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))
        return True

4、定义hook函数

hook()函数是register_forward_hook()函数必须提供的参数,好处是“用户可以自行决定拦截了中间信息之后要做什么!”,比如自己想单纯的记录网络的输入输出(也可以进行修改等更加复杂的操作)。

首先定义几个容器用于记录:

定义用于获取网络各层输入输出tensor的容器:

?
1
2
3
4
5
# 并定义module_name用于记录相应的module名字
module_name = []
features_in_hook = []
features_out_hook = []
hook函数需要三个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数:

hook函数负责将获取的输入输出添加到feature列表中;并提供相应的module名字

?
1
2
3
4
5
6
def hook(module, fea_in, fea_out):
    print("hooker working")
    module_name.append(module.__class__)
    features_in_hook.append(fea_in)
    features_out_hook.append(fea_out)
    return None

5、对需要的层注册hook

注册钩子必须在forward()函数被执行之前,也就是定义网络进行计算之前就要注册,下面的代码对网络除去ReLU6以外的层都进行了注册(也可以选定某些层进行注册):

注册钩子可以对某些层单独进行:

?
1
2
3
4
5
net = TestForHook()
net_chilren = net.children()
for child in net_chilren:
    if not isinstance(child, nn.ReLU6):
        child.register_forward_hook(hook=hook)

6、测试forward()返回的特征和hook记录的是否一致

6.1 测试forward()提供的输入输出特征

由于前面的forward()函数返回了需要记录的特征,这里可以直接测试:

?
1
2
3
4
5
out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)

得到下面的输出是理所当然的:

*****forward return features*****
(tensor([[0.1000, 0.1000],
        [0.1000, 0.1000]]), tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>))
(tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<ThresholdBackward0>))
*****forward return features*****

6.2 hook记录的输入特征和输出特征

hook通过list结构进行记录,所以可以直接print

测试features_in是不是存储了输入:

?
1
2
3
4
5
print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)

得到和forward一样的结果:

*****hook record features*****
[(tensor([[0.1000, 0.1000],
        [0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>),)]
[tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<ThresholdBackward0>)]
[<class 'torch.nn.modules.linear.Linear'>, 
<class 'torch.nn.modules.linear.Linear'>,
 <class 'torch.nn.modules.activation.ReLU'>]
*****hook record features*****

6.3 把hook记录的和forward做减法

如果害怕会有小数点后面的数值不一致,或者数据类型的不匹配,可以对hook记录的特征和forward记录的特征做减法:

测试forward返回的feautes_in是不是和hook记录的一致:

?
1
2
3
print("sub result'")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):
    print(forward_return-hook_record[0])

得到的全部都是0,说明hook没问题:

?
1
2
3
4
5
6
7
sub result
tensor([[0., 0.],
        [0., 0.]])
tensor([[0., 0.],
        [0., 0.]], grad_fn=<SubBackward0>)
tensor([[0.],
        [0.]], grad_fn=<SubBackward0>)

7、完整代码

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn as nn
 
 
class TestForHook(nn.Module):
    def __init__(self):
        super().__init__()
 
        self.linear_1 = nn.Linear(in_features=2, out_features=2)
        self.linear_2 = nn.Linear(in_features=2, out_features=1)
        self.relu = nn.ReLU()
        self.relu6 = nn.ReLU6()
        self.initialize()
 
    def forward(self, x):
        linear_1 = self.linear_1(x)
        linear_2 = self.linear_2(linear_1)
        relu = self.relu(linear_2)
        relu_6 = self.relu6(relu)
        layers_in = (x, linear_1, linear_2)
        layers_out = (linear_1, linear_2, relu)
        return relu_6, layers_in, layers_out
 
    def initialize(self):
        """ 定义特殊的初始化,用于验证是不是获取了权重"""
        self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))
        self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))
        self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))
        self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))
        return True

定义用于获取网络各层输入输出tensor的容器,并定义module_name用于记录相应的module名字

?
1
2
3
module_name = []
features_in_hook = []
features_out_hook = []

hook函数负责将获取的输入输出添加到feature列表中,并提供相应的module名字

?
1
2
3
4
5
6
def hook(module, fea_in, fea_out):
    print("hooker working")
    module_name.append(module.__class__)
    features_in_hook.append(fea_in)
    features_out_hook.append(fea_out)
    return None

定义全部是1的输入:

?
1
x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])

注册钩子可以对某些层单独进行:

?
1
2
3
4
5
net = TestForHook()
net_chilren = net.children()
for child in net_chilren:
    if not isinstance(child, nn.ReLU6):
        child.register_forward_hook(hook=hook)

测试网络输出:

out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)

测试features_in是不是存储了输入:

?
1
2
3
4
5
print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)

测试forward返回的feautes_in是不是和hook记录的一致:

print("sub result")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):
    print(forward_return-hook_record[0])

 到此这篇关于pytorch中的hook机制register_forward_hook的文章就介绍到这了,更多相关pytorch中的hook机制内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

原文链接:https://zhuanlan.zhihu.com/p/87853615

延伸 · 阅读

精彩推荐
  • Pythonpython的变量与赋值详细分析

    python的变量与赋值详细分析

    这篇文章主要介绍了python的变量与赋值详细分析,具有一定参考价值,需要的朋友可以了解下。...

    ZHANGONE5752020-12-16
  • Python利用Python实现端口扫描器的全过程

    利用Python实现端口扫描器的全过程

    这篇文章主要给大家介绍了关于如何利用Python实现端口扫描器的相关资料,用来检测目标服务器上有哪些端口开放,本文适用于有 Python和计算机网络语言基础...

    tigeriaf8842021-12-20
  • PythonPycharm挂代理后依旧插件下载慢的完美解决方法

    Pycharm挂代理后依旧插件下载慢的完美解决方法

    狠多朋友在使用Pycharm插件时,反应下载速度很慢,挂载了代理还是不够,怎么解决这一问题呢,下面小编给大家代理了Pycharm插件下载慢的完美解决方法,...

    Mr..Nobody5242021-12-17
  • PythonPython3爬虫使用Fidder实现APP爬取示例

    Python3爬虫使用Fidder实现APP爬取示例

    这篇文章主要介绍了Python3爬虫使用Fidder实现APP爬取示例,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧...

    TM08317522021-04-22
  • Pythonpython 每天如何定时启动爬虫任务(实现方法分享)

    python 每天如何定时启动爬虫任务(实现方法分享)

    python 每天如何定时启动爬虫任务?今天小编就为大家分享一篇python 实现每天定时启动爬虫任务的方法。具有很好的参考价值,希望对大家有所帮助。一起...

    大蛇王9182021-02-22
  • PythonPython中单例模式总结

    Python中单例模式总结

    单例模式(Singleton Pattern)是一种常用的软件设计模式,该模式的主要目的是确保某一个类只有一个实例存在。当你希望在整个系统中,某个类只能出现一...

    孟庆健4512021-01-16
  • Pythonvirtualenv 指定 python 解释器的版本方法

    virtualenv 指定 python 解释器的版本方法

    今天小编就为大家分享一篇virtualenv 指定 python 解释器的版本方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    Inside_Zhang9482021-04-12
  • PythonPython实现检测服务器是否可以ping通的2种方法

    Python实现检测服务器是否可以ping通的2种方法

    这篇文章主要介绍了Python实现检测服务器是否可以ping通的2种方法,本文分别讲解了使用ping和fping命令检测服务器是否可以ping通,需要的朋友可以参考下 ...

    脚本之家13452020-05-18