文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

Python tensorflow与pytorch的浮点运算数怎么计算

2023-07-04 15:24

关注

这篇文章主要讲解了“Python tensorflow与pytorch的浮点运算数怎么计算”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“Python tensorflow与pytorch的浮点运算数怎么计算”吧!

1. 引言

FLOPs 是 floating point operations 的缩写,指浮点运算数,可以用来衡量模型/算法的计算复杂度。本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对应模型的 FLOPs。

2. 模型结构

为了说明方便,先搭建一个简单的神经网络模型,其模型结构以及主要参数如表1 所示。

表 1 模型结构及主要参数

LayerschannelsKernelsStridesUnitsActivation
Conv2D32(4,4)(1,2)\relu
GRU\\\96\
Dense\\\256sigmoid

用 tensorflow(实际使用 tensorflow 中的 keras 模块)实现该模型的代码为:

from tensorflow.keras.layers import *from tensorflow.keras.models import load_model, Modeldef test_model_tf(Input_shape):    # shape: [B, C, T, F]    main_input = Input(batch_shape=Input_shape, name='main_inputs')    conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation='relu', data_format='channels_first', name='conv')(main_input)    # shape: [B, T, FC]    gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv)    gru = GRU(units=96, reset_after=True, return_sequences=True, name='gru')(gru)    output = Dense(256, activation='sigmoid', name='output')(gru)    model = Model(inputs=[main_input], outputs=[output])    return model

用 pytorch 实现该模型的代码为:

import torchimport torch.nn as nnclass test_model_torch(nn.Module):    def __init__(self):        super(test_model_torch, self).__init__()        self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2))        self.relu = nn.ReLU()        self.gru = nn.GRU(input_size=4064, hidden_size=96)        self.fc = nn.Linear(96, 256)        self.sigmoid = nn.Sigmoid()    def forward(self, inputs):        # shape: [B, C, T, F]        out = self.conv2d(inputs)        out = self.relu(out)        # shape: [B, T, FC]        batch, channel, frame, freq = out.size()        out = torch.reshape(out, (batch, frame, freq*channel))        out, _ = self.gru(out)        out = self.fc(out)        out = self.sigmoid(out)        return out

3. 计算模型的 FLOPs

本节讨论的版本具体为:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102。

3.1. tensorflow 1.12.0

在 tensorflow 1.12.0 环境中,可以使用以下代码计算模型的 FLOPs:

import tensorflow as tfimport tensorflow.keras.backend as Kdef get_flops(model):    run_meta = tf.RunMetadata()    opts = tf.profiler.ProfileOptionBuilder.float_operation()    flops = tf.profiler.profile(graph=K.get_session().graph,                                run_meta=run_meta, cmd='op', options=opts)    return flops.total_float_opsif __name__ == "__main__":    x = K.random_normal(shape=(1, 1, 100, 256))    model = test_model_tf(x.shape)    print('FLOPs of tensorflow 1.12.0:', get_flops(model))

3.2. tensorflow 2.3.1

在 tensorflow 2.3.1 环境中,可以使用以下代码计算模型的 FLOPs :

import tensorflow.compat.v1 as tfimport tensorflow.compat.v1.keras.backend as Ktf.disable_eager_execution()def get_flops(model):    run_meta = tf.RunMetadata()    opts = tf.profiler.ProfileOptionBuilder.float_operation()    flops = tf.profiler.profile(graph=K.get_session().graph,                                run_meta=run_meta, cmd='op', options=opts)    return flops.total_float_opsif __name__ == "__main__":    x = K.random_normal(shape=(1, 1, 100, 256))    model = test_model_tf(x.shape)    print('FLOPs of tensorflow 2.3.1:', get_flops(model))

3.3. pytorch 1.10.1+cu102

在 pytorch 1.10.1+cu102 环境中,可以使用以下代码计算模型的 FLOPs(需要安装 thop):

import thopx = torch.randn(1, 1, 100, 256)model = test_model_torch()flops, _ = thop.profile(model, inputs=(x,))print('FLOPs of pytorch 1.10.1:', flops * 2)

需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代码有乘 2 2 2 操作。

3.4. 结果对比

三者计算出的 FLOPs 分别为:

tensorflow 1.12.0:

Python tensorflow与pytorch的浮点运算数怎么计算

tensorflow 2.3.1:

Python tensorflow与pytorch的浮点运算数怎么计算

pytorch 1.10.1:

Python tensorflow与pytorch的浮点运算数怎么计算

可以看到 tensorflow 1.12.0 和 tensorflow 2.3.1 的结果基本在同一个量级,而与 pytorch 1.10.1 计算出来的相差甚远。但如果将上述模型结构改为只包含第一层 Conv2D,三者计算出来的 FLOPs 却又是一致的。所以推断差异主要来自于 GRU 的 FLOPs。

感谢各位的阅读,以上就是“Python tensorflow与pytorch的浮点运算数怎么计算”的内容了,经过本文的学习后,相信大家对Python tensorflow与pytorch的浮点运算数怎么计算这一问题有了更深刻的体会,具体使用情况还需要大家实践验证。这里是编程网,小编将为大家推送更多相关知识点的文章,欢迎关注!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯