※ hackernoon.com/facial-similarity-with-siamese-networks-in-pytorch-9642aa9db2f7의 도움을 아주 많이 받았습니다.
이전 포스팅에서 샴네트워크와 One Shot Learning을 배웠습니다. 이번 포스팅에서는 Contrastive loss와 실제 구현 코드를 함께 살펴보겠습니다.
[Contrastive Loss]
CNN학습 결과로 나온 벡터를 이용하여 유클리드 거리를 이용하는 식, 기억나시나요? 다음과 같은 식이었습니다.
새로운 사람을 타겟 이미지와 전혀 다른 사람과 비교하여 유클리드 거리를 구하는 식입니다. 이런 방식을 삼중항 손실이라고 부르는데요, 설명은 hipolarbear.tistory.com/20포스팅을 살펴보시면 되겠습니다.
이 삼중항 손실을 이용하여 긍정이미지와는 거리가 가깝도록, 부정 이미지와는 거리가 멀도록 하는 식이 있습니다. 이 식을 Contrastive Loss라고 부릅니다. 한국어로는 대조 손실이라고 번역하는 것 같습니다.
긍정 이미지와 같은 이미지는 1의 라벨(Y값은 1)을, 다른 이미지는 0의 라벨(Y값은 0)을 갖습니다.
Y값이 1이라면, +기호를 중심으로 왼쪽의 식이 사라지고 오른쪽 식이 살아남습니다.
그렇게 되면 0이상의 값을 가지는 m(margin)보다 유클리드 거리가 크다면 m-Dw는 음수가 되어 Contrastive loss의 값은 0이 됩니다. max함수 덕분이죠. Dw값이 margin내에 있다면 이 유클리드 거리값은 그대로 살아남게 되어 최종적으로 긍정이미지와의 거리는 가까워지겠습니다.
Y값이 0이라면, +기호를 중심으로 왼쪽의 식만 살아남게 됩니다.
이 경우에는 어떠한 처리를 해주지 않아 날것 그대로의 유클리드 거리가 쌓이게 될 것입니다.
Margin의 크기 설정이 중요해 보입니다.
코드는 다음과 같습니다.
class ContrastiveLoss(torch.nn.Module):
"""
Contrastive loss function.
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2, keepdim = True)
loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss_contrastive
라벨을 이용하여 Y값을 설정하여 계산하는 모습을 보실 수 있습니다.
[샴네트워크 코드]
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.cnn1 = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(1, 4, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(4),
nn.ReflectionPad2d(1),
nn.Conv2d(4, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.ReflectionPad2d(1),
nn.Conv2d(8, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
)
self.fc1 = nn.Sequential(
nn.Linear(8*100*100, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 5))
def forward_once(self, x):
output = self.cnn1(x)
output = output.view(output.size()[0], -1)
output = self.fc1(output)
return output
def forward(self, input1, input2):
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
return output1, output2
샴 네트워크 코드입니다. 코드를 보시면 엥?? 실컷 설명한 Contrastive loss도 없고... 뭐가 없습니다.
샴 네트워크 코드라면서 왜 CNN을 하고 있나요??
이 코드는 긍정 이미지에 대한 공유되는 CNN을 구성하기 위한 네트워크입니다. 이전 포스팅에서 이미지를 비교할 때 미리 학습해놓은 긍정 이미지의 CNN 학습 결과 잣대로 비교한다고 하였습니다. 우선 긍정 이미지를 학습하는 부분이라고 할 수 있겠습니다.
[데이터 로딩]
class SiameseNetworkDataset(Dataset):
def __init__(self,imageFolderDataset,transform=None,should_invert=True):
self.imageFolderDataset = imageFolderDataset
self.transform = transform
self.should_invert = should_invert
def __getitem__(self,index):
img0_tuple = random.choice(self.imageFolderDataset.imgs)
#we need to make sure approx 50% of images are in the same class
should_get_same_class = random.randint(0,1)
if should_get_same_class:
while True:
#keep looping till the same class image is found
img1_tuple = random.choice(self.imageFolderDataset.imgs)
if img0_tuple[1]==img1_tuple[1]:
break
else:
while True:
#keep looping till a different class image is found
img1_tuple = random.choice(self.imageFolderDataset.imgs)
if img0_tuple[1] !=img1_tuple[1]:
break
img0 = Image.open(img0_tuple[0])
img1 = Image.open(img1_tuple[0])
img0 = img0.convert("L")
img1 = img1.convert("L")
if self.should_invert:
img0 = PIL.ImageOps.invert(img0)
img1 = PIL.ImageOps.invert(img1)
if self.transform is not None:
img0 = self.transform(img0)
img1 = self.transform(img1)
return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))
def __len__(self):
return len(self.imageFolderDataset.imgs)
이미지를 불러와 라벨별로 처리하는 모습을 볼 수 있습니다.
[학습]
net = SiameseNetwork().cuda()
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.0005 )
counter = []
loss_history = []
iteration_number= 0
for epoch in range(0,Config.train_number_epochs):
for i, data in enumerate(train_dataloader, 0):
img0, img1 , label = data
img0, img1 , label = img0.cuda(), img1.cuda() , label.cuda()
optimizer.zero_grad()
output1,output2 = net(img0,img1)
loss_contrastive = criterion(output1,output2,label)
loss_contrastive.backward()
optimizer.step()
if i %10 == 0 :
print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.item()))
iteration_number +=10
counter.append(iteration_number)
loss_history.append(loss_contrastive.item())
show_plot(counter,loss_history)
네트워크와 옵티마이저 등 초기화 후 학습을 진행합니다.
학습 순서는 다음과 같습니다. for문 내부입니다.
- 첫번째 이미지를 네트워크에 통과시킵니다.
- 두번째 이미지를 네트워크에 통과시킵니다.
- 첫번째 이미지와 두번째 이미지의 Contrastive Loss를 계산합니다.
- 역전파하고 네트워크를 개선합니다.
다음과 같은 그래프를 그리며 loss가 감소하는 모습을 볼 수 있었습니다.
[학습 결과 보기]
testing_directory = "custom_data/test_data/"
folder_dataset_test = dset.ImageFolder(root=testing_directory)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
transform=transforms.Compose([transforms.Resize((100,100)),
transforms.ToTensor()
])
,should_invert=False)
test_dataloader = DataLoader(siamese_dataset,num_workers=0,batch_size=1,shuffle=True)
dataiter = iter(test_dataloader)
x0,_,_ = next(dataiter)
for i in range(10):
_,x1,label2 = next(dataiter)
concatenated = torch.cat((x0,x1),0)
output1,output2 = net(Variable(x0).cuda(),Variable(x1).cuda())
euclidean_distance = F.pairwise_distance(output1, output2)
imshow(torchvision.utils.make_grid(concatenated),'Dissimilarity: {:.2f}'.format(euclidean_distance.item()))
테스트 데이터를 따로 구성해주고 실제로 유클리드 거리를 계산하며 학습이 잘됐는지 확인해봅시다. 제 얼굴과 원빈, 고릴라를 비교한 결과는 다음과 같았습니다.
학습이 잘 된 것 같습니다!! 다만 노트북 성능이 좋지 않아 batch size를 6으로 줄이고, 학습 데이터를 20장으로, worker를 1로 줄이면서 정확한 학습이나 학습속도 측면에서 아쉬운 부분이 있기도 했습니다.
[폴더 경로 보기]
이 프로젝트를 진행하면서 폴더 경로를 형성하는 부분에서 애를 많이 먹었습니다. 그래서 위 코드에 사용된 폴더 경로의 모습을 기록하려고 합니다.
["Origin_data_2/" 폴더]
이 부분에서 사용한 경로입니다.
["custom_data\test_data\" 폴더]
※ hackernoon.com/facial-similarity-with-siamese-networks-in-pytorch-9642aa9db2f7 의 포스팅과 코드를 구현하며 많은 도움이 되었습니다.
github.com/BUZZINGPolarBear/Why_Am_I_ALONE
전체 코드를 보실 수 있습니다.
'머신러닝' 카테고리의 다른 글
[Pytorch] Siamese network를 이용하여 나의 외모를 점검해보자_1탄 (1) | 2021.02.17 |
---|---|
[Pytorch] 데이터를 뻥튀기하자! Data Augmentation (8) | 2021.02.13 |
[입문] Pytorch 와 Linear Regression (0) | 2021.01.15 |