Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User Registration #29

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chatbot-core/custom_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import requests
from langchain.embeddings.base import Embeddings

from utils import EmbeddingModelType


Expand Down
14 changes: 14 additions & 0 deletions common/dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import typing

from django.db import models

DjangoModel = typing.TypeVar("DjangoModel", bound=models.Model)


def load_model_objects(
Model: typing.Type[DjangoModel],
keys: list[int],
) -> list[DjangoModel]:
qs = Model.objects.filter(id__in=keys)
_map = {obj.pk: obj for obj in qs}
return [_map[key] for key in keys]
Empty file added main/graphql/__init__.py
Empty file.
16 changes: 16 additions & 0 deletions main/graphql/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass

from strawberry.django.context import StrawberryDjangoContext
from strawberry.types import Info as _Info

from .dataloaders import GlobalDataLoader


@dataclass
class GraphQLContext(StrawberryDjangoContext):
dl: GlobalDataLoader


# NOTE: This is for type support only, There is a better way?
class Info(_Info):
context: GraphQLContext # type: ignore[reportIncompatibleMethodOverride]
10 changes: 10 additions & 0 deletions main/graphql/dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from django.utils.functional import cached_property

from user.dataloaders import UserDataLoader


class GlobalDataLoader:

@cached_property
def user(self):
return UserDataLoader()
80 changes: 80 additions & 0 deletions main/graphql/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import dataclasses

import strawberry

from user.enums import enum_map as user_enum_map

ENUM_TO_STRAWBERRY_ENUM_MAP: dict[str, type] = {
**user_enum_map,
}


class AppEnumData:
def __init__(self, enum):
self.enum = enum

@property
def key(self):
return self.enum

@property
def label(self):
return str(self.enum.label)


def generate_app_enum_collection_data(name):
return type(
name,
(),
{field_name: [AppEnumData(e) for e in enum] for field_name, enum in ENUM_TO_STRAWBERRY_ENUM_MAP.items()},
)


AppEnumCollectionData = generate_app_enum_collection_data("AppEnumCollectionData")


def generate_type_for_enum(name, Enum):
return strawberry.type(
dataclasses.make_dataclass(
f"AppEnumCollection{name}",
[
("key", Enum),
("label", str),
],
)
)


def _enum_type(name, Enum):
EnumType = generate_type_for_enum(name, Enum)

@strawberry.field
def _field() -> list[EnumType]: # type: ignore[reportGeneralTypeIssues]
return [
EnumType(
key=e,
label=e.label,
)
for e in Enum
]

return list[EnumType], _field


def generate_type_for_enums():
enum_fields = [
(
enum_field_name,
*_enum_type(enum_field_name, enum),
)
for enum_field_name, enum in ENUM_TO_STRAWBERRY_ENUM_MAP.items()
]
return strawberry.type(
dataclasses.make_dataclass(
"AppEnumCollection",
enum_fields,
)
)


AppEnumCollection = generate_type_for_enums()
14 changes: 14 additions & 0 deletions main/graphql/permissions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import typing

from asgiref.sync import sync_to_async
from strawberry.permission import BasePermission
from strawberry.types import Info


class IsAuthenticated(BasePermission):
message = "User is not authenticated"

@sync_to_async
def has_permission(self, source: typing.Any, info: Info, **_) -> bool:
user = info.context.request.user
return bool(user and user.is_authenticated)
69 changes: 69 additions & 0 deletions main/graphql/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import strawberry
from strawberry.django.views import AsyncGraphQLView

from user import mutations as user_mutations
from user import queries as user_queries

from .context import GraphQLContext
from .dataloaders import GlobalDataLoader
from .enums import AppEnumCollection, AppEnumCollectionData
from .permissions import IsAuthenticated


class CustomAsyncGraphQLView(AsyncGraphQLView):
async def get_context(self, *args, **kwargs) -> GraphQLContext:
return GraphQLContext(
*args,
**kwargs,
dl=GlobalDataLoader(),
)


@strawberry.type
class PublicQuery(
user_queries.PublicQuery,
):
id: strawberry.ID = strawberry.ID("public")


@strawberry.type
class PrivateQuery(
user_queries.PrivateQuery,
):
id: strawberry.ID = strawberry.ID("private")


@strawberry.type
class PublicMutation(
user_mutations.PublicMutation,
):
id: strawberry.ID = strawberry.ID("public")


@strawberry.type
class PrivateMutation:
id: strawberry.ID = strawberry.ID("private")


@strawberry.type
class Query:
public: PublicQuery = strawberry.field(resolver=lambda: PublicQuery())
private: PrivateQuery = strawberry.field(permission_classes=[IsAuthenticated], resolver=lambda: PrivateQuery())
enums: AppEnumCollection = strawberry.field( # type: ignore[reportGeneralTypeIssues]
resolver=lambda: AppEnumCollectionData()
)


@strawberry.type
class Mutation:
public: PublicMutation = strawberry.field(resolver=lambda: PublicMutation())
private: PrivateMutation = strawberry.field(
resolver=lambda: PrivateMutation(),
permission_classes=[IsAuthenticated],
)


schema = strawberry.Schema(
query=Query,
mutation=Mutation,
)
41 changes: 41 additions & 0 deletions main/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,59 @@
"user",
"common",
"content",
"rest_framework",
"corsheaders",
]

MIDDLEWARE = [
"django.middleware.security.SecurityMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"corsheaders.middleware.CorsMiddleware",
"django.middleware.common.CommonMiddleware",
"django.middleware.csrf.CsrfViewMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
"django.middleware.clickjacking.XFrameOptionsMiddleware",
]

# CORS
if not env("DJANGO_CORS_ORIGIN_REGEX_WHITELIST"):
CORS_ORIGIN_ALLOW_ALL = True
else:
# Example ^https://[\w-]+\.mapswipe\.org$
CORS_ORIGIN_REGEX_WHITELIST = env("DJANGO_CORS_ORIGIN_REGEX_WHITELIST")

CORS_ALLOW_CREDENTIALS = True
CORS_URLS_REGEX = r"(^/media/.*$)|(^/graphql/$)"
CORS_ALLOW_METHODS = (
"DELETE",
"GET",
"OPTIONS",
"PATCH",
"POST",
"PUT",
)

CORS_ALLOW_HEADERS = (
"accept",
"accept-encoding",
"authorization",
"content-type",
"dnt",
"origin",
"user-agent",
"x-csrftoken",
"x-requested-with",
"sentry-trace",
)


# Strawberry
# -- Pagination
STRAWBERRY_ENUM_TO_STRAWBERRY_ENUM_MAP = "main.graphql.enums.ENUM_TO_STRAWBERRY_ENUM_MAP"
STRAWBERRY_DEFAULT_PAGINATION_LIMIT = 50
STRAWBERRY_MAX_PAGINATION_LIMIT = 100

ROOT_URLCONF = "main.urls"

TEMPLATES = [
Expand Down
14 changes: 14 additions & 0 deletions main/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,25 @@
from django.conf.urls.static import static
from django.contrib import admin
from django.urls import path
from django.views.decorators.csrf import csrf_exempt

from main.graphql.schema import CustomAsyncGraphQLView
from main.graphql.schema import schema as graphql_schema

urlpatterns = [
path("admin/", admin.site.urls),
path(
"graphql/",
csrf_exempt(
CustomAsyncGraphQLView.as_view(
schema=graphql_schema,
graphiql=False,
)
),
),
]
if settings.DEBUG:
urlpatterns.append(path("graphiql/", csrf_exempt(CustomAsyncGraphQLView.as_view(schema=graphql_schema))))

# Static and media file URLs
urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
Expand Down
Loading
Loading