I have a very large tensor L
(millions of elements), from which I gather a relatively small subtensor S
(maybe a thousand of elements).
I then apply my model to S
, compute loss, and backpropagate to S
and to L
with the intent to only update selected elements in L
. Problem is PyTorch makes L
's gradient to be a continuous tensor, so it basically doubles L
's memory usage.
Is there an easy way to compute and apply gradient to L
without doubling memory usage?
Sample code to illustrate the problem:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
net = nn.Sequential(
nn.Linear(1, 64),
nn.ReLU(),
nn.Linear(64,64),
nn.ReLU(),
nn.Linear(64, 1))
L = Parameter(torch.zeros([1024*1024*256], dtype=torch.float32))
L.data.uniform_(-1, 1)
indices = torch.randint(high=256*1024*1024, size=[1024])
S = torch.unsqueeze(L[indices], dim=1)
out = net(S)
loss = out.sum()
loss.backward()
print(loss)
g = L.grad
print(g.shape) # this is huge!
You don't actually need requires_grad
on L
as gradients will be computed and applied manually. Instead, set it on S
. That will stop backpropagation at S
.
Then, you can update the values of L
using S.grad
and your preferred optimization. Something along these lines
L = torch.zeros([1024*1024*256], dtype=torch.float32)
...
S = torch.unsqueeze(L[indices], dim=1)
S.requires_grad_()
out = net(S)
loss = torch.abs(out).sum()
loss.backward()
with torch.no_grad():
L[indices] -= learning_rate * torch.squeeze(S.grad)
S.grad.zero_()
Firebase Cloud Functions: PubSub, "res.on is not a function"
TypeError: Cannot read properties of undefined (reading 'createMessageComponentCollector')
I've read some discussions about this but I couldn't find anything on how to correct the error I have
Say I have a template and I have a dictionary containing wordsI want to create a function that returns a list of words that match the given template
I have a gray scale image which size is (5472, 3648)How does python return the pixel value and x y cordinates of an point in the image? I want to get the values when I set the cursor and click the point in the image?
Is it possible after setting selenium webdriver to a headless mode set it back to a normal mode?