diff --git a/cogs/music/db_manager.py b/cogs/music/db_manager.py new file mode 100644 index 0000000..3191b54 --- /dev/null +++ b/cogs/music/db_manager.py @@ -0,0 +1,315 @@ +""" +Database Manager for Groovy-Zilean +Centralizes all database operations and provides a clean interface. +Makes future PostgreSQL migration much easier. +""" + +import sqlite3 +from contextlib import contextmanager +from typing import Optional, List, Tuple, Any +import config + + +class DatabaseManager: + """Manages database connections and operations""" + + def __init__(self): + self.db_path = config.get_db_path() + + @contextmanager + def get_connection(self): + """ + Context manager for database connections. + Automatically handles commit/rollback and closing. + + Usage: + with db.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(...) + """ + conn = sqlite3.connect(self.db_path) + try: + yield conn + conn.commit() + except Exception as e: + conn.rollback() + raise e + finally: + conn.close() + + def initialize_tables(self): + """Create database tables if they don't exist""" + with self.get_connection() as conn: + cursor = conn.cursor() + + # Create servers table + cursor.execute('''CREATE TABLE IF NOT EXISTS servers ( + server_id TEXT PRIMARY KEY, + is_playing INTEGER DEFAULT 0, + song_name TEXT, + song_url TEXT, + song_thumbnail TEXT, + loop_mode TEXT DEFAULT 'off', + volume INTEGER DEFAULT 100, + effect TEXT DEFAULT 'none', + song_start_time REAL DEFAULT 0, + song_duration INTEGER DEFAULT 0 + );''') + + # Set all to not playing on startup + cursor.execute("UPDATE servers SET is_playing = 0;") + + # Migrations for existing databases - add columns if missing + migrations = [ + ("loop_mode", "TEXT DEFAULT 'off'"), + ("volume", "INTEGER DEFAULT 100"), + ("effect", "TEXT DEFAULT 'none'"), + ("song_start_time", "REAL DEFAULT 0"), + ("song_duration", "INTEGER DEFAULT 0"), + ("song_thumbnail", "TEXT DEFAULT ''"), + ("song_url", "TEXT DEFAULT ''") + ] + + for col_name, col_type in migrations: + try: + cursor.execute(f"ALTER TABLE servers ADD COLUMN {col_name} {col_type};") + except sqlite3.OperationalError: + # Column already exists, skip + pass + + # Create songs/queue table + cursor.execute('''CREATE TABLE IF NOT EXISTS songs ( + server_id TEXT NOT NULL, + song_link TEXT, + queued_by TEXT, + position INTEGER NOT NULL, + title TEXT, + thumbnail TEXT, + duration INTEGER, + PRIMARY KEY (position), + FOREIGN KEY (server_id) REFERENCES servers(server_id) + );''') + + # Clear all songs on startup + cursor.execute("DELETE FROM songs;") + + # =================================== + # Server Operations + # =================================== + + def ensure_server_exists(self, server_id: str) -> None: + """Add server to database if it doesn't exist""" + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute('SELECT COUNT(*) FROM servers WHERE server_id = ?', (server_id,)) + if cursor.fetchone()[0] == 0: + cursor.execute('''INSERT INTO servers (server_id, loop_mode, volume, effect, song_thumbnail, song_url) + VALUES (?, 'off', 100, 'none', '', '')''', (server_id,)) + + def set_server_playing(self, server_id: str, playing: bool) -> None: + """Update server playing status""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + val = 1 if playing else 0 + cursor.execute("UPDATE servers SET is_playing = ? WHERE server_id = ?", (val, server_id)) + + def is_server_playing(self, server_id: str) -> bool: + """Check if server is currently playing""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + cursor.execute("SELECT is_playing FROM servers WHERE server_id = ?", (server_id,)) + res = cursor.fetchone() + return True if res and res[0] == 1 else False + + def set_current_song(self, server_id: str, title: str, url: str, thumbnail: str = "", duration: int = 0, start_time: float = 0) -> None: + """Update currently playing song information""" + with self.get_connection() as conn: + cursor = conn.cursor() + # Ensure duration is an integer + try: + duration = int(duration) + except: + duration = 0 + + cursor.execute(''' UPDATE servers + SET song_name = ?, song_url = ?, song_thumbnail = ?, song_start_time = ?, song_duration = ? + WHERE server_id = ?''', + (title, url, thumbnail, start_time, duration, server_id)) + + def get_current_song(self, server_id: str) -> dict: + """Get current song info""" + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute(''' SELECT song_name, song_thumbnail, song_url FROM servers WHERE server_id = ? LIMIT 1;''', (server_id,)) + result = cursor.fetchone() + + if result: + return {'title': result[0], 'thumbnail': result[1], 'url': result[2]} + return {'title': "Nothing", 'thumbnail': None, 'url': ''} + + def get_current_progress(self, server_id: str) -> Tuple[int, int, float]: + """Get playback progress (elapsed, duration, percentage)""" + import time + + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute('''SELECT song_start_time, song_duration, is_playing FROM servers WHERE server_id = ? LIMIT 1;''', (server_id,)) + result = cursor.fetchone() + + if not result or result[2] == 0: + return 0, 0, 0.0 + + start_time, duration, _ = result + + if duration is None or duration == 0: + return 0, 0, 0.0 + + elapsed = int(time.time() - start_time) + elapsed = min(elapsed, duration) + percentage = (elapsed / duration) * 100 if duration > 0 else 0 + + return elapsed, duration, percentage + + # =================================== + # Queue Operations + # =================================== + + def add_song(self, server_id: str, song_link: str, queued_by: str, title: str, thumbnail: str = "", duration: int = 0, position: Optional[int] = None) -> int: + """Add song to queue, returns position""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + + if position is None: + # Add to end + cursor.execute("SELECT MAX(position) FROM songs WHERE server_id = ?", (server_id,)) + max_pos = cursor.fetchone()[0] + position = (max_pos + 1) if max_pos is not None else 0 + else: + # Insert at specific position (shift others down) + cursor.execute("UPDATE songs SET position = position + 1 WHERE server_id = ? AND position >= ?", + (server_id, position)) + + cursor.execute("""INSERT INTO songs VALUES (?, ?, ?, ?, ?, ?, ?)""", + (server_id, song_link, queued_by, position, title, thumbnail, duration)) + + return position + + def get_next_song(self, server_id: str) -> Optional[Tuple]: + """Get the next song in queue (doesn't remove it)""" + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute('''SELECT * FROM songs WHERE server_id = ? ORDER BY position LIMIT 1;''', (server_id,)) + return cursor.fetchone() + + def remove_song(self, server_id: str, position: int) -> None: + """Remove song at position from queue""" + with self.get_connection() as conn: + cursor = conn.cursor() + cursor.execute('''DELETE FROM songs WHERE server_id = ? AND position = ?''', (server_id, position)) + + def get_queue(self, server_id: str, limit: int = 10) -> Tuple[int, List[Tuple]]: + """Get songs in queue (returns max_position, list of songs)""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + + cursor.execute("SELECT title, duration, queued_by FROM songs WHERE server_id = ? ORDER BY position LIMIT ?", + (server_id, limit)) + songs = cursor.fetchall() + + cursor.execute("SELECT MAX(position) FROM songs WHERE server_id = ?", (server_id,)) + max_pos = cursor.fetchone()[0] + max_pos = max_pos if max_pos is not None else -1 + + return max_pos, songs + + def clear_queue(self, server_id: str) -> None: + """Clear all songs from queue""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + cursor.execute("DELETE FROM songs WHERE server_id = ?", (server_id,)) + + def shuffle_queue(self, server_id: str) -> bool: + """Shuffle the queue randomly, returns success""" + import random + + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + + cursor.execute("SELECT position, song_link, queued_by, title, thumbnail, duration FROM songs WHERE server_id = ? ORDER BY position", + (server_id,)) + songs = cursor.fetchall() + + if len(songs) <= 1: + return False + + random.shuffle(songs) + cursor.execute("DELETE FROM songs WHERE server_id = ?", (server_id,)) + + for i, s in enumerate(songs): + cursor.execute("INSERT INTO songs VALUES (?, ?, ?, ?, ?, ?, ?)", + (server_id, s[1], s[2], i, s[3], s[4], s[5])) + + return True + + # =================================== + # Settings Operations + # =================================== + + def get_loop_mode(self, server_id: str) -> str: + """Get loop mode: 'off', 'song', or 'queue'""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + cursor.execute("SELECT loop_mode FROM servers WHERE server_id = ?", (server_id,)) + res = cursor.fetchone() + return res[0] if res else 'off' + + def set_loop_mode(self, server_id: str, mode: str) -> None: + """Set loop mode: 'off', 'song', or 'queue'""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + cursor.execute("UPDATE servers SET loop_mode = ? WHERE server_id = ?", (mode, server_id)) + + def get_volume(self, server_id: str) -> int: + """Get volume (0-200)""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + cursor.execute("SELECT volume FROM servers WHERE server_id = ?", (server_id,)) + res = cursor.fetchone() + return res[0] if res else 100 + + def set_volume(self, server_id: str, volume: int) -> int: + """Set volume (0-200), returns the set volume""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + cursor.execute("UPDATE servers SET volume = ? WHERE server_id = ?", (volume, server_id)) + return volume + + def get_effect(self, server_id: str) -> str: + """Get current audio effect""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + cursor.execute("SELECT effect FROM servers WHERE server_id = ?", (server_id,)) + res = cursor.fetchone() + return res[0] if res else 'none' + + def set_effect(self, server_id: str, effect: str) -> None: + """Set audio effect""" + with self.get_connection() as conn: + cursor = conn.cursor() + self.ensure_server_exists(server_id) + cursor.execute("UPDATE servers SET effect = ? WHERE server_id = ?", (effect, server_id)) + + +# Global instance +db = DatabaseManager()