2023 ๋ฐ์ด์คํฌ์ค ๋ฐ์ดํฐ ๊ฒฝ์ง๋ํ - ์น์ํ ๋ถ์ผ (๋ฐ์ด์คํฌ์ค ํ์ ์ตํฉ๋ํ 7๊ฐ ๋ํ ์ฌํ์ ๋ถ๋ฌธ)
$USER/RESULT
โโโโโTRAIN/
โ โโโ TL_AdamW.py
โ โโโ FT_RMSprop.py
โโโโโTEST/
โ โโโ predict_AdamWRMSprop.py
โโโ README.md
โโโ FT_RMSdragon.pth
โโโ BEST_TL_AdamWdragon.pth
โโโ fin_tuned_model_AdamWRMSdragon2.csv
โโโ TRAIN : ํ์ต์ ํ์ํ TL_AdamW.py, FT_RMSprop.py ํ์ผ์ ๋ด์ ํด๋
โ โโโ TL_AdamW.py :
โ โ AdamW optimizer๋ฅผ ์ด์ฉํด Transfer learning์ ์งํ(๊ณต์๋ฌธ์์ฐธ๊ณ ๋ฐ ์คํ์ ์ผ๋ก ํ๋ผ๋ฏธํฐ ๊ฒฐ์ )**
โ โ 1. classification layer์ ์ด์ง๋ถ๋ฅ ๊ฒฐ๊ณผ๋ฅผ ๋์ถํ๋๋ก layer๋ฅผ ์ถ๊ฐ
โ โ 2. 150 epoch ์ค f1-score๊ฐ ๊ฐ์ฅ ๋์ ๋์ pth ์ ์ฅ
โ โโโ FT_RMSprop.py :
โ TL_AdamW.py์์ ์ป์ pth์ RMSprop optimizer๋ฅผ ์ด์ฉํด ๋ฏธ์ธ์กฐ์ **
โ 1. densenet121์ ๋ชจ๋ ํ๋ผ๋ฏธํฐ ๋๊ฒฐ(freeze) ํ classification layer๋ง ๋๊ฒฐํด์
โ 2. ๊ณต์๋ฌธ์๋ฅผ ์ฐธ๊ณ ํ์ฌ ์คํ์ ์ผ๋ก ์ป์ ํ๋ผ๋ฏธํฐ์ ํ์ต๋ฅ ์ ํตํด scheduler ์ธํ
โ 3. ์ต์ข
์ ์ผ๋ก ์ป์ pth ์์ฑ
โ
โโโ TEST : ์ถ๋ก ์ ํ์ํ predict_AdamWRMSprop.py ํ์ผ์ ๋ด์ ํด๋
โ โโโ predict_AdamWRMSprop.py:
โ baseline predict.py์ ๋์ผ. (๊ฒฝ๋ก๋ง ๋ค๋ฆ)**
โ 1. TRAIN์ ๋ ํ์ผ์ ํตํด ์ต์ข
์ ์ผ๋ก ์ป์ pth๋ก ์๋ก์ด ๋ฐ์ดํฐ์ ๋ํ ์์ธก ์ํ
โ 2. ์์ธก ๊ฒฐ๊ณผ 0/1์ low/high๋ก ๋ฐ๊พธ์ด csv ํํ๋ก ์ ์ฅ
โ (fin_tuned_model_AdamWRMSdragon2.csv)
โ
โโโ FT_RMSdragon.pth : ์ต์ข
์ ์ผ๋ก ํ์ต๋ ๋ชจ๋ธ
โโโ BEST_TL_AdamWdragon.pth : ์ต์ ์ฑ๋ฅ์ผ๋ ์ ์ฅ๋ ๋ชจ๋ธ
โโโ fin_tuned_model_AdamWRMSdragon2.csv : ์ถ๋ก ์ ๋๋ฆฐ ๊ฒฐ๊ณผ
- 'python3.9 TL_AdamW.py' ํ์ผ ์คํ(๋จผ์ ์ํ)
- 'python3.9 FT_RMSprop.py' ํ์ผ ์คํ
- 'USER/RESULT' ๋ด์ ๊ฒฐ๊ณผ๊ฐ ์ ์ฅ๋จ
- 'python3.9 predict_AdamWRMSprop.py' ์คํ
- 'USER/RESULT/' ๋ด์ ๊ฒฐ๊ณผ ํ์ผ(fin_tuned_model_AdamWRMSdragon2.csv)์ด ์ ์ฅ๋จ
'AdamW'๋ฅผ ์ฌ์ฉํ ์ด๊ธฐ ํ์ต ๋จ๊ณ๋ถํฐ ์์ํ ๋ค์ 'RMSprop'์ ์ฌ์ฉํ์ฌ ๋ฏธ์ธ ์กฐ์ ํ๋ ์ด ํ๋ก์ธ์ค์ ์ต์ข ๊ฒฐ๊ณผ๋ ์ด๋ก ์ ์ผ๋ก ์๋ ๋ชจ๋ธ์ ํ์ต์ ๊ณ์ํ๋ ๊ฒ๊ณผ ์ ์ฌํ์ง๋ง ์ตํฐ๋ง์ด์ ๋ฅผ ์ค๊ฐ์ ์ ํํ๋ค๋ ์ฐจ๋ณ์ ์ด ์๋ค.
- densenet121 : https://pytorch.org/vision/main/models/generated/torchvision.models.densenet121.html
- RMSprop : https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html
- AdamW : https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
-
๋ผ์ด๋ธ๋ฌ๋ฆฌ
- torch
- torchvision
- torch.optim์ RMSprop, lr_scheduler
- PIL
- tqdm
- sklearn
- pandas
- numpy
-
ํ๊ฒฝ
- python : 3.9.18
- cuda : runtimeAPI 11.3, driverAPI 11.7
- os : Ubuntu 20.04 LTS