Ruby’s Enumerator Demystified
In the past couple of weeks I’ve discussed ruby’s #enum_for method with a few people, so I thought it would be a good idea to write up a description of what the method is and how it works.
To motivate the discussion, let’s consider a simple problem - given an array, we want to return a new array that contains the sum of each pair of elements in the original array. We’ll call our method #pair_sums. Here’s how the method should work:
[1, 2, 3, 4].pair_sums
# => [3, 7]
[1, 2, 3, 4, 5, 6].pair_sums
# => [3, 7, 11]
A Simple Solution
Below is a first pass solution to this problem. For simplicity, I’m assuming the array contains only integers and has an even number of elements.
class Array
def pair_sums
pair_sums = []
self.each_slice(2) do |a, b|
pair_sums << a + b
end
pair_sums
end
end
The code above is alright, but it’s annoying that we have to create a new array in a local variable. It would be much nicer if we could combine #each_pair with #map to build up a new array of the sums without having to create a local variable.
A Better Solution
Now let’s use #enum_for to get rid of that pesky local variable from the previous example.
class Array
def pair_sums
self.enum_for(:each_slice, 2).map do |a, b|
a + b
end
end
end
So what the heck is going on here? What is the #enum_for method doing? In a nutshell, it’s creating an instance of the Enumerator class and returning it. The Enumerator instance returned includes the Enumerable module and defines the #each method. The trick is that the Enumerator’s #each method is defined as calling #each_slice on our array with the argument 2.
I know, I know, it’s confusing. To help, let’s implement our own #my_enum_for method and MyEnumerator class.
Our Implementation
First, let’s define #my_enum_for on Array. It will take a method and some number of arguments and will return an instance of MyEnumerator.
class Array
def my_enum_for(method, *args)
MyEnumerator.new(self, method, *args)
end
end
Note that we instantiate MyEnumerator with the instance of Array on which we called #my_enum_for. We also pass along the method and arguments.
Next, we’ll have a look at the implementation of MyEnumerator. Take a second to read it over and then read the detailed description below.
class MyEnumerator
include Enumerable
def initialize(object, method, *args)
@object, @method, @args = object, method, args
end
def each
@object.send(@method, *@args) do |*block_args|
yield(*block_args)
end
end
end
The first thing we do in our MyEnumerator class is mix in the Enumerable module. When a class mixes in the Enumerable module, it agrees to a contract — the class must implement the #each method. If this contract is met, the Enumerable module provides a large number of enumeration methods (such as #map, #select, #inject, etc.) for free based on the implementation of #each.
In the #initialize method, we simply store the object, method and arguments in instance variables. We’ll use these in our #each method.
Finally, let’s look at the implementation of #each. Using the instance variables we stored in #initialize, we call the given method on our object with the provided arguments. It is expected that the given method also take a block. We then yield the block’s arguments. This means that we are defining #each on our MyEnumerator instance to call the specified method on the object we passed in. In our #pair_sums example, the MyEnumerator instance will call #each_slice with the argument of 2 on our array. Then, for each slice of two items, it will yield the values. Therefore, the behavior of #each mimics the behavior of #each_slice. Because the behavior of #map is dependent on the implementation of #each, we can now #map across each set of two items and sum them up.
With the above definitions, we can redefine #pair_sums using the method we created.
class Array
def pair_sums
self.my_enum_for(:each_slice, 2).map do |a, b|
a + b
end
end
end
An Even Better Solution
We can make things even nicer. In ruby 1.8.7 and later, many Enumerable methods called without a block will instead return an Enumerator instance wrapping the called method. You can imagine #each_slice to be defined as follows:
def each_slice(*args)
if block_given?
# implementation goes here ...
else
Enumerator.new(self, :each_slice, *args)
end
end
We can take advantage of this feature to implement #pair_sums in an even more concise manner.
class Array
def pair_sums
self.each_slice(2).map do |a, b|
a + b
end
end
end
Some Final Notes
Just a few quick notes for wrap-up:
- This is not the actual implementation of Enumerator. However, I find it helpful to think of the behavior this way.
- The Enumerator class has another instantiation form in which a block is passed to the #new method. It allows the Enumerator to be used similar to a python generator. That behavior is beyond the scope of this post.
Drew