145 lines
5.3 KiB
Python

import base64
import hashlib
import os
from flask import Flask, request, redirect, url_for, render_template, flash
from flask_login import LoginManager, UserMixin, login_user, login_required, logout_user
from ldap3 import Server, Connection, MODIFY_REPLACE
import jwt
from datetime import datetime, timedelta
from dotenv import load_dotenv
from functools import wraps
# Flask setup
app = Flask(__name__)
app.secret_key = os.getenv("FLASK_SECRET_KEY", "supersecretflaskkey")
# Flask-Login setup
login_manager = LoginManager()
login_manager.init_app(app)
load_dotenv()
LDAP_SERVER = os.getenv("LDAP_SERVER")
USER_DN = os.getenv("USER_DN")
GROUP_DN = os.getenv("GROUP_DN")
LDAP_SERVICE_USER = os.getenv("LDAP_SERVICE_USER")
LDAP_SERVICE_PASSWORD = os.getenv("LDAP_SERVICE_PASSWORD")
JWT_SECRET = os.getenv("JWT_SECRET")
JWT_ALGORITHM = "HS256"
JWT_EXPIRATION_SECONDS = 3600
# User model
class User(UserMixin):
def __init__(self, id, dn):
self.id = id
self.dn = dn
@login_manager.user_loader
def load_user(user_id):
return User(user_id, f"uid={user_id},{USER_DN}")
# JWT utility functions
def generate_jwt(payload):
expiration = datetime.utcnow() + timedelta(seconds=JWT_EXPIRATION_SECONDS)
payload.update({"exp": expiration})
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
def decode_jwt(token):
try:
return jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
# Decorator for JWT authentication
def jwt_required(f):
@wraps(f)
def decorated_function(*args, **kwargs):
token = request.cookies.get("jwt")
if not token:
flash("You must log in to access this page.", "error")
return redirect(url_for("login"))
payload = decode_jwt(token)
if not payload:
flash("Session expired or invalid. Please log in again.", "error")
return redirect(url_for("login"))
request.user = payload.get("username")
return f(*args, **kwargs)
return decorated_function
# Routes
@app.route("/", methods=["GET", "POST"])
def login():
if request.method == "POST":
username = request.form["username"]
password = request.form["password"]
user_dn = f"uid={username},{USER_DN}"
server = Server(LDAP_SERVER)
conn = Connection(server, user=user_dn, password=password)
if conn.bind():
login_user(User(username, user_dn))
token = generate_jwt({"username": username})
response = redirect(url_for("profile"))
response.set_cookie("jwt", token, httponly=True, secure=True)
return response
flash("Invalid credentials", "error")
return render_template("login.html")
@app.route("/logout")
@login_required
def logout():
logout_user()
response = redirect(url_for("login"))
response.delete_cookie("jwt")
return response
@app.route("/profile", methods=["GET", "POST"])
@login_required
@jwt_required
def profile():
username = request.user
user_dn = f"uid={username},{USER_DN}"
server = Server(LDAP_SERVER)
# Connect as the service account
service_conn = Connection(server, user=LDAP_SERVICE_USER, password=LDAP_SERVICE_PASSWORD, auto_bind=True)
service_conn.search(user_dn, "(objectClass=klUser)", attributes=["*"])
user_attrs = service_conn.entries[0]
if request.method == "POST":
if "protectedAccount" in user_attrs and user_attrs.protectedAccount == True:
flash("User has protectedAccount: TRUE. Cannot edit!", "error")
else:
updates = {
"sn": request.form["sn"],
"cn": request.form["cn"],
"givenName": request.form["givenName"],
"sipPassword": request.form["sipPassword"],
}
for key, value in updates.items():
service_conn.modify(user_dn, {key: [(MODIFY_REPLACE, [value])]})
ssh_keys = request.form.get('sshPublicKey', '').split('\n')
ssh_keys = map(str.strip, ssh_keys)
ssh_keys = list(set(filter(lambda x: not x.startswith('#') and len(x) > 0, ssh_keys)))
service_conn.modify(user_dn, {"sshPublicKey": [(MODIFY_REPLACE, ssh_keys)]})
if "password" in request.form and request.form["password"].strip() and "password_repeat" in request.form and request.form["password_repeat"].strip():
if request.form["password"] == request.form["password_repeat"]:
hashed_password = "{SHA}" + base64.b64encode(hashlib.sha1(request.form["password"].encode()).digest()).decode()
service_conn.modify(user_dn, {"userPassword": [(MODIFY_REPLACE, [hashed_password])]})
else:
flash("Passwords do not match", "error")
flash("Profile updated successfully", "success")
service_conn.search(user_dn, "(objectClass=klUser)", attributes=["*"])
user_attrs = service_conn.entries[0]
service_conn.search(GROUP_DN, f"(member={user_dn})", attributes=["cn"])
groups = [entry.cn.value for entry in service_conn.entries]
return render_template("profile.html", user_attrs=user_attrs, groups=groups)
if __name__ == "__main__":
app.run(debug=True, host="127.0.0.1", port=5000)