July 7, 2013

Reducers explained (through Python)

Last year, Clojure introduced a new library called core.reducers, which represented a new, efficient way to deal with operations across collections in functional languages. It’s since been picked up by Elixir, and libraries have been written for some other languages.

Today, I want to explain a bit about what reducers are, why they exist, and how they can be more efficient than other functional methods of handling collections, and help you write your code more prettier.

Part 1: Reduce as a fundamental collection operator

First, some helper functions that we’ll use for demonstration.

def add(a, b):
  return a + b

def inc(a):
  return a + 1

def even(a):
  return a % 2 == 0

Alright, we’re ready to go. Let’s start by defining reduce (or fold, so as not to compete with python’s built-in reduce).

Fold will take a function, a collection/iterator, and an initial value.

# Fold can be defined recursively:
def fold_recursive(f, it, val):
  "Reduce, implemented recursively"
  it = it.__iter__() # iterators will return themselves, so this is safe
  try:
    return fold_recursive(f, it, f(val, it.next()))
  except StopIteration: # Thrown by next()
    return val

fold_recursive(add, [1, 2, 3], 0) # => 6

Two python-specific problems:

  • It’s sort of ugly, with that iter in there
  • Python lacks tail-call recursion, so it’s not too efficient:
try:
  fold_recursive(add, range(1000), 0) # Stack overflow
except RuntimeError:
  pass

So, here’s the pythonic version, with an iterator:

def fold(f, it, val):
  "Reduce, implemented iteratively"
  for head in it:
    try:
      val = f(val, head)
    except StopIteration: # We'll see why this is here later.
      return val
  return val

fold(add, range(1000), 0) # => 499500

Great! But did you know you can define the other list processing functions in terms of fold?

Let’s start with map:

def fold_map(f, it):
  "Apply a function to each element in a collection and return the result."
  def _mapper(acc, a): # Define an inner function that accepts acc and a...
    return acc + [f(a)] # ... and appends f(a) onto acc
  return fold(_mapper, it, [])

fold_map(inc, [1, 2, 3]) # => [2, 3, 4]
# = _mapper(_mapper(_mapper([], 1), 2), 3)

Filter is even easier!

def fold_filter(f, it):
  "Retain only elements in a collection for which f(<el>) is truthy"
  def _filterer(acc, a):
    if f(a):             # If f(a) is truthy...
      return acc + [a]   # ... then append a...
    return acc           # ... otherwise don't!
  return fold(_filterer, it, [])

fold_filter(even, [1, 2, 3, 4]) # => [2, 4]

We can make other stuff too.

def fold_take(n, it):
  "Get the first n elements in a collection"
  def _taker(acc, a):
    ii, a = a              # Get the index ii of element a
    if ii < n:
      return acc + [a]
    raise StopIteration()  # Stop reducing when we've had enough
  return fold(_taker, enumerate(it), [])

fold_take(10, [1, 2, 3]) # => [1, 2, 3]
fold_take(3, range(1000)) # => [0, 1, 2]

def fold_flatten(it):
  "Flatten any number of nested collections into one list"
  def _flattener(acc, a):
    try:
      return acc + list(fold_flatten(iter(a))) # Recursion
    except TypeError: # a is not iterable
      return acc + [a]
  return fold(_flattener, it, [])

fold_flatten([1, [2, [3, 4], 5]]) # => [1, 2, 3, 4, 5]

Part 2: A problem

If you’re not familiar with these functions they’re great for pipeline-style programming

def sum(it):
  return fold(add, it, 0)

def squares(it):
  return fold_map(lambda x: x * x, it)

def sum_of_squares(it):
  return sum(squares(it))

sum_of_squares([1, 2, 3]) # => 14 (that's 1 + 4 + 9)

Look how pretty sum_of_squares is compared to:

def ugly_sum_of_squares(it):
  ii = 0
  for thing in it:
    ii += thing * thing
  return ii

Ok, so it’s not the worst, but as complexity increases you’ll benefit more from the pipeline’d version.

But, there’s an obvious problem with this situation: ugly_sum_of_squares goes through the list once, but sum_of_squares goes through the whole thing twice!

Now, why did we go through implementing that stuff with fold? Well, it turns out there’s a solution to the problem that comes up naturally when you consider that fold (reduce) can be a fundamental building block of your collection functions. What if we could just modify the function used by reduce? Then, the operation could be done in one pass!

This is the idea behind reducers. If you take your mapping function (map, filter, flatten, etc.), and have it modify the reducing function, you can perform any number and combination of mappings without having to repeatedly iterate through the list.

Of course, this only works if your collection remembers how its reducer has been modified as you pass it around. We’ll call this reducer-aware collection a “Reducible”.

Part 3: Enter reducers

Let’s first define an identity function, since it will come up:

def ridentity(r):
  return r

We’re going to define a Reducible object. The Reducible is just a wrapper around two things:

  • An iterable of some sort
  • A function that’s going to modify the reducing function. This is the eponymous “reducer”

Let’s get some terms straight. “Reducing function” is the function that actually gets passed to fold. The “reducer” is the function that will operate on the reducing function to modify it. It’s the “reducer” which gets tied to a collection/iterator to create a “reducible”. Note also that “reducers” can be chained, since they accept a reducing function and return one, too.

Our Reducer will also have a reduce method, which will use fold above but will modify the function passed to it with the reducer function.

class Reducible(object):
  # Reducible has an identity reducer by default
  def __init__(self, it, reducer=ridentity):

    # We'll let Reducible wrap other reducibles
    if isinstance(it, Reducible):
      self.it = it

      # self.reducer always has the signature f(r_fn), 
      # where r_fn is the actual function that we'll
      # eventually reduce with
      self.reducer = lambda r: it.reducer(reducer(r))

    else:
      self.it = it
      self.reducer = reducer

  # Reducibles are iterable  
  def __iter__(self):
    return iter(self.it)

  # Reduce will use fold, but modify f first
  def reduce(self, f, val):
    return fold(self.reducer(f), self.it, val)

# Let's take it for a spin!
Reducible([1, 2, 3, 4]).reduce(add, 0) # => 10

The secret sauce is this line:

self.reducer = lambda r: it.reducer(reducer(r))

This lets the Reducible wrap other Reducibles, and chains their reducers together. It’s just the composition of it.reducer and reducer.

That’s it! We can hide the details a little by making a new version of fold that secretly uses a Reducible:

def rreduce(f, it, val):
  return Reducible(it, lambda r: r).reduce(f, val)

Now, we can go on to implement our other collection functions. Remember, these will only need to modify a reducing function. They’re actually a bit simpler for it:

def rmap(f, it):
  # We'll use _reducer as the input to reducible
  def _reducer(f1):
    # _mapper is the function that _reducer will produce
    def _mapper(ret, v):
      return f1(ret, f(v))
    return _mapper
  return Reducible(it, _reducer)

Did you get that? _reducer returns _mapper, and _mapper just applies the function to the value before calling the overall reducing function on (ret, v). Notice that it doesn’t do a single thing about actually making a new list; this is left to whatever reducer is used on it.

So, let’s try it out:

rmap(inc, [1, 2, 3]) # <Reducible object>

That’s right! We won’t get anything out until we actually call reduce on it. This is the catch to reducers, but with the right set of utilities it’s no trouble at all.

Note now that we can chain maps:

rreduce(add, rmap(inc, [1, 2, 3]), 0) # 9 (That's 2 + 3 + 4)
rreduce(add, rmap(inc, rmap(inc, [1, 2, 3])), 0) # 12 (That's 3 + 4 + 5)

Great, it seems to have worked!

But, how about a utility to help us get a list out of rreduce. Let’s just rip something off wholesale from clojure.

First, we need a multiple-dispatch add-a-thing-to-a-collection function we’ll call conj

def conj(coll, thing):

  if isinstance(coll, list):
    coll.append(thing)

  elif isinstance(coll, str):
    return coll + str(thing)

  elif isinstance(coll, tuple):
    return coll + (thing,)

  elif isinstance(coll, dict):
    coll[thing[0]] = thing[1]

  return coll

Now, let’s see what happens if we reduce with conj:

def into(coll1, coll2):
  return rreduce(conj, coll2, coll1)

into([], [1, 2, 3]) # => [1, 2, 3]
into(['a', 'b'], [1, 2, 3]) # => ['a', 'b', 1, 2, 3]
into("", [1, "2", 3]) # => "123"
into((), [1, 2, 3]) # => (1, 2, 3)
into({}, [('a', 1), ('b', 2)]) # => {'a': 1, 'b': 2}

# Super useful! Now we can examine rmap's output:
into([], rmap(inc, [6, 5, 4])) # => [7, 6, 5]

As for sum of squares:

def rsquares(it):
  return rmap(lambda x: x * x, it)

def rsum(it):
  return rreduce(add, it, 0)

def rsum_of_squares(it):
  return rsum(rsquares(it))

rsum_of_squares([1, 2, 3]) # 14, again

But, now we’ve done it in one pass!

Here are a few more examples of list functions implemented with reducers:

def rreverse(it):
  "Reverse a list"
  def _reverse_inner(ret, v):
    return [v] + ret
  return rreduce(_reverse_inner, it, [])

rreverse([1, 2, 3]) # => [3, 2, 1]

def rfilter(f, it):
  "Return a Reducible that filters `it` with `f`"
  def _reducer(f1):
    def _filterer(ret, v):
      return f1(ret, v) if f(v) else ret
    return _filterer
  return Reducible(it, _reducer)

into([], rfilter(even, range(10))) # [0, 2, 4, 6, 8]

def rflatten(it):
  "Return a Reducible that flattens `it`"
  def _reducer(f1):
    def _flattener(ret, val):
      try:
        return rreduce(f1, rflatten(iter(val)), ret)
      except TypeError: # val is not iterable; nothing to flatten
        return f1(ret, val)
    return _flattener
  return Reducible(it, _reducer)

into([], rflatten([1, [2, [3, 4, [5]]]])) # [1, 2, 3, 4, 5]
# It even handles mixed types, no special thought required.
into((), rflatten([1, (2, 3, set([4, 5]))])) # (1, 2, 3, 4, 5)

# Sometimes you need to store state
class Counter(object):
  def __init__(self):
    self.count = 0

  def inc(self):
    c = self.count
    self.count += 1
    return c

def rdrop(n, it):
  cnt = Counter()
  def _reducer(f1):
    def _dropper(ret, val):
      if cnt.inc() >= n:
        return f1(ret, val)
      return ret
    return _dropper
  return Reducible(it, _reducer)

into([], rdrop(7, range(10))) # [7, 8, 9]

# Oh, and remember the StopIteration in the original fold?
# It lets us be lazy!

def rtake(n, it):
  cnt = Counter()
  def _reducer(f1):
    def _taker(ret, val):
      if cnt.inc() < n:
        return f1(ret, val)
      raise StopIteration()
    return _taker
  return Reducible(it, _reducer)

def inf_range():
  ii = 0
  while True:
    yield ii
    ii += 1

into([], rtake(3, range(10))) # => [0, 1, 2]
into([], rtake(3, inf_range())) # => [0, 1, 2]

So we’ve seen that reducers are a way to be incredibly general about how you handle collections. We’ve taken pythons iterable interface and turned it into a full-featured collection-processing beast! They may not be as efficient as Python’s C builtins, but they’re pretty good when it comes to pure Python.

Of course, languages such as Clojure and Elixir have some other properties that make the reducer pattern very effective:

  1. They are functional to begin with, and so provide better performance for code written in a functional style.

  2. They have efficient implementations of immutable collections.

  3. They use Protocols, which are a way of adding new functionality onto existing classes. So, instead of wrapping everything in Reducible, they can instead add a reduce method onto an existing collection type.

  4. With some conditions on the inputs, they can easily parallelize fold

Ok, that’s it for today. I hope you learned something cool!