2013-03-26

Monads in Scala Part Two: More Maybes

In the first installment of this series, we translated the examples from the first page of the chapter on monads in the Haskell wikibook from Haskell to Scala. In this blog post, we will continue on with the second page. This page continues with more Maybe examples, which will help us to generalize our monadic laws test for Maybes a bit. We'll also learn about kleisli composition, which composes functions suitable for passing to flatMap.

All the source code presented here is available here on GitHub. I'm compiling against Scala 2.10.1. Some of this code will not compile under Scala 2.9.x.

Safe Functions

The first section of the wiki page discusses safe functions - versions of mathematical functions that normally produce a Just wrapping the result, but that produce MaybeNot when the input is not within the range of that mathematical function. (As a reminder, I am using MaybeNot where the Haskell examples using Nothing, to avoid confusion with the Nothing found in the Scala library.) This will allow us to "floor out" when chaining mathematical operations - anything along the way that produces a MaybeNot will cause the MaybeNot to propagate right on down the chain.

I'm not sure what Haskell's numeric types look like, but we do already get this kind of behavior when using Java doubles that underlie Scala Doubles. For example, sqrt(-1.0d) will produce double value NaN, which stands for "not a number". If we then take the log of that result - log(NaN) - we will still get NaN. Basically any math operation where one of the operands is NaN will produce a NaN. So safe math functions may not be as useful in Java or Scala, but it's a good example, so let's roll with it.

Here are my safe versions of sqrt and log:

Now when mathematicians say things like "log square root", they mean you should take the square root first, and then take the log of that. Some people, like myself, might find this confusing, so I wanted to point that out. The initial implementation of safeLogSqrt in the example simply checks the outcome of sqrt before passing it on to log, like so:

This is a mouthful, especially when the unsafe version is as simple as this:
• def unsafeLogSqrt(d: Double): Double = log(sqrt(d))
Of course, we want to use flatMap (>>= in Haskell) instead of manually testing if safeSqrt produced a Just or a MaybeNot. The Haskell example presented here looks like this:

Which, according to the "left unit" Monad law, can be simplified to this:

I'm not sure why the Haskell example uses the longer form here. My guess is that it's because return x >>= safeSqrt >>= safeLog looks sexier than safeSqrt(x) >>= safeLog, but that's just my best guess.

Before moving on, here's a version of safeLogSqrt that uses a for loop:

Kleisli Composition

The Haskell wikibook points out that you can produce an unsafe logSqrt function by simply composing the unsafe functions log and sqrt. For function composition, mathematicians use a little dot like this: . In Haskell, the composition is done like this:
• unsafeLogSqrt = log . sqrt
The function on the right is the one that gets called first. In Scala, it would look something like this:

Let's step back and examine what's going on here. In Scala, when we follow a function's name with an underscore, we are indicating that we are talking about the function itself, and not trying to apply that function. scala.math.log is a scala.Function1, and as such, has a compose method that takes a scala.Function1 as an argument. The result is a new function that behaves by calling sqrt first, and then calling log on the result of that.

At this point, the wikibook brings in the Kleisli composition operator, <=<, and shows how you can define a safeLogSqrt with similar brevity, like so:
• safeLogSqrt = safeLog <=< safeSqrt
This is similar to function composition, but ordinary function composition will not work, because the output from safeSqrt is Maybe[Double], and the input to safeLog is Double. So we really want to chain these together in a flatMap style, as in the body of a Scala for loop.

As a first attempt at a kleisli composition operator in Scala, let's write a method that takes in two functions as arguments, and produces the kleisli composition of those two functions:

We are producing a function that takes an A as input and produces a Maybe[C]. We first apply g, which takes an A as input and produces a Maybe[B]. Then we flatMap the Maybe[B] with f, which takes a B as input and produces a Maybe[C]. Now we can define safeLogSqrt as follows:

This is nice, but wouldn't it be better if we could achieve this using an infix notation, to match scala.math.log _ compose scala.math.sqrt _? We cannot simply add a kleisliCompose method to Function1. What we do in Scala >= 2.10 in these circumstances is build an implicit class that contains the desired new method. Let's take a look at our implicit class, and then break it down:

The kleisliCompose method itself is the same, except that type parameters B and C, and parameter f, have been lifted out of the method and into the enclosing class. The MaybeFunction constructor takes a single Function1 as argument - in particular, a Function1 that returns a Maybe.

Let's take note that MaybeFunction is implicit, and that it has a method kleisliCompose. Now, when the compiler encounters something like safeLog _ kleisliCompose safeSqrt _, and tries to resolve the method call, it does not immediately find a method kleisliCompose in Function1, which is the type of safeLog _. Before it gives up, it looks for any implicit conversions it could use to transform safeLog _ into that has a method named kleisliCompose with an applicable signature. Assuming our MaybeFunction class is in the right scope, it will implicitly construct a MaybeFunction from the Function1, and call kleisliCompose on that.

This Pimp my Library pattern is available in languages like Ruby and Groovy via meta-programming. In Scala, it's done at compile-time in a type-safe way. While implicit classes are newly available in Scala 2.10, the same effect has been achieved for years with implicit functions. The behavior of implicit classes is more controlled than that of implicit functions, as the target of the type conversion is constrained to be the new, implicit class.

So let's go ahead and use this pimped Function1 to define our safeLogSqrt function:

Testing Safe Functions

What about testing all the safe functions we have defined above? First of all, we could write tests that assure that all of the versions of safeLogSqrt presented above produce the same results. To do this, we first need to define a set of test data to use. To be sure to test every outcome, I've included input values for which safeLogSqrt will return both a Just and a MaybeNot. I've also included a value (0) for which safeSqrt will return a Just, but safeLogSqrt will return a MaybeNot. So I'm covering the three major code paths: Double to Just[Double] to Just[Double]; Double to Just[Double] to MaybeNot; and Double to MaybeNot to MaybeNot.

Now the test. Just a word of note here, I am switching my tests from just using asserts to using ScalaTest's ShouldMatchers. I'm using "org.scalatest" %% "scalatest" % "2.0.M5b", but any recent version should work. I have 5 versions of safeLogSqrt, so testing four pairs will do:

In the last blog post, we developed tests to assure that Maybe obeyed the monadic laws in regards to some sample Person data. We now have another data set that we can test the Maybe monad against. But the old version of the test is hard-coded to use Person data. Let's generalize that into a function that takes the required test data as arguments. Then we can call this function for different sets of test data. We can easily extrapolate over the types involved. We also pass a String description to include in a FlatSpec behavior clause. The function opens like this:

As you will recall, the left unit Monad law states that m >>= return = m. Let's assert this over all of our test data:

For right unit and associativity, we need to seed a list of all possible Maybes. We'll throw MaybeNot in with the rest of our test data wrapped in Justs:

We have a couple more Monad tests from last time. I won't go into the details, and you can check it out in the source code. I want to get down to business, which is calling this function over our two data sets:

Lookup Tables

The next major section of the wikibook page is on lookup tables. A couple of things perplexed me here. First, they are using a list of pairs as a lookup table. I did check, and it seems Haskell has a perfectly serviceable Map class. Using a list of pairs instead of a library collection class would never occur to me. I'm probably spoiled by the quality of Scala's collection library. The second thing that seemed quite different to me was the fact that they are putting raw strings and numbers into the map. There are two maps in their example that are semantically different, but both have the same type signature: string to string. Haskell has such a rich type system, it seems like they could easily do better than this. They probably made these choices to keep the example simple, and to focus the discussion on the use of Monads. But it wouldn't be natural Scala without using case classes and maps. Typing the data is trivial in Scala:

Before writing any lookup methods, we need to seed the maps. First, we'll generate a hundred or so pieces of test data for each type:

We're going to generate the maps in such a way that there is at least one test value for every permutation of Just and MaybeNot. First, we agree that in every key/value pair in a map, the key and the value relate back to the same Int. Then, we filter each map down to the pairs whose underlying Int is divisible by some small prime. Using 2, 3, and 5 over a hundred element test set will give all permutations. Here's how it looks:

We need to add a conversion from Option to Maybe to translate the result of Map#get. A natural place for this is in the Maybe companion object:

We'll define the single-level lookup methods in terms of a generic lookup function that also takes the Map as parameter:

Finally, we can implement the lookup methods that span multiple dictionaries with for loops or flatMaps: