| | |
| | |
| |
|
| | |
| | |
| |
|
| | import copy |
| | import os |
| | from datetime import datetime |
| |
|
| | import gradio as gr |
| |
|
| | os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1" |
| | import tempfile |
| |
|
| | import cv2 |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import spaces |
| |
|
| | import torch |
| |
|
| | from moviepy.editor import ImageSequenceClip |
| | from PIL import Image |
| | from sam2.build_sam import build_sam2_video_predictor |
| |
|
| | |
| | title = "<center><strong><font size='8'>EdgeTAM<font></strong> <a href='https://github.com/facebookresearch/EdgeTAM'><font size='6'>[GitHub]</font></a> </center>" |
| |
|
| | description_p = """# EdgeTAM is now in transformers, find it [here](https://huggingface.co/yonigozlan/EdgeTAM-hf) |
| | <ol> |
| | <li> Upload one video or click one example video</li> |
| | <li> Click 'include' point type, select the object to segment and track</li> |
| | <li> Click 'exclude' point type (optional), select the area you want to avoid segmenting and tracking</li> |
| | <li> Click the 'Track' button to obtain the masked video </li> |
| | </ol> |
| | """ |
| |
|
| | |
| | examples = [ |
| | ["examples/trimmed/01_dog.mp4"], |
| | ["examples/trimmed/02_cups.mp4"], |
| | ["examples/trimmed/03_blocks.mp4"], |
| | ["examples/trimmed/04_coffee.mp4"], |
| | ["examples/trimmed/05_default_juggle.mp4"], |
| | ["examples/trimmed/01_breakdancer.mp4"], |
| | ["examples/trimmed/02_hummingbird.mp4"], |
| | ["examples/trimmed/03_skateboarder.mp4"], |
| | ["examples/trimmed/04_octopus.mp4"], |
| | ["examples/trimmed/05_landing_dog_soccer.mp4"], |
| | ["examples/trimmed/06_pingpong.mp4"], |
| | ["examples/trimmed/07_snowboarder.mp4"], |
| | ["examples/trimmed/08_driving.mp4"], |
| | ["examples/trimmed/09_birdcartoon.mp4"], |
| | ["examples/trimmed/10_cloth_magic.mp4"], |
| | ["examples/trimmed/11_polevault.mp4"], |
| | ["examples/trimmed/12_hideandseek.mp4"], |
| | ["examples/trimmed/13_butterfly.mp4"], |
| | ["examples/trimmed/14_social_dog_training.mp4"], |
| | ["examples/trimmed/15_cricket.mp4"], |
| | ["examples/trimmed/16_robotarm.mp4"], |
| | ["examples/trimmed/17_childrendancing.mp4"], |
| | ["examples/trimmed/18_threedogs.mp4"], |
| | ["examples/trimmed/19_cyclist.mp4"], |
| | ["examples/trimmed/20_doughkneading.mp4"], |
| | ["examples/trimmed/21_biker.mp4"], |
| | ["examples/trimmed/22_dogskateboarder.mp4"], |
| | ["examples/trimmed/23_racecar.mp4"], |
| | ["examples/trimmed/24_clownfish.mp4"], |
| | ] |
| |
|
| | OBJ_ID = 0 |
| | sam2_checkpoint = "checkpoints/edgetam.pt" |
| | model_cfg = "edgetam.yaml" |
| | predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu") |
| |
|
| |
|
| | def get_video_fps(video_path): |
| | |
| | cap = cv2.VideoCapture(video_path) |
| |
|
| | if not cap.isOpened(): |
| | print("Error: Could not open video.") |
| | return None |
| |
|
| | |
| | fps = cap.get(cv2.CAP_PROP_FPS) |
| |
|
| | return fps |
| |
|
| |
|
| | def reset_state(inference_state): |
| | for v in inference_state["point_inputs_per_obj"].values(): |
| | v.clear() |
| | for v in inference_state["mask_inputs_per_obj"].values(): |
| | v.clear() |
| | for v in inference_state["output_dict_per_obj"].values(): |
| | v["cond_frame_outputs"].clear() |
| | v["non_cond_frame_outputs"].clear() |
| | for v in inference_state["temp_output_dict_per_obj"].values(): |
| | v["cond_frame_outputs"].clear() |
| | v["non_cond_frame_outputs"].clear() |
| | inference_state["output_dict"]["cond_frame_outputs"].clear() |
| | inference_state["output_dict"]["non_cond_frame_outputs"].clear() |
| | inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() |
| | inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() |
| | inference_state["tracking_has_started"] = False |
| | inference_state["frames_already_tracked"].clear() |
| | inference_state["obj_id_to_idx"].clear() |
| | inference_state["obj_idx_to_id"].clear() |
| | inference_state["obj_ids"].clear() |
| | inference_state["point_inputs_per_obj"].clear() |
| | inference_state["mask_inputs_per_obj"].clear() |
| | inference_state["output_dict_per_obj"].clear() |
| | inference_state["temp_output_dict_per_obj"].clear() |
| | return inference_state |
| |
|
| |
|
| | def reset( |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ): |
| | first_frame = None |
| | all_frames = None |
| | input_points = [] |
| | input_labels = [] |
| |
|
| | inference_state = None |
| | return ( |
| | None, |
| | gr.update(open=True), |
| | None, |
| | None, |
| | gr.update(value=None, visible=False), |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ) |
| |
|
| |
|
| | def clear_points( |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ): |
| | input_points = [] |
| | input_labels = [] |
| | if inference_state and inference_state["tracking_has_started"]: |
| | inference_state = reset_state(inference_state) |
| | return ( |
| | first_frame, |
| | None, |
| | gr.update(value=None, visible=False), |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ) |
| |
|
| |
|
| | def preprocess_video_in( |
| | video_path, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ): |
| | if video_path is None: |
| | return ( |
| | gr.update(open=True), |
| | None, |
| | None, |
| | gr.update(value=None, visible=False), |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ) |
| |
|
| | |
| | cap = cv2.VideoCapture(video_path) |
| | if not cap.isOpened(): |
| | print("Error: Could not open video.") |
| | return ( |
| | gr.update(open=True), |
| | None, |
| | None, |
| | gr.update(value=None, visible=False), |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ) |
| |
|
| | frame_number = 0 |
| | _first_frame = None |
| | all_frames = [] |
| |
|
| | while True: |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| |
|
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | frame = np.array(frame) |
| |
|
| | |
| | if frame_number == 0: |
| | _first_frame = frame |
| | all_frames.append(frame) |
| |
|
| | frame_number += 1 |
| |
|
| | cap.release() |
| | first_frame = copy.deepcopy(_first_frame) |
| | input_points = [] |
| | input_labels = [] |
| |
|
| | predictor.to("cpu") |
| | inference_state = predictor.init_state( |
| | offload_video_to_cpu=True, |
| | offload_state_to_cpu=True, |
| | video_path=video_path, |
| | ) |
| |
|
| | return [ |
| | gr.update(open=False), |
| | first_frame, |
| | None, |
| | gr.update(value=None, visible=False), |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ] |
| |
|
| |
|
| | def segment_with_points( |
| | point_type, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | evt: gr.SelectData, |
| | ): |
| | predictor.to("cpu") |
| | if inference_state: |
| | inference_state["device"] = predictor.device |
| | input_points.append(evt.index) |
| | print(f"TRACKING INPUT POINT: {input_points}") |
| |
|
| | if point_type == "include": |
| | input_labels.append(1) |
| | elif point_type == "exclude": |
| | input_labels.append(0) |
| | print(f"TRACKING INPUT LABEL: {input_labels}") |
| |
|
| | |
| | transparent_background = Image.fromarray(first_frame).convert("RGBA") |
| | w, h = transparent_background.size |
| |
|
| | |
| | fraction = 0.01 |
| | radius = int(fraction * min(w, h)) |
| |
|
| | |
| | transparent_layer = np.zeros((h, w, 4), dtype=np.uint8) |
| |
|
| | for index, track in enumerate(input_points): |
| | if input_labels[index] == 1: |
| | cv2.circle(transparent_layer, track, radius, (0, 255, 0, 255), -1) |
| | else: |
| | cv2.circle(transparent_layer, track, radius, (255, 0, 0, 255), -1) |
| |
|
| | |
| | transparent_layer = Image.fromarray(transparent_layer, "RGBA") |
| | selected_point_map = Image.alpha_composite( |
| | transparent_background, transparent_layer |
| | ) |
| |
|
| | |
| | points = np.array(input_points, dtype=np.float32) |
| | |
| | labels = np.array(input_labels, dtype=np.int32) |
| | _, _, out_mask_logits = predictor.add_new_points( |
| | inference_state=inference_state, |
| | frame_idx=0, |
| | obj_id=OBJ_ID, |
| | points=points, |
| | labels=labels, |
| | ) |
| |
|
| | mask_image = show_mask((out_mask_logits[0] > 0.0).cpu().numpy()) |
| | first_frame_output = Image.alpha_composite(transparent_background, mask_image) |
| |
|
| | torch.cuda.empty_cache() |
| | return ( |
| | selected_point_map, |
| | first_frame_output, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ) |
| |
|
| |
|
| | def show_mask(mask, obj_id=None, random_color=False, convert_to_image=True): |
| | if random_color: |
| | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
| | else: |
| | cmap = plt.get_cmap("tab10") |
| | cmap_idx = 0 if obj_id is None else obj_id |
| | color = np.array([*cmap(cmap_idx)[:3], 0.6]) |
| | h, w = mask.shape[-2:] |
| | mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| | mask = (mask * 255).astype(np.uint8) |
| | if convert_to_image: |
| | mask = Image.fromarray(mask, "RGBA") |
| | return mask |
| |
|
| |
|
| | @spaces.GPU(duration=60) |
| | def propagate_to_all( |
| | video_in, |
| | all_frames, |
| | input_points, |
| | inference_state, |
| | ): |
| | if torch.cuda.get_device_properties(0).major >= 8: |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| | predictor.to("cuda") |
| | if inference_state: |
| | inference_state["device"] = predictor.device |
| |
|
| | if len(input_points) == 0 or video_in is None or inference_state is None: |
| | return None |
| | |
| | video_segments = ( |
| | {} |
| | ) |
| | print("starting propagate_in_video") |
| | for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( |
| | inference_state |
| | ): |
| | video_segments[out_frame_idx] = { |
| | out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() |
| | for i, out_obj_id in enumerate(out_obj_ids) |
| | } |
| |
|
| | |
| | vis_frame_stride = 1 |
| |
|
| | output_frames = [] |
| | for out_frame_idx in range(0, len(video_segments), vis_frame_stride): |
| | transparent_background = Image.fromarray(all_frames[out_frame_idx]).convert( |
| | "RGBA" |
| | ) |
| | out_mask = video_segments[out_frame_idx][OBJ_ID] |
| | mask_image = show_mask(out_mask) |
| | output_frame = Image.alpha_composite(transparent_background, mask_image) |
| | output_frame = np.array(output_frame) |
| | output_frames.append(output_frame) |
| |
|
| | torch.cuda.empty_cache() |
| |
|
| | |
| | original_fps = get_video_fps(video_in) |
| | fps = original_fps |
| | clip = ImageSequenceClip(output_frames, fps=fps) |
| | |
| | unique_id = datetime.now().strftime("%Y%m%d%H%M%S") |
| | final_vid_output_path = f"output_video_{unique_id}.mp4" |
| | final_vid_output_path = os.path.join( |
| | tempfile.gettempdir(), final_vid_output_path |
| | ) |
| |
|
| | |
| | clip.write_videofile(final_vid_output_path, codec="libx264") |
| |
|
| | return gr.update(value=final_vid_output_path) |
| |
|
| |
|
| | def update_ui(): |
| | return gr.update(visible=True) |
| |
|
| |
|
| | with gr.Blocks() as demo: |
| | first_frame = gr.State() |
| | all_frames = gr.State() |
| | input_points = gr.State([]) |
| | input_labels = gr.State([]) |
| | inference_state = gr.State() |
| |
|
| | with gr.Column(): |
| | |
| | gr.Markdown(title) |
| | with gr.Row(): |
| |
|
| | with gr.Column(): |
| | |
| | gr.Markdown(description_p) |
| |
|
| | with gr.Accordion("Input Video", open=True) as video_in_drawer: |
| | video_in = gr.Video(label="Input Video", format="mp4") |
| |
|
| | with gr.Row(): |
| | point_type = gr.Radio( |
| | label="point type", |
| | choices=["include", "exclude"], |
| | value="include", |
| | scale=2, |
| | ) |
| | propagate_btn = gr.Button("Track", scale=1, variant="primary") |
| | clear_points_btn = gr.Button("Clear Points", scale=1) |
| | reset_btn = gr.Button("Reset", scale=1) |
| |
|
| | points_map = gr.Image( |
| | label="Frame with Point Prompt", type="numpy", interactive=False |
| | ) |
| |
|
| | with gr.Column(): |
| | gr.Markdown("# Try some of the examples below ⬇️") |
| | gr.Examples( |
| | examples=examples, |
| | inputs=[ |
| | video_in, |
| | ], |
| | examples_per_page=8, |
| | ) |
| | gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") |
| | gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") |
| | gr.Markdown("\n\n\n\n\n\n\n\n\n\n\n") |
| | output_image = gr.Image(label="Reference Mask") |
| |
|
| | output_video = gr.Video(visible=False) |
| |
|
| | |
| | video_in.upload( |
| | fn=preprocess_video_in, |
| | inputs=[ |
| | video_in, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ], |
| | outputs=[ |
| | video_in_drawer, |
| | points_map, |
| | output_image, |
| | output_video, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ], |
| | queue=False, |
| | ) |
| |
|
| | video_in.change( |
| | fn=preprocess_video_in, |
| | inputs=[ |
| | video_in, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ], |
| | outputs=[ |
| | video_in_drawer, |
| | points_map, |
| | output_image, |
| | output_video, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ], |
| | queue=False, |
| | ) |
| |
|
| | |
| | points_map.select( |
| | fn=segment_with_points, |
| | inputs=[ |
| | point_type, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ], |
| | outputs=[ |
| | points_map, |
| | output_image, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ], |
| | queue=False, |
| | ) |
| |
|
| | |
| | clear_points_btn.click( |
| | fn=clear_points, |
| | inputs=[ |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ], |
| | outputs=[ |
| | points_map, |
| | output_image, |
| | output_video, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ], |
| | queue=False, |
| | ) |
| |
|
| | reset_btn.click( |
| | fn=reset, |
| | inputs=[ |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ], |
| | outputs=[ |
| | video_in, |
| | video_in_drawer, |
| | points_map, |
| | output_image, |
| | output_video, |
| | first_frame, |
| | all_frames, |
| | input_points, |
| | input_labels, |
| | inference_state, |
| | ], |
| | queue=False, |
| | ) |
| |
|
| | propagate_btn.click( |
| | fn=update_ui, |
| | inputs=[], |
| | outputs=output_video, |
| | queue=False, |
| | ).then( |
| | fn=propagate_to_all, |
| | inputs=[ |
| | video_in, |
| | all_frames, |
| | input_points, |
| | inference_state, |
| | ], |
| | outputs=[ |
| | output_video, |
| | ], |
| | concurrency_limit=10, |
| | queue=False, |
| | ) |
| |
|
| |
|
| | |
| | demo.launch() |
| |
|