본문 바로가기

DeepLearning/파이토치(pytorch)

[torch] DDP 모드에서 훈련된 모델 불러오기.

728x90
반응형
def load_DDP_model(path,model_frame):
    '''
    path : /path/to/DDP-trained-model.pt
    model_frame : model structure
    '''
    state_dict=torch.load(path)
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
        new_state_dict[name] = v
    model_frame.load_state_dict(new_state_dict)
    return model_frame

 

DDP 모드로 훈련된 모델들은 이유를 모르겠지만 각 레이어 이름 앞에 'module.' 가 붙는다.

따라서 그것들을 제거해줘야 load_state_dict가 작동한다.

이를 위해서 위의 기능을 사용하면 된다.

728x90
반응형