portrait

End of Line blog

Thoughts on software development, by Adam Ruka

Graal Truffle tutorial part 3 – specializations with Truffle DSL, TypeSystem

This article is part of a tutorial on GraalVM's Truffle language implementation framework.


In the previous article, we implemented efficient addition of JavaScript numbers, with specializations handling 32-bit integer and floating-point double cases. However, in order to accomplish that, we had to write over a hundred lines of fairly repetitive code. If we wanted to implement a similar scheme for the other operations like subtraction, multiplication, etc., the amount of code would quickly balloon into thousands of lines.

Fortunately, Truffle ships with a solution to that problem: the Truffle DSL.

The Truffle DSL is a Java annotation processor. If you haven’t heard about that technology before, it’s a way the JVM standard provides to hook into the Java compilation process, and generate additional code that gets compiled alongside your hand-written code. There are many libraries in the Java ecosystem that use this technique, the most famous ones being Lombok, Dagger2, Google’s Auto, and Immutables.

The idea behind the Truffle DSL is to alleviate the need for writing all of the boilerplate code managing the state machine in each Node needed to keep track of the active specializations. Instead, you simply write the logic of executing your specialized operations, and all of the code dealing with transitions of the state machine is automatically generated by the annotation processor.

Using the Truffle DSL

To use the Truffle DSL, you have to structure your Node code a little differently than we did in the previous article. Because annotation processors can only generate new Java classes, and not modify existing ones, we need to make the Node classes that use the DSL abstract, and the concrete class will be generated by the DSL.

(And yes, I know that Lombok manages to break this “no modifications of existing classes” restriction. However, that’s done in a way that is not allowed by the annotation processor standard.)

Since our Node class is now abstract, we can’t simply declare fields in our class and annotate them with @Child, as the generator would not know how to initialize them correctly. So instead of using the @Child on fields to denote AST subtrees, we use the @NodeChild annotation on the class itself:

import com.oracle.truffle.api.dsl.NodeChild;

@NodeChild("leftNode")
@NodeChild("rightNode")
public abstract class AdditionNode extends EasyScriptNode {
    // ...
}

This tells the annotation processor that it should generate a Node class with two @Child fields, named leftNode and rightNode. Since AdditionNode extends EasyScriptNode, the type of those fields will be EasyScriptNode – you can change that with the type attribute of the @NodeChild annotation.

(If you’re writing your interpreter in a Java version before 8, you can use the @NodeChildren annotation instead:

import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.NodeChildren;

@NodeChildren({
    @NodeChild("leftNode"),
    @NodeChild("rightNode")
})
public abstract class AdditionNode extends EasyScriptNode {
    // ...
}

, as Java 8 was the first version that allowed repeating the same annotation multiple times on the same target)

Now, in the body of the node class, you don’t implement any of the execute*() methods. Instead, you write instance methods that implement the specializations you want the node to handle. Each of those methods takes as arguments the values obtained from executing that node’s children, with the assumption that the given specialization is active. Because of that, the number of arguments to each method must be the same as the number of @NodeChild annotations placed on the class (in our case, that’s 2). The actual execute*() methods will be implemented in the generated class by delegating to our methods – after evaluating the child nodes, and verifying the given specialization is active, of course.

The methods need to have the @Specialization annotation placed on them, which allows you to customize the generated code; we’ll discuss some of those customizations below. The implementations of the methods themselves are pretty simple; you perform the operation, return the result, and that’s pretty much it (they are meant to be inlined during JIT compilation).

With that said, here’s how AdditionNode looks for EasyScript handling 32-bit integer and double addition:

import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;

@NodeChild("leftNode")
@NodeChild("rightNode")
public abstract class AdditionNode extends EasyScriptNode {
    @Specialization(rewriteOn = ArithmeticException.class)
    protected int addInts(int leftValue, int rightValue) {
        return Math.addExact(leftValue, rightValue);
    }

    @Specialization(replaces = "addInts")
    protected double addDoubles(double leftValue, double rightValue) {
        return leftValue + rightValue;
    }
}

That’s it! These few lines of code are equivalent to the over 100-line long AdditionNode class from the previous article.

We used a couple customizations @Specialization offers. First, we used the rewriteOn attribute to deactivate the int specialization when addInts throws ArithmeticException. If we didn’t do that, Math.addExact() throwing ArithmeticException would simply terminate our interpreter! Second, we indicated that the double specialization is a superset of the int one with the replaces attribute; otherwise, both of them could be active at the same time, which would result in generating sub-optimal machine code.

Now, you might be curious: how you actually use this new class? The DSL generates it with a well-known name, by appending Gen to the name of its superclass. So, in our example, the generated class will have the name AdditionNodeGen. That class contains a static factory method, create(), which returns AdditionNode, and you can use that method to create instances of the auto-generated node class. The create method takes in as many arguments as @NodeChild annotations you placed on your abstract class.

Here’s a simple test, showing our addition working correctly:

import com.oracle.truffle.api.CallTarget;

import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class OverflowTest {
    @Test
    public void adding_1_to_int_max_does_not_overflow() {
        EasyScriptNode exprNode = AdditionNodeGen.create(
                new DoubleLiteralNode(Integer.MAX_VALUE),
                new DoubleLiteralNode(1));
        var rootNode = new EasyScriptRootNode(exprNode);
        CallTarget callTarget = rootNode.getCallTarget();

        var result = callTarget.call();

        assertEquals(Integer.MAX_VALUE + 1D, result);
    }
}

Truffle DSL code

If you’re curious, you can look at the code generated by the DSL for AdditionNodeGen:

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
import com.oracle.truffle.api.dsl.GeneratedBy;
import com.oracle.truffle.api.dsl.UnsupportedSpecializationException;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import com.oracle.truffle.api.nodes.NodeCost;
import com.oracle.truffle.api.nodes.UnexpectedResultException;
import java.util.concurrent.locks.Lock;

@GeneratedBy(AdditionNode.class)
public final class AdditionNodeGen extends AdditionNode {
    public static AdditionNode create(EasyScriptNode leftNode, EasyScriptNode rightNode) {
        return new AdditionNodeGen(leftNode, rightNode);
    }

    @Child private EasyScriptNode leftNode_;
    @Child private EasyScriptNode rightNode_;
    @CompilationFinal private int state_;
    @CompilationFinal private int exclude_;

    private AdditionNodeGen(EasyScriptNode leftNode, EasyScriptNode rightNode) {
        this.leftNode_ = leftNode;
        this.rightNode_ = rightNode;
    }

    @Override
    public double executeDouble(VirtualFrame frameValue) {
        int state = state_;
        double leftNodeValue_ = this.leftNode_.executeDouble(frameValue);
        double rightNodeValue_ = this.rightNode_.executeDouble(frameValue);
        if ((state & 0b10) != 0 /* is-active addDoubles(double, double) */) {
            return addDoubles(leftNodeValue_, rightNodeValue_);
        }
        CompilerDirectives.transferToInterpreterAndInvalidate();
        return (double) executeAndSpecialize(leftNodeValue_, rightNodeValue_);
    }

    @Override
    public Object executeGeneric(VirtualFrame frameValue) {
        int state = state_;
        if ((state & 0b10) == 0 /* only-active addInts(int, int) */ && state != 0  /* is-not addInts(int, int) && addDoubles(double, double) */) {
            return executeGeneric_int_int0(frameValue, state);
        } else if ((state & 0b1) == 0 /* only-active addDoubles(double, double) */ && state != 0  /* is-not addInts(int, int) && addDoubles(double, double) */) {
            return executeGeneric_double_double1(frameValue, state);
        } else {
            return executeGeneric_generic2(frameValue, state);
        }
    }

    private Object executeGeneric_int_int0(VirtualFrame frameValue, int state) {
        // skipped for brevity - very similar to executeInt()...
    }

    private Object executeGeneric_double_double1(VirtualFrame frameValue, int state) {
        // skipped for brevity - very similar to executeDouble()...
    }

    private Object executeGeneric_generic2(VirtualFrame frameValue, int state) {
        // skipped for brevity - very similar to executeAndSpecialize()...
    }

    @Override
    public int executeInt(VirtualFrame frameValue) throws UnexpectedResultException {
        int state = state_;
        int leftNodeValue_;
        try {
            leftNodeValue_ = this.leftNode_.executeInt(frameValue);
        } catch (UnexpectedResultException ex) {
            Object rightNodeValue = this.rightNode_.executeGeneric(frameValue);
            return expectInteger(executeAndSpecialize(ex.getResult(), rightNodeValue));
        }
        int rightNodeValue_;
        try {
            rightNodeValue_ = this.rightNode_.executeInt(frameValue);
        } catch (UnexpectedResultException ex) {
            return expectInteger(executeAndSpecialize(leftNodeValue_, ex.getResult()));
        }
        if ((state & 0b1) != 0 /* is-active addInts(int, int) */) {
            try {
                return addInts(leftNodeValue_, rightNodeValue_);
            } catch (ArithmeticException ex) {
                // implicit transferToInterpreterAndInvalidate()
                Lock lock = getLock();
                lock.lock();
                try {
                    this.exclude_ = this.exclude_ | 0b1 /* add-excluded addInts(int, int) */;
                    this.state_ = this.state_ & 0xfffffffe /* remove-active addInts(int, int) */;
                } finally {
                    lock.unlock();
                }
                return expectInteger(executeAndSpecialize(leftNodeValue_, rightNodeValue_));
            }
        }
        CompilerDirectives.transferToInterpreterAndInvalidate();
        return expectInteger(executeAndSpecialize(leftNodeValue_, rightNodeValue_));
    }

    private Object executeAndSpecialize(Object leftNodeValue, Object rightNodeValue) {
        Lock lock = getLock();
        boolean hasLock = true;
        lock.lock();
        int state = state_;
        int exclude = exclude_;
        try {
            if ((exclude) == 0 /* is-not-excluded addInts(int, int) */ && leftNodeValue instanceof Integer) {
                int leftNodeValue_ = (int) leftNodeValue;
                if (rightNodeValue instanceof Integer) {
                    int rightNodeValue_ = (int) rightNodeValue;
                    this.state_ = state = state | 0b1 /* add-active addInts(int, int) */;
                    try {
                        lock.unlock();
                        hasLock = false;
                        return addInts(leftNodeValue_, rightNodeValue_);
                    } catch (ArithmeticException ex) {
                        // implicit transferToInterpreterAndInvalidate()
                        lock.lock();
                        try {
                            this.exclude_ = this.exclude_ | 0b1 /* add-excluded addInts(int, int) */;
                            this.state_ = this.state_ & 0xfffffffe /* remove-active addInts(int, int) */;
                        } finally {
                            lock.unlock();
                        }
                        return executeAndSpecialize(leftNodeValue_, rightNodeValue_);
                    }
                }
            }
            if (leftNodeValue instanceof Double) {
                double leftNodeValue_ = (double) leftNodeValue;
                if (rightNodeValue instanceof Double) {
                    double rightNodeValue_ = (double) rightNodeValue;
                    this.exclude_ = exclude = exclude | 0b1 /* add-excluded addInts(int, int) */;
                    state = state & 0xfffffffe /* remove-active addInts(int, int) */;
                    this.state_ = state = state | 0b10 /* add-active addDoubles(double, double) */;
                    lock.unlock();
                    hasLock = false;
                    return addDoubles(leftNodeValue_, rightNodeValue_);
                }
            }
            throw new UnsupportedSpecializationException(this, new Node[] {this.leftNode_, this.rightNode_}, leftNodeValue, rightNodeValue);
        } finally {
            if (hasLock) {
                lock.unlock();
            }
        }
    }

    @Override
    public NodeCost getCost() {
        int state = state_;
        if (state == 0b0) {
            return NodeCost.UNINITIALIZED;
        } else if ((state & (state - 1)) == 0 /* is-single-active  */) {
            return NodeCost.MONOMORPHIC;
        }
        return NodeCost.POLYMORPHIC;
    }

    private static int expectInteger(Object value) throws UnexpectedResultException {
        if (value instanceof Integer) {
            return (int) value;
        }
        throw new UnexpectedResultException(value);
    }
}

I’ve omitted a few of the more repetitive fragments of the code to save space, but you can see it’s pretty similar to the AdditionNode we’ve written manually in the previous article.

There are a few obvious differences. It uses a bitset kept in an int to store the different states, instead of an enum, as that’s faster and smaller, and doesn’t require an additional class for each Node type. The actual bit mask representing each state (0b1, 0b10, etc.) is repeated every time it’s used, but hey – this is generated code, we don’t care about that sort of duplication! There are also 2 bitsets used – one for the active specializations, and one for the excluded ones, the latter needed because we used the replaces attribute of the @Specialization annotation (so, the int specialization becomes excluded in favor of the double specialization). Naturally, both bitset fields are marked with @CompilationFinal, like our enum-based state field was in the hand-written code.

You might also notice that all mutations of the state fields are protected with a Lock provided by Truffle’s Node superclass. This is extra assurance in case the interpreter runs a multi-threaded program. This clearly shows the advantage of generated code – using the lock requires a lot of boilerplate code (you have to store a reference to it in a local variable, you need to unlock it in a finally block, etc.), and so I skipped it in the manually written version, but the DSL can go this extra mile, taking care of all the boilerplate code for us.

However, if you compare the executeInt and executeDouble methods between the manually written and generated versions, you’ll see they’re basically identical.

Even though the generated code tries to be as efficient as possible, you can clearly see that some care has been taken to make it nicer to read: it includes comments showing which bitmask corresponds to which state, and a helpful utility method expectInteger.

The TypeSystem class

Now, if you look closely at the test I’ve written above, you can see that I’ve cheated a little: I’ve used DoubleLiteralNode, even though the values I used, Int.MAX_VALUE and 1, could be used with IntLiteralNode.

If we attempt to write a test for overflow using IntLiteralNode:

import com.oracle.truffle.api.CallTarget;

import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class OverflowTest {
    @Test
    public void adding_1_to_int_max_does_not_overflow() {
        EasyScriptNode exprNode = AdditionNodeGen.create(
                new IntLiteralNode(Integer.MAX_VALUE),
                new IntLiteralNode(1));
        var rootNode = new EasyScriptRootNode(exprNode);
        CallTarget callTarget = rootNode.getCallTarget();

        var result = callTarget.call();

        assertEquals(Integer.MAX_VALUE + 1D, result);
    }
}

It will actually fail with an exception:

com.oracle.truffle.api.dsl.UnsupportedSpecializationException:
Unexpected values provided for AdditionNodeGen@62230c58: [2147483647, 1], [Integer,Integer]

This is thrown from executeAndSpecialize() that we’ve seen above in AdditionNodeGen. Let’s re-write the code slightly, removing some unnecessary details, to make it easier to see exactly what is the problem:

import com.oracle.truffle.api.dsl.UnsupportedSpecializationException;
import com.oracle.truffle.api.nodes.Node;

public final class AdditionNodeGen extends AdditionNode {
    // ...

    private Object executeAndSpecialize(Object leftNodeValue, Object rightNodeValue) {
        if (!this.excludedStates.contains(INT_STATE) &&
                leftNodeValue instanceof Integer &&
                rightNodeValue instanceof Integer) {
            int leftNodeValue_  = (int) leftNodeValue;
            int rightNodeValue_ = (int) rightNodeValue;
            this.activeStates.add(INT_STATE);
            try {
                return this.addInts(leftNodeValue_, rightNodeValue_);
            } catch (ArithmeticException ex) {
                this.excludedStates.add(INT_STATE);
                this.activeStates.remove(INT_STATE);
                return this.executeAndSpecialize(leftNodeValue_, rightNodeValue_);
            }
        }

        if (leftNodeValue instanceof Double &&
                rightNodeValue instanceof Double) {
            double leftNodeValue_  = (double) leftNodeValue;
            double rightNodeValue_ = (double) rightNodeValue;
            this.excludedStates.add(INT_STATE);
            this.activeStates.remove(INT_STATE);
            this.activeStates.add(DOUBLE_STATE);
            return this.addDoubles(leftNodeValue_, rightNodeValue_);
        }

        throw new UnsupportedSpecializationException(this, new Node[] {this.leftNode_, this.rightNode_}, leftNodeValue, rightNodeValue);
    }

    // ...
}

With this, we can clearly see what happens:

  1. The first if succeeds, and addInts is called.
  2. addInts calls Math.addExact(), which throws ArithmeticException.
  3. The catch executes, excluding the int specialization, and calls executeAndSpecialize recursively.
  4. The first if now fails, because the int specialization has been excluded.
  5. But the second if also fails, because integers are not instances of Double!
  6. UnsupportedSpecializationException is thrown.

So, the crucial change we need to make is in step 5. We need to somehow tell the Truffle DSL that it should treat Integer objects as valid Doubles. And there is a way to do that – with the @TypeSystem annotation.

@TypeSystem is an annotation you place on an abstract class. Inside that class, you can write static methods that allow you to express various relationships between the types in the language you’re implementing.

The first kind is a type check. Instead of the DSL generating simple instanceof expressions for checking whether a value is of a given type, like we’ve seen above, the DSL will instead call your method. The method needs to be annotated with the @TypeCheck annotation, which takes one attribute – the class of the type we’re checking for. The method must take a single argument of type Object, and return a boolean answering whether the provided value is of the type provided in @TypeCheck. In our case, we should return true if it’s either a Double, or an Integer:

import com.oracle.truffle.api.dsl.TypeCheck;
import com.oracle.truffle.api.dsl.TypeSystem;

@TypeSystem
public abstract class EasyScriptTypeSystem {
    @TypeCheck(double.class)
    public static boolean isDouble(Object value) {
        return value instanceof Double || value instanceof Integer;
    }

    // ...
}

The second kind of method is a type cast. Once the type check denotes that a given value is considered a given type, the DSL will attempt to convert it to that type. By default, it just uses the built-in Java cast (like (double) myValue). Of course, that default cast fails when attempting to cast an Integer to double. Because of that, the Truffle DSL allows us to implement our own conversion logic, in a static method annotated with @TypeCast:

import com.oracle.truffle.api.dsl.TypeCast;
import com.oracle.truffle.api.dsl.TypeSystem;

@TypeSystem
public abstract class EasyScriptTypeSystem {
    // ...

    @TypeCast(double.class)
    public static double asDouble(Object value) {
        if (value instanceof Integer) {
            return ((Integer) value).doubleValue();
        } else {
            return (double) value;
        }
    }
}

So, this defines a type system, but how do we actually use it? It’s done by placing the @TypeSystemReference annotation on the Node classes, and providing the class annotated with @TypeSystem as the only attribute. Since this annotation is inherited from superclasses, it’s usually placed on the abstract Node superclass:

import com.oracle.truffle.api.dsl.TypeSystemReference;
import com.oracle.truffle.api.nodes.Node;

@TypeSystemReference(EasyScriptTypeSystem.class)
public abstract class EasyScriptNode extends Node {
    // ...
}

If you now run the above adding_1_to_int_max_does_not_overflow test, it should pass! If you examine the AdditionNodeGen class, you should see that executeAndSpecialize() now uses our methods:

import com.oracle.truffle.api.dsl.UnsupportedSpecializationException;
import com.oracle.truffle.api.nodes.Node;

public final class AdditionNodeGen extends AdditionNode {
    // ...

    private Object executeAndSpecialize(Object leftNodeValue, Object rightNodeValue) {
        if (!this.excludedStates.contains(INT_STATE) &&
                leftNodeValue instanceof Integer &&
                rightNodeValue instanceof Integer) {
            int leftNodeValue_  = (int) leftNodeValue;
            int rightNodeValue_ = (int) rightNodeValue;
            this.activeStates.add(INT_STATE);
            try {
                return this.addInts(leftNodeValue_, rightNodeValue_);
            } catch (ArithmeticException ex) {
                this.excludedStates.add(INT_STATE);
                this.activeStates.remove(INT_STATE);
                return this.executeAndSpecialize(leftNodeValue_, rightNodeValue_);
            }
        }

        if (EasyScriptTypeSystem.isDouble(leftNodeValue) &&
                EasyScriptTypeSystem.isDouble(rightNodeValue)) {
            double leftNodeValue_ = EasyScriptTypeSystem.asDouble(leftNodeValue);
            double rightNodeValue_ = EasyScriptTypeSystem.asDouble(rightNodeValue);
            this.excludedStates.add(INT_STATE);
            this.activeStates.remove(INT_STATE);
            this.activeStates.add(DOUBLE_STATE);
            return this.addDoubles(leftNodeValue_, rightNodeValue_);
        }

        throw new UnsupportedSpecializationException(this, new Node[] {this.leftNode_, this.rightNode_}, leftNodeValue, rightNodeValue);
    }

    // ...
}

(instanceof is still used for ints, as we didn’t instruct the Truffle DSL to do anything special for them)

@ImplicitCast

While our two method combination with @TypeCheck and @TypeCast works, it’s not ideal for two reasons:

  1. We had to write two separate methods with pretty coupled logic, which violates the DRY principle.
  2. The complicated check logic like v instanceof Double || v instanceof Integer might be difficult for Graal to optimize away, and thus have an adverse effect on the performance of the generated native code when JITting.

For these reasons, there is a different way to express in the @TypeSystem class the fact the a given type can be treated as another type: the @ImplicitCast annotation. It can be placed on a static method that performs the conversion between the types:

import com.oracle.truffle.api.dsl.ImplicitCast;
import com.oracle.truffle.api.dsl.TypeSystem;

@TypeSystem
public abstract class EasyScriptTypeSystem {
    @ImplicitCast
    public static double castIntToDouble(int value) {
        return value;
    }
}

This has pretty much the same effect as the couple of methods we had before, but with a single one, that is sure to produce efficient code when it’s JITted.

Using @ImplicitCast has a pretty dramatic effect on the generated code – it creates an entire new class that extends the abstract @TypeSystem with a bunch of static utility methods, and the code in AdditionNodeGen looks quite a bit different, as it now has to handle the possibility of an int being passed anywhere a double was previously expected.

Summary

So, this is how to use the Truffle DSL to implement Node classes with specializations. As you can see compared to the previous article, it uses very little code to achieve the same thing as the manually written version did, by generating most of the boilerplate code needed for managing the active specializations of a given Node. We’ve only scratched the surface of the capabilities of the DSL – we’ll see some of them later, as we add more features to our EasyScript language.

As always, all code from the article is available on GitHub.

In the next part of the series, we’ll talk about parsing, and introduce GraalVM’s polyglot API.


This article is part of a tutorial on GraalVM's Truffle language implementation framework.