Skip to content

prefect.server.api.middleware

CsrfMiddleware

Bases: BaseHTTPMiddleware

Middleware for CSRF protection. This middleware will check for a CSRF token in the headers of any POST, PUT, PATCH, or DELETE request. If the token is not present or does not match the token stored in the database for the client, the request will be rejected with a 403 status code.

Source code in prefect/server/api/middleware.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class CsrfMiddleware(BaseHTTPMiddleware):
    """
    Middleware for CSRF protection. This middleware will check for a CSRF token
    in the headers of any POST, PUT, PATCH, or DELETE request. If the token is
    not present or does not match the token stored in the database for the
    client, the request will be rejected with a 403 status code.
    """

    async def dispatch(
        self, request: Request, call_next: NextMiddlewareFunction
    ) -> Response:
        """
        Dispatch method for the middleware. This method will check for the
        presence of a CSRF token in the headers of the request and compare it
        to the token stored in the database for the client. If the token is not
        present or does not match, the request will be rejected with a 403
        status code.
        """

        request_needs_csrf_protection = request.method in {
            "POST",
            "PUT",
            "PATCH",
            "DELETE",
        }

        if (
            settings.PREFECT_SERVER_CSRF_PROTECTION_ENABLED.value()
            and request_needs_csrf_protection
        ):
            incoming_token = request.headers.get("Prefect-Csrf-Token")
            incoming_client = request.headers.get("Prefect-Csrf-Client")

            if incoming_token is None:
                return JSONResponse(
                    {"detail": "Missing CSRF token."},
                    status_code=status.HTTP_403_FORBIDDEN,
                )

            if incoming_client is None:
                return JSONResponse(
                    {"detail": "Missing client identifier."},
                    status_code=status.HTTP_403_FORBIDDEN,
                )

            db = provide_database_interface()
            async with db.session_context() as session:
                token = await models.csrf_token.read_token_for_client(
                    session=session, client=incoming_client
                )

                if token is None or token.token != incoming_token:
                    return JSONResponse(
                        {"detail": "Invalid CSRF token or client identifier."},
                        status_code=status.HTTP_403_FORBIDDEN,
                        headers={"Access-Control-Allow-Origin": "*"},
                    )

        return await call_next(request)

dispatch async

Dispatch method for the middleware. This method will check for the presence of a CSRF token in the headers of the request and compare it to the token stored in the database for the client. If the token is not present or does not match, the request will be rejected with a 403 status code.

Source code in prefect/server/api/middleware.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
async def dispatch(
    self, request: Request, call_next: NextMiddlewareFunction
) -> Response:
    """
    Dispatch method for the middleware. This method will check for the
    presence of a CSRF token in the headers of the request and compare it
    to the token stored in the database for the client. If the token is not
    present or does not match, the request will be rejected with a 403
    status code.
    """

    request_needs_csrf_protection = request.method in {
        "POST",
        "PUT",
        "PATCH",
        "DELETE",
    }

    if (
        settings.PREFECT_SERVER_CSRF_PROTECTION_ENABLED.value()
        and request_needs_csrf_protection
    ):
        incoming_token = request.headers.get("Prefect-Csrf-Token")
        incoming_client = request.headers.get("Prefect-Csrf-Client")

        if incoming_token is None:
            return JSONResponse(
                {"detail": "Missing CSRF token."},
                status_code=status.HTTP_403_FORBIDDEN,
            )

        if incoming_client is None:
            return JSONResponse(
                {"detail": "Missing client identifier."},
                status_code=status.HTTP_403_FORBIDDEN,
            )

        db = provide_database_interface()
        async with db.session_context() as session:
            token = await models.csrf_token.read_token_for_client(
                session=session, client=incoming_client
            )

            if token is None or token.token != incoming_token:
                return JSONResponse(
                    {"detail": "Invalid CSRF token or client identifier."},
                    status_code=status.HTTP_403_FORBIDDEN,
                    headers={"Access-Control-Allow-Origin": "*"},
                )

    return await call_next(request)