In your code, I altered the printouts a bit, to visualize a bit better (at least in my opinion) what's going on:
import torch.nn as nn
class A(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 5)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.fc(x))
a = A().to('cuda')
print("\nfrom `A().to('cuda')`")
print(f"{id(a.fc.weight)=}, {a.fc.weight.data_ptr()=}")
print(f"{id(a.fc.bias)=}, {a.fc.bias.data_ptr()=}")
# from `A().to('cuda')`
# id(a.fc.weight)=138293720716624, a.fc.weight.data_ptr()=138293368850432
# id(a.fc.bias)=138293720716720, a.fc.bias.data_ptr()=138293368850944
weight = {}
for key, value in a.state_dict().items():
weight[key] = value
print("\nfrom `weight`")
for key, value in weight.items():
print(f"{key}, {id(value)=}, {value.data_ptr()=}")
# from `weight`
# fc.weight, id(value)=138293720716816, value.data_ptr()=138293368850432
# fc.bias, id(value)=138293720716528, value.data_ptr()=138293368850944
a.to('cpu')
print("\nfrom `a.to('cpu')`")
print(f"{id(a.fc.weight)=}, {a.fc.weight.data_ptr()=}")
print(f"{id(a.fc.bias)=}, {a.fc.bias.data_ptr()=}")
# from `a.to('cpu')`
# id(a.fc.weight)=138293720716624, a.fc.weight.data_ptr()=101884008832832
# id(a.fc.bias)=138293720716720, a.fc.bias.data_ptr()=101884008983616
If you compare the first two blocks of printouts, (the one from A().to('cuda') and the one from the weight dict), you get:
# from `A().to('cuda')`
# id(a.fc.weight)=138293720716624, a.fc.weight.data_ptr()=138293368850432
# id(a.fc.bias)=138293720716720, a.fc.bias.data_ptr()=138293368850944
# from `weight`
# fc.weight, id(value)=138293720716816, value.data_ptr()=138293368850432
# fc.bias, id(value)=138293720716528, value.data_ptr()=138293368850944
The IDs are different, but the data pointers are the same. What this means: the weight dict contains shallow copies of the tensors in a ("copies" because they have a new ID, "shallow" because they point to the same memory; namely the one on the GPU). This is in line with state_dict(), which you are using to produce the weight dict, and which is documented to produce shallow copies.
If you compare the first and last block of printouts, (the one from A().to('cuda') and the one from a.to('cpu'), you have the opposite situation:
# from `A().to('cuda')`
# id(a.fc.weight)=138293720716624, a.fc.weight.data_ptr()=138293368850432
# id(a.fc.bias)=138293720716720, a.fc.bias.data_ptr()=138293368850944
# from `a.to('cpu')`
# id(a.fc.weight)=138293720716624, a.fc.weight.data_ptr()=101884008832832
# id(a.fc.bias)=138293720716720, a.fc.bias.data_ptr()=101884008983616
The IDs are the same, but the data pointers are different. What this means: the parameters of your model a (a.fc.weight and a.fc.bias) still refer to the same tensor objects (same IDs), but in the meantime, the tensor's underlying memory has been replaced (different memory pointers; namely, now pointing to the CPU memory).
Your code ends with
print("a.state_dict() device:", [t.device for t in a.state_dict().values()]) # in CPU
print("weight device:", [t.device for t in weight.values()]) # still in GPU
So you are comparing
- the tensors of your model
a's parameters (or rather, a new shallow copy of them, since you call a.state_dict() once more), the memory of which, by now, has been moved to the CPU, with
- the shallow copies from earlier on (items in
weight dict, which result from your first call of a.state_dict()), whose memory still resides on the GPU.