DL&ML/papers

Siamese Neutral Networks for One-shot Image Recognition

식피두 2020. 10. 27. 14:30

논문 링크 ; www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf

(링크 클릭이 안되는 경우 제목으로 검색)

 

이전에 Extrative Summarization as Text Matching 논문에서

siamese-BERT 아키텍쳐가 제안되었는데, 

 

siamese-network가 무엇인지 궁금해서 관련 논문을 찾아보았다.

 

One-shot Image Recognition

딥러닝을 이용해서 특정 태스크의 문제를 해결하기 위한 좋은 피쳐, 표현을 얻기 위해선

양질의 데이터와 비싼 컴퓨팅 자원이 필요하다.

 

하지만, 사람은 조금 과장해서 A가 그려진 단 한장의 이미지만을 보고도

A의 변형에 대해서 같은 것임을 분류해 낼 수 있다.

 

여기서, One-shot Image Recognition 이란 개념이 등장하는데,

이는 각 클래스별로 한 개의 이미지 밖에 없는 상황에서,

테스트 데이터가 주어졌을 때 가지고 있는 데이터 중 어떤 이미지와 같은 것인지 찾는 문제(분류)를 말한다.

 

본 논문에선 siamese network를 제안해서 두 개의 인풋 사이의 유사도에 대해서 

랭킹을 매길 수 있도록 네트워크를 학습시킬 수 있는 방법을 제안한다.

 

학습된 네트워크는 유사도 측정에 최적화된 양질의 feature를 뽑을 수 있고,

이 feature를 이용하여 one-shot image recognition을 수행하는데 활용될 수 있다. 

 

*One-shot Learning ; 각 클래스별로 1개의 example만 보고 test 데이터를 분류하는 태스크

 

Approach 

1. 먼저 siamese neural network를 supervised metric-based 방법으로 학습을 시킨다.

   이미지 쌍이 주어지면 네트워크는 같은 클래스인지 아닌지 여부를 학습한다. (0~1 확률)

   이를, Verification Model이라고 한다.

 

* 가정 ; verification을 잘 하는 네트워크는 one-shot 분류도 잘 할 것이다.

 

2. 학습된 네트워크의 피쳐, 표현(representation)을 '추가 학습 없이' one-shot learning을 수행하는데 활용한다.

   테스트 이미지가 주어지면, 각 클래스의 1 example의 이미지를 verification model에 넣어서

   가장 높은 점수(확률)을 고른다.

   (feature만 미리 뽑아두어, feature에 대해 직접 유사도 비교를 빠르게 처리하고 마는 줄 알았는데... 아닌가보다)

 

논문에선 omniglot (50개국 언어의 알파벳 데이터셋)을 활용했다.

 

Verification Model에 의해 학습된 피쳐를 통해서

특정 언어의 알파벳이 같은지/다른지를 충분히 구분해낼 수 있고,

다른 다양한 언어의 알파벳에 대해서도 마찬가지로 잘 구분해 낼 수 있도록 학습 되었다면

새로운 알파벳에 대해서도 잘 동작할 것이라 가정해 볼 수 있겠다.

 

Deep Siamese Network

Siamese 네트워크는 이미 (Bromley, LeCun et al. 1993) 에서 제안된 바 있다.

두 개의 개별 인풋이 Twin Networks를 통과하고 energy function에 의해 합쳐진다.

Energy function에서는 추출된 각 피쳐간의 메트릭을 계산한다.

 

Weight이 공유되는 Twin Netoworks(symmetric)기 때문에

완전히 같은 인풋이 들어왔을 경우에 Feature Space 상에서 서로 다른 공간에 매핑 될 수가 없다.

 

LeCun et al. 에서는 Contrastive Energy Function을 사용해서

두 개의 입력이 주어졌을 때, 같은 쌍에 대해선 에너지를 감소시키고

다른 쌍에 대해선 에너지를 증가시키게 하였다.

 

이 논문에서는 Weighted L1 Distance를 거쳐서 Sigmoid 를 통과하여 0~1 사이 값을 뽑았다.

(따라서, cross entropy objective를 사용)

 

Siamese Neutral Networks for One-shot Image Recognition, Koch et al.

논문에선 앞단에 CNN을 두었고,

마지막 Conv 레이어에서 flatten된 피쳐에 대해 FC Layer(256*6*6 x 4096)+sigmoid를 통과 시킨 뒤

이미지당 4096 차원의 피쳐를 추출했다. 

 

그 이후에 두 피쳐의 L1 디스턴스를 계산 하고, 

weighted factor를 곱해준 결과를 얻은 후에 sigmoid를 취해 Prediction Vector 'P'를 얻는다.

(이 부분이 그림에서 마지막 FC Layer를 의미한다. 논문에서는 명확히 FC라고 명시를 하지 않음)

 

마지막 레이어에서 두 피쳐 벡터간의 유사도를 계산

자세한 튜닝 방법은 논문을 참고하자.

아, 참고로 Twin Network를 사용하기 때문에, 각각 출력에 대해 계산되는 gradient는 누적되어 학습이 된다.

그 외에 layer 별 서로 다른 LR, LR decay,... 다양한 것들을 튜닝해줬다.

 

Experiment

omniglot 데이터셋을 이용하였으며,

총 50 종류의 알파벳 중

40 종류는 background set

10 종류는 evaluation set

으로 분리하였다.

 

background set은 verification model을 위한 training / validation / test에 사용되었고

evaluation set은 one-shot classification 성능을 측정하는데 사용되었다.

 

Verficiation model

같은/다른 클래스 쌍을 랜덤 샘플링하여 학습하였고,

데이터셋 크기를 다르게 하여 실험이 진행되었다.

 

알파벳에 대한 표현이 동등한 기회를 갖고 학습될 수 있도록 알파벳 별 uniform 한 개수의

학습셋을 만들어 줬다고 함.

 

Siamese Neutral Networks for One-shot Image Recognition, Koch et al.

* One-shot Learning

가장 궁금했던 부분.

 

이제, verification model이 준비되었으니 one-shot image recognition을 해봄으로써

학습된 피쳐의 discriminative potential을 측정해볼 수 있다.

 

테스트 이미지 x가 있고,

각 클래스에 대한 Example이 c개(Category, x_1, x_2, ... x_c) 만큼 있다면

 

x, x_c 를 네트워크에 넣고 출력 확률 값이 가장 높은 것을 택하면 된다.

* 효율적으로 처리하기 위해선 c개의 미니배치로 구성하여 네트워크에 전달해준다.

 

앞서 분리해 놓은 evaluation dataset을 이용해서 (10 alphabets)

특정 알파벳 기준 20-way 분류 문제를 만들어서 평가를 수행한다.

 

1. uniformly 랜덤하게 alphabet 하나 선택

2. 20명의 drawers(태깅한 사람) 중 2명의 characters 선택

3. 첫 번째 사람의 character는 test 셋으로 생각, 두 번째 사람의 characters는 20-example로 생각하여 test 데이터의 클래스를 예측

4. 1-3 프로세스를 전체 알파벳에 대해 2번 반복

 

각 알파벳 당 40번의 one-shot learning trials

전체적으론 400번의 one-shot learning trials

 

characters 혹은 stroke에 대한 extra prior knowledge 없이

Human이나 HBPL에 근접한 성능을 보이며,

나머지 베이스라인보다 월등한 성능을 보여줌.

 

Siamese Neutral Networks for One-shot Image Recognition, Koch et al.

 

 

참고할만한 코드.. (contrastive loss를 사용하였다.)

github.com/delijati/pytorch-siamese

 

delijati/pytorch-siamese

Siamese Network implementation using Pytorch. Contribute to delijati/pytorch-siamese development by creating an account on GitHub.

github.com