【深度学习】在 MNIST实现自动编码器实践教程

一、说明

        自动编码器是一种无监督学习的神经网络模型,主要用于降维或特征提取。常见的自动编码器包括基本的单层自动编码器、深度自动编码器、卷积自动编码器和变分自动编码器等。

        其中,基本的单层自动编码器由一个编码器和一个解码器组成,编码器将输入数据压缩成低维数据,解码器将低维数据还原成原始数据。深度自动编码器是在单层自动编码器的基础上增加了多个隐藏层,可以实现更复杂的特征提取。卷积自动编码器则是针对图像等数据特征提取的一种自动编码器,它使用卷积神经网络进行特征提取和重建。变分自动编码器则是一种生成式模型,可以用于生成新的数据样本。

        总的来说,不同类型的自动编码器适用于不同类型的数据和问题,选择合适的自动编码器可以提高模型的性能。

二、在Minist数据集实现自动编码器

2.1 概述

        本文中的代码用于在 MNIST 数据集上训练自动编码器。自动编码器是一种旨在重建其输入的神经网络。在此脚本中,自动编码器由两个较小的网络组成:编码器和解码器。编码器获取输入图像,将其压缩为 64 个特征,并将编码表示传递给解码器,然后解码器重建输入图像。自动编码器通过最小化重建图像和原始图像之间的均方误差来训练。该脚本首先加载 MNIST 数据集并规范化像素值。然后,它将图像重塑为一维表示,以便可以将其输入神经网络。之后,使用tensorflow.keras库中的输入层和密集层创建编码器和解码器模型。自动编码器模型是通过链接编码器和解码器模型创建的。然后使用亚当优化器和均方误差损失函数编译自动编码器。最后,自动编码器在归一化和重塑的MNIST图像上训练25个epoch。通过绘制训练集和测试集在 epoch 上的损失来监控训练进度。训练后,脚本绘制一些测试图像及其相应的重建。此外,还计算了原始图像和重建图像之间的均方误差和结构相似性指数(SSIM)。

        下图显示了模型的良好拟合,可以看到模型的良好拟合。

训练和测试数据的模型丢失

        该代码比较两个图像,一个来自测试集的原始图像和一个由自动编码器生成的预测图像。它使用该函数计算两个图像之间的均方误差 (MSE),并使用 scikit-image 库中的函数计算两个图像之间的结构相似性指数 (SSIM)。根据 mse 和 ssim 代码检索test_labels以打印测试图像的值。msessim

2.2 代码实现

import numpy as np
import tensorflow
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.layers import Layer 
from skimage import metrics
## import os can be skipped if there is nocompatibility issue 
## with the OpenMP library and TensorFlow 
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"# Load the MNIST dataset
(x_train, train_labels), (x_test, test_labels) = mnist.load_data()# Normalize the data
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.# Flatten the images
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))# Randomize both the training and test
permutation = np.random.permutation(len(x_train))
x_train, train_labels = x_train[permutation], train_labels[permutation]
permutation = np.random.permutation(len(x_test))
x_test, test_labels = x_test[permutation], test_labels[permutation]
# Create the encoderlist_xtest = [ [x_test[i], test_labels[i]] for i in test_labels] 
print(len(list_xtest)) encoder_input = Input(shape=(784,))
encoded = Dense(64, activation='relu')(encoder_input)
encoder = Model(encoder_input, encoded)# Create the decoder
decoder_input = Input(shape=(64,))
decoded = Dense(784, activation='sigmoid')(decoder_input)
decoder = Model(decoder_input, decoded)# Create the autoencoder
autoencoder = Model(encoder_input, decoder(encoder(encoder_input)))lr_schedule = tensorflow.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate = 5e-01, decay_steps = 2500, decay_rate = 0.75,staircase=True) 
tensorflow.keras.optimizers.Adam(learning_rate = lr_schedule,beta_1=0.95,beta_2=0.99,epsilon=1e-01)
autoencoder.compile(optimizer='adam', loss='mean_squared_error')# Train the autoencoder
history = autoencoder.fit(x_train, x_train,epochs=25,batch_size=512,shuffle=True,validation_data=(x_test, x_test))# Plot the training history
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper right')
plt.show()# Plot the test figures vs. predicted figures
decoded_imgs = autoencoder.predict(x_test)def mse(imageA, imageB):err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)err /= float(imageA.shape[0])return errdef ssim(imageA, imageB):return metrics.structural_similarity(imageA, imageB,channel_axis=None)decomser = [] 
decossimr = [] 
n = 10
list_xtestn = [ [x_test[i], test_labels[i]] for i in range(10)] 
print([list_xtestn[i][1] for i in range(n)]) 
plt.figure(figsize=(20, 4))
for i in range(n):# Display originalax = plt.subplot(2, n, i + 1)plt.imshow(x_test[i].reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)# Display reconstructionax = plt.subplot(2, n, i + 1 + n)plt.imshow(decoded_imgs[i].reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)if mse(list_xtestn[i][0],decoded_imgs[i]) <= 0.01: msel = mse(list_xtestn[i][0],decoded_imgs[i])decomser.append(list_xtestn[i][1])  if ssim(list_xtestn[i][0],decoded_imgs[i]) > 0.85:ssiml = ssim(list_xtestn[i][0],decoded_imgs[i])decossimr.append(list_xtestn[i][1])   print("mse and ssim for image %s are %s and %s" %(i,msel,ssiml)) 
plt.show() print(decomser)
print(decossimr)

三、实验的部分结果示例 

        该模型可以预测手写数据,如下所示。

原始数据和预测数据

        此外,使用MSE和ssim方法将预测图像与测试图像进行比较,可以访问test_labels并打印预测数据。

预测和测试图像的 MSE 和 SSM 值,以及 SSE 和 SSIM 方法test_labels返回的数字列表

        此代码演示如何使用自动编码器通过图像比较教程来训练和建立手写识别网络。一开始,训练和测试图像是随机的,因此每次运行的图像集都不同。

        在另一篇文章中,我们将展示如何使用 Padé 近似值作为自动编码器 (link.medium.com/cqiP5bd9ixb) 的激活函数。

引用:

  1. 原始的MNIST数据集:LeCun,Y.,Cortes,C.和Burges,C.J.(2010)。MNIST手写数字数据库。AT&T 实验室 [在线]。可用: http://yann。莱昆。com/exdb/mnist/
  2. 自动编码器概念和应用:Hinton,G.E.和Salakhutdinov,R.R.(2006)。使用神经网络降低数据的维数。科学, 313(5786), 504–507.
  3. 使用自动编码器进行图像重建:Masci,J.,Meier,U.,Cireşan,D.和Schmidhuber,J.(2011年52月)。用于分层特征提取的堆叠卷积自动编码器。在人工神经网络国际会议(第 59-<> 页)中。施普林格,柏林,海德堡。
  4. The tensorflow.keras library: Chollet, F. (2018).使用 Python 进行深度学习。纽约州谢尔特岛:曼宁出版公司
  5. 均方误差损失函数和亚当优化器:Kingma,D.P.和Ba,J.(2014)。Adam:一种随机优化的方法。arXiv预印本arXiv:1412.6980。
  6. 结构相似性指数(SSIM):Wang,Z.,Bovik,A.C.,Sheikh,H.R.和Simoncelli,E.P.(2004)。图像质量评估:从错误可见性到结构相似性。IEEE图像处理事务,13(4),600-612。
  7. 弗朗西斯·贝尼斯坦特

    ·

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/23263.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

K8S 部署 RocketMQ

文章目录 添加模板部署本地访问 集群使用 kubesphere 作为工具 添加模板 添加 helm 模板 helm repo add rocketmq-repo https://helm-charts.itboon.top/rocketmq helm repo update rocketmq-repo编写 value.yaml 文件 配置主从节点的个数&#xff0c;例子为单节点 broker:…

Tool Documentation Enables Zero-Shot Tool-Usage with Large Language Models

本文是LLM系列文章的内容&#xff0c;针对《Tool Documentation Enables Zero-Shot Tool-Usage with Large Language Models》的翻译。 工具文档赋能大模型零样本的工具使用 摘要1 引言2 相关工作3 实验设置3.1 常规的工作流3.2 工具使用提示方法3.3 评估任务 4 实证研究结果4…

node debian 镜像 new Date 获取时间少 8 小时问题

问题 在 node debian 镜像中&#xff0c;用 (new Date()).getHours() 与系统时间&#xff08;东 8 区&#xff09;少了 8 小时 系统时间 $ node > (new Date()).getHours() 11容器中的时间 $ node > (new Date()).getHours() 3原 Dockerfile FROM node:20.5-bullsey…

snap xxx has “install-snap“ change in progress

error description * 系重复安装&#xff0c;进程冲突 solution 展示snap的改变 然后sudo snap abort 22即可终止该进程 之后重新运行install command&#xff5e;&#xff5e; PS: ubuntu有时候加载不出来&#xff0c;执行resolvectl flush-caches&#xff0c;清除dns缓存…

用python编写的软件有哪些,编写python 用什么软件

大家好&#xff0c;小编来为大家解答以下问题&#xff0c;用python编写的软件有哪些&#xff0c;编写python 用什么软件&#xff0c;现在让我们一起来看看吧&#xff01; 随着互联网的迅速发展&#xff0c;新技术不断创新&#xff0c;万物互联的时代&#xff0c;企业对IT人员的…

【博客687】k8s informer的list-watch机制剖析

k8s informer的list-watch机制剖析 1、list-watch场景&#xff1a; client-go中的reflector模块首先会list apiserver获取某个资源的全量信息&#xff0c;然后根据list到的rv来watch资源的增量信息。希望使用client-go编写的控制器组件在与apiserver发生连接异常时&#xff0c…

【51单片机】晨启科技,酷黑版,音乐播放器

四、音乐播放器 任务要求&#xff1a; 设计制作一个简易音乐播放器&#xff08;通过手柄板上的蜂鸣器发声&#xff0c;播放2到4首音乐&#xff09;&#xff0c;同时LED模块闪烁&#xff0c;给人视、听觉美的感受。 评分细则&#xff1a; 按下播放按键A6开始播放音乐&#xff0…

SpringCloud入门Day01-服务注册与发现、服务通信、负载均衡与算法

SpringCloudNetflix入门 一、应用架构的演变 伴随互联网的发展&#xff0c;使用互联网的人群越来越多&#xff0c;软件应用的体量越来越大和复杂。而传统单体应用 可能不足以支撑大数据量以及发哦并发场景应用的框架也随之进行演变从最开始的单体应用架构到分布式&#xff08…

18. SpringBoot 如何在 POM 中引入本地 JAR 包

❤️ 个人主页&#xff1a;水滴技术 &#x1f338; 订阅专栏&#xff1a;成功解决 BUG 合集 &#x1f680; 支持水滴&#xff1a;点赞&#x1f44d; 收藏⭐ 留言&#x1f4ac; Spring Boot 是一种基于 Spring 框架的轻量级应用程序开发框架&#xff0c;它提供了快速开发应用程…

Palo Alto Networks® PA-220R 下一代防火墙 确保恶劣工况下的网络安全

一、主要安全功能 1、每时每刻在各端口对全部应用进行分类 • 将 App-ID 用于工业协议和应用&#xff0c;例如 Modbus、 DNP3、IEC 60870-5-104、Siemens S7、OSIsoft PI 等。 • 不论采用何种端口、SSL/SSH 加密或者其他规避技术&#xff0c;都会识别应用。 • 使用…

面试之快速学习c++11-函数模版的默认模版参数,可变模版,tuple

//学习地址&#xff1a; http://c.biancheng.net/view/3730.html 函数模版的默认模版参数 在 C98/03 标准中&#xff0c;类模板可以有默认的模板参数&#xff0c;如下&#xff1a; template <typename T, typename U int, U N 0> struct TestTemplateStruct {};但是…

CSS水平垂直居中

1.利用定位 margin:auto 2.flex布局 3.grid布局 一、利用positionmargin:auto <style>.outer {position: relative; /*父亲相对定位*/width: 200px;height: 200px;background-color: red;}.inner {position: absolute; /*儿子绝对定位*/top: 0;bottom: 0;left: 0;ri…

Python GUI编程(Tkinter)

Python GUI编程(Tkinter) Python 提供了多个图形开发界面的库&#xff0c;几个常用 Python GUI 库如下&#xff1a; Tkinter&#xff1a; Tkinter 模块(Tk 接口)是 Python 的标准 Tk GUI 工具包的接口 .Tk 和 Tkinter 可以在大多数的 Unix 平台下使用,同样可以应用在 Windows …

【硬件设计】模拟电子基础三--放大电路

模拟电子基础三--放大电路 一、集成运算放大器1.1 定义、组成与性能1.2 电流源电路1.3 差动放大电路1.4 理想运算放大器 二、集成运算放大器的应用2.1 反向比例运算电路2.2 同向比例运算电路2.3 反向加法运算电路2.4 反向减法运算电路2.5 积分运算电路2.6 微分运算电路2.7电压比…

weaviate入门安装与部署

官方教程 weaviate Introduction 安装 参考weaviate/installation&#xff0c;weaviate的安装有三种方案&#xff1a; 在本地安装docker build拉取镜像安装买WCS的weaviate服务 实践中&#xff0c;为了加快部署速度&#xff0c;我们用docker build拉取了一个镜像安装&#…

浅析大数据时代下的视频技术发展趋势以及AI加持下视频场景应用

视频技术的发展可以追溯到19世纪初期的早期实验。到20世纪初期&#xff0c;电视技术的发明和普及促进了视频技术的进一步发展。 1&#xff09;数字化&#xff1a;数字化技术的发明和发展使得视频技术更加先进。数字电视信号具有更高的清晰度和更大的带宽&#xff0c;可以更快地…

软考高级架构师——2、操作系统

一、进程管理 • 进程的状态&#xff08;★&#xff09; • 进程的同步与互斥&#xff08;★★★★&#xff09; 临界资源&#xff1a;诸进程间需要互斥方式对其进行共享的资源&#xff0c;如打印机、磁带机等 临界区&#xff1a;每个进程中访问临界资源的那段代码称为临界区…

STM32(HAL)串口中断接收

目录 1、简介 2 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 串口外设配置 2.3 项目生成 3、KEIL端程序整合 1、简介 本文对HAL串口中断函数进行介绍。 2 基础配置 2.1.1 SYS配置 2.1.2 RCC配置 2.2 串口外设配置 2.3 项目生成 3、KEIL端程序整合 首先在main.c文件中进行…

【云原生】k8s中Contrainer 生命周期回调/策略/指针学习

个人主页&#xff1a;征服bug-CSDN博客 kubernetes专栏&#xff1a;kubernetes_征服bug的博客-CSDN博客 目录 1 容器生命周期 2 容器生命周期回调/事件/钩子 3 容器重启策略 4 自定义容器启动命令 5 容器探针 1 容器生命周期 Kubernetes 会跟踪 Pod 中每个容器的状态&am…

根据VCF文件探测差异SNP并设计引物序列

根据VCF设计Marker序列 问题描述&#xff1a;如果有两个样品的测序数据&#xff0c;并通过GATK等上游分析得到了变异位点信息&#xff0c;现在我想要找任意两个样品的差异SNP&#xff0c;并提取该位置上下游50bp序列&#xff0c;用于设计引物&#xff0c;应该怎么做&#xff1f…