sets/sets.py

181 lines
4.8 KiB
Python

import functools
import threading
import weakref
set_table = weakref.WeakValueDictionary()
set_table_lock = threading.Lock()
@functools.total_ordering
class InternedSetObject:
def __init__(self, elements):
assert type(elements) == tuple
self.elements = elements
def __contains__(self, element):
for i in self.elements:
if i == element:
return True
return False
def __iter__(self):
return iter(self.elements)
def union(self, other):
assert isinstance(other, InternedSetObject)
# Since the elements of both sets are already sorted, we can just join them
elements = []
own_index = 0
other_index = 0
while True:
if own_index == len(self.elements):
# Ran out of own elements, add those of the other and exit
elements.extend(other.elements[other_index:])
break
elif other_index == len(other.elements):
# Ran out of other's elements, add own and exit
elements.extend(self.elements[own_index:])
break
elif self.elements[own_index] == other.elements[other_index]:
# Both have the element, add it once (this takes care of deduplication)
elements.append(self.elements[own_index])
own_index += 1
other_index += 1
elif self.elements[own_index] < other.elements[other_index]:
# Our element goes first, add it
elements.append(self.elements[own_index])
own_index += 1
else:
# Other's element goes first, add it
elements.append(other.elements[other_index])
other_index += 1
return _new_set(tuple(elements))
def intersection(self, other):
assert isinstance(other, InternedSetObject)
# This works with the same basic idea as union
# The only difference here is that only duplicate elements get added
elements = []
own_index = 0
other_index = 0
while True:
if own_index == len(self.elements):
# Ran out of own elements, exit (since we don't have other's remaining elements)
break
elif other_index == len(other.elements):
# Ran out of other's elements, exit (since other doesn't have our remaining elements)
break
elif self.elements[own_index] == other.elements[other_index]:
# Both have the element, add it
elements.append(self.elements[own_index])
own_index += 1
other_index += 1
elif self.elements[own_index] < other.elements[other_index]:
# Our element goes first, skip it (since other doesn't have it)
own_index += 1
else:
# Other's element goes first, skip it (since we don't have it)
other_index += 1
return _new_set(tuple(elements))
def difference(self, other):
assert isinstance(other, InternedSetObject)
# This works with the same basic ide as union
# The only difference here is that we never add anything from the other
elements = []
own_index = 0
other_index = 0
while True:
if own_index == len(self.elements):
# Ran out of own elements, exit (since we don't want other's elements)
break
elif other_index == len(other.elements):
# Ran out of other's elements, add own and exit
elements.extend(self.elements[own_index:])
break
elif self.elements[own_index] == other.elements[other_index]:
# Both have the element, skip it
own_index += 1
other_index += 1
elif self.elements[own_index] < other.elements[other_index]:
# Our element goes first, add it
elements.append(self.elements[own_index])
own_index += 1
else:
# Other's element goes first, skip it (since we don't want its element in the final)
other_index += 1
return _new_set(tuple(elements))
def __eq__(self, other):
assert (id(self) == id(other) and self is other) or id(self) != id(other)
return self is other
def __lt__(self, other):
return id(self) < id(other)
def __hash__(self):
return hash(id(self))
def __repr__(self):
name = 'InternedSetObject'
if __name__ != '__main__':
name = '%s.%s' % (__name__, name)
return '%s(%s)' % (name, repr(self.elements))
def __str__(self):
if len(self.elements) > 0:
return '{%s}' % ', '.join(map(str, self.elements))
else:
return ''
def dedup_elements(elements):
"""Deduplicates a sorted iterable and returns it as a tuple"""
deduplicated = []
for element in elements:
if len(deduplicated) > 0 and element == deduplicated[-1]:
continue
else:
deduplicated.append(element)
return tuple(deduplicated)
def _new_set(elements):
"""Returns a set corresponding to elements list that is already sorted and deduped"""
global set_table, set_table_lock
with set_table_lock:
if elements in set_table:
return set_table[elements]
else:
set_object = InternedSetObject(elements)
set_table[elements] = set_object
return set_object
def set(elements = None):
"""Returns an InternedSetObject. The same object will be returned for all equivalent sets."""
if elements is not None:
elements = dedup_elements(sorted(elements))
else:
elements = ()
return _new_set(elements)