import sqlite3
import shutil
import os
import json
import hashlib

import database_api

PATH_INIT_SCRIPT = "./init_database.sql"
PATH_FILESYSTEM = "./storage"
PATH_DATABASE = "./crisis_events.db"
PATH_METADATA = "./crisis_events_database_metadata.json"

class SqliteDatabaseAPI(database_api.CrisisEventsDatabase):
    
    def _connect(self):
        connection = sqlite3.connect(PATH_DATABASE)
        return connection, connection.cursor()

    def _disconnect(self,connection, cursor):
        cursor.close()
        connection.commit()
        connection.close()

    def get_next_user_id(self):
        next = None
        metadata = None
        with open(PATH_METADATA, "r") as f: #im lazy :)
            metadata = json.loads(f.read())
            next = metadata["Users"]
            metadata["Users"] += 1
        
        with open(PATH_METADATA, "w") as f: #who needs r+
            json.dump(metadata,f)
        
        return next

    def get_next_collection_id(self):
        next = None
        metadata = None
        with open(PATH_METADATA, "r") as f: #im lazy :)
            metadata = json.loads(f.read())
            next = metadata["Collections"]
            metadata["Collections"] += 1
        
        with open(PATH_METADATA, "w") as f: #who needs r+
            json.dump(metadata,f)
        
        return next


    def create_collection(self,user_id: int, collection_name:str,collection_type:str):
        connection, cur = self._connect()
        collection_id = self.get_next_collection_id()
        cur.execute("INSERT INTO event_collections VALUES (?, ?, ?, ?, ?, ?, ?, ?)", (collection_id,user_id,collection_type,"summary","t5 summary"," bert summary","nltk summary",collection_name,))
        print(f"Created new collection '{collection_name}' owned by {user_id}")
        self._disconnect(connection, cur)
        return collection_id
    
    def delete_collection(self, collection_name):
        connection, cur = self._connect()
        cur.execute("DELETE FROM event_collections WHERE collection_name = ?", (collection_name,))

        # Check if any row was affected
        if cur.rowcount > 0:
            print(f"Deleted collection '{collection_name}'")
            self._disconnect(connection, cur)
            return True
        else:
            print(f"Collection '{collection_name}' not found")
            self._disconnect(connection, cur)
            return False
    
    def create_raw_text_file(self,collection_id,path):
        connection, cur = self._connect()
        cur.execute("INSERT INTO raw_text_files VALUES (NULL, ?, ?)", (collection_id,path,))
        print(f"Created new raw_text_file '{path}' owned by collection {collection_id}")
        self._disconnect(connection, cur)
        return collection_id
    
    def get_raw_text_files(self,collection_id):
        connection, cur = self._connect()
        cur.execute("SELECT * FROM raw_text_files WHERE collection_id == ?", (collection_id,))
        result = cur.fetchall()
        self._disconnect(connection, cur)
        return [{"id":r[0],"collection_id":r[1],"path":r[2]} for r in result]

    
    def create_user(self, username: str, password: str):
        connection, cur = self._connect()
        user_id = self.get_next_user_id()
        user_hash = hashlib.md5(password.encode()).digest().hex()
        print(user_hash)
        cur.execute("INSERT INTO users VALUES (?, ?, ?)", (user_id,username,user_hash,))
        print(f"Created new user '{username}'")
        self._disconnect(connection, cur)
        return user_id

    def get_user_by_username(self, username: str):
        connection, cur = self._connect()
        cur.execute("SELECT * FROM users WHERE user_name == ?;", (username,))
        result = cur.fetchone()
        self._disconnect(connection, cur)
        return result
    
    def get_collection(self,collection_id:int):
        connection, cur = self._connect()
        cur.execute("SELECT * FROM event_collections WHERE collection_id == ?;", (collection_id,))
        result = cur.fetchone()

        json_result = None
        if result:
            json_result = {
                "collection_id":result[0],
                "user_id":result[1],
                "type":result[2],
                "collection_summary":result[3],
                "t5_summary":result[4],
                "bert_summary":result[5],
                "nltk_summary":result[6],
                "collection_name":result[7],
            }
        self._disconnect(connection, cur)
        return json_result
    
    def get_collections(self, user_id: int):
        connection, cur = self._connect()
        cur.execute("SELECT * FROM event_collections WHERE owner_id == ?;", (user_id,))
        result = cur.fetchall()

        # print(result)

        json_result = []
        if result:
            for item in result:
                json_result.append({
                    "collection_id":item[0],
                    "collection_name":item[7],
                })

        self._disconnect(connection, cur)
        return json_result


    def set_collection(self,collection_json):
        """
        json_result = {
                "collection_id":result[0],
                "user_id":result[1],
                "type":result[2],
                "collection_summary":result[3],
                "t5_summary":result[4],
                "bert_summary":result[5],
                "nltk_summary":result[6],
                "collection_name":result[7],
            }
        """
        connection, cur = self._connect()
        cur.execute("UPDATE event_collections SET type = ?, collection_summary = ?,t5_summary = ?,bert_summary = ?,nltk_summary = ?, collection_name = ? WHERE collection_id == ? AND owner_id == ?;", 
                    (collection_json["type"],
                     collection_json["collection_summary"],
                     collection_json["t5_summary"],
                     collection_json["bert_summary"],
                     collection_json["nltk_summary"],
                     collection_json["collection_name"],
                     collection_json["collection_id"],
                     collection_json["user_id"],
                    )
        )
        # result = cur.fetchone()
        self._disconnect(connection, cur)
        # return result
    
    def update_collection_summary(self,collection_id:int, summary:str):
        connection, cur = self._connect()
        cur.execute("UPDATE event_collections SET collection_summary = ? WHERE collection_id == ?;", (summary,collection_id,))
        # result = cur.fetchone()
        self._disconnect(connection, cur)
        # return result

    def get_sample_of_collections(self):
        connection, cur = self._connect()
        cur.execute("SELECT * FROM event_collections LIMIT 10;")
        result = cur.fetchall()
        self._disconnect(connection, cur)
        return result
    
    def get_sample_of_users(self):
        connection, cur = self._connect()
        cur.execute("SELECT * FROM users LIMIT 10;")
        result = cur.fetchall()
        self._disconnect(connection, cur)
        return result

    def initialize(self, reset = False):
        #try to delete the databse and all relevant data
        
        # try:
        #     shutil.rmtree(PATH_FILESYSTEM)
        # except FileNotFoundError:
        #     pass
        # finally:
        #     os.mkdir(PATH_FILESYSTEM)

        connection, cur = self._connect()
        with open(PATH_INIT_SCRIPT,"r") as f:
            sql = f.read()
            cur.executescript(sql)

        with open(PATH_METADATA, "w") as f:
            metadata = {
                "Users": 0,
                "Collections": 0,
            }
            json.dump(metadata,f)

        self._disconnect(connection, cur)
    
    def get_info(self) -> str:
        return "SQLITE Backend"

IMPLEMENTATION = SqliteDatabaseAPI

if __name__ == "__main__":
    crisis_db = SqliteDatabaseAPI()
    crisis_db.initialize()
    print(crisis_db.get_next_user_id())
    print(crisis_db.get_next_user_id())
    print(crisis_db.get_next_user_id())
    print(crisis_db.get_next_user_id())
    user_id = crisis_db.get_next_user_id()
    collection_id = crisis_db.create_collection(user_id,"My collection")

    print(crisis_db.get_collection(0,4))

    print(crisis_db.set_collection(0,4,"{'hi':'bye'}","new summary", "new name"))

    print(crisis_db.get_collection(0,4))

    # print(get_collection_title(collection_id))