from __future__ import annotations

import logging
from datetime import date, datetime, timezone, timedelta

from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.date import DateTrigger
from telegram.constants import ParseMode

from .db import Database, parse_iso
from .stats import day_range_for, local_date
from .texts_sk import reminder_hours, reminder_minutes

logger = logging.getLogger(__name__)


class ReminderScheduler:
    HOURLY_START = 3

    def __init__(self, db: Database, bot, reminder_dm: bool, tz_name: str) -> None:
        self.db = db
        self.bot = bot
        self.reminder_dm = reminder_dm
        self.tz_name = tz_name
        self.scheduler = AsyncIOScheduler(timezone=timezone.utc)

    def start(self) -> None:
        self.scheduler.start()

    async def shutdown(self) -> None:
        self.scheduler.shutdown(wait=False)

    async def schedule_daily_thresholds(self, session_id: int, chat_id: int, user_id: int, start_ts_iso: str) -> None:
        start_ts = parse_iso(start_ts_iso).astimezone(timezone.utc)
        local_day = local_date(start_ts, self.tz_name)
        day_range = day_range_for(local_day, self.tz_name)
        day_key = str(local_day)

        stats_row = await self.db.get_user_stats(
            chat_id, user_id, day_range.start_utc.isoformat(), start_ts.isoformat()
        )
        total_before = int(stats_row["total_sec"])

        for threshold in (60, 110, 120):
            threshold_sec = threshold * 60
            if total_before >= threshold_sec:
                continue
            delay = threshold_sec - total_before
            run_at = start_ts + timedelta(seconds=delay)
            self.scheduler.add_job(
                self._send_daily_threshold,
                trigger=DateTrigger(run_date=run_at),
                args=[session_id, chat_id, user_id, day_key, threshold],
                id=self._job_id_daily(session_id, threshold),
                replace_existing=True,
            )

    async def schedule_hourly_threshold(self, session_id: int, start_ts_iso: str) -> None:
        start_ts = parse_iso(start_ts_iso).astimezone(timezone.utc)
        run_at = start_ts + timedelta(hours=self.HOURLY_START)
        self.scheduler.add_job(
            self._send_hourly_threshold,
            trigger=DateTrigger(run_date=run_at),
            args=[session_id],
            id=self._job_id_hourly(session_id),
            replace_existing=True,
        )

    def cancel_session(self, session_id: int) -> None:
        for threshold in (60, 110, 120):
            try:
                self.scheduler.remove_job(self._job_id_daily(session_id, threshold))
            except Exception:
                pass
        try:
            self.scheduler.remove_job(self._job_id_hourly(session_id))
        except Exception:
            pass

    async def rehydrate(self) -> None:
        sessions = await self.db.get_active_sessions()
        for session in sessions:
            await self._rehydrate_daily_thresholds(session)
            await self._rehydrate_hourly_thresholds(session)

    @staticmethod
    def _job_id_daily(session_id: int, threshold: int) -> str:
        return f"session-{session_id}-daily-{threshold}"

    @staticmethod
    def _job_id_hourly(session_id: int) -> str:
        return f"session-{session_id}-hourly"

    async def _send_daily_threshold(self, session_id: int, chat_id: int, user_id: int, day_key: str, threshold: int) -> None:
        session = await self.db.get_active_session_by_id(session_id)
        if not session:
            return

        try:
            day_obj = date.fromisoformat(day_key)
        except ValueError:
            day_obj = local_date(parse_iso(session["start_ts"]), self.tz_name)
        day_range = day_range_for(day_obj, self.tz_name)
        stats_row = await self.db.get_user_stats(
            chat_id, user_id, day_range.start_utc.isoformat(), day_range.end_utc.isoformat()
        )
        total_today = int(stats_row["total_sec"])
        if session["end_ts"] is None:
            start_ts = parse_iso(session["start_ts"]).astimezone(timezone.utc)
            total_today += int((datetime.now(timezone.utc) - start_ts).total_seconds())

        notif = await self.db.get_daily_notification(chat_id, user_id, day_key)
        already = False
        if notif:
            if threshold == 60 and notif["notified_60"]:
                already = True
            if threshold == 110 and notif["notified_110"]:
                already = True
            if threshold == 120 and notif["notified_120"]:
                already = True
        if already or total_today < threshold * 60:
            return

        name = session["display_name"] or session["username"] or str(user_id)
        text = reminder_minutes(name, threshold)
        await self._send_dm_with_fallback(user_id, chat_id, text)
        if threshold == 120:
            await self._notify_admins(chat_id, text)
        await self.db.set_daily_notified(chat_id, user_id, day_key, threshold)

    async def _rehydrate_daily_thresholds(self, session) -> None:
        now_utc = datetime.now(timezone.utc)
        day_key = str(local_date(now_utc, self.tz_name))
        day_range = day_range_for(local_date(now_utc, self.tz_name), self.tz_name)
        stats_row = await self.db.get_user_stats(
            session["chat_id"], session["user_id"], day_range.start_utc.isoformat(), day_range.end_utc.isoformat()
        )
        total_today = int(stats_row["total_sec"])
        if session["end_ts"] is None:
            start_ts = parse_iso(session["start_ts"]).astimezone(timezone.utc)
            total_today += int((now_utc - start_ts).total_seconds())

        for threshold in (60, 110, 120):
            notif = await self.db.get_daily_notification(session["chat_id"], session["user_id"], day_key)
            already = False
            if notif:
                if threshold == 60 and notif["notified_60"]:
                    already = True
                if threshold == 110 and notif["notified_110"]:
                    already = True
                if threshold == 120 and notif["notified_120"]:
                    already = True
            if already:
                continue
            if total_today >= threshold * 60:
                name = session["display_name"] or session["username"] or str(session["user_id"])
                text = reminder_minutes(name, threshold)
                await self._send_dm_with_fallback(session["user_id"], session["chat_id"], text)
                if threshold == 120:
                    await self._notify_admins(session["chat_id"], text)
                await self.db.set_daily_notified(session["chat_id"], session["user_id"], day_key, threshold)
            else:
                remaining = threshold * 60 - total_today
                run_at = datetime.now(timezone.utc) + timedelta(seconds=remaining)
                self.scheduler.add_job(
                    self._send_daily_threshold,
                    trigger=DateTrigger(run_date=run_at),
                    args=[session["id"], session["chat_id"], session["user_id"], day_key, threshold],
                    id=self._job_id_daily(session["id"], threshold),
                    replace_existing=True,
                )

    async def _send_hourly_threshold(self, session_id: int) -> None:
        session = await self.db.get_active_session_by_id(session_id)
        if not session:
            return

        now_utc = datetime.now(timezone.utc)
        start_ts = parse_iso(session["start_ts"]).astimezone(timezone.utc)
        elapsed_hours = int((now_utc - start_ts).total_seconds() // 3600)
        if elapsed_hours < self.HOURLY_START:
            run_at = start_ts + timedelta(hours=self.HOURLY_START)
            self.scheduler.add_job(
                self._send_hourly_threshold,
                trigger=DateTrigger(run_date=run_at),
                args=[session_id],
                id=self._job_id_hourly(session_id),
                replace_existing=True,
            )
            return

        notified_hour = session["notified_hour"] or 0
        if elapsed_hours <= notified_hour:
            next_hour = max(self.HOURLY_START, notified_hour + 1)
            run_at = start_ts + timedelta(hours=next_hour)
            self.scheduler.add_job(
                self._send_hourly_threshold,
                trigger=DateTrigger(run_date=run_at),
                args=[session_id],
                id=self._job_id_hourly(session_id),
                replace_existing=True,
            )
            return

        hours_to_report = max(self.HOURLY_START, elapsed_hours)
        name = session["display_name"] or session["username"] or str(session["user_id"])
        text = reminder_hours(name, hours_to_report)
        await self._send_dm_with_fallback(session["user_id"], session["chat_id"], text)
        await self.db.set_session_notified_hour(session_id, hours_to_report)

        next_hour = hours_to_report + 1
        run_at = start_ts + timedelta(hours=next_hour)
        self.scheduler.add_job(
            self._send_hourly_threshold,
            trigger=DateTrigger(run_date=run_at),
            args=[session_id],
            id=self._job_id_hourly(session_id),
            replace_existing=True,
        )

    async def _rehydrate_hourly_thresholds(self, session) -> None:
        now_utc = datetime.now(timezone.utc)
        start_ts = parse_iso(session["start_ts"]).astimezone(timezone.utc)
        elapsed_hours = int((now_utc - start_ts).total_seconds() // 3600)
        notified_hour = session["notified_hour"] or 0

        if elapsed_hours >= self.HOURLY_START and elapsed_hours > notified_hour:
            hours_to_report = elapsed_hours
            name = session["display_name"] or session["username"] or str(session["user_id"])
            text = reminder_hours(name, hours_to_report)
            await self._send_dm_with_fallback(session["user_id"], session["chat_id"], text)
            await self.db.set_session_notified_hour(session["id"], hours_to_report)
            notified_hour = hours_to_report

        next_hour = max(self.HOURLY_START, max(elapsed_hours, notified_hour) + 1)
        run_at = start_ts + timedelta(hours=next_hour)
        self.scheduler.add_job(
            self._send_hourly_threshold,
            trigger=DateTrigger(run_date=run_at),
            args=[session["id"]],
            id=self._job_id_hourly(session["id"]),
            replace_existing=True,
        )
    async def _send_dm_with_fallback(self, user_id: int, chat_id: int, text: str) -> None:
        try:
            await self.bot.send_message(
                chat_id=user_id,
                text=text,
                parse_mode=ParseMode.HTML,
                disable_web_page_preview=True,
            )
            return
        except Exception:
            pass
        try:
            await self.bot.send_message(
                chat_id=chat_id,
                text=text,
                parse_mode=ParseMode.HTML,
                disable_web_page_preview=True,
            )
        except Exception:
            pass

    async def _notify_admins(self, chat_id: int, text: str) -> None:
        try:
            admins = await self.bot.get_chat_administrators(chat_id)
        except Exception:
            admins = []
        sent_any = False
        for admin in admins:
            if admin.user.is_bot:
                continue
            try:
                await self.bot.send_message(
                    chat_id=admin.user.id,
                    text=text,
                    parse_mode=ParseMode.HTML,
                    disable_web_page_preview=True,
                )
                sent_any = True
            except Exception:
                pass
        if not sent_any:
            try:
                await self.bot.send_message(
                    chat_id=chat_id,
                    text=text,
                    parse_mode=ParseMode.HTML,
                    disable_web_page_preview=True,
                )
            except Exception:
                pass
