1行代码消除PyTorch的CUDA内存溢出报错,这个GitHub项目刚发布就揽星600+

news/2024/7/8 6:52:08

点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达38cb427b8c08d884975fe8c0b1752e10.png

丰色 发自 凹非寺
量子位 报道 | 公众号 QbitAI

CUDA error: out of memory.

多少人用PyTorch“炼丹”时都会被这个bug困扰。

94da560952c7918cccb63ed87bbd5ff5.png

一般情况下,你得找出当下占显存的没用的程序,然后kill掉。

如果不行,还需手动调整batch size到合适的大小……

有点麻烦。

现在,有人写了一个PyTorch wrapper,用一行代码就能“无痛”消除这个bug。

514acbcfa999580b705ff00fd6652c6f.gif

有多厉害?

相关项目在GitHub才发布没几天就收获了600+星。

1cff14cb7731ef5cc8bbd9332be9b492.png

一行代码解决内存溢出错误

软件包名叫koila,已经上传PyPI,先安装一下:

pip install koila

现在,假如你面对这样一个PyTorch项目:构建一个神经网络来对FashionMNIST数据集中的图像进行分类。

先定义input、label和model:

# A batch of MNIST image
input = torch.randn(8, 28, 28)# A batch of labels
label = torch.randn(0, 10, [8])class NeuralNetwork(Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = Flatten()self.linear_relu_stack = Sequential(Linear(28 * 28, 512),ReLU(),Linear(512, 512),ReLU(),Linear(512, 10),)def forward(self, x):x = self.flatten(x)logits = self.linear_relu_stack(x)return logits

然后定义loss函数、计算输出和losses。

loss_fn = CrossEntropyLoss()# Calculate losses
out = nn(t)
loss = loss_fn(out, label)# Backward pass
nn.zero_grad()
loss.backward()

好了,如何使用koila来防止内存溢出?

超级简单!

只需在第一行代码,也就是把输入用lazy张量wrap起来,并指定bacth维度——

koila就能自动帮你计算剩余的GPU内存并使用正确的batch size了。

在本例中,batch=0,则修改如下:

input = lazy(torch.randn(8, 28, 28), batch=0)

完事儿!就这样和PyTorch“炼丹”时的OOM报错说拜拜。

灵感来自TensorFlow的静态/懒惰评估

下面就来说说koila背后的工作原理。

“CUDA error: out of memory”这个报错通常发生在前向传递(forward pass)中,因为这时需要保存很多临时变量。

koila的灵感来自TensorFlow的静态/懒惰评估(static/lazy evaluation)。

它通过构建图,并仅在必要时运行访问所有相关信息,来确定模型真正需要多少资源。

而只需计算临时变量的shape就能计算各变量的内存使用情况;而知道了在前向传递中使用了多少内存,koila也就能自动选择最佳batch size了。

又是算shape又是算内存的,koila听起来就很慢?

c90010d09d39173138fe20767ab67209.png

NO。

即使是像GPT-3这种具有96层的巨大模型,其计算图中也只有几百个节点。

而Koila的算法是在线性时间内运行,任何现代计算机都能够立即处理这样的图计算;再加上大部分计算都是单个张量,所以,koila运行起来一点也不慢。

你又会问了,PyTorch Lightning的batch size搜索功能不是也可以解决这个问题吗?

是的,它也可以。

但作者表示,该功能已深度集成在自己那一套生态系统中,你必须得用它的DataLoader,从他们的模型中继承子类,才能训练自己的模型,太麻烦了。

koila灵活又轻量,只需一行代码就能解决问题,非常“大快人心”有没有。

不过目前,koila还不适用于分布式数据的并行训练方法(DDP),未来才会支持多GPU

cb6aa84eb2d87d51a1aa22fd87eeb867.png

以及现在只适用于常见的nn.Module类。

6d2e9ac160ee16848369669a2d693dbb.png

ps. koila作者是一位叫做RenChu Wang的小哥。

756f188844feba6136a11ec33e707e1b.png

项目地址:
https://github.com/rentruewang/koila

参考链接:
https://www.reddit.com/r/MachineLearning/comments/r4zaut/p_eliminate_pytorchs_cuda_error_out_of_memory/

本文系网易新闻•网易号特色内容激励计划签约账号【量子位】原创内容,未经账号授权,禁止随意转载。9c0562c0213977fe4559239217283c61.png

outside_default.png

点个在看 paper不断!


http://lihuaxi.xjx100.cn/news/262648.html

相关文章

分析6千万条GitHub帖子,发现你的工作状态与表情符号强相关

作者 | 凌霄出品 | AI科技大本营(ID:rgznai100)新冠疫情使得远程办公的人数大幅度增加,然而,当越来越多的人远程工作时,人们的情绪和心理健康状态也难以通过日常面对面的交流来观察,雇主们也就无法获得员工…

Flask框架中Cookie与Session用法详解

1、Cookie 1.1 设置cookie from flask import Flask, make_responseapp Flask(__name__)app.route(/cookie) def set_cookie():resp make_response(set cookie ok)resp.set_cookie(username, itcast)return resp1.2 设置cookie有效期 from flask import Flask, make_respo…

起飞,会了这4个 Intellij IDEA 调试魔法,阅读源码都简单了

前言上一篇文章 IntelliJ IDEA 高级调试之Stream Trace 算是 IntelliJ IDEA 高级调试技巧的开胃菜,很多小伙伴被这个小技巧征服。趁热打铁,今天给大家带来几个我日常工作以及阅读源码必备的 IntelliJ IDEA 高级调试技巧,分分钟要起飞的节奏断…

SpringCloud的服务网关zuul

演示如何使用api网关屏蔽各服务来源 一、概念和定义 1、zuul最终还是使用Ribbon的,顺便测试一下Hystrix断路保护2、zuul也是一个EurekaClient,访问服务注册中心,获取元数据,使用本地的Ribbon负载均衡,Hystrix断路保护&…

Java的JVM,GC是什么?

JVM是Java Virtual Machine(Java虚拟机)的缩写。GC是垃圾收集的意思(Gabage Collection) JVM是一种用于计算设备的规范,它是一个虚构出来的计算机,是通过在实际的计算机上仿真模拟各种计算机功能来实现的。 write once&#xff…

第五篇:协调和协定之选举算法

目录 选举 基础 领导者 要求 表现 基于环的选举 算法 要求 表现 故障或者失败 霸凌算法 故障检测 算法 要求 表现 选举 选择独特的流程来扮演领导者的角色,承担特殊的任务 例如 基于服务器的互斥算法需要选举一个服务器进程伯克利算法 基础 任何进…

Linux之bash编程基本语法

在Linux运维工作中,我们为了提高工作效率通常会用bash编写脚本来完成某工作。今天就来为大家介绍bash的一些常见的基本语法。在讲解bash语法之前首先介绍一下bash。bash环境主要是由解释器来完成的。【解释器】:解释命令:词法分析、语法分析、…

《数据竞赛白皮书》发布:竞赛核心价值及促进人才数字化转型

近年来,“数据竞赛”已经成为大数据与人工智能领域的热门话题。据不完全统计,2014年开始,全球赛事超1000场,仅中国的竞赛场次年均增长达108.8%,累计超120万人次参加,奖金累计达到2.8亿人民币。拥有这样的增…