Saturday, May 8, 2010

Fixed point combinator in Java

In combinatory logics there is an interesting thing called Y combinator which is used to implement recursion. Another name for it is fixed point combinator because it computes fixed point (e.g. f(x) = x) for functions.

Combinator Y applies function to its fixed point Y f = f (Y f) and it can be defined as Y = S S K (S (K (S S (S (S S K)))) K). The definition of the fixed point in Haskell is much simpler: fix f = f (fix f). The fix function evaluates to f ( f ( f ( .....)))) which seem to loop indefinitely. However, it does not either hang nor cause stack overflow because of lazy evaluation. The first argument fix f passed to f is evaluated only if it is needed for computation. For example, fix (\unused -> 0) gives result 0.

Function

factorial = fix (\f n -> if (n==0) then 1 else n * f (n-1))

computes a factorial of n. The function passed to fix takes a function f e.g. its fixed point and number n. fix passes only one argument to a function that takes two. The result is the function which takes the rest of the arguments. If n == 0 it returns 1, otherwise it calculates factorial for n - 1 and multiplies it by n.  It is interesting that recursion is created using anonymous function which does not know its name and hence, cannot call itself. The factorial function can be rewritten with explicit recursion as

factorial = \n -> if (n==0) then 1 else n * factorial (n-1). 

I wondered how fixed point combinator can be implemented in an object-oriented imperative language. I do mostly Java programming so I chose it for the experiment.

Making Java functional

Implementation of the fixed point combinator requires laziness. The methods(functions) are not first-class citizens and Java does not support partial application. So, before implementing fixed point combinator we need to overcome these problems.

To make the functions the first-class citizens we must represent them as objects. There are several methods to do it. The first is to create an interface with one method, which represents the function and implement it. The interface can be implemented either with a usual or anonymous class. In the latter case we can create the closures. Here is the example.


interface Function {
public Object func(Object[] params);
}

// inside some function
final int addendum = 5;
Function addFiveAndMultiply = new Function () {
public Object func(Object[] params) {
Integer a = (Integer) params[0];
Integer b = (Integer) params[1];
return (addendum + a) * b;
}
};

This implementation has two drawbacks. The first is that we need to write boilerplate code to accept parameters. The second is that the parameters must be objects, not primitives, although it is mitigated by autoboxing. They can be overcome by introducing a new interface for each signature, e.g. for function that takes int, String and returns String we create an interface like FunctionIntStringToString. It will give more type safety and allow accept parameters as usually, but if we use many functions with different types there will be even more boilerplate than before.

So I decided to switch to another variant. Java reflection mechanism has class Method, which, respectively, represents a method. I wrapped method with the class Partial which also handles partial application. This mechanism does not allow to create closures at place because we cannot get Method of anonymous class. But the methods are written in a usual style and we can wrap any static method in Partial, including even methods of other classes. There is less boilerplate code and it is separated from functions. The restriction to static is just a technical limitation of my implementation and it can be removed.

Function wrapper 


class Partial {
private Method method;
private List<Object> parameters;

public Partial(Method method) {
this.method = method;
this.parameters = new ArrayList<Object>();
}
public Partial apply(Object param) {
Partial result = new Partial(method);
result.parameters.addAll(parameters);
result.parameters.add(param);
return result;
}
public Object eval() throws Exception {
int arity = method.getParameterTypes().length;
if (parameters.size() < arity) {
return this;
}
Object result = method.invoke(null, parameters.subList(0, arity).toArray());
if (parameters.size() > arity) {
Partial p = (Partial) result;
p.parameters.addAll(parameters.subList(arity, parameters.size()));
return p.eval();
} else {
return result;
}
}
}

This class stores a method and a list of its parameters. We can apply arguments lazily using method apply. Method eval tries to execute the function with the parameters passed earlier.


The most interesting part in this class is method eval. Let's look at it closer. At first it computes arity - the number of arguments the function takes. Then if we have less parameters than arity, nothing can be computed and we just return this partial. If number of passed arguments is sufficient, we can invoke the function. Note there can be more arguments than function consumes. It is possible if function returns partial which will consume the rest of the arguments. In this case we cast the result to Partial, add the rest of the arguments and call eval recursively. The recursion can be replaced with a loop, but it will lead to more verbose code. The last clause is executed if arity and number of parameters passed are equal. In this case we just return the result.

Fixed point

Finally we have got a small functional framework and everything is ready to accomplish the initial task. Here we go!


public static Partial fix(Partial method) {
return method.apply(fix.apply(method));
}

This function is quite close to the Haskell function fix f = f (fix f)

Functional programming in Java

Now when we have functional abstractions and a fixed point combinator, let's write a functional program, which computes and prints factorials.


public class Functional {
static Partial fix;
static Partial facBody;
static Partial mapBody;
static Partial printFactorial;

public static void main(String[] args) throws Exception {
fix = new Partial(Functional.class.getMethod("fix", new Class[]{Partial.class}));
facBody = new Partial(Functional.class.getMethod("facBody", new Class[]{Partial.class, long.class}));
mapBody = new Partial(Functional.class.getMethod("mapBody", new Class[]{Partial.class, Partial.class, List.class}));
printFactorial = new Partial(Functional.class.getMethod("printFactorial", new Class[]{long.class}));

ArrayList<Long> xs = new ArrayList<Long>();
for (long i = 0; i < 16; i++) {
s.add(i);
}
map(printFactorial, xs);
}
public static Partial fix(Partial method) {
return method.apply(fix.apply(method));
}
public static List map(Partial f, List xs) throws Exception {
return (List) (fix(mapBody).apply(f)).apply(xs).eval();
}
public static List mapBody(Partial p, Partial f, List xs) throws Exception {
if (xs.isEmpty()) return new ArrayList();
Object head = f.apply(xs.get(0)).eval();
List rest = (List) (p.apply(f)).apply(xs.subList(1, xs.size())).eval();
rest.add(0, head);
return rest;
}
public static void printFactorial(long n) throws Exception {
System.out.printf("factorial(%d) = %d\n", n, factorial(n));
}
public static long factorial(long n) throws Exception {
return (Long) fix(facBody).apply(n).eval();
}
public static long facBody(Partial partial, long n) throws Exception {
return n == 0 ? 1 : n * ((Long)(partial.apply(n - 1).eval()));
}
}

Here one loop for is replaced with function map. It is possible to make this program look even more functional, for example, replace Java conditionals with function cond or generate the list with an analog of unfoldr.

Conclusion: Java is able to express some concepts of functional programming and combinatory logics.