RESTful API 开发
2026/3/20大约 13 分钟
RESTful API 开发
第一章:REST 基础
REST 设计原则
REST(Representational State Transfer)是一种软件架构风格,核心原则包括:
- 统一接口:使用标准 HTTP 方法和状态码
- 无状态:每个请求包含所有必要信息
- 资源导向:URL 代表资源,动词由 HTTP 方法表示
- 分层系统:客户端不需要知道是否直接连接到服务器
HTTP 方法与 CRUD
| HTTP 方法 | CRUD 操作 | 描述 | 示例 |
|---|---|---|---|
| GET | Read | 获取资源 | GET /users |
| POST | Create | 创建资源 | POST /users |
| PUT | Update | 完整更新资源 | PUT /users/1 |
| PATCH | Partial Update | 部分更新资源 | PATCH /users/1 |
| DELETE | Delete | 删除资源 | DELETE /users/1 |
HTTP 状态码
# 成功响应
200 OK # 请求成功
201 Created # 资源创建成功
204 No Content # 成功但无返回内容
# 重定向
301 Moved Permanently # 永久重定向
302 Found # 临时重定向
304 Not Modified # 资源未修改
# 客户端错误
400 Bad Request # 请求格式错误
401 Unauthorized # 未认证
403 Forbidden # 无权限
404 Not Found # 资源不存在
405 Method Not Allowed # 方法不允许
409 Conflict # 资源冲突
422 Unprocessable Entity # 验证失败
429 Too Many Requests # 请求过多
# 服务器错误
500 Internal Server Error # 服务器错误
502 Bad Gateway # 网关错误
503 Service Unavailable # 服务不可用
第二章:Flask 原生 API
基础 API 视图
from flask import Flask, jsonify, request, abort, make_response
app = Flask(__name__)
# 模拟数据库
users = [
{'id': 1, 'username': 'alice', 'email': 'alice@example.com'},
{'id': 2, 'username': 'bob', 'email': 'bob@example.com'},
]
# GET - 获取所有用户
@app.route('/api/users', methods=['GET'])
def get_users():
return jsonify({
'users': users,
'total': len(users)
})
# GET - 获取单个用户
@app.route('/api/users/<int:user_id>', methods=['GET'])
def get_user(user_id):
user = next((u for u in users if u['id'] == user_id), None)
if user is None:
abort(404)
return jsonify(user)
# POST - 创建用户
@app.route('/api/users', methods=['POST'])
def create_user():
if not request.json:
abort(400)
if 'username' not in request.json or 'email' not in request.json:
abort(400, description='Missing required fields')
user = {
'id': users[-1]['id'] + 1 if users else 1,
'username': request.json['username'],
'email': request.json['email']
}
users.append(user)
return jsonify(user), 201
# PUT - 完整更新用户
@app.route('/api/users/<int:user_id>', methods=['PUT'])
def update_user(user_id):
user = next((u for u in users if u['id'] == user_id), None)
if user is None:
abort(404)
if not request.json:
abort(400)
user['username'] = request.json.get('username', user['username'])
user['email'] = request.json.get('email', user['email'])
return jsonify(user)
# PATCH - 部分更新用户
@app.route('/api/users/<int:user_id>', methods=['PATCH'])
def patch_user(user_id):
user = next((u for u in users if u['id'] == user_id), None)
if user is None:
abort(404)
if not request.json:
abort(400)
if 'username' in request.json:
user['username'] = request.json['username']
if 'email' in request.json:
user['email'] = request.json['email']
return jsonify(user)
# DELETE - 删除用户
@app.route('/api/users/<int:user_id>', methods=['DELETE'])
def delete_user(user_id):
user = next((u for u in users if u['id'] == user_id), None)
if user is None:
abort(404)
users.remove(user)
return '', 204
API 错误处理
from flask import Flask, jsonify
from werkzeug.exceptions import HTTPException
app = Flask(__name__)
# 自定义 API 异常
class APIError(Exception):
def __init__(self, message, status_code=400, payload=None):
super().__init__()
self.message = message
self.status_code = status_code
self.payload = payload
def to_dict(self):
rv = dict(self.payload or ())
rv['error'] = self.message
rv['status'] = self.status_code
return rv
@app.errorhandler(APIError)
def handle_api_error(error):
response = jsonify(error.to_dict())
response.status_code = error.status_code
return response
# 处理所有 HTTP 异常
@app.errorhandler(HTTPException)
def handle_http_exception(error):
return jsonify({
'error': error.description,
'status': error.code
}), error.code
# 处理 500 错误
@app.errorhandler(500)
def handle_internal_error(error):
return jsonify({
'error': 'Internal server error',
'status': 500
}), 500
# 使用自定义异常
@app.route('/api/users/<int:user_id>')
def get_user(user_id):
user = User.query.get(user_id)
if not user:
raise APIError('User not found', status_code=404)
return jsonify(user.to_dict())
API 响应格式化
from flask import Flask, jsonify, make_response
from functools import wraps
app = Flask(__name__)
def api_response(data=None, message=None, status=200, **kwargs):
"""统一 API 响应格式"""
response = {
'success': 200 <= status < 400,
'status': status
}
if message:
response['message'] = message
if data is not None:
response['data'] = data
response.update(kwargs)
return jsonify(response), status
# 分页响应
def paginated_response(items, page, per_page, total):
"""分页响应"""
return api_response(
data=items,
pagination={
'page': page,
'per_page': per_page,
'total': total,
'pages': (total + per_page - 1) // per_page,
'has_next': page * per_page < total,
'has_prev': page > 1
}
)
# 使用示例
@app.route('/api/users')
def get_users():
page = request.args.get('page', 1, type=int)
per_page = request.args.get('per_page', 20, type=int)
pagination = User.query.paginate(page=page, per_page=per_page)
return paginated_response(
items=[u.to_dict() for u in pagination.items],
page=page,
per_page=per_page,
total=pagination.total
)
@app.route('/api/users', methods=['POST'])
def create_user():
# ... 创建用户 ...
return api_response(
data=user.to_dict(),
message='User created successfully',
status=201
)
第三章:Flask-RESTful
安装与基础配置
pip install flask-restful
from flask import Flask
from flask_restful import Api, Resource
app = Flask(__name__)
api = Api(app)
# 基础资源
class HelloWorld(Resource):
def get(self):
return {'message': 'Hello, World!'}
api.add_resource(HelloWorld, '/')
资源类
from flask_restful import Resource, Api, reqparse, fields, marshal_with
from flask import Flask
app = Flask(__name__)
api = Api(app)
# 定义输出格式
user_fields = {
'id': fields.Integer,
'username': fields.String,
'email': fields.String,
'created_at': fields.DateTime(dt_format='iso8601'),
'uri': fields.Url('user') # 自动生成 URL
}
class UserResource(Resource):
@marshal_with(user_fields)
def get(self, user_id):
user = User.query.get_or_404(user_id)
return user
@marshal_with(user_fields)
def put(self, user_id):
parser = reqparse.RequestParser()
parser.add_argument('username', type=str)
parser.add_argument('email', type=str)
args = parser.parse_args()
user = User.query.get_or_404(user_id)
if args['username']:
user.username = args['username']
if args['email']:
user.email = args['email']
db.session.commit()
return user
def delete(self, user_id):
user = User.query.get_or_404(user_id)
db.session.delete(user)
db.session.commit()
return '', 204
class UserListResource(Resource):
@marshal_with(user_fields)
def get(self):
return User.query.all()
@marshal_with(user_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('username', type=str, required=True)
parser.add_argument('email', type=str, required=True)
parser.add_argument('password', type=str, required=True)
args = parser.parse_args()
user = User(username=args['username'], email=args['email'])
user.set_password(args['password'])
db.session.add(user)
db.session.commit()
return user, 201
# 注册资源
api.add_resource(UserListResource, '/api/users')
api.add_resource(UserResource, '/api/users/<int:user_id>')
请求解析
from flask_restful import reqparse
# 基础解析器
parser = reqparse.RequestParser()
parser.add_argument('username', type=str, required=True, help='Username is required')
parser.add_argument('email', type=str, required=True)
parser.add_argument('age', type=int, default=0)
parser.add_argument('tags', type=str, action='append') # 多值参数
# 从不同位置获取参数
parser.add_argument('token', location='headers')
parser.add_argument('page', type=int, location='args')
parser.add_argument('data', type=dict, location='json')
# 自定义验证
def email_validator(value):
import re
if not re.match(r'^[^@]+@[^@]+\.[^@]+$', value):
raise ValueError('Invalid email format')
return value
parser.add_argument('email', type=email_validator)
# 继承解析器
base_parser = reqparse.RequestParser()
base_parser.add_argument('page', type=int, default=1)
base_parser.add_argument('per_page', type=int, default=20)
user_parser = base_parser.copy()
user_parser.add_argument('username', type=str)
user_parser.add_argument('email', type=str)
输出格式化
from flask_restful import fields, marshal_with, marshal
# 嵌套字段
address_fields = {
'city': fields.String,
'street': fields.String,
'zip_code': fields.String(attribute='zip')
}
user_fields = {
'id': fields.Integer,
'username': fields.String,
'email': fields.String,
'address': fields.Nested(address_fields),
'posts': fields.List(fields.Nested({
'id': fields.Integer,
'title': fields.String
}))
}
# 自定义字段
class GravatarUrl(fields.Raw):
def format(self, value):
import hashlib
hash = hashlib.md5(value.lower().encode()).hexdigest()
return f'https://www.gravatar.com/avatar/{hash}'
user_fields['avatar'] = GravatarUrl(attribute='email')
# 条件字段
class ConditionalField(fields.Raw):
def __init__(self, field, condition, **kwargs):
super().__init__(**kwargs)
self.field = field
self.condition = condition
def format(self, value):
if self.condition():
return self.field.format(value)
return None
# 手动格式化
@app.route('/api/users/<int:user_id>')
def get_user(user_id):
user = User.query.get_or_404(user_id)
return marshal(user, user_fields)
第四章:Flask-RESTX(推荐)
安装与配置
pip install flask-restx
from flask import Flask
from flask_restx import Api, Resource, Namespace, fields
app = Flask(__name__)
# 创建 API 实例
api = Api(
app,
version='1.0',
title='My API',
description='A sample API',
doc='/docs' # Swagger UI 路径
)
# 创建命名空间
ns_users = api.namespace('users', description='User operations')
ns_posts = api.namespace('posts', description='Post operations')
模型定义与 Swagger 文档
from flask_restx import fields, Namespace, Resource
ns = Namespace('users', description='User operations')
# 定义模型(用于文档和验证)
user_model = ns.model('User', {
'id': fields.Integer(readonly=True, description='User ID'),
'username': fields.String(required=True, description='Username'),
'email': fields.String(required=True, description='Email address'),
'created_at': fields.DateTime(description='Creation time')
})
user_input_model = ns.model('UserInput', {
'username': fields.String(required=True, min_length=3, max_length=80),
'email': fields.String(required=True),
'password': fields.String(required=True, min_length=8)
})
user_update_model = ns.model('UserUpdate', {
'username': fields.String(min_length=3, max_length=80),
'email': fields.String()
})
# 分页模型
pagination_model = ns.model('Pagination', {
'page': fields.Integer(description='Current page'),
'per_page': fields.Integer(description='Items per page'),
'total': fields.Integer(description='Total items'),
'pages': fields.Integer(description='Total pages')
})
user_list_model = ns.model('UserList', {
'users': fields.List(fields.Nested(user_model)),
'pagination': fields.Nested(pagination_model)
})
资源定义
from flask_restx import Namespace, Resource, reqparse
from flask import request
ns = Namespace('users', description='User operations')
# 查询参数解析器
list_parser = reqparse.RequestParser()
list_parser.add_argument('page', type=int, default=1, location='args')
list_parser.add_argument('per_page', type=int, default=20, location='args')
list_parser.add_argument('search', type=str, location='args')
@ns.route('/')
class UserList(Resource):
@ns.doc('list_users')
@ns.expect(list_parser)
@ns.marshal_list_with(user_model)
def get(self):
"""获取用户列表"""
args = list_parser.parse_args()
query = User.query
if args['search']:
query = query.filter(User.username.contains(args['search']))
pagination = query.paginate(
page=args['page'],
per_page=args['per_page']
)
return {
'users': pagination.items,
'pagination': {
'page': pagination.page,
'per_page': pagination.per_page,
'total': pagination.total,
'pages': pagination.pages
}
}
@ns.doc('create_user')
@ns.expect(user_input_model)
@ns.marshal_with(user_model, code=201)
@ns.response(400, 'Validation Error')
def post(self):
"""创建新用户"""
data = request.json
if User.query.filter_by(email=data['email']).first():
ns.abort(400, 'Email already registered')
user = User(
username=data['username'],
email=data['email']
)
user.set_password(data['password'])
db.session.add(user)
db.session.commit()
return user, 201
@ns.route('/<int:id>')
@ns.response(404, 'User not found')
@ns.param('id', 'User ID')
class UserResource(Resource):
@ns.doc('get_user')
@ns.marshal_with(user_model)
def get(self, id):
"""获取用户详情"""
user = User.query.get_or_404(id)
return user
@ns.doc('update_user')
@ns.expect(user_update_model)
@ns.marshal_with(user_model)
def put(self, id):
"""更新用户信息"""
user = User.query.get_or_404(id)
data = request.json
if 'username' in data:
user.username = data['username']
if 'email' in data:
user.email = data['email']
db.session.commit()
return user
@ns.doc('delete_user')
@ns.response(204, 'User deleted')
def delete(self, id):
"""删除用户"""
user = User.query.get_or_404(id)
db.session.delete(user)
db.session.commit()
return '', 204
认证装饰器
from flask_restx import Namespace, Resource
from functools import wraps
ns = Namespace('users', description='User operations')
# 定义认证
authorizations = {
'Bearer': {
'type': 'apiKey',
'in': 'header',
'name': 'Authorization',
'description': 'Bearer token authentication'
}
}
api = Api(
app,
authorizations=authorizations,
security='Bearer'
)
# 认证装饰器
def token_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = request.headers.get('Authorization')
if not token:
ns.abort(401, 'Token is missing')
try:
token = token.split(' ')[1] # Bearer <token>
data = jwt.decode(token, app.config['SECRET_KEY'], algorithms=['HS256'])
current_user = User.query.get(data['sub'])
except Exception:
ns.abort(401, 'Token is invalid')
return f(*args, **kwargs)
return decorated
@ns.route('/profile')
class Profile(Resource):
@ns.doc('get_profile', security='Bearer')
@token_required
@ns.marshal_with(user_model)
def get(self):
"""获取当前用户资料"""
return current_user
第五章:API 版本控制
URL 版本控制
from flask import Flask, Blueprint
from flask_restx import Api, Namespace
app = Flask(__name__)
# V1 API
v1_blueprint = Blueprint('api_v1', __name__, url_prefix='/api/v1')
api_v1 = Api(v1_blueprint, version='1.0', title='API V1')
users_ns_v1 = Namespace('users', description='User operations')
api_v1.add_namespace(users_ns_v1)
@users_ns_v1.route('/')
class UsersV1(Resource):
def get(self):
return {'version': 'v1', 'users': []}
# V2 API
v2_blueprint = Blueprint('api_v2', __name__, url_prefix='/api/v2')
api_v2 = Api(v2_blueprint, version='2.0', title='API V2')
users_ns_v2 = Namespace('users', description='User operations')
api_v2.add_namespace(users_ns_v2)
@users_ns_v2.route('/')
class UsersV2(Resource):
def get(self):
return {'version': 'v2', 'data': {'users': []}}
# 注册蓝图
app.register_blueprint(v1_blueprint)
app.register_blueprint(v2_blueprint)
请求头版本控制
from flask import Flask, request, g
app = Flask(__name__)
# 版本路由
version_routes = {}
def version(v):
"""版本装饰器"""
def decorator(f):
version_routes.setdefault(f.__name__, {})[v] = f
return f
return decorator
@app.before_request
def get_api_version():
g.api_version = request.headers.get('API-Version', '1.0')
def versioned_route(func_name):
"""获取版本化的函数"""
def wrapper(*args, **kwargs):
version = g.api_version
if func_name in version_routes:
if version in version_routes[func_name]:
return version_routes[func_name][version](*args, **kwargs)
return {'error': f'Version {version} not supported'}, 400
return wrapper
# 使用
@app.route('/api/users')
@version('1.0')
def get_users_v1():
return {'version': '1.0', 'users': []}
@app.route('/api/users')
@version('2.0')
def get_users_v2():
return {'version': '2.0', 'data': {'users': []}}
Accept 头版本控制
from flask import Flask, request
from werkzeug.exceptions import NotAcceptable
app = Flask(__name__)
def get_version_from_accept():
"""从 Accept 头解析版本"""
accept = request.headers.get('Accept', '')
# 支持格式:application/vnd.myapi.v1+json
import re
match = re.search(r'application/vnd\.myapi\.v(\d+)\+json', accept)
if match:
return int(match.group(1))
return 1 # 默认版本
@app.route('/api/users')
def get_users():
version = get_version_from_accept()
if version == 1:
return {'users': []}
elif version == 2:
return {'data': {'users': []}, 'meta': {}}
else:
raise NotAcceptable('API version not supported')
第六章:API 认证方案
API Key 认证
from flask import Flask, request, g
from functools import wraps
app = Flask(__name__)
# API Key 存储(实际应使用数据库)
API_KEYS = {
'api_key_123': {'user_id': 1, 'permissions': ['read', 'write']},
'api_key_456': {'user_id': 2, 'permissions': ['read']},
}
def require_api_key(f):
@wraps(f)
def decorated(*args, **kwargs):
api_key = request.headers.get('X-API-Key')
if not api_key:
return {'error': 'API key is missing'}, 401
if api_key not in API_KEYS:
return {'error': 'Invalid API key'}, 401
g.api_client = API_KEYS[api_key]
return f(*args, **kwargs)
return decorated
def require_permission(permission):
def decorator(f):
@wraps(f)
@require_api_key
def decorated(*args, **kwargs):
if permission not in g.api_client['permissions']:
return {'error': 'Permission denied'}, 403
return f(*args, **kwargs)
return decorated
return decorator
@app.route('/api/data')
@require_api_key
def get_data():
return {'data': 'secret data'}
@app.route('/api/data', methods=['POST'])
@require_permission('write')
def create_data():
return {'message': 'Data created'}, 201
OAuth 2.0
from flask import Flask, request, jsonify
from authlib.integrations.flask_oauth2 import ResourceProtector
from authlib.oauth2.rfc6750 import BearerTokenValidator
app = Flask(__name__)
class MyBearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string):
# 从数据库查询 token
return Token.query.filter_by(access_token=token_string).first()
def request_invalid(self, request):
return False
def token_revoked(self, token):
return token.revoked
require_oauth = ResourceProtector()
require_oauth.register_token_validator(MyBearerTokenValidator())
@app.route('/api/profile')
@require_oauth()
def profile():
user = current_token.user
return jsonify(user.to_dict())
@app.route('/api/admin')
@require_oauth('admin') # 需要 admin scope
def admin():
return jsonify({'message': 'Admin area'})
JWT 认证完整示例
from flask import Flask, request, jsonify, g
from flask_restx import Api, Resource, Namespace
import jwt
from datetime import datetime, timedelta
from functools import wraps
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key'
api = Api(app, authorizations={
'Bearer': {
'type': 'apiKey',
'in': 'header',
'name': 'Authorization'
}
})
auth_ns = Namespace('auth', description='Authentication')
api.add_namespace(auth_ns)
def create_token(user_id, token_type='access'):
if token_type == 'access':
expires = timedelta(hours=1)
else:
expires = timedelta(days=30)
payload = {
'sub': user_id,
'type': token_type,
'iat': datetime.utcnow(),
'exp': datetime.utcnow() + expires
}
return jwt.encode(payload, app.config['SECRET_KEY'], algorithm='HS256')
def jwt_required(f):
@wraps(f)
def decorated(*args, **kwargs):
token = None
auth_header = request.headers.get('Authorization')
if auth_header:
try:
token = auth_header.split(' ')[1]
except IndexError:
return {'error': 'Invalid token format'}, 401
if not token:
return {'error': 'Token is missing'}, 401
try:
payload = jwt.decode(
token,
app.config['SECRET_KEY'],
algorithms=['HS256']
)
if payload['type'] != 'access':
return {'error': 'Invalid token type'}, 401
g.current_user_id = payload['sub']
except jwt.ExpiredSignatureError:
return {'error': 'Token has expired'}, 401
except jwt.InvalidTokenError:
return {'error': 'Invalid token'}, 401
return f(*args, **kwargs)
return decorated
@auth_ns.route('/login')
class Login(Resource):
def post(self):
"""用户登录"""
data = request.json
username = data.get('username')
password = data.get('password')
user = User.query.filter_by(username=username).first()
if not user or not user.check_password(password):
return {'error': 'Invalid credentials'}, 401
return {
'access_token': create_token(user.id, 'access'),
'refresh_token': create_token(user.id, 'refresh'),
'user': user.to_dict()
}
@auth_ns.route('/refresh')
class Refresh(Resource):
def post(self):
"""刷新令牌"""
data = request.json
refresh_token = data.get('refresh_token')
try:
payload = jwt.decode(
refresh_token,
app.config['SECRET_KEY'],
algorithms=['HS256']
)
if payload['type'] != 'refresh':
return {'error': 'Invalid token type'}, 401
return {
'access_token': create_token(payload['sub'], 'access')
}
except jwt.ExpiredSignatureError:
return {'error': 'Refresh token has expired'}, 401
except jwt.InvalidTokenError:
return {'error': 'Invalid refresh token'}, 401
第七章:API 高级特性
数据验证(Marshmallow)
pip install marshmallow flask-marshmallow marshmallow-sqlalchemy
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from flask_marshmallow import Marshmallow
from marshmallow import validates, ValidationError, fields
app = Flask(__name__)
db = SQLAlchemy(app)
ma = Marshmallow(app)
class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(80), unique=True)
email = db.Column(db.String(120), unique=True)
posts = db.relationship('Post', backref='author')
class Post(db.Model):
id = db.Column(db.Integer, primary_key=True)
title = db.Column(db.String(200))
content = db.Column(db.Text)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'))
# Schema 定义
class UserSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = User
load_instance = True
include_relationships = True
email = fields.Email(required=True)
@validates('username')
def validate_username(self, value):
if len(value) < 3:
raise ValidationError('Username must be at least 3 characters')
if User.query.filter_by(username=value).first():
raise ValidationError('Username already exists')
class PostSchema(ma.SQLAlchemyAutoSchema):
class Meta:
model = Post
load_instance = True
include_fk = True
author = ma.Nested(UserSchema, only=('id', 'username'))
user_schema = UserSchema()
users_schema = UserSchema(many=True)
post_schema = PostSchema()
posts_schema = PostSchema(many=True)
# 使用
@app.route('/api/users', methods=['POST'])
def create_user():
try:
user = user_schema.load(request.json)
db.session.add(user)
db.session.commit()
return user_schema.dump(user), 201
except ValidationError as err:
return {'errors': err.messages}, 400
@app.route('/api/users/<int:id>')
def get_user(id):
user = User.query.get_or_404(id)
return user_schema.dump(user)
分页与过滤
from flask import request
from flask_restx import Namespace, Resource, reqparse
ns = Namespace('posts', description='Post operations')
# 分页和过滤参数
list_parser = reqparse.RequestParser()
list_parser.add_argument('page', type=int, default=1)
list_parser.add_argument('per_page', type=int, default=20)
list_parser.add_argument('sort', type=str, choices=['created_at', 'title'])
list_parser.add_argument('order', type=str, choices=['asc', 'desc'], default='desc')
list_parser.add_argument('search', type=str)
list_parser.add_argument('category', type=int)
list_parser.add_argument('author', type=int)
@ns.route('/')
class PostList(Resource):
@ns.expect(list_parser)
def get(self):
"""获取文章列表(支持分页和过滤)"""
args = list_parser.parse_args()
query = Post.query
# 搜索
if args['search']:
search = f"%{args['search']}%"
query = query.filter(
db.or_(
Post.title.ilike(search),
Post.content.ilike(search)
)
)
# 过滤
if args['category']:
query = query.filter(Post.category_id == args['category'])
if args['author']:
query = query.filter(Post.user_id == args['author'])
# 排序
if args['sort']:
column = getattr(Post, args['sort'])
if args['order'] == 'desc':
column = column.desc()
query = query.order_by(column)
# 分页
pagination = query.paginate(
page=args['page'],
per_page=min(args['per_page'], 100), # 限制最大每页数量
error_out=False
)
return {
'items': posts_schema.dump(pagination.items),
'pagination': {
'page': pagination.page,
'per_page': pagination.per_page,
'total': pagination.total,
'pages': pagination.pages,
'has_next': pagination.has_next,
'has_prev': pagination.has_prev
}
}
缓存
from flask_caching import Cache
from functools import wraps
import hashlib
import json
cache = Cache(app, config={'CACHE_TYPE': 'redis'})
def cache_response(timeout=300, key_prefix='api'):
"""API 响应缓存装饰器"""
def decorator(f):
@wraps(f)
def decorated(*args, **kwargs):
# 生成缓存键
cache_key = f"{key_prefix}:{request.path}:{hashlib.md5(request.query_string).hexdigest()}"
# 尝试从缓存获取
cached = cache.get(cache_key)
if cached:
return json.loads(cached)
# 执行函数
result = f(*args, **kwargs)
# 缓存结果
cache.set(cache_key, json.dumps(result), timeout=timeout)
return result
return decorated
return decorator
@app.route('/api/posts')
@cache_response(timeout=60)
def get_posts():
posts = Post.query.all()
return posts_schema.dump(posts)
# 手动清除缓存
@app.route('/api/posts', methods=['POST'])
def create_post():
# 创建文章
post = post_schema.load(request.json)
db.session.add(post)
db.session.commit()
# 清除列表缓存
cache.delete_memoized(get_posts)
return post_schema.dump(post), 201
速率限制
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
limiter = Limiter(
app,
key_func=get_remote_address,
default_limits=["200 per day", "50 per hour"]
)
# 基于用户的限制
def get_user_identifier():
if hasattr(g, 'current_user_id'):
return f"user:{g.current_user_id}"
return get_remote_address()
@app.route('/api/search')
@limiter.limit("30 per minute")
def search():
return {'results': []}
@app.route('/api/upload', methods=['POST'])
@limiter.limit("5 per minute", key_func=get_user_identifier)
def upload():
return {'message': 'Uploaded'}
# 处理超限
@app.errorhandler(429)
def ratelimit_handler(e):
return {
'error': 'Rate limit exceeded',
'message': str(e.description)
}, 429
CORS 跨域
from flask_cors import CORS
# 全局 CORS
CORS(app, resources={
r"/api/*": {
"origins": ["http://localhost:3000", "https://example.com"],
"methods": ["GET", "POST", "PUT", "DELETE"],
"allow_headers": ["Content-Type", "Authorization"]
}
})
# 单个路由 CORS
from flask_cors import cross_origin
@app.route('/api/public')
@cross_origin()
def public_endpoint():
return {'message': 'Public data'}
@app.route('/api/specific')
@cross_origin(origins=['https://trusted.com'])
def specific_endpoint():
return {'message': 'Specific origin only'}
第八章:API 文档与测试
Swagger/OpenAPI 文档
from flask_restx import Api, Resource, fields
api = Api(
app,
version='1.0',
title='My API',
description='A comprehensive API documentation',
doc='/docs',
prefix='/api'
)
# 详细模型定义
user_model = api.model('User', {
'id': fields.Integer(
readonly=True,
description='User unique identifier',
example=1
),
'username': fields.String(
required=True,
description='Username',
example='john_doe',
min_length=3,
max_length=80
),
'email': fields.String(
required=True,
description='Email address',
example='john@example.com'
),
'created_at': fields.DateTime(
readonly=True,
description='Creation timestamp'
)
})
@api.route('/users/<int:id>')
@api.doc(params={'id': 'User ID'})
class UserResource(Resource):
@api.doc(
description='Get a user by ID',
responses={
200: ('Success', user_model),
404: 'User not found'
}
)
@api.marshal_with(user_model)
def get(self, id):
"""获取用户详情
返回指定 ID 的用户信息,包含用户名、邮箱等基本信息。
"""
return User.query.get_or_404(id)
API 测试
# tests/test_api.py
import pytest
from app import create_app, db
from app.models import User
@pytest.fixture
def app():
app = create_app('testing')
with app.app_context():
db.create_all()
yield app
db.drop_all()
@pytest.fixture
def client(app):
return app.test_client()
@pytest.fixture
def auth_headers(client):
# 创建测试用户
user = User(username='test', email='test@example.com')
user.set_password('password')
db.session.add(user)
db.session.commit()
# 获取令牌
response = client.post('/api/auth/login', json={
'username': 'test',
'password': 'password'
})
token = response.json['access_token']
return {'Authorization': f'Bearer {token}'}
class TestUserAPI:
def test_get_users(self, client):
response = client.get('/api/users')
assert response.status_code == 200
assert 'users' in response.json
def test_create_user(self, client):
response = client.post('/api/users', json={
'username': 'newuser',
'email': 'new@example.com',
'password': 'password123'
})
assert response.status_code == 201
assert response.json['username'] == 'newuser'
def test_get_user(self, client, auth_headers):
# 创建用户
user = User(username='alice', email='alice@example.com')
db.session.add(user)
db.session.commit()
# 获取用户
response = client.get(f'/api/users/{user.id}', headers=auth_headers)
assert response.status_code == 200
assert response.json['username'] == 'alice'
def test_update_user(self, client, auth_headers):
user = User(username='bob', email='bob@example.com')
db.session.add(user)
db.session.commit()
response = client.put(
f'/api/users/{user.id}',
headers=auth_headers,
json={'username': 'bobby'}
)
assert response.status_code == 200
assert response.json['username'] == 'bobby'
def test_delete_user(self, client, auth_headers):
user = User(username='charlie', email='charlie@example.com')
db.session.add(user)
db.session.commit()
response = client.delete(f'/api/users/{user.id}', headers=auth_headers)
assert response.status_code == 204
def test_unauthorized(self, client):
response = client.get('/api/profile')
assert response.status_code == 401
总结
本章详细介绍了 Flask RESTful API 开发:
- REST 基础:设计原则、HTTP 方法、状态码
- 原生 API:基础视图、错误处理、响应格式化
- Flask-RESTful:资源类、请求解析、输出格式化
- Flask-RESTX:Swagger 文档、模型定义
- 版本控制:URL、请求头、Accept 头
- 认证方案:API Key、OAuth 2.0、JWT
- 高级特性:数据验证、分页过滤、缓存、限流、CORS
- 文档与测试:OpenAPI 文档、API 测试
下一章我们将学习 Flask 扩展生态。