Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pydatastructs/trees/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
BinarySearchTree,
BinaryTreeTraversal,
AVLTree,
BinaryIndexedTree
BinaryIndexedTree,
SplayTree
)
__all__.extend(binary_trees.__all__)

Expand Down
120 changes: 119 additions & 1 deletion pydatastructs/trees/binary_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
'BinaryTree',
'BinarySearchTree',
'BinaryTreeTraversal',
'BinaryIndexedTree'
'BinaryIndexedTree',
'SplayTree'
]

class BinaryTree(object):
Expand Down Expand Up @@ -1059,3 +1060,120 @@ def get_sum(self, left_index, right_index):
self.get_prefix_sum(left_index - 1)
else:
return self.get_prefix_sum(right_index)

class SplayTree(SelfBalancingBinaryTree):
"""
Represents Splay Trees.

References
==========

.. [1] https://en.wikipedia.org/wiki/Splay_tree

"""
def _zig(self, x, p):
if self.tree[p].left == x:
super(SplayTree, self)._right_rotate(p, x)
else:
super(SplayTree, self)._left_rotate(p, x)

def _zig_zig(self, x, p):
super(SplayTree, self)._right_rotate(self.tree[p].parent, p)
super(SplayTree, self)._right_rotate(p, x)

def _zig_zag(self, x, p):
super(SplayTree, self)._left_right_rotate(self.tree[p].parent, p)

def _zag_zag(self, x, p):
super(SplayTree, self)._left_rotate(self.tree[p].parent, p)
super(SplayTree, self)._left_rotate(p, x)

def _zag_zig(self, x, p):
super(SplayTree, self)._right_left_rotate(self.tree[p].parent, p)

def splay(self, x, p):
while self.tree[x].parent is not None:
if self.tree[p].parent is None:
self._zig(x, p)
elif self.tree[p].left == x and self.tree[self.tree[p].parent].left == p:
self._zig_zig(x, p)
elif self.tree[p].right == x and self.tree[self.tree[p].parent].right == p:
self._zag_zag(x, p)
elif self.tree[p].left == x and self.tree[self.tree[p].parent].right == p:
self._zag_zig(x, p)
else:
self._zig_zag(x, p)
p = self.tree[x].parent

def insert(self, key, x):
super(SelfBalancingBinaryTree, self).insert(key, x)
e, p = super(SelfBalancingBinaryTree, self).search(key, parent=True)
self.tree[self.size-1].parent = p;
self.splay(e, p)

def delete(self, x):
e, p = super(SelfBalancingBinaryTree, self).search(x, parent=True)
if e is None:
return
self.splay(e, p)
b = super(SelfBalancingBinaryTree, self).delete(x, balancing_info=True)
return True

def join(self, other):
"""
Joins two trees current and other such that all elements of
the current splay tree are smaller than the elements of the other tree.

Parameters
==========

other: SplayTree
SplayTree which needs to be joined with the self tree.

"""
maxm = self.root_idx
while self.tree[maxm].right is not None:
maxm = self.tree[maxm].right
self.splay(maxm, self.tree[maxm].parent)
traverse = BinaryTreeTraversal(other)
elements = traverse.depth_first_search(order='pre_order', node=other.root_idx)
for i in range(len(elements)):
super(SelfBalancingBinaryTree, self).insert(elements[i].key, elements[i].data)
for j in range(len(elements) - 1, -1, -1):
e, p = super(SelfBalancingBinaryTree, other).search(elements[j].key, parent=True)
other.tree[e] = None

def split(self, x):
"""
Splits current splay tree into two trees such that one tree contains nodes
with key less than or equal to x and the other tree containing
nodes with key greater than x.

Parameters
==========

x: key
Key of the element on the basis of which split is performed.

Returns
=======

other: SplayTree
SplayTree containing elements with key greater than x.

"""
e, p = super(SelfBalancingBinaryTree, self).search(x, parent=True)
if e is None:
return
self.splay(e, p)
other = SplayTree(None, None)
if self.tree[self.root_idx].right is not None:
traverse = BinaryTreeTraversal(self)
elements = traverse.depth_first_search(order='pre_order', node=self.tree[self.root_idx].right)
for i in range(len(elements)):
super(SelfBalancingBinaryTree, other).insert(elements[i].key, elements[i].data)
for j in range(len(elements) - 1, -1, -1):
e, p = super(SelfBalancingBinaryTree, self).search(elements[j].key, parent=True)
self.tree[e] = None
self.tree[self.root_idx].right = None
return other
28 changes: 27 additions & 1 deletion pydatastructs/trees/tests/test_binary_trees.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydatastructs.trees.binary_trees import (
BinarySearchTree, BinaryTreeTraversal, AVLTree,
ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree)
ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree, SplayTree)
from pydatastructs.utils.raises_util import raises
from pydatastructs.utils.misc_util import TreeNode
from copy import deepcopy
Expand Down Expand Up @@ -348,3 +348,29 @@ def test_issue_234():
tree.insert(4.56, 4.56)
tree._left_rotate(5, 8)
assert tree.tree[tree.tree[8].parent].left == 8

def test_SplayTree():
t = SplayTree(100, 100)
t.insert(50, 50)
t.insert(200, 200)
t.insert(40, 40)
t.insert(30, 30)
t.insert(20, 20)
t.insert(55, 55)

assert str(t) == ("[(None, 100, 100, None), (None, 50, 50, None), (0, 200, 200, None), (None, 40, 40, 1), (5, 30, 30, 3), (None, 20, 20, None), (4, 55, 55, 2)]")
t.delete(40)
assert str(t) == ("[(None, 100, 100, None), '', (0, 200, 200, None), (4, 50, 50, 6), (5, 30, 30, None), (None, 20, 20, None), (None, 55, 55, 2)]")
t.delete(150)
assert str(t) == ("[(None, 100, 100, None), '', (0, 200, 200, None), (4, 50, 50, 6), (5, 30, 30, None), (None, 20, 20, None), (None, 55, 55, 2)]")
t1 = SplayTree(1000, 1000)
t1.insert(2000, 2000)

assert str(t1) == ("[(None, 1000, 1000, None), (0, 2000, 2000, None)]")

t.join(t1)
assert str(t) == ("[(None, 100, 100, None), '', (6, 200, 200, 7), (4, 50, 50, None), (5, 30, 30, None), (None, 20, 20, None), (3, 55, 55, 0), (8, 2000, 2000, None), (None, 1000, 1000, None)]")
s = t.split(200)

assert str(s) == ("[(1, 2000, 2000, None), (None, 1000, 1000, None)]")
assert str(t) == ("[(None, 100, 100, None), '', (6, 200, 200, None), (4, 50, 50, None), (5, 30, 30, None), (None, 20, 20, None), (3, 55, 55, 0), '', '']")