|
2 | 2 | import requests |
3 | 3 | import boto3 |
4 | 4 | import os |
| 5 | +import io |
| 6 | + |
| 7 | +from streamlit_cropper import st_cropper |
| 8 | +from PIL import Image, ImageDraw |
5 | 9 |
|
6 | 10 |
|
7 | 11 | ALB_URL = os.environ.get('ALB_URL') |
|
12 | 16 | s3 = boto3.client('s3') |
13 | 17 |
|
14 | 18 | st.set_page_config( |
15 | | - page_title='Gen AI - Image Variation', |
| 19 | + page_title='Gen AI - Image Replace', |
16 | 20 | page_icon = 'images/aws_favi.png', |
17 | 21 | # layout = 'wide' |
18 | 22 | ) |
19 | 23 | st.title('이미지 교체') |
20 | 24 | st.write('주변 배경과 일치하도록 변경하여 이미지를 수정합니다.') |
21 | 25 |
|
| 26 | +if "mask_enable" not in st.session_state: |
| 27 | + st.session_state.mask_enable = False |
| 28 | + |
22 | 29 | uploaded_file = st.file_uploader("파일을 선택하세요", type=['png', 'jpg']) |
23 | 30 | if uploaded_file is not None: |
24 | | - bytes_data = uploaded_file.getvalue() |
25 | | - st.image(bytes_data) |
| 31 | + # print(st.session_state.mask_enable) |
| 32 | + st.checkbox("이미지 마스크 지정", key="mask_enable") |
| 33 | + if st.session_state.mask_enable: |
| 34 | + img = Image.open(uploaded_file) |
| 35 | + width, height = img.size |
26 | 36 |
|
27 | | -m_input = st.text_area( |
28 | | - '이미지에서 남기고 싶은 오브젝트를 서술합니다 예) car, phone, bag', |
29 | | - '', |
30 | | - # height=100 |
31 | | -) |
32 | | -q_input = st.text_area( |
33 | | - '남기고 싶은 오프젝트 이외에 배경에 대해서 정의합니다', |
34 | | - '', |
35 | | - # height=100 |
36 | | -) |
| 37 | + cropped_box = st_cropper( |
| 38 | + img, |
| 39 | + realtime_update=True, |
| 40 | + box_color='#0000FF', |
| 41 | + aspect_ratio=None, |
| 42 | + return_type='box' |
| 43 | + ) |
| 44 | + |
| 45 | + left = cropped_box['left'] |
| 46 | + top = cropped_box['top'] |
| 47 | + right = left + cropped_box['width'] |
| 48 | + bottom = top + cropped_box['height'] |
| 49 | + shape = (left, top, right, bottom) |
| 50 | + |
| 51 | + masked_image = Image.new('RGB', (width, height), color=(255, 255, 255)) |
| 52 | + temp_image = ImageDraw.Draw(masked_image) |
| 53 | + temp_image.rectangle(shape, fill=(0, 0, 0)) |
| 54 | + |
| 55 | + # st.image(masked_image) |
| 56 | + |
| 57 | + st.write('이미지에서 남기고 싶은 오브젝트') |
| 58 | + cropped_image = img.crop(shape) |
| 59 | + _ = cropped_image.thumbnail((150,150)) |
| 60 | + st.image(cropped_image) |
| 61 | + else: |
| 62 | + bytes_data = uploaded_file.getvalue() |
| 63 | + st.image(bytes_data) |
| 64 | + |
| 65 | + m_input = st.text_area( |
| 66 | + '이미지에서 남기고 싶은 오브젝트를 지정합니다 예) car, phone, bag', |
| 67 | + '', |
| 68 | + # height=100 |
| 69 | + ) |
| 70 | + |
| 71 | + q_input = st.text_area( |
| 72 | + '지정한 오프젝트가 보일 배경에 대해서 정의합니다', |
| 73 | + '', |
| 74 | + # height=100 |
| 75 | + ) |
37 | 76 |
|
38 | 77 | with st.form('submit_form', clear_on_submit=True): |
39 | 78 | submitted = st.form_submit_button('Submit') |
40 | 79 | if submitted: |
41 | 80 | with st.spinner('Loading...'): |
| 81 | + uploaded_file.seek(0) |
| 82 | + s3.upload_fileobj( |
| 83 | + uploaded_file, |
| 84 | + BUCKET_NAME, |
| 85 | + f'images/{uploaded_file.name}' |
| 86 | + ) |
| 87 | + |
| 88 | + if st.session_state.mask_enable: |
| 89 | + in_mem_file = io.BytesIO() |
| 90 | + masked_image.save(in_mem_file, format='PNG') |
| 91 | + in_mem_file.seek(0) |
| 92 | + |
| 93 | + file_name, ext = os.path.splitext(uploaded_file.name) |
| 94 | + masked_image_name = f'images/masked_{file_name}.png' |
42 | 95 | s3.upload_fileobj( |
43 | | - uploaded_file, |
| 96 | + in_mem_file, |
44 | 97 | BUCKET_NAME, |
45 | | - f'images/{uploaded_file.name}' |
| 98 | + masked_image_name |
46 | 99 | ) |
47 | 100 |
|
| 101 | + data = { |
| 102 | + 'name': f'images/{uploaded_file.name}', |
| 103 | + 'prompt': q_input, |
| 104 | + 'mask_image': masked_image_name |
| 105 | + } |
| 106 | + else: |
48 | 107 | data = { |
49 | 108 | 'name': f'images/{uploaded_file.name}', |
50 | 109 | 'prompt': q_input, |
51 | 110 | 'mask_prompt': m_input |
52 | 111 | } |
53 | | - response = requests.post(API_URL, json=data) |
54 | | - result = response.text |
| 112 | + |
| 113 | + response = requests.post(API_URL, json=data) |
| 114 | + result = response.text |
55 | 115 |
|
56 | | - print(result) |
57 | | - image_object = s3.get_object(Bucket=BUCKET_NAME, Key=f'images/{result}') |
58 | | - st.image(image_object['Body'].read()) |
| 116 | + print(result) |
| 117 | + image_object = s3.get_object(Bucket=BUCKET_NAME, Key=f'images/{result}') |
| 118 | + st.image(image_object['Body'].read()) |
59 | 119 |
|
0 commit comments