Michael Sproul

Traits as Higher-order Functions

Written by Michael Sproul on July 9, 2015.

Although Rust shares many features with modern functional languages there are several things missing that make Haskell-style functional programming impractical. Functions can be passed to and returned from functions, but the distinctions between unique function types, function pointers and different types of closures make everything somewhat more heavy-weight than Haskell. For example, there’s currently no way to neatly create new functions with function composition (no (.)) and no currying. This is fine - Rust is a systems programming language, and precise control over memory and performance mostly justifies the complexity. If however, you still find yourself wanting to glue lots of pure functions together, read on!

The setting for our functional adventure is my radix trie library, which I wrote mostly in a functional style and was in the process of bolting new features onto. The problem was that each new feature I added required duplication of essentially the same logic. I noticed that all the important operations on my trie followed the same structure - here’s insert with the details removed:

fn insert(&mut self, key: K, value: V, mut key_fragments: NibbleVec) -> Option<V> {
    if key_fragments.len() == 0 {
        // Do something with the root and return early.
    }

    let bucket = key_fragments.get(0) as usize;

    let result = match self.children[bucket] {
        None => // Do something and return early.
        Some(ref mut child) => {
            match match_keys(&key_fragments, &child.key) {
                KeyMatch::Full => // Some action for full key matches.
                KeyMatch::Partial(idx) => // Some action for partial key matches.
                KeyMatch::FirstPrefix => // Some action for prefix matches.
                // Split the key and recurse.
                KeyMatch::SecondPrefix => {
                    let new_key = key_fragments.split(child.key.len());
                    child.insert(key, value, new_key)
                }
            }
        }
    };

    // Process the intermediate result and return something of the same type.
    f(result)
}

Feel free to just skim these code samples, the only important thing I want to highlight in this block is the existence of different actions for different cases. Now, under the veneer of respectability, I’m actually a horrendous programmer, so my first sincere attempt to fix this problem was to create an uber function with no less than ELEVEN type parameters. It looked more or less like this:

fn uber<K, V, Key, Value, Output,
        RootFn, NoChildFn, FullMatchFn, PartialMatchFn, FirstPrefixFn, ActionFn>
(
    trie: &mut Trie<K, V>,
    key: Key,
    value: Value,
    mut key_fragments: NibbleVec,
    root_fn: RootFn,
    no_child_fn: NoChildFn,
    full_match_fn: FullMatchFn,
    partial_match_fn: PartialMatchFn,
    first_prefix_fn: FirstPrefixFn,
    action_fn: ActionFn
) -> Output
where
    RootFn: Fn(&mut Trie<K, V>, Key, Value) -> Output,
    NoChildFn: Fn(&mut Trie<K, V>, Key, Value, NibbleVec, usize) -> Output,
    FullMatchFn: Fn(&mut Trie<K, V>, Key, Value, NibbleVec) -> Output,
    PartialMatchFn: Fn(&mut Trie<K, V>, Key, Value, NibbleVec, usize) -> Output,
    FirstPrefixFn: Fn(&mut Trie<K, V>, Key, Value, NibbleVec) -> Output,
    ActionFn: Fn(&mut Trie<K, V>, Output, usize) -> Output
{
    if key_fragments.len() == 0 {
        return root_fn(self, key, value);
    }

    let bucket = key_fragments.get(0) as usize;

    let intermediate = match self.children[bucket] {
        None => return no_child_fn(self, key, value, key_fragments),
        Some(ref mut child) => {
            match match_keys(&key_fragments, &child.key) {
                KeyMatch::Full =>
                    full_match_fn(child, key, value, key_fragments),
                KeyMatch::Partial(idx) =>
                    partial_match_fn(child, key, value, key_fragments, idx),
                KeyMatch::FirstPrefix =>
                    first_prefix_fn(child, key, value, key_fragments),
                // Split the key and recurse.
                KeyMatch::SecondPrefix => {
                    let new_key = key_fragments.split(child.key.len());
                    uber(child, key, value, new_key, root_fn,
                         no_child_fn, full_match_fn, partial_match_fn,
                         first_prefix_fn, action_fn)
                }
            }
        }
    };

    action_fn(self, intermediate, bucket)
}

Thankfully, this horrendous mess doesn’t even compile. The reason for this is that the function action_fn is moved into the recursive uber call. Adding a Copy or a Clone bound to ActionFn allows the uber function to compile but makes it more or less useless, as closures aren’t ever Copy or Clone. With closures ruled out, it’s still possible to pass a regular function, but then we may as well use the fn types directly to reduce noise in the type signature. When I was mucking around with this, I started to do this to all of the higher-order functions and noticed that my calls to uber started to look a lot like trait implementations:

fn insert<K, V>(trie: &mut Trie<K, V>, key: K, value: V) -> Option<V> where K: TrieKey {
    // Oh look, loads of functions defined inside a thing...
    fn root_fn<K, V>(trie: &mut Trie<K, V>, key: K, value: V) -> Option<V> {
        trie.replace_value(key, value)
    }
    fn action_fn<K, V>(trie: &mut Trie<K, V>, x: Option<V>, _: usize) -> Option<V> {
        if x.is_none() {
            trie.length += 1;
        }
        x
    }
    // ... more function definitions ...
    uber(...)
}

This led me down the path of creating a trait for traversals (a better name than uber), with a few nice side-effects:

  • Traits allow for default implementations - no need to specify every function for every traversal.
  • All the library traversals do the same thing in root_fn and full_match_fn, so full_match_fn can default to calling root_fn (renamed to match_fn).
  • The processing function, action_fn, can default to returning the intermediate result.
  • Less type parameters in type signatures thanks to associated types and the removal of Fn trait bounds.

With a few more unrelated simplifications, this gives a traversal trait that looks like this:

pub trait TraversalMut<'a, K: 'a, V: 'a> where K: TrieKey {
    type Input: 'a;
    type Output;

    fn default_output() -> Self::Output;

    fn match_fn(trie: &mut Trie<K, V>, input: Self::Input) -> Self::Output {
        Self::default_result()
    }

    fn no_child_fn(trie: &mut Trie<K, V>, input: Self::Input,
                   nv: NibbleVec, bucket: usize) -> Self::Output {
        Self::default_result()
    }

    fn child_match_fn(child: &mut Trie<K, V>, input: Self::Input,
                      nv: NibbleVec) -> Self::Output {
        Self::match_fn(child, input)
    }

    fn partial_match_fn(child: &mut Trie<K, V>, input: Self::Input,
                        nv: NibbleVec, idx: usize) -> Self::Output {
        Self::default_result()
    }

    fn first_prefix_fn(trie: &mut Trie<K, V>, input: Self::Input,
                       nv: NibbleVec) -> Self::Output {
        Self::default_result()
    }

    fn action_fn(trie: &mut Trie<K, V>, intermediate: Self::Output,
                 bucket: usize) -> Self::Output {
        intermediate
    }

    fn run(trie: &mut Trie<K, V>, input: Self::Input,
           mut key_fragments: NibbleVec) -> Self::Output {

        if key_fragments.len() == 0 {
            return Self::match_fn(trie, input);
        }

        let bucket = key_fragments.get(0) as usize;

        let intermediate = match trie.children[bucket] {
            None => return Self::no_child_fn(trie, input, key_fragments, bucket),
            Some(ref mut child) => {
                match match_keys(&key_fragments, &child.key) {
                    KeyMatch::Full =>
                        Self::child_match_fn(child, input, key_fragments),
                    KeyMatch::Partial(i) =>
                        Self::partial_match_fn(child, input, key_fragments, i),
                    KeyMatch::FirstPrefix =>
                        Self::first_prefix_fn(child, input, key_fragments),
                    KeyMatch::SecondPrefix => {
                        let new_key = key_fragments.split(child.key.len());
                        Self::run(child, input, new_key)
                    }
                }
            }
        };

        Self::action_fn(trie, intermediate, bucket)
    }
}

All the nice traversal logic is bundled up in run, and a minimal (albeit useless) implementation only requires definitions for Input, Output and default_output. None of these functions have a self parameter so I used empty enums as the target types:

enum Insert {}
impl<'a, K: 'a, V: 'a> TraversalMut<'a, K, V> for Insert { ... }

Implementing insert and remove operations in terms of this trait was quite a pleasant experience, but when it came to implementing lookups and returning references I realised I needed to change the lifetime and mutability of the input trie, but that’s a story for another time (if you like macros you can check out the full source).

In conclusion, Rust has some nice facilities for dealing with higher-order functions, and doing insanely higher-order things is definitely possible with just fn types and Fn traits. However, one can avoid some amount of mucking around by defining a trait whose singular role it is to run a bunch of functions. Plus, with traits you get default implementations and you can call your thing a computational structure!

Thanks for reading!

Michael


Tagged: rust, functional-programming, algorithms