0

Linear Regression example with NumPy

Free lines Python 1 revisions 370 a year ago a year ago
import csv
import matplotlib.pyplot as plt
import numpy as np

def representsInt(s):
    try: 
        int(s)
        return True
    except ValueError:
        return False

# Load training data
with open('train.csv') as f:
	dataset = [{k: int(v) if representsInt(v) else v for k, v in row.items()}
		for row in csv.DictReader(f, skipinitialspace=True)]

# Simple data scaling without normalization
for i in range(0, len(dataset)):
	dataset[i]['SalePrice'] = dataset[i]['SalePrice']/1000

feature = 'TotalBsmtSF' # What we provide
target = 'SalePrice'    # What we look for

# Cost function
def cost():
	global dataset
	errorTotal = 0
	for item in dataset:
		errorTotal = errorTotal + ((h(item[feature]) - item[target])**2)
	return 1/(2 * len(dataset)) * errorTotal

m = 0
b = 0

# Hypothesis function
def h(x):
	global m, b
	return m*x + b

learningRate = 0.00000001 # Scaling should have been done better, more than that overshoots

def descend():
	global m, b, learningRate
	errorTotal = 0

	for item in dataset:
		errorTotal = errorTotal + (h(item[feature]) - item[target])
	newB = b - learningRate * 1/len(dataset) * errorTotal

	errorTotal = 0

	for item in dataset:
		errorTotal = errorTotal + (h(item[feature]) - item[target])*item[feature]
	newM = m - learningRate * 1/len(dataset) * errorTotal

	b = newB
	m = newM

errors = [
	{'error': cost(), 'iteration': 0}
]

iteration = 0

for i in range(0, 300):
	descend()
	iteration = iteration + 1
	errors.append({'error': cost(), 'iteration': iteration})

plt.figure(1)
plt.title('Houses')
plt.plot(np.arange(0, 6000, 0.1), [h(x) for x in np.arange(0, 6000, 0.1)], 'red')
plt.scatter([item[feature] for item in dataset], [item[target] for item in dataset])
plt.xlabel('Area (sqft)')
plt.ylabel('Price ($k)')

plt.figure(2)
plt.title('Error rate')
plt.plot([error['iteration'] for error in errors], [error['error'] for error in errors], 'red')
plt.xlabel('Iteration')
plt.ylabel('Error')

plt.show()

Screenshot

This is a Linear Regression example done in Python with help of NumPy, as well as Matplotlib for plotting data. It uses Kaggle's training data as an example. This example should be directly runnable given train.csv exists within the same directory. It also requires Matplotlib and NumPy, obviously.

I have implemented this after Andrew Ng's course on Coursera, and given that I am getting started it might not be the best approach. Also note that there are very inefficient parts in code, but these are intended to simplify code and make it easily readable, since it's only for learning purposes.

Here's how it works step by step:

  • First it imports some Python stuff, as well as defines representsInt helper.
  • Then it loads the CSV file located at ./train.csv to dataset in the form of list of dicts
  • Then it scales the SalePrice variable. Since it's thousands of dollars, we divide by 1000 to make it $k
  • Two variables are defined:
    • feature which indicates what thing we have or whatever
    • target which indicates what we look for/want to predict
  • Then, a cost function is defined. The cost function is Mean Squared Error, which is seemingly common for Linear Regression: Cost function
  • Then m and b are defined for our hypothesis function. The way we are approaching this problem is by finding the the line the fits the best between the data. (see screenshot) So given that the simplest form of a line is y = mx + b, all we need to tune and find good values for are m and b.
  • Then, we define our hypothesis function h, which just represents the line.
  • Then we implement Gradient Descent, which is basically an algorithm that minimizes functions based on their derivatives. It uses a small learning rate because our data is not well normalized, otherwise it would overshoot. We use Gradient Descent to minimize the cost: Gradient Descend
  • Then errors and iteration are just variables that track our hypothesis function accuracy as we apply Gradient Descent.
  • We do 300 steps of Gradient Descent so that it reaches a reasonable convergence.
  • Plot the data.