Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_qwen_image_edit_pre_process_func(
vae_config = json.load(f)
vae_scale_factor = 2 ** len(vae_config["temporal_downsample"]) if "temporal_downsample" in vae_config else 8

image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2, do_convert_rgb=True)
latent_channels = vae_config.get("z_dim", 16)

def pre_process_func(
Expand Down Expand Up @@ -136,7 +136,7 @@ def get_qwen_image_edit_post_process_func(
vae_config = json.load(f)
vae_scale_factor = 2 ** len(vae_config["temporal_downsample"]) if "temporal_downsample" in vae_config else 8

image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2, do_convert_rgb=True)

def post_process_func(
images: torch.Tensor,
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(

self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_convert_rgb=True)
self.tokenizer_max_length = 1024
# Edit prompt template - different from generation template
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_qwen_image_edit_plus_pre_process_func(
vae_config = json.load(f)
vae_scale_factor = 2 ** len(vae_config["temporal_downsample"]) if "temporal_downsample" in vae_config else 8

image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2, do_convert_rgb=True)
latent_channels = vae_config.get("z_dim", 16)

def pre_process_func(
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_qwen_image_edit_plus_post_process_func(
vae_config = json.load(f)
vae_scale_factor = 2 ** len(vae_config["temporal_downsample"]) if "temporal_downsample" in vae_config else 8

image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2, do_convert_rgb=True)

def post_process_func(
images: torch.Tensor,
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(

self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_convert_rgb=True)
self.tokenizer_max_length = 1024
# Edit prompt template - different from generation template, supports multiple images
self.prompt_template_encode = (
Expand Down