diff --git a/backend/apps/api/internal/extensions/__init__.py b/backend/apps/api/internal/extensions/__init__.py new file mode 100644 index 0000000000..4b350753ac --- /dev/null +++ b/backend/apps/api/internal/extensions/__init__.py @@ -0,0 +1 @@ +"""Strawberry extensions.""" diff --git a/backend/apps/api/internal/extensions/cache.py b/backend/apps/api/internal/extensions/cache.py new file mode 100644 index 0000000000..1d2396ad12 --- /dev/null +++ b/backend/apps/api/internal/extensions/cache.py @@ -0,0 +1,106 @@ +"""Strawberry cache extension.""" + +import hashlib +import json +from functools import lru_cache + +from django.conf import settings +from django.core.cache import cache +from django.core.serializers.json import DjangoJSONEncoder +from strawberry.extensions import SchemaExtension +from strawberry.permission import PermissionExtension +from strawberry.schema import Schema +from strawberry.utils.str_converters import to_camel_case + + +@lru_cache(maxsize=1) +def get_protected_fields(schema: Schema) -> tuple[str, ...]: + """Get protected field names. + + Args: + schema (Schema): The GraphQL schema. + + Returns: + tuple[str, ...]: Tuple of protected field names in camelCase. + + """ + return tuple( + to_camel_case(field.name) + for field in getattr( + getattr(schema.schema_converter.type_map.get("Query"), "definition", None), + "fields", + (), + ) + if any(isinstance(ext, PermissionExtension) for ext in field.extensions) + ) + + +def generate_key(field_name: str, field_args: dict) -> str: + """Generate a unique cache key for a query. + + Args: + field_name (str): The GraphQL field name. + field_args (dict): The field's arguments. + + Returns: + str: The unique cache key. + + """ + key = f"{field_name}:{json.dumps(field_args, cls=DjangoJSONEncoder, sort_keys=True)}" + return f"{settings.GRAPHQL_RESOLVER_CACHE_PREFIX}-{hashlib.sha256(key.encode()).hexdigest()}" + + +def invalidate_cache(field_name: str, field_args: dict) -> bool: + """Invalidate a specific GraphQL query from the resolver cache. + + Args: + field_name: The GraphQL field name (e.g., 'getProgram'). + field_args: The field's arguments as a dict (e.g., {'programKey': 'my-program'}). + + Returns: + True if cache was invalidated, False if key didn't exist. + + """ + return cache.delete(generate_key(field_name, field_args)) + + +def invalidate_program_cache(program_key: str) -> None: + """Invalidate all GraphQL caches related to a program. + + Args: + program_key: The program's key identifier. + + """ + invalidate_cache("getProgram", {"programKey": program_key}) + invalidate_cache("getProgramModules", {"programKey": program_key}) + + +def invalidate_module_cache(module_key: str, program_key: str) -> None: + """Invalidate all GraphQL caches related to a module. + + Args: + module_key: The module's key identifier. + program_key: The program's key identifier. + + """ + invalidate_cache("getModule", {"moduleKey": module_key, "programKey": program_key}) + invalidate_program_cache(program_key) + + +class CacheExtension(SchemaExtension): + """Cache extension.""" + + def resolve(self, _next, root, info, *args, **kwargs): + """Wrap the resolver to provide caching.""" + if ( + info.field_name.startswith("__") + or info.parent_type.name != "Query" + or info.field_name in get_protected_fields(self.execution_context.schema) + ): + return _next(root, info, *args, **kwargs) + + return cache.get_or_set( + generate_key(info.field_name, kwargs), + lambda: _next(root, info, *args, **kwargs), + settings.GRAPHQL_RESOLVER_CACHE_TIME_SECONDS, + ) diff --git a/backend/apps/common/extensions.py b/backend/apps/common/extensions.py deleted file mode 100644 index 21e9373950..0000000000 --- a/backend/apps/common/extensions.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Strawberry extensions.""" - -import hashlib -import json -from functools import lru_cache - -from django.conf import settings -from django.core.cache import cache -from strawberry.extensions import SchemaExtension -from strawberry.permission import PermissionExtension -from strawberry.schema import Schema -from strawberry.utils.str_converters import to_camel_case - - -@lru_cache(maxsize=1) -def get_protected_fields(schema: Schema) -> tuple[str, ...]: - """Get protected field names. - - Args: - schema (Schema): The GraphQL schema. - - Returns: - tuple[str, ...]: Tuple of protected field names in camelCase. - - """ - query_type = schema.schema_converter.type_map.get("Query") - fields = getattr(getattr(query_type, "definition", None), "fields", ()) - return tuple( - to_camel_case(field.name) - for field in fields - if any(isinstance(ext, PermissionExtension) for ext in field.extensions) - ) - - -class CacheExtension(SchemaExtension): - """CacheExtension class.""" - - def generate_key(self, field_name: str, field_args: dict) -> str: - """Generate a unique cache key for a query. - - Args: - field_name (str): The GraphQL field name. - field_args (dict): The field's arguments. - - Returns: - str: The unique cache key. - - """ - key = f"{field_name}:{json.dumps(field_args, sort_keys=True)}" - return ( - f"{settings.GRAPHQL_RESOLVER_CACHE_PREFIX}-{hashlib.sha256(key.encode()).hexdigest()}" - ) - - def resolve(self, _next, root, info, *args, **kwargs): - """Wrap the resolver to provide caching.""" - if ( - info.field_name.startswith("__") - or info.parent_type.name != "Query" - or info.field_name in get_protected_fields(self.execution_context.schema) - ): - return _next(root, info, *args, **kwargs) - - return cache.get_or_set( - self.generate_key(info.field_name, kwargs), - lambda: _next(root, info, *args, **kwargs), - settings.GRAPHQL_RESOLVER_CACHE_TIME_SECONDS, - ) diff --git a/backend/apps/mentorship/api/internal/mutations/module.py b/backend/apps/mentorship/api/internal/mutations/module.py index 96e95c3dd4..2c15cb797a 100644 --- a/backend/apps/mentorship/api/internal/mutations/module.py +++ b/backend/apps/mentorship/api/internal/mutations/module.py @@ -8,6 +8,7 @@ from django.db import transaction from django.utils import timezone +from apps.api.internal.extensions.cache import invalidate_module_cache, invalidate_program_cache from apps.github.models import User as GithubUser from apps.mentorship.api.internal.nodes.module import ( CreateModuleInput, @@ -38,6 +39,7 @@ def resolve_mentors_from_logins(logins: list[str]) -> set[Mentor]: msg = f"GitHub user '{login}' not found." logger.warning(msg, exc_info=True) raise ValueError(msg) from e + return mentors @@ -119,6 +121,8 @@ def create_module(self, info: strawberry.Info, input_data: CreateModuleInput) -> mentors_to_set.add(creator_as_mentor) module.mentors.set(list(mentors_to_set)) + transaction.on_commit(lambda: invalidate_program_cache(program.key)) + return module @strawberry.mutation(permission_classes=[IsAuthenticated]) @@ -330,6 +334,7 @@ def update_module(self, info: strawberry.Info, input_data: UpdateModuleInput) -> module = Module.objects.select_related("program").get( key=input_data.key, program__key=input_data.program_key ) + old_module_key = module.key except Module.DoesNotExist as e: raise ObjectDoesNotExist(msg=MODULE_NOT_FOUND_MSG) from e @@ -400,4 +405,13 @@ def update_module(self, info: strawberry.Info, input_data: UpdateModuleInput) -> module.program.save(update_fields=["experience_levels"]) + program_key = module.program.key + + def _invalidate(): + invalidate_module_cache(old_module_key, program_key) + if module.key != old_module_key: + invalidate_module_cache(module.key, program_key) + + transaction.on_commit(_invalidate) + return module diff --git a/backend/apps/mentorship/api/internal/mutations/program.py b/backend/apps/mentorship/api/internal/mutations/program.py index 0905e28c53..88bba96ad8 100644 --- a/backend/apps/mentorship/api/internal/mutations/program.py +++ b/backend/apps/mentorship/api/internal/mutations/program.py @@ -6,6 +6,7 @@ from django.core.exceptions import ObjectDoesNotExist, PermissionDenied, ValidationError from django.db import transaction +from apps.api.internal.extensions.cache import invalidate_program_cache from apps.mentorship.api.internal.mutations.module import resolve_mentors_from_logins from apps.mentorship.api.internal.nodes.enum import ProgramStatusEnum from apps.mentorship.api.internal.nodes.program import ( @@ -76,6 +77,7 @@ def update_program(self, info: strawberry.Info, input_data: UpdateProgramInput) try: program = Program.objects.get(key=input_data.key) + old_key = program.key except Program.DoesNotExist as err: msg = f"Program with key '{input_data.key}' not found." logger.warning(msg, exc_info=True) @@ -133,6 +135,13 @@ def update_program(self, info: strawberry.Info, input_data: UpdateProgramInput) admins_to_set = resolve_mentors_from_logins(input_data.admin_logins) program.admins.set(admins_to_set) + def _invalidate(): + invalidate_program_cache(old_key) + if program.key != old_key: + invalidate_program_cache(program.key) + + transaction.on_commit(_invalidate) + return program @strawberry.mutation(permission_classes=[IsAuthenticated]) @@ -161,6 +170,8 @@ def update_program_status( program.status = input_data.status.value program.save() + transaction.on_commit(lambda: invalidate_program_cache(program.key)) + logger.info("Updated status of program '%s' to '%s'", program.key, program.status) return program diff --git a/backend/settings/graphql.py b/backend/settings/graphql.py index 6020741d74..39eb24ab67 100644 --- a/backend/settings/graphql.py +++ b/backend/settings/graphql.py @@ -4,9 +4,9 @@ from strawberry.extensions import QueryDepthLimiter from strawberry_django.optimizer import DjangoOptimizerExtension +from apps.api.internal.extensions.cache import CacheExtension from apps.api.internal.mutations import ApiMutations from apps.api.internal.queries import ApiKeyQueries -from apps.common.extensions import CacheExtension from apps.github.api.internal.queries import GithubQuery from apps.mentorship.api.internal.mutations import ( ModuleMutation, diff --git a/backend/tests/apps/api/internal/extensions/__init__.py b/backend/tests/apps/api/internal/extensions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/tests/apps/common/extensions_test.py b/backend/tests/apps/api/internal/extensions/cache_test.py similarity index 56% rename from backend/tests/apps/common/extensions_test.py rename to backend/tests/apps/api/internal/extensions/cache_test.py index 3b7bb1c982..d1d0056129 100644 --- a/backend/tests/apps/common/extensions_test.py +++ b/backend/tests/apps/api/internal/extensions/cache_test.py @@ -1,51 +1,67 @@ """Tests for CacheExtension.""" +import datetime from unittest.mock import MagicMock, patch +from uuid import UUID import pytest +from django.conf import settings from strawberry.permission import PermissionExtension -from apps.common.extensions import CacheExtension, get_protected_fields +from apps.api.internal.extensions.cache import ( + CacheExtension, + generate_key, + get_protected_fields, + invalidate_module_cache, + invalidate_program_cache, +) class TestGenerateKey: - """Test cases for the generate_key method.""" + """Test cases for the generate_key function.""" - @pytest.fixture - def extension(self): - """Return a CacheExtension instance.""" - return CacheExtension() - - def test_creates_deterministic_hash(self, extension): + def test_creates_deterministic_hash(self): """Test that generate_key creates a deterministic hash key.""" - key1 = extension.generate_key("chapter", {"key": "germany"}) - key2 = extension.generate_key("chapter", {"key": "germany"}) + key1 = generate_key("chapter", {"key": "germany"}) + key2 = generate_key("chapter", {"key": "germany"}) assert key1 == key2 assert key1.startswith("graphql-") assert len(key1.split("-")[-1]) == 64 # SHA256 hex digest length - def test_differs_for_different_field_names(self, extension): + def test_differs_for_different_field_names(self): """Test that different field names produce different keys.""" - key1 = extension.generate_key("chapter", {"key": "germany"}) - key2 = extension.generate_key("project", {"key": "germany"}) + key1 = generate_key("chapter", {"key": "germany"}) + key2 = generate_key("project", {"key": "germany"}) assert key1 != key2 - def test_differs_for_different_args(self, extension): + def test_differs_for_different_args(self): """Test that different arguments produce different keys.""" - key1 = extension.generate_key("chapter", {"key": "germany"}) - key2 = extension.generate_key("chapter", {"key": "canada"}) + key1 = generate_key("chapter", {"key": "germany"}) + key2 = generate_key("chapter", {"key": "canada"}) assert key1 != key2 - def test_sorts_args_for_consistency(self, extension): + def test_sorts_args_for_consistency(self): """Test that argument order doesn't affect the key.""" - key1 = extension.generate_key("chapter", {"a": "1", "b": "2"}) - key2 = extension.generate_key("chapter", {"b": "2", "a": "1"}) + key1 = generate_key("chapter", {"a": "1", "b": "2"}) + key2 = generate_key("chapter", {"b": "2", "a": "1"}) assert key1 == key2 + def test_serializes_datetime_and_uuid_args(self): + """Test that generate_key handles datetime and UUID in args without error.""" + dt = datetime.datetime(2025, 1, 15, 12, 0, 0, tzinfo=datetime.UTC) + uid = UUID("550e8400-e29b-41d4-a716-446655440000") + field_args = {"at": dt, "id": uid} + + key1 = generate_key("someField", field_args) + key2 = generate_key("someField", field_args) + + assert key1 == key2 + assert key1.startswith(f"{settings.GRAPHQL_RESOLVER_CACHE_PREFIX}-") + class TestGetProtectedFields: """Test cases for the get_protected_fields function.""" @@ -90,7 +106,9 @@ class TestResolve: @pytest.fixture(autouse=True) def mock_protected_fields(self): """Patch get_protected_fields for all tests.""" - with patch("apps.common.extensions.get_protected_fields", return_value=("apiKeys",)): + with patch( + "apps.api.internal.extensions.cache.get_protected_fields", return_value=("apiKeys",) + ): yield @pytest.fixture @@ -140,7 +158,7 @@ def test_skips_protected_fields(self, extension, mock_info, mock_next): mock_next.assert_called_once() assert result == mock_next.return_value - @patch("apps.common.extensions.cache") + @patch("apps.api.internal.extensions.cache.cache") def test_returns_cached_result_on_hit(self, mock_cache, extension, mock_info, mock_next): """Test that cached result is returned on cache hit.""" cached_result = {"name": "Cached OWASP"} @@ -152,7 +170,7 @@ def test_returns_cached_result_on_hit(self, mock_cache, extension, mock_info, mo mock_cache.get_or_set.assert_called_once() mock_next.assert_not_called() - @patch("apps.common.extensions.cache") + @patch("apps.api.internal.extensions.cache.cache") def test_caches_result_on_miss(self, mock_cache, extension, mock_info, mock_next): """Test that result is cached on cache miss.""" mock_cache.get_or_set.side_effect = lambda _key, default, _timeout: default() @@ -161,3 +179,41 @@ def test_caches_result_on_miss(self, mock_cache, extension, mock_info, mock_next mock_next.assert_called_once() mock_cache.get_or_set.assert_called_once() + + +class TestInvalidationHelpers: + """Test cases for invalidation helper functions.""" + + @patch("apps.api.internal.extensions.cache.cache") + @patch("apps.api.internal.extensions.cache.generate_key") + def test_invalidate_program_cache_uses_camel_case_keys(self, mock_generate_key, mock_cache): + """Test that invalidate_program_cache uses correct camelCase keys.""" + mock_generate_key.side_effect = lambda name, _args: f"{name}-hashed" + + invalidate_program_cache("my-program") + + # Verify calls to generate_key use camelCase 'programKey' + assert mock_generate_key.call_count == 2 + mock_generate_key.assert_any_call("getProgram", {"programKey": "my-program"}) + mock_generate_key.assert_any_call("getProgramModules", {"programKey": "my-program"}) + + assert mock_cache.delete.call_count == 2 + mock_cache.delete.assert_any_call("getProgram-hashed") + mock_cache.delete.assert_any_call("getProgramModules-hashed") + + @patch("apps.api.internal.extensions.cache.cache") + @patch("apps.api.internal.extensions.cache.generate_key") + def test_invalidate_module_cache_uses_camel_case_keys(self, mock_generate_key, mock_cache): + """Test that invalidate_module_cache uses correct camelCase keys.""" + mock_generate_key.side_effect = lambda name, _args: f"{name}-hashed" + + invalidate_module_cache("module-1", "program-1") + + assert mock_generate_key.call_count == 3 + mock_generate_key.assert_any_call( + "getModule", {"moduleKey": "module-1", "programKey": "program-1"} + ) + mock_generate_key.assert_any_call("getProgram", {"programKey": "program-1"}) + mock_generate_key.assert_any_call("getProgramModules", {"programKey": "program-1"}) + + assert mock_cache.delete.call_count == 3 diff --git a/frontend/__tests__/unit/pages/CreateProgram.test.tsx b/frontend/__tests__/unit/pages/CreateProgram.test.tsx index 9cdac1d5a6..1165d4a9a6 100644 --- a/frontend/__tests__/unit/pages/CreateProgram.test.tsx +++ b/frontend/__tests__/unit/pages/CreateProgram.test.tsx @@ -160,19 +160,23 @@ describe('CreateProgramPage (comprehensive tests)', () => { fireEvent.submit(screen.getByText('Save').closest('form')) await waitFor(() => { - expect(mockCreateProgram).toHaveBeenCalledWith({ - variables: { - input: { - name: 'Test Program', - description: 'A description', - menteesLimit: 0, - startedAt: '2025-01-01', - endedAt: '2025-12-31', - tags: ['tag1', 'tag2'], - domains: ['domain1', 'domain2'], + expect(mockCreateProgram).toHaveBeenCalledWith( + expect.objectContaining({ + variables: { + input: { + name: 'Test Program', + description: 'A description', + menteesLimit: 0, + startedAt: '2025-01-01', + endedAt: '2025-12-31', + tags: ['tag1', 'tag2'], + domains: ['domain1', 'domain2'], + }, }, - }, - }) + awaitRefetchQueries: true, + refetchQueries: expect.any(Array), + }) + ) expect(mockRouterPush).toHaveBeenCalledWith('/my/mentorship') }) diff --git a/frontend/src/app/mentorship/programs/[programKey]/modules/[moduleKey]/page.tsx b/frontend/src/app/mentorship/programs/[programKey]/modules/[moduleKey]/page.tsx index 476ff243fb..43b474052f 100644 --- a/frontend/src/app/mentorship/programs/[programKey]/modules/[moduleKey]/page.tsx +++ b/frontend/src/app/mentorship/programs/[programKey]/modules/[moduleKey]/page.tsx @@ -19,6 +19,7 @@ const ModuleDetailsPage = () => { error, loading: isLoading, } = useQuery(GetProgramAdminsAndModulesDocument, { + fetchPolicy: 'cache-and-network', variables: { programKey, moduleKey, @@ -34,7 +35,7 @@ const ModuleDetailsPage = () => { } }, [error]) - if (isLoading) return + if (isLoading && !data) return if (error) { return ( diff --git a/frontend/src/app/mentorship/programs/[programKey]/page.tsx b/frontend/src/app/mentorship/programs/[programKey]/page.tsx index 9f4f571c81..f17e47dd4f 100644 --- a/frontend/src/app/mentorship/programs/[programKey]/page.tsx +++ b/frontend/src/app/mentorship/programs/[programKey]/page.tsx @@ -18,8 +18,9 @@ const ProgramDetailsPage = () => { error: graphQLRequestError, loading: isLoading, } = useQuery(GetProgramAndModulesDocument, { - variables: { programKey }, + fetchPolicy: 'cache-and-network', skip: !programKey, + variables: { programKey }, }) const program = data?.getProgram @@ -29,9 +30,9 @@ const ProgramDetailsPage = () => { if (graphQLRequestError) { handleAppError(graphQLRequestError) } - }, [graphQLRequestError, programKey]) + }, [graphQLRequestError]) - if (isLoading) return + if (isLoading && !data) return if (graphQLRequestError) { return ( diff --git a/frontend/src/app/my/mentorship/programs/[programKey]/edit/page.tsx b/frontend/src/app/my/mentorship/programs/[programKey]/edit/page.tsx index dde49d9f1a..940acde789 100644 --- a/frontend/src/app/my/mentorship/programs/[programKey]/edit/page.tsx +++ b/frontend/src/app/my/mentorship/programs/[programKey]/edit/page.tsx @@ -9,12 +9,16 @@ import { useState, useEffect } from 'react' import { ErrorDisplay, handleAppError } from 'app/global-error' import { ProgramStatusEnum } from 'types/__generated__/graphql' import { UpdateProgramDocument } from 'types/__generated__/programsMutations.generated' -import { GetProgramDetailsDocument } from 'types/__generated__/programsQueries.generated' +import { + GetMyProgramsDocument, + GetProgramDetailsDocument, +} from 'types/__generated__/programsQueries.generated' import type { ExtendedSession } from 'types/auth' import { formatDateForInput } from 'utils/dateFormatter' import { parseCommaSeparated } from 'utils/parser' import LoadingSpinner from 'components/LoadingSpinner' import ProgramForm from 'components/ProgramForm' + const EditProgramPage = () => { const router = useRouter() const { programKey } = useParams<{ programKey: string }>() @@ -104,7 +108,11 @@ const EditProgramPage = () => { status: formData.status, } - const result = await updateProgram({ variables: { input } }) + const result = await updateProgram({ + awaitRefetchQueries: true, + refetchQueries: [{ query: GetMyProgramsDocument }], + variables: { input }, + }) const updatedProgramKey = result.data?.updateProgram?.key || programKey addToast({ diff --git a/frontend/src/app/my/mentorship/programs/[programKey]/modules/[moduleKey]/edit/page.tsx b/frontend/src/app/my/mentorship/programs/[programKey]/modules/[moduleKey]/edit/page.tsx index bad0b4cc47..3a925886f3 100644 --- a/frontend/src/app/my/mentorship/programs/[programKey]/modules/[moduleKey]/edit/page.tsx +++ b/frontend/src/app/my/mentorship/programs/[programKey]/modules/[moduleKey]/edit/page.tsx @@ -9,6 +9,7 @@ import { ErrorDisplay, handleAppError } from 'app/global-error' import { ExperienceLevelEnum } from 'types/__generated__/graphql' import { UpdateModuleDocument } from 'types/__generated__/moduleMutations.generated' import { GetProgramAdminsAndModulesDocument } from 'types/__generated__/moduleQueries.generated' +import { GetProgramAndModulesDocument } from 'types/__generated__/programsQueries.generated' import type { ExtendedSession } from 'types/auth' import type { ModuleFormData } from 'types/mentorship' import { formatDateForInput } from 'utils/dateFormatter' @@ -111,7 +112,11 @@ const EditModulePage = () => { tags: parseCommaSeparated(formData.tags), } - const result = await updateModule({ variables: { input } }) + const result = await updateModule({ + awaitRefetchQueries: true, + refetchQueries: [{ query: GetProgramAndModulesDocument, variables: { programKey } }], + variables: { input }, + }) const updatedModuleKey = result.data?.updateModule?.key || moduleKey addToast({ diff --git a/frontend/src/app/my/mentorship/programs/[programKey]/modules/[moduleKey]/page.tsx b/frontend/src/app/my/mentorship/programs/[programKey]/modules/[moduleKey]/page.tsx index eb1f1aa5b1..7475f3444d 100644 --- a/frontend/src/app/my/mentorship/programs/[programKey]/modules/[moduleKey]/page.tsx +++ b/frontend/src/app/my/mentorship/programs/[programKey]/modules/[moduleKey]/page.tsx @@ -21,6 +21,7 @@ const ModuleDetailsPage = () => { error, loading: isLoading, } = useQuery(GetProgramAdminsAndModulesDocument, { + fetchPolicy: 'cache-and-network', variables: { programKey, moduleKey, @@ -36,7 +37,7 @@ const ModuleDetailsPage = () => { } }, [data, error]) - if (isLoading) return + if (isLoading && !data) return if (!module) { return ( diff --git a/frontend/src/app/my/mentorship/programs/[programKey]/modules/create/page.tsx b/frontend/src/app/my/mentorship/programs/[programKey]/modules/create/page.tsx index 54dea5d0a5..d61ad014b1 100644 --- a/frontend/src/app/my/mentorship/programs/[programKey]/modules/create/page.tsx +++ b/frontend/src/app/my/mentorship/programs/[programKey]/modules/create/page.tsx @@ -4,7 +4,7 @@ import { addToast } from '@heroui/toast' import { useRouter, useParams } from 'next/navigation' import { useSession } from 'next-auth/react' import React, { useEffect, useState } from 'react' -import { ErrorDisplay, handleAppError } from 'app/global-error' +import { ErrorDisplay } from 'app/global-error' import { ExperienceLevelEnum } from 'types/__generated__/graphql' import { CreateModuleDocument } from 'types/__generated__/moduleMutations.generated' import { @@ -100,30 +100,9 @@ const CreateModulePage = () => { } await createModule({ + awaitRefetchQueries: true, + refetchQueries: [{ query: GetProgramAndModulesDocument, variables: { programKey } }], variables: { input }, - update: (cache, { data: mutationData }) => { - const created = mutationData?.createModule - if (!created) return - try { - const existing = cache.readQuery({ - query: GetProgramAndModulesDocument, - variables: { programKey }, - }) - if (existing?.getProgram && existing?.getProgramModules) { - cache.writeQuery({ - query: GetProgramAndModulesDocument, - variables: { programKey }, - data: { - getProgram: existing.getProgram, - getProgramModules: [created, ...existing.getProgramModules], - }, - }) - } - } catch (error) { - handleAppError(error) - return - } - }, }) addToast({ diff --git a/frontend/src/app/my/mentorship/programs/[programKey]/page.tsx b/frontend/src/app/my/mentorship/programs/[programKey]/page.tsx index e25859a092..2ec882c417 100644 --- a/frontend/src/app/my/mentorship/programs/[programKey]/page.tsx +++ b/frontend/src/app/my/mentorship/programs/[programKey]/page.tsx @@ -4,13 +4,12 @@ import { addToast } from '@heroui/toast' import { capitalize } from 'lodash' import { useParams } from 'next/navigation' import { useSession } from 'next-auth/react' -import { useEffect, useMemo, useState } from 'react' +import { useMemo } from 'react' import { ErrorDisplay, handleAppError } from 'app/global-error' import { ProgramStatusEnum } from 'types/__generated__/graphql' import { UpdateProgramStatusDocument } from 'types/__generated__/programsMutations.generated' import { GetProgramAndModulesDocument } from 'types/__generated__/programsQueries.generated' import type { ExtendedSession } from 'types/auth' -import type { Module, Program } from 'types/mentorship' import { titleCaseWord } from 'utils/capitalize' import { formatDate } from 'utils/dateFormatter' import DetailsCard from 'components/CardDetailsPage' @@ -22,9 +21,6 @@ const ProgramDetailsPage = () => { const { data: session } = useSession() const username = (session as ExtendedSession)?.user?.login - const [program, setProgram] = useState(null) - const [modules, setModules] = useState([]) - const [updateProgram] = useMutation(UpdateProgramStatusDocument, { onError: handleAppError, }) @@ -32,10 +28,13 @@ const ProgramDetailsPage = () => { const { data, loading: isQueryLoading } = useQuery(GetProgramAndModulesDocument, { variables: { programKey }, skip: !programKey, + fetchPolicy: 'cache-and-network', notifyOnNetworkStatusChange: true, }) const isLoading = isQueryLoading + const program = data?.getProgram ?? null + const modules = data?.getProgramModules ?? [] const isAdmin = useMemo( () => !!program?.admins?.some((admin) => admin.login === username), @@ -82,14 +81,7 @@ const ProgramDetailsPage = () => { } } - useEffect(() => { - if (data?.getProgram) { - setProgram(data.getProgram) - setModules(data.getProgramModules || []) - } - }, [data]) - - if (isLoading) return + if (isLoading && !data) return if (!program && !isLoading) { return ( diff --git a/frontend/src/app/my/mentorship/programs/create/page.tsx b/frontend/src/app/my/mentorship/programs/create/page.tsx index a9cc0a9dda..ef0cba9728 100644 --- a/frontend/src/app/my/mentorship/programs/create/page.tsx +++ b/frontend/src/app/my/mentorship/programs/create/page.tsx @@ -6,6 +6,7 @@ import { useSession } from 'next-auth/react' import React, { useEffect, useState } from 'react' import { CreateProgramDocument } from 'types/__generated__/programsMutations.generated' +import { GetMyProgramsDocument } from 'types/__generated__/programsQueries.generated' import { ExtendedSession } from 'types/auth' import { parseCommaSeparated } from 'utils/parser' import LoadingSpinner from 'components/LoadingSpinner' @@ -61,7 +62,11 @@ const CreateProgramPage = () => { domains: parseCommaSeparated(formData.domains), } - await createProgram({ variables: { input } }) + await createProgram({ + awaitRefetchQueries: true, + refetchQueries: [{ query: GetMyProgramsDocument }], + variables: { input }, + }) addToast({ description: 'Program created successfully!',