카테고리 없음

[mnist] Anomaly detection에서의 boosting 기법 효과

TTSR 2024. 10. 8. 08:20
728x90
반응형

1. 개요

 이상치 탐지 (anomaly detection) 중 비지도 이상치 탐지는 주로 정상 데이터만 있는 경우에 자주 사용되는 기법임. 문제는 비지도 학습의 경우 지도학습에 비해 성능이 다소 떨어지는 것이 사실임. 이러한 문제를 boosting 기법을 적용하면 좀 더 성능이 향상될지가 궁금했음.

 

2. 방법

a) mnist 데이터셋에 대해 각 숫자들을 정상으로 두고 나머지는 비정상으로 두는 방식으로 인위적인 normal/abnormal 데이터셋을 만듬. 훈련데이터의 abnormal은 validation set으로 사용하여 autoencoder 모델의 ROC-AUC값이 최대치에 다다르면 해당 모델을 사용하도록 만듬.

b) 편의상 single model이나 ensemble model이나 모델의 형태는 동일하게 만들었지만 single-model의 경우는 전체 훈련 데이터에 대해 훈련하게 만듬.

c) 반면, ensemble model은 첫 번 째 모델을 제외하고는 나머지 모델들은 이전 모델의 약점을 좀 더 보강할 수 있도록 학습 파라미터를 조정함. 방식은 t시점의 모델에서 어떤 훈련 데이터의 reconstruction error값이 클수록 t+1시점의 모델은 해당 데이터의 loss값에 집중하도록 만듬

d) reconstruction error는 훈련을 모두 마친 t시점의 모델에서 모든 데이터셋에 대해 계산했으며 샘플의 가중치는 (sample_error/total_error)x(N_of_samples)로 하여 계산함. 가중치의 합은 1이 되게 만들었지만 훈련데이터의 크기가 너무 작아질 경우 가중치가 너무 작아지기 때문에 훈련데이터의 크기를 곱함.

def calc_resampling_prob(x):
	'''
    재구축 오차값을 기반으로 다음 모델이 집중해야할 샘플의 가중치를 조정함.
    resapling prob으로 썼지만 새로운 데이터를 만들
    필요없이 train-set의 가중치를 조절하는 것으로도 됨.
    resampling을 하게 된다면 normal만의 validation set이 생기므로 이것은 유저의 선택영역임.
    t=0시점에서는 모두 균일한 가중치가 적용됨.
    '''
    a=x/x.sum()
    a=a.detach()
    return a.flatten()

def train_ensemble_by_boost(n=5,train_set=pdata['train-dataset']):
    '''
    training
    '''
    import copy
    import numpy as np
    models={}
    valid_n=None
    original_idx=list(range(train_set['Normal'].shape[0]))
    for i in range(n):
        if i==0:
            weights=None
        models[i]=train_single_model(train_d=train_set['Normal'],
                                    valid_n=train_set['Normal'],
                                    valid_a=train_set['Anormal'],
                                    weights=weights)
        # 학습 후 train-set에 대해 잘 못 맞춘 정도를 계산하기
        m=models[i][0]['relu']['AE'].to('cpu').eval()
        real_img=train_set['Normal'].reshape(len(original_idx),-1)
        reco_img=m(real_img)
        errors=((reco_img-real_img)**2).sum(axis=1)**(1/2)
        weights=calc_resampling_prob(x=errors)
    return models

 

 

3. 결과

3회 반복으로 ensemble 모델은 3개를 만들고 single model은 1개를 만들어서 비교를 했을 때, 모든 숫자에 대해 ensemble모델이 더 좋은 결과가 나타남.

728x90
반응형