用TensorFlow进行逻辑回归(六)

import tensorflow as tf

import numpy as np

from tensorflow.keras.datasets import mnist

import time

# MNIST数据集参数

num_classes = 10  # 数字0到9, 10类

num_features = 784  # 28*28

# 训练参数

learning_rate = 0.01

training_steps = 1000

batch_size = 256

display_step =50

# 预处理数据集

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 转为float32

x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)

# 转为一维向量

x_train, x_test = x_train.reshape([-1, num_features]), x_test.reshape([-1, num_features])

# [0, 255] 到 [0, 1]

x_train, x_test = x_train / 255, x_test / 255

# tf.data.Dataset.from_tensor_slices 是使用x_train, y_train构建数据集

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))

# 将数据集打乱,并设置batch_size大小

train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)

# 权重[748, 10],图片大小28*28,类数

W = tf.Variable(tf.ones([num_features, num_classes]), name="weight")

# 偏置[10],共10类

b = tf.Variable(tf.zeros([num_classes]), name="bias")

# 逻辑回归函数

def logistic_regression(x):

    return tf.nn.softmax(tf.matmul(x, W) + b)

# 损失函数

def cross_entropy(y_pred, y_true):

    # tf.one_hot()函数的作用是将一个值化为一个概率分布的向量

    y_true = tf.one_hot(y_true, depth=num_classes)

    # tf.clip_by_value将y_pred的值控制在1e-9和1.0之间

    y_pred = tf.clip_by_value(y_pred, 1e-9, 1.0)

    return tf.reduce_mean(-tf.reduce_sum(y_true * tf.math.log(y_pred)))

# 计算精度

def accuracy(y_pred, y_true):

    # tf.cast作用是类型转换

    correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))

    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# 优化器采用随机梯度下降

optimizer = tf.optimizers.SGD(learning_rate)

# 梯度下降

def run_optimization(x, y):

    with tf.GradientTape() as g:

        pred = logistic_regression(x)

        loss = cross_entropy(pred, y)

    # 计算梯度

    gradients = g.gradient(loss, [W, b])

    # 更新梯度

    optimizer.apply_gradients(zip(gradients, [W, b]))

# 开始训练

start = time.perf_counter()

for epoch in range(5):

    for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):

        run_optimization(batch_x, batch_y)

        if step % display_step == 0:

            pred = logistic_regression(batch_x)

            loss = cross_entropy(pred, batch_y)

            acc = accuracy(pred, batch_y)

            print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc))

   

# 测试模型的准确率

pred = logistic_regression(x_test)

print("Test Accuracy: %f" % accuracy(pred, y_test))

elapsed = (time.perf_counter() - start)

print("Time used:",elapsed)

例3

import  matplotlib.pyplot as plt

import  numpy as np

import tensorflow as tf

print(tf.__version__)

%matplotlib inline

mnist = tf.keras.datasets.mnist

(train_images,train_labels),(test_images,test_labels)=mnist.load_data()

total_num=len(train_images)

valid_split=0.2

train_num =int(total_num*(1-valid_split))

train_x=train_images[:train_num]

train_y=train_labels[:train_num]

valid_x=train_images[train_num:]

valid_y=train_labels[train_num:]

test_x=test_images

test_y=test_labels

train_x=train_x.reshape(-1,784)

valid_x=valid_x.reshape(-1,784)

test_x=test_x.reshape(-1,784)

train_x=tf.cast(train_x/255.0,tf.float32)

valid_x=tf.cast(valid_x/255.0,tf.float32)

test_x=tf.cast(test_x/255.0,tf.float32)

train_y=tf.one_hot(train_y,depth=10)

valid_y=tf.one_hot(valid_y,depth=10)

test_y=tf.one_hot(test_y,depth=10)

#定义模型函数

def model(x,w,b):

    pred=tf.matmul(x,w)+b

    return tf.nn.softmax(pred)

np.random.seed(612)

W = tf.Variable(np.random.randn(784,10),dtype=tf.float32)

B = tf.Variable(np.random.randn(10),dtype=tf.float32)

def loss(x,y,w,b):

    pred = model(x,w,b)

    loss_=tf.keras.losses.categorical_crossentropy(y_true=y,y_pred=pred)

    return tf.reduce_mean(loss_)

#设置迭代次数和学习率

train_epochs = 100

batch_size=50

learning_rate = 0.001

def grad(x,y,w,b):

    with tf.GradientTape() as tape:

        loss_ = loss(x,y,w,b)

    return tape.gradient(loss_,[w,b])

optimizer= tf.keras.optimizers.Adam(learning_rate=learning_rate)

def accuracy(x,y,w,b):

     pred = model(x,w,b)

     correct_prediction=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))

     return tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#构建线性函数的斜率和截距

total_step=int(train_num/batch_size)

loss_list_train = []

loss_list_valid = []

acc_list_train = []

acc_valid_train = []

training_epochs=100

#开始训练,轮数为epoch,采用SGD随机梯度下降优化方法

for epoch in range(training_epochs):

    for step in range(total_step):

        xs=train_x[step*batch_size:(step+1)*batch_size]

        ys=train_y[step*batch_size:(step+1)*batch_size]

        #计算损失,并保存本次损失计算结果

        grads=grad(xs,ys,W,B)

        optimizer.apply_gradients(zip(grads,[W,B]))

    loss_train =loss(train_x,train_y,W,B).numpy()

    loss_valid =loss(valid_x,valid_y,W,B).numpy()

    acc_train=accuracy(train_x,train_y,W,B).numpy()

    acc_valid=accuracy(valid_x,valid_y,W,B).numpy()

    loss_list_train.append(loss_train)

    loss_list_valid.append(loss_valid)

    acc_list_train.append(acc_train)

    acc_valid_train.append(acc_valid)

print("epoch={:3d},train_loss={:.4f},train_acc={:.4f},val_loss={:.4f},val_acc={:.4f}".format(epoch+1,loss_train,acc_train,loss_valid,acc_valid))

plt.xlabel("Epochs")

plt.ylabel("Loss")

plt.plot(loss_list_train,'blue',label="Train Loss")

plt.plot(loss_list_valid,'red',label="Valid Loss")

plt.xlabel("Epochs")

plt.ylabel("Accuracy")

plt.plot(acc_list_train,'blue',label="Train Acc")

plt.plot(acc_valid_train,'red',label="Valid Acc")

acc_test=accuracy(test_x,test_y,W,B).numpy

print("Test accuracy:",acc_test)

def predict(x,w,b):

    pred=model(x,w,b)

    result=tf.argmax(pred,1).numpy

return result

pred_test=predict(test_x,W,B)

pred_test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/91293.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/91293.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【HTTP版本演变】

在浏览器中输入URL并按回车之后会发生什么1. 输入URL并解析输入URL后,浏览器会解析出协议、主机、端口、路径等信息,并构造一个HTTP请求(浏览器会根据请求头判断是否又HTTP缓存,并根据是否有缓存决定从服务器获取资源还是使用缓存…

Android 16系统源码_窗口动画(一)窗口过渡动画层级图分析

一 窗口过渡动画 1.1 案例效果图1.2 案例源码 1.2.1 添加权限 (AndroidManifest.xml) <!-- 系统悬浮窗权限&#xff08;Android 6.0需动态请求&#xff09; --> <uses-permission android:name"android.permission.SYSTEM_ALERT_WINDOW" />1.2.2 窗口显示…

腾讯云WAF域名分级防护实战笔记

基于业务风险等级、合规要求及腾讯云最佳实践&#xff0c;提供可直接落地的配置方案&#xff0c;供学习借鉴&#xff1a;一、域名分级与防护原则1. ​域名分级清单&#xff08;核心资产&#xff09;​​​主域名​​业务类型​​风险等级​​合规要求​​防护等级​example.com…

1. 请说出你知道的水平垂直居中的方法

总结 容器 flex 布局&#xff0c;jsutify-content: center; align-items: center;容器 flex 布局&#xff0c;子项 margin: auto;容器 relative 布局&#xff0c;子项 absolute 布局&#xff0c;left: 50%; top: 50%; transform: translate(-50%, -50%);子项 absolute 布局&…

VS Code `launch.json` 完整配置指南:参数详解 + 配置实例

文章目录&#x1f4e6; 一、基本结构&#x1f50d; 二、单个配置项详解示例配置&#xff1a;&#x1f9e9; 三、字段说明与可选值&#x1f4c1; 四、常用变量&#xff08;宏替换&#xff09;&#x1f6e0;️ 五、常见配置实例1️⃣ 调试当前打开的 .py 文件2️⃣ 调试 Jupyter …

使用浏览器inspect调试wx小程序

edge://inspect/#devices调试wx小程序 背景&#xff1a; 在开发混合项目的过程中&#xff0c;常常需要在app环境排查问题&#xff0c;接口可以使用fiddler等工具来抓包&#xff0c;但是js错误就不好抓包了&#xff0c;这里介绍一种调试工具-浏览器。 调试过程 首先电脑打开edg…

【论文阅读】-《Simple Black-box Adversarial Attacks》

简单黑盒对抗攻击 Chuan Guo Jacob R. Gardner Yurong You Andrew Gordon Wilson Kilian Q. Weinberger 摘要 我们提出了一种在黑盒&#xff08;black-box&#xff09;场景下构建对抗样本&#xff08;adversarial images&#xff09;的极其简单的方法。与白盒&#xff08;…

基于ASP.NET+SQL Server实现(Web)企业进销存管理系统

企业进销存管理系统的设计和实现一、摘要进销存管理是现代企业生产经营中的重要环节&#xff0c;是完成企业资源配置的重要管理工作&#xff0c;对企业生产经营效率的最大化发挥着重要作用。本文以我国中小企业的进销存管理为研究对象&#xff0c;描述了企业进销存管理系统从需…

(LeetCode 面试经典 150 题 ) 15. 三数之和 (排序+双指针)

题目&#xff1a;15. 三数之和 思路&#xff1a;排序双指针&#xff0c;时间复杂度0(n^2nlogn)。 先将数组nums升序排序&#xff0c;方便去重和使用双指针。第一层for循环来枚举第一位数&#xff0c;后面使用双指针来找到第二个、第三个数即可&#xff0c;细节看注释。 C版本…

easy-springdoc

介绍 简化springdoc的使用&#xff08;可以搭配knife4j-openapi3-jakarta-spring-boot-starter一起使用&#xff09; maven引用 <dependency><groupId>io.github.xiaoyudeguang</groupId><artifactId>easy-springdoc</artifactId><version>…

配置nodejs,若依

1.配置node.js环境 Node.js — Download Node.js 1.下载好一路下一步&#xff0c;可以安装到d盘 装完之后执行 npm -v 显示版本号即安装成功 2.安装好后新建两个文件夹&#xff0c;node_cache和node_global 3.配置环境变量 新建变量 在path里编辑变量 4.配置用户变量 5.…

Python学习之路(十二)-开发和优化处理大数据量接口

文章目录一、接口设计原则二、性能优化策略1. 数据库优化2. 缓存机制3. 并发模型三、内存管理技巧1. 内存优化实践2. 避免内存泄漏四、接口测试与监控1. 性能测试2. 日志与监控3. 错误处理与限流五、代码示例&#xff08;Flask 流式处理&#xff09;六、部署建议一、接口设计原…

【实时Linux实战系列】实时数据流的网络传输

在实时系统中&#xff0c;数据流的实时传输是许多应用场景的核心需求之一。无论是工业自动化中的传感器数据、金融交易中的高频数据&#xff0c;还是多媒体应用中的视频流&#xff0c;都需要在严格的时间约束内完成数据的传输。实时数据流的传输不仅要求高吞吐量&#xff0c;还…

C#数组(一维数组、多维数组、交错数组、参数数组)

在 C# 中&#xff0c;数组是一种用于存储固定大小的相同类型元素的集合。数组可以包含值类型、引用类型或对象类型的元素&#xff0c;并且在内存中是连续存储的。以下是关于 C# 数组的详细介绍&#xff1a;1. 一维数组声明与初始化// 声明数组 int[] numbers; // 声…

Dify离线安装包-集成全部插件、模板和依赖组件,方便安可内网使用

项目介绍 Dify一键离线安装包&#xff0c;集成安装了全部插件、模板&#xff0c;并集成了dify全部插件所需的依赖组件。方便你在内网、安可环境等离线状态下使用。 Dify是一个开源的LLM应用开发平台。其直观的界面结合了AI工作流、RAG管道、Agent、模型管理、可观测性功能等&…

面试150 翻转二叉树

思路 采用先序遍历&#xff0c;可以通过新建根节点node&#xff0c;将原来root的右子树连到去node的左子树中&#xff0c;root的左子树连到去node的右子树中。 # Definition for a binary tree node. # class TreeNode: # def __init__(self, val0, leftNone, rightNone): …

C++-linux系统编程 3.gcc编译工具

GCC编译工具链完全指南 GCC&#xff08;GNU Compiler Collection&#xff09;是Linux系统下最常用的编译器套件&#xff0c;支持C、C、Objective-C等多种编程语言。本章将深入讲解GCC的编译流程、常用选项及项目实战技巧。 一、GCC编译的四个核心阶段 GCC编译一个程序需要经过四…

uView UI 组件大全

uView UI 是一个基于 uni-app 的高质量 UI 组件库&#xff0c;提供丰富的跨平台组件&#xff08;支持 H5、小程序、App 等&#xff09;。以下是其核心组件的分类大全及功能说明&#xff0c;结合最新版本&#xff08;1.2.10&#xff09;整理&#xff1a; &#x1f4e6; 一、基础…

QWidget 和 QML 的本质和使用上的区别

QWidget 和 QML 是 Qt 框架中两种不同的 UI 开发技术&#xff0c;它们在底层实现、设计理念和使用场景上有显著区别。以下是它们的本质和主要差异&#xff1a;1. 本质区别特性QWidgetQML (Qt Modeling Language)技术基础基于 C 的面向对象控件库基于声明式语言&#xff08;类似…

中转模型服务的风险

最近发现一些 AI 相关帖子下&#xff0c;存在低质 claude code 中转的小广告。 其中转的基本原理就是 claude code 允许自己提供 API endpoint 和 key&#xff0c;可以使用任意一个 OpenAI API 兼容的供应商&#xff0c;就这么简单。 进一点 claude token&#xff0c;再混入一点…