One vs All logit to recognize handwritten digits

12 Mar

The blog has been neglected for long. I graduated from my Master’s program, landed a job and worked on some really fun assignments at work, I wish I could talk about them more here but I shouldn’t.

I have always been an enthusiast about data and solving problems using simple algorithms in data science. So I signed up for Andrew Ng’s machine learning course on Coursera recently. However the course was in Octave and as a die hard R fan who doesn’t really care about certificates any i (I learn more when I’m relaxed and I’m doing it just for pleasure) I decided to have a go at all the assignments in R.

I’m not presenting anything path breaking here, just some of my thoughts as I go through these assignments as an amateur enthusiast.

First off, Andrew Ng is a wonderful teacher who is extremely knowledgeable and yet has a simple clear way of presenting complex concepts.I recommend his course to anybody  who wants to have fun with data.

In the third assignment, we were given the extremely interesting task of recognizing handwritten digits from 0 to 9 using a machine learning algorithm, specifically one vs all logistic regression. The input consisted of pixel intensities of the image of the digits. Each digit being a row in the matrix.

Displaying a random number of rows of the input dataset as an image gave me this –


Logistic Regression is a rather common statistical modelling technique/machine learning algorithm used as a binary classifier. It is basically a generalized linear model that uses the sigmoid function as the link function. The task here was to extend the binary classifier to build a multiclass classifier which could learn and recognize each of these 10 different digits. The idea to extend it is pretty simple. Select each class(digit) in this case as ‘one’ and group together all other classes(digits) as the other ‘all’. So the logit now gives you the probability that the image was of the selected digit and not any of the other digits. Building such a model for each digit would give you 10 different models each giving the probability that the image is of the digit specified by the model. You could now select the maximum of these probabilities and assign the image to that category. If the maximum probability was for the digit 6, then the image is most likely a 6.

At the outset the above method seems simple and elegant. I was wondering why we progressed to multinomial logistic regression in our econometrics classes as soon as we had multiple classification problems, why we never paused to consider multiple logistic regressions. I also felt very stupid for not having thought of this. However as I started thinking about it more, I realized that there is a key difference between modelling for machine learning and modelling to study a social science like economics, though the models are very similar. In machine learning as far as I have seen the focus is on the prediction, the parameters of the model being mere tools to get us to an accurate prediction. In a social science( and in business) there is equal if not more emphasis on the parameters themselves as decision tools which tell us about the importance of each of the independent variables in the model.

In a multinomial logistic regression, probabilities of belonging to the various classes are evaluated jointly instead of in a stratified model(This also means you can do with just 1 model with n-1 sets of parameters for n classes instead of n models for n classes like in the one vs all example above). This means we get a clean set of probabilities for the mutually exclusive classes which add up to one. This makes it easy for us to calculate, in one go important marginal effects of the variables like the effect on the odds ratio for different groups vs the base group with a change in one of the dependent variables, in a coherent way for the entire system. I’m not sure how I would interpret marginal effects coming out of a onevsall model in a sensible way. Maybe someone can clear that up for me.

Another realization that hit me was that gradient descent is not a great way to learn the parameters of a logistic regression, that choosing the learning rate is a huge headache.In may scenarios, it is near impossible to choose a ‘goldilocks’ learning rate. In my moments of misplaced idealism I had decided not to use the packaged optimization functions Prof Ng asked us to use, trying my own hack of a gradient descent function instead. Either I would end up having a rather large learning rate leading to an algorithm that wouldn’t converge, and I think as mentioned in the class, would dance around a local minimum without actually hitting it. Looking something like this (lol)

local min

Or I would have a very small learning rate, which would take millions of iterations to converge. Anyway, finally I used a packaged optimizing function, though I could not use the matlab function coded and included in the exercise resources specifically for this exercise. (I was too lazy to translate this into R). I got an in-sample accuracy of 91% which is less than what was expected according to the assignment handout,I attribute this to not using the specified optimizing function.

Overall it was a fun exercise, I hope to learn a lot more and I also wish I had a test sample set aside to see how my onevsall classifier would work on a new sample of data. However given the time it takes for R to run the optimization function on my laptop, I’m totally chickening out of that for now.

PS: I’ve attached the R code for onevsall I used, in case anyone wants to read it and lambaste its inefficiencies.

Leave a comment

Posted by on March 12, 2014 in Data Analysis


Tags: ,

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )


Connecting to %s

%d bloggers like this: