Source code for priority_search_tree.ps_set

from typing import Callable
from typing import Iterable
from typing import Iterator
from typing import MutableSet
from typing import Optional
from typing import TypeVar

from .ps_tree import _KEY
from .ps_tree import _PRIORITY
from .ps_tree import PrioritySearchTree

_V = TypeVar("_V")


[docs] class PrioritySearchSet(MutableSet): """Mutable Set that maintains priority search tree properties. PrioritySearchSet can be used to store any type of objects. 2 functions should be passed to PrioritySearchSet constructor: * ``key_func`` to extract **key** for the object * ``priority_func`` to extract **priority** for the object extracted **key**, **priority** values will be used in underlying PrioritySearchTree Example:: @dataclass class Point: x: int y: int # create new set with Point.x as key and Point.y as priority pss = PrioritySearchSet(key_func=lambda p: p.x, priority_func=lambda p: p.y, iterable=[Point(1, 1), Point(2, 2)]) # add new item to set pss.add(Point(3, 3)) # 3-sided query result = pss.query(Point(1, 1), Point(3, 1), Point(2, 2)) # result = [Point(x=3, y=3), Point(x=2, y=2)] Args: key_func (Callable): Specifies a function of one argument that is used to extract a comparison **key** (for example, ``key_func=lambda p: p.x``). priority_func (Callable): Specifies a function of one argument that is used to extract a *priority* value (for example, ``priority_func=lambda p: p.y``). iterable (Iterable): Initial values to build priority search set. The default value is ``None``. Raises: KeyError: in case if iterable contains values with not unique **key** Complexity: `O(N*log(N))` where **N** is number of items to be added to new PSS """ __slots__ = ["_values", "key_func", "priority_func", "_pst"]
[docs] def __init__( self, key_func: Callable[[_V], _KEY], priority_func: Callable[[_V], _PRIORITY], iterable: Optional[Iterable[_V]] = None ) -> None: self.key_func = key_func self.priority_func = priority_func self._values = {} key_priorities = [] if iterable: for item in iterable: key = key_func(item) priority = priority_func(item) self._values[key] = item key_priorities.append((key, priority)) self._pst = PrioritySearchTree(key_priorities)
[docs] def get_with_max_priority(self) -> _KEY: """Return the item with the largest **priority** from the PSS. Returns: item with the largest **priority** Raises: KeyError: If the PSS is empty Complexity: Amortized `O(1)` """ return self._values[self._pst.get_with_max_priority()]
[docs] def pop(self) -> _V: """Remove and return the item with the largest **priority** from the PSS. Returns: item with the largest **priority** Raises: KeyError: If the PSS is empty Complexity: `O(log(N))` where **N** is number of items in PSS """ return self._values.pop(self._pst.popitem()[0])
[docs] def add(self, value: _V) -> None: """Add new item to priority search Set. Args: value: Value to insert into PSS Raises: KeyError: in case if value already exists in PSS Complexity: `O(log(N))` where **N** is number of items in PSS Note: this function is using ``key_func(value)`` to compare the items """ key = self.key_func(value) priority = self.priority_func(value) self._pst[key] = priority self._values[key] = value
[docs] def remove(self, value: _V) -> None: """Remove value from PSS. Args: value: Value to remove from PSS Raises: KeyError: in case if value not exists in PSS Complexity: `O(log(N))` where **N** is number of items in PSS Note: this function is using ``key_func(value)`` to compare the items """ key = self.key_func(value) del self._pst[key] del self._values[key]
[docs] def query(self, left: _V, right: _V, bottom: _V) -> list: """Performs 3 sided query on PSS. This function returns list of items that meet the following criteria: 1. items have **key** grater or equal to **key** of `left` argument 2. items have **key** smaller or equal to **key** of `right` argument 3. items have **priority** grater or equal to **priority** of `bottom` argument Args: left: Left bound for query (**key** is used to compare). right: Right bound for query (**key** is used to compare). bottom: Bottom bound for query (**priority** is used to compare). Returns: List: list of items that satisfy criteria, or empty list if no items found Complexity: `O(log(N)+K)` where **N** is number of items in PSS and **K** is number of reported items """ key_left = self.key_func(left) key_right = self.key_func(right) priority_bottom = self.priority_func(bottom) return [self._values[x] for x in self._pst.query(key_left, key_right, priority_bottom)]
[docs] def sorted_query(self, left: _V, right: _V, bottom: _V, items_limit: int = 0) -> list: """Performs sorted 3 sided query on PSS. This function returns list of items that meet the following criteria: 1. items have **key** grater or equal to **key** of `left` argument 2. items have **key** smaller or equal to **key** of `right` argument 3. items have **priority** grater or equal to **priority** of `bottom` argument Args: left: Left bound for query (**key** is used). right: Right bound for query (**key** is used). bottom: Bottom bound for query (**priority** is used). items_limit (int): Number of items to return. Default value is ``0`` - no limit. Returns: List: list of items that satisfy criteria and sorted by **priority** (in case of limit, items with largest **priority** will be returned), or empty list if no items found. Complexity: `O(log(N)+K*log(K))` where **N** is number of items in PSS and **K** is number of returned items """ key_left = self.key_func(left) key_right = self.key_func(right) priority_bottom = self.priority_func(bottom) return [self._values[x] for x in self._pst.sorted_query(key_left, key_right, priority_bottom, items_limit)]
[docs] def __len__(self) -> int: """Implements the built-in function len() Returns: int: Number of items in PSS. Complexity: `O(1)` """ return len(self._pst)
[docs] def __contains__(self, value: _V) -> bool: """Membership test operator Args: value: Value to test for membership Returns: bool: ``True`` if value is in ``self``, ``False`` otherwise. Complexity: `O(log(N))` where **N** is number of items in PSS Note: this function is using ``key_func(value)`` to compare the items """ return self.key_func(value) in self._values
[docs] def clear(self) -> None: """Removes **all** items from priority search set. Complexity: `O(1)` """ self._pst.clear() self._values.clear()
[docs] def discard(self, value) -> None: """Remove value from PSS if it exists. Args: value: Value to remove from PSS Complexity: `O(log(N))` where **N** is number of items in PSS Note: this function is using ``key_func(value)`` to compare the items """ key = self.key_func(value) if key in self._values: del self._pst[key] del self._values[key]
[docs] def __iter__(self) -> Iterator: """Create an iterator that iterates values in sorted by **key** order Returns: Iterator: in order iterator """ for key in self._pst: yield self._values[key]