Snippet

Binary Search — Every Variant

The three forms of binary search you need to know: exact match, leftmost insertion point, and rightmost insertion point.

2 min read

Binary search is O(logn)O(\log n) but there are three importantly different variants. Get the off-by-one wrong and you're in trouble.

Exact Match

def binary_search(arr: list[int], target: int) -> int:
    """Return index of target, or -1 if not found."""
    lo, hi = 0, len(arr) - 1
    while lo <= hi:
        mid = lo + (hi - lo) // 2   # avoids integer overflow
        if arr[mid] == target:
            return mid
        elif arr[mid] < target:
            lo = mid + 1
        else:
            hi = mid - 1
    return -1

Leftmost Insertion Point (lower_bound)

Find the leftmost position where target could be inserted to keep the array sorted — equivalent to C++ std::lower_bound.

def lower_bound(arr: list[int], target: int) -> int:
    """Return leftmost i such that arr[i] >= target."""
    lo, hi = 0, len(arr)             # note: hi = len(arr), not len-1
    while lo < hi:                   # note: strict <
        mid = (lo + hi) // 2
        if arr[mid] < target:
            lo = mid + 1
        else:
            hi = mid                 # never hi = mid - 1
    return lo

Rightmost Insertion Point (upper_bound)

def upper_bound(arr: list[int], target: int) -> int:
    """Return leftmost i such that arr[i] > target."""
    lo, hi = 0, len(arr)
    while lo < hi:
        mid = (lo + hi) // 2
        if arr[mid] <= target:       # note: <=, not <
            lo = mid + 1
        else:
            hi = mid
    return lo

TypeScript Generic Version

binarySearch.ts
function lowerBound<T>(
  arr: T[],
  target: T,
  cmp: (a: T, b: T) => number = (a, b) => (a < b ? -1 : a > b ? 1 : 0)
): number {
  let lo = 0, hi = arr.length
  while (lo < hi) {
    const mid = (lo + hi) >> 1
    if (cmp(arr[mid], target) < 0) lo = mid + 1
    else hi = mid
  }
  return lo
}
💡

Use lower_bound and upper_bound together: the number of occurrences of target in a sorted array is exactly upper_bound(arr, target) - lower_bound(arr, target).

Count occurrences in O(logn)O(\log n)

def count(arr: list[int], target: int) -> int:
    return upper_bound(arr, target) - lower_bound(arr, target)