Долго думал, о чём бы сегодня рассказать. В общем, надо рассказать, как замедлять нейронные сети :)

Думаю, что у каждого практикующего DL инженера возникала out-of-memory (OOM) ситуация. Иногда надо просто поставить батч поменьше или исправить неправильную операцию. Однако не во всех случаях это возможно.

Давайте представим, что мы файнтюним самую большую модель из семейства YOLOv5 (можете заменить 5 на любой другой релиз). Специфика нашей задачи такова, что в процессе обучения необходимо передавать изображения размером 1024*1024 или даже 2048*2048.

Возникает проблема, что размер батча в этом случае стремится к 1, а видеокарты с большим объемом видеопамяти у нас под рукой нет. BatchNorm слои с небольшими батчами работают плохо. Как быть?

Один из методов снижения потребляемой видеопамяти, о котором я сегодня хотел бы рассказать — это gradient checkpointing. Для того, чтобы понять, как работает этот метод, надо сначала разобраться, на что тратится видеопамять в процессе обучения.

Первое, что съедает видеопамять — это параметры модели. Они хранятся на GPU. Также, если мы используем Adam для обучения, он хранит по две бегущие статистики на каждый вес нашей сети. Также, после вызова loss.backward(), у параметров сети появляется атрибут .grad, который тоже хранится в видеопамяти.

Но для сверточных нейронных сетей до 500М параметров больше всего видеопамяти тратится в forward проходе модели. Для быстрой реализации backward прохода, в процессе forward прохода PyTorch запоминает множество промежуточных значений, вычисляемых внутри forward. Например, обычная 2D свёртка сохраняет входной (N, C, H, W) тензор для того, чтобы вычислить значение градиента весов этой свёртки. Так как эти промежуточные значения зависят от входного разрешения, все эти промежуточные тензоры, которые хранятся до вызова loss.backward(), съедают много видеопамяти.

Переходим к замедлению обучения. Чтобы не хранить входы свёрток, можно в процессе backward прохода делать по одному forward проходу, не сохраняя никаких промежуточных тензоров. Когда мы добираемся до очередной свёртки в вычислительном графе, которой нужен градиент, мы используем полученное в текущем forward значение входа этой свёртки для вычисления градиента весов этой свёртки. Затем считаем градиент лосса по входу этой свёртки и переходим к предыдущей операции, для которой снова нужно делать forward проход, чтобы посчитать входные значения для этой операции.

В общем, все бы ничего, только это все имеет квадратичную сложность в зависимости от числа слоёв в сети. Если сеть неглубокая, но все слои там огромные — такой вариант имеет право существовать, но это не про Deep Learning.

Получается, нужен некоторый баланс между тем, сколько будет сохраняться промежуточных тензоров (которые будут использоваться в backward), и сколько из них будет пересчитываться заново в самом backward проходе. Механизм gradient checkpointing позволяет это настроить.

Для этого часть сети заключается в отдельную checkpoint операцию, которая в процессе forward прохода сохранит входные значения только для первой операции данной секции; при этом входные значения для второй и последующих операций сохраняться не будут. В процессе backward прогона, когда алгоритм autograd доберется до этой checkpoint операции, он сделает для неё полный forward прогон, теперь уже сохраняя все промежуточные значения, и потом сделает backward проход по всем операциям внутри данной части сети.

При корректном использовании, данный механизм позволяет сократить потребление памяти в 2-4 раза, при этом замедляя обучение не более чем в 2 раза. Ну а если возможно сократить потребление памяти в 4 раза, то можно и увеличить размер батча в 4 раза. И тогда уже BatchNorm будет работать нормально.