파이썬3
Cluster index permutation test
TTSR
2023. 10. 4. 15:19
728x90
반응형
우리가 샘플들의 라벨을 알고 있을 때, clustering이 잘 됐다면 같은 그룹의 샘플들의 거리 (intra-subtype distance) 는 가깝고 다른 그룹의 샘플들의 거리 (inter-subtype distance)는 멀 것이다.
이러한 점을 이용해서 clustering이 잘됐는지 permutation test를 통해 p-value를 얻을 수 있다.
출처 : 10.1038/s41467-020-17139-y
def calc_cluster_index(mm, class_list, distance_method='correlation'):
'''
mm : feature matrix (sample x feature)
class_list : dictionary {'subtypeA':['s1','s2'],'Subtype2':['s3','s4']}
distance_method : ['euclidean','correlation']
'''
import numpy as np
from scipy.spatial.distance import pdist, squareform
# Extract values
x = mm.loc[sum(class_list.values(),[]),:]
if x.shape[0] != 1:
if distance_method == 'euclidean':
x1 = squareform(pdist(x, metric='euclidean'))
elif distance_method == 'correlation':
x1 = 1 - np.corrcoef(x.T, rowvar=False)
else:
x1 = squareform(pdist(x.T))
x1=pd.DataFrame(x1,index=x.index,columns=x.index)
# Calculate intra connectivity
intra_con = []
for key in class_list.keys():
samples=class_list[key]
k1=x1.loc[samples,samples].values
r,c=np.triu_indices(n=k1.shape[0],k=1)
intra_con.append(k1[r,c].mean())
# Calculate inter connectivity
inter_con = []
keys=list(class_list.keys())
class_list2={}
n=0
for key in class_list:
c=class_list[key]
components=[]
for idx in range(n,n+len(c)):
components.append(idx)
class_list2[key]=components
n=idx+1
for idx in range(0,len(keys)-1):
k1=keys[idx]
for idx2 in range(idx+1,len(keys)):
k2=keys[idx2]
a1 = class_list2[k1]
b1 = class_list2[k2]
c1 = x1.values[np.ix_(a1, b1)]
inter_con.append(np.mean(c1))
# Calculate cluster index
cluster_index = sum(intra_con) - sum(inter_con)
# Export result
return cluster_index
p-value는 아래와 같이 계산한다.
def cluster_index_P(mm,class_list, distance_method='correlation',perm=10000,plot=False):
import random
from tqdm import tqdm
# Get real value
real_value=calc_cluster_index(mm=mm,class_list = class_list, distance_method = distance_method)
# Calculate permutated values
perm_value=[]
mmx=mm.copy()
samples=sum(class_list.values(),[])
for idx in tqdm(range(perm)):
per_label=random.sample(samples,k=len(samples))
mmx.index=per_label
fake_value=calc_cluster_index(mm=mmx,class_list=class_list,distance_method=distance_method)
perm_value.append(fake_value)
# calculate p-value
pvalue=sum(np.array(perm_value)<real_value)/perm
# plotting
if plot:
import seaborn as sns
import matplotlib.pyplot as plt
sns.kdeplot(x=pd.Series(perm_value))
plt.xlabel('Cluster index')
plt.axvline(x=real_value,color='red',linestyle='--')
plt.show()
return pvalue
728x90
반응형