Computing

Compression - 2 : PyTorch Pruning Tutorial 및 계산 속도가 빨라지지 않는 이유 본문

Deep Learning/Optimization (Algorithm)

Compression - 2 : PyTorch Pruning Tutorial 및 계산 속도가 빨라지지 않는 이유

jhson989 2022. 4. 29. 21:20

이전글
2022.03.29 - [Deep Learning/Optimization (Algorithm)] - Compression - 1 : Overview

Pruning 기법 소개

이전글에서 소개한 Pruning 기법에 대하여 PyTorch framework를 통해 구현해 보고, 성능에 대하여 분석해보고자 한다.

많은 컴퓨터공학 전공자들은 Pruning(가지치기)이라는 개념을 tree 자료구조에서 탐색할 노드의 개수를 줄이는 방법에 대해 배울 때 배웠을 것이다. 딥러닝에서도 비슷한 개념으로 Fig 1.과 같이 딥러닝 layer(=graph)의 node(=feature)사이의 edge(=weights)를 제거하여 총 계산할 node(feature)의 개수를 줄이고자 한다.

 

Fig 1. Pruning 기법 [5]

 

Pruning은 딥러닝 네트워크의 feature들끼리의 edges, 즉 weights(=layer의 paramters) 개수를 줄인다. 따라서 사라진 weights만큼의 계산량과 메모리 사용량이 감소하기에, 이론적으로는 pruning 과정을 통해 학습 및 추론 속도가 빠른 경량화된 네트워크를 얻을 수 있다.

Pruning은 layer의 parameters 개수를 줄이기에 네트워크 성능에 큰 영향을 미칠 수밖에 없다. 특히 parameters를 너무 많이 없애면 네트워크의 학습능력이 떨어져 문제를 해결하지 못할 수도 있다. 따라서 pruning은 over-parameterized network에 사용하기 좋은 경량화 기법이다. 문제에 비해 너무 많은 parameters를 가지는 딥러닝 네트워크를 over-parameterized network라고 하는데, 그렇기에 더 적은 수의 parameters를 가지게끔 pruning하여도 성능이 어느정도 유지될 수 있다. (경우에 따라서는 더 좋아질수도 있을 것 같은데, over-parameterized network는 어느정도 학습 데이터에 overfitting 될 수 있기 때문이다.)

Over-parameterized network인지 아닌지 판단하는 방법은 실험적으로 보이는 수 밖에는 없을 것이다. 그 문제를 풀기 위한 최적의 네트워크 형태를 알 수 있다면 가능하겠지만, (아직은?) 불가능하기에 자신의 네트워크의 크기(parameters의 개수)를 점점 줄여가며 성능이 유지되는 지를 확인해봐야 한다. 크기를 줄여서 다시 학습시켜 보았는데도 어느정도 성능이 유지된다면 pruning을 진행하여도 괜찮은 over-parameterized network라고 판단할 수 있을 것이다.

어떤 weights를 없애는 지도 pruning 이후의 네트워크 성능 유지에 중요한 요인이 된다. 일반적으로 네트워크 성능에 영향이 적은 weights를 없애는 것이 성능 유지에 좋을 것인데, weight의 L1 norm 또는 L2 norm 크기가 작은 weights를 없애는 것이 일반적이다. 즉 weights의 절대값이 0에 가까울수록 네트워크 성능에 영향이 덜 미친다고 판단하여 0에 가까운 weights들을 제거한다. 랜덤하게 제거하는 방법도 있지만 이 경우 pruning이 성능에 미치는 영향도 랜덤할 것이다.

네트워크의 성능에 영향이 적을 것이라고 예상되는 절대값이 작은 weights를 제거해도 어느정도 네트워크의 성능은 변할 수 밖에 없다. 애초에 절대값이 작은 weights가 네트워크 성능에 영향이 무조건 적다는 것도 증명되지 않았기 때문에(그럴것 같지만), 네트워크의 성능은 중간 계산 과정이 생략됨에 따라 변하게 된다. 따라서 일반적으로 pruning된 네트워크는 다시 학습을 반복하여 정확도 향상을 시도한다. 즉 학습 -> pruning 및 성능 테스트 -> 학습 -> pruning 및 성능 테스트 -> ... 의 방식으로 반복 학습을 진행한다. 이 과정을 통햐 목표한 만큼 경량화되고 성능 또한 유지되는 네트워크를 얻을 수 있다.

 

 

PyTorch Pruning Tutorial

자료 [1]은 PyTorch framework에서 어떻게 pruning하는 지를 잘 보여준다. PyTorch에서 pruning 기법을 사용하기 위해서는 다음 package를 import하여야 한다.

 

import torch.nn.utils.prune as prune

 

딥러닝 네트워크는 다음과 같은 방식으로 pruning할 수 있다. 예시에서는 l1_unstructed() 함수를 이용하여 해당 모듈의 parameters를 pruning한다.

 

# Pruning할 모델 선언
model = LeNet()

# Model의 module(layer) 중 Conv2d, Linear module에 대하여 pruning 진행
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
    	# layer의 weight에 대하여 L1 norm 기준 pruning 진행
        # 이때 amount는 전체 parameters 중 제거되는 parameters의 비율
        # 이 예제에서는 90%의 parameters를 제거함
        prune.l1_unstructured(module, name='weight', amount=0.9)
        
        # layer의 bais에 대하여 L1 norm 기준 pruning 진행
        prune.l1_unstructured(module, name='bias', amount=0.9)

 

 

이때 중요한 것은 PyTorch는 실제 parameters를 제거하는 것이 아닌, parameters의 값을 0으로 만드는 방식으로 pruning한다. 이는 parameters가 tensor로 저장되었기에 배열의 한 element(=parameter)를 물리적으로 제거하는 것이 쉽지 않기 때문이다. (Tensor는 index(좌표)로 위치를 표현하기에, 중간의 element 하나를 제거하는 것이 불가능하다.) Parameters를 0으로 만드는 것은 해당 parameter의 영향이 없는 것과 동일하기에(곱했을 때 0이 되기에) 이 방식을 사용한다. 

 

구체적으로 예시를 들면 다음과 같다. 위 코드 예시에서 weight를 pruning할 경우 conv2d 모듈 내부에서는 weight tensor는 사라지고 weight_origin, weight_mask tensor가 생성된다. weight_origin은 pruning 되기 전의 weight 값을 저장하며, weight_mask는 bitmask로 해당 parameter가 pruning되는 지를 표시한다. 매 forward 과정마다 weight_origin과 weight_mask tensor를 point-wise 곱하여 weight tensor를 생성하고, 이를 이용하여 forward 과정을 계산한다. weight를 만드는 함수는 모듈의 _forward_pre_hooks에 저장되고, forward 시 module이 호출될 때마다 _forward_pre_hooks에 저장된 hook 함수에 따라 prune된 weight를 다시 계산한다.

 

Pruning이 완료된 딥러닝 네트워크는 remove()함수를 이용해 결과를 확정한다.

 

# Pruning된 모든 모듈에 대하여 remove 실행
for name, module in self.model.named_modules():
    if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
    	# weight parameter에 대하여 remove 실행
        prune.remove(module, 'weight')
        # bias parameter에 대하여 remove 실행
        prune.remove(module, 'bias')

 

remove()함수는 weight_origin과 weight_mask를 이용해 weight를 다시 생성하고, weight_origin, weight_mask tensor 및 _forward_pre_hooks 내 pruning 함수를 제거한다.

 

다음 github repository에 PyTorch prune package를 사용한 pruning 예시를 올려놓았다. mnist 분류 문제를 위한 over-parameterized LeNet을 구현하고, 이를 pruning하여 성능을 측정하였다. 

 

 

GitHub - jhson989/pytorch-pruning-basic: pytorch pruning tutorial

pytorch pruning tutorial. Contribute to jhson989/pytorch-pruning-basic development by creating an account on GitHub.

github.com

 

 

 

Pruning 성능 분석

Pruing의 원래 의도는 네트워크를 경량화하여 적은 메모리 사용량, 빠른 계산 속도를 달성하고자 하는 것이다. 하지만 PyTorch에서는 두개 다 불가능하다[2.3.4]. 이는 네트워크의 parameters를 실제로 없애는 것이 아닌 0으로 만드는 방식으로 작동하기 때문이다. 일단 parameters의 개수가 줄지도 않을 뿐더러 학습과정에서는 original parameter뿐만 아니라mask tensor 또한 저장해야 하기에 학습 중 메모리 사용량은 오히려 늘어나게 된다. 또한 parameters의 개수 자체가 줄어들지 않아 연산량 자체는 줄어들지 않기에 학습 및 추론 속도 또한 증가하지 않는다.

만약 pruning을 많이 할 경우(95%이상?) parameters들의 대부분이 0인 sparse tensor를 얻을 수 있다. 따라서 sparsity 성질을 활용할 수 있는 계산 알고리즘[6]이나 하드웨어를 활용한다면 계산 속도 향상 및 메모리 사용량 감소를 달성할 수 있을 것이다. 아직까지는 PyTorch에서 pruning 기법을 실험적으로만 사용해 볼 수 있지만 (네트워크의 parameters가 제거되어도 정확도가 유지될 수 있는 지 정도는 확인할 수 있다), 추후 PyTorch가 업데이트 된다면 pruning 기법을 통해 물리적인 네트워크 경량화가 가능할 것이다.

 

 

Reference

[1] https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
[2] https://discuss.pytorch.org/t/pruning-doesnt-affect-speed-nor-memory-for-resnet-101/75814
[3] https://discuss.pytorch.org/t/discussion-of-practical-speedup-for-pruning/75212
[4] https://stackoverflow.com/questions/62326683/prunning-model-doesnt-improve-inference-speed-or-reduce-model-size
[5] Song Han, Efficient Methods and Hardware for Deep Learning, May 25, 2017, Stanford University
[6] 2022.03.22 - [Parallel Computing/알고리즘] - SpMM - 1 : Introduction