from typing import Iterable
from typing import Iterator
from typing import MutableMapping
from typing import Optional
from typing import Tuple
from typing import TypeVar
from .ps_tree_node import Node
_KEY = TypeVar("_KEY")
_PRIORITY = TypeVar("_PRIORITY")
[docs]
class PrioritySearchTree(MutableMapping):
"""Class that represents Priority search tree.
PrioritySearchTree is a mutable mapping that stores **keys** and corresponding **priorities**.
* Keys are stored in balanced binary search tree (red/black tree) that allow to effectively perform next operations:
* in order traversal
* find min/max keys
* find next/previous keys
* Priorities and keys form max priority queue, that allow to effectively perform next operations:
* find element with max priority
* remove element with max priority
* update priority for a given key
* It is capable to perform 3 sided queries
Example::
# create new tree
pst = PrioritySearchTree([(1,1),(2,2)])
# add key 3 to the tree with priority 5
pst[3] = 5
# perform 3 sided query
result = pst.query(0,4,2)
Args:
iterable (Iterable): Initial values to build priority search tree.
Each item in the iterable must itself be an iterable with exactly two objects.
The first object of each item becomes a **key** in the new pst, and the second object the corresponding
**priority**. The default value is ``None``.
Raises:
KeyError: in case if iterable contains values with not unique **key**
Complexity:
`O(N*log(N))` where **N** is number of items to be added to new PST
"""
__slots__ = ["_root", "_len"]
def _push_down(self, node: Node, heap_key: Tuple[_PRIORITY, _KEY]) -> None:
if node.heap_key[0] == Node.PLACEHOLDER_VALUE:
node.heap_key = heap_key
return
if node.heap_key[1] < node.tree_key:
self._push_down(node.left, node.heap_key)
else:
self._push_down(node.right, node.heap_key)
node.heap_key = heap_key
def _sift_down(self, node: Node, heap_key: Tuple[_PRIORITY, _KEY]) -> None:
if node.heap_key[0] == Node.PLACEHOLDER_VALUE:
node.heap_key = heap_key
return
if heap_key > node.heap_key:
self._push_down(node, heap_key)
return
if heap_key[1] < node.tree_key:
self._sift_down(node.left, heap_key)
else:
self._sift_down(node.right, heap_key)
def _push_up(self, node: Node) -> None:
if node.heap_key[0] == Node.PLACEHOLDER_VALUE:
return
if node.left.heap_key[0] == Node.PLACEHOLDER_VALUE:
node.heap_key = node.right.heap_key
self._push_up(node.right)
return
if node.right.heap_key[0] == Node.PLACEHOLDER_VALUE:
node.heap_key = node.left.heap_key
self._push_up(node.left)
return
if node.left.heap_key >= node.right.heap_key:
node.heap_key = node.left.heap_key
self._push_up(node.left)
else:
node.heap_key = node.right.heap_key
self._push_up(node.right)
[docs]
def __init__(self, iterable: Optional[Iterable[Tuple[_KEY, _PRIORITY]]] = None) -> None:
self._root: Node = Node.NULL_NODE
self._len: int = 0
if iterable:
sn = sorted(iterable)
current_key = sn[0]
for next_key in sn[1:]:
if current_key[0] == next_key[0]:
raise KeyError(f"More than one item with key:{current_key[0]}")
current_key = next_key
sn_len = len(sn)
sn_iter = iter(sn)
tree_nodes = []
lvl = sn_len.bit_length()
for _ in range(sn_len - 2 ** (lvl - 1)):
keys = next(sn_iter)
ln = Node(tree_key=keys[0], heap_key=(keys[1], keys[0]))
keys = next(sn_iter)
rn = Node(tree_key=keys[0], heap_key=(keys[1], keys[0]))
pn = Node(tree_key=rn.tree_key, heap_key=(keys[1], keys[0]), color=0)
pn.set_left(ln)
pn.set_right(rn)
self._push_up(pn)
tree_nodes.append((pn, ln.tree_key, rn.tree_key))
for keys in sn_iter:
pn = Node(tree_key=keys[0], heap_key=(keys[1], keys[0]), color=0)
tree_nodes.append((pn, pn.tree_key, pn.tree_key))
while len(tree_nodes) > 1:
new_nodes = []
for i in range(0, len(tree_nodes), 2):
ln, ln_min, ln_max = tree_nodes[i]
rn, rn_min, rn_max = tree_nodes[i + 1]
pn = Node(tree_key=rn_min, heap_key=ln.heap_key, color=0)
pn.set_left(ln)
pn.set_right(rn)
self._push_up(pn)
new_nodes.append((pn, ln_min, rn_max))
tree_nodes = new_nodes
self._root = tree_nodes[0][0]
self._len = sn_len
[docs]
def get_with_max_priority(self) -> _KEY:
"""Returns the **key** with the largest **priority** in PST.
Returns:
**key** with the largest **priority**
Raises:
KeyError: If the PST is empty
Complexity:
`O(1)`
"""
if self._root == Node.NULL_NODE:
raise KeyError
return self._root.heap_key[1]
[docs]
def popitem(self) -> Tuple[_KEY, _PRIORITY]:
"""Remove and return (key, priority) pair from the PST. Pair with max **priority** will be removed.
Returns:
Tuple: **key** and **priority** pair
Raises:
KeyError: If the PST is empty
Complexity:
`O(log(N))` where **N** is number of items in PST
"""
if self._root == Node.NULL_NODE:
raise KeyError
result = self._root.heap_key
del self[result[1]]
return result[1], result[0]
def _fix_delete(self, node: Node) -> None:
while node != self._root and node.color == 0:
if node == node.parent.left:
s_node = node.parent.right
if s_node.color == 1:
s_node.color = 0
node.parent.color = 1
self._rotate_left(node.parent)
s_node = node.parent.right
if s_node.left.color == 0 and s_node.right.color == 0:
s_node.color = 1
node = node.parent
else:
if s_node.right.color == 0:
s_node.left.color = 0
s_node.color = 1
self._rotate_right(s_node)
s_node = node.parent.right
s_node.color = node.parent.color
node.parent.color = 0
s_node.right.color = 0
self._rotate_left(node.parent)
node = self._root
else:
s_node = node.parent.left
if s_node.color == 1:
s_node.color = 0
node.parent.color = 1
self._rotate_right(node.parent)
s_node = node.parent.left
if s_node.right.color == 0 and s_node.left.color == 0:
s_node.color = 1
node = node.parent
else:
if s_node.left.color == 0:
s_node.right.color = 0
s_node.color = 1
self._rotate_left(s_node)
s_node = node.parent.left
s_node.color = node.parent.color
node.parent.color = 0
s_node.left.color = 0
self._rotate_right(node.parent)
node = self._root
node.color = 0
def _transplant(self, u: Node, v: Node) -> None:
if u.parent is None:
self._root = v
self._root.parent = None
elif u == u.parent.left:
u.parent.set_left(v)
else:
u.parent.set_right(v)
[docs]
def query(self, key_left: _KEY, key_right: _KEY, priority_bottom: _PRIORITY) -> list:
"""Performs 3 sided query on PST.
This function returns list of items that meet the following criteria:
1. items have **key** grater or equal to `key_left` argument
2. items have **key** smaller or equal to `key_right` argument
3. items have **priority** grater or equal to `priority_bottom` argument
Args:
key_left: Left bound for query (**key** is used to compare).
key_right: Right bound for query (**key** is used to compare).
priority_bottom: Bottom bound for query (**priority** is used to compare).
Returns:
List: list of **keys** that satisfy criteria, or empty list if no items found
Complexity:
`O(log(N)+K)` where **N** is number of items in PST and **K** is number of reported items
"""
result = []
def _query_node(node) -> None:
if node == Node.NULL_NODE or node.heap_key[0] == Node.PLACEHOLDER_VALUE:
return
if node.heap_key[0] >= priority_bottom:
if key_left <= node.heap_key[1] <= key_right:
result.append(node.heap_key[1])
else:
return
if key_right < node.tree_key:
_query_node(node.left)
elif key_left >= node.tree_key:
_query_node(node.right)
else:
_query_node(node.left)
_query_node(node.right)
_query_node(self._root)
return result
[docs]
def sorted_query(self, key_left: _KEY, key_right: _KEY, priority_bottom: _PRIORITY, items_limit: int = 0) -> list:
"""Performs 3 sided query on PST.
This function returns list of items that meet the following criteria:
1. items have **key** grater or equal to `key_left` argument
2. items have **key** smaller or equal to `key_right` argument
3. items have **priority** grater or equal to `priority_bottom` argument
Args:
key_left: Left bound for query (**key** is used to compare).
key_right: Right bound for query (**key** is used to compare).
priority_bottom: Bottom bound for query (**priority** is used to compare).
items_limit (int): Number of items to return. Default value is ``0`` - no limit.
Returns:
List: list of items that satisfy criteria and sorted by **priority**
(in case of limit, items with largest **priority** will be returned), or empty list if no items found
Complexity:
`O(log(N)+K*log(K))` where **N** is number of items in PST and **K** is number of returned items
"""
if items_limit <= 0:
items_limit = self._len
def _sorted_query_node(node, limit) -> list:
result = []
if node == Node.NULL_NODE or node.heap_key[0] == Node.PLACEHOLDER_VALUE or limit == 0:
return result
if node.heap_key[0] >= priority_bottom:
if key_left <= node.heap_key[1] <= key_right:
result.append(node.heap_key)
limit -= 1
else:
return result
if key_right < node.tree_key:
result.extend(_sorted_query_node(node.left, limit))
elif key_left >= node.tree_key:
result.extend(_sorted_query_node(node.right, limit))
else:
left = _sorted_query_node(node.left, limit)
right = _sorted_query_node(node.right, limit)
# merge
i, j = 0, 0
while i < len(left) and j < len(right) and len(result) < items_limit:
if left[i] >= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
while i < len(left) and len(result) < items_limit:
result.append(left[i])
i += 1
while j < len(right) and len(result) < items_limit:
result.append(right[j])
j += 1
return result
return [x[1] for x in _sorted_query_node(self._root, items_limit)]
def _fix_insert(self, node: Node) -> None:
while node.parent.color == 1:
if node.parent.parent.right == node.parent:
u = node.parent.parent.left
if u.color == 1:
u.color = 0
node.parent.color = 0
node.parent.parent.color = 1
node = node.parent.parent
else:
if node.parent.left == node:
node = node.parent
self._rotate_right(node)
node.parent.color = 0
node.parent.parent.color = 1
self._rotate_left(node.parent.parent)
else:
u = node.parent.parent.right
if u.color == 1:
u.color = 0
node.parent.color = 0
node.parent.parent.color = 1
node = node.parent.parent
else:
if node.parent.right == node:
node = node.parent
self._rotate_left(node)
node.parent.color = 0
node.parent.parent.color = 1
self._rotate_right(node.parent.parent)
if node == self._root:
break
self._root.color = 0
def _rotate_right(self, x: Node) -> None:
y = x.left
self._push_down(y, x.heap_key)
x.set_left(y.right)
if not x.parent:
self._root = y
y.parent = None
elif x == x.parent.left:
x.parent.set_left(y)
else:
x.parent.set_right(y)
y.set_right(x)
self._push_up(x)
def _rotate_left(self, x: Node) -> None:
y = x.right
self._push_down(y, x.heap_key)
x.set_right(y.left)
if not x.parent:
self._root = y
y.parent = None
elif x == x.parent.left:
x.parent.set_left(y)
else:
x.parent.set_right(y)
y.set_left(x)
self._push_up(x)
[docs]
def __len__(self) -> int:
"""Implements the built-in function len()
Returns:
int: Number of items in PST.
Complexity:
`O(1)`
"""
return self._len
[docs]
def clear(self) -> None:
"""Removes **all** items from PST.
Complexity:
`O(1)`
"""
self._root = Node.NULL_NODE
self._len = 0
[docs]
def update_priority(self, key: _KEY, priority: _PRIORITY) -> _PRIORITY:
"""Updates priority for the given key.
Args:
key: **key** to update
priority: new **priority** value
Returns:
old **priority** value
Raises:
KeyError: in case if **key** not exists in PST
Complexity:
`O(log(N))` where **N** is number of items in PST
"""
node = self._root
heap_node = None
while node.heap_key[0] != Node.PLACEHOLDER_VALUE:
if key == node.heap_key[1]:
heap_node = node
break
if key < node.tree_key:
node = node.left
else:
node = node.right
if not heap_node:
raise KeyError(f"Key not found:{key}")
result = heap_node.heap_key[0]
self._push_up(heap_node)
self._sift_down(self._root, (priority, key))
return result
[docs]
def __setitem__(self, key: _KEY, priority: _PRIORITY) -> None:
"""implements assignment operation.
Args:
key: **key** to add/update
priority: new **priority**
Complexity:
`O(log(N))` where **N** is number of items in PST
"""
if self._root == Node.NULL_NODE:
self._root = Node(tree_key=key, heap_key=(priority, key), color=0)
self._len = 1
return
prev = None
node = self._root
while node != Node.NULL_NODE:
prev = node
if key < node.tree_key:
node = node.left
elif key == node.tree_key:
self.update_priority(key, priority)
return
else:
node = node.right
new_placeholder = Node(tree_key=key, heap_key=(Node.PLACEHOLDER_VALUE, Node.PLACEHOLDER_VALUE))
prev_placeholder = Node(tree_key=prev.tree_key, heap_key=(Node.PLACEHOLDER_VALUE, Node.PLACEHOLDER_VALUE))
if key < prev.tree_key:
prev.set_right(prev_placeholder)
prev.set_left(new_placeholder)
else:
prev.tree_key = key
prev.set_right(new_placeholder)
prev.set_left(prev_placeholder)
self._sift_down(self._root, (priority, key))
self._fix_insert(new_placeholder)
self._len += 1
[docs]
def __delitem__(self, key: _KEY) -> None:
"""Remove **key** from PST.
Args:
key: **key** to remove
Raises:
KeyError: in case if **key** not exists in PST
Complexity:
`O(log(N))` where **N** is number of items in PST
"""
tree_node = None
node = self._root
while node != Node.NULL_NODE:
if key == node.tree_key:
tree_node = node
break
elif key < node.tree_key:
node = node.left
else:
node = node.right
if tree_node is None:
raise KeyError(f"Key not found:{key}")
leaf_node = None
node = tree_node
while node != Node.NULL_NODE:
leaf_node = node
if key < node.tree_key:
node = node.left
else:
node = node.right
if leaf_node == self._root:
self._root = Node.NULL_NODE
self._len = 0
return
# remove heap value
node = leaf_node
while node.heap_key[1] != key:
node = node.parent
self._push_up(node)
if tree_node.left == Node.NULL_NODE: # left node
cut_node = tree_node.parent
fix_node = tree_node.parent.right
elif tree_node.right == leaf_node: # leaf children
cut_node = tree_node
fix_node = tree_node.left
else: # subtree case
tree_node.tree_key = leaf_node.parent.tree_key
cut_node = leaf_node.parent
fix_node = leaf_node.parent.right
self._push_down(cut_node, cut_node.heap_key)
self._transplant(cut_node, fix_node)
if cut_node.color == 0:
self._fix_delete(fix_node)
self._len -= 1
[docs]
def __getitem__(self, key: _KEY) -> _PRIORITY:
"""Returns **priority** of given **key** in PST.
Args:
key: **key** to find
Returns:
**priority** value if the given **key**
Raises:
KeyError: in case if **key** not exists in PST
Complexity:
`O(log(N))` where **N** is number of items in PST
"""
node = self._root
heap_node = None
while node.heap_key[0] != Node.PLACEHOLDER_VALUE:
if key == node.heap_key[1]:
heap_node = node
break
if key < node.tree_key:
node = node.left
else:
node = node.right
if not heap_node:
raise KeyError(f"Key not found:{key}")
return heap_node.heap_key[0]
[docs]
def __iter__(self) -> Iterator:
"""Create an iterator that iterates **keys** in sorted order
Returns:
Iterator: in order iterator
"""
stack = []
current = self._root
yielded_key = None
while True:
if current != Node.NULL_NODE:
stack.append(current)
current = current.left
elif stack:
current = stack.pop()
if current.tree_key != yielded_key:
yielded_key = current.tree_key
yield yielded_key
current = current.right
else:
break