"""Async gunicorn worker for aiohttp.web"""
import asyncio
import os
import re
import signal
import sys
from types
import FrameType
from typing
import Any, Awaitable, Callable, Optional, Union
# noqa
from gunicorn.config
import AccessLogFormat
as GunicornAccessLogFormat
from gunicorn.workers
import base
from aiohttp
import web
from .helpers
import set_result
from .web_app
import Application
from .web_log
import AccessLogger
try:
import ssl
SSLContext = ssl.SSLContext
except ImportError:
# pragma: no cover
ssl =
None # type: ignore[assignment]
SSLContext = object
# type: ignore[misc,assignment]
__all__ = (
"GunicornWebWorker",
"GunicornUVLoopWebWorker")
class GunicornWebWorker(base.Worker):
# type: ignore[misc,no-any-unimported]
DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT
DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default
def __init__(self, *args: Any, **kw: Any) ->
None:
# pragma: no cover
super().__init__(*args, **kw)
self._task: Optional[asyncio.Task[
None]] =
None
self.exit_code = 0
self._notify_waiter: Optional[asyncio.Future[bool]] =
None
def init_process(self) ->
None:
# create new event_loop after fork
asyncio.get_event_loop().close()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
super().init_process()
def run(self) ->
None:
self._task = self.loop.create_task(self._run())
try:
# ignore all finalization problems
self.loop.run_until_complete(self._task)
except Exception:
self.log.exception(
"Exception in gunicorn worker")
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()
sys.exit(self.exit_code)
async
def _run(self) ->
None:
runner =
None
if isinstance(self.wsgi, Application):
app = self.wsgi
elif asyncio.iscoroutinefunction(self.wsgi):
wsgi = await self.wsgi()
if isinstance(wsgi, web.AppRunner):
runner = wsgi
app = runner.app
else:
app = wsgi
else:
raise RuntimeError(
"wsgi app should be either Application or "
"async function returning Application, got {}".format(self.wsgi)
)
if runner
is None:
access_log = self.log.access_log
if self.cfg.accesslog
else None
runner = web.AppRunner(
app,
logger=self.log,
keepalive_timeout=self.cfg.keepalive,
access_log=access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format
),
shutdown_timeout=self.cfg.graceful_timeout / 100 * 95,
)
await runner.setup()
ctx = self._create_ssl_context(self.cfg)
if self.cfg.is_ssl
else None
runner = runner
assert runner
is not None
server = runner.server
assert server
is not None
for sock
in self.sockets:
site = web.SockSite(
runner,
sock,
ssl_context=ctx,
)
await site.start()
# If our parent changed then we shut down.
pid = os.getpid()
try:
while self.alive:
# type: ignore[has-type]
self.notify()
cnt = server.requests_count
if self.max_requests
and cnt > self.max_requests:
self.alive =
False
self.log.info(
"Max requests, shutting down: %s", self)
elif pid == os.getpid()
and self.ppid != os.getppid():
self.alive =
False
self.log.info(
"Parent changed, shutting down: %s", self)
else:
await self._wait_next_notify()
except BaseException:
pass
await runner.cleanup()
def _wait_next_notify(self) ->
"asyncio.Future[bool]":
self._notify_waiter_done()
loop = self.loop
assert loop
is not None
self._notify_waiter = waiter = loop.create_future()
self.loop.call_later(1.0, self._notify_waiter_done, waiter)
return waiter
def _notify_waiter_done(
self, waiter: Optional[
"asyncio.Future[bool]"] =
None
) ->
None:
if waiter
is None:
waiter = self._notify_waiter
if waiter
is not None:
set_result(waiter,
True)
if waiter
is self._notify_waiter:
self._notify_waiter =
None
def init_signals(self) ->
None:
# Set up signals through the event loop API.
self.loop.add_signal_handler(
signal.SIGQUIT, self.handle_quit, signal.SIGQUIT,
None
)
self.loop.add_signal_handler(
signal.SIGTERM, self.handle_exit, signal.SIGTERM,
None
)
self.loop.add_signal_handler(
signal.SIGINT, self.handle_quit, signal.SIGINT,
None
)
self.loop.add_signal_handler(
signal.SIGWINCH, self.handle_winch, signal.SIGWINCH,
None
)
self.loop.add_signal_handler(
signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1,
None
)
self.loop.add_signal_handler(
signal.SIGABRT, self.handle_abort, signal.SIGABRT,
None
)
# Don't let SIGTERM and SIGUSR1 disturb active requests
# by interrupting system calls
signal.siginterrupt(signal.SIGTERM,
False)
signal.siginterrupt(signal.SIGUSR1,
False)
# Reset signals so Gunicorn doesn't swallow subprocess return codes
# See: https://github.com/aio-libs/aiohttp/issues/6130
def handle_quit(self, sig: int, frame: Optional[FrameType]) ->
None:
self.alive =
False
# worker_int callback
self.cfg.worker_int(self)
# wakeup closing process
self._notify_waiter_done()
def handle_abort(self, sig: int, frame: Optional[FrameType]) ->
None:
self.alive =
False
self.exit_code = 1
self.cfg.worker_abort(self)
sys.exit(1)
@staticmethod
def _create_ssl_context(cfg: Any) ->
"SSLContext":
"""Creates SSLContext instance for usage in asyncio.create_server.
See ssl.SSLSocket.__init__
for more details.
"""
if ssl
is None:
# pragma: no cover
raise RuntimeError(
"SSL is not supported.")
ctx = ssl.SSLContext(cfg.ssl_version)
ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
ctx.verify_mode = cfg.cert_reqs
if cfg.ca_certs:
ctx.load_verify_locations(cfg.ca_certs)
if cfg.ciphers:
ctx.set_ciphers(cfg.ciphers)
return ctx
def _get_valid_log_format(self, source_format: str) -> str:
if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT:
return self.DEFAULT_AIOHTTP_LOG_FORMAT
elif re.search(r
"%\([^\)]+\)", source_format):
raise ValueError(
"Gunicorn's style options in form of `%(name)s` are not "
"supported for the log formatting. Please use aiohttp's "
"format specification to configure access log formatting: "
"http://docs.aiohttp.org/en/stable/logging.html"
"#format-specification"
)
else:
return source_format
class GunicornUVLoopWebWorker(GunicornWebWorker):
def init_process(self) ->
None:
import uvloop
# Close any existing event loop before setting a
# new policy.
asyncio.get_event_loop().close()
# Setup uvloop policy, so that every
# asyncio.get_event_loop() will create an instance
# of uvloop event loop.
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
super().init_process()