0

It's a simple regression problem. But no matter how much I try, I can't get the answer I want. I'm guessing the weight should be 32 (4 * 8) but, the code returns 25. Why is that?

This is my full source code:

import torch 
import torch.nn as nn
import torch.optim as op

X = torch.FloatTensor([[1., 2.],[2., 4.],[3., 6.]])
Y = torch.FloatTensor([[2.],[8.],[18.]])

class TEST(nn.Module):
    def __init__(self):
        super(TEST,self).__init__()
        self.l1 = nn.Linear(2,1)
        
    def forward(self, input):
        x = self.l1(input)
        return x
    
epochs = 2000
lr = 0.001
    
model = TEST()
loss_func = nn.MSELoss()
optimizer = op.SGD(model.parameters(), lr=lr)

for epoch in range(epochs):
    optimizer.zero_grad()
    output = model(X)
    loss = loss_func(output, Y)
    loss.backward()
    optimizer.step()
    
    if epoch%10 == 0:
        print('loss[{}] : {}'.format(epoch, loss))
        
XX = torch.FloatTensor([[4., 8.]])

print(model(XX))

This is the output of the code:

loss[1920] : 0.8891088366508484
loss[1930] : 0.8890921473503113
loss[1940] : 0.8890781402587891
loss[1950] : 0.8890655636787415
loss[1960] : 0.8890505433082581
loss[1970] : 0.8890388011932373
loss[1980] : 0.889029324054718
loss[1990] : 0.8890181183815002
tensor([[25.3124]], grad_fn=<AddmmBackward>)

1 Answer 1

2

You are trying to approximate y = x1*x2 but are using a single linear layer i.e. a purely linear model. Ultimately, what happens is you are learning weights a and b such that y = a*x1 + b*x2. However, this model cannot approximate the distribution of x1, x2 -> x1*x2.

Sign up to request clarification or add additional context in comments.

1 Comment

This is the correct answer he's trying to approximate a non -linear function with a linear model

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.