import functools import threading import weakref set_table = weakref.WeakValueDictionary() set_table_lock = threading.Lock() @functools.total_ordering class InternedSetObject: def __init__(self, elements): assert type(elements) == tuple self.elements = elements def __contains__(self, element): for i in self.elements: if i == element: return True return False def __iter__(self): return iter(self.elements) def union(self, other): assert isinstance(other, InternedSetObject) # Since the elements of both sets are already sorted, we can just join them elements = [] own_index = 0 other_index = 0 while True: if own_index == len(self.elements): # Ran out of own elements, add those of the other and exit elements.extend(other.elements[other_index:]) break elif other_index == len(other.elements): # Ran out of other's elements, add own and exit elements.extend(self.elements[own_index:]) break elif self.elements[own_index] == other.elements[other_index]: # Both have the element, add it once (this takes care of deduplication) elements.append(self.elements[own_index]) own_index += 1 other_index += 1 elif self.elements[own_index] < other.elements[other_index]: # Our element goes first, add it elements.append(self.elements[own_index]) own_index += 1 else: # Other's element goes first, add it elements.append(other.elements[other_index]) other_index += 1 return _new_set(tuple(elements)) def intersection(self, other): assert isinstance(other, InternedSetObject) # This works with the same basic idea as union # The only difference here is that only duplicate elements get added elements = [] own_index = 0 other_index = 0 while True: if own_index == len(self.elements): # Ran out of own elements, exit (since we don't have other's remaining elements) break elif other_index == len(other.elements): # Ran out of other's elements, exit (since other doesn't have our remaining elements) break elif self.elements[own_index] == other.elements[other_index]: # Both have the element, add it elements.append(self.elements[own_index]) own_index += 1 other_index += 1 elif self.elements[own_index] < other.elements[other_index]: # Our element goes first, skip it (since other doesn't have it) own_index += 1 else: # Other's element goes first, skip it (since we don't have it) other_index += 1 return _new_set(tuple(elements)) def difference(self, other): assert isinstance(other, InternedSetObject) # This works with the same basic ide as union # The only difference here is that we never add anything from the other elements = [] own_index = 0 other_index = 0 while True: if own_index == len(self.elements): # Ran out of own elements, exit (since we don't want other's elements) break elif other_index == len(other.elements): # Ran out of other's elements, add own and exit elements.extend(self.elements[own_index:]) break elif self.elements[own_index] == other.elements[other_index]: # Both have the element, skip it own_index += 1 other_index += 1 elif self.elements[own_index] < other.elements[other_index]: # Our element goes first, add it elements.append(self.elements[own_index]) own_index += 1 else: # Other's element goes first, skip it (since we don't want its element in the final) other_index += 1 return _new_set(tuple(elements)) def __eq__(self, other): assert (id(self) == id(other) and self is other) or id(self) != id(other) return self is other def __lt__(self, other): return id(self) < id(other) def __hash__(self): return hash(id(self)) def __repr__(self): name = 'InternedSetObject' if __name__ != '__main__': name = '%s.%s' % (__name__, name) return '%s(%s)' % (name, repr(self.elements)) def __str__(self): if len(self.elements) > 0: return '{%s}' % ', '.join(map(str, self.elements)) else: return '∅' def dedup_elements(elements): """Deduplicates a sorted iterable and returns it as a tuple""" deduplicated = [] for element in elements: if len(deduplicated) > 0 and element == deduplicated[-1]: continue else: deduplicated.append(element) return tuple(deduplicated) def _new_set(elements): """Returns a set corresponding to elements list that is already sorted and deduped""" global set_table, set_table_lock with set_table_lock: if elements in set_table: return set_table[elements] else: set_object = InternedSetObject(elements) set_table[elements] = set_object return set_object def set(elements = None): """Returns an InternedSetObject. The same object will be returned for all equivalent sets.""" if elements is not None: elements = dedup_elements(sorted(elements)) else: elements = () return _new_set(elements)