Apply gradient to a tensor in a sparse way in PyTorch

132
March 24, 2022, at 02:00 AM

I have a very large tensor L (millions of elements), from which I gather a relatively small subtensor S (maybe a thousand of elements).

I then apply my model to S, compute loss, and backpropagate to S and to L with the intent to only update selected elements in L. Problem is PyTorch makes L's gradient to be a continuous tensor, so it basically doubles L's memory usage.

Is there an easy way to compute and apply gradient to L without doubling memory usage?

Sample code to illustrate the problem:

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
net = nn.Sequential(
  nn.Linear(1, 64),
  nn.ReLU(),
  nn.Linear(64,64),
  nn.ReLU(),
  nn.Linear(64, 1))
L = Parameter(torch.zeros([1024*1024*256], dtype=torch.float32))
L.data.uniform_(-1, 1)
indices = torch.randint(high=256*1024*1024, size=[1024])
S = torch.unsqueeze(L[indices], dim=1)
out = net(S)
loss = out.sum()
loss.backward()
print(loss)
g = L.grad
print(g.shape)  # this is huge!
Answer 1

You don't actually need requires_grad on L as gradients will be computed and applied manually. Instead, set it on S. That will stop backpropagation at S.

Then, you can update the values of L using S.grad and your preferred optimization. Something along these lines

L = torch.zeros([1024*1024*256], dtype=torch.float32)
...
S = torch.unsqueeze(L[indices], dim=1)
S.requires_grad_()
out = net(S)
loss = torch.abs(out).sum()
loss.backward()
with torch.no_grad():   
  L[indices] -= learning_rate * torch.squeeze(S.grad)
  S.grad.zero_()
Rent Charter Buses Company
READ ALSO
Interactive stdin, stdout with subprocess

Interactive stdin, stdout with subprocess

I've read some discussions about this but I couldn't find anything on how to correct the error I have

214
Python Function that Match a Given Template

Python Function that Match a Given Template

Say I have a template and I have a dictionary containing wordsI want to create a function that returns a list of words that match the given template

141
How to get image X Y coordinates and a pixel value at the indicated point by cursor with Python?

How to get image X Y coordinates and a pixel value at the indicated point by cursor with Python?

I have a gray scale image which size is (5472, 3648)How does python return the pixel value and x y cordinates of an point in the image? I want to get the values when I set the cursor and click the point in the image?

105
How to set selenium webdriver from headless mode to normal mode within the same session?

How to set selenium webdriver from headless mode to normal mode within the same session?

Is it possible after setting selenium webdriver to a headless mode set it back to a normal mode?

116