本文實例講述了Python利用邏輯回歸模型解決MNIST手寫數字識別問題。分享給大家供大家參考,具體如下:
MNIST手寫數字識別問題:輸入黑白的手寫阿拉伯數字,通過機器學習判斷輸入的是幾。可以通過TensorFLow下載MNIST手寫數據集,通過import引入MNIST數據集并進行讀取,會自動從網上下載所需文件。
%matplotlib inlineimport tensorflow as tfimport tensorflow.examples.tutorials.mnist.input_data as input_datamnist=input_data.read_data_sets('MNIST_data/',one_hot=True)import matplotlib.pyplot as plt def plot_image(image): #圖片顯示函數 plt.imshow(image.reshape(28,28),cmap='binary') plt.show() print("訓練集數量:",mnist.train.num_examples, "特征值組成:",mnist.train.images.shape, "標簽組成:",mnist.train.labels.shape) batch_images,batch_labels=mnist.train.next_batch(batch_size=10) #批量讀取數據print(batch_images.shape,batch_labels.shape) print('標簽值:',np.argmax(mnist.train.labels[1000]),end=' ') #np.argmax()得到實際值print('獨熱編碼表示:',mnist.train.labels[1000])plot_image(mnist.train.images[1000]) #顯示數據集中第1000張圖片
輸出訓練集 的數量有55000個,并打印特征值的shape為(55000,784),其中784代表每張圖片由28*28個像素點組成,由于是黑白圖片,每個像素點只有黑白單通道,即通過784個數可以描述一張圖片的特征值。可以將圖片在Jupyter中輸出,將784個特征值reshape為28×28的二維數組,傳給plt.imshow()函數,之后再通過show()輸出。
MNIST提供next_batch()方法用于批量讀取數據集,例如上面批量讀取10個對應的images與labels數據并分別返回。該方法會按順序一直往后讀取,直到結束后會自動打亂數據,重新繼續讀取。
在打開mnist數據集時,第二個參數設置one_hot,表示采用獨熱編碼方式打開。獨熱編碼是一種稀疏向量,其中一個元素為1,其他元素均為0,常用于表示有限個可能的組合情況。例如數字6的獨熱編碼為第7個分量為1,其他為0的數組。可以通過np.argmax()函數返回數組最大值的下標,即獨熱編碼表示的實際數字。通過獨熱編碼可以將離散特征的某個取值對應歐氏空間的某個點,有利于機器學習中特征之間的距離計算
數據集的劃分,一種劃分為訓練集用于模型的訓練,測試集用于結果的測試,要求集合數量足夠大,而且具有代表性。但是在多次執行后,會導致模型向測試集數據進行擬合,從而導致測試集數據失去了測試的效果。因此將數據集進一步劃分為訓練集、驗證集、測試集,將訓練后的模型用驗證集驗證,當多次迭代結束之后再拿測試集去測試。MNIST數據集中的訓練集為mnist.train,驗證集為mnist.validation,測試集為mnist.test
與線性回歸相對比,房價預測是根據多個輸入參數x與對應權重w相乘再加上b得到線性的輸出房價。而還有許多問題的輸出是非線性的、控制在[0,1]之間的,比如判斷郵件是否為垃圾郵件,手寫數字為0~9等,邏輯回歸就是用于處理此類問題。例如電子郵件分類器輸出0.8,表示該郵件為垃圾郵件的概率是0.8.
邏輯回歸通過Sigmoid函數保證輸出的值在[0,1]之間,該函數可以將全體實數映射到[0,1],從而將線性的輸出轉換為[0,1]的數。其定義與圖像如下:
新聞熱點
疑難解答