本文参考了经典的LeNet-5卷积神经网络模型对mnist数据集进行训练。LeNet-5模型是大神Yann
LeCun于1998年在论文"Gradient-based learning applied to document
recognition"中提出来的,它是第一个成功应用于数字识别问题的卷积神经网络。下图展示了LeNet-5模型的架构。




文中所使用的卷积神经网络结构依次为输入层,卷积层1,池化层1,卷积层2,池化层2,全连接层1,全连接层2,输出层。
"""A very simple MNIST classifier. See extensive documentation at
https://www.tensorflow.org/get_started/mnist/beginners """ from __future__
import absolute_import from __future__ import division from __future__ import
print_function import argparse import sys from
tensorflow.examples.tutorials.mnist import input_data import numpy as np import
tensorflow as tf data_dir = './data/' mnist =
input_data.read_data_sets(data_dir, one_hot=True) #第一层卷积层尺寸和深度 CONV_1_SIZE = 3
CONV_1_DEEP = 32 INPUT_CHANNELS = 1 #输入通道数 #第二层卷积层尺寸和深度 CONV_2_SIZE = 3
CONV_2_DEEP = 64 #每批次数据集的大小 BATCH_SIZE = 100 #学习率 LEARNING_RATE_INIT = 1e-3
#学习率初始值 x = tf.placeholder(tf.float32, [None, 784]) y_ =
tf.placeholder(tf.float32, [None, 10]) #对输入向量x转换成图像矩阵形式 with
tf.variable_scope('reshape'): x_image = tf.reshape(x, [-1, 28, 28, 1])
#因为数据的条数未知,所以为-1 #卷积层1 with tf.variable_scope('conv1'): initial_value =
tf.truncated_normal([CONV_1_SIZE,CONV_1_SIZE,INPUT_CHANNELS,CONV_1_DEEP],
stddev=0.1) conv_1_w = tf.Variable(initial_value=initial_value,
collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'WEIGHTS']) conv_1_b =
tf.Variable(initial_value=tf.constant(0.1, shape=[CONV_1_DEEP])) conv_1_l =
tf.nn.conv2d(x_image, conv_1_w, strides=[1,1,1,1], padding='SAME') + conv_1_b
conv_1_h = tf.nn.relu(conv_1_l) #池化层1 with tf.variable_scope('pool1'): pool_1_h
= tf.nn.max_pool(conv_1_h, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
#卷积层2 with tf.variable_scope('conv2'): conv_2_w =
tf.Variable(tf.truncated_normal([CONV_2_SIZE,CONV_2_SIZE,CONV_1_DEEP,CONV_2_DEEP],
stddev=0.1), collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'WEIGHTS']) conv_2_b =
tf.Variable(tf.constant(0.1, shape=[CONV_2_DEEP])) conv_2_l =
tf.nn.conv2d(pool_1_h, conv_2_w, strides=[1,1,1,1], padding='SAME') + conv_2_b
conv_2_h = tf.nn.relu(conv_2_l) #池化层2 with tf.name_scope('pool2'): pool_2_h =
tf.nn.max_pool(conv_2_h, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
#全连接层1 with tf.name_scope('fc1'): # fc_1_w =
tf.Variable(tf.truncated_normal([7*7*64, 1024], stddev=0.1)) fc_1_b =
tf.Variable(tf.constant(0.1, shape=[1024]))
#全连接层的输入为向量,而池化层2的输出为7x7x64的矩阵,所以这里要将矩阵转化成一个向量 pool_2_h_flat =
tf.reshape(pool_2_h, [-1,7*7*64]) fc_1_h = tf.nn.relu(tf.matmul(pool_2_h_flat,
fc_1_w) + fc_1_b) #dropout在训练时会随机将部分节点的输出改为0,以避免过拟合问题,从而使得模型在测试数据上的效果更好
#dropout一般只在全连接层而不是卷积层或者池化层使用 with tf.name_scope('dropout'): keep_prob =
tf.placeholder(tf.float32) fc_1_h_drop = tf.nn.dropout(fc_1_h, keep_prob)
#全连接层2 And 输出层 with tf.name_scope('fc2'): fc_2_w =
tf.Variable(tf.truncated_normal([1024,10], stddev=0.1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, 'WEIGHTS']) fc_2_b =
tf.Variable(tf.constant(0.1, shape=[10])) y = tf.matmul(fc_1_h_drop, fc_2_w) +
fc_2_b #交叉熵 cross_entropy =
tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
#l2正则项 l2_loss = tf.add_n([tf.nn.l2_loss(w) for w in
tf.get_collection('WEIGHTS')]) #代价函数 = 交叉熵加上惩罚项 total_loss = cross_entropy +
7e-5*l2_loss #定义一个Adam优化器 train_step =
tf.train.AdamOptimizer(LEARNING_RATE_INIT).minimize(total_loss) sess =
tf.InteractiveSession() init_op = tf.global_variables_initializer()
sess.run(init_op) #Train for step in range(5000): batch_xs, batch_ys =
mnist.train.next_batch(BATCH_SIZE) _, loss, l2_loss_value, total_loss_value =
sess.run( [train_step, cross_entropy, l2_loss, total_loss], feed_dict={x:
batch_xs, y_:batch_ys, keep_prob:0.5}) correct_prediction =
tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy =
tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # if (step+1)%200 == 0:
#每隔200步评估一下训练集和测试集 train_accuracy = accuracy.eval(feed_dict={x:batch_xs,
y_:batch_ys, keep_prob:1.0}) test_accuracy =
accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels,
keep_prob:1.0}) print("step:%d, loss:%f, train_acc:%f, test_acc:%f" % (step,
total_loss_value, train_accuracy, test_accuracy))
输出:



友情链接
ioDraw流程图
API参考文档
OK工具箱
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:ixiaoyang8@qq.com
QQ群:637538335
关注微信