david wong

Hey! I'm David, cofounder of zkSecurity and the author of the Real-World Cryptography book. I was previously a crypto architect at O(1) Labs (working on the Mina cryptocurrency), before that I was the security lead for Diem (formerly Libra) at Novi (Facebook), and a security consultant for the Cryptography Services of NCC Group. This is my blog about cryptography and security and other related topics that I find interesting.

State monads in OCaml posted November 2022

Previously I talked about monads, which are just a way to create a "container" type (that contains some value), and let people chain computation within that container. It seems to be a pattern that's mostly useful in functional languages as they are often limited in the ways they can do things.

Little did I know, there's more to monads, or at least monads are so vague that they can be used in all sorts of ways. One example of this is state monads.

A state monad is a monad which is defined on a type that looks like this:

type 'a t = state -> 'a * state

In other word, the type is actually a function that performs a state transition and also returns a value (of type 'a).

When we act on state monads, we're not really modifying a value, but a function instead. Which can be brain melting.


The bind and return functions are defined very differently due to this.

The return function should return a function (respecting our monad type) that does nothing with the state:

let return a = fun state -> (a, state)
let return a state = (a, state) (* same as above *)

This has the correct type signature of val return : 'a -> 'a t (where, remember, 'a t is state -> ('a, state)). So all good.

The bind function is much more harder to parse. Remember the type signature first:

val bind : 'a t -> f:('a -> 'b t) -> 'b t

which we can extend, to help us understand what this means when we're dealing with a monad type that holds a function:

val bind : (state -> ('a, state)) -> f:('a -> (state -> ('b, state))) -> (state -> ('b, state))

you should probably spend a few minutes internalizing this type signature. I'll describe it in other words to help: bind takes a state transition function, and another function f that takes the output of that first function to produce another state transition (along with another return value 'b).

The result is a new state transition function. That new state transition function can be seen as the chaining of the first function and the additional one f.

OK let's write it down now:

let bind t ~f = fun state ->
    (* apply the first state transition first *)
    let a, transient_state = t state in
    (* and then the second *)
    let b, final_state = f a transient_state in
    (* return these *)
    (b, final_state)

Hopefully that makes sense, we're really just using this to chain state transitions and produce a larger and larger main state-transition function (our monad type t).


How does that look like when we're using this in practice? As most likely when a return value is created, we want to make it available to the whole scope. This is because we want to really write code that looks like this:

let run state =
    (* use the state to create a new variable *)
    let (a, state) = new_var () state in
    (* use the state to negate variable a *)
    let (b, state) = negate a state in
    (* use the state to add a and b together *)
    let (c, state) = add a b state in
    (* return c and the final state *)
    (c, state)

where run is a function that takes a state, applies a number of state transition on that state, and return the new state as well as a value produced during that computation. The important thing to take away there is that we want to apply these state transition functions with values that were created previously at different point in time.

Also, if that helps, here are the signatures of our imaginary state transition functions:

val new_var -> unit -> state -> (var, state)
val negate -> var -> state -> (var, state)
val add -> var -> var -> state -> (var, state)

Rewriting the previous example with our state monad, we should have something like this:

let run =
    bind (new_var ()) ~f:(fun a ->
        bind (negate a) ~f:(fun b -> bind (add a b) ~f:(fun c -> 
            return c)))

Which, as I explained in my previous post on monads, can be written more clearly using something like a let% operator:

let t = 
    let%bind a = new_var () in
    let%bind b = negate a in
    let%bind c = add a b in
    return c

And so now we see the difference: monads are really just way to do things we can already do but without having to pass the state around.

It can be really hard to internalize how the previous code is equivalent to the non-monadic example. So I have a whole example you can play with, which also inline the logic of bind and return to see how they successfuly extend the state. (It probably looks nicer on Github).

type state = { next : int }
(** a state is just a counter *)

type 'a t = state -> 'a * state
(** our monad is a state transition *)

(* now we write our monad API *)

let bind (t : 'a t) ~(f : 'a -> 'b t) : 'b t =
 fun state ->
  (* apply the first state transition first *)
  let a, transient_state = t state in
  (* and then the second *)
  let b, final_state = f a transient_state in
  (* return these *)
  (b, final_state)

let return (a : int) (state : state) = (a, state)

(* here's some state transition functions to help drive the example *)

let new_var _ (state : state) =
  let var = state.next in
  let state = { next = state.next + 1 } in
  (var, state)

let negate var (state : state) = (0 - var, state)
let add var1 var2 state = (var1 + var2, state)

(* Now we write things in an imperative way, without monads.
   Notice that we pass the state and return the state all the time, which can be tedious.
*)

let () =
  let run state =
    (* use the state to create a new variable *)
    let a, state = new_var () state in
    (* use the state to negate variable a *)
    let b, state = negate a state in
    (* use the state to add a and b together *)
    let c, state = add a b state in
    (* return c and the final state *)
    (c, state)
  in
  let init_state = { next = 2 } in
  let c, _ = run init_state in
  Format.printf "c: %d\n" c

(* We can write the same with our monad type [t]: *)

let () =
  let run =
    bind (new_var ()) ~f:(fun a ->
        bind (negate a) ~f:(fun b -> bind (add a b) ~f:(fun c -> return c)))
  in
  let init_state = { next = 2 } in
  let c, _ = run init_state in
  Format.printf "c2: %d\n" c

(* To understand what the above code gets translated to, we can inline the logic of the [bind] and [return] functions.
   But to do that more cleanly, we should start from the end and work backwards.
*)
let () =
  let run =
    (* fun c -> return c *)
    let _f1 c = return c in
    (* same as *)
    let f1 c state = (c, state) in
    (* fun b -> bind (add a b) ~f:f1 *)
    (* remember, [a] is in scope, so we emulate it by passing it as an argument to [f2] *)
    let f2 a b state =
      let c, state = add a b state in
      f1 c state
    in
    (* fun a -> bind (negate a) ~f:f2 a *)
    let f3 a state =
      let b, state = negate a state in
      f2 a b state
    in
    (* bind (new_var ()) ~f:f3 *)
    let f4 state =
      let a, state = new_var () state in
      f3 a state
    in
    f4
  in
  let init_state = { next = 2 } in
  let c, _ = run init_state in
  Format.printf "c3: %d\n" c

(* If we didn't work backwards, it would look like this: *)
let () =
  let run state =
    let a, state = new_var () state in
    (fun state ->
      let b, state = new_var () state in
      (fun state ->
        let c, state = add a b state in
        (fun state -> (c, state)) state)
        state)
      state
  in
  let init_state = { next = 2 } in
  let c, _ = run init_state in
  Format.printf "c4: %d\n" c
Well done! You've reached the end of my post. Now you can leave a comment or read something else.

Comments

lok0

This is a great explanation of state monads in OCaml.

leave a comment...