145 lines
5.3 KiB
Python
145 lines
5.3 KiB
Python
import base64
|
|
import hashlib
|
|
import os
|
|
from flask import Flask, request, redirect, url_for, render_template, flash
|
|
from flask_login import LoginManager, UserMixin, login_user, login_required, logout_user
|
|
from ldap3 import Server, Connection, MODIFY_REPLACE
|
|
import jwt
|
|
from datetime import datetime, timedelta
|
|
from dotenv import load_dotenv
|
|
from functools import wraps
|
|
|
|
|
|
# Flask setup
|
|
app = Flask(__name__)
|
|
app.secret_key = os.getenv("FLASK_SECRET_KEY", "supersecretflaskkey")
|
|
|
|
# Flask-Login setup
|
|
login_manager = LoginManager()
|
|
login_manager.init_app(app)
|
|
|
|
load_dotenv()
|
|
LDAP_SERVER = os.getenv("LDAP_SERVER")
|
|
USER_DN = os.getenv("USER_DN")
|
|
GROUP_DN = os.getenv("GROUP_DN")
|
|
LDAP_SERVICE_USER = os.getenv("LDAP_SERVICE_USER")
|
|
LDAP_SERVICE_PASSWORD = os.getenv("LDAP_SERVICE_PASSWORD")
|
|
JWT_SECRET = os.getenv("JWT_SECRET")
|
|
JWT_ALGORITHM = "HS256"
|
|
JWT_EXPIRATION_SECONDS = 3600
|
|
|
|
# User model
|
|
class User(UserMixin):
|
|
def __init__(self, id, dn):
|
|
self.id = id
|
|
self.dn = dn
|
|
|
|
@login_manager.user_loader
|
|
def load_user(user_id):
|
|
return User(user_id, f"uid={user_id},{USER_DN}")
|
|
|
|
# JWT utility functions
|
|
def generate_jwt(payload):
|
|
expiration = datetime.utcnow() + timedelta(seconds=JWT_EXPIRATION_SECONDS)
|
|
payload.update({"exp": expiration})
|
|
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
|
|
|
def decode_jwt(token):
|
|
try:
|
|
return jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
|
|
except jwt.ExpiredSignatureError:
|
|
return None
|
|
except jwt.InvalidTokenError:
|
|
return None
|
|
|
|
# Decorator for JWT authentication
|
|
def jwt_required(f):
|
|
@wraps(f)
|
|
def decorated_function(*args, **kwargs):
|
|
token = request.cookies.get("jwt")
|
|
if not token:
|
|
flash("You must log in to access this page.", "error")
|
|
return redirect(url_for("login"))
|
|
payload = decode_jwt(token)
|
|
if not payload:
|
|
flash("Session expired or invalid. Please log in again.", "error")
|
|
return redirect(url_for("login"))
|
|
request.user = payload.get("username")
|
|
return f(*args, **kwargs)
|
|
return decorated_function
|
|
|
|
# Routes
|
|
@app.route("/", methods=["GET", "POST"])
|
|
def login():
|
|
if request.method == "POST":
|
|
username = request.form["username"]
|
|
password = request.form["password"]
|
|
user_dn = f"uid={username},{USER_DN}"
|
|
server = Server(LDAP_SERVER)
|
|
conn = Connection(server, user=user_dn, password=password)
|
|
if conn.bind():
|
|
login_user(User(username, user_dn))
|
|
token = generate_jwt({"username": username})
|
|
response = redirect(url_for("profile"))
|
|
response.set_cookie("jwt", token, httponly=True, secure=True)
|
|
return response
|
|
flash("Invalid credentials", "error")
|
|
return render_template("login.html")
|
|
|
|
@app.route("/logout")
|
|
@login_required
|
|
def logout():
|
|
logout_user()
|
|
response = redirect(url_for("login"))
|
|
response.delete_cookie("jwt")
|
|
return response
|
|
|
|
@app.route("/profile", methods=["GET", "POST"])
|
|
@login_required
|
|
@jwt_required
|
|
def profile():
|
|
username = request.user
|
|
user_dn = f"uid={username},{USER_DN}"
|
|
server = Server(LDAP_SERVER)
|
|
|
|
# Connect as the service account
|
|
service_conn = Connection(server, user=LDAP_SERVICE_USER, password=LDAP_SERVICE_PASSWORD, auto_bind=True)
|
|
|
|
service_conn.search(user_dn, "(objectClass=klUser)", attributes=["*"])
|
|
user_attrs = service_conn.entries[0]
|
|
|
|
if request.method == "POST":
|
|
if "protectedAccount" in user_attrs and user_attrs.protectedAccount == True:
|
|
flash("User has protectedAccount: TRUE. Cannot edit!", "error")
|
|
else:
|
|
updates = {
|
|
"sn": request.form["sn"],
|
|
"cn": request.form["cn"],
|
|
"givenName": request.form["givenName"],
|
|
"sipPassword": request.form["sipPassword"],
|
|
}
|
|
for key, value in updates.items():
|
|
service_conn.modify(user_dn, {key: [(MODIFY_REPLACE, [value])]})
|
|
ssh_keys = request.form.get('sshPublicKey', '').split('\n')
|
|
ssh_keys = map(str.strip, ssh_keys)
|
|
ssh_keys = list(set(filter(lambda x: not x.startswith('#') and len(x) > 0, ssh_keys)))
|
|
service_conn.modify(user_dn, {"sshPublicKey": [(MODIFY_REPLACE, ssh_keys)]})
|
|
if "password" in request.form and request.form["password"].strip() and "password_repeat" in request.form and request.form["password_repeat"].strip():
|
|
if request.form["password"] == request.form["password_repeat"]:
|
|
hashed_password = "{SHA}" + base64.b64encode(hashlib.sha1(request.form["password"].encode()).digest()).decode()
|
|
service_conn.modify(user_dn, {"userPassword": [(MODIFY_REPLACE, [hashed_password])]})
|
|
else:
|
|
flash("Passwords do not match", "error")
|
|
flash("Profile updated successfully", "success")
|
|
|
|
service_conn.search(user_dn, "(objectClass=klUser)", attributes=["*"])
|
|
user_attrs = service_conn.entries[0]
|
|
service_conn.search(GROUP_DN, f"(member={user_dn})", attributes=["cn"])
|
|
groups = [entry.cn.value for entry in service_conn.entries]
|
|
return render_template("profile.html", user_attrs=user_attrs, groups=groups)
|
|
|
|
if __name__ == "__main__":
|
|
app.run(debug=True, host="127.0.0.1", port=5000)
|
|
|
|
|