from flask import Flask, request, abort, jsonify, make_response
from flask_cors import CORS
from database_api import CrisisEventsDatabase
#from database_implementation import sqlite_api
from runtime_import import runtime_import
import hashlib
from flask_jwt_extended import JWTManager, create_access_token, jwt_required, get_jwt_identity,create_refresh_token
import os 
import zipfile
import glob

#summarizers
import summarizer_implementations.t5 as t5
import summarizer_implementations.nltk_summarizer as nltk
import summarizer_implementations.bert as bert

app = Flask(__name__)
app.config['JWT_SECRET_KEY'] = 'PI'
jwt = JWTManager(app)
#pip install flask-cors
#pip install Flask-JWT-Extended
CORS(app) 


database:CrisisEventsDatabase = None

#this function is terrible, oh well!
def database_debug_view():
    db_html = ""
    user_db_html = ""
    for collection in database.get_sample_of_collections():
        db_html += f"<br/><code>{collection}</code>"

    for collection in database.get_sample_of_users():
        user_db_html += f"<br/><code>{collection}</code>"

    return f"""
    <html>
    <head></head>
    <body>
        <p>This service uses {database.get_info()} as its database</p>
        <p>Warning: if there are any '<'  or '>' or '/' or '\\' in the database: this will break! </p>
        <code>CollectionID, UserID, CollectionData, CollectionSummary, CollectionName</code>
        {db_html}
        <br/><br/><br/>
        <code>UserID, UserName, Hash</code>
        {user_db_html}
    </body>
    </html>
    """
@app.route('/login', methods = ['POST'])
def login():
     
    data = request.json  # Retrieve JSON data from the request
    print(f'The login data: {data}')

    if "authenticate" not in data:
        abort(400)

    data = data["authenticate"]
    if 'username' not in data or 'password' not in data:
        abort(400)

    username = data['username']
    password = data ['password']

    user_hash = hashlib.md5(password.encode()).digest().hex()

    user_data = database.get_user_by_username(username)

    if user_data is None:
        #abort(401)  # Unauthorized if user does not exist
        print(f'Username or password is incorrect: {data}')
        return {"status":"User does not exist"}, 401
        
    stored_hash = user_data[2]

    if user_hash != stored_hash:
        print(f'Password is incorrect: {data}')
        return {"status":"Incorrect Password"}, 401
  
    access_token = create_access_token(username)
    refresh_token = create_refresh_token(username)
    print(f'Login Successful!\n')
    return jsonify({"status":"success", "access_token": access_token, "refresh_token": refresh_token}), 200  # Return a response to indicate success

@app.route('/refresh', methods = ['POST'])
@jwt_required(refresh = True)
def refresh():
    current_user = get_jwt_identity()
    new_access_token = create_access_token(identity = current_user)
    return make_response(jsonify({"access_token": new_access_token}), 200)

@app.route('/database_service', methods=['POST','GET'])
def database_service():
    """
    POST Commands:
        create_collection
            - creates a collection under the logged in user

        get_collections
            - gets the title and id of every collection under the logged user
    """
    if request.method == 'POST':
        # print("request body: ",request.json)

        



        result = ''
        try:
            command = request.json["command"]
            if command == "create_collection":
                database.create_collection(0,request.json["data"]["collection_name"])
            elif command == "get_collection":
                result = str(database.get_collection(0,0))
            elif command == "get_collections":
                result = str(database.get_collection(0,0))
            else:
                print("Invalid command")
                abort(400)
        except KeyError:
            print("Got poorly formatted request")
            abort(400)

        return result, 200
    elif request.method == 'GET':
        return database_debug_view(), 200
    else:
        abort(400)

@app.route('/api/v1/get_collections', methods=['GET'])
def get_collections():
    "http://127.0.0.1:5000//api/v1/get_collections?user=0"
    user_id = request.args.get("user")

    if not user_id:
        abort(400)

    collections = database.get_collections(user_id)

    # print(collections)
    return {
        "status":"success",
        "collections": collections
    }, 200

@app.route('/api/v1/get_collection', methods=['GET'])
def get_collection():
    "http://127.0.0.1:5000//api/v1/get_collection?collection=0"
    collection_id = request.args.get("collection")

    if not collection_id:
        abort(400)

    collection = database.get_collection(collection_id)
    # print("get_collection:",collection)

    if collection:
        return {"status":"success","collection":collection}, 200
    else:
        return {"status":"failure"}, 200
    
@app.route('/api/v1/create_collection', methods=['POST'])
def v1_create_collection():

    data = request.json

    if "collection_info" not in data:
        abort(400)

    collection_info = data["collection_info"]

    if "collection_name" not in collection_info:
        abort(400)

    if "type" not in collection_info:
        abort(400)

    id = database.create_collection(0,collection_info["collection_name"],collection_info["type"])

    return {
        "status":"success",
        "collection_id":id
        }, 200

@app.route('/api/v1/update_collection_glob', methods=['POST'])
def v1_update_collection_glob():

    data = request.json

    if "collection_info" not in data:
        abort(400)

    collection_info = data["collection_info"]

    if "collection_id" not in collection_info:
        print("Missing ID")
        abort(400)
    
    if "glob" not in collection_info:
        print("Missing Glob")
        abort(400)

    database.update_collection_glob(collection_info["collection_id"],collection_info["glob"])
    return {
        "status":"success",
        }, 200

@app.route('/api/v1/summarize/t5', methods=['POST'])
def v1_summarize_t5():
    data = request.json

    if "collection_info" not in data:
        abort(400)

    collection_info = data["collection_info"]

    if "collection_id" not in collection_info:
        print("Missing ID")
        abort(400)

    # collection = database.get_collection(collection_info["collection_id"])

    files = database.get_raw_text_files(collection_info["collection_id"])
    
    glob = ""
    for file_data in files:
        with open(file_data["path"], "r", encoding="utf8") as f:
            glob += f.read()

    # summary = t5.summarize(collection["collection_data"]["glob"])
    summary = t5.summarize(glob)
    
    if summary:
        database.update_collection_summary(collection_info["collection_id"],summary)

        return {
            "status":"success",
            }, 200
    else:
        return {
            "status":"failure",
            }, 500
    
@app.route('/api/v1/summarize/nltk', methods=['POST'])
def v1_summarize_nltk():
    data = request.json

    if "collection_info" not in data:
        abort(400)

    collection_info = data["collection_info"]

    if "collection_id" not in collection_info:
        print("Missing ID")
        abort(400)

    # collection = database.get_collection(collection_info["collection_id"])

    files = database.get_raw_text_files(collection_info["collection_id"])
    
    glob = ""
    for file_data in files:
        with open(file_data["path"], "r", encoding="utf8") as f:
            glob += f.read()

    # summary = t5.summarize(collection["collection_data"]["glob"])
    summary = nltk.summarize(glob)
    
    if summary:
        database.update_collection_summary(collection_info["collection_id"],summary)

        return {
            "status":"success",
            }, 200
    else:
        return {
            "status":"failure",
            }, 500
    
@app.route('/api/v1/summarize/bert', methods=['POST'])
def v1_summarize_bert():
    data = request.json

    if "collection_info" not in data:
        abort(400)

    collection_info = data["collection_info"]

    if "collection_id" not in collection_info:
        print("Missing ID")
        abort(400)

    # collection = database.get_collection(collection_info["collection_id"])

    files = database.get_raw_text_files(collection_info["collection_id"])
    
    glob = ""
    for file_data in files:
        with open(file_data["path"], "r", encoding="utf8") as f:
            glob += f.read()

    # summary = t5.summarize(collection["collection_data"]["glob"])
    summary = bert.summarize(glob)
    
    if summary:
        database.update_collection_summary(collection_info["collection_id"],summary)

        return {
            "status":"success",
            }, 200
    else:
        return {
            "status":"failure",
            }, 500
    

@app.route('/api/v1/upload_raw_text', methods=['POST'])
def v1_upload_raw_text():
    "http://127.0.0.1:5000//api/v1/upload_raw_text?collection=0"
    collection_id = request.args.get("collection")

    if not collection_id:
        abort(400)

    # collection = database.get_collection(collection_id)

    os.makedirs(f"./backend/storage/{collection_id}",exist_ok=True)

    for zip_file in request.files.keys():
        path = f"./backend/storage/{collection_id}/{zip_file}"
        request.files[zip_file].save(path)

        with zipfile.ZipFile(path, 'r') as zip_ref:
            zip_ref.extractall(f"./backend/storage/{collection_id}")

        glob_path = r'./backend/storage/' + collection_id + r'/*.txt'

        for file_path in glob.glob(glob_path):
            database.create_raw_text_file(collection_id, file_path)


    return {"status":"success"}, 200

@app.route('/api/v1/upload_url_file', methods=['POST'])
def v1_upload_url_file():
    "http://127.0.0.1:5000//api/v1/upload_url_file?collection=0"
    collection_id = request.args.get("collection")

    if not collection_id:
        abort(400)

    # collection = database.get_collection(collection_id)

    os.makedirs(f"./backend/storage/{collection_id}",exist_ok=True)

    for url_file in request.files.keys():
        path = f"./backend/storage/{collection_id}/{url_file}"
        request.files[url_file].save(path)
        print(f"Saved: {path}")


    return {"status":"success"}, 200
  
@app.route('/api/v1/get_items', methods=['GET'])
def v1_get_items():
    "http://127.0.0.1:5000//api/v1/get_items?collection=0"
    collection_id = request.args.get("collection")
    file_type = request.args.get("type")

    if not collection_id:
        abort(400)

    collection = database.get_collection(collection_id)

    c_type = collection["type"]

    items = []
    if c_type == "text":
        glob_path = r'./backend/storage/' + collection_id + r'/*.txt'
        for file_path in glob.glob(glob_path):
            items.append(file_path)
    elif c_type == "url":
        glob_path = r'./backend/storage/' + collection_id + r'/*.txt'
        for file_path in glob.glob(glob_path):
            with open(file_path) as file:
                items = [line.rstrip() for line in file]

    print(items)

    return {"status":"success","files":items}, 200


@app.route('/testing', methods=['POST','GET'])
def testing():
    if request.method == 'POST':
        return ["Success Post request"], 200

    elif request.method == 'GET':
        print("get request: ",request.args)
        return [{"test":"hello"}], 200

    abort(400)

 



if __name__ == '__main__':
    database = runtime_import("database_implementation/")[0].IMPLEMENTATION()
    database.initialize()
    database.create_user("test_user","12345")
    database.create_user("test_user2","12345")
    

    app.run()