FastAPI 入门教程FastAPI 入门教程
首页
基础教程
实战项目
FastAPI官网
首页
基础教程
实战项目
FastAPI官网
  • 实战项目

    • 🎯 实战项目 - 学生管理系统
    • 第1章 - 项目概述
    • 第2章 - 项目搭建
    • 第3章 - 数据模型
    • 第4章 - CRUD 接口
    • 第5章 - 进阶功能
    • 第6章 - 完整代码

第5章 - 进阶功能

🎯 本章目标

  • 实现数据统计接口
  • 实现成绩排名功能
  • 添加数据导出功能
  • 完善错误处理

1️⃣ 统计 API 路由

app/api/statistics.py

"""
统计 API 路由
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from sqlalchemy import func, desc

from app.database import get_db
from app.models import Class, Student, Score
from app.utils.response import success

router = APIRouter()

@router.get("/statistics/overview", summary="总体统计")
def get_overview(db: Session = Depends(get_db)):
    """
    获取系统总体统计数据
    
    包括:
    - 班级总数
    - 学生总数
    - 成绩记录数
    - 各科平均分
    """
    # 基础统计
    class_count = db.query(Class).count()
    student_count = db.query(Student).count()
    score_count = db.query(Score).count()
    
    # 各科平均分
    subject_avg = db.query(
        Score.subject,
        func.avg(Score.score).label("avg_score"),
        func.count(Score.id).label("count")
    ).group_by(Score.subject).all()
    
    subject_stats = [
        {
            "subject": item.subject,
            "avg_score": round(item.avg_score, 2) if item.avg_score else 0,
            "count": item.count
        }
        for item in subject_avg
    ]
    
    # 性别分布
    gender_dist = db.query(
        Student.gender,
        func.count(Student.id).label("count")
    ).group_by(Student.gender).all()
    
    gender_stats = {item.gender: item.count for item in gender_dist}
    
    return success(data={
        "class_count": class_count,
        "student_count": student_count,
        "score_count": score_count,
        "subject_stats": subject_stats,
        "gender_distribution": gender_stats
    })

@router.get("/statistics/class/{class_id}", summary="班级统计")
def get_class_statistics(class_id: int, db: Session = Depends(get_db)):
    """
    获取班级统计数据
    
    包括:
    - 学生人数
    - 各科成绩统计(平均分、最高分、最低分)
    - 成绩分布
    """
    # 学生人数
    student_count = db.query(Student).filter(Student.class_id == class_id).count()
    
    # 获取班级学生ID列表
    student_ids = db.query(Student.id).filter(Student.class_id == class_id).all()
    student_ids = [s.id for s in student_ids]
    
    if not student_ids:
        return success(data={
            "student_count": 0,
            "subject_stats": [],
            "score_distribution": {}
        })
    
    # 各科成绩统计
    subject_stats = db.query(
        Score.subject,
        func.avg(Score.score).label("avg"),
        func.max(Score.score).label("max"),
        func.min(Score.score).label("min"),
        func.count(Score.id).label("count")
    ).filter(
        Score.student_id.in_(student_ids)
    ).group_by(Score.subject).all()
    
    subject_data = [
        {
            "subject": item.subject,
            "avg": round(item.avg, 2) if item.avg else 0,
            "max": item.max or 0,
            "min": item.min or 0,
            "count": item.count
        }
        for item in subject_stats
    ]
    
    # 成绩分布(按分数段)
    scores = db.query(Score.score).filter(Score.student_id.in_(student_ids)).all()
    distribution = {
        "优秀(90-100)": 0,
        "良好(80-89)": 0,
        "中等(70-79)": 0,
        "及格(60-69)": 0,
        "不及格(0-59)": 0
    }
    
    for score in scores:
        s = score.score
        if s >= 90:
            distribution["优秀(90-100)"] += 1
        elif s >= 80:
            distribution["良好(80-89)"] += 1
        elif s >= 70:
            distribution["中等(70-79)"] += 1
        elif s >= 60:
            distribution["及格(60-69)"] += 1
        else:
            distribution["不及格(0-59)"] += 1
    
    return success(data={
        "student_count": student_count,
        "subject_stats": subject_data,
        "score_distribution": distribution
    })

@router.get("/statistics/ranking", summary="成绩排名")
def get_ranking(
    subject: str = Query(default=None, description="科目(不传则按总分排名)"),
    class_id: int = Query(default=None, description="班级ID(不传则全校排名)"),
    limit: int = Query(default=10, ge=1, le=100, description="返回数量"),
    db: Session = Depends(get_db)
):
    """
    获取成绩排名
    
    - 可按科目排名或按总分排名
    - 可按班级筛选或全校排名
    """
    if subject:
        # 单科排名
        query = db.query(
            Student.id,
            Student.name,
            Student.student_no,
            Score.score
        ).join(Score, Student.id == Score.student_id).filter(
            Score.subject == subject
        )
        
        if class_id:
            query = query.filter(Student.class_id == class_id)
        
        results = query.order_by(desc(Score.score)).limit(limit).all()
        
        ranking = [
            {
                "rank": idx + 1,
                "student_id": item.id,
                "student_name": item.name,
                "student_no": item.student_no,
                "score": item.score
            }
            for idx, item in enumerate(results)
        ]
    else:
        # 总分排名
        query = db.query(
            Student.id,
            Student.name,
            Student.student_no,
            func.sum(Score.score).label("total_score"),
            func.count(Score.id).label("subject_count")
        ).join(Score, Student.id == Score.student_id)
        
        if class_id:
            query = query.filter(Student.class_id == class_id)
        
        results = query.group_by(Student.id).order_by(desc("total_score")).limit(limit).all()
        
        ranking = [
            {
                "rank": idx + 1,
                "student_id": item.id,
                "student_name": item.name,
                "student_no": item.student_no,
                "total_score": item.total_score or 0,
                "subject_count": item.subject_count,
                "avg_score": round(item.total_score / item.subject_count, 2) if item.subject_count > 0 else 0
            }
            for idx, item in enumerate(results)
        ]
    
    return success(data={
        "subject": subject or "总分",
        "class_id": class_id,
        "ranking": ranking
    })

@router.get("/statistics/student/{student_id}", summary="学生成绩分析")
def get_student_analysis(student_id: int, db: Session = Depends(get_db)):
    """
    获取学生成绩分析
    
    包括:
    - 各科成绩
    - 总分和平均分
    - 班级排名
    - 年级排名
    """
    # 获取学生信息
    student = db.query(Student).filter(Student.id == student_id).first()
    if not student:
        return success(data=None, message="学生不存在")
    
    # 获取学生成绩
    scores = db.query(Score).filter(Score.student_id == student_id).all()
    
    if not scores:
        return success(data={
            "student_name": student.name,
            "student_no": student.student_no,
            "scores": [],
            "total_score": 0,
            "avg_score": 0,
            "class_rank": None,
            "grade_rank": None
        })
    
    # 各科成绩
    score_list = [
        {"subject": s.subject, "score": s.score, "exam_date": str(s.exam_date)}
        for s in scores
    ]
    
    # 总分和平均分
    total_score = sum(s.score for s in scores)
    avg_score = round(total_score / len(scores), 2)
    
    # 班级排名(按总分)
    class_rank = None
    if student.class_id:
        class_students = db.query(Student.id).filter(Student.class_id == student.class_id).all()
        class_student_ids = [s.id for s in class_students]
        
        class_totals = db.query(
            Score.student_id,
            func.sum(Score.score).label("total")
        ).filter(
            Score.student_id.in_(class_student_ids)
        ).group_by(Score.student_id).order_by(desc("total")).all()
        
        for idx, item in enumerate(class_totals):
            if item.student_id == student_id:
                class_rank = idx + 1
                break
    
    # 年级排名(全校)
    all_totals = db.query(
        Score.student_id,
        func.sum(Score.score).label("total")
    ).group_by(Score.student_id).order_by(desc("total")).all()
    
    grade_rank = None
    for idx, item in enumerate(all_totals):
        if item.student_id == student_id:
            grade_rank = idx + 1
            break
    
    return success(data={
        "student_name": student.name,
        "student_no": student.student_no,
        "class_name": student.class_info.name if student.class_info else None,
        "scores": score_list,
        "total_score": total_score,
        "avg_score": avg_score,
        "subject_count": len(scores),
        "class_rank": class_rank,
        "grade_rank": grade_rank,
        "total_students": len(all_totals)
    })

2️⃣ 更新路由注册

更新 app/api/__init__.py:

"""
API 路由汇总
"""
from app.api import classes, students, scores, statistics

__all__ = ["classes", "students", "scores", "statistics"]

更新 app/main.py,添加统计路由:

from app.api import classes, students, scores, statistics

# ... 其他代码 ...

# 注册路由
app.include_router(classes.router, prefix=settings.API_PREFIX, tags=["班级管理"])
app.include_router(students.router, prefix=settings.API_PREFIX, tags=["学生管理"])
app.include_router(scores.router, prefix=settings.API_PREFIX, tags=["成绩管理"])
app.include_router(statistics.router, prefix=settings.API_PREFIX, tags=["数据统计"])

3️⃣ 全局异常处理

app/utils/exceptions.py

"""
全局异常处理
"""
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError

async def http_exception_handler(request: Request, exc: HTTPException):
    """HTTP 异常处理"""
    return JSONResponse(
        status_code=exc.status_code,
        content={
            "code": exc.status_code,
            "message": exc.detail,
            "data": None
        }
    )

async def validation_exception_handler(request: Request, exc: RequestValidationError):
    """请求验证异常处理"""
    errors = exc.errors()
    error_messages = []
    for error in errors:
        field = ".".join(str(loc) for loc in error["loc"])
        message = error["msg"]
        error_messages.append(f"{field}: {message}")
    
    return JSONResponse(
        status_code=422,
        content={
            "code": 422,
            "message": "参数验证失败",
            "data": {
                "errors": error_messages
            }
        }
    )

async def general_exception_handler(request: Request, exc: Exception):
    """通用异常处理"""
    return JSONResponse(
        status_code=500,
        content={
            "code": 500,
            "message": "服务器内部错误",
            "data": None
        }
    )

在 app/main.py 中注册异常处理:

from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError

from app.utils.exceptions import (
    http_exception_handler,
    validation_exception_handler,
    general_exception_handler
)

# ... 创建 app ...

# 注册异常处理器
app.add_exception_handler(HTTPException, http_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(Exception, general_exception_handler)

4️⃣ 添加日志

app/utils/logger.py

"""
日志配置
"""
import logging
import sys
from datetime import datetime

# 创建日志器
logger = logging.getLogger("student_management")
logger.setLevel(logging.DEBUG)

# 控制台处理器
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)

# 格式化器
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)

# 添加处理器
logger.addHandler(console_handler)

def log_info(message: str):
    logger.info(message)

def log_error(message: str):
    logger.error(message)

def log_debug(message: str):
    logger.debug(message)

5️⃣ 请求日志中间件

在 app/main.py 中添加:

import time
from fastapi import Request
from app.utils.logger import log_info

@app.middleware("http")
async def log_requests(request: Request, call_next):
    """记录请求日志"""
    start_time = time.time()
    
    response = await call_next(request)
    
    process_time = time.time() - start_time
    log_info(
        f"{request.method} {request.url.path} "
        f"- Status: {response.status_code} "
        f"- Time: {process_time:.3f}s"
    )
    
    return response

6️⃣ 数据初始化脚本

scripts/init_data.py

"""
初始化测试数据
"""
import sys
sys.path.append(".")

from datetime import date
from app.database import SessionLocal
from app.models import Class, Student, Score

def init_data():
    db = SessionLocal()
    
    try:
        # 创建班级
        classes = [
            Class(name="一年级1班", grade="一年级", teacher="张老师"),
            Class(name="一年级2班", grade="一年级", teacher="李老师"),
            Class(name="二年级1班", grade="二年级", teacher="王老师"),
        ]
        for c in classes:
            db.add(c)
        db.commit()
        
        print("✅ 班级数据创建成功")
        
        # 创建学生
        students = [
            Student(student_no="2024001", name="张三", gender="男", age=18, class_id=1),
            Student(student_no="2024002", name="李四", gender="女", age=17, class_id=1),
            Student(student_no="2024003", name="王五", gender="男", age=18, class_id=1),
            Student(student_no="2024004", name="赵六", gender="女", age=17, class_id=2),
            Student(student_no="2024005", name="钱七", gender="男", age=19, class_id=2),
        ]
        for s in students:
            db.add(s)
        db.commit()
        
        print("✅ 学生数据创建成功")
        
        # 创建成绩
        subjects = ["语文", "数学", "英语"]
        import random
        
        for student in students:
            for subject in subjects:
                score = Score(
                    student_id=student.id,
                    subject=subject,
                    score=random.randint(60, 100),
                    exam_date=date(2024, 1, 15),
                    exam_type="期中考试"
                )
                db.add(score)
        db.commit()
        
        print("✅ 成绩数据创建成功")
        print("\n🎉 所有测试数据初始化完成!")
        
    except Exception as e:
        print(f"❌ 初始化失败: {e}")
        db.rollback()
    finally:
        db.close()

if __name__ == "__main__":
    init_data()

运行初始化脚本:

python scripts/init_data.py

📝 小结

本章我们完成了:

  1. ✅ 数据统计接口(总体统计、班级统计)
  2. ✅ 成绩排名功能
  3. ✅ 学生成绩分析
  4. ✅ 全局异常处理
  5. ✅ 日志记录
  6. ✅ 数据初始化脚本

🏃 下一步

所有功能都完成了,最后来看看完整的项目代码!

👉 第6章 - 完整代码

最近更新: 2025/12/26 10:15
Contributors: 王长安
Prev
第4章 - CRUD 接口
Next
第6章 - 完整代码