# Apply gradient to a tensor in a sparse way in PyTorch

132
March 24, 2022, at 02:00 AM

I have a very large tensor `L` (millions of elements), from which I gather a relatively small subtensor `S` (maybe a thousand of elements).

I then apply my model to `S`, compute loss, and backpropagate to `S` and to `L` with the intent to only update selected elements in `L`. Problem is PyTorch makes `L`'s gradient to be a continuous tensor, so it basically doubles `L`'s memory usage.

Is there an easy way to compute and apply gradient to `L` without doubling memory usage?

Sample code to illustrate the problem:

``````import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
net = nn.Sequential(
nn.Linear(1, 64),
nn.ReLU(),
nn.Linear(64,64),
nn.ReLU(),
nn.Linear(64, 1))
L = Parameter(torch.zeros([1024*1024*256], dtype=torch.float32))
L.data.uniform_(-1, 1)
indices = torch.randint(high=256*1024*1024, size=[1024])
S = torch.unsqueeze(L[indices], dim=1)
out = net(S)
loss = out.sum()
loss.backward()
print(loss)
print(g.shape)  # this is huge!
``````

You don't actually need `requires_grad` on `L` as gradients will be computed and applied manually. Instead, set it on `S`. That will stop backpropagation at `S`.

Then, you can update the values of `L` using `S.grad` and your preferred optimization. Something along these lines

``````L = torch.zeros([1024*1024*256], dtype=torch.float32)
...
S = torch.unsqueeze(L[indices], dim=1)
out = net(S)
loss = torch.abs(out).sum()
loss.backward()
``````
POPULAR ONLINE

214

### Python Function that Match a Given Template

Say I have a template and I have a dictionary containing wordsI want to create a function that returns a list of words that match the given template

141

### How to get image X Y coordinates and a pixel value at the indicated point by cursor with Python?

I have a gray scale image which size is (5472, 3648)How does python return the pixel value and x y cordinates of an point in the image? I want to get the values when I set the cursor and click the point in the image?

105

### How to set selenium webdriver from headless mode to normal mode within the same session?

Is it possible after setting selenium webdriver to a headless mode set it back to a normal mode?

116