728x90
반응형
def load_DDP_model(path,model_frame):
'''
path : /path/to/DDP-trained-model.pt
model_frame : model structure
'''
state_dict=torch.load(path)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
new_state_dict[name] = v
model_frame.load_state_dict(new_state_dict)
return model_frame
DDP 모드로 훈련된 모델들은 이유를 모르겠지만 각 레이어 이름 앞에 'module.' 가 붙는다.
따라서 그것들을 제거해줘야 load_state_dict가 작동한다.
이를 위해서 위의 기능을 사용하면 된다.
728x90
반응형
'DeepLearning > 파이토치(pytorch)' 카테고리의 다른 글
[torch] multi-gpu (DDP)로 사용하기 튜토리얼 (0) | 2024.03.06 |
---|---|
[torch] cross-validation을 위한 기능 만들기 (multi-gpu) (0) | 2024.03.06 |
파이토치 시드 (seed)값 고정법 (0) | 2023.03.17 |