๐ฉ ๋ฅ๋ฌ๋๊ณผ ์ธ๊ณต ์ ๊ฒฝ๋ง ์๊ณ ๋ฆฌ์ฆ์ ์ดํดํ๊ณ ํ ์ํ๋ก๋ฅผ ์ฌ์ฉํด ๊ฐ๋จํ ์ธ๊ณต ์ ๊ฒฝ๋ง ๋ชจ๋ธ์ ๋ง๋ค์ด ๋ด ์๋ค.
ํจ์ MNIST ๋ฐ์ดํฐ์ ๐
๋ฅ๋ฌ๋์์๋ MNIST ๋ฐ์ดํฐ์ ์ด ์ ๋ช ํฉ๋๋ค.
๐ MNIST : ์์ผ๋ก ์ด ์ซ์(0~9)๋ค๋ก ์ด๋ฃจ์ด์ง ๋ํ ๋ฐ์ดํฐ๋ฒ ์ด์ค
๐ ํ ์ํ๋ก
ํนํ, ํจ์ MNIST ๋ฐ์ดํฐ๋ ์๋ ์ ๋ช ํ๊ธฐ ๋๋ฌธ์ ๋ค์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์๋๋ฐ,
์ด ์ค ํ ์ํ๋ก(TensorFlow)๋ผ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํด ์ด ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฌ์ค๊ฒ ์ต๋๋ค!
๐ ์ผ๋ผ์ค
์ฐ๋ฆฌ๊ฐ ์ด์ฉํ ๋ฐ์ดํฐ์ ์ 10์ข ๋ฅ์ ํจ์ ์์ดํ ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
๊ทธ๋ผ, ํ ์ํ๋ก์ ์ผ๋ผ์ค ํจํค์ง๋ฅผ ์ํฌํธํ ํ, ํจ์ MNIST ๋ฐ์ดํฐ๋ฅผ ๋ค์ด๋ก๋ํฉ๋๋ค.
๐ ์ผ๋ผ์ค keras.datasets.fashion_mnist ๋ชจ๋์ load_data() : ํ๋ จ ๋ฐ์ดํฐ์ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ๋๋์ด ๋ฐํ
from tensorflow import keras
(train_input, train_target), (test_input, test_target) =\
keras.datasets.fashion_mnist.load_data()
๋ค์ด๋ก๋ ๋ฐ์ ๋ฐ์ดํฐ๋ ์ ํด๋ ์์ด์ฝ์ ํด๋ฆญํด์ ํ์ธํ ์ ์์ต๋๋ค.
๋ฐ์ดํฐ์ ํฌ๊ธฐ๋ ํ์ธํ๊ฒ ์ต๋๋ค~_~
ํ๋ จ ๋ฐ์ดํฐ ์ค ์ ๋ ฅ์ 60,000๊ฐ์ ์ด๋ฏธ์ง๋ก, ๊ฐ ์ด๋ฏธ์ง๋ 28x28 ํฌ๊ธฐ์ ๋๋ค.
ํ๋ จ ๋ฐ์ดํฐ ์ค ํ๊น๋ 60,000๊ฐ์ ์์๊ฐ ์๋ 1์ฐจ์ ๋ฐฐ์ด์ธ ๊ฒ์ ํ์ธํ ์ ์๋ค์~!
โถ๏ธ ํ๋ จ ๋ฐ์ดํฐ๋ 60,000๊ฐ
ํ ์คํธ ๋ฐ์ดํฐ ์ค ์ ๋ ฅ์ 10,000๊ฐ์ ์ด๋ฏธ์ง๋ก, ๊ฐ ์ด๋ฏธ์ง๋ 28x28 ํฌ๊ธฐ
ํ ์คํธ ๋ฐ์ดํฐ ์ค ํ๊น๋ 10,000๊ฐ์ ์์๊ฐ ์๋ 1์ฐจ์ ๋ฐฐ์ด
โถ๏ธ ํ ์คํธ ๋ฐ์ดํฐ๋ 10,000๊ฐ
๐ ๋ฐ์ดํฐ ๊ทธ๋ฆผ์ผ๋ก ์ถ๋ ฅ
๋งทํ๋กฏ๋ฆฝ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ํ๋ จ ๋ฐ์ดํฐ์์ ๋ช ๊ฐ์ ์ํ์ ๊ทธ๋ฆผ์ผ๋ก ์ถ๋ ฅํด๋ณด๊ฒ ์ต๋๋ค!
๐ ๋งทํ๋กฏ๋ฆฝ์ subplots() : ํ๋ฒ์ ์ฌ๋ฌ ๊ฐ์ ๊ทธ๋ํ ๊ทธ๋ฆฌ๊ธฐ (figure ์ axes ๋ฅผ ๋ฐํ)
๐ figure : ๊ทธ๋ํ๊ฐ ๊ทธ๋ ค์ง๋ ํ๋ ์ (subplot ์ ๊ทธ๋ํ์ด๊ณ , ํ figure ์์ ์ฌ๋ฌ ๊ฐ์ ๊ทธ๋ํ๋ฅผ ๊ทธ๋ฆด ์ ์์)
๐ axes : ์ค์ ๋ฐ์ดํฐ๊ฐ ๊ทธ๋ํ๋ก ๊ทธ๋ ค์ง๋ ์บ๋ฒ์ค
ex) fig, axs = plt.subplots(2,2) : 4๊ฐ(2x2)์ ax๋ค์ ๊ฐ์ง๋ ํ๋์ figure ์์ฑ
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 10, figsize=(10, 10))
for i in range(10):
axs[i].imshow(train_input[i], cmap="gray_r")
axs[i].axis("off")
plt.show()
๐ ํ๊น ๊ตฌ์ฑ
ํจ์ MNIST์ ํ๊น์ ๊ฐ๊ฐ 0~9๊น์ง์ ์ซ์ ๋ ์ด๋ธ๋ก ๋์ด ์๊ณ , ๊ฐ ์ซ์๋ง๋ค ์๋์ ๊ฐ์ด ๊ตฌ์ฑ๋ฉ๋๋ค. (10๊ฐ ๋ถ๋ฅ)
Label |0 |1 |2 |3 |4 |5 |6 |7 |8 |9
Description |ํฐ์
์ธ |๋ฐ์ง |์ค์จํฐ |๋๋ ์ค |์ฝํธ |์๋ฌ |์
์ธ |์ค๋์ปค์ฆ |๊ฐ๋ฐฉ |์ตํด ๋ถ์ธ
๊ทธ๋ผ ์ ๊ทธ๋ฆผ ์ํ์ ํ๊น๊ฐ์ ์ถ๋ ฅํด๋ณด๊ฒ ์ต๋๋ค.
์ซ์ 5๋ก ๊ฐ์ ๋ง์ง๋ง 2๊ฐ ์ํ์ด ๋ณด์ด๋ค์!
์ฆ, ๋ง์ง๋ง 2๊ฐ ์ํ์ ์ ์ฌ์ง์์ ๋ดค์ ๋ ๊ฐ์ ์ข ๋ฅ์ ์ ๋ฐ(์๋ฌ)์ธ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
๋ํ, ๊ฐ ๋ ์ด๋ธ๋ง๋ค์ ์ํ ๊ฐ์๋ฅผ ์ถ๋ ฅํด๋ณด๋ 6,000๊ฐ์ ์ํ์ด ๋ค์ด์๋ ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค.
๐ ๋ก์ง์คํฑ ํ๊ท๋ก ํจ์ ์์ดํ ๋ถ๋ฅํ๊ธฐ
๐ก ํด๋น ํ๋ จ ์ํ์ 60,000๊ฐ๋ ๋์์ฃ ..! ์ด๋ ๊ฒ ๋ฐ์ดํฐ ๊ฐ์๊ฐ ๋ง์ผ๋ฉด, ์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ํ๊บผ๋ฒ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ๋ จํ๋ ๊ฒ๋ณด๋จ ์ํ์ ํ๋์ฉ ๊บผ๋ด์ ํ๋ จํ๋ ๋ฐฉ๋ฒ์ด ํจ์จ์ ์ ๋๋ค.
โจ ์ํ์ ํ๋์ฉ ๊บผ๋ด์ ํ๋ จํ๋ค..?! ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ์ด ๋ฑ์ด๋ค์! ๊ทธ๋ ๊ธฐ์ 4์ฅ์์ ๋ฐฐ์ด SGDClassifier ๋ฅผ ์ฌ์ฉํ ์์ ์ด๊ณ , 4์ฅ๊ณผ ๊ด๋ จ๋ ๋ด์ฉ์ ์๋๋ฅผ ์ฐธ๊ณ ํด์ฃผ์ธ์.๐
[์ธ๊ณต์ง๋ฅ/ํผ๊ณต๋จธ์ ] 04-2. ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ (1)
์ ์ง์ ์ธ ํ์ต์ ์ํ ๋ฌธ์ ์ธ์ ๋ชจ๋ธ์ด ๋งค๋ฒ ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ๋ค์ ์๋กญ๊ฒ ํ๋ จํ๋ ๋ฐ์๋ ์๋์ ๊ฐ์ ๋ฌธ์ ๋ค์ด ์์ต๋๋ค. ํ๋ จ ๋ฐ์ดํฐ๊ฐ ํ ๋ฒ์ ์ค๋น๋๋ ๊ฒ์ด ์๋๋ผ ์กฐ๊ธ์ฉ ์ ๋ฌ๋๋ค๋ฉด, ๊ณ
avoc-o-d.tistory.com
๐ ์ํ ์ค๋นํ๊ธฐ
๐04 ์ฅ ๋ฆฌ๋ง์ธ๋ !
- SGDClassifier ํด๋์ค์ loss ๋งค๊ฐ๋ณ์๋ฅผ 'log'๋ก ์ง์ ํ์์ต๋๋ค.
- ๐ค ์ด์ ๋? ๋ก์ง์คํฑ ์์ค ํจ์๋ฅผ ์ต์ํํ๋ ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ ๋ชจ๋ธ์ ๋ง๋ค๊ธฐ ์ํด์ ์ ๋๋ค.
- ํ์คํ ์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ต๋๋ค.
- ๐ค ์ด์ ๋? ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ์ ์ฌ๋ฌ ํน์ฑ ์ค ๊ธฐ์ธ๊ธฐ๊ฐ ๊ฐ์ฅ ๊ฐํ๋ฅธ ๋ฐฉํฅ์ ๋ฐ๋ผ ์ด๋ํ๋๊น, ๋ง์ฝ ํน์ฑ๋ง๋ค ๊ฐ์ ๋ฒ์๊ฐ ๋ง์ด ๋ค๋ฅด๋ฉด ์ฌ๋ฐ๋ฅด๊ฒ ์์ค ํจ์์ ๊ฒฝ์ฌ๋ฅผ ๋ด๋ ค์ฌ ์ ์๊ธฐ ๋๋ฌธ์ ๋๋ค.
๐ก ๊ทธ๋ ๋ค๋ฉด ํจ์ MNIST ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํ์ฌ ๋ชจ๋ธ์ ๋ง๋ค๊ธฐ ์ํด, ํจ์ MNIST ๋ฐ์ดํฐ๋ฅผ ํ์คํ ์ ์ฒ๋ฆฌํ๋ ๋ฐฉ๋ฒ์?
ํจ์ MNIST์ ๊ฒฝ์ฐ ๊ฐ ํฝ์ ์ 0~255 ์ฌ์ด์ ์ ์ซ๊ฐ์ ๊ฐ์ง๋๋ค. ๊ทธ๋ ๋ค๋ฉด, 255๋ก ๋๋์ด 0~1 ์ฌ์ด์ ๊ฐ์ผ๋ก ์ ๊ทํํฉ๋๋ค.
(์ฐธ๊ณ ) ํ์คํ๋ ์๋์ง๋ง, ์์๊ฐ์ผ๋ก ์ด๋ฃจ์ด์ง ์ด๋ฏธ์ง๋ฅผ ์ ์ฒ๋ฆฌํ ๋ ๋๋ฆฌ ์ฌ์ฉ๋๋ ๋ฐฉ๋ฒ์ ๋๋ค.
๐ก SGDClassifier ๋ 2์ฐจ์ ์ ๋ ฅ์ ๋ค๋ฃจ์ง ๋ชปํ๊ธฐ ๋๋ฌธ์, ํ์ฌ 2์ฐจ์ ๋ฐฐ์ด์ธ ๊ฐ ์ํ์ 1์ฐจ์ ๋ฐฐ์ด๋ก ํผ์ณ์ผ ํฉ๋๋ค.
28x28 ์ด๋ฏธ์ง๋ฅผ ํผ์ณ์ ๊ธธ์ด๊ฐ 784์ธ 1์ฐจ์ ๋ฐฐ์ด๋ก ๋ง๋ญ๋๋ค.
train_scaled = train_input / 255.0
train_scaled = train_scaled.reshape(-1, 28*28) # ์ฒซ ๋ฒ์งธ ์ฐจ์(์ํ ๊ฐ์)์ ๋ณํ์ง ์๊ณ ์๋ณธ ๋ฐ์ดํฐ์ ๋ ๋ฒ์จฐ, ์ธ ๋ฒ์จฐ ์ฐจ์์ด 1์ฐจ์์ผ๋ก ํฉ์ณ์ง
784๊ฐ์ ํฝ์ ๋ก ์ด๋ฃจ์ด์ง 60,000๊ฐ์ ์ํ์ด ์ค๋น๋์์ต๋๋ค! ์ฑ๋ฅ์ ํ์ธํด๋ณด๊ฒ ์ต๋๋ค.
๐ ์ฑ๋ฅ ํ์ธ
SGDClassifier ํด๋์ค์ cross_validate() ๋ฅผ ์ด์ฉํด ์ค๋น๋ ์ํ์์ ๊ต์ฐจ ๊ฒ์ฆ์ผ๋ก ์ฑ๋ฅ์ ํ์ธํฉ๋๋ค.
- SGDClassifier ํด๋์ค์ loss ๋งค๊ฐ๋ณ์๋ฅผ 'log'๋ก ์ง์ ํฉ๋๋ค.
๐ค ํ์ฌ ๋ฌธ์ ๋ 10๊ฐ ํด๋์ค(0~9)๋ฅผ ๋ถ๋ฅํ๋ ๋ค์ค ๋ถ๋ฅ ๋ฌธ์ ์ธ๋ฐ, ์ ๋ก์ง์คํฑ ์์ค ํจ์(์ด์ง ํฌ๋ก์ค์ํธ๋กํผ ์์ค ํจ์ : ์ด์ง ๋ถ๋ฅ์์ ์ฌ์ฉํ๋ ์์ค ํจ์) ์ฐ๋ ๊ฑด๊ฐ์?
๐ก ์ฌ์ดํท๋ฐ SGDClassifier์ ๋ก์ง์คํฑ ์์ค ํจ์๋ ์์ง๋ง, ํฌ๋ก์ค์ํธ๋กํผ ์์ค ํจ์(๋ค์ค ๋ถ๋ฅ์์ ์ฌ์ฉํ๋ ์์ค ํจ์)๋ฅผ ์ง์ ํ ๊ณณ์ด ์์ต๋๋ค.
๐ค ๊ทธ๋ผ ๋ก์ง์คํฑ ์์ค ํจ์๋ฅผ ์ด๋ป๊ฒ ์ฒ๋ฆฌํ๋์?
๐ก(์๋ ๊ทธ๋ฆผ)
๐(์ฐธ๊ณ ์ฉ) 04์ฅ ๋ก์ง์คํฑ ํ๊ท๋ก ๋ค์ค ๋ถ๋ฅ ์ํํ๊ธฐ ๋ฆฌ๋ง์ธ๋
๐ก ์ฆ, ์ด์ง ๋ถ๋ฅ๋ ๋ค์ค ๋ถ๋ฅ๋ ์๊ด ์์ด SGDClassifier ํด๋์ค์ loss ๋งค๊ฐ๋ณ์๋ฅผ 'log'๋ก ์ง์ ํ๋ฉด ๊ฐ ๋ถ๋ฅ์ ๋ง๊ฒ ํ๋ จ์ด ๋ฉ๋๋ค.
from sklearn.model_selection import cross_validate
from sklearn.linear_model import SGDClassifier
sc = SGDClassifier(loss="log", max_iter=5, random_state=42) # ๋ฐ๋ณต ํ์ 5๋ฒ (ํ์ ๋๋ ค๋ ์ฑ๋ฅ ํฅ์์ ๋ฏธ๋ฏธํจ)
scores = cross_validate(sc, train_scaled, train_target, n_jobs=-1)
์ ํ๋ 82%๋ผ๋! ์ฑ๋ฅ์ด ์ ์ข๋ค์. ๐
์ฑ๋ฅ์ ํฅ์์ํค๋ ๊ณผ์ ์ ์ดํดํ๊ธฐ ์ํด! ์ฐ์ ๋ก์ง์คํฑ ํ๊ท ๊ณต์์ ๋ํด ๋ฆฌ๋ง์ธ๋ ํ ํ์๊ฐ ์์ต๋๋ค.
๐04 ์ฅ ๋ฆฌ๋ง์ธ๋ !
- ๋ก์ง์คํฑ ํ๊ท ๊ณต์
z = a x (๋ฌด๊ฒ) + b x (๊ธธ์ด) + c x (๋๊ฐ์ ๊ธธ์ด) + d x (๋์ด) + e x (๋๋น) + f
๐ก ๋ก์ง์คํฑ ํ๊ท ๊ณต์์ ํจ์ MNIST ๋ฐ์ดํฐ์ ๋ง๊ฒ ๋ณํํ๋ฉด?
- ์ด 784๊ฐ์ ํฝ์ , ์ฆ 784๊ฐ์ ํน์ฑ์ด ์์
- ๊ฐ ๋ ์ด๋ธ(0~9)๋ง๋ค ๋ฐฉ์ ์์ ์ธ์
- โจ ์ฃผ์! ๊ฐ ๋ ์ด๋ธ๋ง๋ค ๊ฐ์ค์น์ ์ ํธ์ ๋ค๋ฅธ ๊ฐ์ ์ฌ์ฉํด์ผ ํจ.
๐ค ์? ๋ ์ด๋ธ ๋ชจ๋ ๊ฐ์ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๋ค๋ฉด ๊ตฌ๋ถํ ์ ์์ ๋ฆฌ๊ฐ ์๊ฒ ์ฃ !- [๋ ์ด๋ธ 0] z_ํฐ์ ์ธ = w1 x (ํฝ์ 1) + w2 x (ํฝ์ 2) + ... + w784 x (ํฝ์ 784) + b
- [๋ ์ด๋ธ 1] z_๋ฐ์ง = w1' x (ํฝ์ 1) + w2' x (ํฝ์ 2) + ... + w784' x (ํฝ์ 784) + b'
- ...
โถ๏ธ 4์ฅ์์ ๋ฐฐ์ ๋ ๊ฒ์ฒ๋ผ,,! 10๊ฐ์ ํด๋์ค(ํฐ์ ์ธ , ๋ฐ์ง, ...)์ ๋ํ ์ ํ ๋ฐฉ์ ์์ ๋ชจ๋ ๊ณ์ฐ(z_ํฐ์ ์ธ , z_๋ฐ์ง, ...)ํ ๋ค์ ์ํํธ๋งฅ์ค ํจ์๋ฅผ ํต๊ณผํ์ฌ ๊ฐ ํด๋์ค์ ๋ํ ํ๋ฅ ์ ์ป์ ์ ์์ต๋๋ค~! ๐ฆ๐ฆ
์ง๊ธ๊น์ง ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ์ ์ฌ์ฉํ๋ ๋ก์ง์คํฑ ํ๊ท์ ๋ํด ๋ณต์ต์ ํ์ต๋๋ค.
๊ทธ๋ฐ๋ฐ! ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ ์ธ๊ณต ์ ๊ฒฝ๋ง์ ๋ก์ง์คํฑ ํ๊ท์ ๊ฐ๋ค๋ ์ฌ์ค..!!! ๐
๋ค์ ๊ธ์์ ์ธ๊ณต ์ ๊ฒฝ๋ง์ ๋ํด ์ด์ด ์์ฑํ๊ฒ ์ต๋๋ค!
'๐ป My Work > ๐ง AI' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[์ธ๊ณต์ง๋ฅ/ํผ๊ณต๋จธ์ ] 07-1. ์ธ๊ณต ์ ๊ฒฝ๋ง (3) (0) | 2023.01.02 |
---|---|
[์ธ๊ณต์ง๋ฅ/ํผ๊ณต๋จธ์ ] 07-1. ์ธ๊ณต ์ ๊ฒฝ๋ง (2) (0) | 2023.01.01 |
[์ธ๊ณต์ง๋ฅ/ํผ๊ณต๋จธ์ ] 04-2. ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ (2) (0) | 2022.12.09 |
[์ธ๊ณต์ง๋ฅ/ํผ๊ณต๋จธ์ ] 04-2. ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ (1) (0) | 2022.12.09 |
[์ธ๊ณต์ง๋ฅ/ํผ๊ณต๋จธ์ ] 04-1. ๋ก์ง์คํฑ ํ๊ท (2) | 2022.12.08 |