-
大小: 4KB文件類型: .rar金幣: 2下載: 0 次發(fā)布日期: 2021-06-03
- 語言: Python
- 標(biāo)簽: CNN??Tensorflow??CIFAR10??深度學(xué)習(xí)??圖像分類??
資源簡介
文件中原始代碼利用CNN對(duì)CIFAR10數(shù)據(jù)集進(jìn)行分類,準(zhǔn)確度達(dá)到0.67,優(yōu)化代碼通過權(quán)重正則化、數(shù)據(jù)增強(qiáng),增加全連接層等方式進(jìn)行優(yōu)化,準(zhǔn)確度達(dá)到0.85。

代碼片段和文件信息
#?-*-?coding:?utf-8?-*-
“““
Created?on?Tue?Jan??8?14:04:20?2019
@author:?shihui
“““
import?tensorflow?as?tf
import?numpy?as?np
import?cifar10cifar10_input
import?time
‘‘‘
初始化權(quán)重函數(shù)
‘‘‘
def?variable_with_weight_loss(shapestdw1):
????var?=?tf.Variable(tf.truncated_normal(shapestddev=std)dtype=tf.float32)
????if?w1?is?not?None:
????????weight_loss?=?tf.multiply(tf.nn.l2_loss(var)w1name=“weight_loss“)
????????tf.add_to_collection(“l(fā)osses“weight_loss)
????return?var
‘‘‘
損失函數(shù)
‘‘‘
def?loss_func(logitslabels):
????labels?=?tf.cast(labelstf.int32)
????cross_entropy?=?tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits
???????????????????????????labels=labelsname=“cross_entropy_per_example“)
????cross_entropy_mean?=?tf.reduce_mean(tf.reduce_sum(cross_entropy))
????tf.add_to_collection(“l(fā)osses“cross_entropy_mean)
????return?tf.add_n(tf.get_collection(“l(fā)osses“)name=“total_loss“)
if?__name__?==?“__main__“:
????#設(shè)置最大迭代次數(shù)
????max_steps?=?10000
????#設(shè)置每次訓(xùn)練的數(shù)據(jù)大小
????batch_size?=?128
????#下載解壓數(shù)據(jù)
????cifar10.maybe_download_and_extract()
????#?設(shè)置數(shù)據(jù)的存放目錄
????cifar10_dir?=?“C:/Users/29811/Desktop/cifar10/dataset/cifar-10-batches-bin“
????#獲取數(shù)據(jù)增強(qiáng)后的訓(xùn)練集數(shù)據(jù)
????images_trainlabels_train?=?cifar10_input.distorted_inputs(cifar10_dirbatch_size)
????#獲取裁剪后的測試數(shù)據(jù)
????images_testlabels_test?=?cifar10_input.inputs(eval_data=Truedata_dir=cifar10_dir
???????????????????????????????????????????????????batch_size=batch_size)
????#定義模型的輸入和輸出數(shù)據(jù)
????image_holder?=?tf.placeholder(dtype=tf.float32shape=[batch_size24243])
????label_holder?=?tf.placeholder(dtype=tf.int32shape=[batch_size])
????#設(shè)計(jì)第一層卷積
????weight1?=?variable_with_weight_loss(shape=[55364]std=5e-2w1=0)
????kernel1?=?tf.nn.conv2d(image_holderweight1[1111]padding=“SAME“)
????bais1?=?tf.Variable(tf.constant(0.0dtype=tf.float32shape=[64]))
????conv1?=?tf.nn.relu(tf.nn.bias_add(kernel1bais1))
????pool1?=?tf.nn.max_pool(conv1[1331][1221]padding=“SAME“)
????norm1?=?tf.nn.lrn(pool14bias=1.0alpha=0.001?/?9beta=0.75)
????#設(shè)計(jì)第二層卷積
????weight2?=?variable_with_weight_loss(shape=[556464]std=5e-2w1=0)
????kernel2?=?tf.nn.conv2d(norm1weight2[1111]padding=“SAME“)
????bais2?=?tf.Variable(tf.constant(0.1dtype=tf.float32shape=[64]))
????conv2?=?tf.nn.relu(tf.nn.bias_add(kernel2bais2))
????norm2?=?tf.nn.lrn(conv24bias=1.0alpha=0.01?/?9beta=0.75)
????pool2?=?tf.nn.max_pool(norm2[1331][1221]padding=“SAME“)
????#第一層全連接層
????reshape?=?tf.reshape(pool2[batch_size-1])
????dim?=?reshape.get_shape()[1].value
????weight3?=?variable_with_weight_loss([dim384]std=0.04w1=0.004)
????bais3?=?tf.Variable(tf.constant(0.1shape=[384]dtype=tf.float32))
????local3?=?tf.nn.relu(tf.matmul(reshapeweight3)+bais3)
????#第二層全連接層
????weight4?=?variable_with_weight_loss([384192]std=0.04w1=0.004)
????bais4?=?tf.Variable(tf.constant(0.1shape=[192]dtype=tf.float32))
????local4?=?tf.nn.relu(tf.matm
?屬性????????????大小?????日期????時(shí)間???名稱
-----------?---------??----------?-----??----
?????文件???????5367??2020-05-14?18:03??優(yōu)化后.py
?????文件???????5751??2020-05-14?19:18??原始.py
-----------?---------??----------?-----??----
????????????????11118????????????????????2
評(píng)論
共有 條評(píng)論