[ PyTorch ] Custom Model 제작 - part 5
Updated:
hook
- 패키지화된 코드에서 custom 코드를 중간에 실행시킬 수 있도록 만들어 놓은 인터페이스
- 사용 목적
- 프로그램 실행 로직 분석
- 프로그램 추가 기능 제공
- 종류
- pre-hook
- 기존의 프로그램이 실행되기 전 수행
- hook
- 기존의 프로그램이 실행되고 난 후 수행
- pre-hook
-
사용 방법
# hook 함수 정의 def hook_func(): """ hook이 실행할 내용 작성 """ # hook을 사용할 객체(tensor, module) 생성 model = Model() # 해당 객체에 hook 등록 model.register_forward_hook(hook_func)
Tensor에 적용하는 hook
backward hook
- 종류
- hook → register_hook()
-
예시
class Model(nn.Module): def __init__(self): super().__init__() self.W = Parameter(torch.Tensor([5])) def forward(self, x1, x2): output = x1 * x2 output = output * self.W return output # 모델 생성 model = Model() # Model의 Parameter W의 gradient 값을 저장하기 위한 list answer = [] def tensor_hook(grad): answer.extend(grad) # 내가 정의한 hook을 실제로 모듈 내의 tensor(여기서는 weight)에 적용 model.W.register_hook(tensor_hook) x1 = torch.rand(1, requires_grad=True) x2 = torch.rand(1, requires_grad=True) output = model(x1, x2) output.backward() if answer == [model.W.grad]: print(True) else: print(False) """ 출력결과 : True """
Module에 적용하는 hook
forward hook
- 종류
- pre-hook → register_forward_pre_hook()
- hook → register_forward_hook()
-
예시
class Add(nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2): output = torch.add(x1, x2) return output # 전파되는 output 값에 5를 더한다. def hook(module, output): output += 5 # 모델 생성 add = Add() # 내가 정의한 hook을 실제로 모듈에 적용하는 과정 add.register_forward_hook(hook) x1, x2 = torch.rand(1), torch.rand(1) print(x1,",",x2) output = add(x1, x2) print(output) """ 실행결과 : x1 + x2의 값에 5가 더해진 값이 출력된다. """
backward hook
- 종류
- hook → register_full_backward_hook()
-
예시
class Model(nn.Module): def __init__(self): super().__init__() self.W = Parameter(torch.Tensor([5])) def forward(self, x1, x2): output = x1 * x2 output = output * self.W return output # 모델 생성 model = Model() # x1.grad, x2.grad, output.grad 순서로 list에 넣는다. answer = [] def module_hook(module, grad_input, grad_output): answer.extend(grad_input) answer.extend(grad_output) # 내가 정의한 hook을 실제로 모듈에 적용시키는 과정 model.register_full_backward_hook(module_hook) x1 = torch.rand(1, requires_grad=True) x2 = torch.rand(1, requires_grad=True) output = model(x1, x2) output.retain_grad() output.backward() if answer == [x1.grad, x2.grad, output.grad]: print(True) else: print(False) """ 출력결과 : True """