Skip to content

Latest commit

ย 

History

History
59 lines (43 loc) ยท 2.51 KB

Transfer Learning.md

File metadata and controls

59 lines (43 loc) ยท 2.51 KB

๐Ÿ“• ์ „์ดํ•™์Šต ์ •๋ฆฌ (Pytorch)

TRANSFER LEARNING ์ด๋ž€?

์ž‘์€ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ๋„ CNNs ํ•™์Šต์„ ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•.

๋งค์šฐ ํฐ ๋ฐ์ดํ„ฐ์…‹(์˜ˆ. 100๊ฐ€์ง€ ๋ถ„๋ฅ˜์— ๋Œ€ํ•ด 120๋งŒ๊ฐœ์˜ ์ด๋ฏธ์ง€๊ฐ€ ํฌํ•จ๋œ ImageNet)์—์„œ ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง(ConvNet)์„ ๋ฏธ๋ฆฌ ํ•™์Šตํ•œ ํ›„, ์ด ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง์„ ๊ด€์‹ฌ์žˆ๋Š” ์ž‘์—… ์„ ์œ„ํ•œ ์ดˆ๊ธฐ ์„ค์ • ๋˜๋Š” ๊ณ ์ •๋œ ํŠน์ง• ์ถ”์ถœ๊ธฐ(fixed feature extractor)๋กœ ์‚ฌ์šฉํ•œ๋‹ค.

์‹œ๋‚˜๋ฆฌ์˜ค 2๊ฐ€์ง€

  1. fine tuning - ์‚ฌ์ „ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ์ƒˆ๋กœ์šด ๋ฌธ์ œ์— ์ ์šฉํ•˜๊ธฐ ์œ„ํ•ด ์ผ๋ถ€ ๊ฐ€์ค‘์น˜๋ฅผ ์กฐ์ ˆํ•˜๋Š” ํ•™์Šต ๊ณผ์ •
  2. ๊ณ ์ •๋œ ํŠน์ง• ์ถ”์ถœ๊ธฐ๋กœ์จ์˜ ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง- ๋งˆ์ง€๋ง‰์— ์™„์ „ํžˆ ์—ฐ๊ฒฐ ๋œ ๊ณ„์ธต์„ ์ œ์™ธํ•œ ๋ชจ๋“  ์‹ ๊ฒฝ๋ง์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ณ ์ •. ์ด ๋งˆ์ง€๋ง‰์˜ ์™„์ „ํžˆ ์—ฐ๊ฒฐ๋œ ๊ณ„์ธต์€ ์ƒˆ๋กœ์šด ๋ฌด์ž‘์œ„์˜ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ–๋Š” ๊ณ„์ธต์œผ๋กœ ๋Œ€์ฒด๋˜์–ด ์ด ๊ณ„์ธต๋งŒ ํ•™์Šต.

์ค‘์š”ํ•œ ์ฝ”๋“œ ์ •๋ฆฌ

ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง ๋ฏธ์„ธ์กฐ์ •(finetuning)

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# ์—ฌ๊ธฐ์„œ ๊ฐ ์ถœ๋ ฅ ์ƒ˜ํ”Œ์˜ ํฌ๊ธฐ๋Š” 2๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
# ๋˜๋Š”, nn.Linear(num_ftrs, len (class_names))๋กœ ์ผ๋ฐ˜ํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# ๋ชจ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค์ด ์ตœ์ ํ™”๋˜์—ˆ๋Š”์ง€ ๊ด€์ฐฐ
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# 7 ์—ํญ๋งˆ๋‹ค 0.1์”ฉ ํ•™์Šต๋ฅ  ๊ฐ์†Œ
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

# ํ•™์Šต ๋ฐ ํ‰๊ฐ€
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

๊ณ ์ •๋œ ํŠน์ง• ์ถ”์ถœ๊ธฐ๋กœ์จ์˜ ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

# ์ƒˆ๋กœ ์ƒ์„ฑ๋œ ๋ชจ๋“ˆ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋Š” ๊ธฐ๋ณธ๊ฐ’์ด requires_grad=True ์ž„
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# ์ด์ „๊ณผ๋Š” ๋‹ค๋ฅด๊ฒŒ ๋งˆ์ง€๋ง‰ ๊ณ„์ธต์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋“ค๋งŒ ์ตœ์ ํ™”๋˜๋Š”์ง€ ๊ด€์ฐฐ
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# 7 ์—ํญ๋งˆ๋‹ค 0.1์”ฉ ํ•™์Šต๋ฅ  ๊ฐ์†Œ
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

# ํ•™์Šต ๋ฐ ํ‰๊ฐ€ 
model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=25)