Skip to content

mirascope.integrations.middleware

middleware_factory

A factory method for creating middleware decorators.

Example:

from mirascope.core import openai, prompt_template
from mirascope.integrations import middleware_factory

def with_saving():
    """Saves some data after a Mirascope call."""

    return middleware_factory(
        custom_context_manager=custom_context_manager,
        custom_decorator=custom_decorator,
        handle_call_response=handle_call_response,
        handle_call_response_async=handle_call_response_async,
        handle_stream=handle_stream,
        handle_stream_async=handle_stream_async,
        handle_response_model=handle_response_model,
        handle_response_model_async=handle_response_model_async,
        handle_structured_stream=handle_structured_stream,
        handle_structured_stream_async=handle_structured_stream_async,
    )


@with_saving()
@openai.call("gpt-4o-mini")
@prompt_template("Recommend a {genre} book")
def recommend_book(genre: str):
    ...


response = recommend_book("fantasy")  # `with_saving` automatically run
print(response.content)
Source code in mirascope/integrations/_middleware_factory.py
def middleware_factory(
    custom_context_manager: Callable[
        [SyncFunc | AsyncFunc], AbstractContextManager[_T]
    ] = default_context_manager,
    custom_decorator: Callable | None = None,
    handle_call_response: Callable[
        [BaseCallResponse, SyncFunc | AsyncFunc, _T | None], None
    ]
    | None = None,
    handle_call_response_async: Callable[
        [BaseCallResponse, SyncFunc | AsyncFunc, _T | None], Awaitable[None]
    ]
    | None = None,
    handle_stream: Callable[[BaseStream, SyncFunc | AsyncFunc, _T | None], None]
    | None = None,
    handle_stream_async: Callable[
        [BaseStream, SyncFunc | AsyncFunc, _T | None], Awaitable[None]
    ]
    | None = None,
    handle_response_model: Callable[
        [ResponseModel, SyncFunc | AsyncFunc, _T | None], None
    ]
    | None = None,
    handle_response_model_async: Callable[
        [ResponseModel, SyncFunc | AsyncFunc, _T | None], Awaitable[None]
    ]
    | None = None,
    handle_structured_stream: Callable[
        [BaseStructuredStream, SyncFunc | AsyncFunc, _T | None], None
    ]
    | None = None,
    handle_structured_stream_async: Callable[
        [BaseStructuredStream, SyncFunc | AsyncFunc, _T | None], Awaitable[None]
    ]
    | None = None,
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
    '''A factory method for creating middleware decorators.

    Example:

    ```python
    from mirascope.core import openai, prompt_template
    from mirascope.integrations import middleware_factory

    def with_saving():
        """Saves some data after a Mirascope call."""

        return middleware_factory(
            custom_context_manager=custom_context_manager,
            custom_decorator=custom_decorator,
            handle_call_response=handle_call_response,
            handle_call_response_async=handle_call_response_async,
            handle_stream=handle_stream,
            handle_stream_async=handle_stream_async,
            handle_response_model=handle_response_model,
            handle_response_model_async=handle_response_model_async,
            handle_structured_stream=handle_structured_stream,
            handle_structured_stream_async=handle_structured_stream_async,
        )


    @with_saving()
    @openai.call("gpt-4o-mini")
    @prompt_template("Recommend a {genre} book")
    def recommend_book(genre: str):
        ...


    response = recommend_book("fantasy")  # `with_saving` automatically run
    print(response.content)
    ```
    '''

    @overload
    def decorator(fn: Callable[_P, _R]) -> Callable[_P, _R]: ...

    @overload
    def decorator(fn: Callable[_P, Awaitable[_R]]) -> Callable[_P, Awaitable[_R]]: ...

    def decorator(
        fn: Callable[_P, _R | Awaitable[_R]],
    ) -> Callable[_P, _R | Awaitable[_R]]:
        if inspect.iscoroutinefunction(fn):

            @wraps(fn)
            async def wrapper_async(*args: _P.args, **kwargs: _P.kwargs) -> _R:
                result = await fn(*args, **kwargs)
                if (
                    isinstance(result, BaseCallResponse)
                    and handle_call_response_async is not None
                ):
                    with custom_context_manager(fn) as context:
                        await handle_call_response_async(result, fn, context)
                elif isinstance(result, BaseStream):
                    original_class = type(result)
                    original_aiter = result.__aiter__

                    def new_stream_aiter(
                        self: Any,  # noqa: ANN401
                    ) -> AsyncGenerator[tuple[Any, Any | None], Any]:  # noqa: ANN401
                        async def generator() -> (
                            AsyncGenerator[tuple[Any, Any | None], Any]
                        ):
                            with custom_context_manager(fn) as context:
                                async for chunk, tool in original_aiter():
                                    yield chunk, tool
                                if handle_stream_async is not None:
                                    await handle_stream_async(result, fn, context)

                        return generator()

                    class MiddlewareAsyncStream(original_class):
                        __aiter__ = (
                            custom_decorator(fn)(new_stream_aiter)
                            if custom_decorator
                            else new_stream_aiter
                        )

                    result.__class__ = MiddlewareAsyncStream
                elif (
                    isinstance(result, ResponseModel)
                    and handle_response_model_async is not None
                ):
                    with custom_context_manager(fn) as context:
                        await handle_response_model_async(result, fn, context)
                elif isinstance(result, BaseStructuredStream):
                    original_class = type(result)
                    original_aiter = result.__aiter__

                    def new_aiter(
                        self: Any,  # noqa: ANN401
                    ) -> AsyncGenerator[tuple[Any, Any | None], Any]:  # noqa: ANN401
                        async def generator() -> (
                            AsyncGenerator[tuple[Any, Any | None], Any]
                        ):
                            with custom_context_manager(fn) as context:
                                async for chunk in original_aiter():
                                    yield chunk
                                if handle_structured_stream_async is not None:
                                    await handle_structured_stream_async(
                                        result, fn, context
                                    )

                        return generator()

                    class MiddlewareAsyncStructuredStream(original_class):
                        __aiter__ = (
                            custom_decorator(fn)(new_aiter)
                            if custom_decorator
                            else new_aiter
                        )

                    result.__class__ = MiddlewareAsyncStructuredStream
                return cast(_R, result)

            return (
                custom_decorator(fn)(wrapper_async)
                if custom_decorator
                else wrapper_async
            )
        else:

            @wraps(fn)
            def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
                result = fn(*args, **kwargs)
                if (
                    isinstance(result, BaseCallResponse)
                    and handle_call_response is not None
                ):
                    with custom_context_manager(fn) as context:
                        handle_call_response(result, fn, context)
                elif isinstance(result, BaseStream):
                    original_class = type(result)
                    original_iter = result.__iter__

                    def new_stream_iter(
                        self: Any,  # noqa: ANN401
                    ) -> Generator[tuple[Any, Any | None], None, None]:  # noqa: ANN401
                        # Create the context manager when user iterates over the stream
                        with custom_context_manager(fn) as context:
                            yield from original_iter()
                            if handle_stream is not None:
                                handle_stream(result, fn, context)

                    class MiddlewareStream(original_class):
                        __iter__ = (
                            custom_decorator(fn)(new_stream_iter)
                            if custom_decorator
                            else new_stream_iter
                        )

                    result.__class__ = MiddlewareStream
                elif (
                    isinstance(result, ResponseModel)
                    and handle_response_model is not None
                ):
                    with custom_context_manager(fn) as context:
                        handle_response_model(result, fn, context)
                elif isinstance(result, BaseStructuredStream):
                    original_class = type(result)
                    original_iter = result.__iter__

                    def new_iter(
                        self: Any,  # noqa: ANN401
                    ) -> Generator[
                        Generator[tuple[Any, Any | None], None, None]
                        | Generator[Any, None, None],
                        Any,
                        None,
                    ]:  # noqa: ANN401
                        # Create the context manager when user iterates over the stream
                        with custom_context_manager(fn) as context:
                            yield from original_iter()
                            if handle_structured_stream is not None:
                                handle_structured_stream(result, fn, context)

                    class MiddlewareStructuredStream(original_class):
                        __iter__ = (
                            custom_decorator(fn)(new_iter)
                            if custom_decorator
                            else new_iter
                        )

                    result.__class__ = MiddlewareStructuredStream
                return cast(_R, result)

            return custom_decorator(fn)(wrapper) if custom_decorator else wrapper

    return decorator