用TensorFlow进行逻辑回归(三)

逻辑回归Logistic regression

这个脚本展示如何用TensorFlow求解逻辑回归。 =(×+)y=sigmoid(A×x+b)

我们使用低出生重量数据,特别地:

```

y = 0 or 1 = low birth weight

x = demographic and medical history data

import matplotlib.pyplot as plt

import numpy as np

import tensorflow as tf

import requests

from tensorflow.python.framework import ops

import os.path

import csv

ops.reset_default_graph()

#tf.set_random_seed(42)

np.random.seed(42)

# name of data file

birth_weight_file = 'birth_weight1.csv'

# download data and create data file if file does not exist in current directory

#if not os.path.exists(birth_weight_file):

   

 #   birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'

  #  birth_file = requests.get(birthdata_url)

   # birth_data = birth_file.text.split('\r\n')

    #birth_header = birth_data[0].split('\t')

    #birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]

    #with open(birth_weight_file, 'w', newline='') as f:

     #   writer = csv.writer(f)

      #  writer.writerow(birth_header)

       # writer.writerows(birth_data)

        #f.close()

# read birth weight data into memory

birth_data = []

with open(birth_weight_file, newline='') as csvfile:

     csv_reader = csv.reader(csvfile)

     birth_header = next(csv_reader)

     for row in csv_reader:

         birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable

y_vals = np.array([x[0] for x in birth_data])

# Pull out predictor variables (not id, not target, and not birthweight)

x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results

seed = 99

np.random.seed(seed)

#tf.set_random_seed(seed)

# Split data into train/test = 80%/20%

train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)

test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))

x_vals_train = x_vals[train_indices]

x_vals_test = x_vals[test_indices]

y_vals_train = y_vals[train_indices]

y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)

def normalize_cols(m, col_min=np.array([None]), col_max=np.array([None])):

    if not col_min[0]:

        col_min = m.min(axis=0)

    if not col_max[0]:

        col_max = m.max(axis=0)

    return (m-col_min) / (col_max - col_min), col_min, col_max

   

x_vals_train, train_min, train_max = np.nan_to_num(normalize_cols(x_vals_train))

x_vals_test, _, _ = np.nan_to_num(normalize_cols(x_vals_test, train_min, train_max))

def model(x,w,b):

    # Declare model operations

    model_output = tf.add(tf.matmul(x, w), b)

    return model_output

def loss1(x,y,w,b):

    # Declare Deming loss function

    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model(x,w,b), labels=y))

    return loss

def grad1(x,y,w,b):

    with tf.GradientTape() as tape:

        loss_1 = loss1(x,y,w,b)

return tape.gradient(loss_1,[w,b])

# Declare batch size

# Declare batch size

batch_size = 25

learning_rate = 0.25 # Will not converge with learning rate at 0.4

iterations = 50

# Create variables for linear regression

w1 = tf.Variable(tf.random.normal(shape=[7,1]),tf.float32)

b1 = tf.Variable(tf.random.normal(shape=[1,1]),tf.float32)

optimizer = tf.optimizers.Adam(learning_rate)

# Training loop

# Training loop

loss_vec = []

train_acc = []

test_acc = []

for i in range(5000):

    rand_index = np.random.choice(len(x_vals_train), size=batch_size)

    rand_x = x_vals_train[rand_index]

    rand_y = np.transpose([y_vals_train[rand_index]])

    x=tf.cast(rand_x,tf.float32)

    y=tf.cast(rand_y,tf.float32)

    grads1=grad1(x,y,w1,b1)

    optimizer.apply_gradients(zip(grads1,[w1,b1]))

    #sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

    temp_loss1 = loss1(x, y,w1,b1).numpy()

    #sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})

    loss_vec.append(temp_loss1)

   

    # Actual Prediction

    #prediction = tf.round(tf.sigmoid(model_output))

    #predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)

    #accuracy = tf.reduce_mean(predictions_correct)

   

   

    prediction1 = tf.round(tf.sigmoid(model(tf.cast(x_vals_train,tf.float32),w1,b1)))

    predictions_correct1 = tf.cast(tf.equal(prediction1, tf.cast(np.transpose([y_vals_train]),tf.float32)), tf.float32)

    temp_acc_train = tf.reduce_mean(predictions_correct1)

    train_acc.append(temp_acc_train)

   

    prediction2 = tf.round(tf.sigmoid(model(tf.cast(x_vals_test,tf.float32),w1,b1)))

    predictions_correct2 = tf.cast(tf.equal(prediction2, tf.cast(np.transpose([y_vals_test]),tf.float32)), tf.float32)

    temp_acc_test=tf.reduce_mean(predictions_correct2)

    test_acc.append(temp_acc_test)

   

    if (i+1)%25==0:

        print('Step #' + str(i+1) + ' A = ' + str(w1.numpy()) + ' b = ' + str(b1.numpy()))

        print('Loss = ' + str(temp_loss1))

%matplotlib inline

# Plot loss over time

plt.plot(loss_vec, 'k-')

plt.title('Cross Entropy Loss per Generation')

plt.xlabel('Generation')

plt.ylabel('Cross Entropy Loss')

plt.show()

# Plot train and test accuracy

plt.plot(train_acc, 'k-', label='Train Set Accuracy')

plt.plot(test_acc, 'r--', label='Test Set Accuracy')

plt.title('Train and Test Accuracy')

plt.xlabel('Generation')

plt.ylabel('Accuracy')

plt.legend(loc='lower right')

plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/89202.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/89202.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mingw 编译 assimp v6.0.2 解决编译报错

mingw 编译 assimp v6.0.2 理论上看这个就能满足:在Windows下使用CMakeMinGW64编译Assimp库 环境变量问题 i386 architecture of input file CMakeFiles\assimp.dir/objects.a(assimp.rc.obj)’ is incompatible with i386:x86-64 output collect2.exe: error: ld r…

Windows 11清理C盘方法大全:磁盘清理/禁用休眠/系统还原点/优化大师使用教程

Windows 11清理C盘方法1. 使用磁盘清理工具步骤:按 Win S 搜索“磁盘清理”,打开工具。选择C盘,点击“确定”。勾选需要清理的文件类型(如临时文件、系统错误内存转储等),点击“确定”。确认删除操作&…

Rabbitmq Direct Exchange(直连交换机)多个消费者,配置相同的key ,队列,可以保证只有一个消费者消费吗

思考可以保证消费不被重复消费,因为通过轮询一个消息只会投递给一个消费者。但是不是一个消费者消费,而是多个轮询消费在 RabbitMQ 中,如果多个消费者(Consumers)同时订阅 同一个队列(Queue)&am…

设计模式是什么呢?

1.掌握设计模式的层次第一层:刚刚学编程不久,听说过什么是设计模式。第二层:有很长时间的编程经验,自己写过很多代码,其中用到了设计模式,但是自己不知道。第三层:学习过设计模式,发…

ThreadLocal使用详解-从源码层面分析

从demo入手看效果 代码Demostatic ThreadLocal tl1 new ThreadLocal();static ThreadLocal tl2 new ThreadLocal();static ThreadLocal tl3 new ThreadLocal();public static void main(String[] args) {tl1.set("123");tl2.set("456");tl3.set("4…

CPO:对比偏好优化—突破大型语言模型在机器翻译中的性能边界

温馨提示: 本篇文章已同步至"AI专题精讲" CPO:对比偏好优化—突破大型语言模型在机器翻译中的性能边界 摘要 中等规模的大型语言模型(LLMs),如参数量为 7B 或 13B 的模型,在机器翻译&#xff0…

执行shell 脚本 如何将日志全部输出到文件

在执行 Shell 脚本时,如果需要将 所有输出(包括标准输出 stdout 和错误输出 stderr) 重定向到日志文件,可以使用以下方法:方法 1:直接重定向(推荐) /appdata/mysql_backup_dump.sh &…

Postman接口测试实现UI自动化测试

Selenium底层原理 3天精通Postman接口测试,全套项目实战教程!!运行代码,启动浏览器后,webdriver会将浏览器绑定到特定的端口,作为webdriver的remote server(远程服务端),…

CSS动画与变换全解析:从原理到性能优化的深度指南

引言:现代Web动画的技术革命 在当今的Web体验中,流畅的动画效果已成为用户交互的核心要素。根据Google的研究,60fps的动画可以使用户参与度提升53%,而卡顿的界面会导致跳出率增加40%。本文将深入剖析CSS动画(animation…

NPM组件 @ivy-shared-components/iconslibrary 等窃取主机敏感信息

【高危】NPM组件 ivy-shared-components/iconslibrary 等窃取主机敏感信息 漏洞描述 当用户安装受影响版本的 ivy-shared-components/iconslibrary 等NPM组件包时会窃取用户的主机名、用户名、工作目录、IP地址等信息并发送到攻击者可控的服务器地址。 MPS编号MPS-zh19-e78w…

Fail2ban防止暴力破解工具使用教程

Fail2ban防止暴力破解工具使用教程场景Fail2ban安装和配置安装配置原理遇到的问题以及解决办法问题1:设置的策略是10分钟内ssh连接失败2次的ip进行封禁,日志中实际却出现4次连接。问题2:策略设置为1分钟内失败两次,封禁ip。但通过…

亚远景科技助力长城汽车,开启智能研发新征程

亚远景科技助力长城汽车,开启智能研发新征程在汽车智能化飞速发展的当下,软件研发管理成为车企决胜未来的关键。近日,亚远景科技胡浩老师应邀为长城汽车开展了一场主题深刻且极具实用价值的培训。本次培训聚焦软件研发管理导论 - 建立机器学习…

图算法在前端的复杂交互

引言 图算法是处理复杂关系和交互的强大工具,在前端开发中有着广泛应用。从社交网络的推荐系统到流程图编辑器的路径优化,再到权限依赖的拓扑排序,图算法能够高效解决数据之间的复杂关联问题。随着 Web 应用交互复杂度的增加,如实…

Prometheus Operator:Kubernetes 监控自动化实践

在云原生时代,Kubernetes 已成为容器编排的事实标准。然而,在高度动态的 Kubernetes 环境中,传统的监控工具往往难以跟上服务的快速变化。Prometheus Operator 应运而生,它将 Prometheus 及其生态系统与 Kubernetes 深度融合&…

一种融合人工智能与图像处理的发票OCR技术,将人力从繁琐的票据处理中解放

在数字化浪潮席卷全球的今天,发票OCR技术正悄然改变着企业财务流程的运作模式。这项融合了人工智能与图像处理的前沿技术,已成为财务自动化不可或缺的核心引擎。核心技术:OCR驱动的智能识别引擎发票OCR技术的核心在于光学字符识别&#xff08…

时空大数据:数字时代的“时空罗盘“

引言:为何需要“时空大数据”?“大数据”早已成为热词,但“时空大数据”的提出却暗含深刻逻辑。中国工程院王家耀院士指出,早期社会存在三大认知局限:过度关注商业大数据而忽视科学决策需求;忽视数据的时空…

PySide笔记之信号连接信号

PySide笔记之信号连接信号code review! 在 PySide6(以及 Qt 的其他绑定,如 PyQt)中,信号可以连接到信号。也就是说,可以把一个信号的发射,作为另一个信号的触发条件。这样做的效果是:当第一个信…

Linux操作系统之线程:线程概念

目录 前言: 一、进程与线程 二、线程初体验 三、分页式存储管理初谈 总结: 前言: 大家好啊,今天我们就要开始翻阅我们linux操作系统的另外一座大山:线程了。 对于线程,大体结构上我们是划分为两部分…

windows利用wsl安装qemu

首先需要安装wsl,然后在swl中启动一个子系统。这里我启动一个ubuntu22.04。 接下来的操作全部为在子系统中的操作。 检查虚拟化 在开始安装之前,让我们检查一下你的机器是否支持虚拟化。 要做到这一点,请使用以下命令: sean@DESKTOP-PPNPJJ3:~$ LC_ALL=C lscpu | grep …

如何使用 OpenCV 打开指定摄像头

在计算机视觉应用中,经常需要从特定的摄像头设备获取视频流。例如,在多摄像头环境中,当使用 OpenCV 的 cv::VideoCapture 类打开摄像头时,如果不指定摄像头的 ID,可能会随机打开系统中的某个摄像头,或者按照…