buranun/database.py

160 lines
4.6 KiB
Python

import enum
import random
import threading
import unicodedata
import sqlite3
from collections import namedtuple
from passlib.hash import argon2
import config
UserInfo = namedtuple('UserInfo', ('id', 'parent', 'status', 'username', 'email', 'comment'))
# ------------------------------------------------------------------
# General
# ------------------------------------------------------------------
class userstatus(enum.Enum):
# These will be stored in the database, be mindful of not changing the numbers
deleted = 0
normal = 1
admin = 2
csprng = random.SystemRandom()
def connect():
"""Connect to the database
Requires config.load() to have been called beforehand"""
return sqlite3.connect(config.database_file)
# ------------------------------------------------------------------
# Users
# ------------------------------------------------------------------
user_modify_lock = threading.Lock()
def add_user(db, *, username, password, email, parent, status):
"""Add a user to the database
Returns True is user was added succesfully and False if username was already in use
Will not commit the changes itself, so run .commit() on the database object yourself."""
global csprgn
assert type(username) == str
assert type(password) == str
assert type(email) == str
assert type(parent) == int or parent is None
assert status in userstatus
# Unicode normalize the username
username = unicodedata.normalize('NFKC', username)
# First unicode normalize the password, then hash it with argon2
password = unicodedata.normalize('NFKC', password)
password = argon2.hash(password)
# Convert status into an int for storage
status = status.value
# We don't want any changes to the database to occur while we check if ID and username are available
with user_modify_lock:
cursor = db.cursor()
# Check that the username is unique
cursor.execute('SELECT id FROM users WHERE username = ?;', (username,))
results = cursor.fetchall()
if len(results) != 0:
return False # Username is already in use
# Generate a user ID
while True:
# SQLite uses 64 bit signed ints, so generate at max 2⁶³-1
userid = csprng.randrange(2 << 63)
# Check that the user ID is unique
cursor.execute('SELECT id FROM users WHERE id = ?;', (userid,))
results = cursor.fetchall()
if len(results) == 0:
break # It is unique
# Add the user into the database
cursor.execute('PRAGMA foreign_keys = ON;') # Fail if we insert a user with bogus parent field
cursor.execute('INSERT INTO users VALUES (?, ?, ?, ?, ?, ?, ?);', (userid, parent, status, password, username, email, ''))
return True
def get_userid(db, username):
"""Returns the user ID associated with given username
If no user was found, returns None"""
# Unicode normalize the username
username = unicodedata.normalize('NFKC', username)
# Get the user ID
cursor = db.cursor()
cursor.execute('SELECT id FROM users WHERE username = ?;', (username,))
results = cursor.fetchall()
# If no user was found, return None
if len(results) != 1:
return None
return results[0][0]
def check_password(db, userid, password):
"""Checks the password for given userid
Will return True if the password matches and False otherwise"""
# Unicode normalize the password
password = unicodedata.normalize('NFKC', password)
# Get the password and status
cursor = db.cursor()
cursor.execute('SELECT password, status FROM users WHERE id = ?;', (userid,))
results = cursor.fetchall()
# If no user of that name, fail
if len(results) != 1:
return False
hashed, status = results[0]
# If user has been deleted, fail
if status == userstatus.deleted:
return False
# Check the password
return argon2.verify(password, hashed)
def get_user_info(db, userid):
"""Returns a UserInfo object representing the data associated with a user
If no user was found, returns None"""
cursor = db.cursor()
cursor.execute('SELECT id, parent, status, username, email, comment FROM users WHERE id = ?;', (userid,))
results = cursor.fetchall()
# If no user was found, return None
if len(results) != 1:
return None
userid, parent, status, username, email, comment = results[0]
# Translate status into enum
status = userstatus(status)
return UserInfo(userid, parent, status, username, email, comment)
# ------------------------------------------------------------------
# Boards
# ------------------------------------------------------------------
def list_boards(db):
"""Lists the boards that exist at the moment"""
cursor = db.cursor()
cursor.execute('SELECT name FROM boards;')
results = cursor.fetchall()
# The results look like [('foo',), ('bar',), ('baz',)]
return [i[0] for i in results]