Friday, May 26, 2017

The Probability Monad (part 1)

This is an interesting idea: probability distributions can be modeled as monads. The canonical description lives here but it's very Haskell-heavy. So, in an attempt to grok the probability monad, you might like to look at a Scala implementation here.

[Aside. I tried looking at the Haskell code but had to run
cabal install --dependencies-only --enable-tests
see here for more information.]

The monad in this Scala library is Distribution[T] where T can be, say, a Double such as in a Gaussian distribution:

  object normal extends Distribution[Double] {
    override def get = rand.nextGaussian()

It could be something more interesting, for instance, here the Distribution monad in this particular case is parameterized with a List[Int].

   * You roll a 6-sided die and keep a running sum. What is the probability the
   * sum reaches exactly 30?
  def dieSum(rolls: Int): Distribution[List[Int]] = {
    always(List(0)).markov(rolls)(runningSum => for {
      d <- die
    } yield (d + runningSum.head) :: runningSum)
  def runDieSum = dieSum(30).pr(_ contains 30)

Simulation, simulation, simulation

The method pr will create a simulation where we sample an arbitrary number of monads (default of 10 000). We then filter them for those that contain a score of exactly 30 and calculate the subsequent probability.

Filtering the monads means that traversing the list of 10 000 and calling filter on each one to find ones with a score of 30. Each monad in the list is actually a recursive structure 30 deep (the number of  rolls of the dice; any more is pointless as the total will necessarily be greater than 30).

That's the high-level description. Let's drill down.

State monads again

This recursive structure is a state monad. The monads are created by the recursive calls to markov(). This method creates a new monad by calling flatMap on itself. The get method of this new, inner monad takes the value of its outer monad, passes it to the function that flatMap takes as an argument and in turn calls get on the result.

Having created this inner monad, markov() is called on it and we start the next level of recursion until we have done so 30 times. It is this chain of get calling get when the time comes that will build up the state.

Consequently, we have the outermost monad being a constant Distribution that holds List(0). This is what a call to the outermost get will return. However, get is not publicly accessible. We can only indirectly access it by calling the monad functions.

In short, we have what is a little like a doubly-linked list. The outermost monad contains the "seed", List(0), and a reference to the next monad. The inner monads contain a reference to the next monad (if there is one) and a reference to its outer monad's value via get.

Note that it is the innermost monad that is passed back to the call site calling dieSum, in effect turning the structure inside out.

Anyway, the next job is to filter the structure. This creates a new monad (referencing the erstwhile innermost monad) to do the job but remember monads are lazy so nothing happens yet. It's only when we call a sample method on this monad that something starts to happen. At this point, get is called and we work our way up the get-chain until we reach the outermost monad that contains List(0). Then we "pop" each monad, executing the runningSum => function on the results of the monad before. This is where we roll the die and add append the cumulative result to the List.

If the given of the filter monad is not met, then we keep trying the whole thing again until it is.

Finally, we count the results that meet our predicate dividing by total number of runs. Evidently, we've taken a frequentist approach to probabilities here.

No comments:

Post a Comment