pytorch简单框架

网络搭建:

mynn.py:

import torch
from torch import nn
class mynn(nn.Module):
def __init__(self):
super(mynn, self).__init__()
self.layer1 = nn.Sequential(
nn.Linear(3520, 4096), nn.BatchNorm1d(4096), nn.ReLU(True)
)
self.layer2 = nn.Sequential(
nn.Linear(4096, 4096), nn.BatchNorm1d(4096), nn.ReLU(True)
)
self.layer3 = nn.Sequential(
nn.Linear(4096, 4096), nn.BatchNorm1d(4096), nn.ReLU(True)
)
self.layer4 = nn.Sequential(
nn.Linear(4096, 4096), nn.BatchNorm1d(4096), nn.ReLU(True)
)
self.layer5 = nn.Sequential(
nn.Linear(4096, 3072), nn.BatchNorm1d(3072), nn.ReLU(True)
)
self.layer6 = nn.Sequential(
nn.Linear(3072, 2048), nn.BatchNorm1d(2048), nn.ReLU(True)
)
self.layer7 = nn.Sequential(
nn.Linear(2048, 1024), nn.BatchNorm1d(1024), nn.ReLU(True)
)
self.layer8 = nn.Sequential(
nn.Linear(1024, 256), nn.BatchNorm1d(256), nn.ReLU(True)
)
self.layer9 = nn.Sequential(
nn.Linear(256, 64), nn.BatchNorm1d(64), nn.ReLU(True)
)
self.layer10 = nn.Sequential(
nn.Linear(64, 32), nn.BatchNorm1d(32), nn.ReLU(True)
)
self.layer11 = nn.Sequential(
nn.Linear(32, 3)
)

def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
x = self.layer7(x)
x = self.layer8(x)
x = self.layer9(x)
x = self.layer10(x)
x = self.layer11(x)
return x

Dataset重定义:
mydataset.py

import os
from torch.utils import data
import numpy as np
from astropy.io import fits
from torchvision import transforms as T
from PIL import Image
import pandas as pd

class mydataset(data.Dataset):

def __init__(self,csv_file,root_dir=None,transform=None):
self.landmarks_frame=np.loadtxt(open(csv_file,"rb"),delimiter=",") #landmarks_frame是一个numpy矩阵
self.root_dir=root_dir
self.transform=transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
lfit=self.landmarks_frame[idx,:]
lable=lfit[len(lfit)-1]
datafit=lfit[0:(len(lfit)-1)]
return lable,datafit
主程序:
main.py
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.autograd import Variable
#from models import Mynet, my_AlexNet, my_VGG
from sdata import mydataset
import time
import numpy as np
from model import mynn
if __name__ == '__main__': #如果Dataloader开启num_workers > 0 必须要在'__main__'下才能消除报错

data_train = mydataset.mydataset(csv_file="G:\\DATA\\train.csv",root_dir=None,transform=None)
#data_test = mydataset(test=True)
data_test = mydataset.mydataset(csv_file="G:\\DATA\\test.csv", root_dir=None, transform=None)
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
batch_size=256,
shuffle=True,
num_workers=0,
pin_memory=True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
batch_size=256,
shuffle=True,
num_workers=0,
pin_memory=True)
print("**dataloader done**")
model = mynn.mynn()

if torch.cuda.is_available():
#model = model.cuda()
model.to(torch.device('cuda'))
#损失函数
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
#优化算法
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-4)
n_epochs = 1000

global_train_acc = []

s_time = time.time()

for epoch in range(n_epochs):
running_loss = 0.0
running_correct = 0.0
print('Epoch {}/{}'.format(epoch, n_epochs))
for label,datafit in data_loader_train:
x_train, y_train = datafit,label
#x_train, y_train = Variable(x_train.cuda()), Variable(y_train.cuda())
x_train, y_train = x_train.to(torch.device('cuda')), y_train.to(torch.device('cuda'))
x_train=x_train.float()
y_train=y_train.long()
#x_train, y_train = Variable(x_train), Variable(y_train)
outputs = model(x_train)
_, pred = torch.max(outputs.data, 1)
optimizer.zero_grad()
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()

running_loss += loss.item()
running_correct += torch.sum(pred == y_train.data)

testing_correct = 0.0
for label,datafit in data_loader_test:
x_test, y_test = datafit,label
x_test=x_test.float()
y_test=y_test.long()
x_test, y_test = Variable(x_test.cuda()), Variable(y_test.cuda())
# x_test, y_test = Variable(x_test), Variable(y_test)
outputs = model(x_test)
_, pred = torch.max(outputs.data, 1)
testing_correct += torch.sum(pred == y_test.data)

print('Loss is:{:.4f}, Train Accuracy is:{:.4f}%, Test Accuracy '
'is:{:.4f}'.format(running_loss / len(data_train),
100 * running_correct / len(data_train),
100 * testing_correct / len(data_test)))


e_time = time.time()
print('time_run is :', e_time - s_time)
print('*******done******')

将天文数据写入csv中:
main.py
# -*- coding: utf-8 -*-
"""
Spyder Editor

This is a temporary script file.
"""

import matplotlib.pyplot as plt
from astropy.io import fits
import os
import matplotlib
matplotlib.use('Qt5Agg')
from astropy.io import fits
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.decomposition import PCA
def getData(fitPath,cla):
fileList=[] #所有.fit文件
files=os.listdir(fitPath) #返回一个列表,其中包含在目录条目的名称
y=[]
for f in files:
if os.path.isfile(fitPath+'/'+f) and f[-4:-1]==".fi":
fileList.append(fitPath+'/'+f) #添加文件
len=90000
x=np.ones(3521)
num=1
for path in fileList:
f = fits.open(path)
header = f[0].header # fit文件中的各种标识

SPEC_CLN = header['SPEC_CLN']
SN_G = header['SN_G']
NAXIS1 = header['NAXIS1'] # 光谱数据维度
COEFF0 = header['COEFF0']
COEFF1 = header['COEFF1']
wave = np.ones(NAXIS1) # 光谱图像中的横坐标
for i in range(NAXIS1):
wave[i] = i
logwavelength = COEFF0 + wave * COEFF1
for i in range(NAXIS1):
wave[i] = 10 ** logwavelength[i]
min=0
for i in range(NAXIS1-1):
if wave[i]<=4000 and wave[i+1]>=4000:
min=i
spec = f[0].data[0, :] # 光谱数据 fit中的第一行数据
spec=spec[min:min+3521]
spec=np.array(spec)
spec[3520]=cla
if num==1:
x=spec
num=2
else:
x=np.row_stack((x,spec))
#np.savetxt(csvPath,x, delimiter=',')
return x

if __name__ == '__main__':
x=getData("G:\DATA\STAR",0)
x_train,x_test=train_test_split(x,test_size=0.1 ,random_state=0)

y=getData("G:\DATA\QSO",1)
y_train, y_test = train_test_split(y, test_size=0.1, random_state=0)
x_train = np.row_stack((x_train,y_train))
x_test=np.row_stack((x_test,y_test))

z=getData("G:\DATA\GALAXY",2)
z_train, z_test = train_test_split(z, test_size=0.1, random_state=0)
x_train=np.row_stack((x_train,z_train))
x_test = np.row_stack((x_test,z_test))
np.savetxt("G:\\DATA\\train.csv",x_train, delimiter=',')
np.savetxt("G:\\DATA\\test.csv", x_test, delimiter=',')


转载于:https://www.cnblogs.com/invisible2/p/11523330.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/246979.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java 基础安装和Tomcat8配置

初识 java&#xff0c;基础安装的说明。 下载 在oracle官网一般在同一个java版本会提供2个版本&#xff0c; 一个是Java SE Development Kit 7u80&#xff0c;此版本包含JDK开发环境版本&#xff1b; 另外一个是 Java SE Runtime Environment 7u80&#xff0c;此为只包含JR…

TypeScript React

环境搭建 我们当然可以先用脚手架搭建React项目&#xff0c;然后手动配置成支持TypeScript的环境&#xff0c;虽然比较麻烦&#xff0c;但可以让你更清楚整个过程。这里比较麻烦&#xff0c;就不演示了&#xff0c;直接用命令配置好。 npx create-react-app appname --typescri…

matlab内存溢出的解决方案

&#xff08;1&#xff09; 增加虚拟内存&#xff1a;cmd -> taskmgr 打开任务管理器&#xff0c;查看物理内存和虚拟内存&#xff0c;可观察matlab在运行过程中是否超过物理内存和虚拟内存。若超过&#xff0c;增加虚拟内存的方法是不可行的。物理内存不足的时候可以通过将…

c++MMMMM:oo

1.union&#xff0c;struct和class的区别 转载于:https://www.cnblogs.com/invisible2/p/11524465.html

matlab调用Java程序时出现 Java.lang.OutOfMemoryErrot: GC overhead limit exceeded

matlab调用Java程序时出现 java.lang.OutOfMemoryError: GC overhead limit exceeded JDK1.6.0_37和JDK_1.7.0_60版本&#xff0c;这2个版本中JVM默认启动的时候-XX:UseGCOverheadLimit&#xff0c;即启用了该特性。这其实是JVM的一种推断&#xff0c;如果垃圾回收耗费了98%的…

[FY20 创新人才班 ASE] 第 1 次作业成绩

作业概况 条目备注作业链接【ASE高级软件工程】热身作业&#xff01;提交人数19未完成人数2满分10分作业情况总结 本次作业作为大家软工课程的第一次作业&#xff0c;完成度相当不错&#xff08;尤其是在国外暑研/赶论文的同学也在尽力完成作业&#xff0c;很感动&#xff09;。…

JVM 参数设置

1、集成开发环境下启动并使用JVM&#xff0c;如eclipse需要修改根目录文件eclipse.ini&#xff1b; 2、Windows服务器下安装版Tomcat&#xff0c;可使用Tomcat8w.exe工具&#xff08;tomcat目录下&#xff09;和直接修改注册表两种方式修改Jvm参数&#xff1b; 3、Windows服务…

从筛选简历和面试流程讲起,再给培训班出身的程序员一些建议

本人最近几年一直在外企和互联网公司承担Java技术面试官的职责&#xff0c;大多面试的是Java初级和高级开发&#xff0c;其中有不少是培训班出身的候选人。 在我之前的博文里&#xff0c;从面试官的角度聊聊培训班对程序员的帮助&#xff0c;同时给培训班出身的程序员一些建议&…

马普所机器学习课程 CMU701

马普所机器学习课程 Max-Planck-Institut fr Informatik: Machine Learning https://www.mpi-inf.mpg.de/departments/computer-vision-and-multimodal-computing/teaching/courses/ 马普所 GVV project http://gvv.mpi-inf.mpg.de/GVV_Projects.html CMU701 Tom Mitchel…

Random Forest 实用经验(转)

总结两条关于random forest的实用经验。给定数据和问题&#xff0c;对于算法选择有参考价值。 小样本劣势&#xff0c;大样本优势 小样本情况下&#xff08;1k~100k&#xff09;&#xff1a; RF相对与经典算法&#xff08;SVM or Boosting&#xff09;没优势&#xff0c;一般…

pytorch 查看中间变量的梯度

pytorch 为了节省显存&#xff0c;在反向传播的过程中只针对计算图中的叶子结点(leaf variable)保留了梯度值(gradient)。但对于开发者来说&#xff0c;有时我们希望探测某些中间变量(intermediate variable) 的梯度来验证我们的实现是否有误&#xff0c;这个过程就需要用到 te…

hbase数据迁移到hive中

描述&#xff1a; 原先数据是存储在hbase中的&#xff0c;但是直接查询hbase速度慢&#xff08;hbase是宽表结构&#xff09;&#xff0c;所以想把数据迁移到hive中&#xff1b; 1.先hbase 和 hive创建 外部表链接&#xff0c; 可以在hive直接查询&#xff1b; 2.利用创建的外部…

weka 学习总结(持续)

机器学习之 Weka学习&#xff08;一&#xff09;weka介绍&#xff0c;安装和配置环境变量 机器学习之 Weka学习&#xff08;二&#xff09;算法说明 Weka数据挖掘处理流程介绍 机器学习之 weka学习&#xff08;五&#xff09;示例用法 Weka数据处理结果分析 http://blog.c…

作为IT面试官,我如何考核计算机专业毕业生?作为培训班老师,我又如何提升他们?...

我最近几年一直在做技术面试官&#xff0c;除了面试有一定工作经验的社会人员外&#xff0c;有时还会面试在校实习生和刚毕业的大学生。同时&#xff0c;我也在学校里做过兼职讲师&#xff0c;上些政府补贴课程&#xff08;这些课程有补贴&#xff0c;学生不用出钱&#xff09;…

memcpy函数的实现

1.按1个字节拷贝 &#xff08;1&#xff09;不要直接使用形参&#xff0c;要转换成char* &#xff08;2&#xff09;目标地址要实现保存 &#xff08;3&#xff09;要考虑源和目标内存重叠的情况 void * mymemcpy(void *dest, const void *src, size_t count) {if (dest NULL …

MATLAB中调用Weka设置方法(转)及示例

本文转自&#xff1a; http://blog.sina.com.cn/s/blog_890c6aa30101av9x.html MATLAB命令行下验证Java版本命令 version -java 配置MATLAB调用Java库 Finish Java codes.Create Java library file, i.e., .jar file.Put created .jar file to one of directories Matlab …

webpack4配置基础

前言 为什么要使用构建工具&#xff1f; 1.转换ES6语法&#xff08;很多老版本的浏览器不支持新语法&#xff09; 2.转换JSX 3.CSS前缀补全/预处理器 4.压缩混淆&#xff08;将代码逻辑尽可能地隐藏起来&#xff09; 5.图片压缩 6. .... 为什么选择webpack&#xff1f; 社区…

RESTful API概述

什么是REST REST与技术无关&#xff0c;代表的是一种软件架构风格&#xff0c;REST是Representational State Transfer的简称&#xff0c;中文翻译为“表征状态转移”。这里说的表征性&#xff0c;就是指资源&#xff0c;通常我们称为资源状态转移。 什么是资源&#xff1f; 网…

AI 《A PROPOSAL FOR THE DARTMOUTH SUMMER RESEARCH PROJECT ON ARTIFICIAL INTELLIGENCE》读后总结

本文转载&#xff1a; http://www.cnblogs.com/SnakeHunt2012/archive/2013/02/18/2916242.html 《A Proposal for the Dartmouth Summer Research Project on Artificial Intelligence》&#xff0c;这是AI领域的开山之作&#xff0c;是当年达特茅斯会议上各路大牛们为期两个月…

第94:受限玻尔兹曼机

转载于:https://www.cnblogs.com/invisible2/p/11565179.html