Since the major point of this blog post series is to get our feet wet playing with monads, it only makes sense that we will implement a list. I'm going to use a basic linked list implementation, and call my type Lisst, to avoid confusion with the List type in the Scala collections library. Just to be clear, this Lisst is by no means intended as an alternative to the Scala's List. It is only intended as an exercise to experiment with the monadic aspects of lists.
All the source code presented here is freely available on GitHub. Most of the new code is under package monads.lisst. Be sure to take a look under both src/main and src/test. I recommend you get a local copy, run ~ test in sbt, and play around with it.
This post builds upon the two previous posts in this series - here and here. If you haven't read them yet, I highly recommend you start there.
Monadic operations on lists
As we might expect, the page on the list monad begins by defining the unit function and the binding operation for a list. We recall from our first blog post that the unit function is a function from the contained type to the container type, and that binding operation is the flatMap method in Scala. We will once again put the unit function in the companion object. Let's just nail down the signatures for these methods before we provide an implementation:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object Lisst { | |
def apply[A](a: A): Lisst[A] = ??? | |
} | |
sealed trait Lisst[+A] { | |
def flatMap[B](f: A => Lisst[B]): Lisst[B] = ??? | |
} |
Linked list data structure
In order to start implementing things, we need to flesh out our linked list data type. I'll create a NonEmptyLisst type, that has a head and a tail, and an EmptyLisst class, which is used to terminate the linked list. (If you've never implemented a linked list before, or it's been a while, you may want to freshen up. I'm implementing a basic singly linked list here.) With this basic data structure defined, we can easily implement our unit function:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object Lisst { | |
def apply[A](a: A): Lisst[A] = NonEmptyLisst(a, EmptyLisst) | |
} | |
private case class NonEmptyLisst[+A]( | |
head: A, | |
tail: Lisst[A]) | |
extends Lisst[A] | |
private case object EmptyLisst extends Lisst[Nothing] |
Varargs factory method
Before proceeding to implement flatMap, I'd like to make one further change. Instead of just providing a unit method, I would like to provide a Lisst constructor that can take any number of arguments. This will allow me to construct items of type Lisst[Int] as in any of the following examples:- Lisst[Int]()
- Lisst(1)
- Lisst(1, 2, 3, 4, 5)
Within our implementation, the we get a Seq for the repeated parameters. But what kind of Seq we get is not specified in the Scala Language Specification (section 4.6.2). And a little experimentation shows out that we can indeed end up with nearly any sequence type inside our method. Since these sequences have different performance characteristics, we have to be careful to make sure we don't transform what should be an O(n) operation into an O(n^2) operation. For instance, if we get a List, using indexed access will give us O(n^2) behavior. On the other hand, it we get a Vector, then using head and tail will be O(n^2). The safest way to proceed is using foreach and reverse, both of which should be O(n) operations on any Seq type:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object Lisst { | |
def apply[A](elems: A*): Lisst[A] = reverseSeq(elems.reverse) | |
private def reverseSeq[A](elems: Seq[A]): Lisst[A] = { | |
var result: Lisst[A] = EmptyLisst | |
elems.foreach { elem => result = NonEmptyLisst(elem, result) } | |
result | |
} | |
} |
Implementing map and flatten
The Haskell wikibook implements the bind operation in terms of Haskell list operations concat and map. In Scala terms, flatMap is implemented in terms of flatten and map. However, we remember from the first blog post that while we can define flatMap in terms of flatten and map, we can conversely define flatten and map in terms of flatMap and the unit function. We're going to take the latter approach, as we did with Maybe. Let's fill in map and flatten now:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sealed trait Lisst[+A] { | |
def flatMap[B](f: A => Lisst[B]): Lisst[B] = ??? | |
def map[B](f: A => B): Lisst[B] = | |
flatMap { a => Lisst(f(a)) } | |
def flatten[B]( | |
implicit asLisstLisst: Lisst[A] <:< Lisst[Lisst[B]]) | |
: Lisst[B] = | |
asLisstLisst(this) flatMap identity | |
} |
Implementing foreach and flatMap
Now we're nearly ready to implement flatMap. To do this we are going to introduce an instance method foreach that walks through a Lisst, applying a side-effecting function f to every member of the Lisst in turn. It's a straightforward recursive method on a linked list. You'll notice that we've annotated this method as @tailrec, which assures us that the compiler is able to eliminate the tail recursion, which will prevent the method from blowing out our stack on very long lists. The @tailrec annotation actually turns out to be very useful here, because we learn that we have to make the method final for the Scala compiler to apply tail recursion elimination. The compiler is worried the method might get overridden, even though the Lisst trait is sealed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.annotation.tailrec | |
/** Walks through the lisst, applying function f to every element in turn. */ | |
@tailrec | |
final def foreach(f: A => Unit): Unit = this match { | |
case NonEmptyLisst(head, tail) => { | |
f(head) | |
tail.foreach(f) | |
} | |
case EmptyLisst => | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def flatMap[B](f: A => Lisst[B]): Lisst[B] = { | |
var reverseBs = List[B]() | |
this.foreach { a => | |
f(a).foreach { b => | |
reverseBs = b +: reverseBs | |
} | |
} | |
Lisst.reverseSeq(reverseBs) | |
} |
Testing Lisst[A]
Before moving further on into the Haskell wiki page, let's take a moment to write some tests for the above code. We're going to split testing into two parts. First, we'll write some simple tests to make sure Lisst.flatMap is acting as expected. Then we'll go on to generalize our code for testing Maybe in regards to monadic laws, and apply that testing code to Lisst as well. I'm not going to dwell on the Lisst.flatMap spec, but you can look at it on GitHub here and here. It may be useful to review if you are still not entirely clear on what it means to flatMap a list.
To test Lisst in regards to the monad laws, we will generalize some of the Maybe monad tests that we developed for the previous blog post. Let's quickly review what we did there. Starting from the monads-in-scala-2 Git tag for the project, we launch sbt and run test-only *MaybeMonadLawsSpec. The output from the test is as follows:
[info] MaybeMonadLawsSpec:
[info] Maybe monad with respect to Person data
[info] - should obey left unit monadic law
[info] - should obey right unit monadic law
[info] - should obey associativity monadic law
[info] - should flatten a Maybe[Maybe[_]] according to monadic laws
[info] - should flatMap equivalently to calling map and then flatten
[info] Maybe monad with respect to safe math operations
[info] - should obey left unit monadic law
[info] - should obey right unit monadic law
[info] - should obey associativity monadic law
[info] - should flatten a Maybe[Maybe[_]] according to monadic laws
[info] - should flatMap equivalently to calling map and then flatten
[info] Maybe monad with respect to looking up Numbers and Registrations by Name
[info] - should obey left unit monadic law
[info] - should obey right unit monadic law
[info] - should obey associativity monadic law
[info] - should flatten a Maybe[Maybe[_]] according to monadic laws
[info] - should flatMap equivalently to calling map and then flatten
[info] Maybe monad with respect to looking up Registrations and TaxesOwed by Number
[info] - should obey left unit monadic law
[info] - should obey right unit monadic law
[info] - should obey associativity monadic law
[info] - should flatten a Maybe[Maybe[_]] according to monadic laws
[info] - should flatMap equivalently to calling map and then flatten
As you will recall, we exercise various monad laws with respect to Maybe over 4 different data sets. For each data set, there are five tests. The first three exercise the monad laws proper, and the remaining two are specific to the Maybe type. We will generalize the code that exercises the monad laws proper, refactor the MaybeMonadLawsSpec to use the generalized code, and then write some basic monad laws tests for Lisst.
Abstracting a Monad Type
In order to write generalized test code for a monad, let's first review what a monad is. A monad is defined as having three parts:
- A type constructor M, which provides a higher-kinded type M[A] for any given base type A;
- A unit function, which takes an A and returns an M[A]; and
- A binding operation (i.e., flatMap), which takes an M[A] and a function from A to M[B], and returns an M[B].
These requirements readily translate into a Scala trait:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.language.higherKinds | |
trait Monad { | |
type M[_] | |
def unitFunction[A](a: A): M[A] | |
def bindingOperation[A, B](m: M[A], f: (A) => M[B]): M[B] | |
} |
It's easy to define singleton objects MaybeMonad and LisstMonad that implement the Monad trait:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object MaybeMonad extends Monad { | |
type M[A] = Maybe[A] | |
def unitFunction[A](a: A): M[A] = Just(a) | |
def bindingOperation[A, B](m: M[A], f: (A) => M[B]): M[B] = | |
m flatMap f | |
} | |
object LisstMonad extends Monad { | |
type M[A] = Lisst[A] | |
def unitFunction[A](a: A): M[A] = Lisst(a) | |
def bindingOperation[A, B](m: M[A], f: (A) => M[B]): M[B] = | |
m flatMap f | |
} |
A Method to Test Monad Laws
Let's step back and review some of the testing code we developed previously for Maybe and the monad laws. The function takes these arguments: a description of the test; some test items of type A, a function from A to Maybe[B], and a function from B to Maybe[C]. Within the test, we also generate a test set of Maybes by applying the unit function to every test item, and prepending MaybeNot:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def maybeShouldObeyMonadLaws[A, B, C]( | |
testDataDescription: String, | |
testItems: Seq[A], | |
f: Function1[A, Maybe[B]], | |
g: Function1[B, Maybe[C]]): Unit = { | |
val maybes = MaybeNot +: (testItems map { Just(_) }) | |
// perform the tests... | |
} |
To generalize this, we will need to add three more arguments: the monad; a name for the monad to use in the test descriptions; and a list of test items of type M[A]. With Maybe, it was easy to generate this kind of test data, but in general, we will need to leave this up to the caller. Here is the generalized method:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def monadShouldObeyMonadLaws[A, B, C]( | |
monadName: String, | |
monad: Monad)( | |
testDataDescription: String, | |
testAs: Seq[A], | |
testMs: Seq[monad.M[A]], | |
f: Function1[A, monad.M[B]], | |
g: Function1[B, monad.M[C]]): Unit = { | |
behavior of monadName + " monad with respect to " + testDataDescription | |
it should "obey left unit monadic law" in { | |
testAs foreach { a => | |
{ monad.bindingOperation(monad.unitFunction(a), f) | |
} should equal { | |
f(a) | |
} | |
} | |
} | |
it should "obey right unit monadic law" in { | |
testMs foreach { m => | |
{ monad.bindingOperation( | |
m, | |
{ a: A => monad.unitFunction(a) }) | |
} should equal { | |
m | |
} | |
} | |
} | |
it should "obey associativity monadic law" in { | |
testMs foreach { m => | |
{ monad.bindingOperation( | |
monad.bindingOperation(m, f), | |
g) | |
} should equal { | |
monad.bindingOperation( | |
m, | |
{ a: A => monad.bindingOperation(f(a), g) }) | |
} | |
} | |
} | |
} |
The most interesting thing about this code is that we are using path dependent types in the signature. Specifically, we refer to monad.M[A], monad.M[B], and monad.M[C], where monad is the higher kinded type in question. For our purposes, monad will be one of MaybeMonad and LisstMonad defined above, and monad.M will resolve to types Maybe and Lisst. Note that to use path dependent types in argument lists this way, we need to declare monad in a separate argument list from the arguments with type monad.M.
LisstMonadLawsSpec
It's straightforward to rework the MaybeMonadLawsSpec to use this new method. You can view the details here. To test Lisst, we provide all combinations of methods f and g that return Lissts of length 0, 1, and greater than one. For types A, B, and C, we use Int, String, and Double. We construct our test data roughly as follows:
Finally, we need to add a little something to produce different test descriptions for each data set, as required by ScalaTest. So we just zipWithIndex the fgPairs, and insert the index into the test descriptions:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def toZeroStrings(i: Int): Lisst[String] = Lisst[String]() | |
def toOneString(i: Int): Lisst[String] = Lisst(i.toString) | |
def toThreeStrings(i: Int): Lisst[String] = { | |
val s = i.toString | |
Lisst(s, s, s) | |
} | |
def toZeroDoubles(s: String): Lisst[Double] = Lisst[Double]() | |
def toOneDouble(s: String): Lisst[Double] = Lisst(s.toDouble) | |
def toThreeDoubles(s: String): Lisst[Double] = { | |
val d = s.toDouble | |
Lisst(d, d, d) | |
} | |
val fgPairs = for ( | |
f <- Seq(toZeroStrings _, toOneString _ , toThreeStrings _); | |
g <- Seq(toZeroDoubles _, toOneDouble _ , toThreeDoubles _)) | |
yield (f, g) | |
val testInts = Seq(0, 1, 2) | |
val testLissts = Seq(Lisst[Int](), Lisst(0), Lisst(0, 1, 2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fgPairs.zipWithIndex.foreach { | |
case ((f, g), i) => | |
monadShouldObeyMonadLaws( | |
"Lisst", | |
LisstMonad)( | |
"toString and toDouble conversions, " + | |
"permutation number " + (i + 1), | |
testInts, | |
testLissts, | |
f, | |
g) | |
} |
Bunny Invasion
The next section of the Haskell wiki presents a simple example using a function that replicates a single value into a list. This example is rather trivial compared to the work we've gone through to build and test Lisst, so the translation into Scala is presented here without further comment:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generation(s: String) = Lisst(s, s, s) | |
{ Lisst("bunny") flatMap generation | |
} should equal { | |
Lisst("bunny", "bunny", "bunny") | |
} | |
{ Lisst("bunny") flatMap generation flatMap generation | |
} should equal { | |
Lisst( | |
"bunny", "bunny", "bunny", | |
"bunny", "bunny", "bunny", | |
"bunny", "bunny", "bunny") | |
} |
Tic Tac Toe
In the tic tac toe example (noughts and crosses) in the Haskell wiki, the type Board and the function nextConfigs are never defined. This is understandable, and these are implementation details of the game that are not directly relevant to the monadic treatment of lists. We don't want to spend the time to write a full blown tic tac toe game in Scala either, but we do want to be able to run and test the code. So let's cook up just enough of an implementation to allow us to write and test the nextConfigs method.The Implementation
First of all, there are 9 positions on the board where a turn can be played. We don't want to worry about any of the details of the positions themselves, just to provide a safe way to enumerate them. Here's some defensive code that satisfies these requirements:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object Position { | |
val minPositionIndex = 1 | |
val maxPositionIndex = 9 | |
lazy val allPositions: Seq[Position] = | |
(minPositionIndex to maxPositionIndex) map | |
{ Position(_) } | |
} | |
case class Position(positionIndex: Int) { | |
assert( | |
positionIndex >= Position.minPositionIndex && | |
positionIndex <= Position.maxPositionIndex) | |
} |
Because X always goes first, a sequence of the positions played will fully describe the game state. We follow the Haskell example and name this type the Board. We'll keep the constructor and the position sequence private. We'll give the user access to an initial board state, and let them generate new board states with nextConfigs:
The nextConfigs method itself is quite simple and elegant to write in Scala, and really shows the power of a rich collections API built out from a monadic base. We start with all possible positions; filter out any positions already taken; add each remaining position to the sequence of moves so far; construct a new Board for each new sequence of moves; and finally, construct a Lisst of those new Boards.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object Board { | |
def apply(): Board = new Board(Seq()) | |
} | |
class Board private (private val moves: Seq[Position]) { | |
def nextConfigs: Lisst[Board] = ??? | |
// ... | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def nextConfigs: Lisst[Board] = { | |
val nextBoardsSeq: Seq[Board] = { | |
Position.allPositions | |
} filterNot { | |
moves contains _ | |
} map { | |
moves :+ _ | |
} map { | |
new Board(_) | |
} | |
Lisst(nextBoardsSeq: _*) | |
} |
Four Versions of thirdConfigs
The Haskell example goes on to define the function thirdConfigs in four different ways. thirdConfigs takes a Board state as input, and produces a list of all the Boards three moves out. These readily translate into Scala:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def tick(boards: Lisst[Board]): Lisst[Board] = | |
boards flatMap { _.nextConfigs } | |
def thirdConfigsVersion1(board: Board): Lisst[Board] = | |
tick(tick(tick(Lisst(board)))) | |
def thirdConfigsVersion2(board: Board): Lisst[Board] = | |
Lisst(board) flatMap | |
{ _.nextConfigs } flatMap | |
{ _.nextConfigs } flatMap | |
{ _.nextConfigs } | |
def thirdConfigsVersion3(board: Board): Lisst[Board] = | |
for( | |
board0 <- Lisst(board); | |
board1 <- board0.nextConfigs; | |
board2 <- board1.nextConfigs; | |
board3 <- board2.nextConfigs) | |
yield board3 | |
def thirdConfigsVersion4(board0: Board): Lisst[Board] = | |
for( | |
board1 <- board0.nextConfigs; | |
board2 <- board1.nextConfigs; | |
board3 <- board2.nextConfigs) | |
yield board3 |
Let's make sure these four versions produce the same results:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ thirdConfigsVersion1(Board()) | |
} should equal { | |
thirdConfigsVersion2(Board()) | |
} | |
{ thirdConfigsVersion1(Board()) | |
} should equal { | |
thirdConfigsVersion3(Board()) | |
} | |
{ thirdConfigsVersion1(Board()) | |
} should equal { | |
thirdConfigsVersion4(Board()) | |
} |
While we won't bother to confirm the full contents of the results, let's at least make sure that the size of the results is right. The number of board positions 3 turns out from an empty board is the permutation of 3 elements taken from a set of size 9:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def permutation(n: Int, r: Int): Int = | |
Range(n, n - r, -1).fold(1)(_ * _) | |
{ thirdConfigsVersion1(Board()).size | |
} should equal { | |
permutation(9, 3) | |
} |
thirdConfigs with Kleisli Composition
When reading over the Haskell variations on thirdConfigs, I couldn't help but thinking that this was a natural fit for Kleisli composition, introduced in the previous blog post (and previous page of the Wiki book).In the previous blog post, we used an implicit class to pimp a function suitable for flatMap with a kleisliCompose method:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
implicit class MaybeFunction[B, C](f: (B) => Maybe[C]) { | |
def kleisliCompose[A](g: (A) => Maybe[B]): (A) => Maybe[C] = { | |
a: A => | |
for (b <- g(a); c <- f(b)) yield c | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trait Monad { | |
// ... | |
implicit class KleisliComposition[B, C](f: (B) => M[C]) { | |
def kleisliCompose[A](g: (A) => M[B]): (A) => M[C] = { | |
a: A => bindingOperation(g(a), f) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import LisstMonad.KleisliComposition | |
val nc: (Board) => Lisst[Board] = _.nextConfigs | |
val nnc = nc kleisliCompose nc | |
val nnnc = nnc kleisliCompose nc | |
val thirdConfigsVersion5 = nnnc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ thirdConfigsVersion1(Board()) | |
} should equal { | |
thirdConfigsVersion5(Board()) | |
} |
List Comprehensions
Scala doesn't have list comprehensions. That's a good thing. It doesn't need them. For-comprehensions and the collections API methods are enough. List comprehensions probably make sense in Haskell, but I'm really not qualified to say.Wrapping Up
In this post, we've seen how to implement a singly linked list in the context of a monad. We came up with a generalization of what a monad is with the Scala trait Monad, and used that trait to generalize our testing framework for testing monad laws. We also provided a variety of examples of the Lisst monad in action, including in LisstFlatMapSpec, as well as the bunny invasion and tic tac toe examples provided in the Haskell wiki. Amazingly enough, we accomplished all this more or less using only the unit function and binding operation of the Lisst monad. We did make liberal use of the varargs factory method, but this was mainly used for constructing expected values in our tests. This goes to show that there is a lot we can do with lists without even appending, prepending, or concatenating.
At this point, we've thoroughly covered the two most canonical uses of monads: lists and optional values. I'm looking forward to covering slightly less mainstream monads such as State in the near future. We'll most likely skip over the Haskell wiki page on do notation entirely, or at best merge anything interesting there into a post on the IO monad, as the do notation wiki page is mostly devoted to describing a Haskell-specific language feature. Which means IO is next! (Interestingly enough, a moon of Jupiter that goes by the same name is undergoing a major volcanic eruption at the moment.)
At this point, we've thoroughly covered the two most canonical uses of monads: lists and optional values. I'm looking forward to covering slightly less mainstream monads such as State in the near future. We'll most likely skip over the Haskell wiki page on do notation entirely, or at best merge anything interesting there into a post on the IO monad, as the do notation wiki page is mostly devoted to describing a Haskell-specific language feature. Which means IO is next! (Interestingly enough, a moon of Jupiter that goes by the same name is undergoing a major volcanic eruption at the moment.)
No comments:
Post a Comment
Note: Only a member of this blog may post a comment.