Friday, March 25, 2016

Segment tree implementation in Java

Inspired by Kartik Kukreja's implementation, but storing all data in a tree node class, rather than an array. I find it a lot easier to comprehend this way.
class SegmentTreeNode {
    public int aggregateValue;
    public SegmentTreeNode left;
    public SegmentTreeNode right;
    public int origLowIndex;
    public int origHighIndex;
}

class SegmentTreeTest {

    public static SegmentTreeNode buildSegmentTree(int[] a, int low, int high) {
        SegmentTreeNode n = new SegmentTreeNode();
        n.origLowIndex = low;
        n.origHighIndex = high;

        if(low == high) {
            n.aggregateValue = a[low];
            return n;
        }

        int mid = (high - low)/2 + low;

        n.left = buildSegmentTree(a, low, mid);
        n.right = buildSegmentTree(a, mid+1, high);

        // This segment tree is for summation. Could also be min, max, or any other associative func.
        n.aggregateValue = n.left.aggregateValue + n.right.aggregateValue;

        return n;
    }

    public static SegmentTreeNode getAggregateValue(SegmentTreeNode n, int low, int high) {
        if(n.origLowIndex == low && n.origHighIndex == high) {
            return n;
        }

        // The interval is fully contained within the left node
        if(low >= n.left.origLowIndex && high <=  n.left.origHighIndex) {
            return getAggregateValue(n.left, low, high);
        }

        // The interval is fully contained within the right node
        if(low >= n.right.origLowIndex && high <=  n.right.origHighIndex) {
            return getAggregateValue(n.right, low, high);
        }

        // Split into queries on the left subtree and the right subtree
        SegmentTreeNode leftResult = getAggregateValue(n.left, low, n.left.origHighIndex);
        SegmentTreeNode rightResult = getAggregateValue(n.right, n.right.origLowIndex, high);
        SegmentTreeNode result = new SegmentTreeNode();
        result.origLowIndex = low;
        result.origHighIndex  = high;

        // This segment tree is for summation. Could also be min, max, or any other associative func.
        result.aggregateValue = leftResult.aggregateValue + rightResult.aggregateValue;

        return result;
    }

    public static void update(SegmentTreeNode n, int index, int val) {
        if(n.origLowIndex == index && n.origHighIndex == index) {
            n.aggregateValue = val;
            return;
        }

        if(n.left.origLowIndex <= index && index <= n.left.origHighIndex) {
            update(n.left, index, val);
        } else {
            update(n.right, index, val);
        }

        // This segment tree is for summation. Could also be min, max, or any other associative func.
        n.aggregateValue = n.left.aggregateValue + n.right.aggregateValue;
    }

    public static void main(String[] args) {
        //                    0  1  2  3  4  5  6  7  8   9  10
        int[] a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 };
        SegmentTreeNode r = buildSegmentTree(a, 0, a.length-1);
        System.out.println(r.aggregateValue);

        System.out.println(getAggregateValue(r, 1, 3).aggregateValue);
        System.out.println(getAggregateValue(r, 4, 7).aggregateValue);

        update(r, 2, 10);
        System.out.println(r.aggregateValue);

        System.out.println(getAggregateValue(r, 4, 7).aggregateValue);
        System.out.println(getAggregateValue(r, 1, 3).aggregateValue);
    }
}