Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pydatastructs/trees/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
BinarySearchTree,
BinaryTreeTraversal,
AVLTree,
BinaryIndexedTree
BinaryIndexedTree,
SplayTree
)
__all__.extend(binary_trees.__all__)

Expand Down
137 changes: 136 additions & 1 deletion pydatastructs/trees/binary_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
OneDimensionalArray, DynamicOneDimensionalArray)
from pydatastructs.linear_data_structures.arrays import ArrayForTrees
from collections import deque as Queue
from copy import deepcopy

__all__ = [
'AVLTree',
'BinaryTree',
'BinarySearchTree',
'BinaryTreeTraversal',
'BinaryIndexedTree'
'BinaryIndexedTree',
'SplayTree'
]

class BinaryTree(object):
Expand Down Expand Up @@ -754,6 +756,139 @@ def delete(self, key, **kwargs):
self._balance_deletion(a, key)
return True

class SplayTree(SelfBalancingBinaryTree):
"""
Represents Splay Trees.

References
==========

.. [1] https://en.wikipedia.org/wiki/Splay_tree

"""
def _zig(self, x, p):
if self.tree[p].left == x:
super(SplayTree, self)._right_rotate(p, x)
else:
super(SplayTree, self)._left_rotate(p, x)

def _zig_zig(self, x, p):
super(SplayTree, self)._right_rotate(self.tree[p].parent, p)
super(SplayTree, self)._right_rotate(p, x)

def _zig_zag(self, p):
super(SplayTree, self)._left_right_rotate(self.tree[p].parent, p)

def _zag_zag(self, x, p):
super(SplayTree, self)._left_rotate(self.tree[p].parent, p)
super(SplayTree, self)._left_rotate(p, x)

def _zag_zig(self, p):
super(SplayTree, self)._right_left_rotate(self.tree[p].parent, p)

def splay(self, x, p):
while self.tree[x].parent is not None:
if self.tree[p].parent is None:
self._zig(x, p)
elif self.tree[p].left == x and \
self.tree[self.tree[p].parent].left == p:
self._zig_zig(x, p)
elif self.tree[p].right == x and \
self.tree[self.tree[p].parent].right == p:
self._zag_zag(x, p)
elif self.tree[p].left == x and \
self.tree[self.tree[p].parent].right == p:
self._zag_zig(p)
else:
self._zig_zag(p)
p = self.tree[x].parent

def insert(self, key, x):
super(SelfBalancingBinaryTree, self).insert(key, x)
e, p = super(SelfBalancingBinaryTree, self).search(key, parent=True)
self.tree[self.size-1].parent = p
self.splay(e, p)

def delete(self, x):
e, p = super(SelfBalancingBinaryTree, self).search(x, parent=True)
if e is None:
return
self.splay(e, p)
status = super(SelfBalancingBinaryTree, self).delete(x)
return status

def join(self, other):
"""
Joins two trees current and other such that all elements of
the current splay tree are smaller than the elements of the other tree.

Parameters
==========

other: SplayTree
SplayTree which needs to be joined with the self tree.

"""
maxm = self.root_idx
while self.tree[maxm].right is not None:
maxm = self.tree[maxm].right
minm = other.root_idx
while other.tree[minm].left is not None:
minm = other.tree[minm].left
if not self.comparator(self.tree[maxm].key,
other.tree[minm].key):
raise ValueError("Elements of %s aren't less "
"than that of %s"%(self, other))
self.splay(maxm, self.tree[maxm].parent)
idx_update = self.tree._size
for node in other.tree:
if node is not None:
node_copy = TreeNode(node.key, node.data)
if node.left is not None:
node_copy.left = node.left + idx_update
if node.right is not None:
node_copy.right = node.right + idx_update
self.tree.append(node_copy)
else:
self.tree.append(node)
self.tree[self.root_idx].right = \
other.root_idx + idx_update

def split(self, x):
"""
Splits current splay tree into two trees such that one tree contains nodes
with key less than or equal to x and the other tree containing
nodes with key greater than x.

Parameters
==========

x: key
Key of the element on the basis of which split is performed.

Returns
=======

other: SplayTree
SplayTree containing elements with key greater than x.

"""
e, p = super(SelfBalancingBinaryTree, self).search(x, parent=True)
if e is None:
return
self.splay(e, p)
other = SplayTree(None, None)
if self.tree[self.root_idx].right is not None:
traverse = BinaryTreeTraversal(self)
elements = traverse.depth_first_search(order='pre_order', node=self.tree[self.root_idx].right)
for i in range(len(elements)):
super(SelfBalancingBinaryTree, other).insert(elements[i].key, elements[i].data)
for j in range(len(elements) - 1, -1, -1):
e, p = super(SelfBalancingBinaryTree, self).search(elements[j].key, parent=True)
self.tree[e] = None
self.tree[self.root_idx].right = None
return other

class BinaryTreeTraversal(object):
"""
Represents the traversals possible in
Expand Down
36 changes: 35 additions & 1 deletion pydatastructs/trees/tests/test_binary_trees.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydatastructs.trees.binary_trees import (
BinarySearchTree, BinaryTreeTraversal, AVLTree,
ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree)
ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree, SplayTree)
from pydatastructs.utils.raises_util import raises
from pydatastructs.utils.misc_util import TreeNode
from copy import deepcopy
Expand Down Expand Up @@ -348,3 +348,37 @@ def test_issue_234():
tree.insert(4.56, 4.56)
tree._left_rotate(5, 8)
assert tree.tree[tree.tree[8].parent].left == 8

def test_SplayTree():
t = SplayTree(100, 100)
t.insert(50, 50)
t.insert(200, 200)
t.insert(40, 40)
t.insert(30, 30)
t.insert(20, 20)
t.insert(55, 55)

assert str(t) == ("[(None, 100, 100, None), (None, 50, 50, None), "
"(0, 200, 200, None), (None, 40, 40, 1), (5, 30, 30, 3), "
"(None, 20, 20, None), (4, 55, 55, 2)]")
t.delete(40)
assert str(t) == ("[(None, 100, 100, None), '', (0, 200, 200, None), "
"(4, 50, 50, 6), (5, 30, 30, None), (None, 20, 20, None), "
"(None, 55, 55, 2)]")
t.delete(150)
assert str(t) == ("[(None, 100, 100, None), '', (0, 200, 200, None), (4, 50, 50, 6), "
"(5, 30, 30, None), (None, 20, 20, None), (None, 55, 55, 2)]")

t1 = SplayTree(1000, 1000)
t1.insert(2000, 2000)
assert str(t1) == ("[(None, 1000, 1000, None), (0, 2000, 2000, None)]")

t.join(t1)
assert str(t) == ("[(None, 100, 100, None), '', (6, 200, 200, 8), (4, 50, 50, None), "
"(5, 30, 30, None), (None, 20, 20, None), (3, 55, 55, 0), (None, 1000, 1000, None), "
"(7, 2000, 2000, None), '']")

s = t.split(200)
assert str(s) == ("[(1, 2000, 2000, None), (None, 1000, 1000, None)]")
assert str(t) == ("[(None, 100, 100, None), '', (6, 200, 200, None), (4, 50, 50, None), "
"(5, 30, 30, None), (None, 20, 20, None), (3, 55, 55, 0), '', '', '']")