Monday, June 30, 2008

Matrix Multiplication in Scala

I recently started experimenting with the Scala programming language after hearing the JavaPosse guys raving about it. I found the online tutorials to be relatively shallow, so I went out and bought the Programming in Scala textbook. The book is a tutorial-style introduction to the language, and is quite well written. It's not yet finished, so I bought the wip PDF, and will receive the full textbook by mail when it's finally published.

Scala has a lot to like so far, although at first I found it quite daunting. I remember feeling that I was ready to get started with Python after doing the Python tutorial, but with Scala, it's clear that I have yet to learn about many of the features of the language.

For my first steps, I decided to make a simple dense matrix class, based on Scala Lists, which the textbook describes as "the workhorses of Scala". Initially, I wanted to make the matrix generic, but discovered that I would need some special voodoo to ensure that the type parametrization of the matrix would only occur for classes that had the appropriate operators (+, -, * and so on). That is way beyond me at this stage, so I went for a concrete implementation using Doubles

Here is a blow-by-blow account of the class:
The Matrix class takes a List of its rows as a parameter, and the rows are then stored in a val named elements.
Here, the number of rows and columns of the matrix are calculated from the list of elements, and we have a require directive which checks that all rows of the matrix have the same number of elements (ie. the same number of columns). The forall method returns true if a particular predicate is true for all members of a List. In this case, our predicate is constructed as the function _.length == nCols which is equivalent to the function object (x: List[Double]) => x.length == nCols.
These methods define addition and subtraction by adding rows of the matrices, while we get transpose for free from List. Again in these methods, the underscore notations are equivalent to function objects. For example, _+_ is in this case equivalent to (x: Double, y: Double) => x + y
This is a really fun way to multiply matrices. It directly implements the concept of multiplying rows of one matrix by the columns of the other. The inner yield produces a List[Double] formed from the dot products of each row of the left matrix by all of the columns of the right matrix. This produces a single row of the result matrix. The outer yield simply collects all of the rows of the result matrix.
For pretty-printing, this adds the toString method.
This object adds the ability to construct a Matrix from an array of Doubles provided to the apply method.

If the above code is in a file called Matrix.scala, then we can do the following in the Scala interpreter:
So, dense matrices with multiplication, addition, subtraction and transposition can be implemented fairly easily in Scala. The next steps could include making the Matrix class parametrizable (so we could have a Matrix[Int] or Matrix[Double]). It would also be interesting to investigate the performance of Array-based vs List-based matrix operations. I think array-based implementations would definitely win out in the case of operations that need frequent random element access (like Gaussian elimination), but may be no better than using lists for things like multiplication and addition.

No comments:

Post a Comment