Inspired by
Kartik Kukreja's implementation, but storing all data in a tree node class, rather than an array. I find it a lot easier to comprehend this way.
class SegmentTreeNode {
public int aggregateValue;
public SegmentTreeNode left;
public SegmentTreeNode right;
public int origLowIndex;
public int origHighIndex;
}
class SegmentTreeTest {
public static SegmentTreeNode buildSegmentTree(int[] a, int low, int high) {
SegmentTreeNode n = new SegmentTreeNode();
n.origLowIndex = low;
n.origHighIndex = high;
if(low == high) {
n.aggregateValue = a[low];
return n;
}
int mid = (high - low)/2 + low;
n.left = buildSegmentTree(a, low, mid);
n.right = buildSegmentTree(a, mid+1, high);
// This segment tree is for summation. Could also be min, max, or any other associative func.
n.aggregateValue = n.left.aggregateValue + n.right.aggregateValue;
return n;
}
public static SegmentTreeNode getAggregateValue(SegmentTreeNode n, int low, int high) {
if(n.origLowIndex == low && n.origHighIndex == high) {
return n;
}
// The interval is fully contained within the left node
if(low >= n.left.origLowIndex && high <= n.left.origHighIndex) {
return getAggregateValue(n.left, low, high);
}
// The interval is fully contained within the right node
if(low >= n.right.origLowIndex && high <= n.right.origHighIndex) {
return getAggregateValue(n.right, low, high);
}
// Split into queries on the left subtree and the right subtree
SegmentTreeNode leftResult = getAggregateValue(n.left, low, n.left.origHighIndex);
SegmentTreeNode rightResult = getAggregateValue(n.right, n.right.origLowIndex, high);
SegmentTreeNode result = new SegmentTreeNode();
result.origLowIndex = low;
result.origHighIndex = high;
// This segment tree is for summation. Could also be min, max, or any other associative func.
result.aggregateValue = leftResult.aggregateValue + rightResult.aggregateValue;
return result;
}
public static void update(SegmentTreeNode n, int index, int val) {
if(n.origLowIndex == index && n.origHighIndex == index) {
n.aggregateValue = val;
return;
}
if(n.left.origLowIndex <= index && index <= n.left.origHighIndex) {
update(n.left, index, val);
} else {
update(n.right, index, val);
}
// This segment tree is for summation. Could also be min, max, or any other associative func.
n.aggregateValue = n.left.aggregateValue + n.right.aggregateValue;
}
public static void main(String[] args) {
// 0 1 2 3 4 5 6 7 8 9 10
int[] a = new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 };
SegmentTreeNode r = buildSegmentTree(a, 0, a.length-1);
System.out.println(r.aggregateValue);
System.out.println(getAggregateValue(r, 1, 3).aggregateValue);
System.out.println(getAggregateValue(r, 4, 7).aggregateValue);
update(r, 2, 10);
System.out.println(r.aggregateValue);
System.out.println(getAggregateValue(r, 4, 7).aggregateValue);
System.out.println(getAggregateValue(r, 1, 3).aggregateValue);
}
}