인공지능공부/프레임워크

pytorch lightning multi-gpu사용하기

컴공누나 2023. 9. 12. 21:32
728x90
반응형

pytorch lightning에 대한 전반적인 내용은 아래 포스팅을 참고해주세요. 

 

pytorch lightning을 사용해보자

pytorch는 가장 널리 쓰이는 딥러닝 프레임워크중 하나이죠. 딥러닝, 머신러닝을 공부하시는 분이라면 들어보셨을겁니다. 써보신 분들은 아시겠지만 pytorch 자체가 굉장히 자유도가 높은 프레임워

jaeyoon-95.tistory.com

 

pytorch lightning은 multi-gpu또한 쉽게 사용하실 수 있습니다. 

방법을 찾아보면 아래와 같이 trainer 선언 부에 gpu의 개수를 적어주면 되는데요. 

trainer = pl.Trainer(max_epochs=10,gpus=4,num_sanity_val_steps=1)

 

저도 위와 같이 적어주어 GPU 4개인 환경에서 돌리려고 하니 아래와 같이 오류가 났습니다. 

OverflowError: cannot serialize a bytes object larger than 4 GiB

OverflowError인데.. GPU 1개 짜리에서도 잘 돌던 코드인데, 이상하다 싶었습니다. 

 

여기서 제가 빠뜨린 부분은 dp(data parallel), ddp(distributed data parallel)설정입니다. 

dp같은 경우에는 하나의 GPU에서 모델을 병렬처리 할 경우 사용하고, ddp 같은 경우 여러개의 GPU에서 병렬처리 할 경우 사용합니다. 

 

설정 방법은 아래와 같습니다. 

from pytorch_lightning.strategies import DDPStrategy

plugins = DDPStrategy(find_unused_parameters=False)
trainer = pl.Trainer(max_epochs=10,gpus=4,num_sanity_val_steps=1,strategy=plugins)
trainer.fit(model)

간단하죠?

 

ddp이외에 다른 strategy들도 많은데요. 다른 strategy들은 문서를 참고해주세요:)

반응형