Monad transformer commutativity
What is commutativity?
Commutativity means we can switch the order of two things before we smoosh them together and get the same result as if we smooshed them together in their original order. A more rigorous definition is not required for the contents of this post.
Addition over the reals is commutative - here is one example:
1 + 2 = 2 + 1
What is monad transformer commutativity?
If we stack up two monad transformers, the transformers commute if the stack’s result type has the same “shape” regardless of the order in which we stacked the transformers.
For me, this is easier to understand with examples. Here is a script we
can load directly into GHCi via stack <script_name>.hs
:
#!/usr/bin/env stack
-- stack --resolver lts-9.1 --install-ghc exec ghci --package transformers
import Control.Monad.Trans.Identity (IdentityT(..))
import qualified Control.Monad.Trans.Identity as Identity
import Control.Monad.Trans.Writer (Writer(..), WriterT(..))
import qualified Control.Monad.Trans.Writer as Writer
import Data.Functor.Identity (Identity(..))
identityWriter :: IdentityT (Writer w) a -> (a, w)
= Writer.runWriter (Identity.runIdentityT m)
identityWriter m
writerIdentity :: WriterT w Identity a -> (a, w)
= runIdentity (Writer.runWriterT m) writerIdentity m
In the above, we have implemented the two different ways we can stack
IdentityT
and WriterT w
. The first one has Writer w
as the base
monad while the second one has Identity
as the base monad. The bodies
of our functions do not do anything extra besides running the two
transformers in the stack.
Note that the result type is (a, w)
for both functions. The order of
our stacks did not change the result type, so stacking IdentityT
and
WriterT
is commutative. This conceptually makes sense - IdentityT
wraps a monad but does not do anything else. This means the result type
can not suddenly change if we throw an IdentityT
into our stacks. We
could even say that IdentityT
commutes with all other transformers!
Let’s look at an example that does not use IdentityT
:
#!/usr/bin/env stack
-- stack --resolver lts-9.1 --install-ghc exec ghci --package transformers
import Control.Monad.Trans.Except (Except(..), ExceptT(..))
import qualified Control.Monad.Trans.Except as Except
import Control.Monad.Trans.State (State(..), StateT(..))
import qualified Control.Monad.Trans.State as State
stateExcept :: StateT s (Except e) a -> s -> Either e (a, s)
= Except.runExcept (State.runStateT m s)
stateExcept m s
exceptState :: ExceptT e (State s) a -> s -> (Either e a, s)
= State.runState (Except.runExceptT m) s exceptState m s
In the above, we implemented the two different ways we can stack StateT s
and ExceptT e
. We can immediately see from the return types that
stacking these transformers is not commutative.
If we choose Except e
as our base monad, our result type is Either e (a, s)
. This means we either have an exception value of type e
or we have a pair of our monadic computation’s result and final state.
If we do wind up with a Left <something>
, we have no access to our
computation’s intermediate state.
If we choose State s
as our base monad, our result type is (Either e a, s)
. This means we have a pair where the first element is either an
exception value of type e
or our monadic computation’s result, and the
second element is our computation’s final state.
Which one we choose depends on our use case.
Practice, Practice, Practice
As with most things we learn, gaining an intuition for how to stack transformers in the desired order takes time and practice. When I was first exploring them, I found it very helpful to stack together pairs of transformers and let GHC tell me if I had figured out the correct return types.
The rest of this post is a big GHCi script that enumerates the different ways we can stack the most common transformers up to stacks of size 2. I recommend writing out this enumeration yourself (it is tedious but effective) and only using the included script as a reference.
See if you can spot the stacking orders that are not commutative!
#!/usr/bin/env stack
-- stack --resolver lts-9.1 --install-ghc exec ghci --package transformers
import Control.Monad.Trans.Except (Except(..), ExceptT(..))
import qualified Control.Monad.Trans.Except as Except
import Control.Monad.Trans.Identity (IdentityT(..))
import qualified Control.Monad.Trans.Identity as Identity
import Control.Monad.Trans.Maybe (MaybeT(..))
import qualified Control.Monad.Trans.Maybe as Maybe
import Control.Monad.Trans.RWS (RWS(..), RWST(..))
import qualified Control.Monad.Trans.RWS as RWS
import Control.Monad.Trans.Reader (Reader(..), ReaderT(..))
import qualified Control.Monad.Trans.Reader as Reader
import Control.Monad.Trans.State (State(..), StateT(..))
import qualified Control.Monad.Trans.State as State
import Control.Monad.Trans.Writer (Writer(..), WriterT(..))
import qualified Control.Monad.Trans.Writer as Writer
import Data.Functor.Identity (Identity(..))
import qualified Data.Functor.Identity as Identity
import Data.Maybe (Maybe(..))
import qualified Data.Maybe as Maybe
plainIdentity :: Identity a -> a
= runIdentity m
plainIdentity m
plainReader :: Reader r a -> r -> a
= Reader.runReader m
plainReader m
plainWriter :: Writer w a -> (a, w)
= Writer.runWriter m
plainWriter m
plainState :: State s a -> s -> (a, s)
= State.runState m s
plainState m s
plainMaybe :: Maybe a -> b -> (a -> b) -> b
= maybe d f m
plainMaybe m d f
plainExcept :: Except e a -> Either e a
= Except.runExcept m
plainExcept m
plainRWS :: RWS r w s a -> r -> s -> (a, s, w)
= RWS.runRWS m r s
plainRWS m r s
identityIdentity :: IdentityT Identity a -> a
= runIdentity (Identity.runIdentityT m)
identityIdentity m
identityReader :: IdentityT (Reader r) a -> r -> a
= Reader.runReader (Identity.runIdentityT m) r
identityReader m r
identityWriter :: IdentityT (Writer w) a -> (a, w)
= Writer.runWriter (Identity.runIdentityT m)
identityWriter m
identityState :: IdentityT (State s) a -> s -> (a, s)
= State.runState (Identity.runIdentityT m) s
identityState m s
identityMaybe :: IdentityT Maybe a -> b -> (a -> b) -> b
= maybe d f (Identity.runIdentityT m)
identityMaybe m d f
identityExcept :: IdentityT (Except e) a -> Either e a
= Except.runExcept (Identity.runIdentityT m)
identityExcept m
identityRWS :: IdentityT (RWS r w s) a -> r -> s -> (a, s, w)
= RWS.runRWS (Identity.runIdentityT m) r s
identityRWS m r s
readerIdentity :: ReaderT r Identity a -> r -> a
= runIdentity (Reader.runReaderT m r)
readerIdentity m r
readerReader :: ReaderT r (Reader r') a -> r -> r' -> a
= Reader.runReader (Reader.runReaderT m r) r'
readerReader m r r'
readerWriter :: ReaderT r (Writer w) a -> r -> (a, w)
= Writer.runWriter (Reader.runReaderT m r)
readerWriter m r
readerState :: ReaderT r (State s) a -> r -> s -> (a, s)
= State.runState (Reader.runReaderT m r) s
readerState m r s
readerMaybe :: ReaderT r Maybe a -> r -> b -> (a -> b) -> b
= maybe d f (Reader.runReaderT m r)
readerMaybe m r d f
readerExcept :: ReaderT r (Except e) a -> r -> Either e a
= Except.runExcept (Reader.runReaderT m r)
readerExcept m r
readerRWS :: ReaderT r (RWS r' w s) a -> r -> r' -> s -> (a, s, w)
= RWS.runRWS (Reader.runReaderT m r) r' s
readerRWS m r r' s
writerIdentity :: WriterT w Identity a -> (a, w)
= runIdentity (Writer.runWriterT m)
writerIdentity m
writerReader :: WriterT w (Reader r) a -> r -> (a, w)
= Reader.runReader (Writer.runWriterT m) r
writerReader m r
writerWriter :: WriterT w (Writer w') a -> ((a, w), w')
= Writer.runWriter (Writer.runWriterT m)
writerWriter m
writerState :: WriterT w (State s) a -> s -> ((a, w), s)
= State.runState (Writer.runWriterT m) s
writerState m s
writerMaybe :: WriterT w Maybe a -> (b, w') -> ((a, w) -> (b, w')) -> (b, w')
= maybe d f (Writer.runWriterT m)
writerMaybe m d f
writerExcept :: WriterT w (Except e) a -> Either e (a, w)
= Except.runExcept (Writer.runWriterT m)
writerExcept m
writeRWS :: WriterT w (RWS r w' s) a -> r -> s -> ((a, w), s, w')
= RWS.runRWS (Writer.runWriterT m) r s
writeRWS m r s
stateIdentity :: StateT s Identity a -> s -> (a, s)
= runIdentity (State.runStateT m s)
stateIdentity m s
stateReader :: StateT s (Reader r) a -> s -> r -> (a, s)
= Reader.runReader (State.runStateT m s) r
stateReader m s r
stateWriter :: StateT s (Writer w) a -> s -> ((a, s), w)
= Writer.runWriter (State.runStateT m s)
stateWriter m s
stateState :: StateT s (State s') a -> s -> s' -> ((a, s), s')
= State.runState (State.runStateT m s) s'
stateState m s s'
stateMaybe :: StateT s Maybe a -> s -> (b, s') -> ((a, s) -> (b, s')) -> (b, s')
= maybe d f (State.runStateT m s)
stateMaybe m s d f
stateExcept :: StateT s (Except e) a -> s -> Either e (a, s)
= Except.runExcept (State.runStateT m s)
stateExcept m s
stateRWS :: StateT s (RWS r w s') a -> s -> r -> s' -> ((a, s), s', w)
= RWS.runRWS (State.runStateT m s) r s'
stateRWS m s r s'
maybeIdentity :: MaybeT Identity a -> Maybe a
= runIdentity (Maybe.runMaybeT m)
maybeIdentity m
maybeReader :: MaybeT (Reader r) a -> r -> Maybe a
= Reader.runReader (Maybe.runMaybeT m) r
maybeReader m r
maybeWriter :: MaybeT (Writer w) a -> (Maybe a, w)
= Writer.runWriter (Maybe.runMaybeT m)
maybeWriter m
maybeState :: MaybeT (State s) a -> s -> (Maybe a, s)
= State.runState (Maybe.runMaybeT m) s
maybeState m s
maybeMaybe :: MaybeT Maybe a -> (Maybe b) -> (Maybe a -> Maybe b) -> Maybe b
= maybe d f (Maybe.runMaybeT m)
maybeMaybe m d f
maybeExcept :: MaybeT (Except e) a -> Either e (Maybe a)
= Except.runExcept (Maybe.runMaybeT m)
maybeExcept m
maybeRWS :: MaybeT (RWS r w s) a -> r -> s -> ((Maybe a), s, w)
= RWS.runRWS (Maybe.runMaybeT m) r s
maybeRWS m r s
exceptIdentity :: ExceptT e Identity a -> Either e a
= runIdentity (Except.runExceptT m)
exceptIdentity m
exceptReader :: ExceptT e (Reader r) a -> r -> Either e a
= Reader.runReader (Except.runExceptT m) r
exceptReader m r
exceptWriter :: ExceptT e (Writer w) a -> (Either e a, w)
= Writer.runWriter (Except.runExceptT m)
exceptWriter m
exceptState :: ExceptT e (State s) a -> s -> (Either e a, s)
= State.runState (Except.runExceptT m) s
exceptState m s
exceptMaybe :: ExceptT e Maybe a -> Either e' b -> (Either e a -> Either e' b) -> Either e' b
= maybe d f (Except.runExceptT m)
exceptMaybe m d f
exceptExcept :: ExceptT e (Except e') a -> Either e' (Either e a)
= Except.runExcept (Except.runExceptT m)
exceptExcept m
exceptRWS :: ExceptT e (RWS r w s) a -> r -> s -> (Either e a, s, w)
= RWS.runRWS (Except.runExceptT m) r s
exceptRWS m r s
rwsIdentity :: RWST r w s Identity a -> r -> s -> (a, s, w)
= runIdentity (RWS.runRWST m r s)
rwsIdentity m r s
rwsReader :: RWST r w s (Reader r') a -> r -> s -> r' -> (a, s, w)
= Reader.runReader (RWS.runRWST m r s) r'
rwsReader m r s r'
rwsWriter :: RWST r w s (Writer w') a -> r -> s -> ((a, s, w), w')
= Writer.runWriter (RWS.runRWST m r s)
rwsWriter m r s
rwsState :: RWST r w s (State s') a -> r -> s -> s' -> ((a, s, w), s')
= State.runState (RWS.runRWST m r s) s'
rwsState m r s s'
rwsMaybe :: RWST r w s Maybe a -> r -> s -> (b, s', w') -> ((a, s, w) -> (b, s', w')) -> (b, s', w')
= maybe d f (RWS.runRWST m r s)
rwsMaybe m r s d f
rwsExcept :: RWST r w s (Except e) a -> r -> s -> Either e (a, s, w)
= Except.runExcept (RWS.runRWST m r s)
rwsExcept m r s
rwsRws :: RWST r w s (RWS r' w' s') a -> r -> s -> r' -> s' -> ((a, s, w), s', w')
= RWS.runRWS (RWS.runRWST m r s) r' s' rwsRws m r s r' s'