脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - pytorch和tensorflow计算Flops和params的详细过程

pytorch和tensorflow计算Flops和params的详细过程

2022-08-17 16:55qq_40840829 Python

这篇文章主要介绍了pytorch和tensorflow计算Flops和params,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下

pytorch和tensorflow计算Flops和params

1.只计算params

?
1
2
3
net = model()  # 定义好的网络模型
total = sum([param.nelement() for param in net.parameters()])
print("Number of parameter: %.2fM" % total)

这是网上很常见的直接用自带方法计算params,基本不会出错。胜在简洁。

2.计算flops和params

要计算flops,目前没见到用自带方法计算的,基本都是要安装别的库。
这边我们安装thop库。

?
1
pip install thop # 安装thop库
?
1
2
3
4
5
6
7
8
import torch
from thop import profile
net = model()  # 定义好的网络模型
img1 = torch.randn(1, 3, 512, 512)
img2 = torch.randn(1, 3, 512, 512)
img3 = torch.randn(1, 3, 512, 512)
macs, params = profile(net, (img1,img2,img3))
print('flops: ', 2*macs, 'params: ', params)

这边和其他网上教程的区别便是,他们macs和flops不分。因为macs表示乘加累积操作数一个乘法加上一个加法才算一个macs。而flops表示浮点运算次数,每一个加、减、乘、除操作都算1FLOPs操作。所以很明显,在数值上,1flops=2macs。此外,(img1,img2,img3)就表示你如果有三个输入要输入模型,就这样写

另外,要注意,params只和模型参数量相关,而和输入tensor大小无关。但flops和输入图片大小是相关的.

3.tensorflow计算params和flops

此处是我找到的一些用于tensorflow计算params和flops的方法,仅供参考,不保证效果。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def get_flops_params():
    sess = tf.compat.v1.Session()
    graph = sess.graph
    flops = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation())
    params = tf.compat.v1.profiler.profile(graph,
                                           options=tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter())
    print('FLOPs: {};    Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))
def count2():
    print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))
def get_nb_params_shape(shape):
    '''
    Computes the total number of params for a given shap.
    Works for any number of shapes etc [D,F] or [W,H,C] computes D*F and W*H*C.
    '''
    nb_params = 1
    for dim in shape:
        nb_params = nb_params * int(dim)
    return nb_params
def count3():
    tot_nb_params = 0
    for trainable_variable in tf.trainable_variables():
        shape = trainable_variable.get_shape()  # e.g [D,F] or [W,H,C]
        current_nb_params = get_nb_params_shape(shape)
        tot_nb_params = tot_nb_params + current_nb_params
    print(tot_nb_params)
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
from model import Model
import keras.backend as K
def get_flops(model):
    run_meta = tf.RunMetadata()
    opts = tf.profiler.ProfileOptionBuilder.float_operation()
    # We use the Keras session graph in the call to the profiler.
    flops = tf.profiler.profile(graph=K.get_session().graph,
                                run_meta=run_meta, cmd='op', options=opts)
    return flops.total_float_ops  # Prints the "flops" of the model.
# .... Define your model here ....
M = Model(BATCH_SIZE=1, INPUT_H=268, INPUT_W=360, is_training=False)
print(get_flops(M))

到此这篇关于pytorch和tensorflow计算Flops和params的文章就介绍到这了,更多相关pytorch和tensorflow计算内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

原文链接:https://blog.csdn.net/qq_40840829/article/details/126334037

延伸 · 阅读

精彩推荐
  • PythonLyScript获取上一条与下一条汇编指令的方法详解

    LyScript获取上一条与下一条汇编指令的方法详解

    LyScript 插件默认并没有提供上一条与下一条汇编指令的获取功能,当然你可以使用LyScriptTools工具包直接调用内置命令得到,本文就为大家详细讲讲如何实现...

    lyshark4652022-07-28
  • PythonPython检测网站链接是否已存在

    Python检测网站链接是否已存在

    Python是一种解释型、面向对象、动态数据类型的高级程序设计语言。通过本文给大家介绍Python检测网站链接是否已存在的相关内容,需要的朋友一起学习吧...

    jerrylsxu10492020-08-18
  • Pythonpython docx 中文字体设置的操作方法

    python docx 中文字体设置的操作方法

    今天小编就为大家分享一篇python docx 中文字体设置的操作方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    bsh_csn9052021-02-19
  • Pythonpython测试框架unittest和pytest区别

    python测试框架unittest和pytest区别

    这篇文章主要介绍了python测试框架unittest和pytest区别,帮助大家更好的理解和学习使用python进行自动化测试,感兴趣的朋友可以了解下...

    蘇小柒10142021-10-15
  • PythonPython同时处理多个异常的方法

    Python同时处理多个异常的方法

    这篇文章主要介绍了Python同时处理多个异常的方法,文中讲解非常细致,代码帮助大家更好的理解和学习,感兴趣的朋友可以了解下...

    David Beazley7362020-07-29
  • Python在Python中操作时间之tzset()方法的使用教程

    在Python中操作时间之tzset()方法的使用教程

    这篇文章主要介绍了在Python中操作时间之tzset()方法的使用教程,是Python学习中的基础知识,需要的朋友可以参考下...

    Python教程网3732020-07-06
  • Python详解opencv Python特征检测及K-最近邻匹配

    详解opencv Python特征检测及K-最近邻匹配

    这篇文章主要介绍了详解opencv Python特征检测及K-最近邻匹配,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧...

    天煞孤星0严11282021-05-19
  • Pythonpython根据url地址下载小文件的实例

    python根据url地址下载小文件的实例

    今天小编就为大家分享一篇python根据url地址下载小文件的实例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    insisted_search13302021-05-03