人工智能学习57-TF训练

人工智能学习概述—快手视频
人工智能学习57-TF训练—快手视频
人工智能学习58-TF训练—快手视频
人工智能学习59-TF预测—快手视频

训练示例代码

#导入keras.utils 工具包 
import keras.utils 
#导入mnist数据集 
from keras.datasets import mnist 
#引入tensorflow 类库 
import tensorflow.compat.v1 as tf 
#关闭tensorflow 版本2的功能,仅使用tensorflow版本1的功能 
tf.disable_v2_behavior() 
#引用numpy 处理矩阵操作 
import numpy as np 
#引用图形处理类库 
import matplotlib.pyplot as plt 
import matplotlib 
#引入操作系统类库,方便处理文件与目录 
import os 
#避免多库依赖警告信息 
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 
#设置tensorflow 训练模型所在目录 
model_path = '../log/model.ckpt' 
#设置神经网络分类数量,0-9个数字需要10个分类 
num_classes = 10 
#从数据集mnist装入训练数据集和测试数据集,mnist提供load_data方法 
(x_train, y_train), (x_test, y_test) = mnist.load_data() 
#灰度图编码范围0-255,将编码归一化,转化为0-1之间数值 
x_train = x_train.astype('float32') / 255.0 
x_test = x_test.astype('float32') / 255.0 
#将训练和测试标注数据转化为张量(batch,num_classes) 
y_train = keras.utils.to_categorical(y_train, num_classes) 
y_test = keras.utils.to_categorical(y_test, num_classes) 
#定义输入张量,分配地址 
x = tf.placeholder(tf.float32, [None, 784]) 
y = tf.placeholder(tf.float32, [None, 10]) 
#初始化第一层权重W矩阵 
w1 = tf.Variable(tf.random_normal([784, 128])) 
#初始化第一层偏置B向量 
b1 = tf.Variable(tf.zeros([128])) 
#定义第一层激活函数输入值 X*W+B 
hc1 = tf.add(tf.matmul(x,w1),b1) 
#定义第一层输出,调用激活函数sigmoid 
h1 = tf.sigmoid(hc1) 
#初始化第二层权重W矩阵 
w2 = tf.Variable(tf.random_normal([128, 10])) 
#初始化第二层偏置B向量 
b2 = tf.Variable(tf.zeros([10])) 
#使用激活函数softmax,预测第二层输出 
pred = tf.nn.softmax(tf.matmul(h1, w2) + b2) 
#使用交叉熵定义代价函数 
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) 
#定义学习率 
learn_rate = 0.01 
#使用梯度下降法优化网络 
optimizer = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost) 
epoch_list = [] 
cost_list = [] train_epoch = 30 
batch_size = 100 
display_step = 1 
#定义神经网络模型保存对象 
saver = tf.train.Saver() 
#触发tensorflow初始化,为定义变量赋值 
init = tf.global_variables_initializer() 
#启动tensorflow会话 
with tf.Session() as sess: if os.path.exists('../log/model.ckpt.meta'): saver.restore(sess, model_path) #如果存在网络模型,在会话中装入网络模型 else: sess.run(init) #如果不存在网络模型,会话执行初始化工作 #循环遍历每次训练 for epoch in range(train_epoch): #定义平均损失 avg_cost = 0. #计算总批次 total_batch = int(x_train.shape[0] / batch_size) #循环每批次样本数据 for i in range(total_batch): #读取每批次训练样本数据 batch_xs = x_train[i*batch_size: (i+1)*batch_size] #读取每批次训练标签样本数据 batch_ys = y_train[i*batch_size: (i+1)*batch_size] #转化样本数据格式,添加第一维度代表样本数量 batch_xs = np.reshape(batch_xs, (100, -1)) #启动tensorflow进行样本数据训练 _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: 
batch_ys}) #累计每批次的平均损失 avg_cost += c / total_batch #记录平均损失 epoch_list.append(epoch+1) cost_list.append(avg_cost) #每隔display_step次训练,输出一次统计信息 if (epoch+1) % display_step == 0: print('Epochs: ','%04d' % (epoch+1), 'Costs ', 
'{:.9f}'.format(avg_cost)) print('Train Finished') #保存tensorflow训练的模型 
save_dir = saver.save(sess, model_path) 
print('TensorFlow model save as file %s' % save_dir) 
#图形显示训练结果 
matplotlib.rcParams['font.family'] = 'SimHei' 
plt.plot(epoch_list, cost_list, '.') 
plt.title('Train Model') 
plt.xlabel('Epoch') 
plt.ylabel('Cost') 
plt.show() 

在这里插入图片描述
在这里插入图片描述

测试示例代码

#导入keras.utils 工具包 
import keras.utils 
#导入mnist数据集 
from keras.datasets import mnist 
#引入tensorflow 类库 
import tensorflow.compat.v1 as tf 
#关闭tensorflow 版本2的功能,仅使用tensorflow版本1的功能 
tf.disable_v2_behavior() 
#引用图形处理类库 
import matplotlib.pyplot as plt 
#引用numpy处理矩阵操作 
import numpy as np 
#引入操作系统类库,方便处理文件与目录 
import os 
#避免多库依赖警告信息 
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 
#设置tensorflow 训练模型所在目录 
model_path = '../log/model.ckpt' 
#设置神经网络分类数量,0-9个数字需要10个分类 
num_classes = 10 
#从数据集mnist装入训练数据集和测试数据集,mnist提供load_data方法 
(x_train, y_train), (x_test, y_test) = mnist.load_data() 
#灰度图编码范围0-255,将编码归一化,转化为0-1之间数值 
x_train = x_train.astype('float32') / 255.0 
x_test = x_test.astype('float32') / 255.0 
#将训练和测试标注数据转化为张量(batch,num_classes) 
y_train = keras.utils.to_categorical(y_train, num_classes) 
y_test = keras.utils.to_categorical(y_test, num_classes) 
#定义输入张量,分配地址 
x = tf.placeholder(tf.float32, [None, 784]) 
y = tf.placeholder(tf.float32, [None, 10]) 
#初始化第一层权重W矩阵 
w1 = tf.Variable(tf.random_normal([784, 128])) 
#初始化第一层偏置B向量 
b1 = tf.Variable(tf.zeros([128])) 
#定义第一层激活函数输入值 X*W+B 
hc1 = tf.add(tf.matmul(x,w1),b1) 
#定义第一层输出,调用激活函数sigmoid 
h1 = tf.sigmoid(hc1) 
#初始化第二层权重W矩阵 
w2 = tf.Variable(tf.random_normal([128, 10])) 
#初始化第二层偏置B向量 
b2 = tf.Variable(tf.zeros([10])) 
#使用激活函数softmax,预测第二层输出 
pred = tf.nn.softmax(tf.matmul(h1, w2) + b2) 
#测试样本集数量 
test_num = x_test.shape[0] 
#定义获取随机整数函数,返回0- test_num 
def rand_int(): 
rand = np.random.RandomState(None) 
return rand.randint(low=0, high=test_num) 
#定义神经网络模型保存对象 
saver = tf.train.Saver() 
n = rand_int() 
#启动tensorflow 会话 
with tf.Session() as sess: 
#在会话中装入网络模型 
saver.restore(sess, model_path) 
#读取测试样本数据 
batch_xs = x_test[n: n+2] 
#读取测试标签样本数据 
batch_ys = y_test[n: n+2] 
#转化样本数据格式,添加第一维度代表样本数量 
batch_xs = np.reshape(batch_xs, (2, -1)) 
#定义模型预测输出概率最大品类 
output = tf.argmax(pred, 1) 
#使用模型预算 
outputv, predv = sess.run([output, pred], feed_dict={x: batch_xs}) 
#图形输出 
plt.figure(figsize=(2, 3)) 
for i in range(batch_xs.ndim): 
plt.subplot(1, 2, i+1) 
plt.subplots_adjust(wspace=2) 
t = batch_xs[i].reshape(28, 28) 
plt.imshow(t, cmap='gray') 
if outputv[i] == batch_ys[i].argmax(): 
plt.title('%d,%d' 
% 
color='green') 
else: 
(outputv[i], 
batch_ys[i].argmax()), 
plt.title('%d,%d' % (outputv[i], batch_ys[i].argmax()), color='red') 
plt.xticks([]) 
plt.yticks([]) 
plt.show() 

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/84569.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/84569.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL(83)如何设置密码复杂度策略?

在 MySQL 中,可以通过配置密码策略来设置密码的复杂度要求。MySQL 提供了一些参数和插件来帮助管理员强制实施密码复杂度策略,确保数据库用户使用强密码。下面将详细介绍如何设置密码复杂度策略,并结合代码示例进行说明。 1. 使用 validate_…

如何使用postman做接口自动化测试?

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 本文适合已经掌握 Postman 基本用法的读者,即对接口相关概念有一定了解、已经会使用Postman 进行模拟请求等基本操作。 工作环境与版本: …

面试-操作系统

用户态和内核态的区别 内核态:在内核态下,CPU可以执行所有的指令和访问所有的硬件资源。 用户态:在用户态下,CPU只能执行部分指令集,无法直接访问硬件资源。 内核态的底层操作主要包括:内存管理、进程管理…

【基础算法】二分(二分查找 + 二分答案)

文章目录 一、二分查找1. 【案例】在排序数组中查找元素的第一个和最后一个位置 ⭐(1) 二分查找的引入(2) 解题细节(important)(3) 代码示例(4) 【模板】二分查找(5) STL 中的二分查找 2. 牛可乐和封印魔法 ⭐⭐(1) 解题思路(2) 代码实现 3. A-B 数对 ⭐…

多协议物联网关的方案测试-基于米尔全志T536开发板

本文将介绍基于米尔电子MYD-LT536开发板(米尔基于全志T536开发板)的多协议物联网关方案的开发测试。 摘自优秀创作者-ALSET 米尔基于全志T536开发板 为了充分的应用该开发板,结合T536处理器的特点,这里进一步的进行软件开发&…

echarts的还原,下载图片失效(空白图片,还原白屏)

echarts的toolbox.feature. restore 和toolbox.feature. saveAsImage 失效 也没有任何报错, 只需要修改: // chart.setOption(op); chart.setOption(op,true);

56-Oracle SQL Tuning Advisor(STA)

各位小伙伴,一般都用哪些优化工具,Oracle SQL Tuning Advisor (STA)用的多吗,Profile就是它的其中1个产物,下一期再弄Profile,STA 的核心功能是自动化诊断高负载SQL的性能瓶颈​(如全表扫描、缺失索引&…

修改element-plus的主题色css变量

提示:本文仅是记录我修改element-plus等组件库的css变量, 具体【实现主题色切换看这篇】即可 文章目录 1.文件划分2.src/style/index.scss入口文件3.src/style/theme.scss主题色切换维护4.src/style/_color-utils.scss动态生成element-plus的scss变量5.…

Vibe Coding - 进阶 Cursor Rules

文章目录 为什么要配置 .cursorrules使用 .cursorrules 的五大优势 如何创建与应用 .cursorrules✅ 基础步骤🛠 创建方式: 高质量 .cursorrules 文件,应包含以下内容配置示例Java 项目TypeScript React 项目总结 cursorrules 推荐网站 为什么…

腾讯云自动化助手(TAT)技术评估报告

摘要 腾讯云自动化助手(TAT)作为云服务器(CVM)与轻量应用服务器(Lighthouse)的原生运维工具,通过无密码批量命令执行(Shell/Python/PowerShell)、交互式会话管理及公共命…

【simulink】IEEE5节点系统潮流仿真模型(2机5节点全功能基础模型)

主要内容 该模型为simulink仿真模型,主要实现的内容如下: 模型是基于 Simulink 搭建的电力系统潮流计算仿真模型,围绕2 台发电机、5 个节点的拓扑结构构建,用于电力系统稳态分析,是电力系统研究、教学及工程实践中…

责任链模式详解

责任链模式 场景 顾名思义,责任链模式(Chain of Responsibility Pattern)为请求创建了一个接收者对象的链。这种模式给予请求的类型,对请求的发送者和接收者进行解耦。这种类型的设计模式属于行为型模式。 在这种模式中&#x…

Taro 跨端应用性能优化全攻略:从原理到实践

引言:为什么需要性能优化? 在当今移动互联网时代,用户体验已经成为决定产品成败的关键因素。根据 Google 的研究,页面加载时间每增加 1 秒,移动端转化率就会下降 20%。对于使用 Taro 开发的跨端应用来说,性…

Git集成Jenkins通过Pipeline方式实现一键部署

Docker方式部署Jenkins 部署自定义Docker网络 部署Docker网络的作用: 隔离性便于同一网络内容器相互通信 # 创建名为jenkins的docker网络 docker network create --subnet 172.18.0.0/16 --gateway 172.18.0.1 jenkins# 查看docker网络列表 docker network ls# …

磐基PaaS平台MongoDB组件SSPL许可证风险与合规性分析(下)

#作者:任少近 3.7.条款六:非源代码形式分发 官方原文如下: 原文关键部分:“You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License.” 解读:“您不得…

桌面小屏幕实战课程:DesktopScreen 2 第一个工程

飞书文档http://https://x509p6c8to.feishu.cn/docx/doxcnkGhtbxcv8ge5wKFkunsgmm 一、创建工程 cd ~/esp cp -r esp-idf/examples/get-started/hello_world . cd ~/esp/hello_world//设置目标板卡相关 idf.py set-target esp32//可配置工程属性 idf.py menuconfig 工程源码…

华为云Flexus+DeepSeek征文|体验华为云ModelArts快速搭建Dify-LLM应用开发平台并搭建查询数据库的大模型工作流

华为云FlexusDeepSeek征文|体验华为云ModelArts快速搭建Dify-LLM应用开发平台并搭建查询数据库的大模型工作流 什么是华为云ModelArts 华为云ModelArts ModelArts是华为云提供的全流程AI开发平台,覆盖从数据准备到模型部署的全生命周期管理&#xff0c…

【深度学习】TensorFlow全面指南:从核心概念到工业级应用

TensorFlow全面指南:从核心概念到工业级应用 一、TensorFlow:人工智能时代的计算引擎1.1 核心特性与优势 二、安装与环境配置2.1 版本选择建议2.2 GPU支持关键组件 三、TensorFlow核心概念解析3.1 数据流图(Data Flow Graph)3.2 张量(Tensor)&#xff1a…

在VTK中捕捉体绘制图像进阶(同步操作)

0. 概要 这段代码实现了一个VTK(Visualization Toolkit)应用程序,主要功能是: 读取DICOM医学图像序列并进行体绘制(Volume Rendering)创建一个主窗口显示3D体绘制结果创建一个副窗口显示主窗口的2D截图将副窗口中的交互操作(如旋转、缩放等)转发到主窗口,而不影响副窗…

使用NPOI库导出多个Excel并压缩zip包

使用NPOI库导出Excel文件可以按照以下步骤进行: 添加NPOI库的引用:在项目中添加对NPOI库的引用。 创建一个新的Excel文件对象:使用NPOI中的HSSFWorkbook(对应.xls格式)或XSSFWorkbook(对应.xlsx格式&#…