脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - PyTorch实现MNIST数据集手写数字识别详情

PyTorch实现MNIST数据集手写数字识别详情

2022-09-07 14:00长浔 Python

这篇文章主要介绍了PyTorch实现MNIST数据集手写数字识别详情,文章围绕主题展开详细的内容戒杀,具有一定的参考价值,需要的朋友可以参考一下

前言:

本篇文章基于卷积神经网络CNN,使用PyTorch实现MNIST数据集手写数字识别。

一、PyTorch是什么?

PyTorch 是一个 Torch7 团队开源的 Python 优先的深度学习框架,提供两个高级功能:

  • 强大的 GPU 加速 Tensor 计算(类似 numpy)
  • 构建基于 tape 的自动升级系统上的深度神经网络

你可以重用你喜欢的 python 包,如 numpy、scipy 和 Cython ,在需要时扩展 PyTorch。

二、程序示例

下面案例可供运行参考

1.引入必要库

?
1
2
3
4
import torchvision
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

2.下载数据集

这里设置download=True,将会自动下载数据集,并存储在./data文件夹。

?
1
2
train_data = torchvision.datasets.MNIST(root="./data",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST(root="./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)

3.加载数据集

batch_size=32表示每一个batch中包含32张手写数字图片,shuffle=True表示打乱测试集(data和target仍一一对应)

?
1
2
train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
test_loader = DataLoader(test_data,batch_size=32,shuffle=False)

4.搭建CNN模型并实例化

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.con1 = torch.nn.Conv2d(1,10,kernel_size=5)
        self.con2 = torch.nn.Conv2d(10,20,kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(320,10)
    def forward(self,x):
        batch_size = x.size(0)
        x = F.relu(self.pooling(self.con1(x)))
        x = F.relu(self.pooling(self.con2(x)))
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x
#模型实例化       
model = Net()

5.交叉熵损失函数损失函数及SGD算法优化器

?
1
2
lossfun = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)

6.训练函数

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def train(epoch):
    running_loss = 0.0
    for i,(inputs,targets) in enumerate(train_loader,0):
        # inputs,targets = inputs.to(device),targets.to(device)
        opt.zero_grad()
        outputs = model(inputs)
        loss = lossfun(outputs,targets)
        loss.backward()
        opt.step()
 
        running_loss += loss.item()
        if i % 300 == 299:
            print('[%d,%d] loss:%.3f' % (epoch+1,i+1,running_loss/300))
            running_loss = 0.0

7.测试函数

?
1
2
3
4
5
6
7
8
9
10
11
def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for (inputs,targets) in test_loader:
            # inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _,predicted = torch.max(outputs.data,dim=1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    print(100*correct/total)

8.运行

?
1
2
3
4
if __name__ == '__main__':
    for epoch in range(20):
        train(epoch)
        test()

三、总结

到此这篇关于PyTorch实现MNIST数据集手写数字识别详情的文章就介绍到这了,更多相关PyTorch MNIST 内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

原文链接:https://blog.csdn.net/qq_41664447/article/details/126698428

延伸 · 阅读

精彩推荐
  • Pythonpython编程实现12306的一个小爬虫实例

    python编程实现12306的一个小爬虫实例

    这篇文章主要介绍了python编程实现12306的一个小爬虫实例,具有一定借鉴价值,需要的朋友可以参考下。...

    sentimental_dog7112020-12-29
  • Pythonpython 批量压缩图片的脚本

    python 批量压缩图片的脚本

    用Python编写的批量压缩图片的脚本,可以自定义压缩质量,有批量图片压缩需求的朋友可以直接拿来用...

    Mario-Hero5392021-11-19
  • PythonPython中字符编码简介、方法及使用建议

    Python中字符编码简介、方法及使用建议

    这篇文章主要介绍了Python中字符编码简介、方法及使用建议,需要的朋友可以参考下 ...

    脚本之家5002020-05-18
  • Pythonpython算法表示概念扫盲教程

    python算法表示概念扫盲教程

    这篇文章主要为大家详细介绍了python算法表示概念扫盲教程,具有一定的参考价值,感兴趣的小伙伴们可以参考一下...

    金角大王2472020-09-29
  • Python详解Pandas 处理缺失值指令大全

    详解Pandas 处理缺失值指令大全

    这篇文章主要介绍了详解Pandas 处理缺失值指令大全,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下...

    fitness suite9992020-07-30
  • PythonPython中常见的数据类型小结

    Python中常见的数据类型小结

    这篇文章主要对Python中常见的数据类型进行了总结归纳,很有参考借鉴价值,需要的朋友可以参考下...

    Python教程网5682020-07-30
  • PythonPython简单实现控制电脑的方法

    Python简单实现控制电脑的方法

    这篇文章主要介绍了Python简单实现控制电脑的方法,涉及Python基于os及win32api等模块调用系统命令操作电脑的相关实现技巧,需要的朋友可以参考下...

    Lovephysics15842021-01-07
  • PythonPython3 修改默认环境的方法

    Python3 修改默认环境的方法

    今天小编就为大家分享一篇Python3 修改默认环境的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    极客点儿7722021-05-29