Python实战SSO系统
2026/3/20大约 13 分钟
Python 实战 SSO 系统
项目概述
本章将使用 Python 从零实现一个完整的 SSO 系统,包括:
- SSO 认证服务器(IdP)
- 客户端应用(SP)
- 客户端 SDK
项目架构
目录结构
SSO 服务端实现
配置文件
# sso_server/config.py
import os
from datetime import timedelta
class Config:
"""SSO 服务器配置"""
# Flask 配置
SECRET_KEY = os.getenv("SECRET_KEY", "your-super-secret-key-change-in-production")
DEBUG = os.getenv("DEBUG", "False").lower() == "true"
# 数据库配置
SQLALCHEMY_DATABASE_URI = os.getenv(
"DATABASE_URI",
"sqlite:///sso.db"
)
SQLALCHEMY_TRACK_MODIFICATIONS = False
# Redis 配置
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
REDIS_DB = int(os.getenv("REDIS_DB", 0))
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)
# SSO 配置
SSO_SERVER_URL = os.getenv("SSO_SERVER_URL", "http://localhost:5000")
TGT_EXPIRE = int(os.getenv("TGT_EXPIRE", 28800)) # 8 小时
ST_EXPIRE = int(os.getenv("ST_EXPIRE", 300)) # 5 分钟
SESSION_EXPIRE = int(os.getenv("SESSION_EXPIRE", 1800)) # 30 分钟
# 安全配置
PASSWORD_MIN_LENGTH = 8
MAX_LOGIN_ATTEMPTS = 5
LOCKOUT_DURATION = 1800 # 30 分钟
# JWT 配置
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", SECRET_KEY)
JWT_ACCESS_TOKEN_EXPIRES = timedelta(minutes=15)
JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=7)
数据模型
# sso_server/models/user.py
from datetime import datetime
from flask_sqlalchemy import SQLAlchemy
from werkzeug.security import generate_password_hash, check_password_hash
db = SQLAlchemy()
class User(db.Model):
"""用户模型"""
__tablename__ = 'users'
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True, nullable=False, index=True)
email = db.Column(db.String(120), unique=True, nullable=False)
password_hash = db.Column(db.String(256), nullable=False)
display_name = db.Column(db.String(100))
is_active = db.Column(db.Boolean, default=True)
is_admin = db.Column(db.Boolean, default=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
updated_at = db.Column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
last_login_at = db.Column(db.DateTime)
last_login_ip = db.Column(db.String(45))
# 角色关联
roles = db.relationship('UserRole', backref='user', lazy='dynamic')
def set_password(self, password: str):
"""设置密码"""
self.password_hash = generate_password_hash(password)
def check_password(self, password: str) -> bool:
"""验证密码"""
return check_password_hash(self.password_hash, password)
def to_dict(self) -> dict:
"""转换为字典"""
return {
'id': self.id,
'username': self.username,
'email': self.email,
'display_name': self.display_name,
'is_active': self.is_active,
'roles': [r.role_name for r in self.roles]
}
class Role(db.Model):
"""角色模型"""
__tablename__ = 'roles'
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String(50), unique=True, nullable=False)
description = db.Column(db.String(200))
class UserRole(db.Model):
"""用户角色关联"""
__tablename__ = 'user_roles'
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('users.id'), nullable=False)
role_name = db.Column(db.String(50), nullable=False)
class ServiceProvider(db.Model):
"""服务提供者(接入应用)"""
__tablename__ = 'service_providers'
id = db.Column(db.Integer, primary_key=True)
client_id = db.Column(db.String(64), unique=True, nullable=False, index=True)
client_secret = db.Column(db.String(128), nullable=False)
name = db.Column(db.String(100), nullable=False)
description = db.Column(db.String(500))
redirect_uris = db.Column(db.Text) # JSON 数组
logout_uri = db.Column(db.String(500))
is_active = db.Column(db.Boolean, default=True)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
def get_redirect_uris(self) -> list:
"""获取重定向 URI 列表"""
import json
return json.loads(self.redirect_uris or '[]')
def is_valid_redirect_uri(self, uri: str) -> bool:
"""验证重定向 URI"""
return uri in self.get_redirect_uris()
票据服务
# sso_server/services/ticket_service.py
import secrets
import json
from datetime import datetime, timedelta
from typing import Optional
import redis
class TicketService:
"""票据服务 - 管理 TGT 和 ST"""
def __init__(self, redis_client: redis.Redis, config: dict):
self.redis = redis_client
self.tgt_prefix = "tgt:"
self.st_prefix = "st:"
self.tgt_expire = config.get('TGT_EXPIRE', 28800)
self.st_expire = config.get('ST_EXPIRE', 300)
# ==================== TGT 管理 ====================
def create_tgt(self, user_id: int, user_data: dict) -> str:
"""
创建 Ticket Granting Ticket
Args:
user_id: 用户 ID
user_data: 用户信息
Returns:
TGT ID
"""
tgt_id = f"TGT-{secrets.token_urlsafe(32)}"
key = f"{self.tgt_prefix}{tgt_id}"
tgt_data = {
"user_id": user_id,
"user_data": user_data,
"created_at": datetime.utcnow().isoformat(),
"services": [] # 已登录的服务列表
}
self.redis.setex(key, self.tgt_expire, json.dumps(tgt_data))
# 记录用户的 TGT(用于管理)
user_tgt_key = f"user_tgt:{user_id}"
self.redis.setex(user_tgt_key, self.tgt_expire, tgt_id)
return tgt_id
def get_tgt(self, tgt_id: str) -> Optional[dict]:
"""获取 TGT 数据"""
key = f"{self.tgt_prefix}{tgt_id}"
data = self.redis.get(key)
if data:
return json.loads(data)
return None
def validate_tgt(self, tgt_id: str) -> Optional[dict]:
"""验证并获取 TGT"""
tgt = self.get_tgt(tgt_id)
if not tgt:
return None
# 刷新 TGT(滑动过期)
key = f"{self.tgt_prefix}{tgt_id}"
self.redis.expire(key, self.tgt_expire)
return tgt
def destroy_tgt(self, tgt_id: str) -> bool:
"""销毁 TGT"""
tgt = self.get_tgt(tgt_id)
if not tgt:
return False
# 删除 TGT
key = f"{self.tgt_prefix}{tgt_id}"
self.redis.delete(key)
# 删除用户 TGT 引用
user_tgt_key = f"user_tgt:{tgt['user_id']}"
self.redis.delete(user_tgt_key)
return True
def add_service_to_tgt(self, tgt_id: str, service_url: str, session_id: str):
"""记录已登录的服务(用于单点登出)"""
tgt = self.get_tgt(tgt_id)
if not tgt:
return
services = tgt.get('services', [])
services.append({
'service_url': service_url,
'session_id': session_id,
'login_time': datetime.utcnow().isoformat()
})
tgt['services'] = services
key = f"{self.tgt_prefix}{tgt_id}"
ttl = self.redis.ttl(key)
if ttl > 0:
self.redis.setex(key, ttl, json.dumps(tgt))
def get_tgt_services(self, tgt_id: str) -> list:
"""获取 TGT 关联的所有服务"""
tgt = self.get_tgt(tgt_id)
if tgt:
return tgt.get('services', [])
return []
# ==================== ST 管理 ====================
def create_st(self, tgt_id: str, service_url: str) -> Optional[str]:
"""
创建 Service Ticket
Args:
tgt_id: TGT ID
service_url: 服务 URL
Returns:
ST ID
"""
# 验证 TGT
tgt = self.validate_tgt(tgt_id)
if not tgt:
return None
st_id = f"ST-{secrets.token_urlsafe(32)}"
key = f"{self.st_prefix}{st_id}"
st_data = {
"tgt_id": tgt_id,
"user_id": tgt["user_id"],
"user_data": tgt["user_data"],
"service_url": service_url,
"created_at": datetime.utcnow().isoformat(),
"used": False
}
self.redis.setex(key, self.st_expire, json.dumps(st_data))
return st_id
def validate_st(self, st_id: str, service_url: str) -> Optional[dict]:
"""
验证并消费 Service Ticket(一次性)
Args:
st_id: ST ID
service_url: 服务 URL
Returns:
用户信息
"""
key = f"{self.st_prefix}{st_id}"
# 使用 Lua 脚本保证原子性
lua_script = """
local data = redis.call('GET', KEYS[1])
if not data then
return nil
end
local st = cjson.decode(data)
if st.used then
return nil
end
-- 验证服务 URL
if st.service_url ~= ARGV[1] then
return nil
end
-- 标记为已使用
st.used = true
redis.call('SETEX', KEYS[1], 60, cjson.encode(st))
return data
"""
result = self.redis.eval(lua_script, 1, key, service_url)
if result:
st_data = json.loads(result)
return {
"user_id": st_data["user_id"],
"user_data": st_data["user_data"],
"tgt_id": st_data["tgt_id"]
}
return None
认证服务
# sso_server/services/auth_service.py
from typing import Optional, Tuple
from datetime import datetime
import redis
from ..models.user import db, User
class AuthService:
"""认证服务"""
def __init__(self, redis_client: redis.Redis, config: dict):
self.redis = redis_client
self.config = config
self.max_attempts = config.get('MAX_LOGIN_ATTEMPTS', 5)
self.lockout_duration = config.get('LOCKOUT_DURATION', 1800)
def authenticate(
self,
username: str,
password: str,
ip_address: str
) -> Tuple[Optional[User], str]:
"""
用户认证
Args:
username: 用户名
password: 密码
ip_address: IP 地址
Returns:
(User, error_message)
"""
# 检查是否被锁定
if self._is_locked(username):
return None, "账户已锁定,请稍后重试"
# 查找用户
user = User.query.filter_by(username=username).first()
if not user:
self._record_failed_attempt(username, ip_address)
return None, "用户名或密码错误"
if not user.is_active:
return None, "账户已被禁用"
# 验证密码
if not user.check_password(password):
self._record_failed_attempt(username, ip_address)
return None, "用户名或密码错误"
# 登录成功,清除失败计数
self._clear_failed_attempts(username)
# 更新登录信息
user.last_login_at = datetime.utcnow()
user.last_login_ip = ip_address
db.session.commit()
return user, ""
def _is_locked(self, username: str) -> bool:
"""检查账户是否被锁定"""
key = f"lockout:{username}"
return self.redis.exists(key)
def _record_failed_attempt(self, username: str, ip_address: str):
"""记录失败的登录尝试"""
key = f"login_attempts:{username}"
attempts = self.redis.incr(key)
self.redis.expire(key, 3600) # 1 小时后重置
# 达到阈值,锁定账户
if attempts >= self.max_attempts:
lock_key = f"lockout:{username}"
self.redis.setex(lock_key, self.lockout_duration, "1")
def _clear_failed_attempts(self, username: str):
"""清除失败计数"""
key = f"login_attempts:{username}"
self.redis.delete(key)
def get_user_by_id(self, user_id: int) -> Optional[User]:
"""根据 ID 获取用户"""
return User.query.get(user_id)
def create_user(
self,
username: str,
email: str,
password: str,
display_name: str = None
) -> Tuple[Optional[User], str]:
"""创建用户"""
# 检查用户名是否已存在
if User.query.filter_by(username=username).first():
return None, "用户名已存在"
if User.query.filter_by(email=email).first():
return None, "邮箱已存在"
user = User(
username=username,
email=email,
display_name=display_name or username
)
user.set_password(password)
db.session.add(user)
db.session.commit()
return user, ""
路由实现
# sso_server/routes/auth.py
from flask import Blueprint, request, redirect, render_template, make_response, jsonify, url_for
from urllib.parse import urlencode, urlparse
from ..services.auth_service import AuthService
from ..services.ticket_service import TicketService
from ..models.user import ServiceProvider
auth_bp = Blueprint('auth', __name__)
def get_services():
"""获取服务实例(应该通过依赖注入)"""
from flask import current_app
return (
current_app.config['auth_service'],
current_app.config['ticket_service']
)
@auth_bp.route('/login', methods=['GET', 'POST'])
def login():
"""SSO 登录页面"""
auth_service, ticket_service = get_services()
# 获取目标服务
service_url = request.args.get('service', '')
lt = request.args.get('lt', '') # Login Ticket(可选)
# 检查是否已有 TGT(Cookie)
tgc = request.cookies.get('TGC')
if tgc:
tgt = ticket_service.validate_tgt(tgc)
if tgt:
# 已登录,直接签发 ST
if service_url:
st = ticket_service.create_st(tgc, service_url)
if st:
redirect_url = f"{service_url}{'&' if '?' in service_url else '?'}ticket={st}"
return redirect(redirect_url)
# 无目标服务,显示登录成功页面
return render_template(
'login_success.html',
user=tgt['user_data']
)
if request.method == 'GET':
# 显示登录页面
return render_template(
'login.html',
service=service_url,
error=None
)
# POST: 处理登录
username = request.form.get('username', '')
password = request.form.get('password', '')
remember = request.form.get('remember', False)
# 认证
user, error = auth_service.authenticate(
username,
password,
request.remote_addr
)
if not user:
return render_template(
'login.html',
service=service_url,
error=error
), 401
# 创建 TGT
tgt_id = ticket_service.create_tgt(
user.id,
user.to_dict()
)
# 构建响应
if service_url:
# 签发 ST 并重定向
st = ticket_service.create_st(tgt_id, service_url)
redirect_url = f"{service_url}{'&' if '?' in service_url else '?'}ticket={st}"
response = make_response(redirect(redirect_url))
else:
# 无目标服务
response = make_response(redirect(url_for('auth.profile')))
# 设置 TGC Cookie
max_age = 28800 if not remember else 604800 # 8小时 或 7天
response.set_cookie(
'TGC',
tgt_id,
max_age=max_age,
httponly=True,
secure=request.is_secure,
samesite='Lax'
)
return response
@auth_bp.route('/logout', methods=['GET', 'POST'])
def logout():
"""SSO 登出"""
auth_service, ticket_service = get_services()
service_url = request.args.get('service', '')
tgc = request.cookies.get('TGC')
if tgc:
# 获取所有已登录的服务
services = ticket_service.get_tgt_services(tgc)
# 通知所有服务登出(后端通道)
for service in services:
try:
# 实际应用中应该异步处理
notify_service_logout(service['service_url'], service['session_id'])
except Exception as e:
print(f"登出通知失败: {e}")
# 销毁 TGT
ticket_service.destroy_tgt(tgc)
# 构建响应
if service_url:
response = make_response(redirect(service_url))
else:
response = make_response(redirect(url_for('auth.login')))
# 删除 TGC Cookie
response.delete_cookie('TGC')
return response
@auth_bp.route('/validate', methods=['GET'])
def validate_ticket():
"""验证 Service Ticket"""
auth_service, ticket_service = get_services()
ticket = request.args.get('ticket', '')
service = request.args.get('service', '')
if not ticket or not service:
return make_cas_failure_response('INVALID_REQUEST', '缺少必要参数')
# 验证 ST
result = ticket_service.validate_st(ticket, service)
if not result:
return make_cas_failure_response('INVALID_TICKET', '票据无效或已过期')
# 记录服务登录
if result.get('tgt_id'):
ticket_service.add_service_to_tgt(
result['tgt_id'],
service,
ticket # 使用 ST 作为 session 标识
)
return make_cas_success_response(result['user_data'])
@auth_bp.route('/profile')
def profile():
"""用户中心"""
auth_service, ticket_service = get_services()
tgc = request.cookies.get('TGC')
if not tgc:
return redirect(url_for('auth.login'))
tgt = ticket_service.validate_tgt(tgc)
if not tgt:
return redirect(url_for('auth.login'))
return render_template(
'profile.html',
user=tgt['user_data'],
services=ticket_service.get_tgt_services(tgc)
)
def make_cas_success_response(user_data: dict):
"""构建 CAS 成功响应"""
xml = f'''<?xml version="1.0" encoding="UTF-8"?>
<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
<cas:authenticationSuccess>
<cas:user>{user_data.get('username', '')}</cas:user>
<cas:attributes>
<cas:user_id>{user_data.get('id', '')}</cas:user_id>
<cas:email>{user_data.get('email', '')}</cas:email>
<cas:display_name>{user_data.get('display_name', '')}</cas:display_name>
<cas:roles>{','.join(user_data.get('roles', []))}</cas:roles>
</cas:attributes>
</cas:authenticationSuccess>
</cas:serviceResponse>'''
response = make_response(xml)
response.headers['Content-Type'] = 'application/xml'
return response
def make_cas_failure_response(code: str, message: str):
"""构建 CAS 失败响应"""
xml = f'''<?xml version="1.0" encoding="UTF-8"?>
<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
<cas:authenticationFailure code="{code}">
{message}
</cas:authenticationFailure>
</cas:serviceResponse>'''
response = make_response(xml)
response.headers['Content-Type'] = 'application/xml'
return response
def notify_service_logout(service_url: str, session_id: str):
"""通知服务登出"""
import requests
try:
requests.post(
f"{service_url}/sso/logout-callback",
json={"session_id": session_id},
timeout=5
)
except Exception:
pass
应用入口
# sso_server/app.py
from flask import Flask
import redis
from .config import Config
from .models.user import db
from .services.auth_service import AuthService
from .services.ticket_service import TicketService
from .routes.auth import auth_bp
def create_app(config_class=Config):
"""创建 Flask 应用"""
app = Flask(__name__)
app.config.from_object(config_class)
# 初始化数据库
db.init_app(app)
# 初始化 Redis
redis_client = redis.Redis(
host=app.config['REDIS_HOST'],
port=app.config['REDIS_PORT'],
db=app.config['REDIS_DB'],
password=app.config['REDIS_PASSWORD'],
decode_responses=True
)
# 初始化服务
app.config['auth_service'] = AuthService(redis_client, app.config)
app.config['ticket_service'] = TicketService(redis_client, app.config)
# 注册蓝图
app.register_blueprint(auth_bp, url_prefix='/sso')
# 创建数据库表
with app.app_context():
db.create_all()
return app
if __name__ == '__main__':
app = create_app()
app.run(host='0.0.0.0', port=5000, debug=True)
SSO 客户端 SDK
客户端实现
# sso_client/client.py
import requests
import xml.etree.ElementTree as ET
from typing import Optional, Dict, Any
from urllib.parse import urlencode, urljoin
from dataclasses import dataclass
@dataclass
class SSOUser:
"""SSO 用户信息"""
user_id: int
username: str
email: str
display_name: str
roles: list
@classmethod
def from_dict(cls, data: dict) -> 'SSOUser':
return cls(
user_id=int(data.get('user_id', 0)),
username=data.get('username', ''),
email=data.get('email', ''),
display_name=data.get('display_name', ''),
roles=data.get('roles', '').split(',') if data.get('roles') else []
)
class SSOClient:
"""SSO 客户端"""
def __init__(
self,
sso_server_url: str,
service_url: str,
client_id: str = None,
client_secret: str = None
):
"""
初始化 SSO 客户端
Args:
sso_server_url: SSO 服务器地址
service_url: 当前服务地址
client_id: 客户端 ID(可选)
client_secret: 客户端密钥(可选)
"""
self.sso_server_url = sso_server_url.rstrip('/')
self.service_url = service_url.rstrip('/')
self.client_id = client_id
self.client_secret = client_secret
def get_login_url(self, return_url: str = None) -> str:
"""
获取 SSO 登录 URL
Args:
return_url: 登录后返回的 URL
Returns:
SSO 登录页面 URL
"""
service = return_url or self.service_url
params = urlencode({'service': service})
return f"{self.sso_server_url}/sso/login?{params}"
def get_logout_url(self, return_url: str = None) -> str:
"""
获取 SSO 登出 URL
Args:
return_url: 登出后返回的 URL
Returns:
SSO 登出页面 URL
"""
params = {}
if return_url:
params['service'] = return_url
if params:
return f"{self.sso_server_url}/sso/logout?{urlencode(params)}"
return f"{self.sso_server_url}/sso/logout"
def validate_ticket(self, ticket: str, service_url: str = None) -> Optional[SSOUser]:
"""
验证 Service Ticket
Args:
ticket: ST
service_url: 服务 URL
Returns:
SSOUser 或 None
"""
service = service_url or self.service_url
validate_url = f"{self.sso_server_url}/sso/validate"
try:
response = requests.get(
validate_url,
params={'ticket': ticket, 'service': service},
timeout=10
)
if response.status_code == 200:
return self._parse_cas_response(response.text)
except requests.RequestException as e:
print(f"票据验证请求失败: {e}")
return None
def _parse_cas_response(self, xml_text: str) -> Optional[SSOUser]:
"""解析 CAS XML 响应"""
try:
ns = {'cas': 'http://www.yale.edu/tp/cas'}
root = ET.fromstring(xml_text)
success = root.find('.//cas:authenticationSuccess', ns)
if success is None:
return None
user_elem = success.find('cas:user', ns)
attrs = success.find('cas:attributes', ns)
user_data = {
'username': user_elem.text if user_elem is not None else ''
}
if attrs is not None:
for child in attrs:
tag = child.tag.replace(f"{{{ns['cas']}}}", "")
user_data[tag] = child.text
return SSOUser.from_dict(user_data)
except ET.ParseError as e:
print(f"XML 解析错误: {e}")
return None
Flask 中间件
# sso_client/middleware.py
from functools import wraps
from flask import request, redirect, session, g, current_app
from typing import Callable
from .client import SSOClient, SSOUser
class SSOMiddleware:
"""SSO 中间件"""
def __init__(
self,
app=None,
sso_server_url: str = None,
service_url: str = None,
login_path: str = '/login',
logout_path: str = '/logout',
callback_path: str = '/sso/callback',
excluded_paths: list = None
):
self.sso_server_url = sso_server_url
self.service_url = service_url
self.login_path = login_path
self.logout_path = logout_path
self.callback_path = callback_path
self.excluded_paths = excluded_paths or []
self.client = None
if app is not None:
self.init_app(app)
def init_app(self, app):
"""初始化 Flask 应用"""
self.sso_server_url = self.sso_server_url or app.config.get('SSO_SERVER_URL')
self.service_url = self.service_url or app.config.get('SERVICE_URL')
self.client = SSOClient(
sso_server_url=self.sso_server_url,
service_url=self.service_url
)
# 注册路由
app.add_url_rule(self.login_path, 'sso_login', self._login)
app.add_url_rule(self.logout_path, 'sso_logout', self._logout)
app.add_url_rule(self.callback_path, 'sso_callback', self._callback)
app.add_url_rule('/sso/logout-callback', 'sso_logout_callback',
self._logout_callback, methods=['POST'])
# 注册请求钩子
app.before_request(self._before_request)
def _before_request(self):
"""请求��检查"""
# 跳过排除的路径
if self._is_excluded_path(request.path):
return
# 检查是否已登录
user_data = session.get('sso_user')
if user_data:
g.current_user = SSOUser.from_dict(user_data)
return
# 未登录,检查是否有 ticket
ticket = request.args.get('ticket')
if ticket:
return # 让 callback 处理
# 需要登录,重定向到 SSO
return redirect(self.client.get_login_url(request.url))
def _is_excluded_path(self, path: str) -> bool:
"""检查是否是排除的路径"""
excluded = [
self.login_path,
self.logout_path,
self.callback_path,
'/sso/logout-callback',
'/static'
] + self.excluded_paths
for exc in excluded:
if path.startswith(exc):
return True
return False
def _login(self):
"""登录"""
return_url = request.args.get('return_url', '/')
return redirect(self.client.get_login_url(return_url))
def _logout(self):
"""登出"""
session.clear()
return_url = request.args.get('return_url', '/')
return redirect(self.client.get_logout_url(return_url))
def _callback(self):
"""SSO 回调"""
ticket = request.args.get('ticket')
if not ticket:
return redirect(self.client.get_login_url())
# 验证票据
service_url = request.url.split('?')[0]
user = self.client.validate_ticket(ticket, service_url)
if not user:
return redirect(self.client.get_login_url())
# 存储用户信息
session['sso_user'] = {
'user_id': user.user_id,
'username': user.username,
'email': user.email,
'display_name': user.display_name,
'roles': user.roles
}
# 重定向到原始 URL
return_url = request.args.get('return_url', '/')
return redirect(return_url)
def _logout_callback(self):
"""接收 SSO 登出通知"""
# 实际应用中应该根据 session_id 清除对应会话
return '', 200
def sso_required(f: Callable) -> Callable:
"""SSO 认证装饰器"""
@wraps(f)
def decorated(*args, **kwargs):
if not hasattr(g, 'current_user') or g.current_user is None:
return redirect('/login')
return f(*args, **kwargs)
return decorated
def role_required(*roles: str) -> Callable:
"""角色验证装饰器"""
def decorator(f: Callable) -> Callable:
@wraps(f)
def decorated(*args, **kwargs):
if not hasattr(g, 'current_user') or g.current_user is None:
return redirect('/login')
user_roles = g.current_user.roles
if not any(role in user_roles for role in roles):
return {'error': '权限不足'}, 403
return f(*args, **kwargs)
return decorated
return decorator
示例应用
客户端应用示例
# demo_app/app.py
from flask import Flask, render_template, g, jsonify
from sso_client.middleware import SSOMiddleware, sso_required, role_required
app = Flask(__name__)
app.secret_key = 'demo-app-secret-key'
# SSO 配置
app.config['SSO_SERVER_URL'] = 'http://localhost:5000'
app.config['SERVICE_URL'] = 'http://localhost:5001'
# 初始化 SSO 中间件
sso = SSOMiddleware(
app,
excluded_paths=['/api/public', '/health']
)
@app.route('/')
def index():
"""首页"""
return render_template('index.html', user=g.current_user)
@app.route('/profile')
@sso_required
def profile():
"""用户中心"""
return render_template('profile.html', user=g.current_user)
@app.route('/admin')
@role_required('admin')
def admin():
"""管理页面"""
return render_template('admin.html', user=g.current_user)
@app.route('/api/user')
@sso_required
def api_user():
"""API: 获取当前用户"""
user = g.current_user
return jsonify({
'user_id': user.user_id,
'username': user.username,
'email': user.email,
'roles': user.roles
})
@app.route('/api/public')
def api_public():
"""公开 API"""
return jsonify({'message': '这是公开接口'})
@app.route('/health')
def health():
"""健康检查"""
return jsonify({'status': 'ok'})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5001, debug=True)
HTML 模板
<!-- demo_app/templates/index.html -->
<!DOCTYPE html>
<html>
<head>
<title>Demo App</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 40px;
}
.user-info {
background: #f5f5f5;
padding: 20px;
border-radius: 8px;
}
.nav {
margin-bottom: 20px;
}
.nav a {
margin-right: 15px;
}
</style>
</head>
<body>
<div class="nav">
<a href="/">首页</a>
<a href="/profile">个人中心</a>
{% if 'admin' in user.roles %}
<a href="/admin">管理后台</a>
{% endif %}
<a href="/logout">登出</a>
</div>
<h1>欢迎, {{ user.display_name }}!</h1>
<div class="user-info">
<h3>用户信息</h3>
<p><strong>用户ID:</strong> {{ user.user_id }}</p>
<p><strong>用户名:</strong> {{ user.username }}</p>
<p><strong>邮箱:</strong> {{ user.email }}</p>
<p><strong>角色:</strong> {{ user.roles | join(', ') }}</p>
</div>
</body>
</html>
运行说明
安装依赖
# requirements.txt
flask==3.0.0
flask-sqlalchemy==3.1.1
redis==5.0.1
requests==2.31.0
werkzeug==3.0.1
启动服务
# 1. 启动 Redis
redis-server
# 2. 启动 SSO 服务器
cd sso_server
python app.py
# 3. 启动示例应用
cd demo_app
python app.py
测试流程
- 访问
http://localhost:5001(示例应用) - 自动跳转到
http://localhost:5000/sso/login(SSO 登录) - 输入用户名密码登录
- 自动跳转回示例应用,显示用户信息
- 点击登出,完成单点登出
本章小结
本章实现了一个完整的 SSO 系统,包括:
- SSO 服务端:用户认证、票据管理、单点登出
- 客户端 SDK:票据验证、会话管理、Flask 中间件
- 示例应用:展示 SSO 集成方式