Source code for vstutils.api.endpoint

import typing as _t
import logging
import traceback
import functools
from concurrent.futures import ThreadPoolExecutor, Executor
from collections import OrderedDict

import orjson
from django.conf import settings
from django.db import transaction
from django.http import HttpResponse, HttpRequest
from django.contrib.auth.models import AbstractUser
from django.test.client import Client, ClientHandler
from django.test.utils import modify_settings
from drf_yasg.views import SPEC_RENDERERS
from rest_framework import serializers, views, versioning, request as drf_request
from rest_framework.authentication import (
    SessionAuthentication,
    BasicAuthentication,
    TokenAuthentication,
    BaseAuthentication
)

from . import responses
from .decorators import cache_method_result
from .serializers import DataSerializer
from .validators import UrlQueryStringValidator
from .renderers import ORJSONRenderer
from ..utils import Dict, raise_context, patch_gzip_response
from ..middleware import BaseMiddleware

RequestType = _t.Union[drf_request.Request, HttpRequest]
logger: logging.Logger = logging.getLogger('vstutils')

THREADS_COUNT = settings.BULK_THREADS
API_URL: _t.Text = settings.API_URL
DEFAULT_VERSION = settings.VST_API_VERSION
REST_METHODS: _t.List[_t.Text] = [
    m.upper() for m in views.APIView.http_method_names
]

default_authentication_classes = (
    SessionAuthentication,
    BasicAuthentication,
    TokenAuthentication
)

append_to_list = list.append
response_headers_to_pass = (
    "ETag",
    "Location",
    "Pagination-Identifiers",
)


@functools.singledispatch
def _get_request_data(request_data: _t.Iterable) -> _t.Union[_t.List, _t.Tuple]:
    assert isinstance(request_data, (list, tuple)), 'Request data must be list or tuple.'
    return request_data


@_get_request_data.register(dict)
def _get_request_data_dict(request_data):
    return [request_data]


@_get_request_data.register(str)
def _get_request_data_str(request_data):
    return _get_request_data(orjson.loads(request_data))  # nocv


def _iter_request(request, operation_handler, context):
    executor_class = _DummyExecutor
    if request.method not in ('POST', 'PUT'):
        executor_class = ThreadPoolExecutor if THREADS_COUNT else executor_class
    handler = lambda o: operation_handler(o, context)
    with executor_class(max_workers=THREADS_COUNT) as executor:
        for operation_result in executor.map(handler, _get_request_data(request.data)):
            yield operation_result


def _join_paths(*args) -> _t.Text:
    """Join multiple path fragments into one

    :param *args: List of items that can be anything like '/a/b/c', 'b/c/', 1, 'v1'
    :returns: Path that starts and ends with
    """
    return f"/{'/'.join(str(arg).strip('/') for arg in args)}/"


class _DummyExecutor(Executor):
    # pylint: disable=abstract-method

    def __init__(self, *args, **kwargs):
        super().__init__()

    def map(self, fn, *iterables, timeout=None, chunksize=1):
        return map(fn, *iterables)


class ParseResponseDict(dict):
    __slots__ = ('timing',)
    timing: _t.SupportsFloat

    def __init__(self, path: _t.Text, method: _t.Text, response: HttpResponse):
        super().__init__(
            path=path,
            status=response.status_code,
            data=self._get_rendered(response),
            method=method,
            headers={
                header: response.headers[header]
                for header in response.headers
                if header in response_headers_to_pass
            }
        )
        self.timing = float(response.get('Response-Time', '0.0'))

    def _get_rendered(self, response: _t.Union[HttpResponse, responses.BaseResponseClass]):
        with raise_context():
            result = response.data  # type: ignore
            if isinstance(result, dict):
                return Dict(result)
        with raise_context():
            if isinstance(response.accepted_renderer, ORJSONRenderer) and response.is_rendered:  # type: ignore
                return response.rendered_content  # type: ignore
        if response.status_code != 404 and getattr(response, "rendered_content", False):  # nocv
            return orjson.loads(response.rendered_content.decode())  # type: ignore
        return Dict(detail=str(response.content.decode('utf-8')))


class BulkRequestType(drf_request.Request, HttpRequest):  # type: ignore
    # pylint: disable=abstract-method
    data: _t.List[_t.Dict[_t.Text, _t.Any]]  # type: ignore
    version: _t.Optional[_t.Text]
    successful_authenticator: _t.Optional[BaseAuthentication]


class BulkMiddleware(BaseMiddleware):
    __slots__ = ()

    def request_handler(self, request: HttpRequest) -> HttpRequest:
        request.is_bulk = True  # type: ignore
        if 'user' in request.META:
            request.user = request.META.pop('user')
            # pylint: disable=protected-access
            request._cached_user = request.user  # type: ignore
        if 'language' in request.META:
            request.language = request.META.pop('language')  # type: ignore
        if 'session' in request.META:
            request.session = request.META.pop('session')  # type: ignore
        return request


class BulkClientHandler(ClientHandler):
    __slots__ = ()

    @modify_settings(MIDDLEWARE=settings.MIDDLEWARE_ENDPOINT_CONTROL)
    def __init__(self, *args, **kwargs):
        super().__init__(enforce_csrf_checks=False, *args, **kwargs)
        if self.__class__.__name__ == 'BulkClientHandler':
            self.load_middleware()


class BulkClient(Client):
    __slots__ = ('user', 'language', 'session', 'exc_info')
    handler: BulkClientHandler = BulkClientHandler()
    user: _t.Optional[AbstractUser]

    def __init__(self, enforce_csrf_checks=False, **defaults):
        # pylint: disable=bad-super-call
        self.user = defaults.pop('user', None)
        self.language = defaults.pop('language', None)
        self.session = defaults.pop('session', None)
        super(Client, self).__init__(**defaults)
        self.exc_info = None

    def request(self, **request):
        if self.user:
            request['user'] = self.user
        if self.language:
            request['language'] = self.language
        if self.session:
            request['session'] = self.session
        response = self.handler(self._base_environ(**request))
        if response.cookies:
            self.cookies.update(response.cookies)
        return response


class FormatDataFieldMixin:
    """
    Mixin for fields that can format "<< >>" templates inside strings
    """
    __slots__ = ()
    requires_context: bool = True
    context: _t.Dict

    def to_internal_value(self, data) -> _t.Text:
        result = super().to_internal_value(data)  # type: ignore

        if isinstance(result, str) \
                and '<<' in result \
                and '>>' in result \
                and not ('{' in result and '}' in result) \
                and 'results' in self.context:
            result = result.replace('<<', '{').replace('>>', '}').format(
                *self.context['results'],
                **self.context['variables'],
            )
            with raise_context():
                return orjson.loads(result)

        return result


class TemplateStringField(FormatDataFieldMixin, serializers.CharField):
    """
    Field that can format "<< >>" templates inside strings
    """
    __slots__ = ()


class RequestDataField(FormatDataFieldMixin, DataSerializer):
    """
    Field that can handle basic data types and recursise
    format template strings inside them
    """
    __slots__ = ()

    def to_internal_value(self, data):
        if isinstance(data, str):
            pass

        elif isinstance(data, (list, tuple)):
            return [self.to_internal_value(i) for i in data]

        elif isinstance(data, (dict, OrderedDict)):
            return type(data)(
                (super(RequestDataField, self).to_internal_value(k), self.to_internal_value(v))
                for k, v in data.items()
            )

        return super(RequestDataField, self).to_internal_value(data)


class MethodChoicesField(serializers.ChoiceField):
    """Field for HTTP method"""
    __slots__ = ()

    def __init__(self, choices: _t.List = None, **kwargs):
        super().__init__(choices or REST_METHODS, **kwargs)

    def to_internal_value(self, data):
        return super(MethodChoicesField, self).to_internal_value(str(data).upper())


class PathField(TemplateStringField):
    __slots__ = ()

    def to_internal_value(self, data):

        if isinstance(data, str):
            data = (data, )

        return _join_paths(*[
            super(PathField, self).to_internal_value(path)
            for path in data
        ])


class OperationSerializer(serializers.Serializer):
    # pylint: disable=abstract-method
    __slots__ = ()
    renderer = ORJSONRenderer()

    path = PathField(required=True)
    method = MethodChoicesField(required=True)
    headers = serializers.DictField(child=TemplateStringField(), default={})
    data = RequestDataField(required=False, default=None, allow_null=True)  # type: ignore
    status = serializers.IntegerField(read_only=True, default=500)
    info = serializers.CharField(read_only=True)
    query = TemplateStringField(required=False,
                                allow_blank=True,
                                validators=[UrlQueryStringValidator()],
                                write_only=True)
    let = TemplateStringField(required=False,
                              write_only=True)
    version = serializers.ChoiceField(choices=list(settings.API.keys()),
                                      default=settings.VST_API_VERSION,
                                      write_only=True)

    def to_representation(self, instance: _t.Dict[_t.Text, _t.Any]) -> Dict:
        return Dict(super().to_representation(instance))

    def get_operation_method(self, method: _t.Text) -> _t.Callable:
        return getattr(self.context.get('client'), method.lower())

    def create(self, validated_data: _t.Dict[_t.Text, _t.Union[_t.Text, _t.Mapping]]) -> ParseResponseDict:
        # pylint: disable=protected-access
        method_name = str(validated_data['method']).lower()
        method = self.get_operation_method(method_name)
        url = _join_paths(API_URL, validated_data['version'], validated_data['path'])
        if 'query' in validated_data and validated_data['query']:
            url += '?' + str(validated_data['query'])
        if method_name != 'get':
            method = transaction.atomic()(method)
        data = validated_data['data']
        if data and method_name != 'get':
            data = self.renderer.render(data, media_type=self.renderer.media_type)
        result = ParseResponseDict(
            path=url,
            method=method_name,
            response=method(  # type: ignore
                url,
                content_type='application/json',
                secure=self.context['request']._request.is_secure(),
                data=data,
                **validated_data['headers']
            )
        )
        if 'let' in validated_data:
            self.context['variables'][validated_data['let']] = result
        return result


[docs]class EndpointViewSet(views.APIView): """ Default API-endpoint viewset. """ __slots__ = ('results',) throttle_classes = [] # type: ignore schema = None # type: ignore versioning_class = versioning.QueryParameterVersioning # type: ignore renderer_classes = list(views.APIView.renderer_classes) + list(SPEC_RENDERERS) session_cookie_name: _t.ClassVar[_t.Text] = settings.SESSION_COOKIE_NAME client_environ_keys_copy: _t.List[_t.Text] = [ "SCRIPT_NAME", "SERVER_NAME", "SERVER_PORT", "SERVER_PROTOCOL", "SERVER_SOFTWARE", "REMOTE_ADDR", settings.SECURE_PROXY_SSL_HEADER[0], "HTTP_HOST", "HTTP_USER_AGENT" ] #: One operation serializer class. serializer_class: _t.ClassVar[_t.Type[OperationSerializer]] = OperationSerializer
[docs] def get_client(self, request: BulkRequestType) -> BulkClient: """ Returns test client and guarantees that if bulk request comes authenticated than test client will be authenticated with the same user """ return BulkClient(**self.original_environ_data(request=request))
def original_environ_data(self, request: BulkRequestType, *args) -> _t.Dict: get_environ = request.META.get kwargs = {} for env_var in tuple(self.client_environ_keys_copy) + args: value = get_environ(env_var, None) if value: kwargs[env_var] = str(value) if request.user.is_authenticated: if isinstance(request.successful_authenticator, SessionAuthentication): kwargs['HTTP_COOKIE'] = str(request.META.get('HTTP_COOKIE')) elif isinstance(request.successful_authenticator, (BasicAuthentication, TokenAuthentication)): kwargs['HTTP_AUTHORIZATION'] = str(request.META.get('HTTP_AUTHORIZATION')) kwargs['user'] = request.user # type: ignore kwargs['language'] = getattr(request, 'language', None) kwargs['session'] = getattr(request, 'session', None) return kwargs
[docs] def get_serializer(self, *args, **kwargs) -> OperationSerializer: """ Return the serializer instance that should be used for validating and deserializing input, and for serializing output. """ serializer_class = self.get_serializer_class() kwargs['context'] = self.get_serializer_context(kwargs.get('context', {})) return serializer_class(*args, **kwargs)
@cache_method_result def get_serializer_class(self): """ Return the class to use for the serializer. Defaults to using `self.serializer_class`. You may want to override this if you need to provide different serializations depending on the incoming request. (Eg. admins get full serialization, others get basic serialization) """ assert self.serializer_class is not None, ( "'%s' should either include a `serializer_class` attribute, " "or override the `get_serializer_class()` method." % self.__class__.__name__ ) return self.serializer_class
[docs] def get_serializer_context(self, context) -> dict: """ Extra context provided to the serializer class. """ if 'client' not in context: # nocv context = context.copy() context['client'] = self.get_client(_t.cast(BulkRequestType, self.request)) return { 'request': self.request, 'view': self, **context }
[docs] def operate(self, operation_data: _t.Dict, context: _t.Dict) -> _t.Tuple[_t.Dict, _t.SupportsFloat]: """Method used to handle one operation and return result of it""" serializer = self.get_serializer(data=operation_data, context=context) try: serializer.is_valid(raise_exception=True) return serializer.to_representation(serializer.save()), serializer.instance.timing # type: ignore except Exception as err: return { 'path': 'bulk', 'info': { 'errors': getattr(serializer, '_errors', traceback.format_exc()), 'operation_data': operation_data }, 'status': 500, 'data': {'detail': f'Error in bulk request data. See info. Original message: {str(err)}'} }, 0.0
[docs] def get(self, request: BulkRequestType) -> HttpResponse: """Returns response with swagger ui or openapi json schema if ?format=openapi""" url = f'/api/{getattr(request, "version", DEFAULT_VERSION) or DEFAULT_VERSION}/_openapi/' if request.query_params.get('format') == 'openapi': # type: ignore url += '?format=openapi' should_gzip = True else: should_gzip = False response = self.get_client(request).get(url, secure=request.is_secure()) if should_gzip: patch_gzip_response(response, request) return response
[docs] def post(self, request: BulkRequestType) -> responses.BaseResponseClass: """Execute transactional bulk request""" try: with transaction.atomic(): return self.put(request, allow_fail=False) except Exception: logger.debug(traceback.format_exc()) return responses.HTTP_502_BAD_GATEWAY(self.results)
[docs] def put(self, request: BulkRequestType, allow_fail=True) -> responses.BaseResponseClass: """Execute non transaction bulk request""" context: _t.Dict[_t.Text, _t.Union[_t.List, _t.Dict, BulkClient]] = { 'client': self.get_client(request), 'results': self.results, 'variables': {}, } timings: _t.List = [] for result, timing in _iter_request(request, self.operate, context): append_to_list(self.results, result) append_to_list(timings, timing) if not allow_fail and not (100 <= result.get('status', 500) < 400): raise Exception(f'Execute transaction stopped. Error message: {str(result)}') response = responses.HTTP_200_OK(self.results, timings={f'op{i}': float(j) for i, j in enumerate(timings)}) for cookie_name, cookie_value in context['client'].cookies.items(): # type: ignore if cookie_value.value != request.COOKIES.get(cookie_name, None): response.cookies[cookie_name] = cookie_value return response
def patch(self, request: BulkRequestType) -> responses.BaseResponseClass: return self.put(request) def initial(self, request: drf_request.Request, *args, **kwargs): super().initial(request, *args, **kwargs) self.results: _t.List[_t.Dict[_t.Text, _t.Any]] = [] def finalize_response(self, request: drf_request.Request, *args, **kwargs): if not isinstance(request.successful_authenticator, default_authentication_classes): self.get_client(_t.cast(BulkRequestType, self.request)).logout() return super().finalize_response(request, *args, **kwargs)