tensorflow 实现逻辑回归——原以为TensorFlow不擅长做线性回归或者逻辑回归,原来是这么简单哇!...

实现的是预测 低 出生 体重 的 概率。
尼克·麦克卢尔(Nick McClure). TensorFlow机器学习实战指南 (智能系统与技术丛书) (Kindle 位置 1060-1061). Kindle 版本.

# Logistic Regression
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
#  y = 0 or 1 = low birth weight
#  x = demographic and medical history dataimport matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csvops.reset_default_graph()# Create graph
sess = tf.Session()###
# Obtain and prepare data for modeling
#### Set name of data file
birth_weight_file = 'birth_weight.csv'# Download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'birth_file = requests.get(birthdata_url)birth_data = birth_file.text.split('\r\n')birth_header = birth_data[0].split('\t')birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]with open(birth_weight_file, 'w', newline='') as f:writer = csv.writer(f)writer.writerow(birth_header)writer.writerows(birth_data)f.close()# Read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:csv_reader = csv.reader(csvfile)birth_header = next(csv_reader)for row in csv_reader:birth_data.append(row)birth_data = [[float(x) for x in row] for row in birth_data]# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])# Set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)# Split data into train/test = 80%/20%
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]# Normalize by column (min-max norm)
def normalize_cols(m):col_max = m.max(axis=0)col_min = m.min(axis=0)return (m-col_min) / (col_max - col_min)x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))###
# Define Tensorflow computational graph¶
#### Declare batch size
batch_size = 25# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)###
# Train model
#### Initialize variables
init = tf.global_variables_initializer()
sess.run(init)# Actual Prediction
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)# Training loop
loss_vec = []
train_acc = []
test_acc = []
for i in range(15000):rand_index = np.random.choice(len(x_vals_train), size=batch_size)rand_x = x_vals_train[rand_index]rand_y = np.transpose([y_vals_train[rand_index]])sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})loss_vec.append(temp_loss)temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})train_acc.append(temp_acc_train)temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})test_acc.append(temp_acc_test)if (i+1)%300==0:print('Loss = ' + str(temp_loss))###
# Display model performance
#### Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/390510.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

剑指 Offer 66. 构建乘积数组

给定一个数组 A[0,1,…,n-1],请构建一个数组 B[0,1,…,n-1],其中 B[i] 的值是数组 A 中除了下标 i 以外的元素的积, 即 B[i]A[0]A[1]…A[i-1]A[i1]…A[n-1]。不能使用除法。 示例: 输入: [1,2,3,4,5] 输出: [120,60,40,30,24] 提示: 所有…

amazeui学习笔记--css(基本样式3)--文字排版Typography

amazeui学习笔记--css(基本样式3)--文字排版Typography 一、总结 1、字体:amaze默认非 衬线字体(sans-serif) 2、引用块blockquote和定义列表:引用块blockquote和定义列表(dl dt)注意…

ELK学习记录三 :elasticsearch、logstash及kibana的安装与配置(windows)

注意事项: 1.ELK版本要求5.X以上 2.Elasticsearch5.x版本必须基于jdk1.8,安装环境必须使用jdk1.8 3.操作系统windows10作为测试环境,其他环境命令有差异,请注意 4.本教程适合完全离线安装 5.windows版本ELK安装包下载路径&#xf…

【quickhybrid】JSBridge的实现

前言 本文介绍quick hybrid框架的核心JSBridge的实现 由于在最新版本中,已经没有考虑iOS7等低版本,因此在选用方案时没有采用url scheme方式,而是直接基于WKWebView实现 交互原理 具体H5和Native的交互原理可以参考前文的H5和Native交互原理 …

面试题 10.02. 变位词组

编写一种方法,对字符串数组进行排序,将所有变位词组合在一起。变位词是指字母相同,但排列不同的字符串。 注意:本题相对原题稍作修改 示例: 输入: [“eat”, “tea”, “tan”, “ate”, “nat”, “bat”], 输出: [ [“ate”,…

智能合约设计模式

2019独角兽企业重金招聘Python工程师标准>>> 设计模式是许多开发场景中的首选解决方案,本文将介绍五种经典的智能合约设计模式并给出以太坊solidity实现代码:自毁合约、工厂合约、名称注册表、映射表迭代器和提款模式。 1、自毁合约 合约自毁…

「CodePlus 2017 12 月赛」火锅盛宴

n<100000种食物&#xff0c;给每个食物煮熟时间&#xff0c;有q<500000个操作&#xff1a;在某时刻插入某个食物&#xff1b;查询熟食中编号最小的并删除之&#xff1b;查询是否有编号为id的食物&#xff0c;如果有查询是否有编号为id的熟食&#xff0c;如果有熟食删除之…

5815. 扣分后的最大得分

给你一个 m x n 的整数矩阵 points &#xff08;下标从 0 开始&#xff09;。一开始你的得分为 0 &#xff0c;你想最大化从矩阵中得到的分数。 你的得分方式为&#xff1a;每一行 中选取一个格子&#xff0c;选中坐标为 (r, c) 的格子会给你的总得分 增加 points[r][c] 。 然…

您有一个上云锦囊尚未领取!

前期&#xff0c;我们通过文章《确认过眼神&#xff1f;上云之路需要遇上对的人&#xff01;》向大家详细介绍了阿里云咨询与设计场景下的五款专家服务产品&#xff0c;企业可以通过这些专家服务产品解决了上云前的痛点。那么&#xff0c;当完成上云前的可行性评估与方案设计后…

Python os.chdir() 方法

概述 os.chdir() 方法用于改变当前工作目录到指定的路径。 语法 chdir()方法语法格式如下&#xff1a; os.chdir(path) 参数 path -- 要切换到的新路径。 返回值 如果允许访问返回 True , 否则返回False。 实例 以下实例演示了 chdir() 方法的使用&#xff1a; #!/usr/bin/pyth…

More DETAILS! PBR的下一个发展在哪里?

最近几年图形学社区对PBR的关注非常高&#xff0c;也许是由于Disney以及一些游戏引擎大厂的助推&#xff0c;也许是因为它可以被轻松集成进实时渲染的游戏引擎当中&#xff0c;也许是因为许多人发现现在只需要调几个参数就能实现具有非常精细细节的表面着色了。反正现在网络上随…

sql server 2008 身份验证失败 18456

双击打开后加上 ;-m 然后以管理员方式 打开 SQLSERVER 2008 就可以已window身份登录 不过还没有完 右键 属性 》安全性 更改为 sql server 和 window身份验证模式 没有sql server登陆账号的话创建一个 然后把-m去掉就可以用帐号登录了 转载于:https://www.cnblogs.com/R…

Java逆向基础之AspectJ的获取成员变量的值

注意&#xff1a;由于JVM优化的原因&#xff0c;方法里面的局部变量是不能通过AspectJ拦截并获取其中的值的&#xff0c;但是成员变量可以在逆向中&#xff0c;我们经常要跟踪某些类的成员变量的值&#xff0c;这里以获取ZKM9中的qs类的成员变量g为例进行说明在StackOverFlow上…

Spark 键值对RDD操作

https://www.cnblogs.com/yongjian/p/6425772.html 概述 键值对RDD是Spark操作中最常用的RDD&#xff0c;它是很多程序的构成要素&#xff0c;因为他们提供了并行操作各个键或跨界点重新进行数据分组的操作接口。 创建 Spark中有许多中创建键值对RDD的方式&#xff0c;其中包括…

Python高级网络编程系列之基础篇

一、Socket简介 1、不同电脑上的进程如何通信&#xff1f; 进程间通信的首要问题是如何找到目标进程&#xff0c;也就是操作系统是如何唯一标识一个进程的&#xff01; 在一台电脑上是只通过进程号PID&#xff0c;但在网络中是行不通的&#xff0c;因为每台电脑的IP可能都是不一…

剑指 Offer 67. 把字符串转换成整数

写一个函数 StrToInt&#xff0c;实现把字符串转换成整数这个功能。不能使用 atoi 或者其他类似的库函数。 首先&#xff0c;该函数会根据需要丢弃无用的开头空格字符&#xff0c;直到寻找到第一个非空格的字符为止。 当我们寻找到的第一个非空字符为正或者负号时&#xff0c…

4.RabbitMQ Linux安装

这里使用的Linux是CentOS6.2 将/etc/yum.repo.d/目录下的所有repo文件删除 先下载epel源 # wget -O /etc/yum.repos.d/epel-erlang.repo http://repos.fedorapeople.org/repos/peter/erlang/epel-erlang.repo 修改epel-erlang.repo文件&#xff0c;如下图 添加CentOS 的下载源…

二、数据库设计与操作

一、 数据库设计仿QQ数据库一共包括5张数据表&#xff0c;每张数据表结构如下&#xff1a;1、 tb_User&#xff08;用户信息表&#xff09;这张表主要用来存储用户的好友关系与信息字段名数据类型是否Null值默认值绑定描述IDint否用户账号PwdVarchar(50)否用户密码Frie…

剑指 Offer 36. 二叉搜索树与双向链表

输入一棵二叉搜索树&#xff0c;将该二叉搜索树转换成一个排序的循环双向链表。要求不能创建任何新的节点&#xff0c;只能调整树中节点指针的指向。 为了让您更好地理解问题&#xff0c;以下面的二叉搜索树为例&#xff1a; 我们希望将这个二叉搜索树转化为双向循环链表。链表…

1736. 替换隐藏数字得到的最晚时间

给你一个字符串 time &#xff0c;格式为 hh:mm&#xff08;小时&#xff1a;分钟&#xff09;&#xff0c;其中某几位数字被隐藏&#xff08;用 ? 表示&#xff09;。 有效的时间为 00:00 到 23:59 之间的所有时间&#xff0c;包括 00:00 和 23:59 。 替换 time 中隐藏的数…