본문 바로가기

파이썬3

Cluster index permutation test

728x90
반응형

우리가 샘플들의 라벨을 알고 있을 때, clustering이 잘 됐다면 같은 그룹의 샘플들의 거리 (intra-subtype distance) 는 가깝고 다른 그룹의 샘플들의 거리 (inter-subtype distance)는 멀 것이다.

이러한 점을 이용해서 clustering이 잘됐는지 permutation test를 통해 p-value를 얻을 수 있다.

출처 : 10.1038/s41467-020-17139-y

 

def calc_cluster_index(mm, class_list, distance_method='correlation'):
    '''
    mm : feature matrix (sample x feature)
    class_list : dictionary {'subtypeA':['s1','s2'],'Subtype2':['s3','s4']}
    distance_method : ['euclidean','correlation']
    '''
    import numpy as np
    from scipy.spatial.distance import pdist, squareform
    # Extract values
    x = mm.loc[sum(class_list.values(),[]),:]
    if x.shape[0] != 1:
        if distance_method == 'euclidean':
            x1 = squareform(pdist(x, metric='euclidean'))
        elif distance_method == 'correlation':
            x1 = 1 - np.corrcoef(x.T, rowvar=False)
    else:
        x1 = squareform(pdist(x.T))
    x1=pd.DataFrame(x1,index=x.index,columns=x.index)
    # Calculate intra connectivity
    intra_con = []
    for key in class_list.keys():
        samples=class_list[key]
        k1=x1.loc[samples,samples].values
        r,c=np.triu_indices(n=k1.shape[0],k=1)
        intra_con.append(k1[r,c].mean())
    # Calculate inter connectivity
    inter_con = []
    keys=list(class_list.keys())
    class_list2={}
    n=0
    for key in class_list:
        c=class_list[key]
        components=[]
        for idx in range(n,n+len(c)):
            components.append(idx)
        class_list2[key]=components
        n=idx+1
    for idx in range(0,len(keys)-1):
        k1=keys[idx]
        for idx2 in range(idx+1,len(keys)):
            k2=keys[idx2]
            a1 = class_list2[k1]
            b1 = class_list2[k2]
            c1 = x1.values[np.ix_(a1, b1)]
            inter_con.append(np.mean(c1))
    # Calculate cluster index
    cluster_index = sum(intra_con) - sum(inter_con)
    # Export result
    return cluster_index

p-value는 아래와 같이 계산한다.

def cluster_index_P(mm,class_list, distance_method='correlation',perm=10000,plot=False):
    import random
    from tqdm import tqdm
    # Get real value
    real_value=calc_cluster_index(mm=mm,class_list = class_list, distance_method = distance_method)
    # Calculate permutated values
    perm_value=[]
    mmx=mm.copy()
    samples=sum(class_list.values(),[])
    for idx in tqdm(range(perm)):
        per_label=random.sample(samples,k=len(samples))
        mmx.index=per_label
        fake_value=calc_cluster_index(mm=mmx,class_list=class_list,distance_method=distance_method)
        perm_value.append(fake_value)
    # calculate p-value
    pvalue=sum(np.array(perm_value)<real_value)/perm
    # plotting
    if plot:
    	import seaborn as sns
        import matplotlib.pyplot as plt
        sns.kdeplot(x=pd.Series(perm_value))
        plt.xlabel('Cluster index')
        plt.axvline(x=real_value,color='red',linestyle='--')
        plt.show()
    return pvalue
728x90
반응형