tensorflow实现手写识别

前言

看过一段时间tensorflow的教程,看代码能知个大概,也算入了门。当时学的时候发现最蛋疼的就是如何调输入数据的格式了。自学能力差到一定的极致,就是看了几天教程这一个问题还是没得到解决。但是我还是对A1课程上的手写识别做了简单的训练预测。代表我入门过。(仅仅是入门,代码简单大神再见)

简介

使用 TensorFlow, 你必须明白 TensorFlow:

  • 使用图 (graph) 来表示计算任务.
  • 在被称之为 会话 (Session) 的上下文 (context) 中执行图.
  • 使用 tensor 表示数据.
  • 通过 变量 (Variable) 维护状态.
  • 使用 feed 和 fetch 可以为任意的操作(arbitrary operation) 赋值或者从其中获取数据.

从我个人理解出发,tensorflow的框架使用也大致分为三个部分:网络结构定义,数据格式调整及输入,数据预测。python只是tensorflow的一个数据输入接口,实际做训练预测都是通过C语言来实现的,这是为了提高速度和效率。接下来就分这三个模块介绍下我的代码。

网络结构定义

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
x = tf.placeholder(tf.float32, [None, 400]) # 图像输入向量
W = tf.Variable(tf.zeros([400, 10])) # 权重,初始化值为全零
b = tf.Variable(tf.zeros([10])) # 偏置,初始化值为全零
# 进行模型计算,y是预测,y_ 是实际
y = tf.nn.softmax(tf.matmul(x, W) + b)
y_ = tf.placeholder("float", [None, 10])
# 计算交叉熵
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
# 接下来使用BP算法来进行微调,以0.01的学习速率
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
# 上面设置好了模型,添加初始化创建变量的操作
init = tf.initialize_all_variables()
# 启动创建的模型,并初始化变量
sess = tf.Session()
sess.run(init)

数据格式调整及输入

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 从老师提供的csv数据文件中读取数据
data_matrix = np.loadtxt(open('data.csv', 'rb'), delimiter=',')
data_labels = np.loadtxt(open('dataLabels.csv', 'rb'))
data_indice = range(5000)
random.shuffle(data_indice)
train_data = []
test_data = []
labels = []
train_labels = []
test_labels = []
# 需要将数据转换成one hot向量形式
for i in range(data_labels.size):
labels.append([0] * 10)
labels[i][int(data_labels[i])] = 1
count = 0
for i in data_indice:
if (count < 4500):
train_data.append(data_matrix[i])
train_labels.append(labels[i])
else:
test_data.append(data_matrix[i])
test_labels.append(labels[i])
count += 1

数据预测

代码如下:

1
2
3
4
5
6
7
8
9
print ("train beginning...")
sess.run(train_step, feed_dict={x: train_data, y_: train_labels})
''' 进行模型评估 '''
# 判断预测标签和实际标签是否匹配
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# 计算所学习到的模型在测试数据集上面的正确率
print(sess.run(accuracy, feed_dict={x: test_data, y_: test_labels}))
print ("ok")

结果显示

简单网络手写识别结果
因为使用了最简单的网络,重点在于入门,所以准确率很低,才61%,也无伤大雅吧(尴尬的笑)。