수학적 접근

[머신러닝] MNIST 사용 기초 - Dataset 가져와서 조회하기 본문

개발/AI

[머신러닝] MNIST 사용 기초 - Dataset 가져와서 조회하기

평등수렴 2020. 1. 22. 11:17
반응형

MNIST Database란 손글씨 숫자 이미지 집합으로, 머신러닝 분야의 트레이닝 및 테스트에 널리 사용되는 데이터셋이다.

 

MNIST Dataset 다운로드

https://drive.google.com/open?id=1IQXvFigDTVKlcZAP2oTgTMOXqdhFOsZ_

 

dataset.zip

 

drive.google.com

위 파일의 압축을 풀어서 나오는 폴더를 코드를 작성할 파이썬 파일이 있는 경로에 위치시킨다.

 

그리고 이 데이터셋을 파이썬 파일에서 불러오기 위해 다음과 같이 코드를 작성한다.

import sys, os
sys.path.append("./dataset") # 이때, dataset 폴더는 실행하는 py 파일의 경로와 일치해야 한다.
import numpy as np
import pickle
from dataset.mnist import load_mnist
import matplotlib.pylab as plt

load_mnist 는 불러온 데이터셋을 np.array 로 변환해주는 함수이다.

 

만약

 

cannot import name 'load_mnist' from 'mnist'

 

라는 오류가 뜬다면, dataset 폴더의 경로를 일치시키지 않아 파일이 제대로 인식되지 않은 경우이므로

폴더를 정확하게 위치시킨다.

 

 

그리고 load_mnist() 함수를 사용하여 아래와 같이 데이터를 변환한다.

(train_image_data, train_label_data), (test_image_data, test_label_data) = load_mnist(flatten = True, normalize = False)

train은 기계에게 학습시킬 데이터, test는 학습된 결과를 확인할 때 사용할 데이터를 의미한다.

 

load_mnist 함수의 파라미터로는 다음과 같은 것이 있다.

 

normalize: 0~255 gray-scale 0~1 사이의 값으로 변환할지 여부 (True/False)

 

flatten: 입력 이미지를 1차원으로 저장할지에 대한 여부 (True/False)

- flatten을 True로 할 경우

print(train_image_data.shape)
print(train_label_data.shape)
print(test_image_data.shape)
print(test_label_data.shape)

출력
(60000, 784) # 784 = 28 * 28
(60000,)
(10000, 784)
(10000,)

- flatten을 False로 할 경우

print(train_image_data.shape)
print(train_label_data.shape)
print(test_image_data.shape)
print(test_label_data.shape)

출력
(60000, 1, 28, 28)
(60000,)
(10000, 1, 28, 28)
(10000,)

 

one_hot_label : 정답을 뜻하는 원소만 1이고 나머지는 모두 0 배열로 저장한다.

예를 들어, [0, 0, 1, 0, 0, 0, 0, 0, 0, 0] 2 의미한다

 

- one_hot_label을 True로 할 경우 (flatten = True 상태)

print(train_image_data.shape)
print(train_label_data.shape)
print(test_image_data.shape)
print(test_label_data.shape)

출력
(60000, 784)
(60000, 10)
(10000, 784)
(10000, 10)

 

 

여기까지 하면 MNIST의 Dataset을 우리가 사용하기 쉽게 불러온 것이다.

 

 

위의 print 결과에서 알 수 있듯이, train 데이터는 6만 개, test 데이터는 1만 개 존재한다.

 

그리고 image_data, label_data가 구분되어 있는데,

 

image_data는 실제로 저장되어 있는 이미지의 데이터이고,

 

label_data는 그 이미지가 실제로 어떤 값을 나타내는지가 저장되어 있는 데이터이다.

 

image_datalabel_data에 접근하여 이미지 및 실제 값을 읽어올 수 있다.

 

 

이를 위하여 다음과 같이 함수를 작성한다.

import matplotlib.pylab as plt
def mnist_show(n) :
    image = train_image_data[n]
    image_reshaped = image.reshape(28, 28)
    image_reshaped.shape
    label = train_label_data[n]
    plt.figure(figsize = (4, 4))
    plt.title("sample of " + str(label))
    plt.imshow(image_reshaped, cmap="gray")
    plt.show()

n(0~59,999)을 입력값으로 주면, 그 번호에 맞는 label과 image를 가져와서, 그걸 그림으로 나타내는 함수이다.

 

다음과 같이 실행하면

get_image(2747)

 

아래와 같이 이미지를 불러온다.

 

 

 

 

+++++

 

내가 마시고 있는 생수는 안전한 생수일까?

15년도부터 생수 수질기준 부적합 판정을 받은 이력을 확인할 수 있는 앱, <바른생수>를 출시하였습니다.

 

관심 한번씩만 부탁드립니다 ^^

 

바로가기!

 

 

반응형
Comments