python之tensorflow手把手實(shí)例講解貓狗識(shí)別實(shí)現(xiàn)
作為tensorflow初學(xué)的大三學(xué)生,本次課程作業(yè)的使用貓狗數(shù)據(jù)集做一個(gè)二分類模型。
一,貓狗數(shù)據(jù)集數(shù)目構(gòu)成
| train | cats:1000 ,dogs:1000 |
|---|---|
| test | cats: 500,dogs:500 |
| validation | cats:500,dogs:500 |
二,數(shù)據(jù)導(dǎo)入
train_dir = 'Data/train' test_dir = 'Data/test' validation_dir = 'Data/validation' train_datagen = ImageDataGenerator(rescale=1/255, rotation_range=10, width_shift_range=0.2, #圖片水平偏移的角度 height_shift_range=0.2, #圖片數(shù)值偏移的角度 shear_range=0.2, #剪切強(qiáng)度 zoom_range=0.2,#隨機(jī)縮放的幅度 horizontal_flip=True,#是否進(jìn)行隨機(jī)水平翻轉(zhuǎn) # fill_mode='nearest' ) train_generator = train_datagen.flow_from_directory(train_dir, (224,224),batch_size=1,class_mode='binary',shuffle=False) test_datagen = ImageDataGenerator(rescale=1/255) test_generator = test_datagen.flow_from_directory(test_dir, (224,224),batch_size=1,class_mode='binary',shuffle=True) validation_datagen = ImageDataGenerator(rescale=1/255) validation_generator = validation_datagen.flow_from_directory( validation_dir,(224,224),batch_size=1,class_mode='binary') print(train_datagen) print(test_datagen) print(train_datagen)
三,數(shù)據(jù)集構(gòu)建
我這里是將ImageDataGenerator類里的數(shù)據(jù)提取出來,將數(shù)據(jù)與標(biāo)簽分別存放在兩個(gè)列表,后面在轉(zhuǎn)為np.array,也可以使用model.fit_generator,我將數(shù)據(jù)放在內(nèi)存為了后續(xù)調(diào)參數(shù)時(shí)模型訓(xùn)練能更快讀取到數(shù)據(jù),不用每次訓(xùn)練一整輪都去讀一次數(shù)據(jù)(應(yīng)該是這樣的…我是這樣理解…)
注意我這里的數(shù)據(jù)集構(gòu)建后,三種數(shù)據(jù)都是存放在內(nèi)存中的,我電腦內(nèi)存是16g的可以存放下。
train_data=[] train_labels=[] a=0 for data_train, labels_train in train_generator: train_data.append(data_train) train_labels.append(labels_train) a=a+1 if a>1999: break x_train=np.array(train_data) y_train=np.array(train_labels) x_train=x_train.reshape(2000,224,224,3)
test_data=[] test_labels=[] a=0 for data_test, labels_test in test_generator: test_data.append(data_test) test_labels.append(labels_test) a=a+1 if a>999: break x_test=np.array(test_data) y_test=np.array(test_labels) x_test=x_test.reshape(1000,224,224,3)
validation_data=[] validation_labels=[] a=0 for data_validation, labels_validation in validation_generator: validation_data.append(data_validation) validation_labels.append(labels_validation) a=a+1 if a>999: break x_validation=np.array(validation_data) y_validation=np.array(validation_labels) x_validation=x_validation.reshape(1000,224,224,3)
四,模型搭建
model1 = tf.keras.models.Sequential([ # 第一層卷積,卷積核為,共16個(gè),輸入為150*150*1 tf.keras.layers.Conv2D(16,(3,3),activation='relu',padding='same',input_shape=(224,224,3)), tf.keras.layers.MaxPooling2D((2,2)), # 第二層卷積,卷積核為3*3,共32個(gè), tf.keras.layers.Conv2D(32,(3,3),activation='relu',padding='same'), tf.keras.layers.MaxPooling2D((2,2)), # 第三層卷積,卷積核為3*3,共64個(gè), tf.keras.layers.Conv2D(64,(3,3),activation='relu',padding='same'), tf.keras.layers.MaxPooling2D((2,2)), # 數(shù)據(jù)鋪平 tf.keras.layers.Flatten(), tf.keras.layers.Dense(64,activation='relu'), tf.keras.layers.Dropout(0.5), tf.keras.layers.Dense(1,activation='sigmoid') ]) print(model1.summary())
模型summary:

五,模型訓(xùn)練
model1.compile(optimize=tf.keras.optimizers.SGD(0.00001),
loss=tf.keras.losses.binary_crossentropy,
metrics=['acc'])
history1=model1.fit(x_train,y_train,
# validation_split=(0~1)選擇一定的比例用于驗(yàn)證集,可被validation_data覆蓋
validation_data=(x_validation,y_validation),
batch_size=10,
shuffle=True,
epochs=10)
model1.save('cats_and_dogs_plain1.h5')
print(history1)

plt.plot(history1.epoch,history1.history.get('acc'),label='acc')
plt.plot(history1.epoch,history1.history.get('val_acc'),label='val_acc')
plt.title('正確率')
plt.legend()

可以看到我們的模型泛化能力還是有點(diǎn)差,測試集的acc能達(dá)到0.85以上,驗(yàn)證集卻在0.65~0.70之前跳動(dòng)。
六,模型測試
model1.evaluate(x_validation,y_validation)

最后我們的模型在測試集上的正確率為0.67,可以說還不夠好,有點(diǎn)過擬合,可能是訓(xùn)練數(shù)據(jù)不夠多,后續(xù)可以數(shù)據(jù)增廣或者從驗(yàn)證集、測試集中調(diào)取一部分?jǐn)?shù)據(jù)用于訓(xùn)練模型,可能效果好一些。
到此這篇關(guān)于python之tensorflow手把手實(shí)例講解貓狗識(shí)別實(shí)現(xiàn)的文章就介紹到這了,更多相關(guān)python tensorflow 貓狗識(shí)別內(nèi)容請(qǐng)搜索本站以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持本站!
版權(quán)聲明:本站文章來源標(biāo)注為YINGSOO的內(nèi)容版權(quán)均為本站所有,歡迎引用、轉(zhuǎn)載,請(qǐng)保持原文完整并注明來源及原文鏈接。禁止復(fù)制或仿造本網(wǎng)站,禁止在非maisonbaluchon.cn所屬的服務(wù)器上建立鏡像,否則將依法追究法律責(zé)任。本站部分內(nèi)容來源于網(wǎng)友推薦、互聯(lián)網(wǎng)收集整理而來,僅供學(xué)習(xí)參考,不代表本站立場,如有內(nèi)容涉嫌侵權(quán),請(qǐng)聯(lián)系alex-e#qq.com處理。
關(guān)注官方微信