教程 | 基于LSTM实现手写数字识别

news/2024/7/4 3:40:01

点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

基于tensorflow,如何实现一个简单的循环神经网络,完成手写数字识别,附完整演示代码。

01 LSTM网络构建

基于tensorflow实现简单的LSTM网络,完成mnist手写数字数据集训练与识别。这个其中最重要的构建一个LSTM网络,tensorflow已经给我们提供相关的API, 我们只要使用相关API就可以轻松构建一个简单的LSTM网络。

首先定义输入与目标标签

# create RNN network
X = tf.placeholder(shape=[None, time_steps, num_features], dtype=tf.float32)
Y = tf.placeholder(shape=[None, 10], dtype=tf.float32)

其中

  • None: 表示batchsize的大小或者数目

  • time_steps: 网络把输出重新输入的次数

  • num_features: 输入矩阵/神经元

构建LSTM单元

lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)
outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)

其中:

lstm_cell 表示 LSTM 的单元
num_hidden : 隐藏层节点数目
forget_bias: 遗忘门中要加上的增益偏置

outputs: 网络输出

states:状态

这样我们就构建好一个LSTM循环神经网络了,它的执行过程是很魔幻的。简直是神奇!以后再说。

02 代码程序执行与输出

完整的代码演示分为如下几个部分:

  • 加载数据集

  • 创建LSTM网络

  • 训练网络

  • 执行测试

import tensorflow as tf
from tensorflow.contrib import rnn
import numpy as npfrom tensorflow.examples.tutorials.mnist import input_data
print(tf.__version__)
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)num_hidden = 128
time_steps = 28
num_features = 28
num_classes = 10
batch_size = 128# create RNN network
X = tf.placeholder(shape=[None, time_steps, num_features], dtype=tf.float32)
Y = tf.placeholder(shape=[None, 10], dtype=tf.float32)# Define weights
weights = {'out': tf.Variable(tf.random_normal([num_hidden, num_classes]))
}
biases = {'out': tf.Variable(tf.random_normal([num_classes]))
}def rnn_network(x, weights, biases):x = tf.unstack(x, time_steps, 1)lstm_cell = rnn.BasicLSTMCell(num_hidden, forget_bias=1.0)outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)return tf.matmul(outputs[-1], weights['out']) + biases['out']# 输入预测
logits = rnn_network(X, weights, biases)
prediction = tf.nn.softmax(logits)# 定义损失函数与优化器
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
optimizer = tf.train.AdamOptimizer()
train_op = optimizer.minimize(loss_op)# 计算识别精度
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))# 开始训练
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for step in range(1, 5001):batch_x, batch_y = mnist.train.next_batch(batch_size)# Reshape data to get 28 seq of 28 elementsbatch_x = batch_x.reshape((batch_size, time_steps, num_features))# Run optimization op (backprop)sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})if step % 1000 == 0 or step == 1:# Calculate batch loss and accuracyloss, acc = sess.run([loss_op, accuracy], feed_dict={X: batch_x,Y: batch_y})print("Step " + str(step) + ", Loss= " + \"{:.4f}".format(loss) + ", Training Accuracy= " + \"{:.3f}".format(acc))print("Optimization Finished!")# 使用测试数据集测试训练号的模型, 测试128张手写数字图像test_len = 128test_data = mnist.test.images[:test_len].reshape((-1, time_steps, num_features))test_label = mnist.test.labels[:test_len]print("Testing Accuracy:", \sess.run(accuracy, feed_dict={X: test_data, Y: test_label}))

运行输出如下:

4dcbcaaccecca62a34851d394491b4a6.png

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲

在「小白学视觉」公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲

在「小白学视觉」公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~

3700a4677c996e2d6eb117b13db3f440.png

c251f820a4e26d91cd3d5ca6e7dabf5d.png


http://lihuaxi.xjx100.cn/news/269849.html

相关文章

C++11 带来的新特性 (2)—— 统一初始化(Uniform Initialization)

1 统一初始化(Uniform Initialization) 在C 11之前,所有对象的初始化方式是不同的,经常让写代码的我们感到困惑。C 11努力创造一个统一的初始化方式。 其语法是使用{}和std::initializer_list,先看示例。 int values[]…

《大话西游》20年后重映(附影评:《大话西游》你真的看懂了吗?)

2014-10-25 02:43:24 来源: 北京日报(北京)本报讯 (记者 周南焱)“电影里的台词差点儿都能背,但在影院里再看还是会笑。看到最后紫霞仙子死的时候。还是忍不住落泪!”昨天下午,经典老片《大话西游》在海航活力天宝影城…

php libev pthreads,libuv 与 libev 的对比

05 January 2013libuv和libev,两个名字相当相近的 I/O Library,最近有幸用两个 Library 都写了一些东西,下面就来说一说我本人对两者共同与不同点的主观表述。高性能网络编程这个话题已经被讨论烂了。异步,异步,还是异…

跨平台网络游戏趋势和优势

跨平台网络游戏趋势和优势 前几年还是网页游戏蓬勃发展的状态,就有分析指出从明年开始网页游戏市场已经饱和,想想几年前客户端游戏也是同样的窘境,如果将桌面、移动设备、网页统称一个词汇的话,那就是终端,现在各种的终…

计算摄影 | 计算机如何学会自动裁剪图片(自动构图)?

点击上方“小白学视觉”,选择加"星标"或“置顶”重磅干货,第一时间送达1 自动构图基础1.1 什么是构图自动裁剪用摄影的话语来说,就是自动构图。构图来源于绘画,最初指绘画时根据题材和主题思想的要求,把要表…

阿里云蒋江伟:我们致力于为世界提供70%的算力 | 凌云时刻

导读:6月9日,2020阿里云峰会在云端召开,阿里巴巴合伙人、阿里云智能基础产品事业部高级研究员蒋江伟出席峰会并做了题为《新基建,新算力:阿里云基础设施算力全新升级》的重磅发布。(以下内容为演讲实录&…

关于Redis缓存,这3个问题一定要知道!

点击上方蓝色“方志朋”,选择“设为星标”回复“666”获取独家整理的学习资料!来源:https://4m.cn/e3JwR最近都没看Redis,现在回来温习下,现在从Redis的三大缓存开始重新探一探有多深有多浅(^▽^)让我来开始知识的醍醐…

第十一周作业关于json

json文件的实例: json文件:{ "name":"王小二", "age":25.2, "birthday":"1990-01-01", "school":"蓝翔", "major(技能)":["理发",&q…