from __future__ import annotations

import aiosqlite
from datetime import datetime, timezone


def utcnow_iso() -> str:
    return datetime.now(timezone.utc).isoformat()


def parse_iso(value: str) -> datetime:
    return datetime.fromisoformat(value)


class Database:
    def __init__(self, path: str) -> None:
        self.path = path
        self.conn: aiosqlite.Connection | None = None

    async def connect(self) -> None:
        self.conn = await aiosqlite.connect(self.path)
        self.conn.row_factory = aiosqlite.Row
        await self.conn.execute("PRAGMA journal_mode=WAL")
        await self.conn.execute("PRAGMA foreign_keys=ON")
        await self.init_db()

    async def close(self) -> None:
        if self.conn:
            await self.conn.close()

    async def init_db(self) -> None:
        assert self.conn
        await self.conn.executescript(
            """
            CREATE TABLE IF NOT EXISTS sessions (
              id INTEGER PRIMARY KEY,
              chat_id INTEGER NOT NULL,
              user_id INTEGER NOT NULL,
              username TEXT,
              display_name TEXT,
              task_name TEXT,
              start_ts TEXT NOT NULL,
              end_ts TEXT NULL,
              duration_sec INTEGER NULL,
              notified_60m INTEGER DEFAULT 0,
              created_at TEXT NOT NULL
            );
            CREATE TABLE IF NOT EXISTS daily_notifications (
              chat_id INTEGER NOT NULL,
              user_id INTEGER NOT NULL,
              day TEXT NOT NULL,
              notified_60 INTEGER DEFAULT 0,
              notified_110 INTEGER DEFAULT 0,
              notified_120 INTEGER DEFAULT 0,
              PRIMARY KEY (chat_id, user_id, day)
            );
            CREATE INDEX IF NOT EXISTS idx_sessions_chat_user_start
              ON sessions(chat_id, user_id, start_ts);
            CREATE INDEX IF NOT EXISTS idx_sessions_chat_end
              ON sessions(chat_id, end_ts);
            CREATE INDEX IF NOT EXISTS idx_sessions_user_start
              ON sessions(user_id, start_ts);
            """
        )
        try:
            await self.conn.execute("ALTER TABLE sessions ADD COLUMN task_name TEXT")
            await self.conn.commit()
        except Exception:
            pass
        try:
            await self.conn.execute("ALTER TABLE sessions ADD COLUMN notified_hour INTEGER DEFAULT 0")
            await self.conn.commit()
        except Exception:
            pass
        await self.conn.commit()

    async def get_active_session(self, chat_id: int, user_id: int):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE chat_id = ? AND user_id = ? AND end_ts IS NULL
            ORDER BY start_ts DESC
            LIMIT 1
            """,
            (chat_id, user_id),
        )
        row = await cursor.fetchone()
        await cursor.close()
        return row

    async def get_active_session_by_id(self, session_id: int):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE id = ? AND end_ts IS NULL
            """,
            (session_id,),
        )
        row = await cursor.fetchone()
        await cursor.close()
        return row

    async def get_session_by_id(self, session_id: int):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE id = ?
            """,
            (session_id,),
        )
        row = await cursor.fetchone()
        await cursor.close()
        return row

    async def get_active_session_any(self, user_id: int):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE user_id = ? AND end_ts IS NULL
            ORDER BY start_ts DESC
            LIMIT 1
            """,
            (user_id,),
        )
        row = await cursor.fetchone()
        await cursor.close()
        return row

    async def create_session(
        self,
        chat_id: int,
        user_id: int,
        username: str | None,
        display_name: str,
        start_ts: str,
        task_name: str | None = None,
    ) -> int:
        assert self.conn
        created_at = utcnow_iso()
        cursor = await self.conn.execute(
            """
            INSERT INTO sessions
              (chat_id, user_id, username, display_name, task_name, start_ts, end_ts, duration_sec, notified_60m, created_at)
            VALUES (?, ?, ?, ?, ?, ?, NULL, NULL, 0, ?)
            """,
            (chat_id, user_id, username, display_name, task_name, start_ts, created_at),
        )
        await self.conn.commit()
        return cursor.lastrowid

    async def end_session(self, session_id: int, end_ts: str, duration_sec: int) -> None:
        assert self.conn
        await self.conn.execute(
            """
            UPDATE sessions
            SET end_ts = ?, duration_sec = ?
            WHERE id = ?
            """,
            (end_ts, duration_sec, session_id),
        )
        await self.conn.commit()

    async def set_session_notified_hour(self, session_id: int, hour: int) -> None:
        assert self.conn
        await self.conn.execute(
            """
            UPDATE sessions
            SET notified_hour = ?
            WHERE id = ?
            """,
            (hour, session_id),
        )
        await self.conn.commit()

    async def mark_notified(self, session_id: int) -> None:
        assert self.conn
        await self.conn.execute(
            """
            UPDATE sessions
            SET notified_60m = 1
            WHERE id = ?
            """,
            (session_id,),
        )
        await self.conn.commit()

    async def get_active_sessions(self):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE end_ts IS NULL
            """
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def list_active_sessions_by_chat(self, chat_id: int):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE chat_id = ? AND end_ts IS NULL
            ORDER BY start_ts ASC
            """,
            (chat_id,),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def list_user_sessions_in_range(self, chat_id: int, user_id: int, start_ts: str, end_ts: str):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE chat_id = ? AND user_id = ?
              AND start_ts <= ?
              AND (end_ts IS NULL OR end_ts >= ?)
            ORDER BY start_ts DESC
            """,
            (chat_id, user_id, end_ts, start_ts),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def list_user_sessions_in_range_all_chats(self, user_id: int, start_ts: str, end_ts: str):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE user_id = ?
              AND start_ts <= ?
              AND (end_ts IS NULL OR end_ts >= ?)
            ORDER BY start_ts DESC
            """,
            (user_id, end_ts, start_ts),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def get_daily_notification(self, chat_id: int, user_id: int, day: str):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM daily_notifications
            WHERE chat_id = ? AND user_id = ? AND day = ?
            """,
            (chat_id, user_id, day),
        )
        row = await cursor.fetchone()
        await cursor.close()
        return row

    async def ensure_daily_notification_row(self, chat_id: int, user_id: int, day: str) -> None:
        assert self.conn
        await self.conn.execute(
            """
            INSERT OR IGNORE INTO daily_notifications (chat_id, user_id, day)
            VALUES (?, ?, ?)
            """,
            (chat_id, user_id, day),
        )
        await self.conn.commit()

    async def set_daily_notified(self, chat_id: int, user_id: int, day: str, level: int) -> None:
        assert self.conn
        await self.ensure_daily_notification_row(chat_id, user_id, day)
        if level == 60:
            col = "notified_60"
        elif level == 110:
            col = "notified_110"
        elif level == 120:
            col = "notified_120"
        else:
            return
        await self.conn.execute(
            f"UPDATE daily_notifications SET {col} = 1 WHERE chat_id = ? AND user_id = ? AND day = ?",
            (chat_id, user_id, day),
        )
        await self.conn.commit()

    async def get_user_stats(
        self,
        chat_id: int | None,
        user_id: int,
        start_ts: str,
        end_ts: str,
    ):
        assert self.conn
        if chat_id is None:
            query = """
                SELECT
                  COUNT(*) AS sessions_count,
                  COALESCE(SUM(duration_sec), 0) AS total_sec
                FROM sessions
                WHERE user_id = ?
                  AND end_ts IS NOT NULL
                  AND end_ts >= ? AND end_ts <= ?
            """
            params = (user_id, start_ts, end_ts)
        else:
            query = """
                SELECT
                  COUNT(*) AS sessions_count,
                  COALESCE(SUM(duration_sec), 0) AS total_sec
                FROM sessions
                WHERE chat_id = ? AND user_id = ?
                  AND end_ts IS NOT NULL
                  AND end_ts >= ? AND end_ts <= ?
            """
            params = (chat_id, user_id, start_ts, end_ts)

        cursor = await self.conn.execute(query, params)
        row = await cursor.fetchone()
        await cursor.close()
        return row

    async def get_user_stats_no_task(
        self,
        chat_id: int,
        user_id: int,
        start_ts: str,
        end_ts: str,
    ):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT
              COUNT(*) AS sessions_count,
              COALESCE(SUM(duration_sec), 0) AS total_sec
            FROM sessions
            WHERE chat_id = ? AND user_id = ?
              AND end_ts IS NOT NULL
              AND (task_name IS NULL OR task_name = '')
              AND end_ts >= ? AND end_ts <= ?
            """,
            (chat_id, user_id, start_ts, end_ts),
        )
        row = await cursor.fetchone()
        await cursor.close()
        return row

    async def get_group_totals_for_user(
        self,
        user_id: int,
        start_ts: str,
        end_ts: str,
        limit: int = 3,
    ):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT chat_id, COALESCE(SUM(duration_sec), 0) AS total_sec
            FROM sessions
            WHERE user_id = ?
              AND end_ts IS NOT NULL
              AND end_ts >= ? AND end_ts <= ?
            GROUP BY chat_id
            ORDER BY total_sec DESC
            LIMIT ?
            """,
            (user_id, start_ts, end_ts, limit),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def get_leaderboard(
        self,
        chat_id: int,
        start_ts: str,
        end_ts: str,
        limit: int = 10,
    ):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT user_id, COALESCE(MAX(display_name), '') AS display_name,
                   COALESCE(SUM(duration_sec), 0) AS total_sec
            FROM sessions
            WHERE chat_id = ?
              AND end_ts IS NOT NULL
              AND end_ts >= ? AND end_ts <= ?
            GROUP BY user_id
            ORDER BY total_sec DESC
            LIMIT ?
            """,
            (chat_id, start_ts, end_ts, limit),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def get_group_user_stats(self, chat_id: int, start_ts: str, end_ts: str):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT user_id,
                   COALESCE(MAX(display_name), '') AS display_name,
                   COUNT(*) AS sessions_count,
                   COALESCE(SUM(duration_sec), 0) AS total_sec
            FROM sessions
            WHERE chat_id = ?
              AND end_ts IS NOT NULL
              AND end_ts >= ? AND end_ts <= ?
            GROUP BY user_id
            ORDER BY total_sec DESC
            """,
            (chat_id, start_ts, end_ts),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def get_group_user_stats_all_chats(self, chat_id: int, start_ts: str, end_ts: str):
        assert self.conn
        cursor = await self.conn.execute(
            """
            WITH users AS (
              SELECT DISTINCT user_id
              FROM sessions
              WHERE chat_id = ?
                AND end_ts IS NOT NULL
                AND end_ts >= ? AND end_ts <= ?
            )
            SELECT s.user_id,
                   COALESCE(MAX(s.display_name), '') AS display_name,
                   COUNT(*) AS sessions_count,
                   COALESCE(SUM(s.duration_sec), 0) AS total_sec
            FROM sessions s
            JOIN users u ON u.user_id = s.user_id
            WHERE s.end_ts IS NOT NULL
              AND s.end_ts >= ? AND s.end_ts <= ?
            GROUP BY s.user_id
            ORDER BY total_sec DESC
            """,
            (chat_id, start_ts, end_ts, start_ts, end_ts),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def get_leaderboard_all_chats(self, chat_id: int, start_ts: str, end_ts: str, limit: int = 10):
        assert self.conn
        cursor = await self.conn.execute(
            """
            WITH users AS (
              SELECT DISTINCT user_id
              FROM sessions
              WHERE chat_id = ?
                AND end_ts IS NOT NULL
                AND end_ts >= ? AND end_ts <= ?
                AND (task_name IS NULL OR task_name = '')
            )
            SELECT s.user_id,
                   COALESCE(MAX(s.display_name), '') AS display_name,
                   COALESCE(SUM(s.duration_sec), 0) AS total_sec
            FROM sessions s
            JOIN users u ON u.user_id = s.user_id
            WHERE s.end_ts IS NOT NULL
              AND (s.task_name IS NULL OR s.task_name = '')
              AND s.end_ts >= ? AND s.end_ts <= ?
            GROUP BY s.user_id
            ORDER BY total_sec DESC
            LIMIT ?
            """,
            (chat_id, start_ts, end_ts, start_ts, end_ts, limit),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def get_user_task_totals(self, user_id: int, start_ts: str, end_ts: str, limit: int = 50):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT COALESCE(task_name, '') AS task_name,
                   COALESCE(SUM(duration_sec), 0) AS total_sec,
                   COUNT(*) AS sessions_count
            FROM sessions
            WHERE user_id = ?
              AND end_ts IS NOT NULL
              AND task_name IS NOT NULL
              AND task_name != ''
              AND end_ts >= ? AND end_ts <= ?
            GROUP BY task_name
            ORDER BY total_sec DESC
            LIMIT ?
            """,
            (user_id, start_ts, end_ts, limit),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def list_sessions(
        self,
        limit: int = 200,
        chat_id: int | None = None,
    ):
        assert self.conn
        if chat_id is None:
            query = """
                SELECT * FROM sessions
                ORDER BY start_ts DESC
                LIMIT ?
            """
            params = (limit,)
        else:
            query = """
                SELECT * FROM sessions
                WHERE chat_id = ?
                ORDER BY start_ts DESC
                LIMIT ?
            """
            params = (chat_id, limit)
        cursor = await self.conn.execute(query, params)
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def list_sessions_in_range(self, start_ts: str, end_ts: str):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE start_ts <= ?
              AND (end_ts IS NULL OR end_ts >= ?)
            ORDER BY start_ts DESC
            """,
            (end_ts, start_ts),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def list_sessions_in_range_by_chat(self, chat_id: int, start_ts: str, end_ts: str):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE chat_id = ?
              AND start_ts <= ?
              AND (end_ts IS NULL OR end_ts >= ?)
            ORDER BY start_ts DESC
            """,
            (chat_id, end_ts, start_ts),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def list_sessions_by_name(self, name: str):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT * FROM sessions
            WHERE display_name = ? OR username = ?
            ORDER BY start_ts DESC
            """,
            (name, name),
        )
        rows = await cursor.fetchall()
        await cursor.close()
        return rows

    async def get_summary(self, start_ts: str, end_ts: str):
        assert self.conn
        cursor = await self.conn.execute(
            """
            SELECT
              COUNT(*) AS sessions_count,
              COALESCE(SUM(duration_sec), 0) AS total_sec
            FROM sessions
            WHERE end_ts IS NOT NULL
              AND end_ts >= ? AND end_ts <= ?
            """,
            (start_ts, end_ts),
        )
        row = await cursor.fetchone()
        await cursor.close()
        return row
