基于 Flask的深度学习模型部署服务端详解

基于 Flask 的深度学习模型部署服务端详解

在深度学习领域,训练出一个高精度的模型只是第一步,将其部署到生产环境中,为实际业务提供服务才是最终目标。本文将详细解析一个基于 Flask 和 PyTorch 的深度学习模型部署服务端代码,帮助你理解如何将训练好的模型以 API 形式提供给客户端使用。

一、整体概述

这段代码的主要功能是搭建一个基于 Flask 的 Web 服务,用于接收客户端发送的图像数据,使用预训练的 PyTorch 模型对图像进行分类预测,并将预测结果以 JSON 格式返回给客户端。

二、代码详细解析

1. 导入必要的库

import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models
  • io:用于处理二进制数据,这里主要用于将客户端发送的图像二进制数据转换为图像对象。
  • flask:一个轻量级的 Web 框架,用于搭建 Web 服务。
  • torchtorch.nn.functional:PyTorch 的核心库,用于深度学习模型的构建和计算。
  • PIL.Image:Python Imaging Library(PIL)的一部分,用于处理图像文件。
  • torch.nn:用于定义神经网络的层和模块。
  • torchvision.transformstorchvision.modelstransforms 用于图像预处理,models 提供了预训练的深度学习模型。

2. 初始化 Flask 应用和模型相关变量

app = flask.Flask(__name__)
model = None
use_gpu = False
  • app = flask.Flask(__name__):创建一个新的 Flask 应用实例,__name__ 参数用于确定应用的根路径。
  • model:用于存储加载的深度学习模型,初始化为 None
  • use_gpu:一个布尔变量,用于控制是否使用 GPU 进行模型推理,初始化为 False

3. 加载模型

def load_model():global modelmodel = models.resnet18()num_ftrs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(num_ftrs, 102))checkpoint = torch.load('best.pth')model.load_state_dict(checkpoint['state_dict'])model.eval()if use_gpu:model.cuda()
  • global model:声明 model 为全局变量,以便在函数内部修改它。
  • model = models.resnet18():加载预训练的 ResNet-18 模型。
  • num_ftrs = model.fc.in_features:获取 ResNet-18 模型最后一层全连接层的输入特征数。
  • model.fc = nn.Sequential(nn.Linear(num_ftrs, 102)):修改最后一层全连接层,将输出维度改为 102,这里的 102 可以根据实际任务的类别数进行调整。
  • checkpoint = torch.load('best.pth'):从文件 best.pth 中加载训练好的模型参数。
  • model.load_state_dict(checkpoint['state_dict']):将加载的参数应用到模型中。
  • model.eval():将模型设置为评估模式,关闭一些在训练时使用的特殊层(如 Dropout)。
  • if use_gpu: model.cuda():如果 use_gpuTrue,将模型移动到 GPU 上。

4. 图像预处理

def prepare_image(image, target_size):if image.mode != 'RGB':image = image.convert('RGB')image = transforms.Resize(target_size)(image)image = transforms.ToTensor()(image)image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)image = image[None]if use_gpu:image = image.cuda()return torch.tensor(image)
  • if image.mode != 'RGB': image = image.convert('RGB'):确保输入图像为 RGB 格式。
  • image = transforms.Resize(target_size)(image):将图像调整为指定的大小。
  • image = transforms.ToTensor()(image):将图像转换为 PyTorch 张量。
  • image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image):对图像进行归一化处理,使用的均值和标准差是在 ImageNet 数据集上计算得到的。
  • image = image[None]:增加一个维度,将图像转换为批量输入的格式。
  • if use_gpu: image = image.cuda():如果 use_gpuTrue,将图像移动到 GPU 上。

5. 定义预测接口

@app.route('/predict', methods=['POST'])
def predict():data = {'success': False}if flask.request.method == 'POST':if flask.request.files.get('image'):image = flask.request.files['image'].read()image = Image.open(io.BytesIO(image))image = prepare_image(image, target_size=(224, 224))preds = F.softmax(model(image), dim=1)results = torch.topk(preds.cpu().data, k=3, dim=1)results = (results[0].cpu().numpy(), results[1].cpu().numpy())data['prediction'] = list()for prob, label in zip(results[0][0], results[1][0]):r = {'label': str(label), 'probability': float(prob)}data['prediction'].append(r)data['success'] = Truereturn flask.jsonify(data)
  • @app.route('/predict', methods=['POST']):使用 Flask 的装饰器定义一个路由,当客户端向 /predict 路径发送 POST 请求时,会调用 predict 函数。
  • data = {'success': False}:初始化一个字典,用于存储预测结果和状态信息,初始状态为 success = False
  • if flask.request.method == 'POST':检查请求方法是否为 POST。
  • if flask.request.files.get('image'):检查请求中是否包含名为 image 的文件。
  • image = flask.request.files['image'].read():读取客户端发送的图像文件内容。
  • image = Image.open(io.BytesIO(image)):将二进制数据转换为图像对象。
  • image = prepare_image(image, target_size=(224, 224)):对图像进行预处理。
  • preds = F.softmax(model(image), dim=1):使用模型进行预测,并通过 softmax 函数将输出转换为概率分布。
  • results = torch.topk(preds.cpu().data, k=3, dim=1):获取概率最大的前 3 个结果。
  • results = (results[0].cpu().numpy(), results[1].cpu().numpy()):将结果转换为 NumPy 数组。
  • data['prediction'] = list():初始化一个列表,用于存储预测结果。
  • for prob, label in zip(results[0][0], results[1][0]):遍历前 3 个结果,将标签和概率封装成字典,并添加到 data['prediction'] 列表中。
  • data['success'] = True:将状态信息设置为 success = True,表示预测成功。
  • return flask.jsonify(data):将结果以 JSON 格式返回给客户端。

6. 启动服务

if __name__ == '__main__':print('Loading PyTorch model and Flask starting server ...')print('Please wait until server has fully started')load_model()app.run(host='192.168.1.20', port=5012)
  • if __name__ == '__main__':确保代码作为主程序运行时才执行以下操作。
  • print('Loading PyTorch model and Flask starting server ...')print('Please wait until server has fully started'):打印启动信息。
  • load_model():调用 load_model 函数加载模型。
  • app.run(host='192.168.1.20', port=5012):启动 Flask 服务,监听 192.168.1.20 地址的 5012 端口。运行结果如下
  • 在这里插入图片描述

三、总结

通过上述代码,我们成功搭建了一个基于 Flask 和 PyTorch 的深度学习模型部署服务端。客户端可以通过向 /predict 路径发送包含图像文件的 POST 请求,获取图像分类的预测结果。在实际应用中,可以根据需要对代码进行扩展,如增加更多的模型、优化图像预处理流程、添加错误处理机制等。希望本文能帮助你更好地理解深度学习模型的部署过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/82269.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue3 + Node.js 实现客服实时聊天系统(WebSocket + Socket.IO 详解)

Node.js 实现客服实时聊天系统(WebSocket Socket.IO 详解) 一、为什么选择 WebSocket? 想象一下淘宝客服的聊天窗口:你发消息,客服立刻就能看到并回复。这种即时通讯效果是如何实现的呢?我们使用 Vue3 作…

MySQL数据库与表结构操作指南

前言:本文系统梳理MySQL核心操作语句。内容覆盖建库建表、结构调整、数据迁移全流程(包含创建/修改/删除/备份场景)。希望它们能帮你快速解决问题。 库结构操作 一、库的创建 一个库的简单创建: create database 库名; 注意&am…

【WEB3】区块链、隐私计算、AI和Web3.0——数据民主化(1)

区块链、隐私计算、AI,是未来Web3.0至关重要的三项技术。 1.数据民主化问题 数据在整个生命周期(生产、传输、处理、存储)内的隐私安全,则是Web3.0在初始阶段首要解决的问题。 数据民主化旨在打破数据垄断,让个体能…

C语言—指针2

1. const 修饰变量 1.1 const修饰变量 变量被const修饰时,变量此时为常变量,本质为常量,语法上不可被修改,但是如果此时需要修改变量值,可以通过指针的方式修改。 虽然此时通过指针的方式确实修改了变量的值&#xff…

高级架构软考之网络OSI网络模型

高级架构软考之网络: 1.OSI网络模型: a.物理层: a.物理传输介质物理连接,负责数据传输,并监控数据 b.传输单位:bit c.协议: d:对应设备:中继器、集线器 b.数据链路层: a.…

el-table计算表头列宽,不换行显示

1、在utils.js中封装renderHeader方法 2、在el-table-column中引入: 3、页面展示:

MySQL OCP和Oracle OCP怎么选?

近期oracle 为庆祝 MySQL 数据库发布 30 周年,Oracle 官方推出限时福利:2025 年 4 月 20 日至 7 月 31 日期间,所有人均可免费报考 MySQL OCP(Oracle Certified Professional)认证考试(具体可查看MySQL OCP…

2025最新免费视频号下载工具!支持Win/Mac,一键解析原画质+封面

软件介绍 适用于Windows 2025 最新5月蝴蝶视频号下载工具,免费使用,无广告且免费,支持对原视频和封面进行解析下载,亲测可用,现在很多工具都失效了,难得的几款下载视频号工具,大家且用且珍…

Python学习之路(八)-多线程和多进程浅析

在 Python 中,多线程(Multithreading) 和 多进程(Multiprocessing) 是实现并发编程的两种主要方式。它们各有优劣,适用于不同的场景。 一、基本概念 特性多线程(threading)多进程(multiprocessing)并发模型线程共享内存空间每个进程拥有独立内存空间GIL(全局解释器锁…

Spark缓存--persist方法

1. 功能本质 persist:这是一个通用的持久化方法,能够指定多种不同的存储级别。存储级别决定了数据的存储位置(如内存、磁盘)以及存储形式(如是否序列化)。 2. 存储级别指定 persist:可以通过传入…

裸辞8年前端的面试笔记——JavaScript篇(一)

裸辞后的第二个月开始准备找工作,今天是第三天目前还没有面试,现在的行情是一言难尽,都在疯狂的压价。 下边是今天复习的个人笔记 一、事件循环 JavaScript 的事件循环(Event Loop)是其实现异步编程的关键机制。 从…

什么是死信队列?死信队列是如何导致的?

死信交换机(Dead Letter Exchange,DLX) 定义:死信交换机是一种特殊的交换机,专门用于**接收从其他队列中因特定原因变成死信的消息**。它的本质还是交换机,遵循RabbitMQ中交换机的基本工作原理&#xff0c…

9. 从《蜀道难》学CSS基础:三种选择器的实战解析

引言:当古诗遇上现代网页设计 今天我们通过李白的经典诗作《蜀道难》来学习CSS的三种核心选择器。这种古今结合的学习方式,既能感受中华诗词的魅力,又能掌握实用的网页设计技能。让我们开始这场穿越时空的技术之旅吧! 一、HTML骨架…

三角网格减面算法及其代表的算法库都有哪些?

以下是三角网格减面算法及其代表库/工具的详细分类,涵盖经典算法和现代实现: ​​1. 顶点聚类(Vertex Clustering)​​ ​​原理​​:将网格空间划分为体素栅格,合并每个栅格内的顶点。​​特点​​&#…

URP - 屏幕图像(_CameraOpaqueTexture)

首先需要在unity中开启屏幕图像开关才可以使用该纹理 同样只有不透明对象才能被渲染到屏幕图像中 若想要该对象不被渲染到屏幕图像中,可以将其Shader的渲染队列改为 "Queue" "Transparent" 如何在Shader中使用_CameraOpaqueTexture&#xf…

vue 和 html 的区别

使用 Vue.js 和原生 HTML 开发 Web 应用有显著的区别,主要体现在开发模式、功能扩展、性能优化和维护性等方面。以下是两者的对比分析: 🧱 原生 HTML(HTML CSS JavaScript) 特点: 静态结构:H…

LeetCode[226] 翻转二叉树

思路: 使用递归,归根结底还是左右节点互相倒,那么肯定需要一个temp节点在中间传递,最后就是递归,没什么说的 代码: /*** Definition for a binary tree node.* public class TreeNode {* int …

幂等的几种解决方案以及实践

目录 什么是幂等? 解决幂等的常见解决方案: 唯一标识符案例 数据库唯一约束 案例 乐观锁案例 分布式锁(Distributed Locking) 实践精选方案 首先 为什么不直接使用分布式锁呢? 自定义实现幂等组件&#xff01…

PowerShell中的Json处理

1.定义JSON字符串变量 PS C:\WINDOWS\system32> $body {"Method": "POST","Body": {"model": "deepseek-r1","messages": [{"content": "why is the sky blue?","role"…

奥威BI:AI+BI深度融合,重塑智能AI数据分析新标杆

在数字化浪潮席卷全球的今天,企业正面临着前所未有的数据挑战与机遇。如何高效、精准地挖掘数据价值,已成为推动业务增长、提升竞争力的核心议题。奥威BI,作为智能AI数据分析领域的领军者,凭借其创新的AIBI融合模式,正…