Skip to content

Commit a3d64cc

Browse files
authored
feat: add 4k support (#36)
1 parent 69f7776 commit a3d64cc

File tree

12 files changed

+196
-23
lines changed

12 files changed

+196
-23
lines changed

.env.example

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ REPLICATE_API_TOKEN=
88
REPLICATE_USERNAME=
99
REPLICATE_MAX_TRAIN_STEPS=3000
1010
REPLICATE_NEGATIVE_PROMPT="cropped face, cover face, cover visage, mutated hands"
11+
REPLICATE_HD_VERSION_MODEL_ID=
1112
NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN=
1213
SECRET=
1314
EMAIL_FROM=

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ REPLICATE_API_TOKEN=
5757
REPLICATE_USERNAME=
5858
REPLICATE_MAX_TRAIN_STEPS=3000
5959
REPLICATE_NEGATIVE_PROMPT=
60+
REPLICATE_HD_VERSION_MODEL_ID=
6061
6162
// Replicate instance token (should be rare)
6263
NEXT_PUBLIC_REPLICATE_INSTANCE_TOKEN=
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
-- CreateEnum
2+
CREATE TYPE "HdStatus" AS ENUM ('NO', 'PENDING', 'PROCESSED');
3+
4+
-- AlterTable
5+
ALTER TABLE "Shot" ADD COLUMN "hdStatus" "HdStatus" NOT NULL DEFAULT 'NO';
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
-- AlterTable
2+
ALTER TABLE "Shot" ADD COLUMN "hdPredictionId" TEXT;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
-- AlterTable
2+
ALTER TABLE "Shot" ADD COLUMN "hdOutputUrl" TEXT;

prisma/schema.prisma

+10-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ generator client {
77
provider = "prisma-client-js"
88
}
99

10+
enum HdStatus {
11+
NO
12+
PENDING
13+
PROCESSED
14+
}
15+
1016
model Account {
1117
id String @id @default(cuid())
1218
userId String @map("user_id")
@@ -76,7 +82,7 @@ model Project {
7682
userId String?
7783
shots Shot[]
7884
credits Int @default(100)
79-
promptWizardCredits Int @default(30)
85+
promptWizardCredits Int @default(20)
8086
Payment Payment[]
8187
}
8288

@@ -93,6 +99,9 @@ model Shot {
9399
bookmarked Boolean? @default(false)
94100
blurhash String?
95101
seed Int?
102+
hdStatus HdStatus @default(NO)
103+
hdPredictionId String?
104+
hdOutputUrl String?
96105
}
97106

98107
model Payment {

src/components/home/Pricing.tsx

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ const Pricing = () => {
7474
<b>1</b> Studio with a <b>custom trained model</b>
7575
</CheckedListItem>
7676
<CheckedListItem>
77-
<b>{process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT}</b> avatars
77+
<b>{process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT}</b> avatars 4K
7878
generation
7979
</CheckedListItem>
8080
<CheckedListItem>

src/components/projects/FormPayment.tsx

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ const FormPayment = ({
7676
<b>1</b> Studio with a <b>custom trained model</b>
7777
</CheckedListItem>
7878
<CheckedListItem>
79-
<b>{process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT}</b> avatars
80-
generation (512x512 resolution)
79+
<b>{process.env.NEXT_PUBLIC_STUDIO_SHOT_AMOUNT}</b> avatars 4K
80+
generation
8181
</CheckedListItem>
8282
<CheckedListItem>
8383
<b>30</b> AI prompt assists

src/components/projects/shot/ShotCard.tsx

+95-17
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,27 @@ import { BsHeart, BsHeartFill } from "react-icons/bs";
2323
import { HiDownload } from "react-icons/hi";
2424
import { IoMdCheckmarkCircleOutline } from "react-icons/io";
2525
import { MdOutlineModelTraining } from "react-icons/md";
26+
import { Ri4KFill } from "react-icons/ri";
2627
import { useMutation, useQuery } from "react-query";
2728
import ShotImage from "./ShotImage";
2829
import { TbFaceIdError } from "react-icons/tb";
2930

31+
const getHdLabel = (shot: Shot, isHd: boolean) => {
32+
if (shot.hdStatus === "NO") {
33+
return "Generate in 4K";
34+
}
35+
36+
if (shot.hdStatus === "PENDING") {
37+
return "4K in progress";
38+
}
39+
40+
if (shot.hdStatus === "PROCESSED" && isHd) {
41+
return "Show standard resolution";
42+
}
43+
44+
return "Show 4K";
45+
};
46+
3047
const ShotCard = ({
3148
shot: initialShot,
3249
handleSeed,
@@ -37,6 +54,7 @@ const ShotCard = ({
3754
const { onCopy, hasCopied } = useClipboard(initialShot.prompt);
3855
const { query } = useRouter();
3956
const [shot, setShot] = useState(initialShot);
57+
const [isHd, setIsHd] = useState(Boolean(shot.hdOutputUrl));
4058

4159
const { mutate: bookmark, isLoading } = useMutation(
4260
`update-shot-${initialShot.id}`,
@@ -54,6 +72,19 @@ const ShotCard = ({
5472
}
5573
);
5674

75+
const { mutate: createdHd, isLoading: isCreatingHd } = useMutation(
76+
`create-hd-${initialShot.id}`,
77+
() =>
78+
axios.post<{ shot: Shot }>(
79+
`/api/projects/${query.id}/predictions/${initialShot.id}/hd`
80+
),
81+
{
82+
onSuccess: (response) => {
83+
setShot(response.data.shot);
84+
},
85+
}
86+
);
87+
5788
useQuery(
5889
`shot-${initialShot.id}`,
5990
() =>
@@ -65,10 +96,33 @@ const ShotCard = ({
6596
{
6697
refetchInterval: (data) => (data?.shot.outputUrl ? false : 5000),
6798
refetchOnWindowFocus: false,
68-
enabled: !initialShot.outputUrl,
99+
enabled: !initialShot.outputUrl && initialShot.status !== "failed",
100+
initialData: { shot: initialShot },
101+
onSuccess: (response) => {
102+
setShot(response.shot);
103+
},
104+
}
105+
);
106+
107+
useQuery(
108+
`shot-hd-${initialShot.id}`,
109+
() =>
110+
axios
111+
.get<{ shot: Shot }>(
112+
`/api/projects/${query.id}/predictions/${initialShot.id}/hd`
113+
)
114+
.then((res) => res.data),
115+
{
116+
refetchInterval: (data) =>
117+
data?.shot.hdStatus !== "PENDING" ? false : 5000,
118+
refetchOnWindowFocus: false,
119+
enabled: shot.hdStatus === "PENDING",
69120
initialData: { shot: initialShot },
70121
onSuccess: (response) => {
71122
setShot(response.shot);
123+
if (response.shot.hdOutputUrl) {
124+
setIsHd(true);
125+
}
72126
},
73127
}
74128
);
@@ -82,7 +136,7 @@ const ShotCard = ({
82136
position="relative"
83137
>
84138
{shot.outputUrl ? (
85-
<ShotImage shot={shot} />
139+
<ShotImage isHd={isHd} shot={shot} />
86140
) : (
87141
<Box>
88142
<AspectRatio ratio={1}>
@@ -104,10 +158,7 @@ const ShotCard = ({
104158
</Box>
105159
)}
106160
<Flex position="relative" p={3} flexDirection="column">
107-
<Flex alignItems="center" justifyContent="space-between">
108-
<Text color="blackAlpha.700" fontSize="xs">
109-
{formatRelative(new Date(shot.createdAt), new Date())}
110-
</Text>
161+
<Flex alignItems="center" justifyContent="flex-end">
111162
<Box>
112163
{shot.seed && shot.outputUrl && (
113164
<Tooltip hasArrow label="Re-use style">
@@ -129,17 +180,41 @@ const ShotCard = ({
129180
</Tooltip>
130181
)}
131182
{shot.outputUrl && (
132-
<IconButton
133-
size="sm"
134-
as={Link}
135-
href={shot.outputUrl}
136-
target="_blank"
137-
variant="ghost"
138-
aria-label="Download"
139-
fontSize="md"
140-
icon={<HiDownload />}
141-
/>
183+
<>
184+
<IconButton
185+
size="sm"
186+
as={Link}
187+
href={isHd ? shot.hdOutputUrl : shot.outputUrl}
188+
target="_blank"
189+
variant="ghost"
190+
aria-label="Download"
191+
fontSize="md"
192+
icon={<HiDownload />}
193+
/>
194+
<Tooltip hasArrow label={getHdLabel(shot, isHd)}>
195+
<IconButton
196+
icon={<Ri4KFill />}
197+
color={isHd ? "red.400" : "gray.600"}
198+
isLoading={shot.hdStatus === "PENDING" || isCreatingHd}
199+
onClick={() => {
200+
if (shot.hdStatus === "NO") {
201+
createdHd();
202+
} else if (
203+
shot.hdStatus === "PROCESSED" &&
204+
shot.hdOutputUrl
205+
) {
206+
setIsHd(!isHd);
207+
}
208+
}}
209+
size="sm"
210+
variant="ghost"
211+
aria-label="Make 4K"
212+
fontSize="lg"
213+
/>
214+
</Tooltip>
215+
</>
142216
)}
217+
143218
<Tooltip
144219
hasArrow
145220
label={`${shot.bookmarked ? "Remove" : "Add"} to your gallery`}
@@ -168,7 +243,10 @@ const ShotCard = ({
168243
{shot.prompt}
169244
</Text>
170245

171-
<HStack mt={4}>
246+
<HStack justifyContent="space-between" mt={4}>
247+
<Text color="beige.400" fontSize="xs">
248+
{formatRelative(new Date(shot.createdAt), new Date())}
249+
</Text>
172250
<Button
173251
rightIcon={hasCopied ? <IoMdCheckmarkCircleOutline /> : undefined}
174252
colorScheme="beige"

src/components/projects/shot/ShotImage.tsx

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import { useRouter } from "next/router";
55
import React from "react";
66
import { Controlled as ControlledZoom } from "react-medium-image-zoom";
77

8-
const ShotImage = ({ shot }: { shot: Shot }) => {
8+
const ShotImage = ({ shot, isHd = false }: { shot: Shot; isHd?: boolean }) => {
99
const { push, query } = useRouter();
1010
const { onOpen, onClose, isOpen: isZoomed } = useDisclosure();
1111

@@ -34,7 +34,7 @@ const ShotImage = ({ shot }: { shot: Shot }) => {
3434
placeholder="blur"
3535
blurDataURL={shot.blurhash || "placeholder"}
3636
alt={shot.prompt}
37-
src={shot.outputUrl!}
37+
src={isHd ? shot.hdOutputUrl! : shot.outputUrl!}
3838
width={512}
3939
height={512}
4040
unoptimized
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import replicateClient from "@/core/clients/replicate";
2+
import db from "@/core/db";
3+
import { NextApiRequest, NextApiResponse } from "next";
4+
import { getSession } from "next-auth/react";
5+
6+
const handler = async (req: NextApiRequest, res: NextApiResponse) => {
7+
const projectId = req.query.id as string;
8+
const predictionId = req.query.predictionId as string;
9+
10+
const session = await getSession({ req });
11+
12+
if (!session?.user) {
13+
return res.status(401).json({ message: "Not authenticated" });
14+
}
15+
16+
const project = await db.project.findFirstOrThrow({
17+
where: { id: projectId, userId: session.userId },
18+
});
19+
20+
let shot = await db.shot.findFirstOrThrow({
21+
where: { projectId: project.id, id: predictionId },
22+
});
23+
24+
if (req.method === "POST") {
25+
if (shot.hdStatus !== "NO") {
26+
return res.status(400).json({ message: "4K already applied" });
27+
}
28+
29+
const { data } = await replicateClient.post(
30+
`https://api.replicate.com/v1/predictions`,
31+
{
32+
input: {
33+
image: shot.outputUrl,
34+
upscale: 8,
35+
face_upsample: true,
36+
codeformer_fidelity: 1,
37+
},
38+
version: process.env.REPLICATE_HD_VERSION_MODEL_ID,
39+
}
40+
);
41+
42+
shot = await db.shot.update({
43+
where: { id: shot.id },
44+
data: { hdStatus: "PENDING", hdPredictionId: data.id },
45+
});
46+
47+
return res.json({ shot });
48+
}
49+
50+
if (req.method === "GET") {
51+
if (shot.hdStatus !== "PENDING") {
52+
return res.status(400).json({ message: "4K already applied" });
53+
}
54+
55+
const { data: prediction } = await replicateClient.get(
56+
`https://api.replicate.com/v1/predictions/${shot.hdPredictionId}`
57+
);
58+
59+
if (prediction.output) {
60+
shot = await db.shot.update({
61+
where: { id: shot.id },
62+
data: {
63+
hdStatus: "PROCESSED",
64+
hdOutputUrl: prediction.output,
65+
},
66+
});
67+
}
68+
69+
return res.json({ shot });
70+
}
71+
72+
return res.status(405).json({ message: "Method not allowed" });
73+
};
74+
75+
export default handler;

0 commit comments

Comments
 (0)