1

I just started Machine Learning using Siraj Raval's videos at youtube and tried the challenge of the video "Intro - The Math of Intelligence" which is to perform Linear Regression using Gradient Descent using a dataset from kaggle.com. This is my code:

"""
An Example of a Linear Regression model.

Here i am taking an example from https://www.kaggle.com/alopez247/pokemon
to find a relation between variable "Total" and "HP".

"""
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import sys
import os

data = pd.read_csv("./pokemon_alopez247.csv")
d = {"Total": data['Total'],
     "HP": data['HP']}
smallData = pd.DataFrame(d)
test = smallData.values
epsilon = 0.001


def compute_error_for_line(b, m, points):
    """Return the Error for Line given the points."""
    totalError = 0
    for i in range(0, len(points)):
        x = test[i, 0]
        y = test[i, 1]
        totalError += (y - (m * x + b)) ** 2
    return totalError / float(len(points))


def step_gradient(b_current, m_current, points, learningRate):
    """Return the new b and m points."""
    b_gradient = 0
    m_gradient = 0
    N = float(len(points))
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        error = y - ((m_current * x) + b_current)
        b_gradient += -(2 / N) * error
        m_gradient += -(2 / N) * x * error
    new_b = b_current - (learningRate * b_gradient)
    new_m = m_current - (learningRate * m_gradient)
    return [new_b, new_m]


def main():
    """Return and plot function here."""
    plt.figure(num=None, figsize=(20, 10), dpi=80,
               facecolor='w', edgecolor='k')
    plt.axis([0, 780, 0, 260])
    plt.ylabel("Total")
    plt.xlabel("HP")
    plt.scatter(test[:, [1]], test[:, [0]], c='r', s=1)

    m = 0.3
    b = -30
    x = np.arange(800)
    y = m * x + b
    for i in range(30):
        error = compute_error_for_line(b, m, test)
        print("error :", error)
        if(error > epsilon):
            y = m * x + b
            plt.plot(x, y)
            b, m = step_gradient(b, m, test, 0.0001)
            print("b , m :", b, ",", m)
            plt.pause(0.01)

    plt.show()

    plt.pause(0.001)

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        print('Interrupted')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)

and the output is:

error : 193676.072288
b , m : -29.91451362 , 6.46934413315
/usr/local/lib/python3.5/dist-packages/matplotlib/backend_bases.py:2445: MatplotlibDeprecationWarning: Using default event loop until function specific to this GUI is implemented
  warnings.warn(str, mplDeprecation)
error : 16427.2683093
b , m : -29.9134163218 , 6.04491523016
error : 15588.2873385
b , m : -29.9065147511 , 6.07401898958
error : 15583.8939554
b , m : -29.9000125838 , 6.07192788394
error : 15583.4489928
b , m : -29.8934831191 , 6.07198242461
error : 15583.0227312
b , m : -29.8869557061 , 6.07188938575
error : 15582.5965792
b , m : -29.8804283262 , 6.07180649992
error : 15582.1704489
b , m : -29.8739011182 , 6.07172291798
error : 15581.74434
b , m : -29.8673740726 , 6.07163938615
error : 15581.3182523
b , m : -29.86084719 , 6.0715558531
error : 15580.8921858
b , m : -29.8543204704 , 6.07147232236
error : 15580.4661407
b , m : -29.8477939138 , 6.0713887937
error : 15580.0401168
b , m : -29.8412675201 , 6.07130526712
error : 15579.6141143
b , m : -29.8347412894 , 6.07122174263
error : 15579.1881329
b , m : -29.8282152217 , 6.07113822022
error : 15578.7621729
b , m : -29.821689317 , 6.0710546999
error : 15578.3362341
b , m : -29.8151635752 , 6.07097118166
error : 15577.9103166
b , m : -29.8086379963 , 6.07088766551
error : 15577.4844204
b , m : -29.8021125804 , 6.07080415145
error : 15577.0585455
b , m : -29.7955873275 , 6.07072063947
error : 15576.6326918
b , m : -29.7890622375 , 6.07063712957
error : 15576.2068594
b , m : -29.7825373104 , 6.07055362176
error : 15575.7810482
b , m : -29.7760125462 , 6.07047011604
error : 15575.3552583
b , m : -29.769487945 , 6.0703866124
error : 15574.9294897
b , m : -29.7629635067 , 6.07030311084
error : 15574.5037423
b , m : -29.7564392314 , 6.07021961138
error : 15574.0780162
b , m : -29.7499151189 , 6.07013611399
error : 15573.6523114
b , m : -29.7433911694 , 6.07005261869
error : 15573.2266278
b , m : -29.7368673827 , 6.06996912548
error : 15572.8009655
b , m : -29.730343759 , 6.06988563435
[Finished in 73.209s]

So the output suggests that everything is going according to plan. But look at this. The first blue is are the original values and the line are getting farther away! I tried re-writing the compute_error_for_line and step_gradient functions but still nothing. Thanks for reading to the end.

So how can achieve the parameters for the line which best fits my sample space?

Link to my csv file here (this file will expire in 22 hours).

2
  • Your gradient function does not compute error the same way compute_error_for_line does (gradient function does not square error, other function does square error). Is that on purpose? Commented Jul 11, 2017 at 22:24
  • @kbrose actually it does. I am taking the partial derivates in the step_gradient function (with respect to b and m respt.). Commented Jul 11, 2017 at 22:27

1 Answer 1

1
    plt.scatter(test[:, [1]], test[:, [0]], c='r', s=1)

looks like you have the x and y values swapped. If you change [1] to [0] and vice versa, the plot looks pretty good

Sign up to request clarification or add additional context in comments.

2 Comments

any intuitions to how to find these types of bugs?
I looked at your gradient update step and it seemed fine, so I thought maybe there was an outlier. I looked at your code where you were plotting to see if you were omitting any of the data and then I saw that. I'm not sure I have any tips in general, though

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.