Notes regarding dynamic programming and performance in Haskell

Posted on April 2, 2022

As some of you may know, some competitive programming sites allow you to use “exotic” languages. Recently I spent some of my spare time solving the HackerRank’s Interview Preparation Kit challenges in Haskell.

In this post I will talk about writing dynamic programming algorithms in Haskell. Using this problem as an example.

The solution

Let’s go straight to the solution.

Let f(i, j) be the number of decibinaries that has value of i using the maximum base of 2j (in other words, using the maximum of j+1 digits). We can compute f(i, j) by

f i j | 0 <= i && i <= 9 && j == 0  = 1
      | j == 0                      = 0
      | j > 0                       = sum [f ! (i - k*2^j, j - 1) | k <- [1 .. (min 9 (i `div` (2^j)))]]
                                      + f ! (i, j - 1)
      | otherwise                   = error "Impossible"

How to get to the final answer from f(i, j) is out of the scope of this post. Let’s see how we can compute f(i, j) efficiently in Haskell.

Memorization

Most of the time, straight-forward (or naive) implementations were fast enough. Instead of manually dispatching the execution scheme, you can usually get away with using memorization instead.

Memorization in Haskell is easy with lazy arrays. You create the cache array where each element will be computed by function f, and write the definition of f simply by changing every recursive call to indexing into the cache array in the equation:

import Data.Array

-- m = ...
-- n = ...
cache = listArray ((0, 0), (m, n)) [f i j | i <- [0..m], j <- [0..n]]

f i j | 0 <= i && i <= 9 && j == 0  = 1
      | j == 0                      = 0
      | j > 0                       = sum [cache ! (i - k*2^j, j - 1) | k <- [1 .. (min 9 (i `div` (2^j)))]]
                                      + cache ! (i, j - 1)
      | otherwise                   = error "Impossible"

For the largest testcase this implementation TLEs. It takes 2.363s on my PC with -O2 optimization flag.

Manual dispatching (the idiomatic Haskell way?)

I have to admit that doing manual dispatching in a pure functional way is hard. For simplier DP problems like the one in this stackoverflow question where you can compute the cache matrix row by row and need only the final solution, the accepted answer:

{-# LANGUAGE BangPatterns #-}
import Data.Vector.Unboxed
import Prelude hiding (replicate, tail, scanl)

pascal :: Int -> Int
pascal !n = go 1 ((replicate (n+1) 1) :: Vector Int) where
  go !i !prevRow
    | i <= n    = go (i+1) (scanl f 1 (tail prevRow))
    | otherwise = prevRow ! n
  f x y = (x + y) `rem` 1000000

works well. However, in this case we need to construct the full table. Unfortunately there is no function in the vector library where you can initialize a vector while referring to previous values in the process. AFAIK the best you can do is to create each row and append them together, which requires a lot of memory copy and is not very efficient.

For this HackerRank challenge I didn’t bother to try this method since I have to manually schedule the execution scheme anyway. There is really no advantage to implementing it in imperative style.

Manual dispatching (imperative ST monad)

The following code contains some other optimizations. Namely

  • computing sum of f over i in fSum at the same time.
  • break out of array generation when maxQueries is reached (this is needed because we are no longer utilizing lazy evaluation)
f :: UArray (Int, Int) Int
fSum :: UArray Int Int
(f, fSum) = runST $ do
  f <- newArray ((0, 0), (m, n)) 0 :: ST s (STUArray s (Int, Int) Int)
  for_ [0..9] $ \i -> writeArray f (i, 0) 1
  for_ [0..n] $ \j -> writeArray f (0, j) 1
  fSum <- newArray (0, m) maxQueries :: ST s (STUArray s Int Int)
  writeArray fSum 0 1

  _ <- runMaybeT $ for_ [1..] $ \i -> do
    prevSum <- lift $ readArray fSum (i - 1)
    if prevSum >= maxQueries
      then MaybeT $ return Nothing
      else lift $
        for_ [0..n] $ \j -> do
          case (i, j) of
            (i, j) | i > 9 && j == 0 -> writeArray f (i, j) 0
                   | j == 0 -> writeArray f (i, j) 1
                   | otherwise -> do part1 <- readArray f (i, j-1)
                                     part2 <- sum <$> sequence [readArray f (i - k*2^j, j - 1) | k <- [1 .. (min 9 (i `div` (2^j)))]]
                                     writeArray f (i, j) (part1 + part2)
    cur <- lift $ readArray f (i, maxBinaryIndex i)
    lift $ writeArray fSum i (prevSum + cur)

  ff <- unsafeFreeze f
  ffSum <- unsafeFreeze fSum
  return (ff, ffSum)

This implementation takes only 0.768s.

Lack of “break” in forloop

Loop constructs such as for_ are really just flipped traverse. To break out of the loop we can apply MaybeT moand transformer over the ST monad.

_ <- runMaybeT $ for_ [1..] $ \i -> do
  if earlyBreakCondition
    then MaybeT $ return Nothing
    else lift $ do
      -- ...
  • using unboxed array (STUArray and UArray) saves around 0.4s compared to boxed version (STArray and Array)
  • unsafeFreeze: the performance impact of copying the array once is insignificant compared to the complexity of the dynamic programming algorithm itself
  • for large input data you should always use ByteString instead of String IO. It saves me around 0.3s in this case