pytorch-手写数字识别之全连接层实现

目录

  • 1. 背景
  • 2. nn.Linear线性层
  • 2. 实现MLP网络
  • 3. train
  • 4. 完整代码

1. 背景

上一篇https://blog.csdn.net/wyw0000/article/details/137622977?spm=1001.2014.3001.5502中实现手撸代码的方式实现了手写数字识别,本文将使用pytorch的API实现。

2. nn.Linear线性层

相当于x = x@w1.t() + b1
因此可以使用nn.Linear代替,上一篇文中中的

w1, b1 = torch.randn(200, 784, requires_grad=True),\torch.zeros(200, requires_grad=True)
x = x@w1.t() + b1

使用nn.Linear定义的三层网络,如下图所示:
在这里插入图片描述
增加激活函数relu
在这里插入图片描述

2. 实现MLP网络

  • 实现__init__函数
  • 将网络各层放到序列化容器Sequential中
    见下图使用nn.Sequential创建一个小的model,运行时输入首先传给nn.Linear(784, 200),nn.Linear(784, 200)的输出再传给nn.ReLU(inplace=True),这样依次传递下去,直至结束。
  • 实现forward
    调用将输入x作为model的参数调用model,并将结果返回。
    在这里插入图片描述

3. train

在这里插入图片描述

4. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.ReLU(inplace=True),nn.Linear(200, 200),nn.ReLU(inplace=True),nn.Linear(200, 10),nn.ReLU(inplace=True),)def forward(self, x):x = self.model(x)return xnet = MLP()
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss()for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)logits = net(data)test_loss += criteon(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/825664.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开发一个农场小游戏需要多少钱

开发一个农场小游戏的费用因多个因素而异,包括但不限于游戏的规模、复杂性、功能需求、设计复杂度、开发团队的规模和经验,以及项目的时间周期等。因此,无法给出确切的费用数字。 具体来说,游戏的复杂程度和包含的功能特性数量会直…

企业文档知识库建设,数据安全如何保障?

随着现代市场经济的高速发展,企业的竞争优势越来越多体现在人才和科技的优势。而随着员工流动率的提升,随之流失的则是员工积累多年的宝贵工作经验,如果缺乏有效的内部知识库的建设和管理,企业的竞争优势将难以维系。「企业网盘」…

Claude和chatgpt的区别

ChatGPT是OpenAI开发的人工智能的聊天机器人,它可以生成文章、代码并执行各种任务。是Open AI发布的第一款大语言模型,GPT4效果相比chatgpt大幅提升。尤其是最新版的模型,OpenAI几天前刚刚发布的GPT-4-Turbo-2024-04-09版本,大幅超…

架构设计-流程引擎的架构设计

1、什么是流程引擎 流程引擎是一个底层支撑平台,是为提供流程处理而开发设计的。流程引擎和流程应用,以及应用程序的关系如下图所示。 常见的支撑场景有:Workflow、BPM、流程编排等。本次分享,主要从 BPM 流程引擎切入&#xff0…

【前端】3. CSS【万字长文】

CSS 是什么 层叠样式表 (Cascading Style Sheets). CSS 能够对网页中元素位置的排版进行像素级精确控制, 实现美化页面的效果. 能够做到页面的样式和结构分离. CSS 就是 “东方四大邪术” 之化妆术. 基本语法规范 选择器 {一条/N条声明} 选择器决定针对谁修改 (找谁)声明决…

钉钉直播回放怎么下载到本地

钉钉直播回放如何下载到本地,本文就给大家解密如何下载到本地 工具我已经给大家打包好了 钉钉直播回放下载软件链接:https://pan.baidu.com/s/1_4NZLfENDxswI2ANsQVvpw?pwd1234 提取码:1234 --来自百度网盘超级会员V10的分享 1.首先解压好我给大家…

【Qt】Qt Hello World 程序

文章目录 1、Qt Hello World 程序1.1 使用按钮实现1.1.1 使用可视化方式实现 1.1.2 纯代码方式实现 label创建堆(内存泄漏)或者栈问题Qt基础类(Qstring、Qvector、Qlist)乱码问题零散知识 1、Qt Hello World 程序 1.1 使用按钮实…

Swin Transformer 浅析

Swin Transformer 浅析 文章目录 Swin Transformer 浅析引言Swin Transformer 的网络结构W-MSA 窗口多头注意力机制SW-MSA 滑动窗口多头注意力机制Patch Merging 图块合并 引言 因为ViT无法实现CNN中的层次化构建以及局部信息,由此微软团队提出了Swin Transformer来…

C语言(二维数组)

Hi~!这里是奋斗的小羊,很荣幸各位能阅读我的文章,诚请评论指点,关注收藏,欢迎欢迎~~ 💥个人主页:小羊在奋斗 💥所属专栏:C语言 本系列文章为个人学习笔记&#x…

15.7 2011年42题真题讲解

2,4,6,8,11,13,15,17,19,20 可以推出题目的一个隐含条件:偶数个元素的中位数是靠前的那一个 应试技巧:如果实在想不出高效的算法,那…

基于springboot+vue+Mysql的房产销售平台

开发语言:Java框架:springcloudJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包&#xff1a…

【详细讲解CentOS常用的命令】

🌈个人主页: 程序员不想敲代码啊 🏆CSDN优质创作者,CSDN实力新星,CSDN博客专家 👍点赞⭐评论⭐收藏 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共…

SQLite FTS5 扩展(三十)

返回:SQLite—系列文章目录 上一篇:SQLite的知名用户(二十九) 下一篇:SQLite—系列文章目录 1. FTS5概述 FTS5 是一个 SQLite 虚拟表模块,它为数据库应用程序提供全文搜索功能。在最基本的形式中, 全文搜索引擎允许用户有…

Dinov2 + Faiss 图片检索

MetaAI 通过开源 DINOv2,在计算机视觉领域取得了一个显着的里程碑,这是一个在包含1.42 亿张图像的令人印象深刻的数据集上训练的模型。产生适用于图像级视觉任务(图像分类、实例检索、视频理解)以及像素级视觉任务(深度…

【leetcode面试经典150题】57. 环形链表(C++)

【leetcode面试经典150题】专栏系列将为准备暑期实习生以及秋招的同学们提高在面试时的经典面试算法题的思路和想法。本专栏将以一题多解和精简算法思路为主,题解使用C语言。(若有使用其他语言的同学也可了解题解思路,本质上语法内容一致&…

vivado 使用 JTAG-to-AXI Master 调试核进行硬件系统通信

使用 JTAG-to-AXI Master 调试核进行硬件系统通信 JTAG-to-AXI Master 调试核为可自定义核 , 可在运行时生成 AXI 传输事务并驱动 FPGA 内部的 AXI 信号。该核支持所 有存储器映射型 AXI 接口和 AXI4-Lite 接口 , 并且可支持位宽为 32 或 64 …

免费的数据恢复软件有哪些?推荐10款免费的数据恢复软件!

通过使用功能强大的免费和最好的数据恢复软件,您可以取消删除重要文件和文档。丢失文件是每个人在许多情况下面临的常见问题,这些数据恢复程序可以充当完美的救星。 我们编制了 2024年的最佳软件列表。这些工具易于使用,您可以通过如何在 PC…

Spring Boot 目前还是最先进的吗?

当谈到现代Java开发框架时,Spring Boot一直处于领先地位。它目前不仅是最先进的,而且在Java生态系统中拥有着巨大的影响力。 1. 什么是Spring Boot? Spring Boot是由Spring团队开发的开源框架,旨在简化基于Spring的应用程序的开…

接收区块链的CCF会议--ICSOC 2024 截止7.24

ICSOC是CCF B类会议(软件工程/系统软件/程序设计语言) 2023年长文短文录用率22% Focus Area 4: Emerging Technologies Quantum Service Computing Digital Twins 3D Printing/additive Manufacturing Techniques Blockchain Robotic Process Autom…

N元语言模型

第1关:预测句子概率 任务描述 本关任务:利用二元语言模型计算句子的概率 相关知识 为了完成本关任务,你需要掌握:1.条件概率计算方式。 2.二元语言模型相关知识。 条件概率计算公式 条件概率是指事件A在事件B发生的条件下发…